5757 CacheConfig ,
5858 Constant ,
5959 MHACallable ,
60+ PagedResourceHandler ,
6061 PrepareMetadataCallable ,
6162 PrepareMetadataHostCallable ,
6263 ResourceHandler ,
@@ -84,14 +85,14 @@ def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
8485 return buffer
8586
8687
87- class TrtllmKVResourceHandler (ResourceHandler ):
88+ class TrtllmKVResourceHandler (PagedResourceHandler ):
8889 """Resource handler for TRT-LLM unified KV cache.
8990
90- Uses ResourceHandler (not PagedResourceHandler) so the interface calls allocate()
91- directly, allowing us to create the cache with the exact layout thop.attention expects .
91+ Extends PagedResourceHandler so the cache interface recognizes it as a paged resource
92+ and creates a proper KVCacheManager with correct parameters (num_layers, num_kv_heads, etc.) .
9293
93- Uses kv_factor=2 (unified K+V) and kv_layout="HND" to match what thop.attention expects .
94- The cache is allocated with shape [num_blocks, 2, num_kv_heads, tokens_per_block, head_dim] .
94+ Uses kv_layout="HND" to request HND format from AD's cache manager .
95+ The cache is converted to HND format in interface.py via permute+contiguous .
9596 """
9697
9798 def __init__ (
@@ -103,44 +104,32 @@ def __init__(
103104 trtllm_config : "TrtllmAttentionGlobalState" ,
104105 cache_config : CacheConfig ,
105106 ) -> None :
106- # Store attributes for TRT-LLM attention
107+ # Initialize parent class with token_shape = (2, num_kv_heads, head_dim)
108+ # The 2 is kv_factor for unified K+V cache
109+ # Use HND layout for thop.attention kernel compatibility
110+ super ().__init__ (2 , num_kv_heads , head_dim , dtype = dtype , kv_layout = "HND" )
111+
112+ # Store additional attributes for TRT-LLM attention
107113 self .num_kv_heads = num_kv_heads
108114 self .head_dim = head_dim
109- self .dtype = dtype
110115 self .kv_factor = 2 # Unified K+V cache
111- self .kv_layout = "HND" # Matches thop.attention kernel's per-block layout
112116 self .layer_idx = layer_idx
113117 self ._trtllm_config = trtllm_config
114118 self ._cache_config = cache_config
115119
116- def allocate (self , sequence_info : SequenceInfo ) -> torch .Tensor :
117- """Allocate cache via KVCacheManager or simple allocation."""
118- # Configure global state first (first time only)
119- if not self ._trtllm_config .is_configured :
120- self ._trtllm_config .configure (sequence_info )
121-
122- # Set model config for FP8 KV cache support (first time only)
123- if self ._trtllm_config ._num_layers == 0 :
124- cache_dtype = self .dtype
125- self ._trtllm_config .set_model_config (
126- num_layers = len (TrtllmAttention ._num_kv_heads_per_layer ),
127- num_kv_heads_per_layer = TrtllmAttention ._num_kv_heads_per_layer ,
128- head_dim = TrtllmAttention ._head_dim ,
129- dtype = cache_dtype ,
130- )
120+ def __eq__ (self , other : "TrtllmKVResourceHandler" ) -> bool :
121+ """Check compatibility for KVCacheManager resource grouping.
131122
132- # Allocate unified KV cache with correct layout for thop.attention
133- # Shape: [num_blocks, kv_factor=2, num_kv_heads, tokens_per_block, head_dim] (HND layout)
134- cache = torch .empty (
135- sequence_info .num_blocks ,
136- self .kv_factor , # 2 for K and V
137- self .num_kv_heads ,
138- sequence_info .tokens_per_block ,
139- self .head_dim ,
140- device = sequence_info .device ,
141- dtype = self .dtype ,
123+ Return True so KVCacheManager manages all layers' KV caches together.
124+ """
125+ if not isinstance (other , TrtllmKVResourceHandler ):
126+ return False
127+ return (
128+ self .head_dim == other .head_dim
129+ and self .dtype == other .dtype
130+ and self .kv_factor == other .kv_factor
131+ and self .kv_layout == other .kv_layout
142132 )
143- return cache
144133
145134
146135@dataclass
@@ -175,8 +164,15 @@ class TrtllmLayerState:
175164
176165 def __post_init__ (self ):
177166 """Initialize tensors - use shared tensors from global state where possible."""
178- # Pool mapping needs to be pre-allocated before init_from_shared
179- # Other tensors will come from shared state via init_from_shared()
167+ # Pool pointers and mapping are per-layer (each layer has its own cache buffer)
168+ # These are NOT shared across layers
169+ if self .host_kv_cache_pool_pointers is None :
170+ # Pool pointers: [num_pools, 2] for K and V pointers
171+ # With per-layer caches, each layer uses pool 0 which points to its own buffer
172+ self .host_kv_cache_pool_pointers = torch .zeros (
173+ 1 , 2 , dtype = torch .int64 , device = "cpu" , pin_memory = True
174+ )
175+
180176 if self .host_kv_cache_pool_mapping is None :
181177 # Pool mapping: 2D [num_layers, 2] format expected by thop.attention
182178 max_layers = 256
@@ -185,17 +181,22 @@ def __post_init__(self):
185181 )
186182
187183 def init_from_shared (self , global_state : "TrtllmAttentionGlobalState" ) -> None :
188- """Initialize layer to use shared tensors from global state."""
189- # All layers share the same tensors (single KV cache pool)
184+ """Initialize layer to use shared tensors from global state.
185+
186+ NOTE: Pool pointers (host_kv_cache_pool_pointers) are NOT shared because each layer
187+ has its own cache buffer with a different data_ptr(). These are initialized in
188+ __post_init__ and set per-layer in _prepare_trtllm_metadata.
189+ """
190+ # All layers share sequence/batch metadata tensors
190191 self .sequence_length = global_state ._shared_sequence_length
191192 self .context_lengths = global_state ._shared_context_lengths
192193 self .kv_cache_block_offsets = global_state ._shared_kv_cache_block_offsets
193194 self .host_past_key_value_lengths = global_state ._shared_host_past_key_value_lengths
194195 self .host_context_lengths = global_state ._shared_host_context_lengths
195196 self .host_request_types = global_state ._shared_host_request_types
196197 self .host_total_kv_lens = global_state ._shared_host_total_kv_lens
197- self . host_kv_cache_pool_pointers = global_state . _shared_host_kv_cache_pool_pointers
198- # Keep host_kv_cache_pool_mapping from __post_init__ - it's layer-specific
198+ # NOTE: host_kv_cache_pool_pointers is NOT shared - each layer has its own from __post_init__
199+ # NOTE: host_kv_cache_pool_mapping is NOT shared - each layer has its own from __post_init__
199200
200201
201202class TrtllmAttentionGlobalState :
@@ -324,32 +325,63 @@ def _init_gpu_buffers(self, max_pages: int, max_seqs: int) -> None:
324325 self ._gpu_buffers_initialized = True
325326
326327 def _init_pool_pointers (
327- self , ad_pool_pointers : torch .Tensor , ad_pool_mapping : torch .Tensor , num_layers : int
328+ self ,
329+ ad_pool_pointers : Optional [torch .Tensor ],
330+ ad_pool_mapping : Optional [torch .Tensor ],
331+ num_layers : int ,
332+ fallback_kv_cache_ptr : Optional [int ] = None ,
328333 ) -> None :
329- """Initialize pool pointers once from AD's KVCacheManager.
334+ """Initialize pool pointers once from AD's KVCacheManager or fallback .
330335
331336 This is called once during first host_prepare to set up the static pool info.
332337 Pool pointers don't change between requests - only block offsets do.
338+
339+ Args:
340+ ad_pool_pointers: Pool pointers from SequenceInfo (if available)
341+ ad_pool_mapping: Pool mapping from SequenceInfo (if available)
342+ num_layers: Number of transformer layers
343+ fallback_kv_cache_ptr: Fallback KV cache data pointer (if AD pool not available)
333344 """
334345 if self ._pool_pointers_initialized :
335346 return
336347
337- if ad_pool_pointers is None or ad_pool_mapping is None :
338- return
339-
340- if ad_pool_pointers .numel () == 0 or ad_pool_pointers [0 , 0 ].item () == 0 :
341- return
342-
343- # Set pool pointers (these are static for the lifetime of the cache)
344- self ._shared_host_kv_cache_pool_pointers [0 , 0 ] = ad_pool_pointers [0 , 0 ].item ()
345- self ._shared_host_kv_cache_pool_pointers [0 , 1 ] = 0 # v_ptr=0 for interleaved
346-
347- # Set pool mapping for all layers
348- for layer_i in range (min (num_layers , ad_pool_mapping .shape [0 ])):
349- self ._shared_host_kv_cache_pool_mapping [layer_i , 0 ] = ad_pool_mapping [layer_i , 0 ].item ()
350- self ._shared_host_kv_cache_pool_mapping [layer_i , 1 ] = ad_pool_mapping [layer_i , 1 ].item ()
348+ # Check if AD pool pointers are valid
349+ use_ad_pool = (
350+ ad_pool_pointers is not None
351+ and ad_pool_mapping is not None
352+ and ad_pool_pointers .numel () > 0
353+ and ad_pool_pointers [0 , 0 ].item () != 0
354+ )
351355
352- self ._pool_pointers_initialized = True
356+ if use_ad_pool :
357+ # Use AD's pool pointers directly
358+ # K and V are interleaved in blocks, so only K ptr is needed
359+ # V ptr is set to 0 - kernel uses block offsets to find V
360+ self ._shared_host_kv_cache_pool_pointers [0 , 0 ] = ad_pool_pointers [0 , 0 ].item ()
361+ self ._shared_host_kv_cache_pool_pointers [0 , 1 ] = 0
362+
363+ # Set pool mapping for all layers
364+ for layer_i in range (min (num_layers , ad_pool_mapping .shape [0 ])):
365+ self ._shared_host_kv_cache_pool_mapping [layer_i , 0 ] = ad_pool_mapping [
366+ layer_i , 0
367+ ].item ()
368+ self ._shared_host_kv_cache_pool_mapping [layer_i , 1 ] = ad_pool_mapping [
369+ layer_i , 1
370+ ].item ()
371+
372+ self ._pool_pointers_initialized = True
373+ elif fallback_kv_cache_ptr is not None and fallback_kv_cache_ptr != 0 :
374+ # Fallback: Use kv_cache tensor's data pointer directly
375+ # V ptr is set to 0 - kernel uses block offsets to find V
376+ self ._shared_host_kv_cache_pool_pointers [0 , 0 ] = fallback_kv_cache_ptr # K ptr
377+ self ._shared_host_kv_cache_pool_pointers [0 , 1 ] = 0 # V ptr = 0
378+
379+ # All layers map to pool 0 with layer offset
380+ for layer_i in range (num_layers ):
381+ self ._shared_host_kv_cache_pool_mapping [layer_i , 0 ] = 0 # pool index
382+ self ._shared_host_kv_cache_pool_mapping [layer_i , 1 ] = layer_i # layer offset
383+
384+ self ._pool_pointers_initialized = True
353385
354386 def get_or_create_layer_state (
355387 self ,
@@ -527,7 +559,8 @@ def _host_prepare_trtllm_metadata(
527559 )
528560 page_in_seq = global_state ._gpu_page_idx [:total_pages ]
529561
530- # base_offsets on GPU
562+ # Block offsets: multiplier = num_layers * kv_factor for interleaved K/V
563+ # K and V have different offsets: K = base, V = base + 1
531564 kv_factor = 2
532565 multiplier = num_layers * kv_factor
533566 torch .mul (
@@ -540,10 +573,10 @@ def _host_prepare_trtllm_metadata(
540573 # Fill block offsets using advanced indexing (only zero the slice we need)
541574 global_state ._shared_kv_cache_block_offsets [:, :num_seq , :, :].zero_ ()
542575 global_state ._shared_kv_cache_block_offsets [0 , seq_indices , 0 , page_in_seq ] = (
543- base_offsets
576+ base_offsets # K
544577 )
545578 global_state ._shared_kv_cache_block_offsets [0 , seq_indices , 1 , page_in_seq ] = (
546- base_offsets + 1
579+ base_offsets + 1 # V
547580 )
548581
549582 # Mark that host_prepare has run
@@ -665,7 +698,13 @@ def _prepare_trtllm_metadata(
665698 f"Expected kv_cache shape [pages, 2, heads, tokens, dim], got { kv_cache .shape } "
666699 )
667700
668- num_layers = state .num_layers if state .num_layers > 0 else 32
701+ # Get num_layers - prefer from pool_mapping (accurate) over state.num_layers (may be stale)
702+ if ad_pool_mapping is not None and ad_pool_mapping .numel () > 0 :
703+ num_layers = ad_pool_mapping .shape [0 ]
704+ elif state .num_layers > 0 :
705+ num_layers = state .num_layers
706+ else :
707+ num_layers = 32
669708
670709 # Compute input sequence lengths from cumulative sums
671710 input_seq_lens = (cu_seqlen_host [1 : num_seq + 1 ] - cu_seqlen_host [:num_seq ]).int ()
@@ -686,7 +725,7 @@ def _prepare_trtllm_metadata(
686725 state .sequence_length [:num_seq ].copy_ (seq_len_with_cache .cuda ())
687726 state .context_lengths [:num_seq ].copy_ (input_seq_lens .cuda ())
688727
689- # Set up KV cache pool pointers
728+ # Set up KV cache pool pointers - use AD's pool pointers
690729 use_ad_pool = (
691730 ad_pool_pointers is not None
692731 and ad_pool_mapping is not None
@@ -709,12 +748,18 @@ def _prepare_trtllm_metadata(
709748 state .host_kv_cache_pool_mapping [layer_i , 0 ] = ad_pool_mapping [layer_i , 0 ].item ()
710749 state .host_kv_cache_pool_mapping [layer_i , 1 ] = ad_pool_mapping [layer_i , 1 ].item ()
711750
751+ # Mark pool pointers as initialized in global state to skip redundant init in host_prepare_fn
752+ _global_state ._pool_pointers_initialized = True
753+
712754 # Block offsets: convert flat cache_loc to per-sequence block indices
713755 pages_per_seq = (cu_num_pages_host [1 : num_seq + 1 ] - cu_num_pages_host [:num_seq ]).int ()
714756 max_blocks = pages_per_seq .max ().item () if num_seq > 0 else 1
715757 _global_state .set_max_blocks_per_seq (max_blocks )
716758
717759 # Fill block offsets
760+ # AD's cache_loc contains LOGICAL block indices from KVCacheManager
761+ # Multiplier = num_layers * kv_factor for interleaved K/V layout
762+ # K and V have different offsets: K = base, V = base + 1
718763 kv_factor = 2
719764 multiplier = num_layers * kv_factor
720765 state .kv_cache_block_offsets .zero_ ()
@@ -723,10 +768,32 @@ def _prepare_trtllm_metadata(
723768 n_pages = pages_per_seq [i ].item ()
724769 if n_pages > 0 :
725770 base_offsets = cache_loc [offset : offset + n_pages ] * multiplier
726- state .kv_cache_block_offsets [0 , i , 0 , :n_pages ] = base_offsets
727- state .kv_cache_block_offsets [0 , i , 1 , :n_pages ] = base_offsets + 1
771+ state .kv_cache_block_offsets [0 , i , 0 , :n_pages ] = base_offsets # K
772+ state .kv_cache_block_offsets [0 , i , 1 , :n_pages ] = base_offsets + 1 # V
728773 offset += n_pages
729774
775+ # Debug: print info for layers 0, 1, 31 (first few calls only)
776+ debug_count = getattr (_global_state , "_debug_count" , 0 )
777+ if debug_count < 10 and state .layer_idx in [0 , 1 , 31 ]:
778+ _global_state ._debug_count = debug_count + 1
779+ print (
780+ f"\n [DEBUG #{ debug_count } ] === Layer { state .layer_idx } , num_seq={ num_seq } , num_layers={ num_layers } ==="
781+ )
782+ k_ptr = state .host_kv_cache_pool_pointers [0 , 0 ].item ()
783+ v_ptr = state .host_kv_cache_pool_pointers [0 , 1 ].item ()
784+ print (f"[DEBUG] pool_pointers: K={ k_ptr } , V={ v_ptr } " )
785+ mapping = state .host_kv_cache_pool_mapping [state .layer_idx ].tolist ()
786+ print (f"[DEBUG] pool_mapping layer { state .layer_idx } : { mapping } " )
787+ if state .layer_idx == 0 :
788+ total_pages = cache_loc .shape [0 ] if cache_loc .numel () > 0 else 0
789+ print (f"[DEBUG] cache_loc[:5]: { cache_loc [: min (5 , total_pages )].tolist ()} " )
790+ print (
791+ f"[DEBUG] K_offsets[:5]: { (cache_loc [: min (5 , total_pages )] * multiplier ).tolist ()} "
792+ )
793+ print (
794+ f"[DEBUG] V_offsets[:5]: { (cache_loc [: min (5 , total_pages )] * multiplier + 1 ).tolist ()} "
795+ )
796+
730797 # Return tensors
731798 max_blocks_per_seq = state .kv_cache_block_offsets .shape [3 ]
732799
@@ -793,14 +860,32 @@ def trtllm_mha_with_cache(
793860 if not kv_cache .is_cuda :
794861 raise RuntimeError (f"kv_cache must be on CUDA, got { kv_cache .device } " )
795862
796- # Validate unified KV cache format
797- # Expected shape: [num_blocks, 2, num_kv_heads, tokens_per_block, head_dim] (HND layout)
798- # This shape is created by TrtllmKVResourceHandler.allocate() which permutes the base allocation
863+ # Validate KV cache format
864+ # TrtllmKVResourceHandler configures AD to allocate cache in HND format:
865+ # [num_blocks, 2, num_kv_heads, tokens_per_block, head_dim]
866+ # This matches what thop.attention expects.
799867 assert kv_cache .dim () == 5 , f"kv_cache must be 5D, got { kv_cache .dim ()} D"
800868 assert kv_cache .shape [1 ] == 2 , (
801- f"kv_cache.shape[1] must be 2 (kv_factor), got { kv_cache .shape [1 ]} "
869+ f"kv_cache must be in HND format [B, 2, H, T, D] with shape[1]=2, "
870+ f"got shape { kv_cache .shape } . Ensure TrtllmKVResourceHandler is used."
802871 )
803872
873+ # Lazy initialization of model config (done once on first attention call)
874+ if _trtllm_config ._num_layers == 0 :
875+ # Track layer config
876+ TrtllmAttention ._track_layer_config (num_kv_heads , head_dim , kv_cache .dtype )
877+
878+ # Once all layers have been seen, set model config
879+ # We infer num_layers from layer_idx (assumes layers are processed in order)
880+ expected_num_layers = len (TrtllmAttention ._num_kv_heads_per_layer )
881+ if layer_idx == expected_num_layers - 1 or expected_num_layers >= 32 :
882+ _trtllm_config .set_model_config (
883+ num_layers = expected_num_layers ,
884+ num_kv_heads_per_layer = TrtllmAttention ._num_kv_heads_per_layer ,
885+ head_dim = head_dim ,
886+ dtype = kv_cache .dtype ,
887+ )
888+
804889 # Get batch dimensions
805890 num_prefill , num_prefill_tokens , num_decode = batch_info_host .tolist ()
806891 num_seq = num_prefill + num_decode
0 commit comments