From d3f41583eb6dce0255097cd782e6fa6bdd26af70 Mon Sep 17 00:00:00 2001 From: diaskabdualiev1 Date: Fri, 6 Feb 2026 19:06:44 +0500 Subject: [PATCH] Fix NaN loss and crash in multi-GPU PPO training Three issues caused training to fail on `Mjlab-Velocity-Rough-Unitree-Go2` with multiple GPUs: 1. **Ratio overflow**: `exp(log_prob_new - log_prob_old)` overflows to Inf when the policy changes significantly between collection and update, especially with unstable early training (robot falling immediately). Fix: clamp log-ratio to [-20, 20] before exp(). 2. **NaN gradient propagation across GPUs**: `all_reduce(SUM)` in `reduce_parameters()` causes one GPU with NaN gradients to poison all other GPUs. Fix: `nan_to_num()` on gradients before all_reduce. 3. **NaN loss corrupts parameters**: when loss becomes NaN, `backward()` produces NaN gradients and `optimizer.step()` writes NaN into all parameters, making recovery impossible. Fix: skip optimizer step when loss is not finite. 4. **Negative/NaN std**: `self.std` parameter (noise_std_type="scalar") can become negative or NaN after optimizer step, crashing `torch.distributions.Normal`. Fix: clamp + nan_to_num in `update_distribution()` for both ActorCritic and ActorCriticRecurrent. Co-Authored-By: Claude Opus 4.6 --- mjlab/rsl_rl/algorithms/ppo.py | 37 ++++++++++++------- mjlab/rsl_rl/modules/actor_critic.py | 5 ++- .../rsl_rl/modules/actor_critic_recurrent.py | 5 ++- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/mjlab/rsl_rl/algorithms/ppo.py b/mjlab/rsl_rl/algorithms/ppo.py index d433609..578f28b 100644 --- a/mjlab/rsl_rl/algorithms/ppo.py +++ b/mjlab/rsl_rl/algorithms/ppo.py @@ -293,8 +293,10 @@ def update(self): # noqa: C901 for param_group in self.optimizer.param_groups: param_group["lr"] = self.learning_rate - # Surrogate loss - ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)) + # Surrogate loss (clamp log-ratio to prevent exp overflow → Inf → NaN) + log_ratio = actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch) + log_ratio = torch.clamp(log_ratio, -20.0, 20.0) + ratio = torch.exp(log_ratio) surrogate = -torch.squeeze(advantages_batch) * ratio surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp( ratio, 1.0 - self.clip_param, 1.0 + self.clip_param @@ -371,17 +373,23 @@ def update(self): # noqa: C901 self.rnd_optimizer.zero_grad() # type: ignore rnd_loss.backward() - # Collect gradients from all GPUs - if self.is_multi_gpu: - self.reduce_parameters() - - # Apply the gradients - # -- For PPO - nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) - self.optimizer.step() - # -- For RND - if self.rnd_optimizer: - self.rnd_optimizer.step() + # Skip optimizer step if loss is NaN (prevents corrupting all parameters) + if not torch.isfinite(loss): + self.optimizer.zero_grad() + if self.rnd_optimizer: + self.rnd_optimizer.zero_grad() + else: + # Collect gradients from all GPUs + if self.is_multi_gpu: + self.reduce_parameters() + + # Apply the gradients + # -- For PPO + nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.optimizer.step() + # -- For RND + if self.rnd_optimizer: + self.rnd_optimizer.step() # Store the losses mean_value_loss += value_loss.item() @@ -449,6 +457,9 @@ def reduce_parameters(self): grads += [param.grad.view(-1) for param in self.rnd.parameters() if param.grad is not None] all_grads = torch.cat(grads) + # Replace NaN/Inf gradients with 0 before reduction (prevents one GPU poisoning all others) + all_grads = torch.nan_to_num(all_grads, nan=0.0, posinf=0.0, neginf=0.0) + # Average the gradients across all GPUs torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM) all_grads /= self.gpu_world_size diff --git a/mjlab/rsl_rl/modules/actor_critic.py b/mjlab/rsl_rl/modules/actor_critic.py index 825a255..bef1e44 100644 --- a/mjlab/rsl_rl/modules/actor_critic.py +++ b/mjlab/rsl_rl/modules/actor_critic.py @@ -136,7 +136,10 @@ def update_distribution(self, obs): std = torch.exp(self.log_std).expand_as(mean) else: raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") - # create distribution + # create distribution (guard against negative/NaN std from optimizer or gradient explosion) + std = torch.clamp(std, min=1e-6) + std = torch.nan_to_num(std, nan=1.0, posinf=1.0, neginf=1e-6) + mean = torch.nan_to_num(mean, nan=0.0) self.distribution = Normal(mean, std) def act(self, obs, **kwargs): diff --git a/mjlab/rsl_rl/modules/actor_critic_recurrent.py b/mjlab/rsl_rl/modules/actor_critic_recurrent.py index e8da59c..0a03e8f 100644 --- a/mjlab/rsl_rl/modules/actor_critic_recurrent.py +++ b/mjlab/rsl_rl/modules/actor_critic_recurrent.py @@ -153,7 +153,10 @@ def update_distribution(self, obs): std = torch.exp(self.log_std).expand_as(mean) else: raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") - # create distribution + # create distribution (guard against negative/NaN std from optimizer or gradient explosion) + std = torch.clamp(std, min=1e-6) + std = torch.nan_to_num(std, nan=1.0, posinf=1.0, neginf=1e-6) + mean = torch.nan_to_num(mean, nan=0.0) self.distribution = Normal(mean, std) def act(self, obs, masks=None, hidden_states=None):