@@ -219,7 +219,7 @@ def ptx_select_argmax_candidate(
219219
220220 @cute .jit
221221 def warp_argmax_redux (current_max : Float32 , current_argmax : Int32 ):
222- """Redux-based warp argmax - only works on sm_100+ (Blackwell)."""
222+ """Redux-based warp argmax - only works on sm_100f (Blackwell)."""
223223 warp_max = ptx_redux_sync_max_f32 (current_max )
224224 candidate_idx = ptx_select_argmax_candidate (current_max , warp_max , current_argmax )
225225 winning_idx = ptx_redux_sync_min_u32 (candidate_idx )
@@ -324,7 +324,7 @@ def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_w
324324 class ArgmaxKernel (ReductionBase ):
325325 def __init__ (self , dtype : Type [cutlass .Numeric ], N : int , use_redux : bool = False ):
326326 super ().__init__ (dtype , N , stage = 1 , reduction_dtype = cutlass .Float32 )
327- # use_redux=True for Blackwell (sm_100+ ), False for Hopper (sm_90)
327+ # use_redux=True for Blackwell (sm_100f ), False for Hopper (sm_90)
328328 self .use_redux = use_redux
329329
330330 def _calculate_threads_per_row (self ):
@@ -582,6 +582,11 @@ def _should_use_torch_fallback(N: int, dtype: torch.dtype) -> bool:
582582 return True
583583 if N % _VOCAB_SIZE_ALIGNMENT != 0 :
584584 return True
585+ # Fall back on sm_120 - CUTLASS DSL JIT not well-supported for this setup
586+ from ..._utils import get_sm_version
587+
588+ if get_sm_version () >= 120 :
589+ return True
585590 return False
586591
587592 def argmax (x : torch .Tensor ) -> torch .Tensor :
@@ -618,11 +623,12 @@ def convert_from_dlpack(tensor):
618623 out_tensor = convert_from_dlpack (out )
619624 current_stream = cuda .CUstream (torch .cuda .current_stream ().cuda_stream )
620625
621- # Detect compute capability: use redux instructions only on Blackwell (sm_100+ )
622- # redux.sync.max.f32 is only supported on sm_100+
626+ # Detect compute capability: use redux instructions only on Blackwell (sm_100f )
627+ # redux.sync.max.f32 is only supported on sm_100f
623628 from ..._utils import get_sm_version
624629
625- use_redux = get_sm_version () >= 100 # sm_100+ (Blackwell)
630+ sm_version = get_sm_version ()
631+ use_redux = sm_version >= 100 and sm_version < 120
626632
627633 compile_key = (dtype , N , use_redux )
628634 if compile_key not in _argmax_compile_cache :
0 commit comments