@@ -47,7 +47,7 @@ def __init__(
4747 scheduler = None ,
4848 weighting = None ,
4949 use_lt = False ,
50- N_epochs_with_same_weights = 10 ,
50+ reset_weighting_at_epoch_start = True ,
5151 ):
5252 """
5353 Initialization of the :class:`AutoregressiveSolver` class.
@@ -69,8 +69,11 @@ def __init__(
6969 If ``None``, uniform weighting is used. Default is ``None``.
7070 :param bool use_lt: Whether to use LabelTensors.
7171 Default is ``False``.
72- :param int N_epochs_with_same_weights: Number of epochs to keep the same adaptive weights
73- before recomputing them. Default is ``10``.
72+ :param bool reset_weighting_at_epoch_start: If ``True``, resets
73+ the running averages used for adaptive weighting at the start
74+ of each epoch. Default is ``True``. This parameter is for an advanced
75+ use case, setting it to False can improve stability, especially
76+ when data per epoch are very scarse.
7477 """
7578
7679 super ().__init__ (
@@ -82,11 +85,9 @@ def __init__(
8285 weighting = weighting ,
8386 use_lt = use_lt ,
8487 )
85- # cache for per-condition adaptive weights and epoch-based update control
86- # this is the most generic way to implement periodic weight updates I found
87- self ._cached_weights = {}
88- self ._epochs_since_update = 0
89- self .N_epochs_with_same_weights = N_epochs_with_same_weights
88+ self ._running_avg_step_losses = {}
89+ self ._running_step_counts = {}
90+ self .reset_weighting_at_epoch_start = reset_weighting_at_epoch_start
9091
9192 @staticmethod
9293 def unroll (
@@ -165,7 +166,9 @@ def decide_starting_indices(
165166
166167 return indices
167168
168- def loss_data (self , unroll , eps = None , aggregation_strategy = None , condition_name = None ):
169+ def loss_data (
170+ self , unroll , eps = None , aggregation_strategy = None , condition_name = None
171+ ):
169172 """
170173 Compute the autoregressive multi-step data loss.
171174
@@ -197,32 +200,68 @@ def loss_data(self, unroll, eps=None, aggregation_strategy=None, condition_name=
197200 step_loss = self ._loss_fn (predicted_state , target_state )
198201 losses .append (step_loss )
199202
200- if logger . isEnabledFor ( logging . DEBUG ) and ( step <= 3 or torch .isnan (step_loss ) ):
203+ if step <= 3 or torch .isnan (step_loss ):
201204 logger .debug (
202205 " Step %d: loss=%.4e, pred=[%.3f, %.3f]" ,
203206 step ,
204207 float (step_loss .item ()),
205- float (predicted_state .min ()),
206- float (predicted_state .max ()),
208+ float (predicted_state .detach (). min ()),
209+ float (predicted_state .detach (). max ()),
207210 )
208-
211+
209212 current_state = predicted_state
210213
211214 step_losses = torch .stack (losses ) # [unroll_length]
212215
213216 with torch .no_grad ():
214217 condition_name = condition_name or "default"
215218 weights = self .get_weights (condition_name , step_losses , eps )
216- if logger .isEnabledFor (logging .DEBUG ):
217- logger .debug (" Losses: %s" , step_losses .detach ().cpu ().numpy ().round (4 ))
218- logger .debug (" Weights: %s" , weights .cpu ().numpy ().round (4 ))
219- logger .debug (" Weight ratio: %.1f" , float (weights .max () / weights .min ()))
219+
220+ logger .debug (
221+ " Losses: %s" , step_losses .detach ().cpu ().numpy ().round (4 )
222+ )
223+ logger .debug (" Weights: %s" , weights .cpu ().numpy ().round (4 ))
224+ logger .debug (
225+ " Weight ratio: %.1f" , float (weights .max () / weights .min ())
226+ )
220227
221228 if aggregation_strategy is None :
222229 aggregation_strategy = torch .sum
223230
224231 return aggregation_strategy (step_losses * weights )
225232
233+ def get_weights (self , condition_name , step_losses , eps ):
234+ """
235+ Return cached weights or compute new ones.
236+ :param str condition_name: Name of the condition.
237+ :param torch.Tensor step_losses: 1D tensor of per-step losses.
238+ :param float eps: Weighting parameter.
239+ :return: Weights tensor.
240+ :rtype: torch.Tensor
241+ """
242+ key = condition_name or "default"
243+ x = step_losses .detach ()
244+
245+ if x .dim () != 1 :
246+ raise ValueError (
247+ f"step_losses must be a 1D tensor, got shape { x .shape } "
248+ )
249+
250+ if key not in self ._running_avg_step_losses :
251+ self ._running_avg_step_losses [key ] = x .clone ()
252+ self ._running_step_counts [key ] = 1
253+ else :
254+ self ._running_step_counts [key ] += 1
255+ k = self ._running_step_counts [key ]
256+ # update running average
257+ self ._running_avg_step_losses [key ] += (
258+ x - self ._running_avg_step_losses [key ]
259+ ) / k
260+
261+ return self ._compute_adaptive_weights (
262+ self ._running_avg_step_losses [key ], eps
263+ )
264+
226265 def _compute_adaptive_weights (self , step_losses , eps ):
227266 """
228267 Actual computation of adaptive weights.
@@ -231,38 +270,25 @@ def _compute_adaptive_weights(self, step_losses, eps):
231270 :return: Computed weights tensor.
232271 :rtype: torch.Tensor
233272 """
234- print (f"updating weights, eps={ eps } " )
273+ logger . debug (f"updating weights, eps={ eps } " )
235274
236275 if eps is None :
237276 return torch .ones_like (step_losses ) / step_losses .numel ()
238277
278+ # normalize to mean 1 (avoid too large exponents)
279+ step_losses = step_losses / (step_losses .mean () + 1e-12 )
280+
239281 log_w = torch .clamp (- eps * torch .cumsum (step_losses , dim = 0 ), - 20 , 20 )
240282 return torch .softmax (log_w , dim = 0 )
241283
242- def get_weights (self , condition_name , step_losses , eps ):
243- """
244- Return cached weights or compute new ones.
245- :param str condition_name: Name of the condition.
246- :param torch.Tensor step_losses: 1D tensor of per-step losses.
247- :param float eps: Weighting parameter.
248- :return: Weights tensor.
249- :rtype: torch.Tensor
250- """
251- cached = self ._cached_weights .get (condition_name , None )
252- if cached is None :
253- cached = self ._compute_adaptive_weights (step_losses , eps ).cpu ()
254- self ._cached_weights [condition_name ] = cached
255- return cached .to (step_losses .device )
256-
257- def on_train_epoch_end (self ):
284+ def on_train_epoch_start (self ):
258285 """
259- Hook called by Lightning at the end of each epoch.
260- Forces periodic recalculation of weights by clearing the cache .
286+ Hook called by Lightning at the beginning of each epoch.
287+ Forces periodic cleaning of he dictionaries used for weighting estimate .
261288 """
262- self ._epochs_since_update += 1
263- if self ._epochs_since_update >= self .N_epochs_with_same_weights :
264- self ._cached_weights .clear ()
265- self ._epochs_since_update = 0
289+ if self .reset_weighting_at_epoch_start :
290+ self ._running_avg_step_losses .clear ()
291+ self ._running_step_counts .clear ()
266292
267293 def predict (self , initial_state , num_steps ):
268294 """
0 commit comments