Skip to content

Commit 611b216

Browse files
committed
clean
1 parent 5dfcf8b commit 611b216

File tree

8 files changed

+128
-58
lines changed

8 files changed

+128
-58
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args
1414
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank
1515
from lightllm.utils.log_utils import init_logger
16-
from lightllm.common.basemodel.routing_manager import g_routing_capture_manager, get_next_moe_layer_index
16+
from lightllm.common.basemodel.routing_manager import get_next_moe_layer_index
1717

1818
logger = init_logger(__name__)
1919

@@ -105,7 +105,6 @@ def _init_parallel_params(self):
105105
f"redundancy_expertids: {self.redundancy_expert_ids}"
106106
)
107107
self.local_n_routed_experts = self.n_routed_experts // self.global_world_size + self.redundancy_expert_num
108-
self.split_inter_size = self.moe_intermediate_size
109108
n_experts_per_rank = self.n_routed_experts // self.global_world_size
110109
start_expert_id = self.global_rank_ * n_experts_per_rank
111110
self.local_expert_ids = (

lightllm/common/basemodel/routing_manager.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,52 @@
33
from typing import Optional
44
from lightllm.utils.log_utils import init_logger
55
from lightllm.utils.dist_utils import get_current_rank_in_dp
6+
from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray
7+
from lightllm.utils.envs_utils import get_unique_server_name
68

79
logger = init_logger(__name__)
810

11+
12+
class SharedRoutingConfig:
13+
"""Shared MoE routing configuration across processes."""
14+
15+
def __init__(self):
16+
service_name = get_unique_server_name()
17+
# Shape: [num_moe_layers, topk]
18+
self._shm = SharedArray(f"{service_name}_routing_config", shape=(2,), dtype=np.int32)
19+
20+
@property
21+
def num_moe_layers(self) -> int:
22+
return int(self._shm.arr[0])
23+
24+
@num_moe_layers.setter
25+
def num_moe_layers(self, value: int):
26+
self._shm.arr[0] = value
27+
28+
@property
29+
def topk(self) -> int:
30+
return int(self._shm.arr[1])
31+
32+
@topk.setter
33+
def topk(self, value: int):
34+
self._shm.arr[1] = value
35+
36+
def is_initialized(self) -> bool:
37+
return self.num_moe_layers > 0 and self.topk > 0
38+
39+
40+
# Global shared routing config (lazy initialized)
41+
_shared_routing_config: Optional[SharedRoutingConfig] = None
42+
43+
44+
def get_shared_routing_config() -> SharedRoutingConfig:
45+
"""Get or create the shared routing config."""
46+
global _shared_routing_config
47+
if _shared_routing_config is None:
48+
_shared_routing_config = SharedRoutingConfig()
49+
return _shared_routing_config
50+
51+
952
# MoE layer counter for auto-incrementing moe_layer_index
1053
_moe_layer_counter: int = 0
1154

@@ -75,12 +118,8 @@ def __init__(
75118
)
76119

77120
def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None:
78-
assert (
79-
0 <= moe_layer_index < self.num_moe_layers
80-
), f"moe_layer_index {moe_layer_index} out of range [0, {self.num_moe_layers})"
81-
slot = microbatch_index % self.num_slots
82121
num_tokens = topk_ids.shape[0]
83-
self.gpu_buffer[slot, moe_layer_index, :num_tokens, :] = topk_ids.to(self.dtype)
122+
self.gpu_buffer[microbatch_index, moe_layer_index, :num_tokens, :] = topk_ids.to(self.dtype)
84123

85124
def flush_to_cpu_async(self, mem_indexes: torch.Tensor, microbatch_index: int) -> None:
86125
num_tokens = mem_indexes.shape[0]
@@ -98,9 +137,20 @@ def flush_to_cpu_async(self, mem_indexes: torch.Tensor, microbatch_index: int) -
98137
self.cpu_buffer[:, cpu_indexes, :] = self.gpu_buffer[slot, :, :num_tokens, :].cpu()
99138
event.record()
100139

101-
def extract_for_request(self, mem_indexes: torch.Tensor) -> np.ndarray:
140+
def sync_events(self) -> None:
141+
"""Synchronize all flush events. Call once before batch extraction."""
102142
for event in self.flush_events:
103143
event.synchronize()
144+
145+
def extract_for_request(self, mem_indexes: torch.Tensor) -> np.ndarray:
146+
self.sync_events()
147+
return self.cpu_buffer[:, mem_indexes, :].numpy()
148+
149+
def extract_for_request_no_sync(self, mem_indexes: torch.Tensor) -> np.ndarray:
150+
"""Extract routing data without synchronizing events.
151+
152+
Call sync_events() once before using this method in a batch.
153+
"""
104154
return self.cpu_buffer[:, mem_indexes, :].numpy()
105155

106156

@@ -132,8 +182,6 @@ def init_routing_capture(model) -> None:
132182
return
133183

134184
# Only create routing capture manager on rank 0
135-
# Routing decisions are identical across all TP ranks, so we only need to capture on rank 0
136-
# which is the rank that communicates results back to the Router/HTTP server
137185
if get_current_rank_in_dp() != 0:
138186
logger.info("Skipping routing capture initialization on non-zero rank")
139187
return
@@ -145,16 +193,9 @@ def init_routing_capture(model) -> None:
145193
)
146194
return
147195

