diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 4fea1e0b4e6..0e55f4d35e0 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -40,8 +40,7 @@ from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler, TRTLLMSampler) from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler, - KVCacheV2DummyScheduler, SimpleScheduler, - SimpleUnifiedScheduler) + SimpleScheduler, SimpleUnifiedScheduler) from .seq_slot_manager import SeqSlotManager GB = 1 << 30 @@ -860,14 +859,17 @@ def create_py_executor_instance( if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager: scheduler_capacity += 1 + # KVCacheManagerV2 always uses Python scheduler (SimpleUnifiedScheduler) + # regardless of TLLM_USE_PYTHON_SCHEDULER environment variable use_python_scheduler = os.getenv("TLLM_USE_PYTHON_SCHEDULER", "0") == "1" - if use_python_scheduler and not isinstance(kv_cache_manager, - KVCacheManagerV2): + is_kv_cache_v2 = isinstance(kv_cache_manager, KVCacheManagerV2) + + if is_kv_cache_v2 or use_python_scheduler: scheduler = SimpleUnifiedScheduler( max_batch_size=max_batch_size, max_num_tokens=max_num_tokens, - kv_cache_manager=kv_cache_manager.impl - if kv_cache_manager is not None else None, + kv_cache_manager=kv_cache_manager if is_kv_cache_v2 else + (kv_cache_manager.impl if kv_cache_manager is not None else None), peft_cache_manager=peft_cache_manager.impl if peft_cache_manager is not None else None, scheduler_policy=scheduler_config.capacity_scheduler_policy, @@ -875,18 +877,12 @@ def create_py_executor_instance( two_step_lookahead=mapping.has_pp(), scheduler_capacity=scheduler_capacity) else: - if isinstance(kv_cache_manager, KVCacheManagerV2): - capacity_scheduler = KVCacheV2DummyScheduler( - scheduler_capacity, - kv_cache_manager if kv_cache_manager is not None else None) - else: - capacity_scheduler = BindCapacityScheduler( - scheduler_capacity, - kv_cache_manager.impl if kv_cache_manager is not None else None, - peft_cache_manager.impl - if peft_cache_manager is not None else None, - scheduler_config.capacity_scheduler_policy, - two_step_lookahead=mapping.has_pp()) + capacity_scheduler = BindCapacityScheduler( + scheduler_capacity, + kv_cache_manager.impl if kv_cache_manager is not None else None, + peft_cache_manager.impl if peft_cache_manager is not None else None, + scheduler_config.capacity_scheduler_policy, + two_step_lookahead=mapping.has_pp()) mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens, ctx_chunk_config) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index a5548946d88..db801baddc2 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -36,6 +36,7 @@ _KVCache) from tensorrt_llm.runtime.kv_cache_manager_v2._common import GPU_LEVEL from tensorrt_llm.runtime.kv_cache_manager_v2._config import DataRole +from tensorrt_llm.runtime.kv_cache_manager_v2._exceptions import OutOfPagesError from tensorrt_llm.runtime.kv_cache_manager_v2._utils import (exact_div, typed_range) from tensorrt_llm.sampling_params import SamplingParams @@ -1471,6 +1472,8 @@ def __init__( else: self.max_attention_window_vec = [None] + self._scheduler_prepared_resources = False # Track if scheduler handled resources + if isinstance(num_kv_heads, int): self.num_kv_heads_per_layer = [ (num_kv_heads + tp_size - 1) // tp_size @@ -1707,6 +1710,25 @@ def get_num_free_blocks(self) -> int: @nvtx_range("prepare_resources_kv_cache_manager_v2") def prepare_resources(self, scheduled_batch: ScheduledRequests): + """ + Prepare resources for scheduled requests. + + For MAX_UTILIZATION policy, resources are already allocated by the scheduler's + prepare_resources method, so we check the flag and skip allocation. + For other policies (GUARANTEED_NO_EVICT), we allocate resources here. + """ + # Check if the scheduler already prepared resources + # TODO: remove this flag and make it assertion after kv_cache_v2 dummy scheduler is removed + if self._scheduler_prepared_resources: + # Reset flag for next round + self._scheduler_prepared_resources = False + else: + # Resources not allocated by scheduler, do it here (GUARANTEED_NO_EVICT path) + self._prepare_resources_guaranteed_no_evict(scheduled_batch) + + def _prepare_resources_guaranteed_no_evict( + self, scheduled_batch: ScheduledRequests): + """Prepare resources for GUARANTEED_NO_EVICT scheduling policy.""" with request_context(self.is_draft, scheduled_batch): context_batch = scheduled_batch.context_requests generation_batch = scheduled_batch.generation_requests @@ -1767,6 +1789,225 @@ def _kv_connector_should_add_sequence(self, request: LlmRequest) -> bool: return self.kv_connector_manager is None or self.kv_connector_manager.should_add_sequence( request) + def _can_evict_request(self, req: LlmRequest) -> bool: + """Check if a request is eligible for eviction.""" + if req.state == LlmRequestState.GENERATION_IN_PROGRESS: + return True + elif req.state == LlmRequestState.CONTEXT_INIT and \ + hasattr(req, 'context_current_position') and \ + req.context_current_position > 0: + return True + return False + + def _try_evict_from_list( + self, request_list, current_req: LlmRequest, + evicted_requests: List[LlmRequest]) -> Optional[LlmRequest]: + """Try to evict a request from the given list (LIFO order).""" + for req in request_list: + if req == current_req: + continue + + if req in evicted_requests: + continue + + req_id = req.py_request_id + + if req_id not in self.kv_cache_map: + continue + + kv_cache = self.kv_cache_map[req_id] + + if kv_cache.status is not _KVCache.Status.ACTIVE: + continue + + if not self._can_evict_request(req): + continue + + kv_cache.suspend() + return req + + return None + + def _try_evict_request_for_capacity( + self, current_req: LlmRequest, + new_generation_batch: List[LlmRequest], + new_context_batch: Optional[List[LlmRequest]], + context_requests: List[LlmRequest], + generation_requests: List[LlmRequest], + evicted_requests: List[LlmRequest]) -> Optional[LlmRequest]: + """ + Try to evict requests to make room for capacity allocation. + + Based on LIFO (Last In First Out) eviction strategy: + - Try to evict from new_generation_batch first (most recent) + - If no candidate found, try from all scheduled requests (updated lists) + + Returns: + Evicted request or None + """ + # Try to evict from new_generation_batch first (LIFO) + evicted_request = self._try_evict_from_list( + reversed(new_generation_batch), current_req, evicted_requests) + + if evicted_request is None: + # If no candidate in generation batch, try all scheduled requests + # Use updated lists (new_context_batch) if available, otherwise use original lists + if new_context_batch is not None: + all_scheduled_requests = new_context_batch + new_generation_batch + else: + all_scheduled_requests = context_requests + generation_requests + + evicted_request = self._try_evict_from_list(all_scheduled_requests, + current_req, + evicted_requests) + + return evicted_request + + @nvtx_range("prepare_resources_for_max_utilization") + def prepare_resources_for_max_utilization( + self, context_requests: List[LlmRequest], + generation_requests: List[LlmRequest] + ) -> tuple[List[LlmRequest], List[LlmRequest]]: + """ + Allocate KV cache resources for max utilization scheduling. + Handles eviction when out of pages. + + Args: + context_requests: List of context requests to schedule + generation_requests: List of generation requests to schedule + + Returns: + Tuple of (scheduled_context, scheduled_generation) after resource allocation + """ + evicted_requests: List[LlmRequest] = [] + + # Create a ScheduledRequest object for context management + scheduled_batch = ScheduledRequests() + scheduled_batch.context_requests = list(context_requests) + scheduled_batch.generation_requests = list(generation_requests) + + with request_context(self.is_draft, scheduled_batch): + new_generation_batch: List[LlmRequest] = [] + + for req in generation_requests: + if req in evicted_requests: + continue + + # Handle missing kv_cache_map entry + if req.py_request_id not in self.kv_cache_map: + logger.warning( + f"Request {req.py_request_id} not in kv_cache_map, skipping" + ) + continue + + kv_cache = self.kv_cache_map[req.py_request_id] + + if not kv_cache.is_active: + result = kv_cache.resume( + torch.cuda.current_stream().cuda_stream) + if not result: + continue + + # Max Utilization Scheduler: Try to increase capacity for generation + # Recursively try to evict requests until we have enough capacity + max_eviction_attempts = len(generation_requests) - len( + evicted_requests) + capacity_increased = False + + for _ in range(max_eviction_attempts): + try: + kv_cache.capacity += 1 + new_generation_batch.append(req) + capacity_increased = True + break + except OutOfPagesError: + evicted = self._try_evict_request_for_capacity( + req, new_generation_batch, None, context_requests, + generation_requests, evicted_requests) + if evicted is None: + # No more requests to evict + break + if evicted in new_generation_batch: + new_generation_batch.remove(evicted) + evicted_requests.append(evicted) + + if not capacity_increased: + # Could not increase capacity even after evicting all possible requests + continue + + # Allocate KV Cache for context requests + new_context_batch: List[LlmRequest] = [] + for req in context_requests: + beam_width = req.sampling_config.beam_width + if 'cp_type' in self.mapping.cp_config and CpType.STAR == self.mapping.cp_config[ + 'cp_type']: + raise RuntimeError( + "Star attention is not supported for kv cache manager v2" + ) + else: + kv_cache = None + if req.is_first_context_chunk and self._kv_connector_should_add_sequence( + req): + if req.py_request_id in self.kv_cache_map: + kv_cache = self.kv_cache_map[req.py_request_id] + else: + # Last token cannot be recovered, so we don't include it in the input tokens to look up for the block that can be reused. + kv_cache = self._create_kv_cache( + req.py_request_id, req.lora_task_id, + req.get_tokens(0)[:-1] + if self.enable_block_reuse else None) + assert beam_width == 1, "Currently, KVCacheManagerV2 only supports beam width 1" + if not self.enable_block_reuse: + assert kv_cache.num_committed_tokens == 0 + kv_cache.stop_committing() + else: + req.context_current_position = kv_cache.num_committed_tokens + chunk_size = req.context_chunk_size + if req.context_current_position + req.context_chunk_size < req.prompt_len: + floored_end_position = ( + req.context_current_position + + req.context_chunk_size + ) // self.tokens_per_block * self.tokens_per_block + chunk_size = floored_end_position - req.context_current_position + + req.context_chunk_size = min( + chunk_size, + req.prompt_len - req.context_current_position) + + success = kv_cache.resume( + torch.cuda.current_stream().cuda_stream) + if not success: + continue + try: + kv_cache.capacity = req.prompt_len + new_context_batch.append(req) + except OutOfPagesError: + kv_cache.suspend() + continue + + if self.kv_connector_manager is not None: + block_ids = self.get_cache_indices(req) + self.kv_connector_manager.update_state_after_alloc( + req, block_ids) + else: + assert req.py_request_id in self.kv_cache_map, f"req.py_request_id {req.py_request_id} not in kv_cache_map" + kv_cache = self.kv_cache_map[req.py_request_id] + assert kv_cache.status is _KVCache.Status.ACTIVE, f"kv_cache {req.py_request_id} is not active" + new_context_batch.append(req) + + # Update scheduled_batch for kv_connector_manager + scheduled_batch.context_requests = new_context_batch + scheduled_batch.generation_requests = new_generation_batch + + if self.kv_connector_manager is not None: + self.kv_connector_manager.build_scheduler_output( + scheduled_batch, self) + + # Set flag to indicate scheduler handled resource preparation + self._scheduler_prepared_resources = True + + return new_context_batch, new_generation_batch + def get_kv_cache_stats(self): class KVCacheStatus: diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 6631057251f..8fdf2299ad1 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -230,6 +230,65 @@ def schedule_request( return scheduled_requests, scheduled_disagg_gen_init_requests, [] +class KVCacheV2MaxUtilizationScheduler(CapacityScheduler): + """ + Max Utilization scheduler for KVCacheManagerV2. + This scheduler maximizes GPU utilization by allowing request eviction/pausing. + """ + no_schedule_until_state = LlmRequestState.CONTEXT_INIT + no_schedule_after_state = LlmRequestState.GENERATION_COMPLETE + + def __init__(self, max_num_requests: int, kv_cache_manager): + super(KVCacheV2MaxUtilizationScheduler, self).__init__() + self.max_num_requests = max_num_requests + self.kv_cache_manager = kv_cache_manager + + def schedule_request( + self, active_requests: RequestList + ) -> tuple[list[LlmRequest], list[LlmRequest], list[LlmRequest]]: + """ + Schedule requests with max utilization policy. + Note: This is a simplified scheduler that delegates resource preparation + to the prepare_resources method. + """ + scheduled_requests = [] + scheduled_disagg_gen_init_requests = [] + + for request in active_requests: + req_state = request.state + # if request cannot be scheduled yet or request should no longer be scheduled, skip + if not req_state == LlmRequestState.DISAGG_GENERATION_INIT and ( + req_state.value < self.no_schedule_until_state.value + or req_state.value >= self.no_schedule_after_state.value): + continue + + if len(scheduled_requests) >= self.max_num_requests: + break + + if req_state == LlmRequestState.DISAGG_GENERATION_INIT: + scheduled_disagg_gen_init_requests.append(request) + else: + scheduled_requests.append(request) + + return scheduled_requests, scheduled_disagg_gen_init_requests, [] + + def prepare_resources(self, context_requests: RequestList, + generation_requests: RequestList): + """ + Prepare resources for max utilization scheduling. + Delegates to KVCacheManagerV2.prepare_resources_for_max_utilization(). + + Args: + context_requests: List of context requests to schedule + generation_requests: List of generation requests to schedule + + Returns: + Tuple of (new_context_batch, new_generation_batch) after resource allocation + """ + return self.kv_cache_manager.prepare_resources_for_max_utilization( + context_requests, generation_requests) + + class MicroBatchScheduler(ABC): @abstractmethod @@ -935,6 +994,76 @@ def _try_scheduling_request( return True +class KVCacheV2DummyPolicy(SchedulerPolicyBase): + """ + Policy wrapper for KVCacheV2DummyScheduler. + Delegates to KVCacheV2DummyScheduler for scheduling logic. + """ + + def __init__(self): + self.delegate_scheduler = None + + def schedule( + self, scheduler: 'PyCapacityScheduler', + active_requests: RequestList) -> tuple[RequestList, RequestList]: + # Lazy initialization of delegate scheduler + if self.delegate_scheduler is None: + self.delegate_scheduler = KVCacheV2DummyScheduler( + max_num_requests=scheduler.max_num_requests, + kv_cache_manager=scheduler.kv_cache_manager) + + # Delegate to KVCacheV2DummyScheduler + scheduled_requests, scheduled_disagg_gen_init_requests, paused_requests = \ + self.delegate_scheduler.schedule_request(active_requests) + + # Combine scheduled and disagg requests (PyCapacityScheduler will classify them later) + all_scheduled = scheduled_requests + scheduled_disagg_gen_init_requests + + return all_scheduled, paused_requests + + +class KVCacheV2MaxUtilizationPolicy(SchedulerPolicyBase): + """ + Policy wrapper for KVCacheV2MaxUtilizationScheduler. + Delegates to KVCacheV2MaxUtilizationScheduler for scheduling logic. + """ + + def __init__(self): + self.delegate_scheduler = None + + def schedule( + self, scheduler: 'PyCapacityScheduler', + active_requests: RequestList) -> tuple[RequestList, RequestList]: + # Lazy initialization of delegate scheduler + if self.delegate_scheduler is None: + self.delegate_scheduler = KVCacheV2MaxUtilizationScheduler( + max_num_requests=scheduler.max_num_requests, + kv_cache_manager=scheduler.kv_cache_manager) + + # Delegate to KVCacheV2MaxUtilizationScheduler + scheduled_requests, scheduled_disagg_gen_init_requests, paused_requests = \ + self.delegate_scheduler.schedule_request(active_requests) + + # Combine scheduled and disagg requests (PyCapacityScheduler will classify them later) + all_scheduled = scheduled_requests + scheduled_disagg_gen_init_requests + + return all_scheduled, paused_requests + + def prepare_resources( + self, context_requests: RequestList, + generation_requests: RequestList + ) -> tuple[RequestList, RequestList]: + """ + Prepare resources for MAX_UTILIZATION policy. + Delegates to KVCacheV2MaxUtilizationScheduler.prepare_resources(). + """ + if self.delegate_scheduler is not None and hasattr( + self.delegate_scheduler, 'prepare_resources'): + return self.delegate_scheduler.prepare_resources( + context_requests, generation_requests) + return context_requests, generation_requests + + class NoEvictScheduledBlocksManager: """ Python equivalent of C++ kv_cache_manager::NoEvictScheduledBlocksManager. @@ -1088,7 +1217,23 @@ def __init__( def _create_policy(self) -> SchedulerPolicyBase: """Create the appropriate policy based on configuration.""" - if self.kv_cache_manager is None: + # Import here to avoid circular dependency + from .resource_manager import KVCacheManagerV2 + + # Check if using KVCacheManagerV2 + is_kv_cache_v2 = isinstance(self.kv_cache_manager, KVCacheManagerV2) + + if is_kv_cache_v2: + # For KVCacheManagerV2, use specialized policies + if self.scheduler_policy == CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: + return KVCacheV2DummyPolicy() + elif self.scheduler_policy == CapacitySchedulerPolicy.MAX_UTILIZATION: + return KVCacheV2MaxUtilizationPolicy() + else: + raise ValueError( + f"Unsupported scheduler policy for KVCacheManagerV2: {self.scheduler_policy}" + ) + elif self.kv_cache_manager is None: return MaxRequestsPolicy() elif self.scheduler_policy == CapacitySchedulerPolicy.MAX_UTILIZATION: return MaxUtilizationPolicy() @@ -1290,6 +1435,23 @@ def _classify_output( fitting_requests.append(req) return fitting_requests, fitting_disagg_gen_init_requests + def prepare_resources( + self, context_requests: RequestList, + generation_requests: RequestList + ) -> tuple[RequestList, RequestList]: + """ + Prepare resources for scheduled requests. + Delegates to the internal policy's prepare_resources method if it exists. + + :param context_requests: List of scheduled context requests + :param generation_requests: List of scheduled generation requests + :return: Tuple of (updated context_requests, updated generation_requests) + """ + if hasattr(self._policy, 'prepare_resources'): + return self._policy.prepare_resources(context_requests, + generation_requests) + return context_requests, generation_requests + class SimpleUnifiedScheduler(RequestScheduler): @@ -1350,6 +1512,13 @@ def schedule_request(self, active_requests: RequestList, context_requests, generation_requests = \ self.micro_batch_scheduler.schedule(fitting_requests, inflight_request_ids) + # Step 3: Resource Preparation (for schedulers that need it, e.g., KVCacheV2MaxUtilization) + # This delegates to PyCapacityScheduler.prepare_resources() which delegates to the policy + # For KVCacheV2MaxUtilizationPolicy, this allocates KV cache resources + if hasattr(self.capacity_scheduler, 'prepare_resources'): + context_requests, generation_requests = \ + self.capacity_scheduler.prepare_resources(context_requests, generation_requests) + return SchedulerOutput( context_requests=context_requests, generation_requests=generation_requests, diff --git a/tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py b/tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py new file mode 100644 index 00000000000..dd3f49294cb --- /dev/null +++ b/tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py @@ -0,0 +1,125 @@ +import json +from pathlib import Path + +import pytest +from utils.llm_data import llm_models_root + +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi import ( + CapacitySchedulerPolicy, + CudaGraphConfig, + KvCacheConfig, + SchedulerConfig, +) + + +# A test case of mmlu_llama from lm_eval +@pytest.fixture(scope="module") +def test_case(): + with open(Path(__file__).parent / "test_overlap_scheduler_input.json") as f: + return json.load(f) + + +@pytest.fixture(scope="module") +def model_path(): + return llm_models_root() / "gpt_oss/gpt-oss-20b" + + +def create_llm( + model_dir, + disable_overlap_scheduler, + sampler_type, + env_overrides=None, + kv_cache_config=None, + scheduler_config=None, +): + """Create LLM with specific overlap scheduler setting""" + pytorch_config = dict( + disable_overlap_scheduler=disable_overlap_scheduler, sampler_type=sampler_type + ) + + if kv_cache_config is None: + kv_cache_config = KvCacheConfig(enable_block_reuse=False) + + llm_kwargs = dict( + model=str(model_dir), + tensor_parallel_size=1, + trust_remote_code=True, + enable_chunked_prefill=True, + cuda_graph_config=CudaGraphConfig(), + **pytorch_config, + kv_cache_config=kv_cache_config, + max_num_tokens=128, # Only one request longer than max_num_tokens is required to test chunked prefill + env_overrides=env_overrides, + ) + + if scheduler_config is not None: + llm_kwargs["scheduler_config"] = scheduler_config + + return LLM(**llm_kwargs) + + +def test_kv_cache_v2_policy_consistency(model_path, test_case): + """ + Test that KVCacheManagerV2 produces consistent outputs between + GUARANTEED_NO_EVICT and MAX_UTILIZATION policies. + """ + # Test configuration + prompts = test_case["prompts"][:2] # Use fewer prompts for faster test + max_new_tokens = 50 # Shorter for faster test + + sampling_config = SamplingParams( + max_tokens=max_new_tokens, + temperature=0.0, # Deterministic for comparison + top_p=1.0, + n=1, + use_beam_search=False, + ) + + # KVCacheConfig for V2 + kv_cache_config = KvCacheConfig( + free_gpu_memory_fraction=0.7, + dtype="auto", + use_kv_cache_manager_v2=True, + enable_block_reuse=False, + ) + + # Test with GUARANTEED_NO_EVICT + scheduler_config_no_evict = SchedulerConfig( + capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT + ) + + with create_llm( + model_path, + disable_overlap_scheduler=False, + sampler_type="TorchSampler", + kv_cache_config=kv_cache_config, + scheduler_config=scheduler_config_no_evict, + ) as llm: + outputs_no_evict = llm.generate(prompts, sampling_params=sampling_config) + texts_no_evict = [output.outputs[0].text for output in outputs_no_evict] + + # Test with MAX_UTILIZATION + scheduler_config_max_util = SchedulerConfig( + capacity_scheduler_policy=CapacitySchedulerPolicy.MAX_UTILIZATION + ) + + with create_llm( + model_path, + disable_overlap_scheduler=False, + sampler_type="TorchSampler", + kv_cache_config=kv_cache_config, + scheduler_config=scheduler_config_max_util, + ) as llm: + outputs_max_util = llm.generate(prompts, sampling_params=sampling_config) + texts_max_util = [output.outputs[0].text for output in outputs_max_util] + + # Verify outputs are consistent between policies + for i, (no_evict, max_util) in enumerate(zip(texts_no_evict, texts_max_util)): + assert no_evict == max_util, ( + f"Output mismatch at index {i}:\nNO_EVICT: {no_evict}\nMAX_UTIL: {max_util}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])