Skip to content

Commit 0bd4630

Browse files
authored
[https://nvbugs/5854860][fix] Fix cutedsl argmax on sm120 (#11181)
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
1 parent ad2d1df commit 0bd4630

File tree

2 files changed

+11
-18
lines changed

2 files changed

+11
-18
lines changed

tensorrt_llm/_torch/cute_dsl_kernels/argmax.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def ptx_select_argmax_candidate(
219219

220220
@cute.jit
221221
def warp_argmax_redux(current_max: Float32, current_argmax: Int32):
222-
"""Redux-based warp argmax - only works on sm_100+ (Blackwell)."""
222+
"""Redux-based warp argmax - only works on sm_100f (Blackwell)."""
223223
warp_max = ptx_redux_sync_max_f32(current_max)
224224
candidate_idx = ptx_select_argmax_candidate(current_max, warp_max, current_argmax)
225225
winning_idx = ptx_redux_sync_min_u32(candidate_idx)
@@ -324,7 +324,7 @@ def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_w
324324
class ArgmaxKernel(ReductionBase):
325325
def __init__(self, dtype: Type[cutlass.Numeric], N: int, use_redux: bool = False):
326326
super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
327-
# use_redux=True for Blackwell (sm_100+), False for Hopper (sm_90)
327+
# use_redux=True for Blackwell (sm_100f), False for Hopper (sm_90)
328328
self.use_redux = use_redux
329329

330330
def _calculate_threads_per_row(self):
@@ -582,6 +582,11 @@ def _should_use_torch_fallback(N: int, dtype: torch.dtype) -> bool:
582582
return True
583583
if N % _VOCAB_SIZE_ALIGNMENT != 0:
584584
return True
585+
# Fall back on sm_120 - CUTLASS DSL JIT not well-supported for this setup
586+
from ..._utils import get_sm_version
587+
588+
if get_sm_version() >= 120:
589+
return True
585590
return False
586591

587592
def argmax(x: torch.Tensor) -> torch.Tensor:
@@ -618,11 +623,12 @@ def convert_from_dlpack(tensor):
618623
out_tensor = convert_from_dlpack(out)
619624
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
620625

621-
# Detect compute capability: use redux instructions only on Blackwell (sm_100+)
622-
# redux.sync.max.f32 is only supported on sm_100+
626+
# Detect compute capability: use redux instructions only on Blackwell (sm_100f)
627+
# redux.sync.max.f32 is only supported on sm_100f
623628
from ..._utils import get_sm_version
624629

625-
use_redux = get_sm_version() >= 100 # sm_100+ (Blackwell)
630+
sm_version = get_sm_version()
631+
use_redux = sm_version >= 100 and sm_version < 120
626632

627633
compile_key = (dtype, N, use_redux)
628634
if compile_key not in _argmax_compile_cache:

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,6 @@ accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_chunked_p
262262
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True] SKIP (https://nvbugs/5800591)
263263
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5800646)
264264
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5800672)
265-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5800679)
266265
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5853997)
267266
examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] SKIP (https://nvbugs/5802248)
268267
unittest/_torch/modeling/test_modeling_llama.py::TestLlama::test_llama_verification_with_kv_cache_relocation SKIP (https://nvbugs/5804923)
@@ -345,18 +344,6 @@ unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP] S
345344
unittest/_torch/modeling/test_modeling_nemotron_h.py::test_nemotron_h_cuda_graph_overlap_scheduler SKIP (https://nvbugs/5843316)
346345
examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1] SKIP (https://nvbugs/5846178)
347346
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] SKIP (https://nvbugs/5846024)
348-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154)
349-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154)
350-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5846154)
351-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5846154)
352-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5846154)
353-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5846154)
354-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/5846154)
355-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154)
356-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5846154)
357-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154)
358-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154)
359-
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154)
360347
accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] SKIP (https://nvbugs/5847284)
361348
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-trtllm-fp8] SKIP (https://nvbugs/5850183)
362349
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-trtllm-fp8] SKIP (https://nvbugs/5850183)

0 commit comments

Comments
 (0)