Skip to content

Commit 342db06

Browse files
committed
[TRTLLM-10279][feat] Support different KV cache layout for one-model spec dec
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
1 parent 3bd319d commit 342db06

File tree

16 files changed

+504
-101
lines changed

16 files changed

+504
-101
lines changed

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing_extensions import Self
1111

1212
from tensorrt_llm.functional import AttentionMaskType
13+
from tensorrt_llm.logger import logger
1314
from tensorrt_llm.models.modeling_utils import QuantConfig
1415

1516
from ..utils import get_global_attrs, get_model_extra_attrs
@@ -61,6 +62,9 @@ class FlashInferAttentionMetadata(AttentionMetadata):
6162
# so set kv_layout as "HND" here
6263
kv_layout: Literal["NHD", "HND"] = "HND"
6364

65+
# Draft KV cache manager for one-model speculative decoding.
66+
draft_kv_cache_manager: Optional[object] = None
67+
6468
paged_kv_indptr_decode: torch.Tensor = field(init=False)
6569
paged_kv_indptr_prefill: torch.Tensor = field(init=False)
6670
_paged_kv_indices: torch.Tensor = field(init=False, repr=False)
@@ -127,6 +131,11 @@ def positions(self) -> torch.Tensor:
127131

128132
def __post_init__(self) -> None:
129133
super().__post_init__()
134+
if self.draft_kv_cache_manager is not None:
135+
logger.warning(
136+
"draft_kv_cache_manager is not supported in FlashInfer backend. "
137+
"One-model speculative decoding with separate KV cache layouts "
138+
"may not work correctly.")
130139
self._post_init_with_buffers(self.cuda_graph_buffers)
131140

132141
def _post_init_with_buffers(self, buffers) -> None:

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ class AttentionMetadata:
5656
max_num_sequences: Optional[int] = None
5757
# The KV cache manager.
5858
kv_cache_manager: KVCacheManager
59+
# Draft KV cache manager for one-model speculative decoding.
60+
# Used when draft and target models have different KV cache layouts.
61+
draft_kv_cache_manager: Optional[KVCacheManager] = None
5962
mapping: Optional[Mapping] = None
6063

6164
enable_flash_mla: bool = False

tensorrt_llm/_torch/attention_backend/sparse/rocket.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,7 @@ def add_dummy_requests(
974974
use_mrope: bool = False,
975975
max_beam_width: int = 1,
976976
num_extra_decoding_steps: int = 0,
977+
draft_kv_cache_manager=None,
977978
):
978979
requests = super().add_dummy_requests(
979980
request_ids=request_ids,
@@ -984,6 +985,7 @@ def add_dummy_requests(
984985
use_mrope=use_mrope,
985986
max_beam_width=max_beam_width,
986987
num_extra_decoding_steps=num_extra_decoding_steps,
988+
draft_kv_cache_manager=draft_kv_cache_manager,
987989
)
988990
if prepare_resource:
989991
for req in requests:

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..speculative.utils import SpecDecodingTensor
1111
from ..speculative.interface import SpecMetadata
1212
from ..speculative.spec_tree_manager import SpecTreeManager
13+
from ..pyexecutor.resource_manager import KVCacheManager
1314

1415
from tensorrt_llm._utils import get_sm_version
1516
from tensorrt_llm.bindings.internal import thop
@@ -647,6 +648,9 @@ class TrtllmAttentionMetadata(AttentionMetadata):
647648
init=True,
648649
repr=False)
649650