148-
n_routed_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0))
149-
if n_routed_experts == 0:
150-
logger.warning(
151-
"enable_return_routed_experts is set but n_routed_experts=0. " "Routing capture will not be enabled."
152-
)
153-
return
154-
155-
topk = model.config.get("num_experts_per_tok", 1)
156-
num_experts = n_routed_experts
157-
196+
num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0))
197+
topk = model.config.get("num_experts_per_tok", 0)
198+
assert num_experts > 0 and topk > 0
158199
enable_overlap = getattr(model.args, "enable_decode_microbatch_overlap", False)
159200

160201
logger.info(
@@ -167,11 +208,16 @@ def init_routing_capture(model) -> None:
167208
topk=topk,
168209
num_experts=num_experts,
169210
batch_max_tokens=model.max_total_token_num,
170-
# Add 1 to handle potential edge case where mem_index == size
171211
kv_cache_size=model.mem_manager.size + 1,
172212
enable_overlap=enable_overlap,
173213
)
174214

215+
# Set shared routing config for cross-process access
216+
shared_config = get_shared_routing_config()
217+
shared_config.num_moe_layers = num_moe_layers
218+
shared_config.topk = topk
219+
logger.info(f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}")
220+
175221

176222
def flush_routing_capture(mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None:
177223
if g_routing_capture_manager is not None:

lightllm/common/quantization/w8a8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def quantize(self, weight: torch.Tensor, output: WeightPack) -> None:
7272
weight = weight.float().cuda(self.device_id_)
7373
scale = weight.abs().max(dim=-1)[0] / 127
7474
weight = weight / scale.reshape(-1, 1)
75-
weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8)
75+
weight = torch.round(weight.clamp(min=-127, max=127)).to(dtype=torch.int8)
7676
output.weight.copy_(weight)
7777
output.weight_scale.copy_(scale)
7878
return

lightllm/server/core/objs/req.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import math
33
import ctypes
4+
import base64
45
import numpy as np
56
import time
67
from .sampling_params import SamplingParams
@@ -122,9 +123,6 @@ class Req(ctypes.Structure):
122123
("cpu_cache_match_page_indexes", CpuCachePageList),
123124
# 分块hash的块大小
124125
("cpu_cache_token_page_size", ctypes.c_int),
125-
("routing_data_num_moe_layers", ctypes.c_int),
126-
("routing_data_num_tokens", ctypes.c_int),
127-
("routing_data_topk", ctypes.c_int),
128126
]
129127

