Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@
from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler,
TRTLLMSampler)
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
KVCacheV2DummyScheduler, SimpleScheduler,
SimpleUnifiedScheduler)
SimpleScheduler, SimpleUnifiedScheduler)
from .seq_slot_manager import SeqSlotManager

GB = 1 << 30
Expand Down Expand Up @@ -860,33 +859,30 @@ def create_py_executor_instance(
if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager:
scheduler_capacity += 1

# KVCacheManagerV2 always uses Python scheduler (SimpleUnifiedScheduler)
# regardless of TLLM_USE_PYTHON_SCHEDULER environment variable
use_python_scheduler = os.getenv("TLLM_USE_PYTHON_SCHEDULER", "0") == "1"
if use_python_scheduler and not isinstance(kv_cache_manager,
KVCacheManagerV2):
is_kv_cache_v2 = isinstance(kv_cache_manager, KVCacheManagerV2)

if is_kv_cache_v2 or use_python_scheduler:
scheduler = SimpleUnifiedScheduler(
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens,
kv_cache_manager=kv_cache_manager.impl
if kv_cache_manager is not None else None,
kv_cache_manager=kv_cache_manager if is_kv_cache_v2 else
(kv_cache_manager.impl if kv_cache_manager is not None else None),
peft_cache_manager=peft_cache_manager.impl
if peft_cache_manager is not None else None,
scheduler_policy=scheduler_config.capacity_scheduler_policy,
ctx_chunk_config=ctx_chunk_config,
two_step_lookahead=mapping.has_pp(),
scheduler_capacity=scheduler_capacity)
else:
if isinstance(kv_cache_manager, KVCacheManagerV2):
capacity_scheduler = KVCacheV2DummyScheduler(
scheduler_capacity,
kv_cache_manager if kv_cache_manager is not None else None)
else:
capacity_scheduler = BindCapacityScheduler(
scheduler_capacity,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl
if peft_cache_manager is not None else None,
scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())
capacity_scheduler = BindCapacityScheduler(
scheduler_capacity,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl if peft_cache_manager is not None else None,
scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())

mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
ctx_chunk_config)
Expand Down
241 changes: 241 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_KVCache)
from tensorrt_llm.runtime.kv_cache_manager_v2._common import GPU_LEVEL
from tensorrt_llm.runtime.kv_cache_manager_v2._config import DataRole
from tensorrt_llm.runtime.kv_cache_manager_v2._exceptions import OutOfPagesError
from tensorrt_llm.runtime.kv_cache_manager_v2._utils import (exact_div,
typed_range)
from tensorrt_llm.sampling_params import SamplingParams
Expand Down Expand Up @@ -1471,6 +1472,8 @@ def __init__(
else:
self.max_attention_window_vec = [None]

self._scheduler_prepared_resources = False # Track if scheduler handled resources

if isinstance(num_kv_heads, int):
self.num_kv_heads_per_layer = [
(num_kv_heads + tp_size - 1) // tp_size
Expand Down Expand Up @@ -1707,6 +1710,25 @@ def get_num_free_blocks(self) -> int:

@nvtx_range("prepare_resources_kv_cache_manager_v2")
def prepare_resources(self, scheduled_batch: ScheduledRequests):
"""
Prepare resources for scheduled requests.

For MAX_UTILIZATION policy, resources are already allocated by the scheduler's
prepare_resources method, so we check the flag and skip allocation.
For other policies (GUARANTEED_NO_EVICT), we allocate resources here.
"""
# Check if the scheduler already prepared resources
# TODO: remove this flag and make it assertion after kv_cache_v2 dummy scheduler is removed
if self._scheduler_prepared_resources:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall design of kv cache manger v2 is allocating resources in the scheduling stage. We can delete the prepare resource here and only do the assertion here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a To Do and will remove it as long as dummy scheduler is not needed.

# Reset flag for next round
self._scheduler_prepared_resources = False
else:
# Resources not allocated by scheduler, do it here (GUARANTEED_NO_EVICT path)
self._prepare_resources_guaranteed_no_evict(scheduled_batch)

def _prepare_resources_guaranteed_no_evict(
self, scheduled_batch: ScheduledRequests):
"""Prepare resources for GUARANTEED_NO_EVICT scheduling policy."""
with request_context(self.is_draft, scheduled_batch):
context_batch = scheduled_batch.context_requests
generation_batch = scheduled_batch.generation_requests
Expand Down Expand Up @@ -1767,6 +1789,225 @@ def _kv_connector_should_add_sequence(self, request: LlmRequest) -> bool:
return self.kv_connector_manager is None or self.kv_connector_manager.should_add_sequence(
request)

def _can_evict_request(self, req: LlmRequest) -> bool:
"""Check if a request is eligible for eviction."""
if req.state == LlmRequestState.GENERATION_IN_PROGRESS:
return True
elif req.state == LlmRequestState.CONTEXT_INIT and \
hasattr(req, 'context_current_position') and \
req.context_current_position > 0:
return True
return False

def _try_evict_from_list(
self, request_list, current_req: LlmRequest,
evicted_requests: List[LlmRequest]) -> Optional[LlmRequest]:
"""Try to evict a request from the given list (LIFO order)."""
for req in request_list:
if req == current_req:
continue

if req in evicted_requests:
continue

req_id = req.py_request_id

if req_id not in self.kv_cache_map:
continue

kv_cache = self.kv_cache_map[req_id]

if kv_cache.status is not _KVCache.Status.ACTIVE:
continue

if not self._can_evict_request(req):
continue

kv_cache.suspend()
return req

return None

def _try_evict_request_for_capacity(
self, current_req: LlmRequest,
new_generation_batch: List[LlmRequest],
new_context_batch: Optional[List[LlmRequest]],
context_requests: List[LlmRequest],
generation_requests: List[LlmRequest],
evicted_requests: List[LlmRequest]) -> Optional[LlmRequest]:
"""
Try to evict requests to make room for capacity allocation.

Based on LIFO (Last In First Out) eviction strategy:
- Try to evict from new_generation_batch first (most recent)
- If no candidate found, try from all scheduled requests (updated lists)

Returns:
Evicted request or None
"""
# Try to evict from new_generation_batch first (LIFO)
evicted_request = self._try_evict_from_list(
reversed(new_generation_batch), current_req, evicted_requests)

if evicted_request is None:
# If no candidate in generation batch, try all scheduled requests
# Use updated lists (new_context_batch) if available, otherwise use original lists
if new_context_batch is not None:
all_scheduled_requests = new_context_batch + new_generation_batch
else:
all_scheduled_requests = context_requests + generation_requests

evicted_request = self._try_evict_from_list(all_scheduled_requests,
current_req,
evicted_requests)

return evicted_request

@nvtx_range("prepare_resources_for_max_utilization")
def prepare_resources_for_max_utilization(
self, context_requests: List[LlmRequest],
generation_requests: List[LlmRequest]
) -> tuple[List[LlmRequest], List[LlmRequest]]:
"""
Allocate KV cache resources for max utilization scheduling.
Handles eviction when out of pages.

Args:
context_requests: List of context requests to schedule
generation_requests: List of generation requests to schedule

Returns:
Tuple of (scheduled_context, scheduled_generation) after resource allocation
"""
evicted_requests: List[LlmRequest] = []

# Create a ScheduledRequest object for context management
scheduled_batch = ScheduledRequests()
scheduled_batch.context_requests = list(context_requests)
scheduled_batch.generation_requests = list(generation_requests)

with request_context(self.is_draft, scheduled_batch):
new_generation_batch: List[LlmRequest] = []

for req in generation_requests:
if req in evicted_requests:
continue

# Handle missing kv_cache_map entry
if req.py_request_id not in self.kv_cache_map:
logger.warning(
f"Request {req.py_request_id} not in kv_cache_map, skipping"
)
continue

kv_cache = self.kv_cache_map[req.py_request_id]

if not kv_cache.is_active:
result = kv_cache.resume(
torch.cuda.current_stream().cuda_stream)
if not result:
continue

# Max Utilization Scheduler: Try to increase capacity for generation
# Recursively try to evict requests until we have enough capacity
max_eviction_attempts = len(generation_requests) - len(
evicted_requests)
capacity_increased = False

for _ in range(max_eviction_attempts):
try:
kv_cache.capacity += 1
new_generation_batch.append(req)
capacity_increased = True
break
except OutOfPagesError:
evicted = self._try_evict_request_for_capacity(
req, new_generation_batch, None, context_requests,
generation_requests, evicted_requests)
if evicted is None:
# No more requests to evict
break
if evicted in new_generation_batch:
new_generation_batch.remove(evicted)
evicted_requests.append(evicted)

if not capacity_increased:
# Could not increase capacity even after evicting all possible requests
continue

# Allocate KV Cache for context requests
new_context_batch: List[LlmRequest] = []
for req in context_requests:
beam_width = req.sampling_config.beam_width
if 'cp_type' in self.mapping.cp_config and CpType.STAR == self.mapping.cp_config[
'cp_type']:
raise RuntimeError(
"Star attention is not supported for kv cache manager v2"
)
else:
kv_cache = None
if req.is_first_context_chunk and self._kv_connector_should_add_sequence(
req):
if req.py_request_id in self.kv_cache_map:
kv_cache = self.kv_cache_map[req.py_request_id]
else:
# Last token cannot be recovered, so we don't include it in the input tokens to look up for the block that can be reused.
kv_cache = self._create_kv_cache(
req.py_request_id, req.lora_task_id,
req.get_tokens(0)[:-1]
if self.enable_block_reuse else None)
assert beam_width == 1, "Currently, KVCacheManagerV2 only supports beam width 1"
if not self.enable_block_reuse:
assert kv_cache.num_committed_tokens == 0
kv_cache.stop_committing()
else:
req.context_current_position = kv_cache.num_committed_tokens
chunk_size = req.context_chunk_size
if req.context_current_position + req.context_chunk_size < req.prompt_len:
floored_end_position = (
req.context_current_position +
req.context_chunk_size
) // self.tokens_per_block * self.tokens_per_block
chunk_size = floored_end_position - req.context_current_position

req.context_chunk_size = min(
chunk_size,
req.prompt_len - req.context_current_position)

success = kv_cache.resume(
torch.cuda.current_stream().cuda_stream)
if not success:
continue
try:
kv_cache.capacity = req.prompt_len
new_context_batch.append(req)
except OutOfPagesError:
kv_cache.suspend()
continue

if self.kv_connector_manager is not None:
block_ids = self.get_cache_indices(req)
self.kv_connector_manager.update_state_after_alloc(
req, block_ids)
else:
assert req.py_request_id in self.kv_cache_map, f"req.py_request_id {req.py_request_id} not in kv_cache_map"
kv_cache = self.kv_cache_map[req.py_request_id]
assert kv_cache.status is _KVCache.Status.ACTIVE, f"kv_cache {req.py_request_id} is not active"
new_context_batch.append(req)

# Update scheduled_batch for kv_connector_manager
scheduled_batch.context_requests = new_context_batch
scheduled_batch.generation_requests = new_generation_batch

if self.kv_connector_manager is not None:
self.kv_connector_manager.build_scheduler_output(
scheduled_batch, self)

# Set flag to indicate scheduler handled resource preparation
self._scheduler_prepared_resources = True

return new_context_batch, new_generation_batch

def get_kv_cache_stats(self):

class KVCacheStatus:
Expand Down
Loading