Skip to content

Commit aba6a8d

Browse files
authored
Refactor gradient reduction logic for Muon parameters
1 parent 19397c5 commit aba6a8d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,9 +1439,9 @@ def average_tensor(
14391439
stream = get_accelerator().current_stream()
14401440

14411441
with get_accelerator().stream(stream):
1442-
# Use pre-detected Muon flag from initialization
1443-
if not self.reduce_scatter or self.uses_muon:
1444-
# Force full all-reduce for Muon parameters even when reduce_scatter is enabled
1442+
# Check if current configuration requires full all-reduce
1443+
if not self.reduce_scatter or any(self.group_uses_muon):
1444+
# Force full all-reduce for Muon parameters or when reduce_scatter is disabled
14451445
self.gradient_reduction_w_predivide(tensor, communication_data_type)
14461446
return
14471447

0 commit comments

Comments
 (0)