Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,17 @@ class TllmGenFmhaRunnerCache {
void trtllm_paged_attention_launcher(
void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache,
void* workspace_buffer, int* block_tables, int* seq_lens, int* cum_seq_lens_q,
int* cum_seq_lens_kv, float* attention_sinks,
int* cum_seq_lens_kv, float* attention_sinks, float* lse,
void* k_cache_scales, void* v_cache_scales,
Data_type q_data_type, Data_type kv_data_type,
Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len,
int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t kv_stride_keys_values,
int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq,
double bmm1_scale, double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t window_left, int64_t sum_seq_q, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
int64_t window_left, int64_t sum_seq_q,
int64_t lse_stride_tokens, int64_t lse_stride_heads,
int64_t sm_count, bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
Expand Down Expand Up @@ -166,6 +167,12 @@ void trtllm_paged_attention_launcher(
runner_params.multiCtasKvScratchPtr =
float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");
}
runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16,
"trtllm_gen_softmax_workspace");
runner_params.lsePtr = lse;
runner_params.lseStrideTokens = lse_stride_tokens;
runner_params.lseStrideHeads = lse_stride_heads;

auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params);
if (!foundKernels) {
Expand Down Expand Up @@ -206,7 +213,7 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
int64_t o_sf_start_index, int64_t window_left, int64_t sm_count,
bool enable_pdl, int64_t workspace_size,
Optional<TensorView> attention_sinks, Optional<TensorView> k_cache_scales,
Optional<TensorView> v_cache_scales) {
Optional<TensorView> v_cache_scales, Optional<TensorView> lse) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
Expand Down Expand Up @@ -260,6 +267,16 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
attention_sinks_ptr = static_cast<float*>(attention_sinks.value().data_ptr());
}

float* lse_ptr = nullptr;
int lse_stride_tokens = 0;
int lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(2);
}

void* k_cache_scales_ptr = nullptr;
void* v_cache_scales_ptr = nullptr;
if (k_cache_scales.has_value()) {
Expand All @@ -274,14 +291,14 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
static_cast<int*>(seq_lens.data_ptr()),
/*cum_seq_lens_q=*/nullptr,
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr,
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, lse_ptr,
k_cache_scales_ptr, v_cache_scales_ptr,
q_data_type, kv_data_type, o_data_type,
TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len,
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale,
bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index,
window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, stream);
window_left, sum_seq_q, lse_stride_tokens, lse_stride_heads, sm_count, enable_pdl, workspace_size, stream);
}

void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_scale_factor,
Expand All @@ -294,7 +311,7 @@ void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_sca
TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv,
int64_t sm_count, bool enable_pdl, int64_t workspace_size,
Optional<TensorView> attention_sinks, Optional<TensorView> k_cache_scales,
Optional<TensorView> v_cache_scales) {
Optional<TensorView> v_cache_scales, Optional<TensorView> lse) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
Expand Down Expand Up @@ -338,6 +355,16 @@ void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_sca
attention_sinks_ptr = static_cast<float*>(attention_sinks.value().data_ptr());
}

float* lse_ptr = nullptr;
int lse_stride_tokens = 0;
int lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(1);
}

void* k_cache_scales_ptr = nullptr;
void* v_cache_scales_ptr = nullptr;
if (k_cache_scales.has_value()) {
Expand All @@ -352,13 +379,13 @@ void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_sca
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
static_cast<int*>(seq_lens.data_ptr()),
/*cum_seq_lens_q=*/static_cast<int*>(cum_seq_lens_q.data_ptr()),
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr,
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, lse_ptr,
k_cache_scales_ptr, v_cache_scales_ptr,
q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
max_num_blocks_per_seq, bmm1_scale, bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index,
window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, stream);
window_left, sum_seq_q, lse_stride_tokens, lse_stride_heads, sm_count, enable_pdl, workspace_size, stream);
}

