Skip to content

Commit 370c217

Browse files
committed
add end-to-end-test and improve mean
1 parent 2eadb5a commit 370c217

File tree

2 files changed

+218
-62
lines changed

2 files changed

+218
-62
lines changed

pina/solver/autoregressive_solver/autoregressive_solver.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(
4747
scheduler=None,
4848
weighting=None,
4949
use_lt=False,
50-
N_epochs_with_same_weights=10,
50+
reset_weighting_at_epoch_start=True,
5151
):
5252
"""
5353
Initialization of the :class:`AutoregressiveSolver` class.
@@ -69,8 +69,11 @@ def __init__(
6969
If ``None``, uniform weighting is used. Default is ``None``.
7070
:param bool use_lt: Whether to use LabelTensors.
7171
Default is ``False``.
72-
:param int N_epochs_with_same_weights: Number of epochs to keep the same adaptive weights
73-
before recomputing them. Default is ``10``.
72+
:param bool reset_weighting_at_epoch_start: If ``True``, resets
73+
the running averages used for adaptive weighting at the start
74+
of each epoch. Default is ``True``. This parameter is for an advanced
75+
use case, setting it to False can improve stability, especially
76+
when data per epoch are very scarse.
7477
"""
7578

7679
super().__init__(
@@ -82,11 +85,9 @@ def __init__(
8285
weighting=weighting,
8386
use_lt=use_lt,
8487
)
85-
# cache for per-condition adaptive weights and epoch-based update control
86-
# this is the most generic way to implement periodic weight updates I found
87-
self._cached_weights = {}
88-
self._epochs_since_update = 0
89-
self.N_epochs_with_same_weights = N_epochs_with_same_weights
88+
self._running_avg_step_losses = {}
89+
self._running_step_counts = {}
90+
self.reset_weighting_at_epoch_start = reset_weighting_at_epoch_start
9091

9192
@staticmethod
9293
def unroll(
@@ -165,7 +166,9 @@ def decide_starting_indices(
165166

166167
return indices
167168

168-
def loss_data(self, unroll, eps=None, aggregation_strategy=None, condition_name=None):
169+
def loss_data(
170+
self, unroll, eps=None, aggregation_strategy=None, condition_name=None
171+
):
169172
"""
170173
Compute the autoregressive multi-step data loss.
171174
@@ -197,32 +200,68 @@ def loss_data(self, unroll, eps=None, aggregation_strategy=None, condition_name=
197200
step_loss = self._loss_fn(predicted_state, target_state)
198201
losses.append(step_loss)
199202

200-
if logger.isEnabledFor(logging.DEBUG) and (step <= 3 or torch.isnan(step_loss)):
203+
if step <= 3 or torch.isnan(step_loss):
201204
logger.debug(
202205
" Step %d: loss=%.4e, pred=[%.3f, %.3f]",
203206
step,
204207
float(step_loss.item()),
205-
float(predicted_state.min()),
206-
float(predicted_state.max()),
208+
float(predicted_state.detach().min()),
209+
float(predicted_state.detach().max()),
207210
)
208-
211+
209212
current_state = predicted_state
210213

211214
step_losses = torch.stack(losses) # [unroll_length]
212215

213216
with torch.no_grad():
214217
condition_name = condition_name or "default"
215218
weights = self.get_weights(condition_name, step_losses, eps)
216-
if logger.isEnabledFor(logging.DEBUG):
217-
logger.debug(" Losses: %s", step_losses.detach().cpu().numpy().round(4))
218-
logger.debug(" Weights: %s", weights.cpu().numpy().round(4))
219-
logger.debug(" Weight ratio: %.1f", float(weights.max() / weights.min()))
219+
220+
logger.debug(
221+
" Losses: %s", step_losses.detach().cpu().numpy().round(4)
222+
)
223+
logger.debug(" Weights: %s", weights.cpu().numpy().round(4))
224+
logger.debug(
225+
" Weight ratio: %.1f", float(weights.max() / weights.min())
226+
)
220227

221228
if aggregation_strategy is None:
222229
aggregation_strategy = torch.sum
223230

224231
return aggregation_strategy(step_losses * weights)
225232

233+
def get_weights(self, condition_name, step_losses, eps):
234+
"""
235+
Return cached weights or compute new ones.
236+
:param str condition_name: Name of the condition.
237+
:param torch.Tensor step_losses: 1D tensor of per-step losses.
238+
:param float eps: Weighting parameter.
239+
:return: Weights tensor.
240+
:rtype: torch.Tensor
241+
"""
242+
key = condition_name or "default"
243+
x = step_losses.detach()
244+
245+
if x.dim() != 1:
246+
raise ValueError(
247+
f"step_losses must be a 1D tensor, got shape {x.shape}"
248+
)
249+
250+
if key not in self._running_avg_step_losses:
251+
self._running_avg_step_losses[key] = x.clone()
252+
self._running_step_counts[key] = 1
253+
else:
254+
self._running_step_counts[key] += 1
255+
k = self._running_step_counts[key]
256+
# update running average
257+
self._running_avg_step_losses[key] += (
258+
x - self._running_avg_step_losses[key]
259+
) / k
260+
261+
return self._compute_adaptive_weights(
262+
self._running_avg_step_losses[key], eps
263+
)
264+
226265
def _compute_adaptive_weights(self, step_losses, eps):
227266
"""
228267
Actual computation of adaptive weights.
@@ -231,38 +270,25 @@ def _compute_adaptive_weights(self, step_losses, eps):
231270
:return: Computed weights tensor.
232271
:rtype: torch.Tensor
233272
"""
234-
print(f"updating weights, eps={eps}")
273+
logger.debug(f"updating weights, eps={eps}")
235274

236275
if eps is None:
237276
return torch.ones_like(step_losses) / step_losses.numel()
238277

278+
# normalize to mean 1 (avoid too large exponents)
279+
step_losses = step_losses / (step_losses.mean() + 1e-12)
280+
239281
log_w = torch.clamp(-eps * torch.cumsum(step_losses, dim=0), -20, 20)
240282
return torch.softmax(log_w, dim=0)
241283

242-
def get_weights(self, condition_name, step_losses, eps):
243-
"""
244-
Return cached weights or compute new ones.
245-
:param str condition_name: Name of the condition.
246-
:param torch.Tensor step_losses: 1D tensor of per-step losses.
247-
:param float eps: Weighting parameter.
248-
:return: Weights tensor.
249-
:rtype: torch.Tensor
250-
"""
251-
cached = self._cached_weights.get(condition_name, None)
252-
if cached is None:
253-
cached = self._compute_adaptive_weights(step_losses, eps).cpu()
254-
self._cached_weights[condition_name] = cached
255-
return cached.to(step_losses.device)
256-
257-
def on_train_epoch_end(self):
284+
def on_train_epoch_start(self):
258285
"""
259-
Hook called by Lightning at the end of each epoch.
260-
Forces periodic recalculation of weights by clearing the cache.
286+
Hook called by Lightning at the beginning of each epoch.
287+
Forces periodic cleaning of he dictionaries used for weighting estimate.
261288
"""
262-
self._epochs_since_update += 1
263-
if self._epochs_since_update >= self.N_epochs_with_same_weights:
264-
self._cached_weights.clear()
265-
self._epochs_since_update = 0
289+
if self.reset_weighting_at_epoch_start:
290+
self._running_avg_step_losses.clear()
291+
self._running_step_counts.clear()
266292

267293
def predict(self, initial_state, num_steps):
268294
"""

0 commit comments

Comments
 (0)