Skip to content

Conversation

@RGoldsack
Copy link

Changed key creation step in line with jax-ml/jax#17432 to fix memory leak when repeatedly calling estimators i.e. bmi.estimators.HistogramEstimator(). The updates are only to the estimator computation functions where the keys, x and y are directly used i.e. the function estimate() within class HistogramEstimatorParams(BaseModel).

Files for issue replication:
bmi_estimate_mi.py
memory_test.py

Changed key creation step inline with jax-ml/jax#17432 to fix memory leak when repeatedly calling estimators.
No changes to functionality, just tidying up previously added .DS_Store files.
Same as last commit, trying to remove remaining .DS_Store files without more appearing
Finally seem to have removed all of the .DS_Store files
@a-marx a-marx requested a review from pawel-czyz September 16, 2025 09:27
Copy link
Member

@pawel-czyz pawel-czyz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

thanks a lot for the improvements! I like them a lot 🙂
I only left some very minor comments, corresponding to some code clarity improvements. I think we can merge it as soon as they are addressed and the PR passes the GitHub CI checks 🙂
(I think that running pre-commit as described here will resolve automatically all (or at least most) of the CI's complaints.)

Thank you once again!

from bmi.interface import BaseModel, IMutualInformationPointEstimator
from bmi.utils import ProductSpace

import gc # new
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment # new will become redundant in a few months, so I'd suggest removing it from this PR.

)
keys = jax.random.split(rng, max_n_steps)
for n_step, key in enumerate(keys, start=1):
# keys = jax.random.split(rng, max_n_steps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice fix! Would you mind adding more in-text context, e.g.,

# We don't use 
# keys = jax.random.split(rng, max_n_steps)
# because of memory leaks. See:
# https://github.com/jax-ml/jax/issues/17432

for future reference, so we don't forget the reason why to avoid this in the future?

self.layers.append(eqx.nn.Linear(dims[-1], 1, key=key_final))

def __call__(self, x: Point, y: Point) -> jax.Array:
# print(f"Critic - x shape {x.shape}, y shape {y.shape}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be removed.

from bmi.interface import BaseModel, EstimateResult, IMutualInformationPointEstimator
from bmi.utils import ProductSpace

import gc # new
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment can be removed here as well.

xs_batch.delete()
ys_batch_paired.delete()
ys_batch_unpaired.delete()
# xs_test.delete()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can # xs_test.detele() be removed as well?

ys_batch_unpaired.delete()
# xs_test.delete()
# ys_test_unpaired.delete()
del xs_batch, ys_batch_paired, ys_batch_unpaired #, xs_test, ys_test_unpaired
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Similarly here: the commented out variables should be removed).

del xs_batch, ys_batch_paired, ys_batch_unpaired #, xs_test, ys_test_unpaired

training_log.finish()
jax.clear_caches() # clears jit/compilation & staging caches
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice fix and explanation!

pawel-czyz and others added 2 commits September 16, 2025 12:50
Removed '# new' from 'gc' imports and previous changes. Updated information regarding memory leaks in neural network basic training file and mentioned its change in MINE estimation. Removed deprecated for loops.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants