Skip to content
Open
14 changes: 14 additions & 0 deletions .claude/settings.local.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"permissions": {
"allow": [
"Bash(python -c \"from tensorrt_llm._torch.pyexecutor._util import KvCacheCreator; from tensorrt_llm._torch.pyexecutor.py_executor_creator import create_py_executor; print\\(''Import successful''\\)\")",
"Bash(nvidia-smi --query-gpu=index,name,memory.used,memory.total --format=csv,noheader)",
"Bash(timeout 600 pytest -s tests/integration/defs/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True])",
"Bash(pkill -f 'trtllm-serve')",
"Bash(pkill -f 'mpi4py')",
"Bash(nvidia-smi --query-gpu=index,memory.used --format=csv,noheader)",
"Bash(timeout 300 pytest -s tests/integration/defs/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=True])",
"Bash(pkill -f 'orted')"
]
}
}
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class AttentionMetadata:
max_num_sequences: Optional[int] = None
# The KV cache manager.
kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2]
# Draft KV cache manager for one-model speculative decoding with separate KV cache layouts
draft_kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2] = None
mapping: Optional[Mapping] = None

enable_flash_mla: bool = False
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/attention_backend/sparse/rocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,7 @@ def add_dummy_requests(
use_mrope: bool = False,
max_beam_width: int = 1,
num_extra_decoding_steps: int = 0,
draft_kv_cache_manager=None,
):
requests = super().add_dummy_requests(
request_ids=request_ids,
Expand All @@ -984,6 +985,7 @@ def add_dummy_requests(
use_mrope=use_mrope,
max_beam_width=max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps,
draft_kv_cache_manager=draft_kv_cache_manager,
)
if prepare_resource:
for req in requests:
Expand Down
46 changes: 46 additions & 0 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,12 @@ class TrtllmAttentionMetadata(AttentionMetadata):
helix_is_inactive_rank: Optional[torch.Tensor] = None
helix_is_inactive_rank_cpu: Optional[torch.Tensor] = None

# Block offsets for the target and draft KV caches
kv_cache_block_offsets: Optional[torch.Tensor] = None
host_kv_cache_block_offsets: Optional[torch.Tensor] = None
draft_kv_cache_block_offsets: Optional[torch.Tensor] = None
draft_host_kv_cache_block_offsets: Optional[torch.Tensor] = None

@property
def max_seq_len(self) -> int:
"""
Expand Down Expand Up @@ -786,6 +792,27 @@ def _post_init_with_buffers(self, buffers) -> None:
)
self.block_ids_per_seq = None
self.kv_block_ids_per_seq = None

# Allocate separate block offset tensors for draft KV cache manager
# Used in one-model speculative decoding with different KV cache layouts
if self.draft_kv_cache_manager is not None:
self.draft_kv_cache_block_offsets = self.get_empty(
buffers,
[
self.draft_kv_cache_manager.num_pools,
self.max_num_sequences, 2,
self.draft_kv_cache_manager.max_blocks_per_seq
],
cache_name="draft_kv_cache_block_offsets",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.draft_host_kv_cache_block_offsets = torch.empty_like(
self.draft_kv_cache_block_offsets,
device='cpu',
pin_memory=True,
)

if self.enable_flash_mla:
self.block_ids_per_seq = self.get_empty(
buffers,
Expand Down Expand Up @@ -987,6 +1014,25 @@ def prepare(self) -> None:
assert self.kv_lens[:self.num_seqs].max(
) <= self.kv_cache_manager.max_seq_len, error_message

# Also prepare draft KV cache block offsets if draft_kv_cache_manager exists
if self.draft_kv_cache_manager is not None:
# Copy blocks for all context requests
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
self.draft_host_kv_cache_block_offsets,
self.request_ids[:self.num_contexts], 1, 0)
# Copy blocks for all generation requests
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
self.draft_host_kv_cache_block_offsets,
self.request_ids[self.num_contexts:], self.beam_width,
self.num_contexts)
for pool_idx in range(
self.draft_host_kv_cache_block_offsets.shape[0]):
self.draft_kv_cache_block_offsets[
pool_idx, :self.num_seqs].copy_(
self.draft_host_kv_cache_block_offsets[
pool_idx, :self.num_seqs],
non_blocking=True)

self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
# Don't use self.kv_lens here because it includes extra tokens.
# Use actual KV length (without extra tokens) for kv_lens_runtime,
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,7 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
return_context_logits: bool = False,
resource_manager=None,
**kwargs,
) -> torch.Tensor:
return super().forward(attn_metadata=attn_metadata,
Expand All @@ -1850,6 +1851,7 @@ def forward(
inputs_embeds=inputs_embeds,
spec_metadata=spec_metadata,
return_context_logits=return_context_logits,
resource_manager=resource_manager,
**kwargs)

def load_weights(self, weights: ConsumableWeightsDict):
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,7 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
return_context_logits: bool = False,
spec_metadata: Optional[SpecMetadata] = None,
resource_manager=None,
**kwargs,
) -> torch.Tensor:
multimodal_params = kwargs.get("multimodal_params", [])
Expand All @@ -1422,7 +1423,8 @@ def forward(
position_ids,
inputs_embeds,
spec_metadata=spec_metadata,
return_context_logits=return_context_logits)
return_context_logits=return_context_logits,
resource_manager=resource_manager)

def infer_max_seq_len(self):
if self.model_config.attn_backend.upper() != 'TRTLLM':
Expand Down
19 changes: 14 additions & 5 deletions tensorrt_llm/_torch/models/modeling_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
WeightsLoadingConfig)
from ..modules.rms_norm import RMSNorm
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
from ..speculative import SpecMetadata, get_spec_worker
from ..speculative import (SpecMetadata, get_spec_worker,
should_use_separate_draft_kv_cache)
from ..utils import AuxStreamType
from .checkpoints.base_weight_mapper import BaseWeightMapper
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel,
Expand Down Expand Up @@ -914,6 +915,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
vocab_size=model_config.pretrained_config.vocab_size)
self.draft_model = None
self.draft_config = None
self.use_separate_draft_kv_cache = False
spec_config = getattr(model_config, 'spec_config', None)
if spec_config and spec_config.spec_dec_mode.use_one_engine():
if spec_config.spec_dec_mode.is_eagle3_one_model():
Expand Down Expand Up @@ -947,11 +949,16 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
self.draft_config.quant_config.kv_cache_quant_algo = \
model_config.quant_config.kv_cache_quant_algo

self.use_separate_draft_kv_cache = should_use_separate_draft_kv_cache(
spec_config)

self.draft_model = get_draft_model(model_config, self.draft_config,
self.lm_head, self.model)
self.spec_worker = get_spec_worker(model_config.spec_config,
model_config,
model_config.mapping)
self.spec_worker = get_spec_worker(
model_config.spec_config,
model_config,
model_config.mapping,
use_separate_draft_kv_cache=self.use_separate_draft_kv_cache)
self.epilogue.append(self.draft_model)
self.epilogue.append(self.spec_worker)

Expand All @@ -970,6 +977,7 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
return_context_logits: bool = False,
spec_metadata: Optional[SpecMetadata] = None,
resource_manager=None,
**kwargs,
) -> torch.Tensor:
hidden_states = self.model(
Expand Down Expand Up @@ -1013,7 +1021,8 @@ def forward(
logits=logits,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
draft_model=self.draft_model)
draft_model=self.draft_model,
resource_manager=resource_manager)
else:
logits = self.logits_processor.forward(
hidden_states,
Expand Down
Loading
Loading