Skip to content

Commit da967d0

Browse files
authored
[TRTLLM-10334] [feat] Support overlap scheduler for disagg ctx instances (#10755)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
1 parent 58dc4be commit da967d0

12 files changed

+46
-44
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class BatchState:
104104

105105
iter_start_time: float = 0
106106
iter_stats: IterationStats = None
107-
ctx_transmission_reqs: list[LlmRequest] = None
107+
all_requests: list[LlmRequest] = None
108108

109109

110110
@dataclasses.dataclass
@@ -1804,6 +1804,7 @@ def _executor_loop_overlap(self):
18041804

18051805
if self.previous_batch is not None and should_process_previous_batch:
18061806
self._update_requests(self.previous_batch.sample_state)
1807+
self._send_kv_async(self.previous_batch.all_requests)
18071808

18081809
if self.drafter is not None and self.use_spec_decode and should_process_previous_batch:
18091810
# Cleanup previous draft resources used in the draft model
@@ -1829,9 +1830,6 @@ def _executor_loop_overlap(self):
18291830

18301831
self._update_request_states(scheduled_batch)
18311832

1832-
ctx_transmission_reqs = self._send_kv_async(
1833-
scheduled_batch.all_requests())
1834-
18351833
if self.previous_batch is not None and should_process_previous_batch:
18361834
self._process_previous_batch()
18371835
else:
@@ -1846,7 +1844,7 @@ def _executor_loop_overlap(self):
18461844
sample_state=sample_state,
18471845
iter_start_time=iter_start_time,
18481846
iter_stats=iter_stats,
1849-
ctx_transmission_reqs=ctx_transmission_reqs)
1847+
all_requests=scheduled_batch.all_requests())
18501848
elif not can_queue_this_rank:
18511849
# If the batch is empty on this rank, we need to clear the previous batch.
18521850
self.previous_batch = None
@@ -1949,10 +1947,6 @@ def _accept_draft_tokens(
19491947
return result_tensors, num_accepted_tokens
19501948

19511949
def _process_previous_batch(self):
1952-
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
1953-
for req in self.previous_batch.ctx_transmission_reqs:
1954-
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
1955-
19561950
self._handle_canceled_requests()
19571951
finished_requests = self._handle_responses()
19581952
scheduled_requests = self.previous_batch.sample_state.scheduled_requests

tensorrt_llm/executor/base_worker.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -448,15 +448,13 @@ def _enqueue_request(self,
448448
context_phase_params = request.disaggregated_params.get_context_phase_params(
449449
)
450450

451-
if self._is_pytorch_backend:
452-
if not self.llm_args.disable_overlap_scheduler:
453-
is_disaggregated = self.engine.kv_cache_transceiver is not None
454-
if is_disaggregated and (
455-
request_type
456-
== tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
457-
raise ValueError(
458-
"Context only requests are not supported in pytorch backend when overlap is enabled."
459-
)
451+
if self._is_pytorch_backend and not self.llm_args.disable_overlap_scheduler \
452+
and self.llm_args.kv_cache_config.enable_block_reuse \
453+
and self.engine.kv_cache_transceiver is not None \
454+
and request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY:
455+
raise ValueError(
456+
"Context only requests are not supported in pytorch backend when overlap is enabled with block reuse."
457+
)
460458

461459
assert request.id is not None
462460

tests/integration/defs/.test_durations

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@
144144
"accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend": 71.2399792142678,
145145
"accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]": 286.7775873204227537,
146146
"accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]": 286.6778334858827293,
147-
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False]": 781.7928658421151,
148-
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]": 270.3750694899354,
147+
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True]": 781.7928658421151,
148+
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True]": 270.3750694899354,
149149
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=2]": 195.4896494857967,
150150
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4]": 205.93911361903884,
151151
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2]": 188.56422709790058,

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -524,20 +524,26 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
524524

