Skip to content

Commit 1496dea

Browse files
authored
add block scale moe python api for asm_moe (#341)
* ck fuse moe gemm stage1 with act * ck fuse moe gemm stage1 with act * fp8 switch to ck * move to ck develop * cutting some moe ut cases * fix bug * fix profile bugs * revert changes of 'get_2stage_cfgs' * re-trigger CI * moe i8 switch ck * Retrigger CI * add block scale moe python api for asm_moe * blockscale moe * fix blockscale bugs
1 parent 781c686 commit 1496dea

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

aiter/fused_moe_bf16_asm.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def asm_moe(hidden_states,
5454
fc2_smooth_scale=None, # [expert(local_expert:EP), 1, inter_dim]
5555
a16=False,
5656
per_tensor_quant_scale=None,
57+
block_shape=None,
5758
expert_mask=None,
5859
activation = ActivationType.Silu
5960
):
@@ -93,7 +94,41 @@ def asm_moe(hidden_states,
9394
else:
9495
raise ValueError(
9596
f"Invalid args: {w1.dtype} {w1.shape=} {w2.shape=}")
96-
97+
elif block_shape is not None:
98+
assert dtype == torch.bfloat16, "asm_moe for block_scale only support bfloat16 hidden_states"
99+
assert block_shape == (
100+
128, 128), "asm_moe for block_scale only support (128, 128)"
101+
assert w1.dtype == torch.float8_e4m3fnuz, "asm_moe for block_scale only support float8_e4m3fnuz weight"
102+
assert w2.shape[2] * 2 == w1.shape[1], "aiter moe for block_scale only support g1u1"
103+
scale_blk_n, scale_blk_k = block_shape
104+
hidden_states = hidden_states.view(M *
105+
model_dim//scale_blk_k, scale_blk_k)
106+
107+
a1_q, a1_scale = pertoken_quant(
108+
hidden_states.view(-1, model_dim // scale_blk_k, scale_blk_k), quant_dtype=torch.float8_e4m3fnuz
109+
)
110+
a1_q = a1_q.view(-1, model_dim)
111+
a1_scale = a1_scale.squeeze(-1).t().contiguous()
112+
113+
114+
scale_blk_n, scale_blk_k = block_shape
115+
aiter.fmoe_fp8_blockscale_g1u1(
116+
moe_buf,
117+
a1_q,
118+
w1,
119+
w2,
120+
sorted_ids,
121+
sorted_weights,
122+
sorted_expert_ids,
123+
num_valid_ids,
124+
topk,
125+
a1_scale,
126+
fc1_scale,
127+
fc2_scale,
128+
scale_blk_n,
129+
scale_blk_k,
130+
None,
131+
)
97132
else:
98133
# a8w8 fmoe, opt: smooth quant
99134
a8_type = w1.dtype if w1.dtype != torch.int32 and w1.dtype != torch.uint32 else torch.float8_e4m3fnuz

op_tests/test_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def asm_moe_test(
9090
a16,
9191
None,
9292
None,
93+
None,
9394
activation,
9495
)
9596

0 commit comments

Comments
 (0)