diff --git a/csrc/trtllm_fused_moe_routing_renormalize.cu b/csrc/trtllm_fused_moe_routing_renormalize.cu index 87b9358e..0414b5aa 100644 --- a/csrc/trtllm_fused_moe_routing_renormalize.cu +++ b/csrc/trtllm_fused_moe_routing_renormalize.cu @@ -441,7 +441,9 @@ void run(Data const& data, void* stream) { TVM_FFI_ICHECK_LE(data.mPaddingLog2, 8) << "Routing kernel expects padding log2 < 8, got " << data.mPaddingLog2; - bool const useSiEgleBlock = data.mNumTokens <= BlockKernelMaxNumTokens; + // FIXME: routingIndicesBlockKernel breaks the vllm + gpt-oss DeepEP + // bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens; + bool const useSingleBlock = false; bool const useSingleCluster = data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index bb8786b3..d3a72592 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -182,23 +182,26 @@ namespace moe::dev { #define LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ stream, extraFlag1, extraFlag2, numExperts) \ if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \ - numThreads, smemSize, stream); \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag2) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \ - numThreads, smemSize, stream); \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false, true), kernel, \ + numBlocks, numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \ - numThreads, smemSize, stream); \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag1) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \ - kernel, numBlocks, numThreads, smemSize, stream); \ + LAUNCH_PDL(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag2) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \ - kernel, numBlocks, numThreads, smemSize, stream); \ + LAUNCH_PDL(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false, true), kernel, \ + numBlocks, numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \ - kernel, numBlocks, numThreads, smemSize, stream); \ + LAUNCH_PDL(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ } else { \ FLASHINFER_WARN("Unsupported dtypeExpW"); \ }