1414from lmdeploy .pytorch .distributed import get_dist_manager
1515from lmdeploy .utils import get_logger
1616
17+ from ..moe import MoeType
1718from ..op_backend import DlinferOpsBackend
1819
1920logger = get_logger ('lmdeploy' )
@@ -42,6 +43,19 @@ def is_Ascend310P(cls) -> bool:
4243 def is_Ascend910 (cls ) -> bool :
4344 return cls .device_name ().startswith (cls .Ascend910 )
4445
46+ @classmethod
47+ @lru_cache (maxsize = 1 )
48+ def soc_version (cls ) -> str :
49+ return torch .npu .get_soc_version ()
50+
51+ @classmethod
52+ def is_A2 (cls ) -> bool :
53+ return 220 <= cls .soc_version () <= 225
54+
55+ @classmethod
56+ def is_A3 (cls ) -> bool :
57+ return 250 <= cls .soc_version () <= 255
58+
4559
4660class AscendKVQuantMeta :
4761 has_set_value : bool = False
@@ -94,7 +108,7 @@ class AscendOpsBackend(DlinferOpsBackend):
94108 half_negative_inf = torch .finfo (torch .float16 ).min
95109 total_slots = None
96110 max_batches = None
97- max_tokens_accros_dp = 0
111+ graph_capture_sizes = None
98112
99113 @staticmethod
100114 def get_name () -> str :
@@ -235,27 +249,90 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s
235249
236250 return kv_start_indices , attention_mask
237251
238- def get_max_tokens_across_dp ():
239- dist_ctx = get_dist_manager ().current_context ()
240- if dist_ctx .dist_config .dp > 1 :
241- total_token_current_rank = torch .sum (step_context .q_seqlens ).to (step_context .q_seqlens .dtype )
242- if cls .enable_graph and step_context .is_decoding :
252+ def get_tokens_across_dp (dp_size , tp_size , ep_size , ep_group ):
253+ num_tokens , max_tokens_across_dp = None , None
254+ if ep_size <= 1 :
255+ pass
256+ else :
257+ is_graph = cls .enable_graph and step_context .is_decoding
258+ # get num tokens for running time
259+ if is_graph :
243260 from dlinfer .framework .lmdeploy_ext .cudagraph .ascend_cudagraph import get_ascend_compatible_size
244- total_token_current_rank_item = total_token_current_rank .item ()
245- total_token_current_rank = torch .tensor (
246- [get_ascend_compatible_size (total_token_current_rank_item )],
247- dtype = total_token_current_rank .dtype ,
248- device = total_token_current_rank .device ,
261+ total_tokens_current_rank_actual = step_context .q_seqlens .size (0 )
262+ num_tokens = get_ascend_compatible_size (total_tokens_current_rank_actual )
263+ total_tokens_current_rank = torch .tensor (
264+ [num_tokens ],
265+ dtype = step_context .q_seqlens .dtype ,
266+ device = torch .npu .current_device (),
249267 )
250- world_size = dist_ctx .dist_config .world_size
251- total_token_buffer = torch .zeros (world_size ,
252- dtype = step_context .q_seqlens .dtype ,
253- device = torch .npu .current_device ())
254- dist .all_gather_into_tensor (total_token_buffer , total_token_current_rank , dist_ctx .ep_gpu_group )
255- max_tokens_accros_dp = torch .max (total_token_buffer ).item ()
268+ else :
269+ total_tokens_current_rank = torch .sum (step_context .q_seqlens ).to (step_context .q_seqlens .dtype )
270+ num_tokens = total_tokens_current_rank .item ()
271+ # get max tokens across data parallel ranks
272+ if dp_size == 1 :
273+ max_tokens_across_dp = num_tokens
274+ return num_tokens , max_tokens_across_dp
275+ else :
276+ total_tokens_buffer = torch .zeros ([dp_size * tp_size ],
277+ dtype = step_context .q_seqlens .dtype ,
278+ device = torch .npu .current_device ())
279+ dist .all_gather_into_tensor (total_tokens_buffer , total_tokens_current_rank , ep_group )
280+ max_tokens_across_dp = torch .max (total_tokens_buffer ).item ()
281+ return num_tokens , max_tokens_across_dp
282+
283+ def get_ep_meta ():
284+ dist_ctx = get_dist_manager ().current_context ()
285+ dp_size , tp_size , ep_size = dist_ctx .dist_config .dp , dist_ctx .dist_config .tp , dist_ctx .dist_config .ep
286+ tp_rank , ep_rank = dist_ctx .attn_tp_group .rank , dist_ctx .ep_rank
287+ tp_group = dist_ctx .attn_tp_group .gpu_group
288+ ep_group = dist_ctx .ep_gpu_group
289+ return dp_size , tp_size , ep_size , tp_rank , ep_rank , tp_group , ep_group
290+
291+ def get_mc2_token_capacity (tp_size ):
292+ if cls .graph_capture_sizes :
293+ max_num_tokens = min (max (cls .graph_capture_sizes ), 512 )
294+ else :
295+ # NOTE: To save memory, we cap the max number of tokens to 512.
296+ max_num_tokens = min (cls .max_batches * 1 , 512 )
297+ num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1 ) // tp_size
298+ return num_tokens_per_tp_rank * tp_size
299+
300+ def select_moe_type (max_tokens_across_dp , dp_size , tp_size , ep_size ):
301+ if ep_size <= 1 :
302+ return MoeType .ALLGATHER
303+ mc2_token_capacity = get_mc2_token_capacity (tp_size )
304+ if SocVersion .is_A2 ():
305+ if max_tokens_across_dp <= mc2_token_capacity and tp_size * dp_size >= 16 :
306+ moe_type = MoeType .MC2
307+ else :
308+ # TODO Currently, w4a8_dynamic does not support allgatherep, we need use all2all
309+ moe_type = MoeType .ALLGATHER
310+ elif SocVersion .is_A3 ():
311+ if max_tokens_across_dp <= mc2_token_capacity :
312+ moe_type = MoeType .MC2
313+ else :
314+ moe_type = MoeType .ALLTOALL
256315 else :
257- max_tokens_accros_dp = torch .sum (step_context .q_seqlens ).item ()
258- return max_tokens_accros_dp
316+ raise ValueError (f'Unsupported soc_version: { SocVersion .soc_version ()} ' )
317+
318+ if moe_type == MoeType .ALLGATHER and not step_context .is_docding :
319+ moe_type = MoeType .ALLGATHER
320+ return moe_type
321+
322+ def update_pad_size (num_tokens , max_tokens_across_dp , tp_size , ep_size , moe_type ):
323+ if ep_size <= 1 :
324+ return 0
325+ # is_graph = cls.enable_graph and step_context.is_decoding
326+ # num_running_tokens = max_tokens_across_dp if is_graph else num_tokens
327+ if moe_type == MoeType .ALLGATHER :
328+ pad_size = 0
329+ elif moe_type == MoeType .ALLTOALL :
330+ pad_size = tp_size - num_tokens
331+ elif moe_type == MoeType .MC2 :
332+ pad_size = (max_tokens_across_dp + tp_size - 1 ) // tp_size * tp_size - num_tokens
333+ if isinstance (pad_size , torch .Tensor ):
334+ pad_size = pad_size .item ()
335+ return pad_size
259336
260337 q_seqlens_cpu , kv_seqlens_cpu , kv_seqlens_expanded = get_cpu_seqlens (step_context .is_decoding ,
261338 is_unpaged_prefill )
@@ -267,7 +344,6 @@ def get_max_tokens_across_dp():
267344 is_unpaged_prefill , q_seqlens_list ,
268345 kv_seqlens_list , max_q_seq_len ,
269346 max_kv_seq_len )
270- cls .max_tokens_accros_dp = get_max_tokens_across_dp ()
271347
272348 if not cls .enable_graph and step_context .kv_quant_policy == 8 :
273349 record_file = os .getenv ('ASCEND_QUANT_RECORD_FILE' )
@@ -300,8 +376,27 @@ def get_max_tokens_across_dp():
300376 quant_policy = step_context .kv_quant_policy ,
301377 quant_meta = AscendKVQuantMeta .quant_meta ,
302378 )
303-
304379 step_context .attn_metadata = attn_metadata
380+
381+ dp_size , tp_size , ep_size , tp_rank , ep_rank , tp_group , ep_group = get_ep_meta ()
382+ num_tokens , max_tokens_across_dp = get_tokens_across_dp (dp_size , tp_size , ep_size , ep_group )
383+ moe_type = select_moe_type (max_tokens_across_dp , dp_size , tp_size , ep_size )
384+ pad_size = update_pad_size (num_tokens , max_tokens_across_dp , tp_size , ep_size , moe_type )
385+ mlp_meta_cls = cls .get_mlp_metadata_cls ()
386+ mlp_metadata = mlp_meta_cls (
387+ max_tokens_across_dp = max_tokens_across_dp ,
388+ pad_size = pad_size ,
389+ dp_size = dp_size ,
390+ tp_size = tp_size ,
391+ ep_size = ep_size ,
392+ tp_rank = tp_rank ,
393+ ep_rank = ep_rank ,
394+ tp_group = tp_group ,
395+ ep_group = ep_group ,
396+ moe_type = moe_type ,
397+ )
398+ step_context .mlp_metadata = mlp_metadata
399+ # torch.npu.synchronize()
305400 return step_context
306401
307402 @staticmethod
@@ -310,7 +405,38 @@ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_
310405 """Build graph runner."""
311406 AscendOpsBackend .enable_graph = not backend_config .eager_mode
312407 AscendOpsBackend .max_batches = cache_config .max_batches
313- from dlinfer .framework .lmdeploy_ext .cudagraph .ascend_cudagraph import AscendGraphRunner
408+ from dlinfer .framework .lmdeploy_ext .cudagraph .ascend_cudagraph import (AscendGraphRunner ,
409+ get_ascend_compatible_size )
410+
411+ @lru_cache
412+ def _get_graph_capture_sizes (max_batches : int ):
413+ """Capture batch size.
414+
415+ Generate compatible sizes up to max_batches (not exceeding it), then add max_batches itself to ensure it can
416+ be handled.
417+ """
418+ if backend_config .eager_mode :
419+ return None
420+ ret = []
421+ batch_size = 1
422+
423+ # Generate batch sizes and apply get_ascend_compatible_size
424+ # Only include sizes that do not exceed max_batches
425+ while batch_size <= max_batches :
426+ compatible_size = get_ascend_compatible_size (batch_size )
427+ if compatible_size > max_batches :
428+ break
429+ if not ret or compatible_size > ret [- 1 ]:
430+ ret .append (compatible_size )
431+ batch_size = compatible_size + 1
432+
433+ # Add max_batches itself to ensure it can be handled
434+ if max_batches not in ret :
435+ ret .append (max_batches )
436+ return ret
437+
438+ AscendOpsBackend .graph_capture_sizes = _get_graph_capture_sizes (cache_config .max_batches )
439+
314440 return AscendGraphRunner (model , model_config , cache_config , backend_config , device )
315441
316442 @staticmethod
@@ -337,6 +463,7 @@ def device_count():
337463 @staticmethod
338464 def support_ray ():
339465 """Support ray."""
466+ # return False
340467 if not _envs .ascend_set_rt_visable_devices_by_ray :
341468 os .environ ['RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES' ] = '1'
342469 return True
0 commit comments