Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
557db0a
wip: not compiles yet
nekorobov Feb 5, 2026
45cdb86
fix: compiles, but hangs in autotuning
nekorobov Feb 5, 2026
d8c15b4
banned splitK and tileN 256, unit test works
nekorobov Feb 5, 2026
8a7a269
Merge remote-tracking branch 'origin/main' into nkorobov/mxfp8-trtllm…
IwakuraRein Feb 5, 2026
77c49a7
upd
IwakuraRein Feb 5, 2026
3e1a29f
add mxfp8 bench
IwakuraRein Feb 5, 2026
b12c461
fix test
IwakuraRein Feb 6, 2026
46eddfa
upd comments
IwakuraRein Feb 6, 2026
b046320
drop tile==8 and use unroll loop 2x
IwakuraRein Feb 6, 2026
acf0c39
fix test
IwakuraRein Feb 6, 2026
2702ee2
WAR: drop all UnrollLoop2xForMma kernels
IwakuraRein Feb 6, 2026
1dc688d
Merge remote-tracking branch 'origin/main' into siyuanf/mxfp8-trtllm-…
IwakuraRein Feb 7, 2026
4e83b82
address comment
IwakuraRein Feb 9, 2026
aae1719
fix unit test
IwakuraRein Feb 9, 2026
73d7594
fix hang and segfault
nekorobov Feb 10, 2026
4354ec4
use permute cache in unit test (WIP)
IwakuraRein Feb 10, 2026
0944312
use permute cache in unit test (WIP)
IwakuraRein Feb 10, 2026
aa85e94
Revert "use permute cache in unit test (WIP)"
IwakuraRein Feb 11, 2026
a7ebf1e
Merge remote-tracking branch 'origin/main' into siyuanf/mxfp8-trtllm-…
IwakuraRein Feb 12, 2026
4815a0c
address comments
IwakuraRein Feb 13, 2026
e18d73c
intermediate_size_factor
IwakuraRein Feb 13, 2026
b9f198d
Merge remote-tracking branch 'origin/main' into siyuanf/mxfp8-trtllm-…
IwakuraRein Feb 13, 2026
c310276
address comments
IwakuraRein Feb 13, 2026
33acaa2
quick fix
IwakuraRein Feb 13, 2026
03cac02
fix intermediate_size_factor initialization
IwakuraRein Feb 14, 2026
19417d1
allow split k
IwakuraRein Feb 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
425 changes: 294 additions & 131 deletions benchmarks/bench_trtllm_gen_fused_moe_autotuner.py

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions csrc/trtllm_batched_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
continue;
}

if (options.mDtypeA == tg::Dtype::MxE4m3 && options.mDtypeB == tg::Dtype::MxE4m3 &&
options.mNumSlicesForSplitK > 1) {
continue;
}

if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM) {
mPassingConfigIndices.push_back(i);
}
Expand Down
199 changes: 134 additions & 65 deletions csrc/trtllm_fused_moe_kernel_launcher.cu

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ArtifactPath:

