11from collections import defaultdict
2+ from typing import cast
23
34import torch
45import torch .distributed as dist
56from mmengine .dist import get_world_size
6- from pydantic import BaseModel , ConfigDict , model_validator
7+ from pydantic import BaseModel , ConfigDict , TypeAdapter , model_validator
78from torch import nn
89from torch .utils .hooks import RemovableHandle
910from typing_extensions import TypedDict
@@ -44,6 +45,9 @@ class InternalMetrics(TypedDict, total=False):
4445 attn_max_logits : dict [str , float ]
4546
4647
48+ internal_metrics_adapter = TypeAdapter (InternalMetrics )
49+
50+
4751class InternalMetricsConfig (BaseModel ):
4852 model_config = ConfigDict (extra = "forbid" )
4953 internal_metrics_interval : int | None = None
@@ -71,30 +75,38 @@ class InternalMetricsRecorder:
7175 def __init__ (self , internal_metrics_cfg : InternalMetricsConfig , model : XTunerBaseModel ):
7276 self .internal_metrics_cfg = internal_metrics_cfg
7377 self .model = model
78+
7479 self .hooks : list [RemovableHandle ] = []
80+
7581 self ._attn_monitor_type : str | None = None
7682 self .attn_max_lse : dict [str , torch .Tensor ] = {}
7783 self .attn_max_logits : dict [str , torch .Tensor ] = {}
84+
7885 self .metrics = self ._init_metrics_dict ()
7986 self ._closed = False
8087
8188 def _init_metrics_dict (self ) -> InternalMetrics :
8289 metrics : InternalMetrics = {}
90+
8391 if self .internal_metrics_cfg .monitor_weights_rms_norm :
8492 metrics ["weight_rms" ] = {}
93+
8594 if self .internal_metrics_cfg .monitor_attn_logits_stats :
8695 attn_cfg : MHAConfig | MLAConfig = self .model .config .attention # type: ignore[attr-defined]
96+
8797 if isinstance (attn_cfg , MLAConfig ):
8898 attn_impl = "flash_attention"
8999 else :
90100 attn_impl = attn_cfg .attn_impl
101+
91102 if attn_impl == "eager_attention" :
92103 # We typically won't use eager attn, but implement it here anyway
93104 self ._attn_monitor_type = "attn_logits"
94105 metrics ["attn_max_logits" ] = {}
95106 elif not (DEVICE == "npu" and attn_impl == "flash_attention" ):
96107 self ._attn_monitor_type = "softmax_lse"
97108 metrics ["attn_max_lse" ] = {}
109+
98110 for module in self .model .modules ():
99111 if isinstance (module , ATTENTION_CLS ):
100112 if self ._attn_monitor_type == "attn_logits" :
@@ -117,20 +129,28 @@ def _init_metrics_dict(self) -> InternalMetrics:
117129 def calculate_module_weight_rms (self , module : nn .Module , layer_name : str , dtype : torch .dtype = torch .float32 ):
118130 """Calculate the RMS of the module's parameters."""
119131 self ._check_closed ()
132+
120133 if "weight_rms" not in self .metrics :
121134 return
135+
122136 all_params = [param .data for param in module .parameters () if param .requires_grad ]
137+
123138 if not all_params :
124139 return
140+
125141 grouped_params = group_tensors_by_device_mesh_and_placements (all_params ) # type: ignore[arg-type]
142+
126143 total_norms = []
127144 total_numel = 0
145+
128146 for params in grouped_params .values ():
129147 total_norm = cal_total_norm (params , norm_type = 2.0 , foreach = True , dtype = dtype )
130148 total_norms .append (total_norm )
131149 total_numel += sum (p .numel () for p in params )
150+
132151 param_l2_norm = torch .linalg .vector_norm (torch .stack (total_norms ), ord = 2.0 , dtype = dtype )
133152 param_rms = param_l2_norm / total_numel ** 0.5
153+
134154 self .metrics ["weight_rms" ][layer_name ] = param_rms .item ()
135155
136156 def register_attn_output_hook (self , module : nn .Module ):
@@ -140,6 +160,7 @@ def register_attn_output_hook(self, module: nn.Module):
140160 def hook (module , input , output ):
141161 if output .get ("softmax_lse" ) is not None :
142162 self .attn_max_lse [module .name ] = torch .max (self .attn_max_lse [module .name ], output ["softmax_lse" ].max ())
163+
143164 if output .get ("attn_logits" ) is not None :
144165 self .attn_max_logits [module .name ] = max (self .attn_max_logits [module .name ], output ["attn_logits" ].max ())
145166
@@ -150,9 +171,11 @@ def hook(module, input, output):
150171 def pop_metrics (self , data_batches : list [ModelItem ]):
151172 """Run a dummy forward to get metrics."""
152173 self ._check_closed ()
174+
153175 for name , module in self .model .named_modules ():
154176 if self .internal_metrics_cfg .monitor_attn_logits_stats and isinstance (module , ATTENTION_CLS ):
155177 self .register_attn_output_hook (module )
178+
156179 if self .internal_metrics_cfg .monitor_weights_rms_norm and isinstance (module , RMS_NORM_MONITOR_MODULES ):
157180 self .calculate_module_weight_rms (module , self ._clean_module_name (name ), dtype = torch .float32 )
158181
@@ -172,6 +195,7 @@ def pop_metrics(self, data_batches: list[ModelItem]):
172195 data_batch = data_batches [i ]
173196 seq_ctx = data_batch ["seq_ctx" ]
174197 output = self .model (seq_ctx = seq_ctx , loss_ctx = None , ** additional_kwargs )
198+
175199 if (
176200 self .internal_metrics_cfg .monitor_moe_load_balance_stats
177201 and (cur_tokens_per_expert := output .get ("tokens_per_expert_global" )) is not None
@@ -195,15 +219,19 @@ def pop_metrics(self, data_batches: list[ModelItem]):
195219 if tokens_per_expert_global is not None :
196220 avg_count_load = tokens_per_expert_global .mean (1 )
197221 max_load_i = torch .amax (tokens_per_expert_global , dim = 1 )
222+
198223 maxvio_all_layers = (max_load_i - avg_count_load ) / avg_count_load
224+
199225 centered_tokens_per_expert = tokens_per_expert_global - avg_count_load [:, None ]
200226 drop_ratio_all_layers = (centered_tokens_per_expert ).abs ().mean (dim = 1 ) / avg_count_load
227+
201228 if "drop_ratio" in self .metrics :
202229 self .metrics ["drop_ratio" ].update (
203230 {f"layer{ idx } " : drop_ratio_all_layers [idx ].item () for idx in range (drop_ratio_all_layers .shape [0 ])}
204231 )
205232 drop_ratio = drop_ratio_all_layers .mean ()
206233 self .metrics ["drop_ratio" ]["total" ] = drop_ratio .item ()
234+
207235 if "maxvio" in self .metrics :
208236 self .metrics ["maxvio" ].update (
209237 {f"layer{ idx } " : maxvio_all_layers [idx ].item () for idx in range (max_load_i .shape [0 ])}
@@ -216,23 +244,27 @@ def pop_metrics(self, data_batches: list[ModelItem]):
216244 # [bsz/intra_layer_micro_batch, ]
217245 local_router_logits_max = torch .max (torch .stack (router_logits_list ))
218246 dist .all_reduce (local_router_logits_max , op = dist .ReduceOp .MAX )
247+
219248 self .metrics ["router_logits_max" ][layer_name ] = local_router_logits_max .item ()
220249
221250 if "router_logits_mean" in self .metrics and router_logits_mean :
222251 for layer_name , router_logits_list in router_logits_mean .items ():
223252 # [bsz/intra_layer_micro_batch, ]
224253 local_router_logits_mean = torch .mean (torch .stack (router_logits_list ))
225254 dist .all_reduce (local_router_logits_mean .div_ (get_world_size ()), op = dist .ReduceOp .SUM )
255+
226256 self .metrics ["router_logits_mean" ][layer_name ] = local_router_logits_mean .item ()
227257
228258 if "attn_max_lse" in self .metrics and self ._attn_monitor_type == "softmax_lse" :
229259 for layer_name , local_attn_max_lse in self .attn_max_lse .items ():
230260 dist .all_reduce (local_attn_max_lse , op = dist .ReduceOp .MAX )
261+
231262 self .metrics ["attn_max_lse" ][layer_name ] = local_attn_max_lse .item ()
232263
233264 if "attn_max_logits" in self .metrics and self ._attn_monitor_type == "attn_logits" :
234265 for layer_name , local_attn_max_logits in self .attn_max_logits .items ():
235266 dist .all_reduce (local_attn_max_logits , op = dist .ReduceOp .MAX )
267+
236268 self .metrics ["attn_max_logits" ][layer_name ] = local_attn_max_logits .item ()
237269
238270 self ._maybe_reset_attn_max_lse_or_logits (self .attn_max_lse )
@@ -241,7 +273,7 @@ def pop_metrics(self, data_batches: list[ModelItem]):
241273 for hook in self .hooks :
242274 hook .remove ()
243275
244- return self .metrics
276+ return internal_metrics_adapter . validate_python ( self .metrics )
245277
246278 def close (self ):
247279 if not self ._closed :
@@ -266,6 +298,7 @@ def __del__(self):
266298 def _maybe_reset_attn_max_lse_or_logits (self , target : dict [str , torch .Tensor ]):
267299 if not target :
268300 return
301+
269302 for v in target .values ():
270303 if isinstance (v , torch .Tensor ):
271304 v .fill_ (SMALL_VAL )
@@ -293,16 +326,12 @@ def need_dummy_forward(self) -> bool:
293326
294327
295328def flatten_internal_metrics_for_logs (metrics : InternalMetrics , sep : str = "/" ) -> dict :
329+ internal_metrics_adapter .validate_python (metrics )
330+
296331 items = []
297332 for name , sub_metrics in metrics .items ():
298- if isinstance (sub_metrics , dict ):
299- for k , v in sub_metrics .items ():
300- if isinstance (v , (float , int )):
301- items .append ((f"{ name } { sep } { k } " , v ))
302- else :
303- raise ValueError (f"Unsupported metric value type: expected float or int, but got { type (v )} " )
304- else :
305- raise ValueError (
306- f"Unsupported metric type for internal metrics: expected dict, but got { type (sub_metrics )} "
307- )
333+ sub_metrics = cast (dict , sub_metrics )
334+ for k , v in sub_metrics .items ():
335+ items .append ((f"{ name } { sep } { k } " , v ))
336+
308337 return dict (items )
0 commit comments