-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathppo.py
More file actions
480 lines (426 loc) · 20.4 KB
/
ppo.py
File metadata and controls
480 lines (426 loc) · 20.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import torch
import torch.nn as nn
import torch.optim as optim
from itertools import chain
from mjlab.rsl_rl.modules import ActorCritic
from mjlab.rsl_rl.modules.rnd import RandomNetworkDistillation
from mjlab.rsl_rl.storage import RolloutStorage
from mjlab.rsl_rl.utils import string_to_callable
class PPO:
"""Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347)."""
policy: ActorCritic
"""The actor critic module."""
def __init__(
self,
policy,
num_learning_epochs=5,
num_mini_batches=4,
clip_param=0.2,
gamma=0.99,
lam=0.95,
value_loss_coef=1.0,
entropy_coef=0.01,
learning_rate=0.001,
max_grad_norm=1.0,
use_clipped_value_loss=True,
schedule="adaptive",
desired_kl=0.01,
device="cpu",
normalize_advantage_per_mini_batch=False,
# RND parameters
rnd_cfg: dict | None = None,
# Symmetry parameters
symmetry_cfg: dict | None = None,
# Distributed training parameters
multi_gpu_cfg: dict | None = None,
):
# device-related parameters
self.device = device
self.is_multi_gpu = multi_gpu_cfg is not None
# Multi-GPU parameters
if multi_gpu_cfg is not None:
self.gpu_global_rank = multi_gpu_cfg["global_rank"]
self.gpu_world_size = multi_gpu_cfg["world_size"]
else:
self.gpu_global_rank = 0
self.gpu_world_size = 1
# RND components
if rnd_cfg is not None:
# Extract parameters used in ppo
rnd_lr = rnd_cfg.pop("learning_rate", 1e-3)
# Create RND module
self.rnd = RandomNetworkDistillation(device=self.device, **rnd_cfg)
# Create RND optimizer
params = self.rnd.predictor.parameters()
self.rnd_optimizer = optim.Adam(params, lr=rnd_lr)
else:
self.rnd = None
self.rnd_optimizer = None
# Symmetry components
if symmetry_cfg is not None:
# Check if symmetry is enabled
use_symmetry = symmetry_cfg["use_data_augmentation"] or symmetry_cfg["use_mirror_loss"]
# Print that we are not using symmetry
if not use_symmetry:
print("Symmetry not used for learning. We will use it for logging instead.")
# If function is a string then resolve it to a function
if isinstance(symmetry_cfg["data_augmentation_func"], str):
symmetry_cfg["data_augmentation_func"] = string_to_callable(symmetry_cfg["data_augmentation_func"])
# Check valid configuration
if symmetry_cfg["use_data_augmentation"] and not callable(symmetry_cfg["data_augmentation_func"]):
raise ValueError(
"Data augmentation enabled but the function is not callable:"
f" {symmetry_cfg['data_augmentation_func']}"
)
# Store symmetry configuration
self.symmetry = symmetry_cfg
else:
self.symmetry = None
# PPO components
self.policy = policy
self.policy.to(self.device)
# Create optimizer
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
# Create rollout storage
self.storage: RolloutStorage = None # type: ignore
self.transition = RolloutStorage.Transition()
# PPO parameters
self.clip_param = clip_param
self.num_learning_epochs = num_learning_epochs
self.num_mini_batches = num_mini_batches
self.value_loss_coef = value_loss_coef
self.entropy_coef = entropy_coef
self.gamma = gamma
self.lam = lam
self.max_grad_norm = max_grad_norm
self.use_clipped_value_loss = use_clipped_value_loss
self.desired_kl = desired_kl
self.schedule = schedule
self.learning_rate = learning_rate
self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, actions_shape):
# create rollout storage
self.storage = RolloutStorage(
training_type,
num_envs,
num_transitions_per_env,
obs,
actions_shape,
self.device,
)
def act(self, obs):
if self.policy.is_recurrent:
self.transition.hidden_states = self.policy.get_hidden_states()
# compute the actions and values
self.transition.actions = self.policy.act(obs).detach()
self.transition.values = self.policy.evaluate(obs).detach()
self.transition.actions_log_prob = self.policy.get_actions_log_prob(self.transition.actions).detach()
self.transition.action_mean = self.policy.action_mean.detach()
self.transition.action_sigma = self.policy.action_std.detach()
# need to record obs before env.step()
self.transition.observations = obs
return self.transition.actions
def process_env_step(self, obs, rewards, dones, extras):
# update the normalizers
self.policy.update_normalization(obs)
if self.rnd:
self.rnd.update_normalization(obs)
# Record the rewards and dones
# Note: we clone here because later on we bootstrap the rewards based on timeouts
self.transition.rewards = rewards.clone()
self.transition.dones = dones
# Compute the intrinsic rewards and add to extrinsic rewards
if self.rnd:
# Compute the intrinsic rewards
self.intrinsic_rewards = self.rnd.get_intrinsic_reward(obs)
# Add intrinsic rewards to extrinsic rewards
self.transition.rewards += self.intrinsic_rewards
# Bootstrapping on time outs
if "time_outs" in extras:
self.transition.rewards += self.gamma * torch.squeeze(
self.transition.values * extras["time_outs"].unsqueeze(1).to(self.device), 1
)
# record the transition
self.storage.add_transitions(self.transition)
self.transition.clear()
self.policy.reset(dones)
def compute_returns(self, obs):
# compute value for the last step
last_values = self.policy.evaluate(obs).detach()
self.storage.compute_returns(
last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch
)
def update(self): # noqa: C901
mean_value_loss = 0
mean_surrogate_loss = 0
mean_entropy = 0
# -- RND loss
if self.rnd:
mean_rnd_loss = 0
else:
mean_rnd_loss = None
# -- Symmetry loss
if self.symmetry:
mean_symmetry_loss = 0
else:
mean_symmetry_loss = None
# generator for mini batches
if self.policy.is_recurrent:
generator = self.storage.recurrent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
else:
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
# iterate over batches
for (
obs_batch,
actions_batch,
target_values_batch,
advantages_batch,
returns_batch,
old_actions_log_prob_batch,
old_mu_batch,
old_sigma_batch,
hid_states_batch,
masks_batch,
) in generator:
# number of augmentations per sample
# we start with 1 and increase it if we use symmetry augmentation
num_aug = 1
# original batch size
# we assume policy group is always there and needs augmentation
original_batch_size = obs_batch.batch_size[0]
# check if we should normalize advantages per mini batch
if self.normalize_advantage_per_mini_batch:
with torch.no_grad():
advantages_batch = (advantages_batch - advantages_batch.mean()) / (advantages_batch.std() + 1e-8)
# Perform symmetric augmentation
if self.symmetry and self.symmetry["use_data_augmentation"]:
# augmentation using symmetry
data_augmentation_func = self.symmetry["data_augmentation_func"]
# returned shape: [batch_size * num_aug, ...]
obs_batch, actions_batch = data_augmentation_func(
obs=obs_batch,
actions=actions_batch,
env=self.symmetry["_env"],
)
# compute number of augmentations per sample
# we assume policy group is always there and needs augmentation
num_aug = int(obs_batch.batch_size[0] / original_batch_size)
# repeat the rest of the batch
# -- actor
old_actions_log_prob_batch = old_actions_log_prob_batch.repeat(num_aug, 1)
# -- critic
target_values_batch = target_values_batch.repeat(num_aug, 1)
advantages_batch = advantages_batch.repeat(num_aug, 1)
returns_batch = returns_batch.repeat(num_aug, 1)
# Recompute actions log prob and entropy for current batch of transitions
# Note: we need to do this because we updated the policy with the new parameters
# -- actor
self.policy.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
actions_log_prob_batch = self.policy.get_actions_log_prob(actions_batch)
# -- critic
value_batch = self.policy.evaluate(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
# -- entropy
# we only keep the entropy of the first augmentation (the original one)
mu_batch = self.policy.action_mean[:original_batch_size]
sigma_batch = self.policy.action_std[:original_batch_size]
entropy_batch = self.policy.entropy[:original_batch_size]
# KL
if self.desired_kl is not None and self.schedule == "adaptive":
with torch.inference_mode():
kl = torch.sum(
torch.log(sigma_batch / old_sigma_batch + 1.0e-5)
+ (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch))
/ (2.0 * torch.square(sigma_batch))
- 0.5,
axis=-1,
)
kl_mean = torch.mean(kl)
# Reduce the KL divergence across all GPUs
if self.is_multi_gpu:
torch.distributed.all_reduce(kl_mean, op=torch.distributed.ReduceOp.SUM)
kl_mean /= self.gpu_world_size
# Update the learning rate
# Perform this adaptation only on the main process
# TODO: Is this needed? If KL-divergence is the "same" across all GPUs,
# then the learning rate should be the same across all GPUs.
if self.gpu_global_rank == 0:
if kl_mean > self.desired_kl * 2.0:
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
self.learning_rate = min(1e-2, self.learning_rate * 1.5)
# Update the learning rate for all GPUs
if self.is_multi_gpu:
lr_tensor = torch.tensor(self.learning_rate, device=self.device)
torch.distributed.broadcast(lr_tensor, src=0)
self.learning_rate = lr_tensor.item()
# Update the learning rate for all parameter groups
for param_group in self.optimizer.param_groups:
param_group["lr"] = self.learning_rate
# Surrogate loss (clamp log-ratio to prevent exp overflow → Inf → NaN)
log_ratio = actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)
log_ratio = torch.clamp(log_ratio, -20.0, 20.0)
ratio = torch.exp(log_ratio)
surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(
ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
)
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
# Value function loss
if self.use_clipped_value_loss:
value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(
-self.clip_param, self.clip_param
)
value_losses = (value_batch - returns_batch).pow(2)
value_losses_clipped = (value_clipped - returns_batch).pow(2)
value_loss = torch.max(value_losses, value_losses_clipped).mean()
else:
value_loss = (returns_batch - value_batch).pow(2).mean()
loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()
# Symmetry loss
if self.symmetry:
# obtain the symmetric actions
# if we did augmentation before then we don't need to augment again
if not self.symmetry["use_data_augmentation"]:
data_augmentation_func = self.symmetry["data_augmentation_func"]
obs_batch, _ = data_augmentation_func(obs=obs_batch, actions=None, env=self.symmetry["_env"])
# compute number of augmentations per sample
num_aug = int(obs_batch.shape[0] / original_batch_size)
# actions predicted by the actor for symmetrically-augmented observations
mean_actions_batch = self.policy.act_inference(obs_batch.detach().clone())
# compute the symmetrically augmented actions
# note: we are assuming the first augmentation is the original one.
# We do not use the action_batch from earlier since that action was sampled from the distribution.
# However, the symmetry loss is computed using the mean of the distribution.
action_mean_orig = mean_actions_batch[:original_batch_size]
_, actions_mean_symm_batch = data_augmentation_func(
obs=None, actions=action_mean_orig, env=self.symmetry["_env"]
)
# compute the loss (we skip the first augmentation as it is the original one)
mse_loss = torch.nn.MSELoss()
symmetry_loss = mse_loss(
mean_actions_batch[original_batch_size:], actions_mean_symm_batch.detach()[original_batch_size:]
)
# add the loss to the total loss
if self.symmetry["use_mirror_loss"]:
loss += self.symmetry["mirror_loss_coeff"] * symmetry_loss
else:
symmetry_loss = symmetry_loss.detach()
# Random Network Distillation loss
# TODO: Move this processing to inside RND module.
if self.rnd:
# extract the rnd_state
# TODO: Check if we still need torch no grad. It is just an affine transformation.
with torch.no_grad():
rnd_state_batch = self.rnd.get_rnd_state(obs_batch[:original_batch_size])
rnd_state_batch = self.rnd.state_normalizer(rnd_state_batch)
# predict the embedding and the target
predicted_embedding = self.rnd.predictor(rnd_state_batch)
target_embedding = self.rnd.target(rnd_state_batch).detach()
# compute the loss as the mean squared error
mseloss = torch.nn.MSELoss()
rnd_loss = mseloss(predicted_embedding, target_embedding)
# Compute the gradients
# -- For PPO
self.optimizer.zero_grad()
loss.backward()
# -- For RND
if self.rnd:
self.rnd_optimizer.zero_grad() # type: ignore
rnd_loss.backward()
# Skip optimizer step if loss is NaN (prevents corrupting all parameters)
if not torch.isfinite(loss):
self.optimizer.zero_grad()
if self.rnd_optimizer:
self.rnd_optimizer.zero_grad()
else:
# Collect gradients from all GPUs
if self.is_multi_gpu:
self.reduce_parameters()
# Apply the gradients
# -- For PPO
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.optimizer.step()
# -- For RND
if self.rnd_optimizer:
self.rnd_optimizer.step()
# Store the losses
mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
mean_entropy += entropy_batch.mean().item()
# -- RND loss
if mean_rnd_loss is not None:
mean_rnd_loss += rnd_loss.item()
# -- Symmetry loss
if mean_symmetry_loss is not None:
mean_symmetry_loss += symmetry_loss.item()
# -- For PPO
num_updates = self.num_learning_epochs * self.num_mini_batches
mean_value_loss /= num_updates
mean_surrogate_loss /= num_updates
mean_entropy /= num_updates
# -- For RND
if mean_rnd_loss is not None:
mean_rnd_loss /= num_updates
# -- For Symmetry
if mean_symmetry_loss is not None:
mean_symmetry_loss /= num_updates
# -- Clear the storage
self.storage.clear()
# construct the loss dictionary
loss_dict = {
"value_function": mean_value_loss,
"surrogate": mean_surrogate_loss,
"entropy": mean_entropy,
}
if self.rnd:
loss_dict["rnd"] = mean_rnd_loss
if self.symmetry:
loss_dict["symmetry"] = mean_symmetry_loss
return loss_dict
"""
Helper functions
"""
def broadcast_parameters(self):
"""Broadcast model parameters to all GPUs."""
# obtain the model parameters on current GPU
model_params = [self.policy.state_dict()]
if self.rnd:
model_params.append(self.rnd.predictor.state_dict())
# broadcast the model parameters
torch.distributed.broadcast_object_list(model_params, src=0)
# load the model parameters on all GPUs from source GPU
self.policy.load_state_dict(model_params[0])
if self.rnd:
self.rnd.predictor.load_state_dict(model_params[1])
def reduce_parameters(self):
"""Collect gradients from all GPUs and average them.
This function is called after the backward pass to synchronize the gradients across all GPUs.
"""
# Create a tensor to store the gradients
grads = [param.grad.view(-1) for param in self.policy.parameters() if param.grad is not None]
if self.rnd:
grads += [param.grad.view(-1) for param in self.rnd.parameters() if param.grad is not None]
all_grads = torch.cat(grads)
# Replace NaN/Inf gradients with 0 before reduction (prevents one GPU poisoning all others)
all_grads = torch.nan_to_num(all_grads, nan=0.0, posinf=0.0, neginf=0.0)
# Average the gradients across all GPUs
torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM)
all_grads /= self.gpu_world_size
# Get all parameters
all_params = self.policy.parameters()
if self.rnd:
all_params = chain(all_params, self.rnd.parameters())
# Update the gradients for all parameters with the reduced gradients
offset = 0
for param in all_params:
if param.grad is not None:
numel = param.numel()
# copy data back from shared buffer
param.grad.data.copy_(all_grads[offset : offset + numel].view_as(param.grad.data))
# update the offset for the next parameter
offset += numel