Skip to content

Commit ac34d2c

Browse files
kevinzakkaclaude
andcommitted
Add MetricsManager for logging custom metrics during training
Adds a MetricsManager so users can log custom per-step metrics without hacking reward functions or adding zero-weight reward terms. Metrics terms use the same callable signature as rewards (env, **params) but have no weight, no dt scaling, and no normalization by episode length. Episode values are true per-step averages (sum / step_count) logged under "Episode_Metrics/{term_name}". Closes #584 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e78ebce commit ac34d2c

File tree

4 files changed

+288
-0
lines changed

4 files changed

+288
-0
lines changed

src/mjlab/envs/manager_based_rl_env.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
NullCurriculumManager,
2323
)
2424
from mjlab.managers.event_manager import EventManager, EventTermCfg
25+
from mjlab.managers.metrics_manager import (
26+
MetricsManager,
27+
MetricsTermCfg,
28+
NullMetricsManager,
29+
)
2530
from mjlab.managers.observation_manager import ObservationGroupCfg, ObservationManager
2631
from mjlab.managers.reward_manager import RewardManager, RewardTermCfg
2732
from mjlab.managers.termination_manager import TerminationManager, TerminationTermCfg
@@ -114,6 +119,9 @@ class ManagerBasedRlEnvCfg:
114119
curriculum: dict[str, CurriculumTermCfg] = field(default_factory=dict)
115120
"""Curriculum terms for adaptive difficulty."""
116121

