[Draft] Muon Optimizer Support for ZeRO3#7798
[Draft] Muon Optimizer Support for ZeRO3#7798PKUWZP wants to merge 15 commits intodeepspeedai:masterfrom
Conversation
…s non continguous version + test + everything else
PKUWZP
left a comment
There was a problem hiding this comment.
@pengdurice Also two more comments:
-
It seems that we have excessive tensor allocations: Multiple torch.empty, torch.zeros, and .clone() calls create memory footprint pressure. Consider reusing buffers where possible.
-
Synchronous all_gather: The distributed operations could potentially be overlapped with computation.
I think we need to re-work on the PR and let's take some times to refine the code.
| self.optimizer_swapper.swap_in_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i]) | ||
| for idx, dest_offset in params_to_subgroup_maps[i]: | ||
| momentum_buffer[idx] = self.optimizer.state[self.fp32_partitioned_groups_flat[i]]["momentum_buffer"].narrow(0, dest_offset, param.partition_numel()).clone() | ||
| self.optimizer_swapper.swap_out_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i]) |
There was a problem hiding this comment.
@pengdurice Here is a bug. The variable param refers to the last parameter from the previous loop (for param in self.ipg_buckets[...]), not the parameter corresponding to idx. We should change it touse_muon_params[idx].partition_numel().
There was a problem hiding this comment.
yeah, good catch, fixed it.
| self.dp_process_group = self.parameter_offload.dp_process_group | ||
| self.sequence_parallel_size = groups._get_sequence_parallel_world_size() | ||
|
|
||
| self.all2all_process_group = all2all_process_group |
There was a problem hiding this comment.
@pengdurice Question: where did we set up the all2all_process_group? It seems that it's never set.
There was a problem hiding this comment.
it is double assigned, there is an identical line nearly, so I deleted this one.
| # params_pad = params + [torch.empty_like(params[-1])] * (world_sz - len(params) % world_sz) | ||
| grads_pad = [param.grad for param in params] + [torch.empty_like(params[-1].grad)] * (world_sz - len(params) % world_sz) | ||
| gathered_momentums_pad = gathered_momentums + [torch.empty_like(gathered_momentums[-1])] * (world_sz - len(gathered_momentums) % world_sz) | ||
| for base_i in range(len(params))[::world_sz]: |
There was a problem hiding this comment.
@pengdurice There's a padding error here. When len(params) % world_sz == 0, this adds world_sz empty tensors instead of 0. Should we change it to: (world_sz - len(params) % world_sz) % world_sz ?
deepspeed/runtime/zero/stage3.py
Outdated
|
|
||
| self.reduce_scatter = reduce_scatter | ||
|
|
||
| self.use_muon = 'muon' in self.optimizer.__class__.__name__.lower() |
There was a problem hiding this comment.
@pengdurice This is very fragile and purely depends on the naming conventions. Can we leverage isinstance() instead?
| self.reduce_scatter = reduce_scatter | ||
|
|
||
| self.use_muon = 'muon' in self.optimizer.__class__.__name__.lower() | ||
| self.save_muon_momentum_buffer_in_memory = ds_config.get('save_muon_momentum_buffer_in_memory', False) |
There was a problem hiding this comment.
@pengdurice Can we add save_muon_momentum_buffer_in_memory to the config schema and documented?
| params_to_subgroup_maps[i].append((idx, dest_offset)) | ||
| idx += 1 | ||
| params_size_offset += param.grad.numel() | ||
| # if optimizer is swappable, swap in the momentum buffer of the parameters that need to be updated using muon and then swap them out |
There was a problem hiding this comment.
@pengdurice This doubles NVMe I/O overhead. Can we consider consolidating into a single swap in/out cycle?
There was a problem hiding this comment.
The thought is moun_update is also time consuming and it could be that splitting the moun_update for each subgroup may also cause overhead. It's worth evaluating which is better;-)
| gathered_params_momentums = self._partitioned_buffers_all_gather(use_muon_params, momentum_buffer, communication_data_type) | ||
| for i in params_to_subgroup_maps: | ||
| if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory: | ||
| self.optimizer_swapper.swap_in_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i]) |
There was a problem hiding this comment.
@pengdurice Again same thing here, can we consolidate the two swaps into one swap?
|
Hi @pengdurice @PKUWZP, I have a question. I saw there is an option that save momentum buffer in memory. Yet for adam optimizer there is no such option. Is that because for adam optimizer such needs is covered by zero offload, while for muon optimizer zero offload is not available yet, so this is used as temporary solution? Thanks! |
Hi thank you for your question, this is because for Adam optimizer, it is handled by its own code and since Adam only does element-wise operations, so no need to specially handle it. in DeepSpeed/deepspeed/runtime/zero/muon/muon_optimizer.py
However, for muon update, we need cross element operations and thus need to do it in this file. It is not about offload. The buffer can be in GPU or CPU depending on if it is offload or not. Hope that answers your question! |
Authors: @pengdurice @PKUWZP
We aim on adding Muon Optimizer to zero stage 3 in this draft PR:
self.fp32_partitioned_groups_flat; whendevice == NVME, we make sure that the momentum buffers can be swapped in and out along with other components in the optimizer states.self.fp32_partitioned_groups_flatto save memory footprint. So, before the muon update, we need to performall_gatheron top of each data-parallel group rank. The Muon updates of the parameters are also divided across the data-parallel ranks, and the results are all-gathered once all updates are complete. After theall_gather, the momentum buffers are partitioned and flatted again.Next steps: