Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/fmhaReduction.cu
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ __global__ void __launch_bounds__(NumThreadsPerCta, 2)
// The O pointer.
DtypeO* oPtr = reinterpret_cast<DtypeO*>(params.ptrO) + oOffset;
// The attentionSinks pointer.
float const* attentionSinksPtr = params.ptrAttentionSinks + headIdxO;
float const* attentionSinksPtr = params.ptrAttentionSinks != nullptr ? (params.ptrAttentionSinks + headIdxO) : nullptr;

// Whether to store the softmax stats.
bool const storesSoftmaxStats{params.ptrSoftmaxStats != nullptr};
Expand Down
97 changes: 86 additions & 11 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,16 @@ class TllmGenFmhaRunnerCache {
void trtllm_paged_attention_launcher(
void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache,
void* workspace_buffer, int* block_tables, int* seq_lens, int* cum_seq_lens_q,
int* cum_seq_lens_kv, float* attention_sinks, Data_type q_data_type, Data_type kv_data_type,
int* cum_seq_lens_kv, float* attention_sinks, float* lse,
void* k_cache_scales, void* v_cache_scales,
Data_type q_data_type, Data_type kv_data_type,
Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len,
int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t kv_stride_keys_values,
int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq,
double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr,
const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, int64_t sm_count,
int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, int64_t lse_stride_tokens, int64_t lse_stride_heads, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
Expand All @@ -97,7 +99,9 @@ void trtllm_paged_attention_launcher(
// Common params
runner_params.qPtr = query;
runner_params.kPtr = key_cache;
runner_params.kSfBasePtr = k_cache_scales;
runner_params.vPtr = value_cache;
runner_params.vSfBasePtr = v_cache_scales;
runner_params.kvPageIdxPtr = block_tables;
runner_params.seqLensKvPtr = seq_lens;
runner_params.oPtr = out;
Expand Down Expand Up @@ -146,6 +150,9 @@ void trtllm_paged_attention_launcher(
<< "Only decode MLA supports sparse MLA";

AlignedAllocator float_allocator(workspace_buffer, workspace_size);
runner_params.lsePtr = lse;
runner_params.lseStrideTokens = lse_stride_tokens;
runner_params.lseStrideHeads = lse_stride_heads;
if (mode == TllmPagedAttentionMode::Context) {
runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal;
runner_params.mKernelType = FmhaKernelType::Context;
Expand All @@ -154,6 +161,10 @@ void trtllm_paged_attention_launcher(

runner_params.cumSeqLensQPtr = cum_seq_lens_q;
runner_params.cumSeqLensKvPtr = cum_seq_lens_kv;

runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16,
"trtllm_gen_softmax_workspace");
} else {
// ForGen
runner_params.mMaskType = TrtllmGenAttentionMaskType::Dense;
Expand All @@ -171,6 +182,9 @@ void trtllm_paged_attention_launcher(
// todo(Yingyi): add softmax buffer later for lse return
runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>(
num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace");
runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16,
"trtllm_gen_softmax_workspace");
// scratch takes the rest of the workspace buffer
runner_params.multiCtasKvScratchPtr =
float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");
Expand Down Expand Up @@ -213,7 +227,8 @@ void trtllm_paged_attention_decode(
TensorView seq_lens, int64_t max_kv_len, Variant<double, ffi::Tensor> bmm1_scale,
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) {
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> k_cache_scales, Optional<TensorView> v_cache_scales, Optional<TensorView> lse) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
Expand Down Expand Up @@ -249,8 +264,12 @@ void trtllm_paged_attention_decode(
int num_kv_heads = key_cache.size(-3);
int kv_stride_keys_values = key_cache.stride(-2); // key/values
int kv_stride_heads = key_cache.stride(-3); // head

int kv_stride_batch = key_cache.stride(0); // batch
if (is_4bit(kv_data_type)) {
kv_stride_keys_values *= 2;
kv_stride_heads *= 2;
kv_stride_batch *= 2;
}

const auto stream = get_stream(query.device());
void* output_sf_ptr =
Expand Down Expand Up @@ -281,17 +300,39 @@ void trtllm_paged_attention_decode(
float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value()
? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
: nullptr;

float* lse_ptr = nullptr;
int lse_stride_tokens = 0;
int lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(2);
}

void* k_cache_scales_ptr = nullptr;
void* v_cache_scales_ptr = nullptr;
if (k_cache_scales.has_value()) {
k_cache_scales_ptr = k_cache_scales.value().data_ptr();
}
if (v_cache_scales.has_value()) {
v_cache_scales_ptr = v_cache_scales.value().data_ptr();
}

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
static_cast<int*>(seq_lens.data_ptr()),
/*cum_seq_lens_q=*/nullptr,
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, lse_ptr,
k_cache_scales_ptr, v_cache_scales_ptr,
q_data_type, kv_data_type, o_data_type,
TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len,
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq,
bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,
o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, sm_count,
o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, lse_stride_tokens, lse_stride_heads, sm_count,
enable_pdl, workspace_size, stream);
}

Expand All @@ -302,7 +343,9 @@ void trtllm_paged_attention_context(
Variant<double, ffi::Tensor> bmm1_scale, Variant<double, ffi::Tensor> bmm2_scale,
double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size,
int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) {
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> k_cache_scales, Optional<TensorView> v_cache_scales, Optional<TensorView> lse) {

auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
Expand All @@ -329,7 +372,12 @@ void trtllm_paged_attention_context(
int kv_stride_keys_values = key_cache.stride(-2); // key/values
int kv_stride_heads = key_cache.stride(-3); // head
int kv_stride_batch = key_cache.stride(0); // batch

if (is_4bit(kv_data_type)) {
kv_stride_keys_values *= 2;
kv_stride_heads *= 2;
kv_stride_batch *= 2;
}

const auto stream = get_stream(query.device());
void* output_sf_ptr =
out_scale_factor.has_value() ? out_scale_factor.value().data_ptr() : nullptr;
Expand Down Expand Up @@ -361,18 +409,38 @@ void trtllm_paged_attention_context(
? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
: nullptr;

float* lse_ptr = nullptr;
int lse_stride_tokens = 0;
int lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(1);
}

void* k_cache_scales_ptr = nullptr;
void* v_cache_scales_ptr = nullptr;
if (k_cache_scales.has_value()) {
k_cache_scales_ptr = k_cache_scales.value().data_ptr();
}
if (v_cache_scales.has_value()) {
v_cache_scales_ptr = v_cache_scales.value().data_ptr();
}

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
static_cast<int*>(seq_lens.data_ptr()),
/*cum_seq_lens_q=*/static_cast<int*>(cum_seq_lens_q.data_ptr()),
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr,
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, lse_ptr,
k_cache_scales_ptr, v_cache_scales_ptr,
q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
/*sparse_mla_top_k=*/0, sm_count, enable_pdl, workspace_size, stream);
/*sparse_mla_top_k=*/0, lse_stride_tokens, lse_stride_heads,sm_count, enable_pdl, workspace_size, stream);
}

void trtllm_ragged_attention_launcher(
Expand All @@ -385,6 +453,7 @@ void trtllm_ragged_attention_launcher(
int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal,
int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch,
int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch,
int64_t lse_stride_tokens, int64_t lse_stride_heads,
int64_t workspace_size, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
Expand Down Expand Up @@ -441,6 +510,8 @@ void trtllm_ragged_attention_launcher(
runner_params.mMaskType =
is_causal ? TrtllmGenAttentionMaskType::Causal : TrtllmGenAttentionMaskType::Dense;
runner_params.lsePtr = lse;
runner_params.lseStrideTokens = lse_stride_tokens;
runner_params.lseStrideHeads = lse_stride_heads;

AlignedAllocator float_allocator(workspace_buffer, workspace_size);
size_t max_batch_size = 8192;
Expand Down Expand Up @@ -482,9 +553,13 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
attention_sinks_ptr = static_cast<float*>(attention_sinks.value().data_ptr());
}
float* lse_ptr = nullptr;
int lse_stride_tokens = 0;
int lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(1);
}
TVM_FFI_ICHECK_EQ(out.ndim(), 3) << "out must be a 3D tensor";
TVM_FFI_ICHECK_EQ(query.ndim(), 3) << "query must be a 3D tensor";
Expand Down Expand Up @@ -535,7 +610,7 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale_value,
bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left,
sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch,
v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream);
v_stride_keys_values, v_stride_heads, v_stride_batch, lse_stride_tokens, lse_stride_heads, workspace_size, stream);
}

namespace trtllm_cubin_loader {
Expand Down
14 changes: 7 additions & 7 deletions csrc/trtllm_fused_moe_routing_deepseek.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ __global__ void routingMainKernel(KernelParams params) {
}
}

// note that for invalid scores, we simply use a negative value:
// they work well even with the compacted format used in topK, and
// sigmoid / bias activated scores cannot be negative
static constexpr float invalidScoreFloat = -1.F;
const OutputT invalidScore = OutputT{invalidScoreFloat};
// note that for invalid scores, we use a very negative value:
// they work well even with the compacted format used in topK.
// With negative bias allowed, sigmoid + bias scores can be negative,
// so we use a very negative value to ensure invalid scores are always less than valid ones
static constexpr float invalidScoreFloat = -1e10F;

// load bias already; each warp represents one expert group
auto threadExpert = threadIdx.x;
Expand Down Expand Up @@ -101,8 +101,8 @@ __global__ void routingMainKernel(KernelParams params) {
smemScoreSigmoid[threadExpert] = scoreSigmoid;
}
// get the score with bias
// note that with invalid values, because sigmoid is < 1 and bias is -1,
// we must get a negative value, which is smaller than any valid value
// note: with invalid values, we use invalidScoreFloat which is guaranteed
// to be smaller than any valid value (even when bias is negative)
auto scoreBias = float{scoreSigmoid + float{biasVal}};

if (expertSelected) {
Expand Down
61 changes: 59 additions & 2 deletions csrc/trtllm_moe_allreduce_fusion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@ using tvm::ffi::Optional;
} \
}()

#define DISPATCH_FLOATING_TYPES_FOR_SCALE(dtype, ScaleType, ...) \
[&] { \
switch (encode_dlpack_dtype(dtype)) { \
case float16_code: { \
using ScaleType = half; \
return __VA_ARGS__(); \
} \
case bfloat16_code: { \
using ScaleType = __nv_bfloat16; \
return __VA_ARGS__(); \
} \
case float32_code: { \
using ScaleType = float; \
return __VA_ARGS__(); \
} \
default: \
TVM_FFI_LOG_AND_THROW(NotImplementedError) \
<< "Unsupported expert_scale_factor dtype; only float16, bfloat16 " \
"and float32 are supported in trtllm_moe_finalize_allreduce_fusion"; \
} \
}()

void trtllm_moe_allreduce_fusion(
int64_t world_size, int64_t world_rank, int64_t token_num, int64_t hidden_size,
TensorView workspace_ptrs, bool launch_with_pdl, TensorView residual_in, TensorView rms_gamma,
Expand Down Expand Up @@ -86,12 +108,13 @@ void trtllm_moe_finalize_allreduce_fusion(
TensorView expanded_idx_to_permuted_idx, TensorView norm_out, TensorView residual_out,
bool launch_with_pdl, TensorView workspace, int64_t const world_rank, int64_t const world_size,
double const eps, Optional<TensorView> shared_expert_output,
Optional<TensorView> expert_scale_factor) {
Optional<TensorView> expert_scale_factor, Optional<float> routing_scaling_factor) {
DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(residual_in.dtype(), c_type, [&] {
MoeFinalizeAllReduceFusionParams<c_type> params;

int hidden_dim = residual_in.size(-1);
int top_k = expanded_idx_to_permuted_idx.size(-1);
int num_tokens = residual_in.size(0);

params.quant_out = nullptr;
params.scale_out = nullptr;
Expand All @@ -101,6 +124,14 @@ void trtllm_moe_finalize_allreduce_fusion(
// size: num_token * hidden_dim
params.size = residual_in.numel();
params.hidden_dim = hidden_dim;
FLASHINFER_CHECK(
expanded_idx_to_permuted_idx.size(0) == num_tokens,
"expanded_idx_to_permuted_idx.size(0) must equal num_tokens, got "
"expanded_idx_to_permuted_idx.size(0)=%d and num_tokens=%d",
expanded_idx_to_permuted_idx.size(0), num_tokens);
params.routing_scaling_factor =
routing_scaling_factor.has_value() ? static_cast<float>(routing_scaling_factor.value())
: 1.0f;

// workspace: AR scratch space
params.workspace = reinterpret_cast<void**>(workspace.data_ptr());
Expand All @@ -125,7 +156,33 @@ void trtllm_moe_finalize_allreduce_fusion(
params.norm_out = norm_out.data_ptr();
params.residual_out = residual_out.data_ptr();

auto status = moefinalize_allreduce_fusion_op(params, launch_with_pdl);
// Record norm_out dtype so kernels can specialize behavior if needed.
params.norm_out_dtype = encode_dlpack_dtype(norm_out.dtype());

// Dispatch on expert_scale_factor dtype for ScaleType. If none is provided, default to float
// and select NormOutT based on norm_out dtype.
cudaError_t status;
if (!expert_scale_factor.has_value()) {
auto norm_dtype = encode_dlpack_dtype(norm_out.dtype());
if (norm_dtype == float8_e4m3fn_code) {
status =
moefinalize_allreduce_fusion_op<c_type, float, __nv_fp8_e4m3>(params, launch_with_pdl);
} else {
status = moefinalize_allreduce_fusion_op<c_type, float, c_type>(params, launch_with_pdl);
}
} else {
DISPATCH_FLOATING_TYPES_FOR_SCALE(
expert_scale_factor.value().dtype(), ScaleType, [&]() {
auto norm_dtype = encode_dlpack_dtype(norm_out.dtype());
if (norm_dtype == float8_e4m3fn_code) {
status = moefinalize_allreduce_fusion_op<c_type, ScaleType, __nv_fp8_e4m3>(
params, launch_with_pdl);
} else {
status =
moefinalize_allreduce_fusion_op<c_type, ScaleType, c_type>(params, launch_with_pdl);
}
});
}
TVM_FFI_ICHECK(status == cudaSuccess)
<< "moefinalize_allreduce_fusion_op failed with error code " << cudaGetErrorString(status);
});
Expand Down
Loading
Loading