void trtllm_ragged_attention_launcher(
Expand All @@ -370,6 +397,7 @@ void trtllm_ragged_attention_launcher(
double o_sf_scale, int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl,
bool is_causal, int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch,
int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch,
int64_t lse_stride_tokens, int64_t lse_stride_heads,
int64_t workspace_size, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
Expand Down Expand Up @@ -422,6 +450,8 @@ void trtllm_ragged_attention_launcher(
runner_params.mMaskType =
is_causal ? TrtllmGenAttentionMaskType::Causal : TrtllmGenAttentionMaskType::Dense;
runner_params.lsePtr = lse;
runner_params.lseStrideTokens = lse_stride_tokens;
runner_params.lseStrideHeads = lse_stride_heads;

AlignedAllocator float_allocator(workspace_buffer, workspace_size);
size_t max_batch_size = 8192;
Expand Down Expand Up @@ -463,9 +493,13 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
attention_sinks_ptr = static_cast<float*>(attention_sinks.value().data_ptr());
}
float* lse_ptr = nullptr;
int lse_stride_tokens = 0;
int lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(1);
}
TVM_FFI_ICHECK_EQ(out.ndim(), 3) << "out must be a 3D tensor";
TVM_FFI_ICHECK_EQ(query.ndim(), 3) << "query must be a 3D tensor";
Expand Down Expand Up @@ -497,7 +531,7 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale,
bmm2_scale, o_sf_scale, batch_size, window_left, sm_count, enable_pdl, is_causal,
k_stride_keys_values, k_stride_heads, k_stride_batch, v_stride_keys_values, v_stride_heads,
v_stride_batch, workspace_size, stream);
v_stride_batch, lse_stride_tokens, lse_stride_heads, workspace_size, stream);
}

