Skip to content

Fix RuntimeError: clamp action std to prevent negative values#8

Open
diaskabdualiev wants to merge 1 commit intounitreerobotics:mainfrom
diaskabdualiev:fix/clamp-action-std
Open

Fix RuntimeError: clamp action std to prevent negative values#8
diaskabdualiev wants to merge 1 commit intounitreerobotics:mainfrom
diaskabdualiev:fix/clamp-action-std

Conversation

@diaskabdualiev
Copy link

Summary

  • Fix crash: RuntimeError: normal expects all elements of std >= 0.0 in actor_critic.py:146 during PPO training
  • Root cause: When noise_std_type="scalar", self.std is a raw nn.Parameter without a lower bound. The optimizer can push it negative, crashing torch.distributions.Normal
  • Especially affects multi-GPU training: gradient averaging via all_reduce in reduce_parameters() amplifies conflicting gradient directions across GPUs

Reproduction

Train Mjlab-Velocity-Rough-Unitree-Go2 with multiple GPUs:

python scripts/train.py Mjlab-Velocity-Rough-Unitree-Go2 \
  --gpu-ids 0 1 2 3 4 5 6 7 8 9 \
  --env.scene.num-envs=8192

Crashes at iteration 1 with:

RuntimeError: normal expects all elements of std >= 0.0
  File "mjlab/rsl_rl/modules/actor_critic.py", line 146, in act
    return self.distribution.sample()

Fix

Add torch.clamp(std, min=1e-6) before creating the Normal distribution to guarantee std remains positive.

Test plan

  • Train Mjlab-Velocity-Rough-Unitree-Go2 with multi-GPU setup — no crash
  • Train Mjlab-Velocity-Flat-Unitree-Go2 — verify no regression

🤖 Generated with Claude Code

Three issues caused training to fail on `Mjlab-Velocity-Rough-Unitree-Go2`
with multiple GPUs:

1. **Ratio overflow**: `exp(log_prob_new - log_prob_old)` overflows to Inf
   when the policy changes significantly between collection and update,
   especially with unstable early training (robot falling immediately).
   Fix: clamp log-ratio to [-20, 20] before exp().

2. **NaN gradient propagation across GPUs**: `all_reduce(SUM)` in
   `reduce_parameters()` causes one GPU with NaN gradients to poison
   all other GPUs. Fix: `nan_to_num()` on gradients before all_reduce.

3. **NaN loss corrupts parameters**: when loss becomes NaN, `backward()`
   produces NaN gradients and `optimizer.step()` writes NaN into all
   parameters, making recovery impossible. Fix: skip optimizer step
   when loss is not finite.

4. **Negative/NaN std**: `self.std` parameter (noise_std_type="scalar")
   can become negative or NaN after optimizer step, crashing
   `torch.distributions.Normal`. Fix: clamp + nan_to_num in
   `update_distribution()` for both ActorCritic and ActorCriticRecurrent.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant