Skip to content

Gradient synchronization in data-parallel trainers #12

@cgarciae

Description

@cgarciae

Hey, great job with nanodl!

I was just looking through the code and noticed that when in Lambda's Trainer the gradients are not being averaged across devices here:

https://github.com/HMUNACHI/nanodl/blob/18c7f8e3da3c0bbfe2df3638a5e87857ec84868d/nanodl/__src/models/lamda.py#L564-L565

Not sure if this is happening elsewhere but usually to keep the weights in sync you apply a jax.lax.pmean over the gradients before passing them to apply_gradients, e.g.

grads = jax.lax.pmean(grads, axis_name='devices')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions