Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/fmhaReduction.cu
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ __global__ void __launch_bounds__(NumThreadsPerCta, 2)
// The O pointer.
DtypeO* oPtr = reinterpret_cast<DtypeO*>(params.ptrO) + oOffset;
// The attentionSinks pointer.
float const* attentionSinksPtr = params.ptrAttentionSinks + headIdxO;
float const* attentionSinksPtr = params.ptrAttentionSinks != nullptr ? (params.ptrAttentionSinks + headIdxO) : nullptr;

// Whether to store the softmax stats.
bool const storesSoftmaxStats{params.ptrSoftmaxStats != nullptr};
Expand Down
97 changes: 86 additions & 11 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,16 @@ class TllmGenFmhaRunnerCache {
void trtllm_paged_attention_launcher(
void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache,
void* workspace_buffer, int* block_tables, int* seq_lens, int* cum_seq_lens_q,
int* cum_seq_lens_kv, float* attention_sinks, Data_type q_data_type, Data_type kv_data_type,
int* cum_seq_lens_kv, float* attention_sinks, float* lse,
void* k_cache_scales, void* v_cache_scales,
Data_type q_data_type, Data_type kv_data_type,
Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len,
int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t kv_stride_keys_values,
int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq,
double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr,
const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, int64_t sm_count,
int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, int64_t lse_stride_tokens, int64_t lse_stride_heads, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
Expand All @@ -97,7 +99,9 @@ void trtllm_paged_attention_launcher(
// Common params
runner_params.qPtr = query;
runner_params.kPtr = key_cache;
runner_params.kSfBasePtr = k_cache_scales;
runner_params.vPtr = value_cache;
runner_params.vSfBasePtr = v_cache_scales;
runner_params.kvPageIdxPtr = block_tables;
runner_params.seqLensKvPtr = seq_lens;
runner_params.oPtr = out;
Expand Down Expand Up @@ -146,6 +150,9 @@ void trtllm_paged_attention_launcher(
<< "Only decode MLA supports sparse MLA";

AlignedAllocator float_allocator(workspace_buffer, workspace_size);
runner_params.lsePtr = lse;
runner_params.lseStrideTokens = lse_stride_tokens;
runner_params.lseStrideHeads = lse_stride_heads;
if (mode == TllmPagedAttentionMode::Context) {
runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal;
runner_params.mKernelType = FmhaKernelType::Context;
Expand All @@ -154,6 +161,10 @@ void trtllm_paged_attention_launcher(

runner_params.cumSeqLensQPtr = cum_seq_lens_q;
runner_params.cumSeqLensKvPtr = cum_seq_lens_kv;

runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16,
"trtllm_gen_softmax_workspace");
} else {
// ForGen
runner_params.mMaskType = TrtllmGenAttentionMaskType::Dense;
Expand All @@ -171,6 +182,9 @@ void trtllm_paged_attention_launcher(
// todo(Yingyi): add softmax buffer later for lse return
runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>(
num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace");
runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16,
"trtllm_gen_softmax_workspace");
// scratch takes the rest of the workspace buffer
runner_params.multiCtasKvScratchPtr =
float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");
Expand Down Expand Up @@ -213,7 +227,8 @@ void trtllm_paged_attention_decode(
TensorView seq_lens, int64_t max_kv_len, Variant<double, ffi::Tensor> bmm1_scale,
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) {
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> k_cache_scales, Optional<TensorView> v_cache_scales, Optional<TensorView> lse) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
Expand Down Expand Up @@ -249,8 +264,12 @@ void trtllm_paged_attention_decode(
int num_kv_heads = key_cache.size(-3);
int kv_stride_keys_values = key_cache.stride(-2); // key/values
int kv_stride_heads = key_cache.stride(-3); // head

int kv_stride_batch = key_cache.stride(0); // batch
if (is_4bit(kv_data_type)) {
kv_stride_keys_values *= 2;
kv_stride_heads *= 2;
kv_stride_batch *= 2;
}

const auto stream = get_stream(query.device());
void* output_sf_ptr =
Expand Down Expand Up @@ -281,17 +300,39 @@ void trtllm_paged_attention_decode(
float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value()
? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
: nullptr;

float* lse_ptr = nullptr;
int lse_stride_tokens = 0;
int lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(2);
}

void* k_cache_scales_ptr = nullptr;
void* v_cache_scales_ptr = nullptr;
if (k_cache_scales.has_value()) {
k_cache_scales_ptr = k_cache_scales.value().data_ptr();
}
if (v_cache_scales.has_value()) {
v_cache_scales_ptr = v_cache_scales.value().data_ptr();
}

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
static_cast<int*>(seq_lens.data_ptr()),
/*cum_seq_lens_q=*/nullptr,
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, lse_ptr,
k_cache_scales_ptr, v_cache_scales_ptr,
q_data_type, kv_data_type, o_data_type,
TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len,
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq,
bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,
o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, sm_count,
o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, lse_stride_tokens, lse_stride_heads, sm_count,
enable_pdl, workspace_size, stream);
}

Expand All @@ -302,7 +343,9 @@ void trtllm_paged_attention_context(
Variant<double, ffi::Tensor> bmm1_scale, Variant<double, ffi::Tensor> bmm2_scale,
double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size,
int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) {
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> k_cache_scales, Optional<TensorView> v_cache_scales, Optional<TensorView> lse) {

auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
Expand All @@ -329,7 +372,12 @@ void trtllm_paged_attention_context(
int kv_stride_keys_values = key_cache.stride(-2); // key/values
int kv_stride_heads = key_cache.stride(-3); // head
int kv_stride_batch = key_cache.stride(0); // batch

if (is_4bit(kv_data_type)) {
kv_stride_keys_values *= 2;
kv_stride_heads *= 2;
kv_stride_batch *= 2;
}

const auto stream = get_stream(query.device());
void* output_sf_ptr =
out_scale_factor.has_value() ? out_scale_factor.value().data_ptr() : nullptr;
Expand Down Expand Up @@ -361,18 +409,38 @@ void trtllm_paged_attention_context(
? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
: nullptr;

float* lse_ptr = nullptr;
int lse_stride_tokens = 0;
int lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(1);
}

void* k_cache_scales_ptr = nullptr;
void* v_cache_scales_ptr = nullptr;
if (k_cache_scales.has_value()) {
k_cache_scales_ptr = k_cache_scales.value().data_ptr();
}
if (v_cache_scales.has_value()) {
v_cache_scales_ptr = v_cache_scales.value().data_ptr();
}

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
static_cast<int*>(seq_lens.data_ptr()),
/*cum_seq_lens_q=*/static_cast<int*>(cum_seq_lens_q.data_ptr()),
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr,
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, lse_ptr,
k_cache_scales_ptr, v_cache_scales_ptr,
q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
/*sparse_mla_top_k=*/0, sm_count, enable_pdl, workspace_size, stream);
/*sparse_mla_top_k=*/0, lse_stride_tokens, lse_stride_heads,sm_count, enable_pdl, workspace_size, stream);
}

void trtllm_ragged_attention_launcher(
Expand All @@ -385,6 +453,7 @@ void trtllm_ragged_attention_launcher(
int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal,
int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch,
int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch,
int64_t lse_stride_tokens, int64_t lse_stride_heads,
int64_t workspace_size, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
Expand Down Expand Up @@ -441,6 +510,8 @@ void trtllm_ragged_attention_launcher(
runner_params.mMaskType =
is_causal ? TrtllmGenAttentionMaskType::Causal : TrtllmGenAttentionMaskType::Dense;
runner_params.lsePtr = lse;
runner_params.lseStrideTokens = lse_stride_tokens;
runner_params.lseStrideHeads = lse_stride_heads;

AlignedAllocator float_allocator(workspace_buffer, workspace_size);
size_t max_batch_size = 8192;
Expand Down Expand Up @@ -482,9 +553,13 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
attention_sinks_ptr = static_cast<float*>(attention_sinks.value().data_ptr());
}
float* lse_ptr = nullptr;
int lse_stride_tokens = 0;
int lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(1);
}
TVM_FFI_ICHECK_EQ(out.ndim(), 3) << "out must be a 3D tensor";
TVM_FFI_ICHECK_EQ(query.ndim(), 3) << "query must be a 3D tensor";
Expand Down Expand Up @@ -535,7 +610,7 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale_value,
bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left,
sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch,
v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream);
v_stride_keys_values, v_stride_heads, v_stride_batch, lse_stride_tokens, lse_stride_heads, workspace_size, stream);
}

namespace trtllm_cubin_loader {
Expand Down
14 changes: 7 additions & 7 deletions csrc/trtllm_fused_moe_routing_deepseek.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ __global__ void routingMainKernel(KernelParams params) {
}
}

// note that for invalid scores, we simply use a negative value:
// they work well even with the compacted format used in topK, and
// sigmoid / bias activated scores cannot be negative
static constexpr float invalidScoreFloat = -1.F;
const OutputT invalidScore = OutputT{invalidScoreFloat};
// note that for invalid scores, we use a very negative value:
// they work well even with the compacted format used in topK.
// With negative bias allowed, sigmoid + bias scores can be negative,
// so we use a very negative value to ensure invalid scores are always less than valid ones
static constexpr float invalidScoreFloat = -1e10F;

// load bias already; each warp represents one expert group
auto threadExpert = threadIdx.x;
Expand Down Expand Up @@ -101,8 +101,8 @@ __global__ void routingMainKernel(KernelParams params) {
smemScoreSigmoid[threadExpert] = scoreSigmoid;
}
// get the score with bias
// note that with invalid values, because sigmoid is < 1 and bias is -1,
// we must get a negative value, which is smaller than any valid value
// note: with invalid values, we use invalidScoreFloat which is guaranteed
// to be smaller than any valid value (even when bias is negative)
auto scoreBias = float{scoreSigmoid + float{biasVal}};

if (expertSelected) {
Expand Down
61 changes: 59 additions & 2 deletions csrc/trtllm_moe_allreduce_fusion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@ using tvm::ffi::Optional;
} \
}()

#define DISPATCH_FLOATING_TYPES_FOR_SCALE(dtype, ScaleType, ...) \
[&] { \
switch (encode_dlpack_dtype(dtype)) { \
case float16_code: { \
using ScaleType = half; \
return __VA_ARGS__(); \
} \
case bfloat16_code: { \
using ScaleType = __nv_bfloat16; \
return __VA_ARGS__(); \
} \
case float32_code: { \
using ScaleType = float; \
return __VA_ARGS__(); \
} \
default: \
TVM_FFI_LOG_AND_THROW(NotImplementedError) \
<< "Unsupported expert_scale_factor dtype; only float16, bfloat16 " \
"and float32 are supported in trtllm_moe_finalize_allreduce_fusion"; \
} \
}()

void trtllm_moe_allreduce_fusion(
int64_t world_size, int64_t world_rank, int64_t token_num, int64_t hidden_size,
TensorView workspace_ptrs, bool launch_with_pdl, TensorView residual_in, TensorView rms_gamma,
Expand Down Expand Up @@ -86,12 +108,13 @@ void trtllm_moe_finalize_allreduce_fusion(
TensorView expanded_idx_to_permuted_idx, TensorView norm_out, TensorView residual_out,
bool launch_with_pdl, TensorView workspace, int64_t const world_rank, int64_t const world_size,
double const eps, Optional<TensorView> shared_expert_output,
Optional<TensorView> expert_scale_factor) {
Optional<TensorView> expert_scale_factor, Optional<float> routing_scaling_factor) {
DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(residual_in.dtype(), c_type, [&] {
MoeFinalizeAllReduceFusionParams<c_type> params;

int hidden_dim = residual_in.size(-1);
int top_k = expanded_idx_to_permuted_idx.size(-1);
int num_tokens = residual_in.size(0);

params.quant_out = nullptr;
params.scale_out = nullptr;
Expand All @@ -101,6 +124,14 @@ void trtllm_moe_finalize_allreduce_fusion(
// size: num_token * hidden_dim
params.size = residual_in.numel();
params.hidden_dim = hidden_dim;
FLASHINFER_CHECK(
expanded_idx_to_permuted_idx.size(0) == num_tokens,
"expanded_idx_to_permuted_idx.size(0) must equal num_tokens, got "
"expanded_idx_to_permuted_idx.size(0)=%d and num_tokens=%d",
expanded_idx_to_permuted_idx.size(0), num_tokens);
params.routing_scaling_factor =
routing_scaling_factor.has_value() ? static_cast<float>(routing_scaling_factor.value())
: 1.0f;

// workspace: AR scratch space
params.workspace = reinterpret_cast<void**>(workspace.data_ptr());
Expand All @@ -125,7 +156,33 @@ void trtllm_moe_finalize_allreduce_fusion(
params.norm_out = norm_out.data_ptr();
params.residual_out = residual_out.data_ptr();

auto status = moefinalize_allreduce_fusion_op(params, launch_with_pdl);
// Record norm_out dtype so kernels can specialize behavior if needed.
params.norm_out_dtype = encode_dlpack_dtype(norm_out.dtype());

// Dispatch on expert_scale_factor dtype for ScaleType. If none is provided, default to float
// and select NormOutT based on norm_out dtype.
cudaError_t status;
if (!expert_scale_factor.has_value()) {
auto norm_dtype = encode_dlpack_dtype(norm_out.dtype());
if (norm_dtype == float8_e4m3fn_code) {
status =
moefinalize_allreduce_fusion_op<c_type, float, __nv_fp8_e4m3>(params, launch_with_pdl);
} else {
status = moefinalize_allreduce_fusion_op<c_type, float, c_type>(params, launch_with_pdl);
}
} else {
DISPATCH_FLOATING_TYPES_FOR_SCALE(
expert_scale_factor.value().dtype(), ScaleType, [&]() {
auto norm_dtype = encode_dlpack_dtype(norm_out.dtype());
if (norm_dtype == float8_e4m3fn_code) {
status = moefinalize_allreduce_fusion_op<c_type, ScaleType, __nv_fp8_e4m3>(
params, launch_with_pdl);
} else {
status =
moefinalize_allreduce_fusion_op<c_type, ScaleType, c_type>(params, launch_with_pdl);
}
});
}
TVM_FFI_ICHECK(status == cudaSuccess)
<< "moefinalize_allreduce_fusion_op failed with error code " << cudaGetErrorString(status);
});
Expand Down
Loading
Loading