122+
metrics: dict[str, MetricsTermCfg] = field(default_factory=dict)
123+
"""Custom metric terms for logging per-step values as episode averages."""
124+
117125
is_finite_horizon: bool = False
118126
"""Whether the task has a finite or infinite horizon. Defaults to False (infinite).
119127
@@ -291,6 +299,11 @@ def load_managers(self) -> None:
291299
else:
292300
self.curriculum_manager = NullCurriculumManager()
293301
print_info(f"[INFO] {self.curriculum_manager}")
302+
if len(self.cfg.metrics) > 0:
303+
self.metrics_manager = MetricsManager(self.cfg.metrics, self)
304+
else:
305+
self.metrics_manager = NullMetricsManager()
306+
print_info(f"[INFO] {self.metrics_manager}")
294307

295308
# Configure spaces for the environment.
296309
self._configure_gym_env_spaces()
@@ -367,6 +380,7 @@ def step(self, action: torch.Tensor) -> types.VecEnvStepReturn:
367380
self.reset_time_outs = self.termination_manager.time_outs
368381

369382
self.reward_buf = self.reward_manager.compute(dt=self.step_dt)
383+
self.metrics_manager.compute()
370384

371385
# Reset envs that terminated/timed-out and log the episode info.
372386
reset_env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1)
@@ -485,6 +499,9 @@ def _reset_idx(self, env_ids: torch.Tensor | None = None) -> None:
485499
# rewards manager.
486500
info = self.reward_manager.reset(env_ids)
487501
self.extras["log"].update(info)
502+
# metrics manager.
503+
info = self.metrics_manager.reset(env_ids)
504+
self.extras["log"].update(info)
488505
# curriculum manager.
489506
info = self.curriculum_manager.reset(env_ids)
490507
self.extras["log"].update(info)

src/mjlab/managers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from mjlab.managers.manager_base import ManagerBase as ManagerBase
1919
from mjlab.managers.manager_base import ManagerTermBase as ManagerTermBase
2020
from mjlab.managers.manager_base import ManagerTermBaseCfg as ManagerTermBaseCfg
21+
from mjlab.managers.metrics_manager import MetricsManager as MetricsManager
22+
from mjlab.managers.metrics_manager import MetricsTermCfg as MetricsTermCfg
23+
from mjlab.managers.metrics_manager import NullMetricsManager as NullMetricsManager
2124
from mjlab.managers.observation_manager import (
2225
ObservationGroupCfg as ObservationGroupCfg,
2326
)
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""Metrics manager for logging custom per-step metrics during training."""
2+
3+
from __future__ import annotations
4+
5+
from copy import deepcopy
6+
from dataclasses import dataclass
7+
from typing import TYPE_CHECKING, Sequence
8+
9+
import torch
10+
from prettytable import PrettyTable
11+
12+
from mjlab.managers.manager_base import ManagerBase, ManagerTermBaseCfg
13+
14+
if TYPE_CHECKING:
15+
from mjlab.envs.manager_based_rl_env import ManagerBasedRlEnv
16+
17+
18+
@dataclass(kw_only=True)
19+
class MetricsTermCfg(ManagerTermBaseCfg):
20+
"""Configuration for a metrics term."""
21+
22+
pass
23+
24+
25+
class MetricsManager(ManagerBase):
26+
"""Accumulates per-step metric values, reports episode averages.
27+
28+
Unlike rewards, metrics have no weight, no dt scaling, and no
29+
normalization by episode length. Episode values are true per-step
30+
averages (sum / step_count), so a metric in [0,1] stays in [0,1]
31+
in the logger.
32+
"""
33+
34+
_env: ManagerBasedRlEnv
35+
36+
def __init__(self, cfg: dict[str, MetricsTermCfg], env: ManagerBasedRlEnv):
37+
self._term_names: list[str] = list()
38+
self._term_cfgs: list[MetricsTermCfg] = list()
39+
self._class_term_cfgs: list[MetricsTermCfg] = list()
40+
41+
self.cfg = deepcopy(cfg)
42+
super().__init__(env=env)
43+
44+
self._episode_sums: dict[str, torch.Tensor] = {}
45+
for term_name in self._term_names:
46+
self._episode_sums[term_name] = torch.zeros(
47+
self.num_envs, dtype=torch.float, device=self.device
48+
)
49+
self._step_count = torch.zeros(self.num_envs, dtype=torch.long, device=self.device)
50+
self._step_values = torch.zeros(
51+
(self.num_envs, len(self._term_names)), dtype=torch.float, device=self.device
52+
)
53+
54+
def __str__(self) -> str:
55+
msg = f"<MetricsManager> contains {len(self._term_names)} active terms.\n"
56+
table = PrettyTable()
57+
table.title = "Active Metrics Terms"
58+
table.field_names = ["Index", "Name"]
59+
table.align["Name"] = "l"
60+
for index, name in enumerate(self._term_names):
61+
table.add_row([index, name])
62+
msg += table.get_string()
63+
msg += "\n"
64+
return msg
65+
66+
# Properties.
67+
68+
@property
69+
def active_terms(self) -> list[str]:
70+
return self._term_names
71+
72+
# Methods.
73+
74+
def reset(
75+
self, env_ids: torch.Tensor | slice | None = None
76+
) -> dict[str, torch.Tensor]:
77+
if env_ids is None:
78+
env_ids = slice(None)
79+
extras = {}
80+
counts = self._step_count[env_ids].float()
81+
# Avoid division by zero for envs that haven't stepped.
82+
safe_counts = torch.clamp(counts, min=1.0)
83+
for key in self._episode_sums:
84+
episode_avg = torch.mean(self._episode_sums[key][env_ids] / safe_counts)
85+
extras["Episode_Metrics/" + key] = episode_avg
86+
self._episode_sums[key][env_ids] = 0.0
87+
self._step_count[env_ids] = 0
88+
for term_cfg in self._class_term_cfgs:
89+
term_cfg.func.reset(env_ids=env_ids)
90+
return extras
91+
92+
def compute(self) -> None:
93+
self._step_count += 1
94+
for term_idx, (name, term_cfg) in enumerate(
95+
zip(self._term_names, self._term_cfgs, strict=False)
96+
):
97+
value = term_cfg.func(self._env, **term_cfg.params)
98+
self._episode_sums[name] += value
99+
self._step_values[:, term_idx] = value
100+
101+
def get_active_iterable_terms(
102+
self, env_idx: int
103+
) -> Sequence[tuple[str, Sequence[float]]]:
104+
terms = []
105+
for idx, name in enumerate(self._term_names):
106+
terms.append((name, [self._step_values[env_idx, idx].cpu().item()]))
107+
return terms
108+
109+
def _prepare_terms(self):
110+
for term_name, term_cfg in self.cfg.items():
111+
term_cfg: MetricsTermCfg | None
112+
if term_cfg is None:
113+
print(f"term: {term_name} set to None, skipping...")
114+
continue
115+
self._resolve_common_term_cfg(term_name, term_cfg)
116+
self._term_names.append(term_name)
117+
self._term_cfgs.append(term_cfg)
118+
if hasattr(term_cfg.func, "reset") and callable(term_cfg.func.reset):
119+
self._class_term_cfgs.append(term_cfg)
120+
121+
122+
class NullMetricsManager:
123+
"""Placeholder for absent metrics manager that safely no-ops all operations."""
124+
125+
def __init__(self):
126+
self.active_terms: list[str] = []
127+
self.cfg = None
128+
129+
def __str__(self) -> str:
130+
return "<NullMetricsManager> (inactive)"
131+
132+
def __repr__(self) -> str:
133+
return "NullMetricsManager()"
134+
135+
def get_active_iterable_terms(
136+
self, env_idx: int
137+
) -> Sequence[tuple[str, Sequence[float]]]:
138+
return []
139+
140+
def reset(self, env_ids: torch.Tensor | None = None) -> dict[str, float]:
141+
return {}
142+
143+
def compute(self) -> None:
144+
pass

