Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions mjlab/rsl_rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,10 @@ def update(self): # noqa: C901
for param_group in self.optimizer.param_groups:
param_group["lr"] = self.learning_rate

# Surrogate loss
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
# Surrogate loss (clamp log-ratio to prevent exp overflow → Inf → NaN)
log_ratio = actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)
log_ratio = torch.clamp(log_ratio, -20.0, 20.0)
ratio = torch.exp(log_ratio)
surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(
ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
Expand Down Expand Up @@ -371,17 +373,23 @@ def update(self): # noqa: C901
self.rnd_optimizer.zero_grad() # type: ignore
rnd_loss.backward()

# Collect gradients from all GPUs
if self.is_multi_gpu:
self.reduce_parameters()

# Apply the gradients
# -- For PPO
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.optimizer.step()
# -- For RND
if self.rnd_optimizer:
self.rnd_optimizer.step()
# Skip optimizer step if loss is NaN (prevents corrupting all parameters)
if not torch.isfinite(loss):
self.optimizer.zero_grad()
if self.rnd_optimizer:
self.rnd_optimizer.zero_grad()
else:
# Collect gradients from all GPUs
if self.is_multi_gpu:
self.reduce_parameters()

# Apply the gradients
# -- For PPO
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.optimizer.step()
# -- For RND
if self.rnd_optimizer:
self.rnd_optimizer.step()

# Store the losses
mean_value_loss += value_loss.item()
Expand Down Expand Up @@ -449,6 +457,9 @@ def reduce_parameters(self):
grads += [param.grad.view(-1) for param in self.rnd.parameters() if param.grad is not None]
all_grads = torch.cat(grads)

# Replace NaN/Inf gradients with 0 before reduction (prevents one GPU poisoning all others)
all_grads = torch.nan_to_num(all_grads, nan=0.0, posinf=0.0, neginf=0.0)

# Average the gradients across all GPUs
torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM)
all_grads /= self.gpu_world_size
Expand Down
5 changes: 4 additions & 1 deletion mjlab/rsl_rl/modules/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ def update_distribution(self, obs):
std = torch.exp(self.log_std).expand_as(mean)
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
# create distribution
# create distribution (guard against negative/NaN std from optimizer or gradient explosion)
std = torch.clamp(std, min=1e-6)
std = torch.nan_to_num(std, nan=1.0, posinf=1.0, neginf=1e-6)
mean = torch.nan_to_num(mean, nan=0.0)
self.distribution = Normal(mean, std)

def act(self, obs, **kwargs):
Expand Down
5 changes: 4 additions & 1 deletion mjlab/rsl_rl/modules/actor_critic_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def update_distribution(self, obs):
std = torch.exp(self.log_std).expand_as(mean)
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
# create distribution
# create distribution (guard against negative/NaN std from optimizer or gradient explosion)
std = torch.clamp(std, min=1e-6)
std = torch.nan_to_num(std, nan=1.0, posinf=1.0, neginf=1e-6)
mean = torch.nan_to_num(mean, nan=0.0)
self.distribution = Normal(mean, std)

def act(self, obs, masks=None, hidden_states=None):
Expand Down