Skip to content

Commit 13df7c3

Browse files
committed
Fix the issues
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
1 parent d1f23ca commit 13df7c3

File tree

6 files changed

+205
-111
lines changed

6 files changed

+205
-111
lines changed

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..speculative.utils import SpecDecodingTensor
1111
from ..speculative.interface import SpecMetadata
1212
from ..speculative.spec_tree_manager import SpecTreeManager
13+
from ..pyexecutor.resource_manager import KVCacheManager
1314

1415
from tensorrt_llm._utils import get_sm_version
1516
from tensorrt_llm.bindings.internal import thop
@@ -647,6 +648,9 @@ class TrtllmAttentionMetadata(AttentionMetadata):
647648
init=True,
648649
repr=False)
649650

651+
# Draft KV cache manager for one-model speculative decoding with separate KV cache layouts
652+
draft_kv_cache_manager: Optional["KVCacheManager"] = None
653+
650654
# Flags to enable spec-dec mode (multi-query mode) in TRTLLM XQA Kernels
651655
# spec decoding mode can be enabled for non-TRTLLM-gen kernels (pre-Blackwell XQA kernels)
652656
# is_spec_decoding_enabled specifies if spec-dec mode is supported for the entire runtime.
@@ -796,6 +800,29 @@ def _post_init_with_buffers(self, buffers) -> None:
796800
)
797801
self.block_ids_per_seq = None
798802
self.kv_block_ids_per_seq = None
803+
804+
# Allocate separate block offset tensors for draft KV cache manager
805+
# Used in one-model speculative decoding with different KV cache layouts
806+
if self.draft_kv_cache_manager is not None:
807+
self.draft_kv_cache_block_offsets = self.get_empty(
808+
buffers,
809+
[
810+
self.draft_kv_cache_manager.num_pools,
811+
self.max_num_sequences, 2,
812+
self.draft_kv_cache_manager.max_blocks_per_seq
813+
],
814+
cache_name="draft_kv_cache_block_offsets",
815+
dtype=torch.int32,
816+
capture_graph=capture_graph,
817+
)
818+
self.draft_host_kv_cache_block_offsets = torch.empty_like(
819+
self.draft_kv_cache_block_offsets,
820+
device='cpu',
821+
pin_memory=True,
822+
)
823+
else:
824+
self.draft_kv_cache_block_offsets = None
825+
self.draft_host_kv_cache_block_offsets = None
799826
if self.enable_flash_mla:
800827
self.block_ids_per_seq = self.get_empty(
801828
buffers,
@@ -1007,6 +1034,25 @@ def prepare(self) -> None:
10071034
assert self.kv_lens[:self.num_seqs].max(
10081035
) <= self.kv_cache_manager.max_seq_len, error_message
10091036

1037+
# Also prepare draft KV cache block offsets if draft_kv_cache_manager exists
1038+
if self.draft_kv_cache_manager is not None:
1039+
# Copy blocks for all context requests
1040+
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
1041+
self.draft_host_kv_cache_block_offsets,
1042+
self.request_ids[:self.num_contexts], 1, 0)
1043+
# Copy blocks for all generation requests
1044+
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
1045+
self.draft_host_kv_cache_block_offsets,
1046+
self.request_ids[self.num_contexts:], self.beam_width,
1047+
self.num_contexts)
1048+
for pool_idx in range(
1049+
self.draft_host_kv_cache_block_offsets.shape[0]):
1050+
self.draft_kv_cache_block_offsets[
1051+
pool_idx, :self.num_seqs].copy_(
1052+
self.draft_host_kv_cache_block_offsets[
1053+
pool_idx, :self.num_seqs],
1054+
non_blocking=True)
1055+
10101056
self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
10111057
# Don't use self.kv_lens here because it includes extra tokens.
10121058
# Use actual KV length (without extra tokens) for kv_lens_runtime,

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 82 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,12 @@ def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None:
402402
# get kv cache stats for both model and draft model
403403
kv_stats = py_executor.resource_manager.resource_managers.get(
404404
ResourceManagerType.KV_CACHE_MANAGER).get_kv_cache_stats()
405-
kv_stats_draft = py_executor.resource_manager.resource_managers.get(
406-
ResourceManagerType.DRAFT_KV_CACHE_MANAGER).get_kv_cache_stats(
407-
) if self._draft_model_engine is not None else None
405+
# Get draft KV cache stats if present (either from two-model mode or one-model
406+
# mode with separate draft KV cache)
407+
draft_kv_cache_manager = py_executor.resource_manager.resource_managers.get(
408+
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
409+
kv_stats_draft = draft_kv_cache_manager.get_kv_cache_stats(
410+
) if draft_kv_cache_manager is not None else None
408411

409412
# get total allocated bytes
410413
allocated_bytes = kv_stats.allocated_bytes + (
@@ -525,83 +528,42 @@ def _create_one_model_draft_kv_cache_manager(
525528
Create a KV cache manager for draft model layers in one-model mode
526529
when target and draft have different KV cache layouts.
527530
"""
528-
draft_pretrained_config = self._draft_config.pretrained_config
529-
quant_config = self._draft_config.quant_config
530-
531-
# Determine KV cache dtype
532-
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache(
533-
):
534-
kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8
535-
elif quant_config is not None and quant_config.quant_mode.has_fp4_kv_cache(
536-
):
537-
kv_cache_dtype = tensorrt_llm.bindings.DataType.NVFP4
538-
else:
539-
kv_cache_dtype = str_dtype_to_binding(
540-
torch_dtype_to_str(self._model_engine.dtype))
541-
542-
num_draft_layers = draft_pretrained_config.num_hidden_layers
543-
544531
# Get target model's num_hidden_layers to compute correct layer indices.
545532
# Draft model layers in one-model mode start at target_num_layers.
546533
target_pretrained_config = self._model_engine.model.model_config.pretrained_config
547534
target_num_layers = target_pretrained_config.num_hidden_layers
535+
num_draft_layers = self._draft_config.pretrained_config.num_hidden_layers
548536

549537
# Create layer_mask: False for target layers, True for draft layers.
550538
# This ensures the draft KV cache manager uses the correct layer indices
551539
# (e.g., layers 32, 33, ... instead of 0, 1, ...).
552540
layer_mask = [False] * target_num_layers + [True] * num_draft_layers
553541

554-
if is_mla(draft_pretrained_config):
555-
# Draft uses MLA
556-
return self._kv_cache_manager_cls(
557-
self._kv_cache_config,
558-
tensorrt_llm.bindings.internal.batch_manager.CacheType.
559-
SELFKONLY,
560-
num_layers=num_draft_layers,
561-
num_kv_heads=1,
562-
head_dim=draft_pretrained_config.kv_lora_rank +
563-
draft_pretrained_config.qk_rope_head_dim,
564-
tokens_per_block=self._tokens_per_block,
565-
max_seq_len=self._max_seq_len,
566-
max_batch_size=self._max_batch_size,
567-
mapping=self._mapping,
568-
dtype=kv_cache_dtype,
569-
spec_config=self._speculative_config,
570-
max_num_tokens=self._max_num_tokens,
571-
is_draft=True,
572-
is_estimating_kv_cache=estimating_kv_cache,
573-
execution_stream=self._execution_stream,
574-
layer_mask=layer_mask,
575-
)
576-
else:
577-
# Draft uses standard attention
578-
hidden_size = draft_pretrained_config.hidden_size
579-
num_attention_heads = draft_pretrained_config.num_attention_heads
580-
num_key_value_heads = getattr(draft_pretrained_config,
581-
'num_key_value_heads',
582-
num_attention_heads)
583-
head_dim = getattr(draft_pretrained_config, "head_dim", None)
584-
if not isinstance(head_dim, int):
585-
head_dim = hidden_size // num_attention_heads
586-
587-
return self._kv_cache_manager_cls(
588-
self._kv_cache_config,
589-
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
590-
num_layers=num_draft_layers,
591-
num_kv_heads=num_key_value_heads,
592-
head_dim=head_dim,
593-
tokens_per_block=self._tokens_per_block,
594-
max_seq_len=self._max_seq_len,
595-
max_batch_size=self._max_batch_size,
596-
mapping=self._mapping,
597-
dtype=kv_cache_dtype,
598-
spec_config=self._speculative_config,
599-
max_num_tokens=self._max_num_tokens,
600-
is_draft=True,
601-
is_estimating_kv_cache=estimating_kv_cache,
602-
execution_stream=self._execution_stream,
603-
layer_mask=layer_mask,
604-
)
542+
# Get the appropriate KV cache manager class for the draft model
543+
draft_kv_cache_manager_cls = get_kv_cache_manager_cls(
544+
self._draft_config)
545+
546+
return _create_kv_cache_manager(
547+
model_engine=None,
548+
kv_cache_manager_cls=draft_kv_cache_manager_cls,
549+
mapping=self._mapping,
550+
kv_cache_config=self._kv_cache_config,
551+
tokens_per_block=self._tokens_per_block,
552+
max_seq_len=self._max_seq_len,
553+
max_batch_size=self._max_batch_size,
554+
spec_config=self._speculative_config,
555+
sparse_attn_config=None, # Not applicable for draft in one-model mode
556+
max_num_tokens=self._max_num_tokens,
557+
max_beam_width=self._max_beam_width,
558+
kv_connector_manager=None, # Not supported for draft models
559+
estimating_kv_cache=estimating_kv_cache,
560+
execution_stream=self._execution_stream,
561+
# One-model draft specific overrides
562+
model_config=self._draft_config,
563+
dtype=self._model_engine.dtype,
564+
is_draft=True,
565+
layer_mask=layer_mask,
566+
)
605567

606568
def build_managers(self,
607569
resources: Dict,
@@ -641,7 +603,7 @@ def teardown_managers(self, resources: Dict) -> None:
641603

642604

643605
def _create_kv_cache_manager(
644-
model_engine: PyTorchModelEngine,
606+
model_engine: Optional[PyTorchModelEngine],
645607
kv_cache_manager_cls,
646608
mapping: Mapping,
647609
kv_cache_config: KvCacheConfig,
@@ -654,13 +616,42 @@ def _create_kv_cache_manager(
654616
max_beam_width: int,
655617
kv_connector_manager: Optional[KvCacheConnectorManager],
656618
estimating_kv_cache: bool,
657-
execution_stream: Optional[torch.cuda.Stream] = None) -> KVCacheManager:
619+
execution_stream: Optional[torch.cuda.Stream] = None,
620+
# Optional overrides for one-model draft case (when model_engine is None)
621+
model_config: Optional[ModelConfig] = None,
622+
dtype: Optional[torch.dtype] = None,
623+
is_draft: Optional[bool] = None,
624+
layer_mask: Optional[List[bool]] = None) -> KVCacheManager:
658625
"""
626+
Create a KVCacheManager instance.
627+
628+
Args:
629+
model_engine: The model engine (can be None if model_config is provided)
630+
model_config: Optional ModelConfig to use instead of extracting from model_engine
631+
dtype: Optional dtype override (required if model_engine is None)
632+
is_draft: Optional is_draft flag override (required if model_engine is None)
633+
layer_mask: Optional layer mask for one-model draft KV cache
634+
659635
Returns:
660-
A KVCacheManager instance for the given model_engine
636+
A KVCacheManager instance
661637
"""
662-
config = model_engine.model.model_config.pretrained_config
663-
quant_config = model_engine.model.model_config.quant_config
638+
# Extract config from model_engine or use provided model_config
639+
if model_config is not None:
640+
config = model_config.pretrained_config
641+
quant_config = model_config.quant_config
642+
_model_config = model_config
643+
else:
644+
config = model_engine.model.model_config.pretrained_config
645+
quant_config = model_engine.model.model_config.quant_config
646+
_model_config = model_engine.model.model_config
647+
648+
# Determine dtype
649+
if dtype is None:
650+
dtype = model_engine.dtype
651+
652+
# Determine is_draft
653+
if is_draft is None:
654+
is_draft = model_engine.is_draft_model
664655

665656
hidden_size = config.hidden_size
666657
num_attention_heads = config.num_attention_heads
@@ -676,8 +667,7 @@ def _create_kv_cache_manager(
676667
):
677668
kv_cache_dtype = tensorrt_llm.bindings.DataType.NVFP4
678669
else:
679-
kv_cache_dtype = str_dtype_to_binding(
680-
torch_dtype_to_str(model_engine.dtype))
670+
kv_cache_dtype = str_dtype_to_binding(torch_dtype_to_str(dtype))
681671

682672
num_hidden_layers = config.num_hidden_layers
683673

@@ -695,12 +685,13 @@ def _create_kv_cache_manager(
695685
dtype=kv_cache_dtype,
696686
spec_config=spec_config,
697687
max_beam_width=max_beam_width,
698-
is_draft=model_engine.is_draft_model,
688+
is_draft=is_draft,
699689
kv_connector_manager=kv_connector_manager
700690
if not estimating_kv_cache else None,
701691
sparse_attn_config=sparse_attn_config,
702692
is_estimating_kv_cache=estimating_kv_cache,
703693
execution_stream=execution_stream,
694+
layer_mask=layer_mask,
704695
)
705696
elif is_nemotron_hybrid(config):
706697
if max_beam_width > 1:
@@ -712,9 +703,10 @@ def _create_kv_cache_manager(
712703
"Connector manager is not supported for MambaHybridCacheManager."
713704
)
714705

715-
config = model_engine.model.model_config.pretrained_config
716706
num_layers = config.hybrid_override_pattern.count("*")
717-
layer_mask = [char == "*" for char in config.hybrid_override_pattern]
707+
hybrid_layer_mask = [
708+
char == "*" for char in config.hybrid_override_pattern
709+
]
718710
mamba_num_layers = config.hybrid_override_pattern.count("M")
719711
mamba_layer_mask = [
720712
char == "M" for char in config.hybrid_override_pattern
@@ -729,12 +721,13 @@ def _create_kv_cache_manager(
729721
mamba_num_layers,
730722
mamba_layer_mask,
731723
config.torch_dtype,
732-
model_engine.model.model_config.quant_config.mamba_ssm_cache_dtype,
724+
quant_config.mamba_ssm_cache_dtype
725+
if quant_config is not None else None,
733726
# kv cache parameters
734727
kv_cache_config,
735728
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
736729
num_layers=num_layers,
737-
layer_mask=layer_mask,
730+
layer_mask=hybrid_layer_mask,
738731
num_kv_heads=num_key_value_heads,
739732
head_dim=head_dim,
740733
tokens_per_block=tokens_per_block,
@@ -755,13 +748,12 @@ def _create_kv_cache_manager(
755748
raise NotImplementedError(
756749
"Connector manager is not supported for MambaHybridCacheManager."
757750
)
758-
config = model_engine.model.model_config.pretrained_config
759751
mamba_layer_mask = [
760752
True if i %
761753
config.full_attention_interval != config.full_attention_interval -
762754
1 else False for i in range(num_hidden_layers)
763755
]
764-
layer_mask = [
756+
hybrid_layer_mask = [
765757
False if i %
766758
config.full_attention_interval != config.full_attention_interval -
767759
1 else True for i in range(num_hidden_layers)
@@ -779,12 +771,13 @@ def _create_kv_cache_manager(
779771
num_mamba_layers,
780772
mamba_layer_mask,
781773
config.torch_dtype,
782-
model_engine.model.model_config.quant_config.mamba_ssm_cache_dtype,
774+
quant_config.mamba_ssm_cache_dtype
775+
if quant_config is not None else None,
783776
# kv cache parameters
784777
kv_cache_config,
785778
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
786779
num_layers=num_layers,
787-
layer_mask=layer_mask,
780+
layer_mask=hybrid_layer_mask,
788781
num_kv_heads=num_key_value_heads,
789782
head_dim=head_dim,
790783
tokens_per_block=tokens_per_block,
@@ -800,7 +793,7 @@ def _create_kv_cache_manager(
800793
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager
801794
is_vswa = kv_cache_config.max_attention_window is not None and len(
802795
set(kv_cache_config.max_attention_window)) > 1
803-
binding_model_config = model_engine.model.model_config.get_bindings_model_config(
796+
binding_model_config = _model_config.get_bindings_model_config(
804797
tokens_per_block=tokens_per_block) if is_vswa else None
805798

806799
kv_cache_manager = kv_cache_manager_cls(
@@ -818,12 +811,13 @@ def _create_kv_cache_manager(
818811
max_num_tokens=max_num_tokens,
819812
model_config=binding_model_config,
820813
max_beam_width=max_beam_width,
821-
is_draft=model_engine.is_draft_model,
814+
is_draft=is_draft,
822815
kv_connector_manager=kv_connector_manager
823816
if not estimating_kv_cache else None,
824817
sparse_attn_config=sparse_attn_config,
825818
is_estimating_kv_cache=estimating_kv_cache,
826819
execution_stream=execution_stream,
820+
layer_mask=layer_mask,
827821
)
828822
return kv_cache_manager
829823

0 commit comments

Comments
 (0)