@@ -860,6 +860,32 @@ def independent_gradient_partition_epilogue(self):
860860 self ._clear_previous_reduced_grads ()
861861
862862 if self .cpu_offload is False :
863+ # Pre-compute gradient norm for Muon clipping if needed
864+ grad_norm_for_muon = None
865+ if self .is_gradient_accumulation_boundary :
866+ # Check if any parameter group uses Muon
867+ uses_muon = False
868+ for i , _ in enumerate (self .bit16_groups ):
869+ if len (self .params_in_partition [i ]) > 0 and getattr (self .params_in_partition [i ][0 ], 'use_muon' , False ):
870+ uses_muon = True
871+ break
872+
873+ # Compute unscaled gradient norm if Muon is used and clipping is enabled
874+ if uses_muon and self .clip_grad > 0. :
875+ # Compute gradient norm before Muon update
876+ norm_groups = []
877+ for i , group in enumerate (self .bit16_groups ):
878+ if not i in self .averaged_gradients or self .averaged_gradients [i ] is None :
879+ all_grad_tensors = self .get_all_grad_tensors (self .params_in_partition [i ],
880+ dtype = self .gradient_accumulation_dtype )
881+ else :
882+ all_grad_tensors = self .all_grad_tensors [i ]
883+ if all_grad_tensors is not None :
884+ norm_groups .append (self .get_grad_norm_direct (all_grad_tensors , self .params_in_partition [i ]))
885+
886+ if len (norm_groups ) > 0 :
887+ grad_norm_for_muon = torch .linalg .vector_norm (torch .stack (norm_groups ), ord = 2 )
888+
863889 for i , _ in enumerate (self .bit16_groups ):
864890 if i not in self .all_grad_tensors or self .all_grad_tensors [i ] is None :
865891 self .all_grad_tensors [i ] = self .get_all_grad_tensors (self .params_in_partition [i ],
@@ -877,7 +903,8 @@ def independent_gradient_partition_epilogue(self):
877903 dtype = self .gradient_accumulation_dtype ,
878904 device = get_accelerator ().current_device_name (),
879905 param_group_idx = i ,
880- return_tensor_list = True )
906+ return_tensor_list = True ,
907+ grad_norm = grad_norm_for_muon )
881908 self .all_grad_tensors [i ] = None
882909
883910 self ._release_ipg_buffers ()
@@ -1894,7 +1921,8 @@ def get_flat_partition(self,
18941921 dtype ,
18951922 device ,
18961923 param_group_idx ,
1897- return_tensor_list = False ):
1924+ return_tensor_list = False ,
1925+ grad_norm = None ):
18981926 if len (tensor_list ) == 0 :
18991927 # This condition can fire when we have small parameteters and many ranks.
19001928 zero_buffer = torch .zeros (int (partition_size ), dtype = dtype , device = device )
@@ -1916,11 +1944,22 @@ def get_flat_partition(self,
19161944 flatten_bf_list = [torch .zeros ([total_size ], dtype = dtype , device = device )]
19171945 self .optimizer .state [flatten_copy ]["momentum_buffer" ] = self .flatten (flatten_bf_list )
19181946
1947+ # Calculate clip factor if gradient clipping is enabled and grad_norm is provided
1948+ clip_factor = 1.0
1949+ if self .clip_grad > 0. and grad_norm is not None :
1950+ # grad_norm is already unscaled (divided by loss_scale)
1951+ clip_factor = max (1.0 , grad_norm / self .clip_grad )
1952+
19191953 buffer_idx = 0
19201954 for i , tensor in enumerate (tensor_list ):
19211955 grad_accum = self .all_grad_tensors [param_group_idx ][i ]
19221956 if getattr (tensor , 'use_muon' , False ) and 'muon' in self .optimizer .__class__ .__name__ .lower ():
19231957 assert tensor .ndim > 1 , f"if use muon, then tensor dim > 1, got { tensor .size ()} "
1958+
1959+ # Apply gradient clipping before muon_update
1960+ if clip_factor > 1.0 :
1961+ grad_accum = grad_accum / clip_factor
1962+
19241963 buffer = torch .narrow (self .optimizer .state [flatten_copy ]["momentum_buffer" ], 0 , buffer_idx ,
19251964 tensor .numel ()).view (tensor .size ())
19261965 grad_accum = muon_update (grad_accum , buffer , self .optimizer .param_groups [param_group_idx ]['momentum' ])
@@ -2058,15 +2097,20 @@ def step(self, closure=None):
20582097 see_memory_usage ('Before norm calculation' )
20592098 scaled_global_grad_norm = self .scaled_global_norm ()
20602099 self ._global_grad_norm = scaled_global_grad_norm / prev_scale
2100+ unscaled_grad_norm = self ._global_grad_norm # Store unscaled norm for use in get_flat_partition
20612101 see_memory_usage ('After norm before optimizer' )
20622102
20632103 # Step 2:- run optimizer and upscaling simultaneously
20642104 for i , group in enumerate (self .bit16_groups ):
20652105 self .timers (OPTIMIZER_GRADIENTS_TIMER ).start ()
20662106 partition_id = dist .get_rank (group = self .real_dp_process_group [i ])
2107+
2108+ # Check if this param group uses Muon (clipping already done in get_flat_partition)
2109+ uses_muon = len (self .params_in_partition [i ]) > 0 and getattr (self .params_in_partition [i ][0 ], 'use_muon' , False )
2110+
20672111 if self .cpu_offload :
20682112 single_grad_partition = self .single_partition_of_fp32_groups [i ].grad
2069- self .unscale_and_clip_grads ([single_grad_partition ], scaled_global_grad_norm )
2113+ self .unscale_and_clip_grads ([single_grad_partition ], scaled_global_grad_norm , skip_clipping = uses_muon )
20702114
20712115 self .timers (OPTIMIZER_GRADIENTS_TIMER ).stop ()
20722116 self .timers (OPTIMIZER_STEP_TIMER ).start ()
@@ -2108,7 +2152,7 @@ def step(self, closure=None):
21082152
21092153 self .averaged_gradients [i ] = None
21102154 self .all_grad_tensors [i ] = None
2111- self .unscale_and_clip_grads ([single_grad_partition ], scaled_global_grad_norm )
2155+ self .unscale_and_clip_grads ([single_grad_partition ], scaled_global_grad_norm , skip_clipping = uses_muon )
21122156
21132157 self .timers (OPTIMIZER_GRADIENTS_TIMER ).stop ()
21142158
@@ -2168,10 +2212,10 @@ def _average_expert_grad_norms(self, norm_groups):
21682212 dist .all_reduce (scaled_norm_tensor , group = self .real_dp_process_group [i ])
21692213 norm_groups [i ] = scaled_norm_tensor .to (self .device )
21702214
2171- def unscale_and_clip_grads (self , grad_groups_flat , total_norm ):
2215+ def unscale_and_clip_grads (self , grad_groups_flat , total_norm , skip_clipping = False ):
21722216 # compute combined scale factor for this group
21732217 combined_scale = self .loss_scale
2174- if self .clip_grad > 0. :
2218+ if self .clip_grad > 0. and not skip_clipping :
21752219 # norm is in fact norm*scale
21762220 clip = ((total_norm / self .loss_scale ) + 1e-6 ) / self .clip_grad
21772221 clip = torch .clamp (clip , min = 1.0 )
0 commit comments