tests/test_metrics_manager.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""Tests for metrics manager functionality."""
2+
3+
from unittest.mock import Mock
4+
5+
import pytest
6+
import torch
7+
8+
from mjlab.managers.metrics_manager import (
9+
MetricsManager,
10+
MetricsTermCfg,
11+
NullMetricsManager,
12+
)
13+
14+
15+
class SimpleTestMetric:
16+
"""A class-based metric that tracks state."""
17+
18+
def __init__(self, cfg: MetricsTermCfg, env):
19+
self.call_count = torch.zeros(env.num_envs, device=env.device)
20+
21+
def __call__(self, env, **kwargs):
22+
self.call_count += 1
23+
return torch.ones(env.num_envs, device=env.device) * 0.5
24+
25+
def reset(self, env_ids: torch.Tensor | None = None, env=None):
26+
if env_ids is not None and len(env_ids) > 0:
27+
self.call_count[env_ids] = 0
28+
29+
30+
@pytest.fixture
31+
def mock_env():
32+
env = Mock()
33+
env.num_envs = 4
34+
env.device = "cpu"
35+
env.scene = {"robot": Mock()}
36+
return env
37+
38+
39+
def test_episode_averages_and_reset(mock_env):
40+
"""Compute for N steps, reset a subset, verify averages and zeroing."""
41+
cfg = {
42+
"term": MetricsTermCfg(
43+
func=lambda env: torch.ones(env.num_envs, device=env.device) * 0.5,
44+
params={},
45+
)
46+
}
47+
manager = MetricsManager(cfg, mock_env)
48+
49+
for _ in range(10):
50+
manager.compute()
51+
52+
info = manager.reset(env_ids=torch.tensor([0, 1]))
53+
54+
# Each env: sum=5.0, count=10, avg=0.5. Mean across 2 reset envs = 0.5.
55+
assert info["Episode_Metrics/term"].item() == pytest.approx(0.5)
56+
# Reset envs zeroed; non-reset envs untouched.
57+
assert manager._episode_sums["term"][0] == 0.0
58+
assert manager._step_count[0] == 0
59+
assert manager._episode_sums["term"][2] == pytest.approx(5.0)
60+
assert manager._step_count[2] == 10
61+
62+
63+
def test_early_termination_uses_per_env_step_count(mock_env):
64+
"""Envs with different episode lengths get correct per-step averages."""
65+
step = [0]
66+
67+
def step_dependent_metric(env):
68+
step[0] += 1
69+
return torch.full((env.num_envs,), float(step[0]), device=env.device)
70+
71+
cfg = {"m": MetricsTermCfg(func=step_dependent_metric, params={})}
72+
manager = MetricsManager(cfg, mock_env)
73+
74+
# 4 steps for all envs: values are 1, 2, 3, 4.
75+
for _ in range(4):
76+
manager.compute()
77+
# Env 0: sum=10, count=4. Reset it (env 1 keeps accumulating).
78+
manager.reset(env_ids=torch.tensor([0]))
79+
80+
# 2 more steps: values are 5, 6.
81+
for _ in range(2):
82+
manager.compute()
83+
# Env 0: sum=11, count=2, avg=5.5.
84+
# Env 1: sum=21, count=6, avg=3.5.
85+
info = manager.reset(env_ids=torch.tensor([0, 1]))
86+
# Mean of [5.5, 3.5] = 4.5.
87+
assert info["Episode_Metrics/m"].item() == pytest.approx(4.5)
88+
89+
90+
def test_class_based_metric_reset_targets_correct_envs(mock_env):
91+
"""Class-based term's reset() is called with the correct env_ids."""
92+
cfg = {"term": MetricsTermCfg(func=SimpleTestMetric, params={})}
93+
manager = MetricsManager(cfg, mock_env)
94+
term = manager._class_term_cfgs[0].func
95+
96+
for _ in range(10):
97+
manager.compute()
98+
99+
manager.reset(env_ids=torch.tensor([0, 2]))
100+
101+
assert term.call_count[0] == 0
102+
assert term.call_count[1] == 10
103+
assert term.call_count[2] == 0
104+
assert term.call_count[3] == 10
105+
106+
107+
def test_null_metrics_manager(mock_env):
108+
"""NullMetricsManager doesn't crash and returns empty dict on reset."""
109+
manager = NullMetricsManager()
110+
manager.compute()
111+
assert manager.reset(env_ids=torch.tensor([0])) == {}
112+
113+
114+
def test_none_terms_are_skipped(mock_env):
115+
"""None terms in config are skipped without error."""
116+
cfg: dict[str, MetricsTermCfg | None] = {
117+
"valid": MetricsTermCfg(
118+
func=lambda env: torch.ones(env.num_envs, device=env.device),
119+
params={},
120+
),
121+
"skipped": None,
122+
}
123+
manager = MetricsManager(cfg, mock_env) # type: ignore[arg-type]
124+
assert manager._term_names == ["valid"]

0 commit comments

Comments
 (0)