Skip to content

Commit 1a693b5

Browse files
author
fy817
committed
Fix Muon optimizer conflict with gradient clipping in ZeRO 1/2
Signed-off-by: fy817 <277645218@qq.com>
1 parent 8a9369d commit 1a693b5

File tree

1 file changed

+50
-6
lines changed

1 file changed

+50
-6
lines changed

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,32 @@ def independent_gradient_partition_epilogue(self):
860860
self._clear_previous_reduced_grads()
861861

862862
if self.cpu_offload is False:
863+
# Pre-compute gradient norm for Muon clipping if needed
864+
grad_norm_for_muon = None
865+
if self.is_gradient_accumulation_boundary:
866+
# Check if any parameter group uses Muon
867+
uses_muon = False
868+
for i, _ in enumerate(self.bit16_groups):
869+
if len(self.params_in_partition[i]) > 0 and getattr(self.params_in_partition[i][0], 'use_muon', False):
870+
uses_muon = True
871+
break
872+
873+
# Compute unscaled gradient norm if Muon is used and clipping is enabled
874+
if uses_muon and self.clip_grad > 0.:
875+
# Compute gradient norm before Muon update
876+
norm_groups = []
877+
for i, group in enumerate(self.bit16_groups):
878+
if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
879+
all_grad_tensors = self.get_all_grad_tensors(self.params_in_partition[i],
880+
dtype=self.gradient_accumulation_dtype)
881+
else:
882+
all_grad_tensors = self.all_grad_tensors[i]
883+
if all_grad_tensors is not None:
884+
norm_groups.append(self.get_grad_norm_direct(all_grad_tensors, self.params_in_partition[i]))
885+
886+
if len(norm_groups) > 0:
887+
grad_norm_for_muon = torch.linalg.vector_norm(torch.stack(norm_groups), ord=2)
888+
863889
for i, _ in enumerate(self.bit16_groups):
864890
if i not in self.all_grad_tensors or self.all_grad_tensors[i] is None:
865891
self.all_grad_tensors[i] = self.get_all_grad_tensors(self.params_in_partition[i],
@@ -877,7 +903,8 @@ def independent_gradient_partition_epilogue(self):
877903
dtype=self.gradient_accumulation_dtype,
878904
device=get_accelerator().current_device_name(),
879905
param_group_idx=i,
880-
return_tensor_list=True)
906+
return_tensor_list=True,
907+
grad_norm=grad_norm_for_muon)
881908
self.all_grad_tensors[i] = None
882909

883910
self._release_ipg_buffers()
@@ -1894,7 +1921,8 @@ def get_flat_partition(self,
18941921
dtype,
18951922
device,
18961923
param_group_idx,
1897-
return_tensor_list=False):
1924+
return_tensor_list=False,
1925+
grad_norm=None):
18981926
if len(tensor_list) == 0:
18991927
# This condition can fire when we have small parameteters and many ranks.
19001928
zero_buffer = torch.zeros(int(partition_size), dtype=dtype, device=device)
@@ -1916,11 +1944,22 @@ def get_flat_partition(self,
19161944
flatten_bf_list = [torch.zeros([total_size], dtype=dtype, device=device)]
19171945
self.optimizer.state[flatten_copy]["momentum_buffer"] = self.flatten(flatten_bf_list)
19181946

1947+
# Calculate clip factor if gradient clipping is enabled and grad_norm is provided
1948+
clip_factor = 1.0
1949+
if self.clip_grad > 0. and grad_norm is not None:
1950+
# grad_norm is already unscaled (divided by loss_scale)
1951+
clip_factor = max(1.0, grad_norm / self.clip_grad)
1952+
19191953
buffer_idx = 0
19201954
for i, tensor in enumerate(tensor_list):
19211955
grad_accum = self.all_grad_tensors[param_group_idx][i]
19221956
if getattr(tensor, 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower():
19231957
assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}"
1958+
1959+
# Apply gradient clipping before muon_update
1960+
if clip_factor > 1.0:
1961+
grad_accum = grad_accum / clip_factor
1962+
19241963
buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx,
19251964
tensor.numel()).view(tensor.size())
19261965
grad_accum = muon_update(grad_accum, buffer, self.optimizer.param_groups[param_group_idx]['momentum'])
@@ -2058,15 +2097,20 @@ def step(self, closure=None):
20582097
see_memory_usage('Before norm calculation')
20592098
scaled_global_grad_norm = self.scaled_global_norm()
20602099
self._global_grad_norm = scaled_global_grad_norm / prev_scale
2100+
unscaled_grad_norm = self._global_grad_norm # Store unscaled norm for use in get_flat_partition
20612101
see_memory_usage('After norm before optimizer')
20622102

20632103
# Step 2:- run optimizer and upscaling simultaneously
20642104
for i, group in enumerate(self.bit16_groups):
20652105
self.timers(OPTIMIZER_GRADIENTS_TIMER).start()
20662106
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
2107+
2108+
# Check if this param group uses Muon (clipping already done in get_flat_partition)
2109+
uses_muon = len(self.params_in_partition[i]) > 0 and getattr(self.params_in_partition[i][0], 'use_muon', False)
2110+
20672111
if self.cpu_offload:
20682112
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
2069-
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
2113+
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm, skip_clipping=uses_muon)
20702114

20712115
self.timers(OPTIMIZER_GRADIENTS_TIMER).stop()
20722116
self.timers(OPTIMIZER_STEP_TIMER).start()
@@ -2108,7 +2152,7 @@ def step(self, closure=None):
21082152

21092153
self.averaged_gradients[i] = None
21102154
self.all_grad_tensors[i] = None
2111-
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
2155+
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm, skip_clipping=uses_muon)
21122156

21132157
self.timers(OPTIMIZER_GRADIENTS_TIMER).stop()
21142158

@@ -2168,10 +2212,10 @@ def _average_expert_grad_norms(self, norm_groups):
21682212
dist.all_reduce(scaled_norm_tensor, group=self.real_dp_process_group[i])
21692213
norm_groups[i] = scaled_norm_tensor.to(self.device)
21702214

2171-
def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
2215+
def unscale_and_clip_grads(self, grad_groups_flat, total_norm, skip_clipping=False):
21722216
# compute combined scale factor for this group
21732217
combined_scale = self.loss_scale
2174-
if self.clip_grad > 0.:
2218+
if self.clip_grad > 0. and not skip_clipping:
21752219
# norm is in fact norm*scale
21762220
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
21772221
clip = torch.clamp(clip, min=1.0)

0 commit comments

Comments
 (0)