Skip to content

Commit c85ffe5

Browse files
committed
Fix for CUDA graph padding
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
1 parent d3cea1a commit c85ffe5

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ..modules.multi_stream_utils import with_multi_stream
1717
from ..speculative.eagle3 import Eagle3ResourceManager
1818
from ..speculative.mtp import SampleStateTensorsMTP
19+
from ..speculative.utils import get_draft_kv_cache_manager
1920
from ..utils import make_weak_ref, piecewise_cuda_graph
2021
from .llm_request import get_draft_token_length
2122
from .mamba_cache_manager import MambaCacheManager
@@ -439,12 +440,19 @@ def _get_padded_batch(self, batch: ScheduledRequests,
439440
if available_blocks < 1:
440441
return 0
441442

443+
# Get draft KV cache manager only for one-model speculative decoding.
444+
# In two-model mode, each model has its own KV cache manager, so
445+
# draft_kv_cache_manager should be None.
446+
draft_kv_cache_manager = get_draft_kv_cache_manager(
447+
self.spec_config, resource_manager)
448+
442449
self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
443450
[CUDA_GRAPH_DUMMY_REQUEST_ID],
444451
is_gen=True,
445452
max_num_draft_tokens=runtime_draft_len,
446453
use_mrope=self.config.use_mrope,
447-
max_beam_width=self.config.max_beam_width)[0]
454+
max_beam_width=self.config.max_beam_width,
455+
draft_kv_cache_manager=draft_kv_cache_manager)[0]
448456
self.padding_dummy_request.is_cuda_graph_dummy = True
449457
spec_res_mgr = resource_manager.get_resource_manager(
450458
ResourceManagerType.SPEC_RESOURCE_MANAGER)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
from ..models.modeling_utils import DecoderModelForCausalLM
4545
from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer,
4646
MoeLoadBalancerIterContext)
47-
from ..speculative import (SpecMetadata, get_num_extra_kv_tokens,
48-
get_spec_metadata,
47+
from ..speculative import (SpecMetadata, get_draft_kv_cache_manager,
48+
get_num_extra_kv_tokens, get_spec_metadata,
4949
update_spec_config_from_model_config)
5050
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
5151
from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata,
@@ -550,12 +550,7 @@ def _get_draft_kv_cache_manager(
550550
Returns the draft KV cache manager only in one-model speculative decoding
551551
mode where the target model manages a separate draft KV cache.
552552
"""
553-
if self.spec_config is None:
554-
return None
555-
if not self.spec_config.spec_dec_mode.use_one_engine():
556-
return None
557-
return resource_manager.get_resource_manager(
558-
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
553+
return get_draft_kv_cache_manager(self.spec_config, resource_manager)
559554

560555
@contextmanager
561556
def set_warmup_flag(self):

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,21 @@ def get_num_extra_kv_tokens(spec_config):
247247
return 0
248248

249249

250+
def get_draft_kv_cache_manager(spec_config, resource_manager):
251+
"""
252+
Returns the draft KV cache manager only in one-model speculative decoding
253+
mode where the target model manages a separate draft KV cache.
254+
"""
255+
from ..pyexecutor.resource_manager import ResourceManagerType
256+
257+
if spec_config is None:
258+
return None
259+
if not spec_config.spec_dec_mode.use_one_engine():
260+
return None
261+
return resource_manager.get_resource_manager(
262+
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
263+
264+
250265
def update_spec_config_from_model_config(spec_config, model_config):
251266
if spec_config.spec_dec_mode.is_mtp_one_model():
252267
# Use `max_draft_len` for several low-level APIs. TODO: Remove this after distinguishing them.

0 commit comments

Comments
 (0)