Skip to content

Gradient explosion in (modded) lab 3 #19

@gmontamat

Description

@gmontamat

I have not seen this behavior with the MNIST dataset, but it was very common that gradient exploded after a few thousand epochs using the CIFAR-10 dataset (loss increased abruptly). Gradient clipping in the Trainer class mitigated the problem:

        # Train loop
        pbar = tqdm(enumerate(range(num_epochs)))
        for idx, epoch in pbar:
            opt.zero_grad()
            loss = self.get_train_loss(**kwargs)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            opt.step()
            pbar.set_description(f'Epoch {idx}, loss: {loss.item():.3f}')

I'm wondering if there are other ways to mitigate this problem such as modifications to the UNet architecture, adding dropout layers or even trying different initializations.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions