33from typing import Optional
44from lightllm .utils .log_utils import init_logger
55from lightllm .utils .dist_utils import get_current_rank_in_dp
6+ from lightllm .server .router .dynamic_prompt .shared_arr import SharedArray
7+ from lightllm .utils .envs_utils import get_unique_server_name
68
79logger = init_logger (__name__ )
810
11+
12+ class SharedRoutingConfig :
13+ """Shared MoE routing configuration across processes."""
14+
15+ def __init__ (self ):
16+ service_name = get_unique_server_name ()
17+ # Shape: [num_moe_layers, topk]
18+ self ._shm = SharedArray (f"{ service_name } _routing_config" , shape = (2 ,), dtype = np .int32 )
19+
20+ @property
21+ def num_moe_layers (self ) -> int :
22+ return int (self ._shm .arr [0 ])
23+
24+ @num_moe_layers .setter
25+ def num_moe_layers (self , value : int ):
26+ self ._shm .arr [0 ] = value
27+
28+ @property
29+ def topk (self ) -> int :
30+ return int (self ._shm .arr [1 ])
31+
32+ @topk .setter
33+ def topk (self , value : int ):
34+ self ._shm .arr [1 ] = value
35+
36+ def is_initialized (self ) -> bool :
37+ return self .num_moe_layers > 0 and self .topk > 0
38+
39+
40+ # Global shared routing config (lazy initialized)
41+ _shared_routing_config : Optional [SharedRoutingConfig ] = None
42+
43+
44+ def get_shared_routing_config () -> SharedRoutingConfig :
45+ """Get or create the shared routing config."""
46+ global _shared_routing_config
47+ if _shared_routing_config is None :
48+ _shared_routing_config = SharedRoutingConfig ()
49+ return _shared_routing_config
50+
51+
952# MoE layer counter for auto-incrementing moe_layer_index
1053_moe_layer_counter : int = 0
1154
@@ -75,12 +118,8 @@ def __init__(
75118 )
76119
77120 def capture (self , moe_layer_index : int , topk_ids : torch .Tensor , microbatch_index : int = 0 ) -> None :
78- assert (
79- 0 <= moe_layer_index < self .num_moe_layers
80- ), f"moe_layer_index { moe_layer_index } out of range [0, { self .num_moe_layers } )"
81- slot = microbatch_index % self .num_slots
82121 num_tokens = topk_ids .shape [0 ]
83- self .gpu_buffer [slot , moe_layer_index , :num_tokens , :] = topk_ids .to (self .dtype )
122+ self .gpu_buffer [microbatch_index , moe_layer_index , :num_tokens , :] = topk_ids .to (self .dtype )
84123
85124 def flush_to_cpu_async (self , mem_indexes : torch .Tensor , microbatch_index : int ) -> None :
86125 num_tokens = mem_indexes .shape [0 ]
@@ -98,9 +137,20 @@ def flush_to_cpu_async(self, mem_indexes: torch.Tensor, microbatch_index: int) -
98137 self .cpu_buffer [:, cpu_indexes , :] = self .gpu_buffer [slot , :, :num_tokens , :].cpu ()
99138 event .record ()
100139
101- def extract_for_request (self , mem_indexes : torch .Tensor ) -> np .ndarray :
140+ def sync_events (self ) -> None :
141+ """Synchronize all flush events. Call once before batch extraction."""
102142 for event in self .flush_events :
103143 event .synchronize ()
144+
145+ def extract_for_request (self , mem_indexes : torch .Tensor ) -> np .ndarray :
146+ self .sync_events ()
147+ return self .cpu_buffer [:, mem_indexes , :].numpy ()
148+
149+ def extract_for_request_no_sync (self , mem_indexes : torch .Tensor ) -> np .ndarray :
150+ """Extract routing data without synchronizing events.
151+
152+ Call sync_events() once before using this method in a batch.
153+ """
104154 return self .cpu_buffer [:, mem_indexes , :].numpy ()
105155
106156
@@ -132,8 +182,6 @@ def init_routing_capture(model) -> None:
132182 return
133183
134184 # Only create routing capture manager on rank 0
135- # Routing decisions are identical across all TP ranks, so we only need to capture on rank 0
136- # which is the rank that communicates results back to the Router/HTTP server
137185 if get_current_rank_in_dp () != 0 :
138186 logger .info ("Skipping routing capture initialization on non-zero rank" )
139187 return
@@ -145,16 +193,9 @@ def init_routing_capture(model) -> None:
145193 )
146194 return
147195
148- n_routed_experts = model .config .get ("n_routed_experts" , model .config .get ("num_experts" , 0 ))
149- if n_routed_experts == 0 :
150- logger .warning (
151- "enable_return_routed_experts is set but n_routed_experts=0. " "Routing capture will not be enabled."
152- )
153- return
154-
155- topk = model .config .get ("num_experts_per_tok" , 1 )
156- num_experts = n_routed_experts
157-
196+ num_experts = model .config .get ("n_routed_experts" , model .config .get ("num_experts" , 0 ))
197+ topk = model .config .get ("num_experts_per_tok" , 0 )
198+ assert num_experts > 0 and topk > 0
158199 enable_overlap = getattr (model .args , "enable_decode_microbatch_overlap" , False )
159200
160201 logger .info (
@@ -167,11 +208,16 @@ def init_routing_capture(model) -> None:
167208 topk = topk ,
168209 num_experts = num_experts ,
169210 batch_max_tokens = model .max_total_token_num ,
170- # Add 1 to handle potential edge case where mem_index == size
171211 kv_cache_size = model .mem_manager .size + 1 ,
172212 enable_overlap = enable_overlap ,
173213 )
174214
215+ # Set shared routing config for cross-process access
216+ shared_config = get_shared_routing_config ()
217+ shared_config .num_moe_layers = num_moe_layers
218+ shared_config .topk = topk
219+ logger .info (f"Shared routing config set: num_moe_layers={ num_moe_layers } , topk={ topk } " )
220+
175221
176222def flush_routing_capture (mem_indexes : torch .Tensor , microbatch_index : int = 0 ) -> None :
177223 if g_routing_capture_manager is not None :
0 commit comments