TRTLLM_GEN_FMHA: str = "75d477a640f268ea9ad117cc596eb39245713b9e/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"e1e11bbfe0743743620ef997a6d5e8e2dbdf01cf/batched_gemm-2a674db-79e4d37"
"456b1ae890d436c794b17e4435b41b849d3e5950/batched_gemm-2a674db-3a84a12"
)
TRTLLM_GEN_GEMM: str = (
"1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3"
Expand All @@ -110,7 +110,7 @@ class CheckSumHash:
"e014d7a54c396733ef012b223603c1be2861019f88faa5dcc882ed1ecfe5c2d9"
)
TRTLLM_GEN_BMM: str = (
"03b1a419b594b7a4613ea8437c172dc2627d56bd360be25aa604859dc12a05fb"
"b9121fed5dd7700b7c2a0dcbcf2ef022483855cf585263324275b0072cca6bb7"
)
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
TRTLLM_GEN_GEMM: str = (
Expand Down
1 change: 1 addition & 0 deletions flashinfer/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from .core import (
ActivationType,
Fp8QuantizationType,
RoutingMethodType,
WeightLayout,
convert_to_block_layout,
Expand Down
72 changes: 57 additions & 15 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,26 @@ class WeightLayout(IntEnum):
BlockMajorK = 2


# The type of gated activation function
# Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h
class GatedActType(IntEnum):
# SwiGlu
SwiGlu = 0
# GeGlu
GeGlu = 1


# The type of FP8 quantization
# Please keep this in sync with the counterpart defined in trtllm_fused_moe_kernel_launcher.cu
class Fp8QuantizationType(IntEnum):
# No FP8 quantization
NoneFp8 = 0
# DeepSeek FP8
DeepSeekFp8 = 1
# MxFp8 x MxFp8
MxFp8 = 2


@functools.cache
def is_trtllm_moe_supported(
dtype_weights: DtypeTrtllmGen,
Expand Down Expand Up @@ -986,7 +1006,7 @@ def __init__(
num_local_experts: int,
dtype_act: DtypeTrtllmGen,
dtype_weights: DtypeTrtllmGen,
use_deepseek_fp8: bool,
fp8_quantization_type: Fp8QuantizationType,
hidden_size: int,
intermediate_size: int,
activation_type: int = ActivationType.Swiglu,
Expand All @@ -998,7 +1018,7 @@ def __init__(
self.top_k = top_k
self.dtype_act = dtype_act
self.dtype_weights = dtype_weights
self.use_deepseek_fp8 = use_deepseek_fp8
self.fp8_quantization_type = fp8_quantization_type
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
Expand All @@ -1025,7 +1045,7 @@ def get_valid_tactics(
instance_key = (
self.dtype_act,
self.dtype_weights,
self.use_deepseek_fp8,
self.fp8_quantization_type,
self.top_k,
self.hidden_size,
self.intermediate_size,
Expand Down Expand Up @@ -1114,16 +1134,28 @@ def forward(
and self.dtype_weights == DtypeTrtllmGen.E4m3
):
# FP8 operations
if self.use_deepseek_fp8:
if (
self.fp8_quantization_type == Fp8QuantizationType.DeepSeekFp8
or self.fp8_quantization_type == Fp8QuantizationType.MxFp8
):
# FP8 block scale
current_num_tokens = hidden_states.shape[0]
current_hidden_size = hidden_states.shape[1]
current_hidden_states_scale = torch.full(
(current_hidden_size // 128, current_num_tokens),
2.0,
dtype=torch.float,
device=hidden_states.device,
)
if self.fp8_quantization_type == Fp8QuantizationType.DeepSeekFp8:
current_hidden_states_scale = torch.full(
(current_hidden_size // 128, current_num_tokens),
2.0,
dtype=torch.float,
device=hidden_states.device,
)
elif self.fp8_quantization_type == Fp8QuantizationType.MxFp8:
current_hidden_states_scale = extra_inputs[0]

else:
raise ValueError(
f"Unsupported FP8 quantization type: {self.fp8_quantization_type}"
)

moe_op.trtllm_fp8_block_scale_moe(
routing_logits,
topk_ids,
Expand All @@ -1149,6 +1181,7 @@ def forward(
kwargs["weight_layout"],
kwargs["enable_pdl"],
[-1, -1] if tactic == -1 else tactic,
self.fp8_quantization_type,
)
else:
# FP8 per tensor scale
Expand Down Expand Up @@ -1319,7 +1352,7 @@ def trtllm_bf16_moe_op(
num_local_experts=local_num_experts,
dtype_act=dtype_act,
dtype_weights=dtype_weights,
use_deepseek_fp8=False,
fp8_quantization_type=Fp8QuantizationType.NoneFp8,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
weight_layout=weight_layout,
Expand Down Expand Up @@ -1452,7 +1485,7 @@ def trtllm_fp8_per_tensor_scale_moe_op(
num_local_experts=local_num_experts,
dtype_act=dtype_act,
dtype_weights=dtype_weights,
use_deepseek_fp8=False, # per_tensor mode
fp8_quantization_type=Fp8QuantizationType.NoneFp8, # per_tensor mode
hidden_size=hidden_size,
intermediate_size=intermediate_size,
weight_layout=WeightLayout.MajorK,
Expand Down Expand Up @@ -1569,6 +1602,7 @@ def trtllm_fp8_block_scale_moe_op(
weight_layout: int = 0,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 8192,
fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
) -> torch.Tensor:
# Determine routing mode: compute from logits or use pre-computed
if routing_logits is None:
Expand Down Expand Up @@ -1619,7 +1653,7 @@ def trtllm_fp8_block_scale_moe_op(
num_local_experts=local_num_experts,
dtype_act=dtype_act,
dtype_weights=dtype_weights,
use_deepseek_fp8=True, # block_scale mode
fp8_quantization_type=fp8_quantization_type, # block_scale mode
hidden_size=hidden_size,
intermediate_size=intermediate_size,
weight_layout=weight_layout,
Expand Down Expand Up @@ -1682,6 +1716,7 @@ def trtllm_fp8_block_scale_moe_op(
weight_layout,
enable_pdl,
[-1, -1] if tactic == -1 else tactic,
fp8_quantization_type,
)

return result
Expand Down Expand Up @@ -1712,6 +1747,7 @@ def _fake_trtllm_fp8_block_scale_moe(
weight_layout: int = 0,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 8192,
fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
):
seq_len = hidden_states.shape[0]
hidden_size = hidden_states.shape[1]
Expand Down Expand Up @@ -1809,7 +1845,7 @@ def trtllm_fp4_block_scale_moe_op(
num_local_experts=num_local_experts,
dtype_act=dtype_act,
dtype_weights=dtype_weights,
use_deepseek_fp8=False,
fp8_quantization_type=Fp8QuantizationType.NoneFp8,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
activation_type=activation_type,
Expand Down Expand Up @@ -2007,7 +2043,7 @@ def trtllm_mxint4_block_scale_moe_op(
num_local_experts=num_local_experts,
dtype_act=dtype_act,
dtype_weights=dtype_weights,
use_deepseek_fp8=False,
fp8_quantization_type=Fp8QuantizationType.NoneFp8,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
activation_type=ActivationType.Swiglu,
Expand Down Expand Up @@ -2303,6 +2339,7 @@ def trtllm_fp8_block_scale_moe(
weight_layout: int = 0,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 8192,
fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
) -> torch.Tensor:
"""FP8 block scale MoE operation.

Expand All @@ -2326,6 +2363,7 @@ def trtllm_fp8_block_scale_moe(
routing_method_type: Type of routing method to use (default: 0)
enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192)
fp8_quantization_type: Type of FP8 quantization to use (default: DeepSeekFp8)
Returns:
torch.Tensor: Output tensor of shape [seq_len, hidden_size]
"""
Expand Down Expand Up @@ -2357,6 +2395,7 @@ def trtllm_fp8_block_scale_moe(
weight_layout,
enable_pdl,
tune_max_num_tokens,
fp8_quantization_type,
)


Expand Down Expand Up @@ -2384,6 +2423,7 @@ def trtllm_fp8_block_scale_routed_moe(
enable_pdl: Optional[bool] = None,
output: Optional[torch.Tensor] = None,
tune_max_num_tokens: int = 8192,
fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
) -> torch.Tensor:
"""FP8 block scale MoE operation with pre-computed routing (packed format).

Expand Down Expand Up @@ -2418,6 +2458,7 @@ def trtllm_fp8_block_scale_routed_moe(
output (Optional[torch.Tensor]): shape [seq_len, hidden_size]
Optional inplace output tensor.
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192)
fp8_quantization_type: Type of FP8 quantization to use (default: DeepSeekFp8)
Returns:
torch.Tensor: Output tensor of shape [seq_len, hidden_size]
"""
Expand Down Expand Up @@ -2446,6 +2487,7 @@ def trtllm_fp8_block_scale_routed_moe(
weight_layout,
enable_pdl,
tune_max_num_tokens,
fp8_quantization_type,
)


Expand Down
Loading