Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/bmi/estimators/_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from bmi.interface import BaseModel, IMutualInformationPointEstimator
from bmi.utils import ProductSpace

import gc

class HistogramEstimatorParams(BaseModel):
n_bins_x: pydantic.PositiveInt
Expand Down Expand Up @@ -75,5 +76,7 @@ def estimate(self, x: ArrayLike, y: ArrayLike) -> float:
for j in range(range_y):
if p_xy[i, j] > 0:
mi += p_xy[i, j] * (np.log(p_xy[i, j]) - np.log(p_x[i]) - np.log(p_y[j]))
del x, y
gc.collect()

return mi
4 changes: 4 additions & 0 deletions src/bmi/estimators/_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from bmi.interface import BaseModel, IMutualInformationPointEstimator
from bmi.utils import ProductSpace

import gc

def _differential_entropy(estimator: KernelDensity, samples: np.ndarray) -> float:
"""Estimates the differential entropy of a distribution by fitting
Expand Down Expand Up @@ -133,6 +134,9 @@ def estimate_entropies(self, x: ArrayLike, y: ArrayLike) -> DifferentialEntropie

mutual_information = h_x + h_y - h_xy

del space, x, y
gc.collect()

return DifferentialEntropies(
entropy_x=h_x,
entropy_y=h_y,
Expand Down
4 changes: 4 additions & 0 deletions src/bmi/estimators/ksg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

_AllowedContinuousMetric = Literal["euclidean", "manhattan", "chebyshev"]

import gc

class KSGEnsembleParameters(pydantic.BaseModel):
neighborhoods: list[int]
Expand Down Expand Up @@ -124,12 +125,15 @@ def fit(self, x: ArrayLike, y: ArrayLike) -> None:
# We calculate mean(digammas) over all the points rather than the batch
digammas_mean_contribution = np.sum(digammas / n_points)
digammas_dict[k].append(digammas_mean_contribution)
del x, y

for k, digammas in digammas_dict.items():
# As the mean over all the points was calculated for each chunk separately,
# we should add the contributions from each chunk
mi_estimate = _DIGAMMA(k) - np.sum(digammas) + _DIGAMMA(n_points)
self._mi_dict[k] = max(0.0, mi_estimate)

gc.collect()

self._fitted = True

Expand Down
21 changes: 17 additions & 4 deletions src/bmi/estimators/neural/_basic_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from bmi.estimators.neural._training_log import TrainingLog
from bmi.estimators.neural._types import BatchedPoints, Critic, Point

import gc

def get_batch(xs: BatchedPoints, ys: BatchedPoints, key: jax.Array, batch_size: Optional[int]):
if batch_size is not None:
Expand Down Expand Up @@ -37,7 +38,7 @@ def basic_training(
max_n_steps: int = 2_000,
early_stopping: bool = True,
learning_rate: float = 0.1,
verbose: bool = True,
verbose: bool = True
) -> tuple[TrainingLog, eqx.Module]:
"""Simple training loop, which samples mini-batches
from (xs, ys) and maximizes mutual information according to
Expand Down Expand Up @@ -72,10 +73,15 @@ def loss(critic, xs, ys):
training_log = TrainingLog(
max_n_steps=max_n_steps, early_stopping=early_stopping, verbose=verbose
)
keys = jax.random.split(rng, max_n_steps)
for n_step, key in enumerate(keys, start=1):
# We no longer use the following line here and in '_mine_estimator.py':
# keys = jax.random.split(rng, max_n_steps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice fix! Would you mind adding more in-text context, e.g.,

# We don't use 
# keys = jax.random.split(rng, max_n_steps)
# because of memory leaks. See:
# https://github.com/jax-ml/jax/issues/17432

for future reference, so we don't forget the reason why to avoid this in the future?

# Due to memory leaks outlined in:
# https://github.com/jax-ml/jax/issues/17432
key = rng
for n_step in range(1, max_n_steps+1):
# run step
batch_xs, batch_ys = get_batch(xs, ys, key, batch_size)
key, subkey = jax.random.split(key) # new
batch_xs, batch_ys = get_batch(xs, ys, subkey, batch_size)
critic, opt_state, mi_train = step(critic, opt_state, batch_xs, batch_ys)

# logging train
Expand All @@ -90,6 +96,13 @@ def loss(critic, xs, ys):
if training_log.early_stop():
break

batch_xs.delete()
batch_ys.delete()

del batch_xs, batch_ys

training_log.finish()
jax.clear_caches() # clears jit/compilation & staging caches
gc.collect() # free Python objects if you dropped all refs

return training_log, critic
1 change: 1 addition & 0 deletions src/bmi/estimators/neural/_critics.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
self.layers.append(eqx.nn.Linear(dims[-1], 1, key=key_final))

def __call__(self, x: Point, y: Point) -> jax.Array:
# print(f"Critic - x shape {x.shape}, y shape {y.shape}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be removed.

z = jnp.concatenate([x, y])

for layer in self.layers[:-1]:
Expand Down
14 changes: 11 additions & 3 deletions src/bmi/estimators/neural/_mine_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from bmi.interface import BaseModel, EstimateResult, IMutualInformationPointEstimator
from bmi.utils import ProductSpace

import gc

def logmeanexp(vs):
return logsumexp(vs) - jnp.log(len(vs))
Expand Down Expand Up @@ -211,9 +212,10 @@ def step(
training_log = TrainingLog(
max_n_steps=max_n_steps, early_stopping=early_stopping, verbose=verbose
)
keys = jax.random.split(rng, max_n_steps)
for n_step, key in enumerate(keys, start=1):
key_sample, key_test = jax.random.split(key)
key = rng
for n_step in range(1, max_n_steps+1):
key, key_step = jax.random.split(key)
key_sample, key_test = jax.random.split(key_step)

# sample
xs_batch, ys_batch_paired, ys_batch_unpaired = _sample_paired_unpaired(
Expand Down Expand Up @@ -245,8 +247,14 @@ def step(
# early stop?
if training_log.early_stop():
break
xs_batch.delete()
ys_batch_paired.delete()
ys_batch_unpaired.delete()
del xs_batch, ys_batch_paired, ys_batch_unpaired

training_log.finish()
jax.clear_caches() # clears jit/compilation & staging caches
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice fix and explanation!

gc.collect() # free Python objects if you dropped all refs

return training_log, critic

Expand Down