Skip to content

Commit e368d87

Browse files
nil0x9HAOCHENYE
authored andcommitted
[Refactor] refactor internal_metrics and ut code
1 parent c52127d commit e368d87

File tree

2 files changed

+70
-34
lines changed

2 files changed

+70
-34
lines changed

tests/utils/test_internal_metrics.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from sys import intern
23
from typing import cast
34
import torch
45
from torch import nn
@@ -17,7 +18,7 @@
1718
from xtuner.v1.module.grouped_linear.moe_group_linear import GroupedLinear
1819
from xtuner.v1.float8.float8_gmm_tile_wise import TileWiseFloat8GroupedLinear
1920
from xtuner.v1.utils import internal_metrics
20-
from xtuner.v1.utils.internal_metrics import InternalMetricsConfig, InternalMetricsRecorder
21+
from xtuner.v1.utils.internal_metrics import InternalMetricsConfig, InternalMetricsRecorder, internal_metrics_adapter
2122
from xtuner.v1.utils.device import get_device
2223

2324

@@ -98,38 +99,44 @@ def test_internal_metrics_run(self):
9899
metrics = metrics_recorder.pop_metrics(data_batches)
99100

100101
# Check that all expected top-level keys exist
101-
assert "weight_rms" in metrics
102-
assert "router_logits_max" in metrics
103-
assert "router_logits_mean" in metrics
104-
assert "maxvio" in metrics
105-
assert "drop_ratio" in metrics
102+
self.assertIn("weight_rms", metrics, "Expected `weight_rms` in metrics")
103+
self.assertIn("router_logits_max", metrics, "Expected `router_logits_max` in metrics")
104+
self.assertIn("router_logits_mean", metrics, "Expected `router_logits_mean` in metrics")
105+
self.assertIn("maxvio", metrics, "Expected `maxvio` in metrics")
106+
self.assertIn("drop_ratio", metrics, "Expected `drop_ratio` in metrics")
107+
108+
internal_metrics_adapter.validate_python(metrics)
106109

107110
if DEVICE != "npu":
108-
assert "attn_max_lse" in metrics or "attn_max_logits" in metrics
111+
self.assertTrue(
112+
any(key in metrics for key in ["attn_max_lse", "attn_max_logits"]),
113+
"Expected either `attn_max_lse` or `attn_max_logits` in metrics"
114+
)
109115

110116
# Check that all values are valid floats (not NaN or Inf)
111117
for metric_name, metric_dict in metrics.items():
112-
assert isinstance(metric_dict, dict), f"{metric_name} should be a dict"
113-
for key, value in metric_dict.items():
114-
assert isinstance(value, float), f"{metric_name}[{key}] should be float"
115-
assert not torch.isnan(torch.tensor(value)), f"{metric_name}[{key}] is NaN"
116-
assert not torch.isinf(torch.tensor(value)), f"{metric_name}[{key}] is Inf"
118+
self.assertIsInstance(metric_dict, dict, f"{metric_name} should be a dict")
119+
for key, value in cast(dict, metric_dict).items():
120+
self.assertIsInstance(value, float, f"{metric_name}[{key}] should be float")
121+
self.assertFalse(torch.isnan(torch.tensor(value)), f"{metric_name}[{key}] is NaN")
122+
self.assertFalse(torch.isinf(torch.tensor(value)), f"{metric_name}[{key}] is Inf")
117123

118124
for key in ["embed_tokens", "lm_head"] + [f"layers.{i}" for i in range(model.config.num_hidden_layers)]:
119-
assert key in metrics["weight_rms"], f"key: {key}, weight_rms: {metrics['weight_rms']}"
125+
self.assertIn(key, metrics["weight_rms"]) # type: ignore[attr-defined]
120126

