Reduce peak vram in moe.py by avoiding cached router indices #464
Reduce peak vram in moe.py by avoiding cached router indices #464shedrachIkenna wants to merge 1 commit intometa-llama:mainfrom
Conversation
Applied a memory efficient forward pass by reshaping and expanding router_indices twice instead of saving in memory which was never used later on
|
Hi @shedrachIkenna! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Currently, the expanded router indices in moe.py forward pass is computed on the fly in gather operation but it is recomputed and assigned to a variable just before the scatter_add_ operation. Holding them in memory while experts are executing increases peak memory pressure during the most memory-intensive part of the layer (expert mlps)
This PR switches to calculating the indices on-the-fly for both the gather and scatter_add_ operations. This way, the index tensor can be garbage collected while the experts are running, which effectively reduces memory usage of the forward pass with negligible compute overhead.