-
Notifications
You must be signed in to change notification settings - Fork 7
Improve memory issues #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Changed key creation step inline with jax-ml/jax#17432 to fix memory leak when repeatedly calling estimators.
No changes to functionality, just tidying up previously added .DS_Store files.
Same as last commit, trying to remove remaining .DS_Store files without more appearing
Finally seem to have removed all of the .DS_Store files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
thanks a lot for the improvements! I like them a lot 🙂
I only left some very minor comments, corresponding to some code clarity improvements. I think we can merge it as soon as they are addressed and the PR passes the GitHub CI checks 🙂
(I think that running pre-commit as described here will resolve automatically all (or at least most) of the CI's complaints.)
Thank you once again!
src/bmi/estimators/_histogram.py
Outdated
| from bmi.interface import BaseModel, IMutualInformationPointEstimator | ||
| from bmi.utils import ProductSpace | ||
|
|
||
| import gc # new |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment # new will become redundant in a few months, so I'd suggest removing it from this PR.
| ) | ||
| keys = jax.random.split(rng, max_n_steps) | ||
| for n_step, key in enumerate(keys, start=1): | ||
| # keys = jax.random.split(rng, max_n_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice fix! Would you mind adding more in-text context, e.g.,
# We don't use
# keys = jax.random.split(rng, max_n_steps)
# because of memory leaks. See:
# https://github.com/jax-ml/jax/issues/17432for future reference, so we don't forget the reason why to avoid this in the future?
| self.layers.append(eqx.nn.Linear(dims[-1], 1, key=key_final)) | ||
|
|
||
| def __call__(self, x: Point, y: Point) -> jax.Array: | ||
| # print(f"Critic - x shape {x.shape}, y shape {y.shape}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be removed.
| from bmi.interface import BaseModel, EstimateResult, IMutualInformationPointEstimator | ||
| from bmi.utils import ProductSpace | ||
|
|
||
| import gc # new |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment can be removed here as well.
| xs_batch.delete() | ||
| ys_batch_paired.delete() | ||
| ys_batch_unpaired.delete() | ||
| # xs_test.delete() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can # xs_test.detele() be removed as well?
| ys_batch_unpaired.delete() | ||
| # xs_test.delete() | ||
| # ys_test_unpaired.delete() | ||
| del xs_batch, ys_batch_paired, ys_batch_unpaired #, xs_test, ys_test_unpaired |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Similarly here: the commented out variables should be removed).
| del xs_batch, ys_batch_paired, ys_batch_unpaired #, xs_test, ys_test_unpaired | ||
|
|
||
| training_log.finish() | ||
| jax.clear_caches() # clears jit/compilation & staging caches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice fix and explanation!
Removed '# new' from 'gc' imports and previous changes. Updated information regarding memory leaks in neural network basic training file and mentioned its change in MINE estimation. Removed deprecated for loops.
Changed key creation step in line with jax-ml/jax#17432 to fix memory leak when repeatedly calling estimators i.e.
bmi.estimators.HistogramEstimator(). The updates are only to the estimator computation functions where the keys, x and y are directly used i.e. the functionestimate()within classHistogramEstimatorParams(BaseModel).Files for issue replication:
bmi_estimate_mi.py
memory_test.py