Skip to content

Commit 08f4845

Browse files
authored
Update stage_1_and_2.py
1 parent ceb84ba commit 08f4845

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,15 @@ def _enforce_cpu_offload():
283283

284284
self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32
285285

286+
# Check for Muon optimizer usage
287+
self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params'])
288+
286289
if self.reduce_scatter and self.partition_gradients:
287290
valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32)
288291
assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
289292
assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
290293
assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
294+
assert not self.uses_muon, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer."
291295

292296
# param flattened by groups
293297
self.bit16_groups = []
@@ -1187,10 +1191,8 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt
11871191
stream = get_accelerator().current_stream()
11881192

11891193
with get_accelerator().stream(stream):
1190-
# Check if any parameter uses Muon optimizer (needs full gradient for orthogonalization)
1191-
uses_muon = any(getattr(param, 'use_muon', False) for group in self.bit16_groups for param in group)
1192-
1193-
if not self.reduce_scatter or uses_muon:
1194+
# Use pre-detected Muon flag from initialization
1195+
if not self.reduce_scatter or self.uses_muon:
11941196
# Force full all-reduce for Muon parameters even when reduce_scatter is enabled
11951197
self.gradient_reduction_w_predivide(tensor, communication_data_type)
11961198
return

0 commit comments

Comments
 (0)