121127
for key in [f"layer{i}" for i in range(model.config.num_hidden_layers)]:
122-
assert key in metrics["maxvio"], f"key: {key}, maxvio: {metrics['maxvio']}"
123-
assert key in metrics["drop_ratio"], f"key: {key}, drop_ratio: {metrics['drop_ratio']}"
124-
assert key in metrics["router_logits_max"], f"key: {key}, router_logits_max: {metrics['router_logits_max']}"
125-
assert key in metrics["router_logits_mean"], f"key: {key}, router_logits_mean: {metrics['router_logits_mean']}"
128+
self.assertIn(key, metrics["maxvio"]) # type: ignore[attr-defined]
129+
self.assertIn(key, metrics["drop_ratio"]) # type: ignore[attr-defined]
130+
self.assertIn(key, metrics["router_logits_max"]) # type: ignore[attr-defined]
131+
self.assertIn(key, metrics["router_logits_mean"]) # type: ignore[attr-defined]
126132

127133
if DEVICE != "npu":
128134
for layer in range(model.config.num_hidden_layers):
129-
assert (
130-
f"layers.{layer}.self_attn" in metrics["attn_max_lse"] or # type: ignore[attr-defined]
131-
f"layers.{layer}.self_attn" in metrics["attn_max_logits"] # type: ignore[attr-defined]
135+
self.assertTrue(
136+
f"layers.{layer}.self_attn" in metrics.get("attn_max_lse", {}) or # type: ignore[attr-defined]
137+
f"layers.{layer}.self_attn" in metrics.get("attn_max_logits", {}), # type: ignore[attr-defined]
138+
f"Expected `layers.{layer}.self_attn` in either `attn_max_lse` or `attn_max_logits`"
132139
)
133140

134-
assert "total" in metrics["maxvio"]
135-
assert "total" in metrics["drop_ratio"]
141+
self.assertIn("total", metrics["maxvio"]) # type: ignore[attr-defined]
142+
self.assertIn("total", metrics["drop_ratio"]) # type: ignore[attr-defined]

xtuner/v1/utils/internal_metrics.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from collections import defaultdict
2+
from typing import cast
23

34
import torch
45
import torch.distributed as dist
56
from mmengine.dist import get_world_size
6-
from pydantic import BaseModel, ConfigDict, model_validator
7+
from pydantic import BaseModel, ConfigDict, TypeAdapter, model_validator
78
from torch import nn
89
from torch.utils.hooks import RemovableHandle
910
from typing_extensions import TypedDict
@@ -44,6 +45,9 @@ class InternalMetrics(TypedDict, total=False):
4445
attn_max_logits: dict[str, float]
4546

4647

48+
internal_metrics_adapter = TypeAdapter(InternalMetrics)
49+
50+
4751
class InternalMetricsConfig(BaseModel):
4852
model_config = ConfigDict(extra="forbid")
4953
internal_metrics_interval: int | None = None
@@ -71,30 +75,38 @@ class InternalMetricsRecorder:
7175
def __init__(self, internal_metrics_cfg: InternalMetricsConfig, model: XTunerBaseModel):
7276
self.internal_metrics_cfg = internal_metrics_cfg
7377
self.model = model
78+
7479
self.hooks: list[RemovableHandle] = []
80+
7581
self._attn_monitor_type: str | None = None
7682
self.attn_max_lse: dict[str, torch.Tensor] = {}
7783
self.attn_max_logits: dict[str, torch.Tensor] = {}
84+
7885
self.metrics = self._init_metrics_dict()
7986
self._closed = False
8087

8188
def _init_metrics_dict(self) -> InternalMetrics:
8289
metrics: InternalMetrics = {}
90+
8391
if self.internal_metrics_cfg.monitor_weights_rms_norm:
8492
metrics["weight_rms"] = {}
93+
8594
if self.internal_metrics_cfg.monitor_attn_logits_stats:
8695
attn_cfg: MHAConfig | MLAConfig = self.model.config.attention # type: ignore[attr-defined]
96+
8797
if isinstance(attn_cfg, MLAConfig):
8898
attn_impl = "flash_attention"
8999
else:
90100
attn_impl = attn_cfg.attn_impl
101+
91102
if attn_impl == "eager_attention":
92103
# We typically won't use eager attn, but implement it here anyway
93104
self._attn_monitor_type = "attn_logits"
94105
metrics["attn_max_logits"] = {}
95106
elif not (DEVICE == "npu" and attn_impl == "flash_attention"):
96107
self._attn_monitor_type = "softmax_lse"
97108
metrics["attn_max_lse"] = {}
109+
98110
for module in self.model.modules():
99111
if isinstance(module, ATTENTION_CLS):
100112
if self._attn_monitor_type == "attn_logits":
@@ -117,20 +129,28 @@ def _init_metrics_dict(self) -> InternalMetrics:
117129
def calculate_module_weight_rms(self, module: nn.Module, layer_name: str, dtype: torch.dtype = torch.float32):
118130
"""Calculate the RMS of the module's parameters."""
119131
self._check_closed()
132+
120133
if "weight_rms" not in self.metrics:
121134
return
135+
122136
all_params = [param.data for param in module.parameters() if param.requires_grad]
137+
123138
if not all_params:
124139
return
140+
125141
grouped_params = group_tensors_by_device_mesh_and_placements(all_params) # type: ignore[arg-type]
142+
126143
total_norms = []
127144
total_numel = 0
145+
128146
for params in grouped_params.values():
129147
total_norm = cal_total_norm(params, norm_type=2.0, foreach=True, dtype=dtype)
130148
total_norms.append(total_norm)
131149
total_numel += sum(p.numel() for p in params)
150+
132151
param_l2_norm = torch.linalg.vector_norm(torch.stack(total_norms), ord=2.0, dtype=dtype)
133152
param_rms = param_l2_norm / total_numel**0.5
153+
134154
self.metrics["weight_rms"][layer_name] = param_rms.item()
135155

136156
def register_attn_output_hook(self, module: nn.Module):
@@ -140,6 +160,7 @@ def register_attn_output_hook(self, module: nn.Module):
140160
def hook(module, input, output):
141161
if output.get("softmax_lse") is not None:
142162
self.attn_max_lse[module.name] = torch.max(self.attn_max_lse[module.name], output["softmax_lse"].max())
163+
143164
if output.get("attn_logits") is not None:
144165
self.attn_max_logits[module.name] = max(self.attn_max_logits[module.name], output["attn_logits"].max())
145166

@@ -150,9 +171,11 @@ def hook(module, input, output):
150171
def pop_metrics(self, data_batches: list[ModelItem]):
151172
"""Run a dummy forward to get metrics."""
152173
self._check_closed()
174+
153175
for name, module in self.model.named_modules():
154176
if self.internal_metrics_cfg.monitor_attn_logits_stats and isinstance(module, ATTENTION_CLS):
155177
self.register_attn_output_hook(module)
178+
156179
if self.internal_metrics_cfg.monitor_weights_rms_norm and isinstance(module, RMS_NORM_MONITOR_MODULES):
157180
self.calculate_module_weight_rms(module, self._clean_module_name(name), dtype=torch.float32)
158181

@@ -172,6 +195,7 @@ def pop_metrics(self, data_batches: list[ModelItem]):
172195
data_batch = data_batches[i]
173196
seq_ctx = data_batch["seq_ctx"]
174197
output = self.model(seq_ctx=seq_ctx, loss_ctx=None, **additional_kwargs)
198+
175199
if (
176200
self.internal_metrics_cfg.monitor_moe_load_balance_stats
177201
and (cur_tokens_per_expert := output.get("tokens_per_expert_global")) is not None
@@ -195,15 +219,19 @@ def pop_metrics(self, data_batches: list[ModelItem]):
195219
if tokens_per_expert_global is not None:
196220
avg_count_load = tokens_per_expert_global.mean(1)
197221
max_load_i = torch.amax(tokens_per_expert_global, dim=1)
222+
198223
maxvio_all_layers = (max_load_i - avg_count_load) / avg_count_load
224+
199225
centered_tokens_per_expert = tokens_per_expert_global - avg_count_load[:, None]
200226
drop_ratio_all_layers = (centered_tokens_per_expert).abs().mean(dim=1) / avg_count_load
227+
201228
if "drop_ratio" in self.metrics:
202229
self.metrics["drop_ratio"].update(
203230
{f"layer{idx}": drop_ratio_all_layers[idx].item() for idx in range(drop_ratio_all_layers.shape[0])}
204231
)
205232
drop_ratio = drop_ratio_all_layers.mean()
206233
self.metrics["drop_ratio"]["total"] = drop_ratio.item()
234+
207235
if "maxvio" in self.metrics:
208236
self.metrics["maxvio"].update(
209237
{f"layer{idx}": maxvio_all_layers[idx].item() for idx in range(max_load_i.shape[0])}
@@ -216,23 +244,27 @@ def pop_metrics(self, data_batches: list[ModelItem]):
216244
# [bsz/intra_layer_micro_batch, ]
217245
local_router_logits_max = torch.max(torch.stack(router_logits_list))
218246
dist.all_reduce(local_router_logits_max, op=dist.ReduceOp.MAX)
247+
219248
self.metrics["router_logits_max"][layer_name] = local_router_logits_max.item()
220249

221250
if "router_logits_mean" in self.metrics and router_logits_mean:
222251
for layer_name, router_logits_list in router_logits_mean.items():
223252
# [bsz/intra_layer_micro_batch, ]
224253
local_router_logits_mean = torch.mean(torch.stack(router_logits_list))
225254
dist.all_reduce(local_router_logits_mean.div_(get_world_size()), op=dist.ReduceOp.SUM)
255+
226256
self.metrics["router_logits_mean"][layer_name] = local_router_logits_mean.item()
227257

228258
if "attn_max_lse" in self.metrics and self._attn_monitor_type == "softmax_lse":
229259
for layer_name, local_attn_max_lse in self.attn_max_lse.items():
230260
dist.all_reduce(local_attn_max_lse, op=dist.ReduceOp.MAX)
261+
231262
self.metrics["attn_max_lse"][layer_name] = local_attn_max_lse.item()
232263

233264
if "attn_max_logits" in self.metrics and self._attn_monitor_type == "attn_logits":
234265
for layer_name, local_attn_max_logits in self.attn_max_logits.items():
235266
dist.all_reduce(local_attn_max_logits, op=dist.ReduceOp.MAX)
267+
236268
self.metrics["attn_max_logits"][layer_name] = local_attn_max_logits.item()
237269

238270
self._maybe_reset_attn_max_lse_or_logits(self.attn_max_lse)
@@ -241,7 +273,7 @@ def pop_metrics(self, data_batches: list[ModelItem]):
241273
for hook in self.hooks:
242274
hook.remove()
243275

244-
return self.metrics
276+
return internal_metrics_adapter.validate_python(self.metrics)
245277

246278
def close(self):
247279
if not self._closed:
@@ -266,6 +298,7 @@ def __del__(self):
266298
def _maybe_reset_attn_max_lse_or_logits(self, target: dict[str, torch.Tensor]):
267299
if not target:
268300
return
301+
269302
for v in target.values():
270303
if isinstance(v, torch.Tensor):
271304
v.fill_(SMALL_VAL)
@@ -293,16 +326,12 @@ def need_dummy_forward(self) -> bool:
293326

294327

295328
def flatten_internal_metrics_for_logs(metrics: InternalMetrics, sep: str = "/") -> dict:
329+
internal_metrics_adapter.validate_python(metrics)
330+
296331
items = []
297332
for name, sub_metrics in metrics.items():
298-
if isinstance(sub_metrics, dict):
299-
for k, v in sub_metrics.items():
300-
if isinstance(v, (float, int)):
301-
items.append((f"{name}{sep}{k}", v))
302-
else:
303-
raise ValueError(f"Unsupported metric value type: expected float or int, but got {type(v)}")
304-
else:
305-
raise ValueError(
306-
f"Unsupported metric type for internal metrics: expected dict, but got {type(sub_metrics)}"
307-
)
333+
sub_metrics = cast(dict, sub_metrics)
334+
for k, v in sub_metrics.items():
335+
items.append((f"{name}{sep}{k}", v))
336+
308337
return dict(items)

0 commit comments

Comments
 (0)