130128
def get_str(self):
@@ -183,10 +181,6 @@ def init(
183181
self.stop_str_matched = False
184182
self.stop_str_matched_token_index = -1
185183

186-
self.routing_data_num_moe_layers = 0
187-
self.routing_data_num_tokens = 0
188-
self.routing_data_topk = 0
189-
190184
self.post_init()
191185

192186
self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size
@@ -240,25 +234,21 @@ def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, to
240234
shape = (num_moe_layers, num_tokens, topk)
241235
self.shm_routing_data = ShmArray(name, shape, dtype=np.int32)
242236
self.shm_routing_data.create_shm()
243-
self.routing_data_num_moe_layers = num_moe_layers
244-
self.routing_data_num_tokens = num_tokens
245-
self.routing_data_topk = topk
246237
return
247238

248-
def link_routing_data_shm_array(self):
249-
if self.routing_data_num_moe_layers == 0:
239+
def link_routing_data_shm_array(self, num_moe_layers: int, topk: int):
240+
if num_moe_layers == 0:
250241
return
251242
service_uni_name = get_unique_server_name()
252243
name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}"
253-
shape = (self.routing_data_num_moe_layers, self.routing_data_num_tokens, self.routing_data_topk)
244+
# num_tokens equals shm_cur_kv_len at the time of creation
245+
shape = (num_moe_layers, self.shm_cur_kv_len, topk)
254246
self.shm_routing_data = ShmArray(name, shape, dtype=np.int32)
255247
self.shm_routing_data.link_shm()
256248
return
257249

258250
def get_routing_data(self):
259-
if self.routing_data_num_moe_layers == 0 or not hasattr(self, "shm_routing_data"):
260-
return None
261-
if self.shm_routing_data is None:
251+
if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None:
262252
return None
263253
return self.shm_routing_data.arr
264254

@@ -268,6 +258,29 @@ def close_routing_data_shm_array(self):
268258
self.shm_routing_data = None
269259
return
270260

261+
def get_routing_metadata(self, num_moe_layers: int, topk: int):
262+
"""Safely extract routing data and format for API response.
263+
264+
Returns a dict with shape, dtype, and base64-encoded data, or None if unavailable.
265+
"""
266+
if num_moe_layers == 0 or topk == 0:
267+
return None
268+
try:
269+
self.link_routing_data_shm_array(num_moe_layers, topk)
270+
routing_data = self.get_routing_data()
271+
if routing_data is None:
272+
return None
273+
return {
274+
"shape": list(routing_data.shape),
275+
"dtype": str(routing_data.dtype),
276+
"data": base64.b64encode(routing_data.tobytes()).decode("ascii"),
277+
}
278+
except Exception as e:
279+
logger.warning(f"Failed to read routing data for req {self.request_id}: {e}")
280+
return None
281+
finally:
282+
self.close_routing_data_shm_array()
283+
271284
def get_prompt_ids(self):
272285
return self.shm_prompt_ids.arr[: self.input_len].tolist()
273286

lightllm/server/core/objs/sampling_params.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,6 @@ def to_dict(self):
497497
"add_spaces_between_special_tokens": self.add_spaces_between_special_tokens,
498498
"print_eos_token": self.print_eos_token,
499499
"disable_prompt_cache": self.disable_prompt_cache,
500-
"return_routed_experts": self.return_routed_experts,
501500
}
502501

503502
def to_origin_dict(self):

lightllm/server/core/objs/start_args_type.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,5 @@ class StartArgs:
159159
# multi_modal
160160
enable_multimodal: bool = field(default=False)
161161
enable_multimodal_audio: bool = field(default=False)
162+
163+
enable_return_routed_experts: bool = field(default=False)

lightllm/server/httpserver/manager.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import hashlib
1111
import datetime
1212
import pickle
13-
import base64
1413
from frozendict import frozendict
1514