651+
# Draft KV cache manager for one-model speculative decoding with separate KV cache layouts
652+
draft_kv_cache_manager: Optional["KVCacheManager"] = None
653+
650654
# Flags to enable spec-dec mode (multi-query mode) in TRTLLM XQA Kernels
651655
# spec decoding mode can be enabled for non-TRTLLM-gen kernels (pre-Blackwell XQA kernels)
652656
# is_spec_decoding_enabled specifies if spec-dec mode is supported for the entire runtime.
@@ -796,6 +800,29 @@ def _post_init_with_buffers(self, buffers) -> None:
796800
)
797801
self.block_ids_per_seq = None
798802
self.kv_block_ids_per_seq = None
803+
804+
# Allocate separate block offset tensors for draft KV cache manager
805+
# Used in one-model speculative decoding with different KV cache layouts
806+
if self.draft_kv_cache_manager is not None:
807+
self.draft_kv_cache_block_offsets = self.get_empty(
808+
buffers,
809+
[
810+
self.draft_kv_cache_manager.num_pools,
811+
self.max_num_sequences, 2,
812+
self.draft_kv_cache_manager.max_blocks_per_seq
813+
],
814+
cache_name="draft_kv_cache_block_offsets",
815+
dtype=torch.int32,
816+
capture_graph=capture_graph,
817+
)
818+
self.draft_host_kv_cache_block_offsets = torch.empty_like(
819+
self.draft_kv_cache_block_offsets,
820+
device='cpu',
821+
pin_memory=True,
822+
)
823+
else:
824+
self.draft_kv_cache_block_offsets = None
825+
self.draft_host_kv_cache_block_offsets = None
799826
if self.enable_flash_mla:
800827
self.block_ids_per_seq = self.get_empty(
801828
buffers,
@@ -1007,6 +1034,25 @@ def prepare(self) -> None:
10071034
assert self.kv_lens[:self.num_seqs].max(
10081035
) <= self.kv_cache_manager.max_seq_len, error_message
10091036

1037+
# Also prepare draft KV cache block offsets if draft_kv_cache_manager exists
1038+
if self.draft_kv_cache_manager is not None:
1039+
# Copy blocks for all context requests
1040+
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
1041+
self.draft_host_kv_cache_block_offsets,
1042+
self.request_ids[:self.num_contexts], 1, 0)
1043+
# Copy blocks for all generation requests
1044+
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
1045+
self.draft_host_kv_cache_block_offsets,
1046+
self.request_ids[self.num_contexts:], self.beam_width,
1047+
self.num_contexts)
1048+
for pool_idx in range(
1049+
self.draft_host_kv_cache_block_offsets.shape[0]):
1050+
self.draft_kv_cache_block_offsets[
1051+
pool_idx, :self.num_seqs].copy_(
1052+
self.draft_host_kv_cache_block_offsets[
1053+
pool_idx, :self.num_seqs],
1054+
non_blocking=True)
1055+
10101056
self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
10111057
# Don't use self.kv_lens here because it includes extra tokens.
10121058
# Use actual KV length (without extra tokens) for kv_lens_runtime,

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
WeightsLoadingConfig)
1818
from ..modules.rms_norm import RMSNorm
1919
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
20-
from ..speculative import SpecMetadata, get_spec_worker
20+
from ..speculative import (SpecMetadata, get_spec_worker,
21+
should_use_separate_draft_kv_cache)
2122
from ..utils import AuxStreamType
2223
from .checkpoints.base_weight_mapper import BaseWeightMapper
2324
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel,
@@ -880,6 +881,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
880881
vocab_size=model_config.pretrained_config.vocab_size)
881882
self.draft_model = None
882883
self.draft_config = None
884+
self.use_separate_draft_kv_cache = False
883885
spec_config = getattr(model_config, 'spec_config', None)
884886
if spec_config and spec_config.spec_dec_mode.use_one_engine():
885887
if spec_config.spec_dec_mode.is_eagle3_one_model():
@@ -913,11 +915,16 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
913915
self.draft_config.quant_config.kv_cache_quant_algo = \
914916
model_config.quant_config.kv_cache_quant_algo
915917

918+
self.use_separate_draft_kv_cache = should_use_separate_draft_kv_cache(
919+
spec_config)
920+
916921
self.draft_model = get_draft_model(model_config, self.draft_config,
917922
self.lm_head, self.model)
918-
self.spec_worker = get_spec_worker(model_config.spec_config,
919-
model_config,
920-
model_config.mapping)
923+
self.spec_worker = get_spec_worker(
924+
model_config.spec_config,
925+
model_config,
926+
model_config.mapping,
927+
use_separate_draft_kv_cache=self.use_separate_draft_kv_cache)
921928

922929
if self.draft_config is not None and model_config.spec_config.eagle3_model_arch == "llama3":
923930
for key, value in self.draft_config.extra_attrs.items():
@@ -934,6 +941,7 @@ def forward(
934941
inputs_embeds: Optional[torch.FloatTensor] = None,
935942
return_context_logits: bool = False,
936943
spec_metadata: Optional[SpecMetadata] = None,
944+
resource_manager=None,
937945
**kwargs,
938946
) -> torch.Tensor:
939947
hidden_states = self.model(
@@ -978,7 +986,8 @@ def forward(
978986
logits=logits,
979987
attn_metadata=attn_metadata,
980988
spec_metadata=spec_metadata,
981-
draft_model=self.draft_model)
989+
draft_model=self.draft_model,
990+
resource_manager=resource_manager)
982991
else:
983992
logits = self.logits_processor.forward(
984993
hidden_states,

0 commit comments

Comments
 (0)