namespace trtllm_cubin_loader {
Expand Down
43 changes: 39 additions & 4 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,6 +1919,9 @@ def _paged_run(
enable_pdl,
workspace_size,
sinks,
None,
None,
None,
)
return out

Expand Down Expand Up @@ -2079,6 +2082,8 @@ def trtllm_batch_decode_with_kv_cache(
q_len_per_req: Optional[int] = 1,
o_scale: Optional[float] = 1.0,
kv_cache_scales: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_lse: bool = False,
lse: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, FP4Tensor]:
"""
Parameters
Expand Down Expand Up @@ -2299,6 +2304,14 @@ def trtllm_batch_decode_with_kv_cache(
if kv_cache_scales is not None:
k_cache_scale, v_cache_scale = kv_cache_scales

if return_lse and lse is None:
lse = torch.empty(
query.shape[0],
query.shape[1],
device=query.device,
dtype=torch.float32,
)

run_func(
out,
out_scale_factor,
Expand Down Expand Up @@ -2326,13 +2339,18 @@ def trtllm_batch_decode_with_kv_cache(
sinks,
k_cache_scale,
v_cache_scale,
lse
)

return (
out = (
out
if out_dtype != "nvfp4"
else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape)
)
if return_lse:
return out, lse
else:
return out
else:
raise KeyError(f"Backend {backend} not supported")

Expand Down Expand Up @@ -2545,6 +2563,8 @@ def trtllm_batch_decode_with_kv_cache_mla(
bmm1_scale_log2_tensor: Optional[torch.Tensor] = None,
bmm2_scale_tensor: Optional[torch.Tensor] = None,
sinks: Optional[List[torch.Tensor]] = None,
return_lse: bool = False,
lse: Optional[torch.Tensor] = None,
enable_pdl: bool = None,
backend: str = "auto",
) -> torch.Tensor:
Expand Down Expand Up @@ -2651,10 +2671,10 @@ def trtllm_batch_decode_with_kv_cache_mla(
out_shape = query.shape[:-1] + (kv_lora_rank,)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
else:
batch_size, _, num_q_heads, _ = query.shape
batch_size, q_len_per_request, num_q_heads, _ = query.shape
check_shape_dtype_device(
out,
[batch_size, num_q_heads, kv_lora_rank],
[batch_size, q_len_per_request, num_q_heads, kv_lora_rank],
torch.bfloat16,
query.device,
"out",
Expand All @@ -2670,6 +2690,15 @@ def trtllm_batch_decode_with_kv_cache_mla(
"Dynamic scale factors bmm1_scale_tensor and bmm2_scale_tensor are only supported for fp8 tensor core operation"
)

if return_lse and lse is None:
lse = torch.empty(
query.shape[0],
query.shape[1],
query.shape[2],
device=query.device,
dtype=torch.float32,
)

run_func(
out,
None, # fp4 output not supported in wrapper api yet.
Expand All @@ -2690,9 +2719,15 @@ def trtllm_batch_decode_with_kv_cache_mla(
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
None,
None,
lse,
)

return out
if return_lse:
return out, lse
else:
return out
else:
raise ValueError(f"Backend {backend} not supported")

Expand Down
17 changes: 16 additions & 1 deletion flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3339,6 +3339,8 @@ def trtllm_batch_context_with_kv_cache(
enable_pdl: Optional[bool] = None,
sinks: Optional[List[torch.Tensor]] = None,
kv_cache_scales: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_lse: bool = False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gate return lse just based on if set to None or not?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they already have this API for ragged_attention, I'm just making it consistent

lse: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, FP4Tensor]:
"""
Parameters
Expand Down Expand Up @@ -3505,6 +3507,14 @@ def trtllm_batch_context_with_kv_cache(
if kv_cache_scales is not None:
k_cache_scale, v_cache_scale = kv_cache_scales

if return_lse and lse is None:
lse = torch.empty(
query.shape[0],
query.shape[1],
device=query.device,
dtype=torch.float32,
)

workspace_size = workspace_buffer.numel() * workspace_buffer.element_size()
run_func(
out,
Expand Down Expand Up @@ -3532,9 +3542,14 @@ def trtllm_batch_context_with_kv_cache(
sinks,
k_cache_scale,
v_cache_scale,
lse
)
return (
out = (
out
if out_dtype != "nvfp4"
else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape)
)
if return_lse:
return out, lse
else:
return out
5 changes: 3 additions & 2 deletions include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,9 @@ class TllmGenFmhaKernel {

if (params.lsePtr != nullptr) {
flashinfer::ComputeLSEFromMD(params.softmaxStatsPtr, params.lsePtr,
params.mSumOfSeqLensQ * params.mNumHeadsQ, params.enable_pdl,
params.stream);
params.mSumOfSeqLensQ, params.mNumHeadsQ,
params.lseStrideTokens, params.lseStrideHeads,
params.enable_pdl, params.stream);
}
// Break the while op.
break;
Expand Down
3 changes: 3 additions & 0 deletions include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ struct TllmGenFmhaRunnerParams {
// The LSE buffer.
float* lsePtr;

int lseStrideTokens;
int lseStrideHeads;

// Attention sink
float const* ptrAttentionSinks{nullptr};
// The output buffer.
Expand Down
16 changes: 10 additions & 6 deletions include/flashinfer/trtllm/fmha/lse.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,27 @@ limitations under the License.

namespace flashinfer {

__global__ void ComputeLSEFromMDKernel(float2* __restrict__ md, float* __restrict__ lse, int n) {
__global__ void ComputeLSEFromMDKernel(float2* __restrict__ md, float* __restrict__ lse, int num_tokens, int num_heads, int lse_stride_tokens, int lse_stride_heads) {
int elem_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (elem_idx >= n) return;
if (elem_idx >= num_tokens * num_heads) return;
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
float2 md_elem = md[elem_idx];
float m = md_elem.x;
float d = md_elem.y;
lse[elem_idx] = math::log2e * m + math::ptx_log2(d);
int token_idx = elem_idx / num_heads;
int head_idx = elem_idx % num_heads;
int elem_idx_lse = token_idx * lse_stride_tokens + head_idx * lse_stride_heads;
lse[elem_idx_lse] = m + math::loge2 * math::ptx_log2(d);
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}

inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool launch_with_pdl,
cudaStream_t stream) {
inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int num_tokens, int num_heads, int lse_stride_tokens, int lse_stride_heads,
bool launch_with_pdl, cudaStream_t stream) {
int n = num_tokens * num_heads;
int num_threads = std::min(1024, UpPowerOfTwo(n));
int num_blocks = ceil_div(n, num_threads);
cudaLaunchConfig_t config;
Expand All @@ -53,7 +57,7 @@ inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool launch_w
config.numAttrs = 1;
config.attrs = attrs;

FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, ComputeLSEFromMDKernel, md, lse, n));
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, ComputeLSEFromMDKernel, md, lse, num_tokens, num_heads, lse_stride_tokens, lse_stride_heads));
return cudaSuccess;
}

Expand Down
Loading