diff --git a/csrc/fmhaReduction.cu b/csrc/fmhaReduction.cu index e329e1c1..4f309365 100644 --- a/csrc/fmhaReduction.cu +++ b/csrc/fmhaReduction.cu @@ -111,7 +111,7 @@ __global__ void __launch_bounds__(NumThreadsPerCta, 2) // The O pointer. DtypeO* oPtr = reinterpret_cast(params.ptrO) + oOffset; // The attentionSinks pointer. - float const* attentionSinksPtr = params.ptrAttentionSinks + headIdxO; + float const* attentionSinksPtr = params.ptrAttentionSinks != nullptr ? (params.ptrAttentionSinks + headIdxO) : nullptr; // Whether to store the softmax stats. bool const storesSoftmaxStats{params.ptrSoftmaxStats != nullptr}; diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 89fe53b8..5e4b240b 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -75,14 +75,16 @@ class TllmGenFmhaRunnerCache { void trtllm_paged_attention_launcher( void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache, void* workspace_buffer, int* block_tables, int* seq_lens, int* cum_seq_lens_q, - int* cum_seq_lens_kv, float* attention_sinks, Data_type q_data_type, Data_type kv_data_type, + int* cum_seq_lens_kv, float* attention_sinks, float* lse, + void* k_cache_scales, void* v_cache_scales, + Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len, int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t kv_stride_keys_values, int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, - int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, int64_t sm_count, + int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, int64_t lse_stride_tokens, int64_t lse_stride_heads, int64_t sm_count, bool enable_pdl, int64_t workspace_size, cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; @@ -97,7 +99,9 @@ void trtllm_paged_attention_launcher( // Common params runner_params.qPtr = query; runner_params.kPtr = key_cache; + runner_params.kSfBasePtr = k_cache_scales; runner_params.vPtr = value_cache; + runner_params.vSfBasePtr = v_cache_scales; runner_params.kvPageIdxPtr = block_tables; runner_params.seqLensKvPtr = seq_lens; runner_params.oPtr = out; @@ -146,6 +150,9 @@ void trtllm_paged_attention_launcher( << "Only decode MLA supports sparse MLA"; AlignedAllocator float_allocator(workspace_buffer, workspace_size); + runner_params.lsePtr = lse; + runner_params.lseStrideTokens = lse_stride_tokens; + runner_params.lseStrideHeads = lse_stride_heads; if (mode == TllmPagedAttentionMode::Context) { runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal; runner_params.mKernelType = FmhaKernelType::Context; @@ -154,6 +161,10 @@ void trtllm_paged_attention_launcher( runner_params.cumSeqLensQPtr = cum_seq_lens_q; runner_params.cumSeqLensKvPtr = cum_seq_lens_kv; + + runner_params.softmaxStatsPtr = float_allocator.aligned_alloc( + sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16, + "trtllm_gen_softmax_workspace"); } else { // ForGen runner_params.mMaskType = TrtllmGenAttentionMaskType::Dense; @@ -171,6 +182,9 @@ void trtllm_paged_attention_launcher( // todo(Yingyi): add softmax buffer later for lse return runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc( num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace"); + runner_params.softmaxStatsPtr = float_allocator.aligned_alloc( + sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16, + "trtllm_gen_softmax_workspace"); // scratch takes the rest of the workspace buffer runner_params.multiCtasKvScratchPtr = float_allocator.aligned_alloc(0, 16, "trtllm_gen_scratch_workspace"); @@ -213,7 +227,8 @@ void trtllm_paged_attention_decode( TensorView seq_lens, int64_t max_kv_len, Variant bmm1_scale, Variant bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, - bool enable_pdl, int64_t workspace_size, Optional attention_sinks) { + bool enable_pdl, int64_t workspace_size, Optional attention_sinks, + Optional k_cache_scales, Optional v_cache_scales, Optional lse) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); @@ -249,8 +264,12 @@ void trtllm_paged_attention_decode( int num_kv_heads = key_cache.size(-3); int kv_stride_keys_values = key_cache.stride(-2); // key/values int kv_stride_heads = key_cache.stride(-3); // head - int kv_stride_batch = key_cache.stride(0); // batch + if (is_4bit(kv_data_type)) { + kv_stride_keys_values *= 2; + kv_stride_heads *= 2; + kv_stride_batch *= 2; + } const auto stream = get_stream(query.device()); void* output_sf_ptr = @@ -281,17 +300,39 @@ void trtllm_paged_attention_decode( float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value() ? static_cast(maybe_bmm2_scale_tensor.value().data_ptr()) : nullptr; + + float* lse_ptr = nullptr; + int lse_stride_tokens = 0; + int lse_stride_heads = 0; + if (lse.has_value()) { + TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; + lse_ptr = static_cast(lse.value().data_ptr()); + lse_stride_tokens = lse.value().stride(0); + lse_stride_heads = lse.value().stride(2); + } + + void* k_cache_scales_ptr = nullptr; + void* v_cache_scales_ptr = nullptr; + if (k_cache_scales.has_value()) { + k_cache_scales_ptr = k_cache_scales.value().data_ptr(); + } + if (v_cache_scales.has_value()) { + v_cache_scales_ptr = v_cache_scales.value().data_ptr(); + } + trtllm_paged_attention_launcher( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), static_cast(seq_lens.data_ptr()), /*cum_seq_lens_q=*/nullptr, - /*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, + /*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, lse_ptr, + k_cache_scales_ptr, v_cache_scales_ptr, + q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, - o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, sm_count, + o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, lse_stride_tokens, lse_stride_heads, sm_count, enable_pdl, workspace_size, stream); } @@ -302,7 +343,9 @@ void trtllm_paged_attention_context( Variant bmm1_scale, Variant bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size, int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count, - bool enable_pdl, int64_t workspace_size, Optional attention_sinks) { + bool enable_pdl, int64_t workspace_size, Optional attention_sinks, + Optional k_cache_scales, Optional v_cache_scales, Optional lse) { + auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); @@ -329,7 +372,12 @@ void trtllm_paged_attention_context( int kv_stride_keys_values = key_cache.stride(-2); // key/values int kv_stride_heads = key_cache.stride(-3); // head int kv_stride_batch = key_cache.stride(0); // batch - + if (is_4bit(kv_data_type)) { + kv_stride_keys_values *= 2; + kv_stride_heads *= 2; + kv_stride_batch *= 2; + } + const auto stream = get_stream(query.device()); void* output_sf_ptr = out_scale_factor.has_value() ? out_scale_factor.value().data_ptr() : nullptr; @@ -361,18 +409,38 @@ void trtllm_paged_attention_context( ? static_cast(maybe_bmm2_scale_tensor.value().data_ptr()) : nullptr; + float* lse_ptr = nullptr; + int lse_stride_tokens = 0; + int lse_stride_heads = 0; + if (lse.has_value()) { + TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; + lse_ptr = static_cast(lse.value().data_ptr()); + lse_stride_tokens = lse.value().stride(0); + lse_stride_heads = lse.value().stride(1); + } + + void* k_cache_scales_ptr = nullptr; + void* v_cache_scales_ptr = nullptr; + if (k_cache_scales.has_value()) { + k_cache_scales_ptr = k_cache_scales.value().data_ptr(); + } + if (v_cache_scales.has_value()) { + v_cache_scales_ptr = v_cache_scales.value().data_ptr(); + } + trtllm_paged_attention_launcher( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), static_cast(seq_lens.data_ptr()), /*cum_seq_lens_q=*/static_cast(cum_seq_lens_q.data_ptr()), - /*cum_seq_lens_kv=*/static_cast(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, + /*cum_seq_lens_kv=*/static_cast(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, lse_ptr, + k_cache_scales_ptr, v_cache_scales_ptr, q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, - /*sparse_mla_top_k=*/0, sm_count, enable_pdl, workspace_size, stream); + /*sparse_mla_top_k=*/0, lse_stride_tokens, lse_stride_heads,sm_count, enable_pdl, workspace_size, stream); } void trtllm_ragged_attention_launcher( @@ -385,6 +453,7 @@ void trtllm_ragged_attention_launcher( int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal, int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch, int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch, + int64_t lse_stride_tokens, int64_t lse_stride_heads, int64_t workspace_size, cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; @@ -441,6 +510,8 @@ void trtllm_ragged_attention_launcher( runner_params.mMaskType = is_causal ? TrtllmGenAttentionMaskType::Causal : TrtllmGenAttentionMaskType::Dense; runner_params.lsePtr = lse; + runner_params.lseStrideTokens = lse_stride_tokens; + runner_params.lseStrideHeads = lse_stride_heads; AlignedAllocator float_allocator(workspace_buffer, workspace_size); size_t max_batch_size = 8192; @@ -482,9 +553,13 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T attention_sinks_ptr = static_cast(attention_sinks.value().data_ptr()); } float* lse_ptr = nullptr; + int lse_stride_tokens = 0; + int lse_stride_heads = 0; if (lse.has_value()) { TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; lse_ptr = static_cast(lse.value().data_ptr()); + lse_stride_tokens = lse.value().stride(0); + lse_stride_heads = lse.value().stride(1); } TVM_FFI_ICHECK_EQ(out.ndim(), 3) << "out must be a 3D tensor"; TVM_FFI_ICHECK_EQ(query.ndim(), 3) << "query must be a 3D tensor"; @@ -535,7 +610,7 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left, sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch, - v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream); + v_stride_keys_values, v_stride_heads, v_stride_batch, lse_stride_tokens, lse_stride_heads, workspace_size, stream); } namespace trtllm_cubin_loader { diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu index 7f9a6642..b9411681 100644 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/trtllm_fused_moe_routing_deepseek.cu @@ -56,11 +56,11 @@ __global__ void routingMainKernel(KernelParams params) { } } - // note that for invalid scores, we simply use a negative value: - // they work well even with the compacted format used in topK, and - // sigmoid / bias activated scores cannot be negative - static constexpr float invalidScoreFloat = -1.F; - const OutputT invalidScore = OutputT{invalidScoreFloat}; + // note that for invalid scores, we use a very negative value: + // they work well even with the compacted format used in topK. + // With negative bias allowed, sigmoid + bias scores can be negative, + // so we use a very negative value to ensure invalid scores are always less than valid ones + static constexpr float invalidScoreFloat = -1e10F; // load bias already; each warp represents one expert group auto threadExpert = threadIdx.x; @@ -101,8 +101,8 @@ __global__ void routingMainKernel(KernelParams params) { smemScoreSigmoid[threadExpert] = scoreSigmoid; } // get the score with bias - // note that with invalid values, because sigmoid is < 1 and bias is -1, - // we must get a negative value, which is smaller than any valid value + // note: with invalid values, we use invalidScoreFloat which is guaranteed + // to be smaller than any valid value (even when bias is negative) auto scoreBias = float{scoreSigmoid + float{biasVal}}; if (expertSelected) { diff --git a/csrc/trtllm_moe_allreduce_fusion.cu b/csrc/trtllm_moe_allreduce_fusion.cu index ac1ce171..fbab2a5e 100644 --- a/csrc/trtllm_moe_allreduce_fusion.cu +++ b/csrc/trtllm_moe_allreduce_fusion.cu @@ -24,6 +24,28 @@ using tvm::ffi::Optional; } \ }() +#define DISPATCH_FLOATING_TYPES_FOR_SCALE(dtype, ScaleType, ...) \ + [&] { \ + switch (encode_dlpack_dtype(dtype)) { \ + case float16_code: { \ + using ScaleType = half; \ + return __VA_ARGS__(); \ + } \ + case bfloat16_code: { \ + using ScaleType = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + case float32_code: { \ + using ScaleType = float; \ + return __VA_ARGS__(); \ + } \ + default: \ + TVM_FFI_LOG_AND_THROW(NotImplementedError) \ + << "Unsupported expert_scale_factor dtype; only float16, bfloat16 " \ + "and float32 are supported in trtllm_moe_finalize_allreduce_fusion"; \ + } \ + }() + void trtllm_moe_allreduce_fusion( int64_t world_size, int64_t world_rank, int64_t token_num, int64_t hidden_size, TensorView workspace_ptrs, bool launch_with_pdl, TensorView residual_in, TensorView rms_gamma, @@ -86,12 +108,13 @@ void trtllm_moe_finalize_allreduce_fusion( TensorView expanded_idx_to_permuted_idx, TensorView norm_out, TensorView residual_out, bool launch_with_pdl, TensorView workspace, int64_t const world_rank, int64_t const world_size, double const eps, Optional shared_expert_output, - Optional expert_scale_factor) { + Optional expert_scale_factor, Optional routing_scaling_factor) { DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(residual_in.dtype(), c_type, [&] { MoeFinalizeAllReduceFusionParams params; int hidden_dim = residual_in.size(-1); int top_k = expanded_idx_to_permuted_idx.size(-1); + int num_tokens = residual_in.size(0); params.quant_out = nullptr; params.scale_out = nullptr; @@ -101,6 +124,14 @@ void trtllm_moe_finalize_allreduce_fusion( // size: num_token * hidden_dim params.size = residual_in.numel(); params.hidden_dim = hidden_dim; + FLASHINFER_CHECK( + expanded_idx_to_permuted_idx.size(0) == num_tokens, + "expanded_idx_to_permuted_idx.size(0) must equal num_tokens, got " + "expanded_idx_to_permuted_idx.size(0)=%d and num_tokens=%d", + expanded_idx_to_permuted_idx.size(0), num_tokens); + params.routing_scaling_factor = + routing_scaling_factor.has_value() ? static_cast(routing_scaling_factor.value()) + : 1.0f; // workspace: AR scratch space params.workspace = reinterpret_cast(workspace.data_ptr()); @@ -125,7 +156,33 @@ void trtllm_moe_finalize_allreduce_fusion( params.norm_out = norm_out.data_ptr(); params.residual_out = residual_out.data_ptr(); - auto status = moefinalize_allreduce_fusion_op(params, launch_with_pdl); + // Record norm_out dtype so kernels can specialize behavior if needed. + params.norm_out_dtype = encode_dlpack_dtype(norm_out.dtype()); + + // Dispatch on expert_scale_factor dtype for ScaleType. If none is provided, default to float + // and select NormOutT based on norm_out dtype. + cudaError_t status; + if (!expert_scale_factor.has_value()) { + auto norm_dtype = encode_dlpack_dtype(norm_out.dtype()); + if (norm_dtype == float8_e4m3fn_code) { + status = + moefinalize_allreduce_fusion_op(params, launch_with_pdl); + } else { + status = moefinalize_allreduce_fusion_op(params, launch_with_pdl); + } + } else { + DISPATCH_FLOATING_TYPES_FOR_SCALE( + expert_scale_factor.value().dtype(), ScaleType, [&]() { + auto norm_dtype = encode_dlpack_dtype(norm_out.dtype()); + if (norm_dtype == float8_e4m3fn_code) { + status = moefinalize_allreduce_fusion_op( + params, launch_with_pdl); + } else { + status = + moefinalize_allreduce_fusion_op(params, launch_with_pdl); + } + }); + } TVM_FFI_ICHECK(status == cudaSuccess) << "moefinalize_allreduce_fusion_op failed with error code " << cudaGetErrorString(status); }); diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 33bb7ac9..ea792d96 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -363,6 +363,7 @@ def trtllm_moe_finalize_allreduce_fusion( eps: float, shared_expert_output: Optional[torch.Tensor], expert_scale_factor: Optional[torch.Tensor], + routing_scaling_factor: Optional[float], ) -> None: module.trtllm_moe_finalize_allreduce_fusion( allreduce_in, @@ -378,6 +379,7 @@ def trtllm_moe_finalize_allreduce_fusion( eps, shared_expert_output, expert_scale_factor, + routing_scaling_factor, ) return SimpleNamespace( @@ -1046,6 +1048,7 @@ def trtllm_moe_finalize_allreduce_fusion( eps: float, shared_expert_output: Optional[torch.Tensor], expert_scale_factor: Optional[torch.Tensor], + routing_scaling_factor: Optional[float] = None, ) -> None: """ Parameters: @@ -1062,6 +1065,7 @@ def trtllm_moe_finalize_allreduce_fusion( - eps: the epsilon value. - shared_expert_output: the shared expert output tensor. [token_num, hidden_dim] - expert_scale_factor: the expert scale factor tensor. [token_num, top_k] + - routing_scaling_factor: optional scalar multiplier applied to routing scores. """ required_lamport_comm_size = allreduce_in.numel() * 2 * world_size @@ -1086,4 +1090,5 @@ def trtllm_moe_finalize_allreduce_fusion( eps=eps, shared_expert_output=shared_expert_output, expert_scale_factor=expert_scale_factor, + routing_scaling_factor=routing_scaling_factor, ) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index cc865ae5..101d2776 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -720,9 +720,7 @@ def __init__( if use_tensor_cores: self._jit_module = get_batch_prefill_jit_module( jit_args[0], - gen_customize_batch_prefill_module( - "fa2", *jit_args - ).build_and_load(), + gen_customize_batch_prefill_module("fa2", *jit_args).build_and_load(), ) else: self._jit_module = get_batch_decode_jit_module( @@ -735,9 +733,7 @@ def __init__( self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) + self._int_workspace_buffer = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=self.device) self._pin_memory_int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, @@ -746,28 +742,18 @@ def __init__( ) self._kv_lens_buffer: Optional[torch.Tensor] = None if backend == "trtllm-gen": - self._kv_lens_buffer = torch.empty( - (32768,), dtype=torch.int32, device=self.device - ) + self._kv_lens_buffer = torch.empty((32768,), dtype=torch.int32, device=self.device) if use_cuda_graph: if not torch.is_tensor(paged_kv_indptr_buffer): - raise ValueError( - "paged_kv_indptr_buffer should be a torch.Tensor in cudagraph mode" - ) + raise ValueError("paged_kv_indptr_buffer should be a torch.Tensor in cudagraph mode") if not torch.is_tensor(paged_kv_indices_buffer): - raise ValueError( - "paged_kv_indices_buffer should be a torch.Tensor in cudagraph mode" - ) + raise ValueError("paged_kv_indices_buffer should be a torch.Tensor in cudagraph mode") if not torch.is_tensor(paged_kv_last_page_len_buffer): - raise ValueError( - "paged_kv_last_page_len_buffer should be a torch.Tensor in cudagraph mode" - ) + raise ValueError("paged_kv_last_page_len_buffer should be a torch.Tensor in cudagraph mode") self._fixed_batch_size = len(paged_kv_last_page_len_buffer) if len(paged_kv_indptr_buffer) != self._fixed_batch_size + 1: - raise ValueError( - "The size of paged_kv_indptr_buffer should be batch_size + 1" - ) + raise ValueError("The size of paged_kv_indptr_buffer should be batch_size + 1") else: self._fixed_batch_size = 0 @@ -795,9 +781,7 @@ def use_tensor_cores(self) -> bool: def is_cuda_graph_enabled(self) -> bool: return self._use_cuda_graph - def reset_workspace_buffer( - self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor - ) -> None: + def reset_workspace_buffer(self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor) -> None: r"""Reset the workspace buffer. Parameters @@ -910,10 +894,7 @@ def plan( The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``. """ - self._workspace_size = ( - self._float_workspace_buffer.numel() - * self._float_workspace_buffer.element_size() - ) + self._workspace_size = self._float_workspace_buffer.numel() * self._float_workspace_buffer.element_size() batch_size = len(last_page_len) if logits_soft_cap is None: @@ -929,29 +910,17 @@ def plan( ) ) if len(indices) > len(self._paged_kv_indices_buf): - raise ValueError( - "The size of indices should be less than or equal to the allocated buffer" - ) + raise ValueError("The size of indices should be less than or equal to the allocated buffer") self._paged_kv_indptr_buf.copy_(indptr, non_blocking=non_blocking) - self._paged_kv_last_page_len_buf.copy_( - last_page_len, non_blocking=non_blocking - ) + self._paged_kv_last_page_len_buf.copy_(last_page_len, non_blocking=non_blocking) self._paged_kv_indices_buf[: len(indices)].copy_( indices, non_blocking=(indices.device == self.device) and non_blocking ) else: - self._paged_kv_indptr_buf = indptr.to( - self.device, non_blocking=non_blocking - ) - self._paged_kv_indices_buf = indices.to( - self.device, non_blocking=non_blocking - ) - self._paged_kv_last_page_len_buf = last_page_len.to( - self.device, non_blocking=non_blocking - ) - self._qo_indptr_buf = qo_indptr_host.to( - self.device, non_blocking=non_blocking - ) + self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=non_blocking) + self._paged_kv_indices_buf = indices.to(self.device, non_blocking=non_blocking) + self._paged_kv_last_page_len_buf = last_page_len.to(self.device, non_blocking=non_blocking) + self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking) indptr_host = indptr.to("cpu") last_page_len_host = last_page_len.to("cpu") @@ -967,9 +936,7 @@ def plan( kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) if fixed_split_size is not None and not self.use_tensor_cores: - raise ValueError( - "fixed_split_size is only supported by tensor core decode for now." - ) + raise ValueError("fixed_split_size is only supported by tensor core decode for now.") if fixed_split_size is None: fixed_split_size = -1 @@ -988,14 +955,9 @@ def plan( if self._backend == "trtllm-gen": assert logits_soft_cap == 0.0 self._max_kv_len = max(kv_lens_arr_host).item() - self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( - kv_lens_arr_host, non_blocking=non_blocking - ) + self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_(kv_lens_arr_host, non_blocking=non_blocking) if self._block_tables is None: - blocks_per_seq = [ - (seq_len + page_size - 1) // page_size - for seq_len in kv_lens_arr_host - ] + blocks_per_seq = [(seq_len + page_size - 1) // page_size for seq_len in kv_lens_arr_host] max_num_blocks_per_seq = max(blocks_per_seq) self._block_tables = torch.zeros( (batch_size, max_num_blocks_per_seq), @@ -1005,11 +967,9 @@ def plan( block_id = indptr[0] for i in range(batch_size): num_blocks_needed = blocks_per_seq[i] - self._block_tables[i, :num_blocks_needed] = ( - self._paged_kv_indices_buf[ - block_id : block_id + num_blocks_needed - ] - ) + self._block_tables[i, :num_blocks_needed] = self._paged_kv_indices_buf[ + block_id : block_id + num_blocks_needed + ] block_id += num_blocks_needed self._cached_module = get_trtllm_gen_decode_module( q_data_type, @@ -1127,9 +1087,7 @@ def forward( self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta - return self.run( - q, paged_kv_cache, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale - ) + return self.run(q, paged_kv_cache, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale) @overload def run( @@ -1235,9 +1193,7 @@ def run( page_size = k_cache.shape[1] else: page_size = k_cache.shape[2] - _check_cached_qkv_data_type( - q, k_cache, self._cached_q_data_type, self._cached_kv_data_type - ) + _check_cached_qkv_data_type(q, k_cache, self._cached_q_data_type, self._cached_kv_data_type) # Convert NHD layout to HND for trtllm-gen backend if self._backend == "trtllm-gen" and self._kv_layout == "NHD": @@ -1272,13 +1228,9 @@ def run( if return_lse: if lse is None: - lse = torch.empty( - (q.size(0), q.size(1)), dtype=torch.float32, device=q.device - ) + lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.device) else: - check_shape_dtype_device( - lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" - ) + check_shape_dtype_device(lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse") if out is None: out = torch.empty_like(q) @@ -1543,9 +1495,7 @@ def __init__( """ self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) + self._int_workspace_buffer = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=self.device) self._pin_memory_int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, @@ -1555,22 +1505,14 @@ def __init__( if use_cuda_graph: if not torch.is_tensor(paged_kv_indptr_buffer): - raise ValueError( - "paged_kv_indptr_buffer should be a torch.Tensor in cudagraph mode" - ) + raise ValueError("paged_kv_indptr_buffer should be a torch.Tensor in cudagraph mode") if not torch.is_tensor(paged_kv_indices_buffer): - raise ValueError( - "paged_kv_indices_buffer should be a torch.Tensor in cudagraph mode" - ) + raise ValueError("paged_kv_indices_buffer should be a torch.Tensor in cudagraph mode") if not torch.is_tensor(paged_kv_last_page_len_buffer): - raise ValueError( - "paged_kv_last_page_len_buffer should be a torch.Tensor in cudagraph mode" - ) + raise ValueError("paged_kv_last_page_len_buffer should be a torch.Tensor in cudagraph mode") self._fixed_batch_size = len(paged_kv_last_page_len_buffer) if len(paged_kv_indptr_buffer) != self._fixed_batch_size + 1: - raise ValueError( - "The size of paged_kv_indptr_buffer should be batch_size + 1" - ) + raise ValueError("The size of paged_kv_indptr_buffer should be batch_size + 1") else: self._fixed_batch_size = 0 @@ -1588,9 +1530,7 @@ def is_cuda_graph_enabled(self) -> bool: def use_tensor_cores(self) -> bool: return self._use_tensor_cores - def reset_workspace_buffer( - self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor - ) -> None: + def reset_workspace_buffer(self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor) -> None: r"""Reset the workspace buffer. Parameters @@ -1682,9 +1622,7 @@ def plan( ) ) if len(indices) > len(self._paged_kv_indices_buf): - raise ValueError( - "The size of indices should be less than or equal to the allocated buffer" - ) + raise ValueError("The size of indices should be less than or equal to the allocated buffer") self._paged_kv_indptr_buf.copy_(indptr) self._paged_kv_indices_buf[: len(indices)] = indices self._paged_kv_last_page_len_buf.copy_(last_page_len) @@ -1787,8 +1725,7 @@ def run( device_arch = major * 10 + minor if device_arch != 80: raise GPUArchitectureError( - f"MLA decode kernel is not supported on this GPU (SM{device_arch}). " - "Supported architecture: SM80." + f"MLA decode kernel is not supported on this GPU (SM{device_arch}). " "Supported architecture: SM80." ) window_left = self._window_left logits_soft_cap = self._logits_soft_cap @@ -1810,9 +1747,7 @@ def run( if out is None: out = torch.empty_like(q_nope, device=device) else: - check_shape_dtype_device( - out, q_nope.shape, q_nope.dtype, q_nope.device, "out" - ) + check_shape_dtype_device(out, q_nope.shape, q_nope.dtype, q_nope.device, "out") if return_lse: if lse is None: @@ -1916,6 +1851,9 @@ def _paged_run( enable_pdl, workspace_size, sinks, + None, + None, + None, ) return out @@ -2077,6 +2015,9 @@ def trtllm_batch_decode_with_kv_cache( q_len_per_req: Optional[int] = 1, o_scale: Optional[float] = 1.0, mask: Optional[torch.Tensor] = None, + kv_cache_scales: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + return_lse: bool = False, + lse: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, FP4Tensor]: """ Parameters @@ -2163,17 +2104,13 @@ def trtllm_batch_decode_with_kv_cache( if kv_cache.shape[1] == 1: k_cache, v_cache = kv_cache, kv_cache else: - assert kv_cache.shape[1] == 2, ( - "When kv_cache is a single tensor, the second dimension must be 1 or 2" - ) + assert kv_cache.shape[1] == 2, "When kv_cache is a single tensor, the second dimension must be 1 or 2" # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) # it doesn't change underlying storage k_cache, v_cache = kv_cache.unbind(dim=1) if backend == "auto": - backend = ( - "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" - ) + backend = "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" if backend == "xqa": # xqa backend doesn't support nvfp4 output @@ -2218,9 +2155,7 @@ def trtllm_batch_decode_with_kv_cache( sm_count = get_device_sm_count(query.device) if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)): - assert query.dtype == torch.float8_e4m3fn, ( - "query must be fp8 when out_dtype is nvfp4." - ) + assert query.dtype == torch.float8_e4m3fn, "query must be fp8 when out_dtype is nvfp4." assert o_sf_scale is not None assert o_sf_vec_size in [None, 16], "only o_sf_vec_size = 16 is supported" o_sf_vec_size = o_sf_vec_size or 16 @@ -2242,9 +2177,7 @@ def trtllm_batch_decode_with_kv_cache( round_up(query.shape[0], 128), round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4), ) - out_scale_factor = torch.empty( - fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device - ) + out_scale_factor = torch.empty(fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device) o_sf_start_index = 0 out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device) else: @@ -2254,9 +2187,7 @@ def trtllm_batch_decode_with_kv_cache( assert isinstance(out, torch.Tensor) # Use uint8 as the container dtype to compliant with next fp4 gemm. - check_shape_dtype_device( - out, fp4_out_shape, torch.uint8, query.device, "out" - ) + check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out") check_shape_dtype_device( out_scale_factor, @@ -2267,10 +2198,7 @@ def trtllm_batch_decode_with_kv_cache( ) # Check o_sf_start_index is valid - if ( - o_sf_start_index < 0 - or o_sf_start_index + out.shape[0] > out_scale_factor.shape[0] - ): + if o_sf_start_index < 0 or o_sf_start_index + out.shape[0] > out_scale_factor.shape[0]: raise ValueError( f"o_sf_start_index is out of the valid range of out_scale_factor. " f"o_sf_start_index={o_sf_start_index}, out.shape[0]={out.shape[0]}, " @@ -2297,6 +2225,19 @@ def trtllm_batch_decode_with_kv_cache( if isinstance(bmm2_scale, torch.Tensor): assert bmm2_scale.dtype == torch.float32 + k_cache_scale = None + v_cache_scale = None + if kv_cache_scales is not None: + k_cache_scale, v_cache_scale = kv_cache_scales + + if return_lse and lse is None: + lse = torch.empty( + query.shape[0], + query.shape[1], + device=query.device, + dtype=torch.float32, + ) + run_func( out, out_scale_factor, @@ -2323,13 +2264,16 @@ def trtllm_batch_decode_with_kv_cache( enable_pdl, workspace_buffer.numel() * workspace_buffer.element_size(), sinks, + k_cache_scale, + v_cache_scale, + lse, ) - return ( - out - if out_dtype != "nvfp4" - else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) - ) + out = out if out_dtype != "nvfp4" else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) + if return_lse: + return out, lse + else: + return out else: raise KeyError(f"Backend {backend} not supported") @@ -2420,9 +2364,7 @@ def xqa_batch_decode_with_kv_cache( if kv_cache.shape[1] == 1: k_cache, v_cache = kv_cache, kv_cache else: - assert kv_cache.shape[1] == 2, ( - "When kv_cache is a single tensor, the second dimension must be 1 or 2" - ) + assert kv_cache.shape[1] == 2, "When kv_cache is a single tensor, the second dimension must be 1 or 2" # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) # it doesn't change underlying storage k_cache, v_cache = kv_cache.unbind(dim=1) @@ -2540,47 +2482,31 @@ def fast_decode_plan( if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime batch size {} " - " mismatches the batch size set during initialization {}".format( - batch_size, self._fixed_batch_size - ) + " mismatches the batch size set during initialization {}".format(batch_size, self._fixed_batch_size) ) if len(indices) > len(self._paged_kv_indices_buf): - raise ValueError( - "The size of indices should be less than or equal to the allocated buffer" - ) + raise ValueError("The size of indices should be less than or equal to the allocated buffer") else: self._paged_kv_indptr_buf = indptr self._paged_kv_indices_buf = indices self._paged_kv_last_page_len_buf = last_page_len if self.use_tensor_cores: - self._qo_indptr_buf = qo_indptr_host.to( - self.device, non_blocking=non_blocking - ) + self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking) # Create empty tensors for dtype info if needed empty_q_data = torch.empty( 0, - dtype=( - getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type - ), + dtype=(getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type), device=self.device, ) empty_kv_cache = torch.empty( 0, - dtype=( - getattr(torch, kv_data_type) - if isinstance(kv_data_type, str) - else kv_data_type - ), + dtype=(getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type), device=self.device, ) - indptr_host = ( - global_override_indptr_cpu - if global_override_indptr_cpu is not None - else indptr.cpu() - ) + indptr_host = global_override_indptr_cpu if global_override_indptr_cpu is not None else indptr.cpu() with torch.cuda.device(self.device): if self.use_tensor_cores: @@ -2588,9 +2514,7 @@ def fast_decode_plan( if page_size == 1: # When page size is 1, last_page_len is always 1. # Directly construct the host tensor rather than executing a device-to-host copy. - last_page_len_host = torch.ones( - (batch_size,), dtype=torch.int32, device="cpu" - ) + last_page_len_host = torch.ones((batch_size,), dtype=torch.int32, device="cpu") else: last_page_len_host = last_page_len.cpu() diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 3b59ad99..2429433e 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -48,19 +48,13 @@ def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table): if H != 128: raise ValueError(f"Expected 128 heads for q_nope_pe, got {H}") if D_q != D_ckv or D_q != 576: - raise ValueError( - f"Expected head dim 576 for q_nope_pe and ckv_kpe_cache, got {D_q} and {D_ckv}" - ) + raise ValueError(f"Expected head dim 576 for q_nope_pe and ckv_kpe_cache, got {D_q} and {D_ckv}") B_block_table, block_num = page_table.shape block_size = ckv_kpe_cache.shape[1] if B_q != B_block_table: - raise ValueError( - f"Expected batch size {B_q} for q_nope_pe and block_table, got {B_q} and {B_block_table}" - ) + raise ValueError(f"Expected batch size {B_q} for q_nope_pe and block_table, got {B_q} and {B_block_table}") if block_num % (128 / block_size) != 0: - raise ValueError( - f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}" - ) + raise ValueError(f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}") def _check_trtllm_gen_mla_shape( @@ -90,27 +84,19 @@ def _check_trtllm_gen_mla_shape( # raise ValueError(f"Expected 128 heads for query, got {H}") # todo(Yingyi): should we check num_heads == 128? Is this deepseek only? if D_q != D_ckv or D_q != 576: - raise ValueError( - f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}" - ) + raise ValueError(f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}") if sparse_mla_top_k > 0: page_table_shape = page_table.shape if page_table_shape != (B_q, Q_len, sparse_mla_top_k): - raise ValueError( - f"Expected page_table.shape == (B_q, Q_len, sparse_mla_top_k), got {page_table_shape}" - ) + raise ValueError(f"Expected page_table.shape == (B_q, Q_len, sparse_mla_top_k), got {page_table_shape}") else: B_block_table, block_num = page_table.shape block_size = page_size if B_q != B_block_table: - raise ValueError( - f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}" - ) + raise ValueError(f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}") if block_num % (128 / block_size) != 0: - raise ValueError( - f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}" - ) + raise ValueError(f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}") @functools.cache @@ -249,9 +235,7 @@ def __init__( self._backend = backend return - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) + self._int_workspace_buffer = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=self.device) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, @@ -437,16 +421,12 @@ def run( if return_lse: raise ValueError("return_lse does not support cutlass backend for now.") if profiler_buffer is not None: - raise ValueError( - "profiler_buffer does not support cutlass backend for now." - ) + raise ValueError("profiler_buffer does not support cutlass backend for now.") self._cached_module = get_mla_module() if out is None: out = torch.empty_like(q_nope) else: - check_shape_dtype_device( - out, q_nope.shape, q_nope.dtype, q_nope.device, "out" - ) + check_shape_dtype_device(out, q_nope.shape, q_nope.dtype, q_nope.device, "out") q_nope_pe = torch.cat([q_nope, q_pe], dim=-1) ckv_kpe_cache = torch.cat([ckv_cache, kpe_cache], dim=-1) _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table) @@ -464,9 +444,7 @@ def run( if profiler_buffer is None: if self._use_profiler: - raise ValueError( - "Profiler is enabled, profiler_buffer must be provided" - ) + raise ValueError("Profiler is enabled, profiler_buffer must be provided") num_heads = q_nope.shape[1] page_size = self._page_size sm_scale = self._sm_scale @@ -476,17 +454,13 @@ def run( if out is None: out = torch.empty_like(q_nope) else: - check_shape_dtype_device( - out, q_nope.shape, q_nope.dtype, q_nope.device, "out" - ) + check_shape_dtype_device(out, q_nope.shape, q_nope.dtype, q_nope.device, "out") if return_lse: if lse is None: lse = torch.empty(q_nope.shape[:2], dtype=torch.float32, device=device) else: - check_shape_dtype_device( - lse, q_nope.shape[:2], torch.float32, q_nope.device, "lse" - ) + check_shape_dtype_device(lse, q_nope.shape[:2], torch.float32, q_nope.device, "lse") profiler_args = (profiler_buffer,) if self._use_profiler else () self._cached_module.run( self._float_workspace_buffer, @@ -526,6 +500,8 @@ def trtllm_batch_decode_with_kv_cache_mla( bmm1_scale: Union[float, torch.Tensor] = 1.0, bmm2_scale: Union[float, torch.Tensor] = 1.0, sinks: Optional[List[torch.Tensor]] = None, + return_lse: bool = False, + lse: Optional[torch.Tensor] = None, enable_pdl: bool = None, backend: str = "auto", ) -> torch.Tensor: @@ -574,9 +550,7 @@ def trtllm_batch_decode_with_kv_cache_mla( When both are provided, the dynamic scale factor tensors will be used. """ if backend == "auto": - backend = ( - "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" - ) + backend = "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" if isinstance(bmm1_scale, torch.Tensor): assert bmm1_scale.dtype == torch.float32 bmm1_scale = bmm1_scale * log2e @@ -594,9 +568,7 @@ def trtllm_batch_decode_with_kv_cache_mla( if sinks is not None: raise ValueError("XQA MLA does not support sinks") if query.size(1) != 1: - raise ValueError( - f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}" - ) + raise ValueError(f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}") return xqa_batch_decode_with_kv_cache_mla( query, kv_cache, @@ -614,16 +586,12 @@ def trtllm_batch_decode_with_kv_cache_mla( enable_pdl, ) elif backend == "trtllm-gen": - enable_pdl = ( - device_support_pdl(query.device) if enable_pdl is None else enable_pdl - ) + enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode sm_count = get_device_sm_count(query.device) block_size = kv_cache.size(-2) - if ( - block_size != 32 and block_size != 64 - ): # todo(Yingyi): add support for more block sizes? + if block_size != 32 and block_size != 64: # todo(Yingyi): add support for more block sizes? raise ValueError(f"Supported block_size are 32 and 64, got {block_size}") _check_trtllm_gen_mla_shape( @@ -641,15 +609,23 @@ def trtllm_batch_decode_with_kv_cache_mla( out_shape = query.shape[:-1] + (kv_lora_rank,) out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) else: - batch_size, _, num_q_heads, _ = query.shape check_shape_dtype_device( out, - [batch_size, num_q_heads, kv_lora_rank], + [*query.shape[:-1], kv_lora_rank], torch.bfloat16, query.device, "out", ) + if return_lse and lse is None: + lse = torch.empty( + query.shape[0], + query.shape[1], + query.shape[2], + device=query.device, + dtype=torch.float32, + ) + run_func( out, None, # fp4 output not supported in wrapper api yet. @@ -671,9 +647,15 @@ def trtllm_batch_decode_with_kv_cache_mla( enable_pdl, workspace_buffer.numel() * workspace_buffer.element_size(), sinks, + None, + None, + lse, ) - return out + if return_lse: + return out, lse + else: + return out else: raise ValueError(f"Backend {backend} not supported") @@ -732,13 +714,9 @@ def xqa_batch_decode_with_kv_cache_mla( block_size = kv_cache.size(-2) q_len_per_request = query.size(1) if q_len_per_request != 1: - raise ValueError( - f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}" - ) + raise ValueError(f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}") if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn: - raise ValueError( - f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}" - ) + raise ValueError(f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}") if sinks is not None: raise ValueError("XQA MLA does not support sinks") diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index a6e32a67..3cf9b9c2 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3439,6 +3439,9 @@ def trtllm_batch_context_with_kv_cache( kv_layout: str = "HND", enable_pdl: Optional[bool] = None, sinks: Optional[List[torch.Tensor]] = None, + kv_cache_scales: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + return_lse: bool = False, + lse: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, FP4Tensor]: """ Parameters @@ -3600,6 +3603,20 @@ def trtllm_batch_context_with_kv_cache( bmm1_scale = bmm1_scale * log2e if isinstance(bmm2_scale, torch.Tensor): assert bmm2_scale.dtype == torch.float32 + + k_cache_scale = None + v_cache_scale = None + if kv_cache_scales is not None: + k_cache_scale, v_cache_scale = kv_cache_scales + + if return_lse and lse is None: + lse = torch.empty( + query.shape[0], + query.shape[1], + device=query.device, + dtype=torch.float32, + ) + workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() run_func( out, @@ -3625,12 +3642,19 @@ def trtllm_batch_context_with_kv_cache( enable_pdl, workspace_size, sinks, + k_cache_scale, + v_cache_scale, + lse ) - return ( + out = ( out if out_dtype != "nvfp4" else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) ) + if return_lse: + return out, lse + else: + return out @flashinfer_api @@ -3710,7 +3734,3 @@ def fmha_v2_prefill_deepseek( is_e4m3, is_bf16_output, ) - if return_lse: - return out, lse - else: - return out diff --git a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh index 143e25de..2976261b 100644 --- a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh +++ b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh @@ -665,6 +665,8 @@ struct AllReduceFusionParams { void* residual_in; void* residual_out; void* norm_out; + // Encoded DLPack dtype for norm_out (or -1 if unspecified). + int64_t norm_out_dtype = -1; void* quant_out; void* scale_out; void* rms_gamma; @@ -700,6 +702,8 @@ struct MoeFinalizeAllReduceFusionParams : public AllReduceFusionParams { // Refer to kernel implementation on layout of those params // number of active experts on current device int top_k; + // Optional scalar multiplier applied to routing scores before combine. + float routing_scaling_factor = 1.0f; // [num_tokens, top_k] void* expert_scale_factor = nullptr; void* shared_expert_output = nullptr; @@ -716,7 +720,9 @@ struct LamportComm { clear_ptr = &reinterpret_cast(workspace[NRanks * 3])[4]; flag_value = *flag_ptr; int comm_size = reinterpret_cast(workspace[NRanks * 3])[3]; + // printf("comm_size: %d\n", comm_size); clear_size = *clear_ptr; + // printf("clear_size: %d\n", clear_size); int data_offset = flag_value % 3; int clear_offset = (flag_value + 2) % 3; for (int r = 0; r < NRanks; ++r) { @@ -800,7 +806,7 @@ __device__ __forceinline__ vec_t rms_norm(vec_t const& } template + typename NormOutT = T, uint32_t VEC_SIZE> __device__ __forceinline__ void fused_op(vec_t const& val, int access_id, int token_id, int access_id_in_token, AllReduceFusionParams& params) { if constexpr (AllReduceOut) { @@ -818,7 +824,13 @@ __device__ __forceinline__ void fused_op(vec_t const& val, int acce vec_t norm_val; norm_val = rms_norm(residual_val, gamma_val, params.rms_eps, params.hidden_dim); if constexpr (NormOut) { - norm_val.store(reinterpret_cast(params.norm_out) + access_id * VEC_SIZE); + // Allow norm_out to use a different dtype (e.g. FP8) via NormOutT. + auto norm_out_ptr = + reinterpret_cast(params.norm_out) + access_id * VEC_SIZE; +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + norm_out_ptr[i] = static_cast(norm_val[i]); + } } #if CUDA_VERSION >= 12080 if constexpr (QuantOut) { @@ -1067,8 +1079,8 @@ __global__ void moereduce_allreduce_fusion_kernel_oneshot_lamport( } // * Fuse - fused_op(sum_val, idx, tidx, - access_id_in_token, params); + fused_op( + sum_val, idx, tidx, access_id_in_token, params); } comm.update(params.size * NRanks); cudaTriggerProgrammaticLaunchCompletion(); @@ -1231,130 +1243,120 @@ cudaError_t moereduction_allreduce_fusion_op(MoeReductionAllReduceFusionParams + typename NormOutT = T, typename ScaleType = T> __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport( MoeFinalizeAllReduceFusionParams params) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - namespace cg = cooperative_groups; - cg::cluster_group cluster = cg::this_cluster(); - cg::grid_group grid = cg::this_grid(); - - static constexpr int VEC_SIZE = details::kBytesPerAccess / sizeof(T); + int tid = threadIdx.x; + int token_id = blockIdx.x; // one block per token + static constexpr int VEC_SIZE = 16 / sizeof(T); - // Each token is handled by one cluster - // which token is handled by current cluster - int token_id = grid.cluster_rank(); - // total number of token - int num_token = params.size / params.hidden_dim; - // Each thread handle VEC_SIZE num elem in token. Total cluster.num_threads() to handle one - // token For current token, which VEC_SIZE is handled by current thread (in unit of - // VEC_SIZE) - int access_id_in_token = cluster.thread_rank(); - // Across all token, which VEC_SIZE is handled by current thread (in unit of - // VEC_SIZE) - int access_id = token_id * params.hidden_dim / VEC_SIZE + access_id_in_token; - // Persistent kernel - // stride to next token handled by current cta - int token_stride = grid.num_clusters(); - // stride in unit of VEC_SIZE - int access_stride = token_stride * params.hidden_dim / VEC_SIZE; - // Total number of access in unit of VEC_SIZE to handle (token_num * hidden_dim) - // This is within one rank - int tot_access = params.size / VEC_SIZE; - vec_t clear_vec; - clear_vec.fill(neg_zero_v); + extern __shared__ __align__(16) uint8_t smem_raw[]; + T* smem_buffer = reinterpret_cast(smem_raw); cudaGridDependencySynchronize(); LamportComm comm(params.workspace, params.rank); - int clear_access = comm.clear_size / VEC_SIZE; - - // * MoE related - int threadid_in_cluster = cluster.thread_rank(); - // Start Offset within one token's hidden_size of element - // Current thread handle token[thread_offset_within_token : thread_offset_within_token + - // VEC_SIZE] - int thread_offset_within_token = threadid_in_cluster * VEC_SIZE; + int num_vecs_per_token = params.hidden_dim / VEC_SIZE; + size_t copy_bytes = params.hidden_dim * sizeof(T); int top_k = params.top_k; bool use_scale_factor = params.expert_scale_factor != nullptr; - // Persistent Kernel - // Each cluster iterate through all token it need to handle - for (int token_id = grid.cluster_rank(); token_id < num_token; token_id += grid.num_clusters()) { - if (thread_offset_within_token >= params.hidden_dim) { - break; - } + vec_t clear_vec; + clear_vec.fill(neg_zero_v); - // * MoE finalize + // Compute MoE finalize for this token directly into shared memory that will be + // TMA-copied to peers. + for (int idx = tid; idx < num_vecs_per_token; idx += blockDim.x) { vec_t accumulator; accumulator.fill(0.f); - for (int k = 0; k < top_k; k++) { - int const expanded_idx = token_id * top_k + k; - int32_t const permuted_idx = params.expanded_idx_to_permuted_idx[expanded_idx]; - + // Accumulate selected experts + for (int k = 0; k < top_k; ++k) { + int expanded_idx = token_id * top_k + k; + int32_t permuted_idx = params.expanded_idx_to_permuted_idx[expanded_idx]; if (permuted_idx == -1) continue; - int thread_offset_across_token = - permuted_idx * params.hidden_dim + thread_offset_within_token; - float block_scale = 1.0; + int thread_offset_across_token = permuted_idx * params.hidden_dim + idx * VEC_SIZE; + float block_scale = params.routing_scaling_factor; if (use_scale_factor) { - block_scale = - static_cast(static_cast(params.expert_scale_factor)[expanded_idx]); + block_scale = static_cast( + static_cast(params.expert_scale_factor)[expanded_idx]) * + params.routing_scaling_factor; } vec_t permuted_data; permuted_data.load(reinterpret_cast(params.allreduce_in) + thread_offset_across_token); - // * acc += scale(data) #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { - // assume computation is done in ScaleType accumulator[i] += static_cast(static_cast(permuted_data[i]) * block_scale); } } - // * Add shared expert output + // Add shared expert output if provided if (params.shared_expert_output) { - // * Load shared expert output - int thread_offset_across_token = token_id * params.hidden_dim + thread_offset_within_token; + int thread_offset_across_token = token_id * params.hidden_dim + idx * VEC_SIZE; vec_t shared_expert_output; shared_expert_output.load(reinterpret_cast(params.shared_expert_output) + thread_offset_across_token); -#pragma unroll accumulator = vec_add(accumulator, shared_expert_output); } - // * AR Store - int idx = token_id * params.hidden_dim / VEC_SIZE + access_id_in_token; remove_neg_zero(accumulator); + accumulator.store(smem_buffer + idx * VEC_SIZE); + } + __syncthreads(); -#pragma unroll - for (int r = 0; r < NRanks; ++r) { - // STG.128 to remote rank - int offset = (params.rank * tot_access + idx) * VEC_SIZE; - accumulator.store_global_volatile(reinterpret_cast(comm.data_bufs[r]) + offset); - } + if (tid < NRanks) { + int r = tid; + size_t dst_offset_elems = + (params.rank * (params.size / VEC_SIZE) + token_id * num_vecs_per_token) * VEC_SIZE; + void* dst_ptr = reinterpret_cast(comm.data_bufs[r]) + dst_offset_elems; + void* src_ptr = smem_buffer; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + asm volatile( + "cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" + : + : "l"(__cvta_generic_to_global(dst_ptr)), + "r"(static_cast(__cvta_generic_to_shared(src_ptr))), + "r"(static_cast(copy_bytes)) + : "memory"); +#endif + // TOOD - do a fallback for architectures before hopper } - // * Clear previous buffer - for (int idx = access_id; idx < clear_access; idx += access_stride) { - clear_vec.store(reinterpret_cast(comm.clear_buf) + idx * VEC_SIZE); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + __syncthreads(); + if (tid == 0) { + asm volatile("cp.async.bulk.commit_group;"); + asm volatile("cp.async.bulk.wait_group 0;"); } +#endif + __syncthreads(); - // * AR Load + Fusion - for (int idx = access_id, tidx = token_id; idx < tot_access; - idx += access_stride, tidx += token_stride) { - // * AR Load + { + // Clear the entire previous Lamport buffer (size given by comm.clear_size), + // treating it as a 1D array of VEC_SIZE-wide vectors. + int clear_access = comm.clear_size / VEC_SIZE; + int global_thread_id = blockIdx.x * blockDim.x + tid; + int total_threads = gridDim.x * blockDim.x; + + for (int idx = global_thread_id; idx < clear_access; idx += total_threads) { + clear_vec.store(reinterpret_cast(comm.clear_buf) + idx * VEC_SIZE); + } + } + + for (int idx = tid; idx < num_vecs_per_token; idx += blockDim.x) { vec_t vals[NRanks]; bool done = false; while (!done) { done = true; #pragma unroll for (int r = 0; r < NRanks; ++r) { - // LDG.128 from local rank vals[r].load_global_volatile(reinterpret_cast(comm.data_bufs[r]) + - (r * tot_access + idx) * VEC_SIZE); + (r * (params.size / VEC_SIZE) + token_id * num_vecs_per_token + idx) * + VEC_SIZE); done &= !has_neg_zero(vals[r]); } } @@ -1364,29 +1366,31 @@ __global__ void moefinalize_allreduce_fusion_kernel_oneshot_lamport( sum_val = vec_add(sum_val, vals[r]); } - // * Fuse: AllReduceOut is always false in finalize_moe_allreduce - fused_op(sum_val, idx, tidx, - access_id_in_token, params); + // Fuse finalize (residual/norm/quant) using the summed value. + int access_id = token_id * num_vecs_per_token + idx; + int access_id_in_token = idx; // Thread-local offset within the token (in vec units). + fused_op( + sum_val, access_id, token_id, access_id_in_token, params); } + comm.update(params.size * NRanks); cudaTriggerProgrammaticLaunchCompletion(); -#endif } template + typename NormOutT = T, typename ScaleType = T> cudaError_t launch_oneshot_moefinalize_lamport(MoeFinalizeAllReduceFusionParams const& params, cudaLaunchConfig_t& cfg) { FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( &cfg, moefinalize_allreduce_fusion_kernel_oneshot_lamport, + NormOutT, ScaleType>, params)); return cudaSuccess; } template + typename NormOutT = T, typename ScaleType = T> cudaError_t moefinalize_allreduce_fusion_kernel_launcher( MoeFinalizeAllReduceFusionParams const& params, bool launch_with_pdl) { int token_num = params.size / params.hidden_dim; @@ -1396,44 +1400,25 @@ cudaError_t moefinalize_allreduce_fusion_kernel_launcher( token_num); oneshot = true; } - // Only support one shot - // FLASHINFER_CHECK(oneshot, "only support one shot"); - // Each token is handled by one cluster - int cluster_num = token_num; - // Total number of threads (within one cluster) that's need to handle one token - // given that each thread handle VEC_SIZE static constexpr int VEC_SIZE = details::kBytesPerAccess / sizeof(T); int threads_per_token = params.hidden_dim / VEC_SIZE; - // Total number of warp (within one cluster) that's need to handle one token - // given that each thread handle VEC_SIZE - int warps_per_token = (threads_per_token + 31) / 32; - int cluster_size = 8; - while (warps_per_token % cluster_size != 0) { - cluster_size /= 2; - } - int block_size = warps_per_token / cluster_size * 32; - FLASHINFER_CHECK(block_size <= 1024 && cluster_size > 0, - "block_size <= 1024 && cluster_size > 0"); - int sm_count = get_sm_count(); - int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size; - cudaLaunchConfig_t cfg; - cudaLaunchAttribute attribute[2]; - cfg.gridDim = grid_size; - cfg.blockDim = block_size; - cfg.dynamicSmemBytes = 0; + FLASHINFER_CHECK(threads_per_token <= 1024, "threads_per_token must be <= 1024"); + + cudaLaunchConfig_t cfg{}; + cfg.gridDim = token_num; + cfg.blockDim = threads_per_token; + cfg.dynamicSmemBytes = params.hidden_dim * sizeof(T); cfg.stream = params.stream; + cudaLaunchAttribute attribute[1]; attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attribute[0].val.programmaticStreamSerializationAllowed = launch_with_pdl ? 1 : 0; - attribute[1].id = cudaLaunchAttributeClusterDimension; - attribute[1].val.clusterDim.x = cluster_size; - attribute[1].val.clusterDim.y = 1; - attribute[1].val.clusterDim.z = 1; cfg.attrs = attribute; - cfg.numAttrs = 2; + cfg.numAttrs = 1; + if (oneshot) { - FLASHINFER_CUDA_CALL( - (launch_oneshot_moefinalize_lamport( - params, cfg))); + FLASHINFER_CUDA_CALL((launch_oneshot_moefinalize_lamport(params, + cfg))); } return cudaSuccess; } @@ -1466,7 +1451,7 @@ cudaError_t moefinalize_allreduce_fusion_kernel_launcher( } \ }() -template +template cudaError_t moefinalize_allreduce_fusion_op(MoeFinalizeAllReduceFusionParams const& params, bool launch_with_pdl) { static constexpr int VEC_SIZE = details::kBytesPerAccess / sizeof(T); @@ -1485,7 +1470,8 @@ cudaError_t moefinalize_allreduce_fusion_op(MoeFinalizeAllReduceFusionParams return cudaErrorNotSupported; } FLASHINFER_CUDA_CALL( - (moefinalize_allreduce_fusion_kernel_launcher( + (moefinalize_allreduce_fusion_kernel_launcher( (params), (launch_with_pdl)))); }); return status; diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index 7fb695ed..7a24acdc 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -264,8 +264,9 @@ class TllmGenFmhaKernel { if (params.lsePtr != nullptr) { flashinfer::ComputeLSEFromMD(params.softmaxStatsPtr, params.lsePtr, - params.mSumOfSeqLensQ * params.mNumHeadsQ, params.enable_pdl, - params.stream); + params.mSumOfSeqLensQ, params.mNumHeadsQ, + params.lseStrideTokens, params.lseStrideHeads, + params.enable_pdl, params.stream); } // Break the while op. break; diff --git a/include/flashinfer/trtllm/fmha/fmhaRunner.cuh b/include/flashinfer/trtllm/fmha/fmhaRunner.cuh index 98eb72c9..75fa374d 100644 --- a/include/flashinfer/trtllm/fmha/fmhaRunner.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaRunner.cuh @@ -32,7 +32,8 @@ class TllmGenFmhaRunner { mDtypeQ == DATA_TYPE_E4M3 || mDtypeQ == DATA_TYPE_FP16 || mDtypeQ == DATA_TYPE_BF16, "Unsupported Q data type: " + std::string(toStr(mDtypeQ))); FLASHINFER_CHECK( - mDtypeKv == DATA_TYPE_E4M3 || mDtypeKv == DATA_TYPE_FP16 || mDtypeKv == DATA_TYPE_BF16, + mDtypeKv == DATA_TYPE_E4M3 || mDtypeKv == DATA_TYPE_FP16 || mDtypeKv == DATA_TYPE_BF16 || + mDtypeKv == DATA_TYPE_E2M1, "Unsupported Kv data type: " + std::string(toStr(mDtypeKv))); FLASHINFER_CHECK(mDtypeOut == DATA_TYPE_E4M3 || mDtypeOut == DATA_TYPE_FP16 || mDtypeOut == DATA_TYPE_BF16 || mDtypeOut == DATA_TYPE_E2M1, diff --git a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h index ab48bc04..4164b257 100755 --- a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h +++ b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h @@ -225,6 +225,9 @@ struct TllmGenFmhaRunnerParams { // The LSE buffer. float* lsePtr; + int lseStrideTokens; + int lseStrideHeads; + // Attention sink float const* ptrAttentionSinks{nullptr}; // The output buffer. diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index 6e62c055..5cec8f81 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -417,7 +417,7 @@ struct KernelParams { // Prepare pointers for TMA descriptors. static std::tuple getDevicePtrs( - TllmGenFmhaRunnerParams const& runnerParams, int32_t bytesPerElt) { + TllmGenFmhaRunnerParams const& runnerParams, int32_t bitsPerElt) { // Declare the q, k, v ptrs. void const *qPtr{runnerParams.qPtr}, *kPtr{runnerParams.kPtr}, *vPtr{runnerParams.vPtr}; @@ -425,11 +425,10 @@ struct KernelParams { if (isPackedQkv(runnerParams.mQkvLayout)) { qPtr = runnerParams.qkvPtr; kPtr = reinterpret_cast(reinterpret_cast(runnerParams.qkvPtr) + - runnerParams.mNumHeadsQ * runnerParams.mHeadDimQk * - bytesPerElt); + runnerParams.mNumHeadsQ * (runnerParams.mHeadDimQk * bitsPerElt / 8)); vPtr = reinterpret_cast(reinterpret_cast(runnerParams.qkvPtr) + (runnerParams.mNumHeadsQ + runnerParams.mNumHeadsKv) * - runnerParams.mHeadDimQk * bytesPerElt); + (runnerParams.mHeadDimQk * bitsPerElt / 8)); } // Set K and V pointer from pagedKv tensor. else if (isPagedKv(runnerParams.mQkvLayout)) { @@ -445,7 +444,7 @@ struct KernelParams { int32_t const maxHeadDimKv{std::max(runnerParams.mHeadDimQk, runnerParams.mHeadDimV)}; vPtr = reinterpret_cast( reinterpret_cast(runnerParams.kvPtr) + - runnerParams.mNumHeadsKv * runnerParams.mMaxSeqLenCacheKv * maxHeadDimKv * bytesPerElt); + runnerParams.mNumHeadsKv * runnerParams.mMaxSeqLenCacheKv * (maxHeadDimKv * bitsPerElt / 8)); } // Return the pointers. @@ -555,7 +554,7 @@ struct KernelParams { memset(¶ms, 0, sizeof(KernelParams)); // Get the device pointers for TMA descriptors. - auto [qPtr, kPtr, vPtr] = getDevicePtrs(options, get_size_in_bytes(kernelMeta.mDataTypeKv)); + auto [qPtr, kPtr, vPtr] = getDevicePtrs(options, get_size_in_bits(kernelMeta.mDataTypeKv)); // The maximum headDim of K and V. // Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv. diff --git a/include/flashinfer/trtllm/fmha/lse.cuh b/include/flashinfer/trtllm/fmha/lse.cuh index b41d084a..fc2c7316 100644 --- a/include/flashinfer/trtllm/fmha/lse.cuh +++ b/include/flashinfer/trtllm/fmha/lse.cuh @@ -23,23 +23,27 @@ limitations under the License. namespace flashinfer { -__global__ void ComputeLSEFromMDKernel(float2* __restrict__ md, float* __restrict__ lse, int n) { +__global__ void ComputeLSEFromMDKernel(float2* __restrict__ md, float* __restrict__ lse, int num_tokens, int num_heads, int lse_stride_tokens, int lse_stride_heads) { int elem_idx = blockIdx.x * blockDim.x + threadIdx.x; - if (elem_idx >= n) return; + if (elem_idx >= num_tokens * num_heads) return; #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif float2 md_elem = md[elem_idx]; float m = md_elem.x; float d = md_elem.y; - lse[elem_idx] = math::log2e * m + math::ptx_log2(d); + int token_idx = elem_idx / num_heads; + int head_idx = elem_idx % num_heads; + int elem_idx_lse = token_idx * lse_stride_tokens + head_idx * lse_stride_heads; + lse[elem_idx_lse] = m + math::loge2 * math::ptx_log2(d); #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } -inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool launch_with_pdl, - cudaStream_t stream) { +inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int num_tokens, int num_heads, int lse_stride_tokens, int lse_stride_heads, + bool launch_with_pdl, cudaStream_t stream) { + int n = num_tokens * num_heads; int num_threads = std::min(1024, UpPowerOfTwo(n)); int num_blocks = ceil_div(n, num_threads); cudaLaunchConfig_t config; @@ -53,7 +57,7 @@ inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool launch_w config.numAttrs = 1; config.attrs = attrs; - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, ComputeLSEFromMDKernel, md, lse, n)); + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, ComputeLSEFromMDKernel, md, lse, num_tokens, num_heads, lse_stride_tokens, lse_stride_heads)); return cudaSuccess; }