We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 19397c5 commit aba6a8dCopy full SHA for aba6a8d
deepspeed/runtime/zero/stage_1_and_2.py
@@ -1439,9 +1439,9 @@ def average_tensor(
1439
stream = get_accelerator().current_stream()
1440
1441
with get_accelerator().stream(stream):
1442
- # Use pre-detected Muon flag from initialization
1443
- if not self.reduce_scatter or self.uses_muon:
1444
- # Force full all-reduce for Muon parameters even when reduce_scatter is enabled
+ # Check if current configuration requires full all-reduce
+ if not self.reduce_scatter or any(self.group_uses_muon):
+ # Force full all-reduce for Muon parameters or when reduce_scatter is disabled
1445
self.gradient_reduction_w_predivide(tensor, communication_data_type)
1446
return
1447
0 commit comments