1615
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -30,6 +29,7 @@
3029
from lightllm.server.core.objs.shm_req_manager import ShmReqManager
3130
from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem
3231
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
32+
from lightllm.common.basemodel.routing_manager import get_shared_routing_config
3333
from lightllm.utils.log_utils import init_logger
3434
from lightllm.server.metrics.manager import MetricClient
3535
from lightllm.utils.statics_utils import MovingAverage
@@ -115,6 +115,9 @@ def __init__(
115115
# If the timemark is not updated for a pre-set time, a prob request will be sent to the backend.
116116
self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark")
117117
self.latest_success_infer_time_mark.set_value(int(time.time()))
118+
119+
# Cache routing config for MoE expert routing data extraction
120+
self._routing_config = get_shared_routing_config() if args.enable_return_routed_experts else None
118121
return
119122

120123
async def _alloc_resource(self, items, md5sums, token_nums, datas):
@@ -779,19 +782,12 @@ async def handle_loop(self):
779782
else:
780783
finish_status = FinishStatus(req.finish_status.status)
781784

782-
if req.sample_params.return_routed_experts and req.routing_data_num_moe_layers > 0:
783-
try:
784-
req.link_routing_data_shm_array()
785-
routing_data = req.get_routing_data()
786-
if routing_data is not None:
787-
metadata["routed_experts"] = {
788-
"shape": list(routing_data.shape),
789-
"dtype": str(routing_data.dtype),
790-
"data": base64.b64encode(routing_data.tobytes()).decode("ascii"),
791-
}
792-
req.close_routing_data_shm_array()
793-
except Exception as e:
794-
logger.warning(f"Failed to read routing data for req {req_id}: {e}")
785+
if self._routing_config is not None and self._routing_config.is_initialized():
786+
routing_meta = req.get_routing_metadata(
787+
self._routing_config.num_moe_layers, self._routing_config.topk
788+
)
789+
if routing_meta is not None:
790+
metadata["routed_experts"] = routing_meta
795791

796792
token_list.append((req_id, text, metadata, finish_status))
797793
else:

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,25 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
114114

115115
return req_objs
116116

117-
def _extract_routing_data(self, req: "InferReq"):
117+
def _extract_routing_data(self, req: "InferReq", sync: bool = True):
118+
"""Extract MoE routing data for a completed request.
119+
120+
Args:
121+
req: The inference request to extract routing data for.
122+
sync: If True, synchronize CUDA events before extraction. Set to False
123+
when processing multiple requests in batch after calling
124+
g_routing_capture_manager.sync_events() once.
125+
"""
118126
mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]
119127
num_moe_layers = g_routing_capture_manager.num_moe_layers
120128
topk = g_routing_capture_manager.topk
121129
num_tokens = req.cur_kv_len
122-
logger.debug(f"R3: Extracting routing for req {req.req_id}: {num_moe_layers}x{num_tokens}x{topk}")
123-
routing_data = g_routing_capture_manager.extract_for_request(mem_indexes.cpu())
130+
if sync:
131+
routing_data = g_routing_capture_manager.extract_for_request(mem_indexes.cpu())
132+
else:
133+
routing_data = g_routing_capture_manager.extract_for_request_no_sync(mem_indexes.cpu())
124134
req.shm_req.create_routing_data_shm_array(num_moe_layers, num_tokens, topk)
125135
req.shm_req.shm_routing_data.arr[:] = routing_data
126-
logger.debug(f"R3: Successfully extracted routing data for req {req.req_id}")
127136

128137
def free_a_req_mem(self, free_token_index: List, req: "InferReq"):
129138
if self.radix_cache is None:
@@ -161,14 +170,20 @@ def _filter(self, finished_request_ids: List[int]):
161170
if len(finished_request_ids) == 0:
162171
return
163172

173+
# Optimization: sync CUDA events once for batch routing data extraction
174+
need_routing_data = g_routing_capture_manager is not None
175+
if need_routing_data:
176+
g_routing_capture_manager.sync_events()
177+
164178
free_req_index = []
165179
free_token_index = []
166180
for request_id in finished_request_ids:
167181
req: InferReq = self.requests_mapping.pop(request_id)
168182
if self.args.diverse_mode:
169183
req.clear_master_slave_state()
170184

171-
self._extract_routing_data(req)
185+
if need_routing_data:
186+
self._extract_routing_data(req, sync=False)
172187

173188
self.free_a_req_mem(free_token_index, req)
174189

0 commit comments

Comments
 (0)