525525
@skip_pre_hopper
526526
@pytest.mark.skip_less_device(2)
527-
@pytest.mark.parametrize("disable_overlap_scheduler", [False, True])
527+
@pytest.mark.parametrize("ctx_disable_overlap_scheduler", [False, True])
528+
@pytest.mark.parametrize("gen_disable_overlap_scheduler", [False, True])
528529
@pytest.mark.parametrize("ctx_enable_block_reuse", [True, False])
529530
@pytest.mark.parametrize("gen_enable_block_reuse", [True, False])
530-
def test_auto_dtype(self, disable_overlap_scheduler, ctx_enable_block_reuse,
531+
def test_auto_dtype(self, ctx_disable_overlap_scheduler,
532+
gen_disable_overlap_scheduler, ctx_enable_block_reuse,
531533
gen_enable_block_reuse):
534+
if ctx_enable_block_reuse and not ctx_disable_overlap_scheduler:
535+
pytest.skip(
536+
"Skip this test because overlap scheduler is not supported with block reuse for context server"
537+
)
532538
ctx_server_config = {
533-
"disable_overlap_scheduler": True,
539+
"disable_overlap_scheduler": ctx_disable_overlap_scheduler,
534540
"kv_cache_config": {
535541
"enable_block_reuse": ctx_enable_block_reuse
536542
}
537543
}
538544
ctx_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"}
539545
gen_server_config = {
540-
"disable_overlap_scheduler": disable_overlap_scheduler,
546+
"disable_overlap_scheduler": gen_disable_overlap_scheduler,
541547
"kv_cache_config": {
542548
"enable_block_reuse": gen_enable_block_reuse
543549
}

tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ port: 8000
33
model: DeepSeek-V3-Lite/bf16
44
free_gpu_memory_fraction: 0.25
55
backend: "pytorch"
6-
disable_overlap_scheduler: True
76
cuda_graph_config: null
87
context_servers:
98
num_instances: 1

tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp4_gentp4_deepseek_r1_v2_fp4_tllm.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ context_servers:
1818
enable_block_reuse: false
1919
free_gpu_memory_fraction: 0.80
2020
dtype: fp8
21-
disable_overlap_scheduler: true
2221
moe_config:
2322
backend: TRTLLM
2423
cuda_graph_config: null
@@ -44,7 +43,6 @@ generation_servers:
4443
enable_block_reuse: false
4544
free_gpu_memory_fraction: 0.80
4645
dtype: fp8
47-
disable_overlap_scheduler: true
4846
moe_config:
4947
backend: TRTLLM
5048
cuda_graph_config:

tests/integration/defs/disaggregated/test_configs/disagg_config_deepseek_v3_lite_empty_batch.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ context_servers:
1616
pipeline_parallel_size: 1
1717
print_iter_log: true
1818
cuda_graph_config: null
19-
disable_overlap_scheduler: true
2019
kv_cache_config:
2120
enable_block_reuse: false
2221
free_gpu_memory_fraction: 0.05

tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ context_servers:
1212
tensor_parallel_size: 1
1313
pipeline_parallel_size: 1
1414
kv_cache_config:
15+
enable_block_reuse: False
1516
free_gpu_memory_fraction: 0.2
1617
enable_partial_reuse: False
17-
disable_overlap_scheduler: True
1818
cache_transceiver_config:
1919
backend: DEFAULT
2020
urls:
@@ -27,9 +27,9 @@ generation_servers:
2727
max_num_tokens: 4096
2828
max_seq_len: 4096
2929
kv_cache_config:
30+
enable_block_reuse: False
3031
free_gpu_memory_fraction: 0.2
3132
enable_partial_reuse: False
32-
disable_overlap_scheduler: False
3333
cache_transceiver_config:
3434
backend: DEFAULT
3535
urls:

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,8 @@ accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[F
304304
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
305305
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
306306
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False]
307-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False]
308-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
307+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True]
308+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True]
309309
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
310310
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True]
311311
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]

tests/integration/test_lists/qa/llm_function_core_sanity.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
162162
accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False]
163163
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False]
164164
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
165-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False]
166-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
165+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True]
166+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True]
167167
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=2]
168168
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4]
169169
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2]

0 commit comments

Comments
 (0)