-
Notifications
You must be signed in to change notification settings - Fork 90
Open
Description
I have not seen this behavior with the MNIST dataset, but it was very common that gradient exploded after a few thousand epochs using the CIFAR-10 dataset (loss increased abruptly). Gradient clipping in the Trainer class mitigated the problem:
# Train loop
pbar = tqdm(enumerate(range(num_epochs)))
for idx, epoch in pbar:
opt.zero_grad()
loss = self.get_train_loss(**kwargs)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
opt.step()
pbar.set_description(f'Epoch {idx}, loss: {loss.item():.3f}')I'm wondering if there are other ways to mitigate this problem such as modifications to the UNet architecture, adding dropout layers or even trying different initializations.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels