Skip to content

Commit 65a72d7

Browse files
MrGevacursoragent
andcommitted
fix(autodeploy): Fix TRT-LLM attention KV cache layout for KVCacheManager
Fix garbage output when using thop.attention with AD's KVCacheManager by correctly configuring the interleaved K/V block layout. Key changes: - TrtllmKVResourceHandler now extends PagedResourceHandler for proper KVCacheManager integration with HND layout support - Configure KVCacheManager to use SELF cache type (kv_factor=2) when handlers request HND layout, avoiding memory-doubling copies - Fix pool pointers: K ptr = AD's base address, V ptr = 0 (kernel uses block offsets to locate V) - Fix pool mapping: Use AD's layer offsets directly - Fix block offsets: Use multiplier = num_layers * kv_factor (64) with K = base_offsets and V = base_offsets + 1 for interleaved layout Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com> Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent b0d1d1a commit 65a72d7

File tree

3 files changed

+277
-96
lines changed

3 files changed

+277
-96
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,15 +1105,20 @@ class PagedResourceHandler(ManagedResourceHandler):
11051105
The PagedResourceHandler can be used to handle resources that support paging such as kv-caches.
11061106
"""
11071107

1108-
def __init__(self, *token_shape: int, dtype: torch.dtype) -> None:
1108+
def __init__(
1109+
self, *token_shape: int, dtype: torch.dtype, kv_layout: Literal["NHD", "HND"] = "NHD"
1110+
) -> None:
11091111
"""Initialize the PagedResourceHandler.
11101112
11111113
Args:
1112-
page_shape: The shape of a single page of the resource.
1114+
token_shape: The shape of a single token's worth of data in the resource.
11131115
dtype: The dtype of the resource.
1116+
kv_layout: Memory layout for KV cache. "NHD" = [blocks, tokens, kv_factor, heads, dim],
1117+
"HND" = [blocks, kv_factor, heads, tokens, dim]. Default is "NHD".
11141118
"""
11151119
self.token_shape = token_shape
11161120
self.dtype = dtype
1121+
self.kv_layout = kv_layout
11171122

11181123

11191124
class StateResourceHandler(ManagedResourceHandler):

tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_attention.py

Lines changed: 153 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
CacheConfig,
5858
Constant,
5959
MHACallable,
60+
PagedResourceHandler,
6061
PrepareMetadataCallable,
6162
PrepareMetadataHostCallable,
6263
ResourceHandler,
@@ -84,14 +85,14 @@ def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
8485
return buffer
8586

8687

87-
class TrtllmKVResourceHandler(ResourceHandler):
88+
class TrtllmKVResourceHandler(PagedResourceHandler):
8889
"""Resource handler for TRT-LLM unified KV cache.
8990
90-
Uses ResourceHandler (not PagedResourceHandler) so the interface calls allocate()
91-
directly, allowing us to create the cache with the exact layout thop.attention expects.
91+
Extends PagedResourceHandler so the cache interface recognizes it as a paged resource
92+
and creates a proper KVCacheManager with correct parameters (num_layers, num_kv_heads, etc.).
9293
93-
Uses kv_factor=2 (unified K+V) and kv_layout="HND" to match what thop.attention expects.
94-
The cache is allocated with shape [num_blocks, 2, num_kv_heads, tokens_per_block, head_dim].
94+
Uses kv_layout="HND" to request HND format from AD's cache manager.
95+
The cache is converted to HND format in interface.py via permute+contiguous.
9596
"""
9697

9798
def __init__(
@@ -103,44 +104,32 @@ def __init__(
103104
trtllm_config: "TrtllmAttentionGlobalState",
104105
cache_config: CacheConfig,
105106
) -> None:
106-
# Store attributes for TRT-LLM attention
107+
# Initialize parent class with token_shape = (2, num_kv_heads, head_dim)
108+
# The 2 is kv_factor for unified K+V cache
109+
# Use HND layout for thop.attention kernel compatibility
110+
super().__init__(2, num_kv_heads, head_dim, dtype=dtype, kv_layout="HND")
111+
112+
# Store additional attributes for TRT-LLM attention
107113
self.num_kv_heads = num_kv_heads
108114
self.head_dim = head_dim
109-
self.dtype = dtype
110115
self.kv_factor = 2 # Unified K+V cache
111-
self.kv_layout = "HND" # Matches thop.attention kernel's per-block layout
112116
self.layer_idx = layer_idx
113117
self._trtllm_config = trtllm_config
114118
self._cache_config = cache_config
115119

116-
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
117-
"""Allocate cache via KVCacheManager or simple allocation."""
118-
# Configure global state first (first time only)
119-
if not self._trtllm_config.is_configured:
120-
self._trtllm_config.configure(sequence_info)
121-
122-
# Set model config for FP8 KV cache support (first time only)
123-
if self._trtllm_config._num_layers == 0:
124-
cache_dtype = self.dtype
125-
self._trtllm_config.set_model_config(
126-
num_layers=len(TrtllmAttention._num_kv_heads_per_layer),
127-
num_kv_heads_per_layer=TrtllmAttention._num_kv_heads_per_layer,
128-
head_dim=TrtllmAttention._head_dim,
129-
dtype=cache_dtype,
130-
)
120+
def __eq__(self, other: "TrtllmKVResourceHandler") -> bool:
121+
"""Check compatibility for KVCacheManager resource grouping.
131122
132-
# Allocate unified KV cache with correct layout for thop.attention
133-
# Shape: [num_blocks, kv_factor=2, num_kv_heads, tokens_per_block, head_dim] (HND layout)
134-
cache = torch.empty(
135-
sequence_info.num_blocks,
136-
self.kv_factor, # 2 for K and V
137-
self.num_kv_heads,
138-
sequence_info.tokens_per_block,
139-
self.head_dim,
140-
device=sequence_info.device,
141-
dtype=self.dtype,
123+
Return True so KVCacheManager manages all layers' KV caches together.
124+
"""
125+
if not isinstance(other, TrtllmKVResourceHandler):
126+
return False
127+
return (
128+
self.head_dim == other.head_dim
129+
and self.dtype == other.dtype
130+
and self.kv_factor == other.kv_factor
131+
and self.kv_layout == other.kv_layout
142132
)
143-
return cache
144133

145134

146135
@dataclass
@@ -175,8 +164,15 @@ class TrtllmLayerState:
175164

176165
def __post_init__(self):
177166
"""Initialize tensors - use shared tensors from global state where possible."""
178-
# Pool mapping needs to be pre-allocated before init_from_shared
179-
# Other tensors will come from shared state via init_from_shared()
167+
# Pool pointers and mapping are per-layer (each layer has its own cache buffer)
168+
# These are NOT shared across layers
169+
if self.host_kv_cache_pool_pointers is None:
170+
# Pool pointers: [num_pools, 2] for K and V pointers
171+
# With per-layer caches, each layer uses pool 0 which points to its own buffer
172+
self.host_kv_cache_pool_pointers = torch.zeros(
173+
1, 2, dtype=torch.int64, device="cpu", pin_memory=True
174+
)
175+
180176
if self.host_kv_cache_pool_mapping is None:
181177
# Pool mapping: 2D [num_layers, 2] format expected by thop.attention
182178
max_layers = 256
@@ -185,17 +181,22 @@ def __post_init__(self):
185181
)
186182

187183
def init_from_shared(self, global_state: "TrtllmAttentionGlobalState") -> None:
188-
"""Initialize layer to use shared tensors from global state."""
189-
# All layers share the same tensors (single KV cache pool)
184+
"""Initialize layer to use shared tensors from global state.
185+
186+
NOTE: Pool pointers (host_kv_cache_pool_pointers) are NOT shared because each layer
187+
has its own cache buffer with a different data_ptr(). These are initialized in
188+
__post_init__ and set per-layer in _prepare_trtllm_metadata.
189+
"""
190+
# All layers share sequence/batch metadata tensors
190191
self.sequence_length = global_state._shared_sequence_length
191192
self.context_lengths = global_state._shared_context_lengths
192193
self.kv_cache_block_offsets = global_state._shared_kv_cache_block_offsets
193194
self.host_past_key_value_lengths = global_state._shared_host_past_key_value_lengths
194195
self.host_context_lengths = global_state._shared_host_context_lengths
195196
self.host_request_types = global_state._shared_host_request_types
196197
self.host_total_kv_lens = global_state._shared_host_total_kv_lens
197-
self.host_kv_cache_pool_pointers = global_state._shared_host_kv_cache_pool_pointers
198-
# Keep host_kv_cache_pool_mapping from __post_init__ - it's layer-specific
198+
# NOTE: host_kv_cache_pool_pointers is NOT shared - each layer has its own from __post_init__
199+
# NOTE: host_kv_cache_pool_mapping is NOT shared - each layer has its own from __post_init__
199200

200201

201202
class TrtllmAttentionGlobalState:
@@ -324,32 +325,63 @@ def _init_gpu_buffers(self, max_pages: int, max_seqs: int) -> None:
324325
self._gpu_buffers_initialized = True
325326

326327
def _init_pool_pointers(
327-
self, ad_pool_pointers: torch.Tensor, ad_pool_mapping: torch.Tensor, num_layers: int
328+
self,
329+
ad_pool_pointers: Optional[torch.Tensor],
330+
ad_pool_mapping: Optional[torch.Tensor],
331+
num_layers: int,
332+
fallback_kv_cache_ptr: Optional[int] = None,
328333
) -> None:
329-
"""Initialize pool pointers once from AD's KVCacheManager.
334+
"""Initialize pool pointers once from AD's KVCacheManager or fallback.
330335
331336
This is called once during first host_prepare to set up the static pool info.
332337
Pool pointers don't change between requests - only block offsets do.
338+
339+
Args:
340+
ad_pool_pointers: Pool pointers from SequenceInfo (if available)
341+
ad_pool_mapping: Pool mapping from SequenceInfo (if available)
342+
num_layers: Number of transformer layers
343+
fallback_kv_cache_ptr: Fallback KV cache data pointer (if AD pool not available)
333344
"""
334345
if self._pool_pointers_initialized:
335346
return
336347

337-
if ad_pool_pointers is None or ad_pool_mapping is None:
338-
return
339-
340-
if ad_pool_pointers.numel() == 0 or ad_pool_pointers[0, 0].item() == 0:
341-
return
342-
343-
# Set pool pointers (these are static for the lifetime of the cache)
344-
self._shared_host_kv_cache_pool_pointers[0, 0] = ad_pool_pointers[0, 0].item()
345-
self._shared_host_kv_cache_pool_pointers[0, 1] = 0 # v_ptr=0 for interleaved
346-
347-
# Set pool mapping for all layers
348-
for layer_i in range(min(num_layers, ad_pool_mapping.shape[0])):
349-
self._shared_host_kv_cache_pool_mapping[layer_i, 0] = ad_pool_mapping[layer_i, 0].item()
350-
self._shared_host_kv_cache_pool_mapping[layer_i, 1] = ad_pool_mapping[layer_i, 1].item()
348+
# Check if AD pool pointers are valid
349+
use_ad_pool = (
350+
ad_pool_pointers is not None
351+
and ad_pool_mapping is not None
352+
and ad_pool_pointers.numel() > 0
353+
and ad_pool_pointers[0, 0].item() != 0
354+
)
351355

352-
self._pool_pointers_initialized = True
356+
if use_ad_pool:
357+
# Use AD's pool pointers directly
358+
# K and V are interleaved in blocks, so only K ptr is needed
359+
# V ptr is set to 0 - kernel uses block offsets to find V
360+
self._shared_host_kv_cache_pool_pointers[0, 0] = ad_pool_pointers[0, 0].item()
361+
self._shared_host_kv_cache_pool_pointers[0, 1] = 0
362+
363+
# Set pool mapping for all layers
364+
for layer_i in range(min(num_layers, ad_pool_mapping.shape[0])):
365+
self._shared_host_kv_cache_pool_mapping[layer_i, 0] = ad_pool_mapping[
366+
layer_i, 0
367+
].item()
368+
self._shared_host_kv_cache_pool_mapping[layer_i, 1] = ad_pool_mapping[
369+
layer_i, 1
370+
].item()
371+
372+
self._pool_pointers_initialized = True
373+
elif fallback_kv_cache_ptr is not None and fallback_kv_cache_ptr != 0:
374+
# Fallback: Use kv_cache tensor's data pointer directly
375+
# V ptr is set to 0 - kernel uses block offsets to find V
376+
self._shared_host_kv_cache_pool_pointers[0, 0] = fallback_kv_cache_ptr # K ptr
377+
self._shared_host_kv_cache_pool_pointers[0, 1] = 0 # V ptr = 0
378+
379+
# All layers map to pool 0 with layer offset
380+
for layer_i in range(num_layers):
381+
self._shared_host_kv_cache_pool_mapping[layer_i, 0] = 0 # pool index
382+
self._shared_host_kv_cache_pool_mapping[layer_i, 1] = layer_i # layer offset
383+
384+
self._pool_pointers_initialized = True
353385

354386
def get_or_create_layer_state(
355387
self,
@@ -527,7 +559,8 @@ def _host_prepare_trtllm_metadata(
527559
)
528560
page_in_seq = global_state._gpu_page_idx[:total_pages]
529561

530-
# base_offsets on GPU
562+
# Block offsets: multiplier = num_layers * kv_factor for interleaved K/V
563+
# K and V have different offsets: K = base, V = base + 1
531564
kv_factor = 2
532565
multiplier = num_layers * kv_factor
533566
torch.mul(
@@ -540,10 +573,10 @@ def _host_prepare_trtllm_metadata(
540573
# Fill block offsets using advanced indexing (only zero the slice we need)
541574
global_state._shared_kv_cache_block_offsets[:, :num_seq, :, :].zero_()
542575
global_state._shared_kv_cache_block_offsets[0, seq_indices, 0, page_in_seq] = (
543-
base_offsets
576+
base_offsets # K
544577
)
545578
global_state._shared_kv_cache_block_offsets[0, seq_indices, 1, page_in_seq] = (
546-
base_offsets + 1
579+
base_offsets + 1 # V
547580
)
548581

549582
# Mark that host_prepare has run
@@ -665,7 +698,13 @@ def _prepare_trtllm_metadata(
665698
f"Expected kv_cache shape [pages, 2, heads, tokens, dim], got {kv_cache.shape}"
666699
)
667700

668-
num_layers = state.num_layers if state.num_layers > 0 else 32
701+
# Get num_layers - prefer from pool_mapping (accurate) over state.num_layers (may be stale)
702+
if ad_pool_mapping is not None and ad_pool_mapping.numel() > 0:
703+
num_layers = ad_pool_mapping.shape[0]
704+
elif state.num_layers > 0:
705+
num_layers = state.num_layers
706+
else:
707+
num_layers = 32
669708

670709
# Compute input sequence lengths from cumulative sums
671710
input_seq_lens = (cu_seqlen_host[1 : num_seq + 1] - cu_seqlen_host[:num_seq]).int()
@@ -686,7 +725,7 @@ def _prepare_trtllm_metadata(
686725
state.sequence_length[:num_seq].copy_(seq_len_with_cache.cuda())
687726
state.context_lengths[:num_seq].copy_(input_seq_lens.cuda())
688727

689-
# Set up KV cache pool pointers
728+
# Set up KV cache pool pointers - use AD's pool pointers
690729
use_ad_pool = (
691730
ad_pool_pointers is not None
692731
and ad_pool_mapping is not None
@@ -709,12 +748,18 @@ def _prepare_trtllm_metadata(
709748
state.host_kv_cache_pool_mapping[layer_i, 0] = ad_pool_mapping[layer_i, 0].item()
710749
state.host_kv_cache_pool_mapping[layer_i, 1] = ad_pool_mapping[layer_i, 1].item()
711750

751+
# Mark pool pointers as initialized in global state to skip redundant init in host_prepare_fn
752+
_global_state._pool_pointers_initialized = True
753+
712754
# Block offsets: convert flat cache_loc to per-sequence block indices
713755
pages_per_seq = (cu_num_pages_host[1 : num_seq + 1] - cu_num_pages_host[:num_seq]).int()
714756
max_blocks = pages_per_seq.max().item() if num_seq > 0 else 1
715757
_global_state.set_max_blocks_per_seq(max_blocks)
716758

717759
# Fill block offsets
760+
# AD's cache_loc contains LOGICAL block indices from KVCacheManager
761+
# Multiplier = num_layers * kv_factor for interleaved K/V layout
762+
# K and V have different offsets: K = base, V = base + 1
718763
kv_factor = 2
719764
multiplier = num_layers * kv_factor
720765
state.kv_cache_block_offsets.zero_()
@@ -723,10 +768,32 @@ def _prepare_trtllm_metadata(
723768
n_pages = pages_per_seq[i].item()
724769
if n_pages > 0:
725770
base_offsets = cache_loc[offset : offset + n_pages] * multiplier
726-
state.kv_cache_block_offsets[0, i, 0, :n_pages] = base_offsets
727-
state.kv_cache_block_offsets[0, i, 1, :n_pages] = base_offsets + 1
771+
state.kv_cache_block_offsets[0, i, 0, :n_pages] = base_offsets # K
772+
state.kv_cache_block_offsets[0, i, 1, :n_pages] = base_offsets + 1 # V
728773
offset += n_pages
729774

775+
# Debug: print info for layers 0, 1, 31 (first few calls only)
776+
debug_count = getattr(_global_state, "_debug_count", 0)
777+
if debug_count < 10 and state.layer_idx in [0, 1, 31]:
778+
_global_state._debug_count = debug_count + 1
779+
print(
780+
f"\n[DEBUG #{debug_count}] === Layer {state.layer_idx}, num_seq={num_seq}, num_layers={num_layers} ==="
781+
)
782+
k_ptr = state.host_kv_cache_pool_pointers[0, 0].item()
783+
v_ptr = state.host_kv_cache_pool_pointers[0, 1].item()
784+
print(f"[DEBUG] pool_pointers: K={k_ptr}, V={v_ptr}")
785+
mapping = state.host_kv_cache_pool_mapping[state.layer_idx].tolist()
786+
print(f"[DEBUG] pool_mapping layer {state.layer_idx}: {mapping}")
787+
if state.layer_idx == 0:
788+
total_pages = cache_loc.shape[0] if cache_loc.numel() > 0 else 0
789+
print(f"[DEBUG] cache_loc[:5]: {cache_loc[: min(5, total_pages)].tolist()}")
790+
print(
791+
f"[DEBUG] K_offsets[:5]: {(cache_loc[: min(5, total_pages)] * multiplier).tolist()}"
792+
)
793+
print(
794+
f"[DEBUG] V_offsets[:5]: {(cache_loc[: min(5, total_pages)] * multiplier + 1).tolist()}"
795+
)
796+
730797
# Return tensors
731798
max_blocks_per_seq = state.kv_cache_block_offsets.shape[3]
732799

@@ -793,14 +860,32 @@ def trtllm_mha_with_cache(
793860
if not kv_cache.is_cuda:
794861
raise RuntimeError(f"kv_cache must be on CUDA, got {kv_cache.device}")
795862

796-
# Validate unified KV cache format
797-
# Expected shape: [num_blocks, 2, num_kv_heads, tokens_per_block, head_dim] (HND layout)
798-
# This shape is created by TrtllmKVResourceHandler.allocate() which permutes the base allocation
863+
# Validate KV cache format
864+
# TrtllmKVResourceHandler configures AD to allocate cache in HND format:
865+
# [num_blocks, 2, num_kv_heads, tokens_per_block, head_dim]
866+
# This matches what thop.attention expects.
799867
assert kv_cache.dim() == 5, f"kv_cache must be 5D, got {kv_cache.dim()}D"
800868
assert kv_cache.shape[1] == 2, (
801-
f"kv_cache.shape[1] must be 2 (kv_factor), got {kv_cache.shape[1]}"
869+
f"kv_cache must be in HND format [B, 2, H, T, D] with shape[1]=2, "
870+
f"got shape {kv_cache.shape}. Ensure TrtllmKVResourceHandler is used."
802871
)
803872

873+
# Lazy initialization of model config (done once on first attention call)
874+
if _trtllm_config._num_layers == 0:
875+
# Track layer config
876+
TrtllmAttention._track_layer_config(num_kv_heads, head_dim, kv_cache.dtype)
877+
878+
# Once all layers have been seen, set model config
879+
# We infer num_layers from layer_idx (assumes layers are processed in order)
880+
expected_num_layers = len(TrtllmAttention._num_kv_heads_per_layer)
881+
if layer_idx == expected_num_layers - 1 or expected_num_layers >= 32:
882+
_trtllm_config.set_model_config(
883+
num_layers=expected_num_layers,
884+
num_kv_heads_per_layer=TrtllmAttention._num_kv_heads_per_layer,
885+
head_dim=head_dim,
886+
dtype=kv_cache.dtype,
887+
)
888+
804889
# Get batch dimensions
805890
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
806891
num_seq = num_prefill + num_decode

0 commit comments

Comments
 (0)