Skip to content

Commit d41042e

Browse files
sufubaoDeveloper
authored andcommitted
feat: add R3 routing support for LLM inference
Add routing capture and management infrastructure to support R3-style request routing across model inference backends. Includes routing manager, request/batch extensions, API endpoint additions, and backend integration for deepseek2, mixtral, qwen3_moe, llama, and gpt_oss models. clean code
1 parent bbdc7ba commit d41042e

File tree

34 files changed

+527
-32
lines changed

34 files changed

+527
-32
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ dist
77
.vscode
88
tmp/
99
requirements-musa.txt
10+
CLAUDE.md

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
num_fused_shared_experts: int = 0,
3434
layer_num: int = 0,
3535
network_config: Dict[str, Any] = None,
36+
moe_layer_index: int = 0,
3637
) -> None:
3738
super().__init__(data_type=data_type)
3839
self.w1_weight_name = gate_proj_name
@@ -50,6 +51,7 @@ def __init__(
5051
self.enable_ep_moe = get_env_start_args().enable_ep_moe
5152
self.n_routed_experts = n_routed_experts
5253
self.num_fused_shared_experts = num_fused_shared_experts
54+
self.moe_layer_index = moe_layer_index
5355
self._init_config(network_config)
5456
self._init_redundancy_expert_params()
5557
self._init_parallel_params()
@@ -130,6 +132,7 @@ def experts(
130132
topk_group: int,
131133
num_expert_group: int,
132134
is_prefill: Optional[bool] = None,
135+
microbatch_index: int = 0,
133136
) -> torch.Tensor:
134137
"""Backward compatible method that routes to platform-specific implementation."""
135138
return self.fuse_moe_impl(
@@ -145,6 +148,8 @@ def experts(
145148
topk_group=topk_group,
146149
num_expert_group=num_expert_group,
147150
is_prefill=is_prefill,
151+
moe_layer_index=self.moe_layer_index,
152+
microbatch_index=microbatch_index,
148153
)
149154

150155
def low_latency_dispatch(

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from lightllm.common.quantization import Quantcfg
99
from lightllm.common.quantization.quantize_method import QuantizationMethod
1010
from lightllm.utils.log_utils import init_logger
11+
from lightllm.common.basemodel.routing_manager import g_routing_capture_manager
1112

1213
logger = init_logger(__name__)
1314

@@ -46,6 +47,7 @@ def __init__(
4647
num_fused_shared_experts: int = 0,
4748
layer_num: int = 0,
4849
network_config: Dict[str, Any] = None,
50+
moe_layer_index: int = 0,
4951
) -> None:
5052
network_config["norm_topk_prob"] = None
5153
super().__init__(
@@ -62,6 +64,7 @@ def __init__(
6264
num_fused_shared_experts=num_fused_shared_experts,
6365
layer_num=layer_num,
6466
network_config=network_config,
67+
moe_layer_index=moe_layer_index,
6568
)
6669

6770
self.hidden_size = network_config["hidden_size"]
@@ -144,10 +147,15 @@ def experts(
144147
topk_group: int,
145148
num_expert_group: int,
146149
is_prefill: Optional[bool] = None,
150+
microbatch_index: int = 0,
147151
):
148152

149153
topk_weights, topk_ids = self._router(router_logits, top_k)
150154

155+
# Rollout router replay
156+
if g_routing_capture_manager is not None:
157+
g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index)
158+
151159
w1, w1_scale = self.w1
152160
w2, w2_scale = self.w2
153161
use_fp8_w8a8 = self.quant_method is not None

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,7 @@ def __call__(
6262
topk_group: int,
6363
num_expert_group: int,
6464
is_prefill: Optional[bool] = None,
65+
moe_layer_index: Optional[int] = None,
66+
microbatch_index: int = 0,
6567
) -> torch.Tensor:
6668
pass

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from lightllm.common.quantization.no_quant import WeightPack
44
from lightllm.common.quantization.quantize_method import QuantizationMethod
55
from .base_impl import FuseMoeBaseImpl
6+
from lightllm.common.basemodel.routing_manager import g_routing_capture_manager
67

78

89
class FuseMoeTriton(FuseMoeBaseImpl):
@@ -124,6 +125,8 @@ def __call__(
124125
topk_group: int,
125126
num_expert_group: int,
126127
is_prefill: Optional[bool] = None,
128+
moe_layer_index: Optional[int] = None,
129+
microbatch_index: int = 0,
127130
):
128131
topk_weights, topk_ids = self._select_experts(
129132
input_tensor=input_tensor,
@@ -136,6 +139,10 @@ def __call__(
136139
num_expert_group=num_expert_group,
137140
scoring_func=scoring_func,
138141
)
142+
143+
if g_routing_capture_manager is not None and moe_layer_index is not None:
144+
g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index)
145+
139146
output = self._fused_experts(
140147
input_tensor=input_tensor,
141148
w13=w13,
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import torch
2+
import numpy as np
3+
from typing import Optional
4+
from lightllm.utils.log_utils import init_logger
5+
from lightllm.utils.dist_utils import get_current_rank_in_dp
6+
from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray
7+
from lightllm.utils.envs_utils import get_unique_server_name
8+
9+
logger = init_logger(__name__)
10+
11+
12+
def get_routing_config_shm() -> SharedArray:
13+
"""Get shared memory for MoE routing config: arr[0]=num_moe_layers, arr[1]=topk."""
14+
service_name = get_unique_server_name()
15+
return SharedArray(f"{service_name}_routing_config", shape=(2,), dtype=np.int32)
16+
17+
18+
class RoutingCaptureManager:
19+
"""Captures MoE routing decisions"""
20+
21+
def __init__(
22+
self,
23+
num_moe_layers: int,
24+
topk: int,
25+
num_experts: int,
26+
batch_max_tokens: int,
27+
kv_cache_size: int,
28+
enable_overlap: bool = False,
29+
):
30+
self.num_moe_layers = num_moe_layers
31+
self.topk = topk
32+
self.num_experts = num_experts
33+
self.batch_max_tokens = batch_max_tokens
34+
self.kv_cache_size = kv_cache_size
35+
36+
self.dtype = torch.int8 if num_experts <= 127 else torch.int16
37+
dtype_bytes = 1 if self.dtype == torch.int8 else 2
38+
39+
self.num_slots = 2 if enable_overlap else 1
40+
41+
gpu_buffer_size = self.num_slots * num_moe_layers * batch_max_tokens * topk * dtype_bytes
42+
self.gpu_buffer = torch.zeros(
43+
(self.num_slots, num_moe_layers, batch_max_tokens, topk),
44+
dtype=self.dtype,
45+
device="cuda",
46+
)
47+
48+
cpu_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes
49+
self.cpu_buffer = torch.zeros(
50+
(num_moe_layers, kv_cache_size, topk),
51+
dtype=self.dtype,
52+
device="cpu",
53+
pin_memory=True,
54+
)
55+
56+
self.flush_streams = [torch.cuda.Stream() for _ in range(self.num_slots)]
57+
self.flush_events = [torch.cuda.Event() for _ in range(self.num_slots)]
58+
59+
dtype_name = "int8" if self.dtype == torch.int8 else "int16"
60+
logger.info(
61+
f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, "
62+
f"slots={self.num_slots}, GPU={gpu_buffer_size / 1024 / 1024:.2f}MB, "
63+
f"CPU={cpu_buffer_size / 1024 / 1024:.2f}MB, dtype={dtype_name}"
64+
)
65+
66+
def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None:
67+
num_tokens = topk_ids.shape[0]
68+
self.gpu_buffer[microbatch_index, moe_layer_index, :num_tokens, :] = topk_ids.to(self.dtype)
69+
70+
def flush_to_cpu_async(self, mem_indexes: torch.Tensor, microbatch_index: int) -> None:
71+
num_tokens = mem_indexes.shape[0]
72+
if num_tokens == 0:
73+
return
74+
75+
slot = microbatch_index % self.num_slots
76+
stream = self.flush_streams[slot]
77+
event = self.flush_events[slot]
78+
79+
stream.wait_stream(torch.cuda.current_stream())
80+
81+
with torch.cuda.stream(stream):
82+
cpu_indexes = mem_indexes.cpu()
83+
self.cpu_buffer[:, cpu_indexes, :] = self.gpu_buffer[slot, :, :num_tokens, :].cpu()
84+
event.record()
85+
86+
def sync_events(self) -> None:
87+
"""Synchronize all flush events. Call once before batch extraction."""
88+
for event in self.flush_events:
89+
event.synchronize()
90+
91+
def extract_for_request(self, mem_indexes: torch.Tensor) -> np.ndarray:
92+
self.sync_events()
93+
return self.cpu_buffer[:, mem_indexes, :].numpy()
94+
95+
def extract_for_request_no_sync(self, mem_indexes: torch.Tensor) -> np.ndarray:
96+
return self.cpu_buffer[:, mem_indexes, :].numpy()
97+
98+
99+
g_routing_capture_manager: Optional[RoutingCaptureManager] = None
100+
101+
102+
def create_routing_capture_manager(
103+
num_moe_layers: int,
104+
topk: int,
105+
num_experts: int,
106+
batch_max_tokens: int,
107+
kv_cache_size: int,
108+
enable_overlap: bool = False,
109+
) -> None:
110+
global g_routing_capture_manager
111+
assert g_routing_capture_manager is None, "RoutingCaptureManager already exists"
112+
g_routing_capture_manager = RoutingCaptureManager(
113+
num_moe_layers=num_moe_layers,
114+
topk=topk,
115+
num_experts=num_experts,
116+
batch_max_tokens=batch_max_tokens,
117+
kv_cache_size=kv_cache_size,
118+
enable_overlap=enable_overlap,
119+
)
120+
121+
122+
def init_routing_capture(model, num_moe_layers: int) -> None:
123+
if get_current_rank_in_dp() != 0:
124+
# Skipping routing capture initialization on non-zero rank
125+
return
126+
127+
if num_moe_layers == 0:
128+
logger.warning(
129+
"enable_return_routed_experts is set but no MoE layers found. " "Routing capture will not be enabled."
130+
)
131+
return
132+
133+
num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0))
134+
topk = model.config.get("num_experts_per_tok", 0)
135+
assert num_experts > 0 and topk > 0
136+
enable_overlap = getattr(model.args, "enable_decode_microbatch_overlap", False)
137+
138+
logger.info(
139+
f"Initializing routing capture: num_moe_layers={num_moe_layers}, "
140+
f"topk={topk}, num_experts={num_experts}, enable_overlap={enable_overlap}"
141+
)
142+
143+
create_routing_capture_manager(
144+
num_moe_layers=num_moe_layers,
145+
topk=topk,
146+
num_experts=num_experts,
147+
batch_max_tokens=model.max_total_token_num,
148+
kv_cache_size=model.mem_manager.size + 1,
149+
enable_overlap=enable_overlap,
150+
)
151+
152+
shm = get_routing_config_shm()
153+
shm.arr[0] = num_moe_layers
154+
shm.arr[1] = topk
155+
logger.info(f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}")

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def _moe_ffn(
312312
use_grouped_topk=self.n_group,
313313
topk_group=self.topk_group,
314314
num_expert_group=self.n_group,
315+
microbatch_index=infer_state.microbatch_index,
315316
)
316317

317318
if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0:
@@ -339,6 +340,7 @@ def _moe_ffn_edp(
339340
topk_group=self.topk_group,
340341
num_expert_group=self.n_group,
341342
is_prefill=infer_state.is_prefill,
343+
microbatch_index=infer_state.microbatch_index,
342344
)
343345

344346
if self.n_shared_experts is not None:

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ def _init_moe(self):
242242
# == 0 时,说明不存在融合共享专家,共享专家单独加载和进行推理。
243243
if self.num_fused_shared_experts == 0:
244244
self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True)
245+
first_moe = self.network_config_["first_k_dense_replace"]
246+
freq = self.network_config_.get("moe_layer_freq", 1)
247+
moe_layer_index = (self.layer_num_ - first_moe) // freq
245248
self.experts = FusedMoeWeight(
246249
gate_proj_name="gate_proj",
247250
down_proj_name="down_proj",
@@ -256,6 +259,7 @@ def _init_moe(self):
256259
num_fused_shared_experts=self.num_fused_shared_experts,
257260
layer_num=self.layer_num_,
258261
network_config=self.network_config_,
262+
moe_layer_index=moe_layer_index,
259263
)
260264

261265
def _init_ffn(self):

lightllm/models/deepseek2/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
77
from lightllm.models.llama.model import LlamaTpPartModel
88
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
9+
from lightllm.common.basemodel.routing_manager import init_routing_capture
910
from lightllm.utils.log_utils import init_logger
1011
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num
1112
from lightllm.distributed.communication_op import dist_group_manager
@@ -49,6 +50,9 @@ def _init_some_value(self):
4950
def _init_custom(self):
5051
self._init_to_get_yarn_rotary()
5152
dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"])
53+
if self.args.enable_return_routed_experts:
54+
num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe)
55+
init_routing_capture(self, num_moe_layers)
5256

5357
def _verify_params(self):
5458
return super()._verify_params()

lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -
5151
use_grouped_topk=False,
5252
topk_group=None,
5353
num_expert_group=None,
54+
microbatch_index=infer_state.microbatch_index,
5455
)
5556
return hidden_states.view(num_tokens, hidden_dim)
5657

0 commit comments

Comments
 (0)