Skip to content

Commit 8317872

Browse files
committed
temp code
1 parent a376b4b commit 8317872

File tree

13 files changed

+253
-68
lines changed

13 files changed

+253
-68
lines changed

lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py

Lines changed: 149 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from lmdeploy.pytorch.distributed import get_dist_manager
1515
from lmdeploy.utils import get_logger
1616

17+
from ..moe import MoeType
1718
from ..op_backend import DlinferOpsBackend
1819

1920
logger = get_logger('lmdeploy')
@@ -42,6 +43,19 @@ def is_Ascend310P(cls) -> bool:
4243
def is_Ascend910(cls) -> bool:
4344
return cls.device_name().startswith(cls.Ascend910)
4445

46+
@classmethod
47+
@lru_cache(maxsize=1)
48+
def soc_version(cls) -> str:
49+
return torch.npu.get_soc_version()
50+
51+
@classmethod
52+
def is_A2(cls) -> bool:
53+
return 220 <= cls.soc_version() <= 225
54+
55+
@classmethod
56+
def is_A3(cls) -> bool:
57+
return 250 <= cls.soc_version() <= 255
58+
4559

4660
class AscendKVQuantMeta:
4761
has_set_value: bool = False
@@ -94,7 +108,7 @@ class AscendOpsBackend(DlinferOpsBackend):
94108
half_negative_inf = torch.finfo(torch.float16).min
95109
total_slots = None
96110
max_batches = None
97-
max_tokens_accros_dp = 0
111+
graph_capture_sizes = None
98112

99113
@staticmethod
100114
def get_name() -> str:
@@ -235,27 +249,90 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s
235249

236250
return kv_start_indices, attention_mask
237251

238-
def get_max_tokens_across_dp():
239-
dist_ctx = get_dist_manager().current_context()
240-
if dist_ctx.dist_config.dp > 1:
241-
total_token_current_rank = torch.sum(step_context.q_seqlens).to(step_context.q_seqlens.dtype)
242-
if cls.enable_graph and step_context.is_decoding:
252+
def get_tokens_across_dp(dp_size, tp_size, ep_size, ep_group):
253+
num_tokens, max_tokens_across_dp = None, None
254+
if ep_size <= 1:
255+
pass
256+
else:
257+
is_graph = cls.enable_graph and step_context.is_decoding
258+
# get num tokens for running time
259+
if is_graph:
243260
from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import get_ascend_compatible_size
244-
total_token_current_rank_item = total_token_current_rank.item()
245-
total_token_current_rank = torch.tensor(
246-
[get_ascend_compatible_size(total_token_current_rank_item)],
247-
dtype=total_token_current_rank.dtype,
248-
device=total_token_current_rank.device,
261+
total_tokens_current_rank_actual = step_context.q_seqlens.size(0)
262+
num_tokens = get_ascend_compatible_size(total_tokens_current_rank_actual)
263+
total_tokens_current_rank = torch.tensor(
264+
[num_tokens],
265+
dtype=step_context.q_seqlens.dtype,
266+
device=torch.npu.current_device(),
249267
)
250-
world_size = dist_ctx.dist_config.world_size
251-
total_token_buffer = torch.zeros(world_size,
252-
dtype=step_context.q_seqlens.dtype,
253-
device=torch.npu.current_device())
254-
dist.all_gather_into_tensor(total_token_buffer, total_token_current_rank, dist_ctx.ep_gpu_group)
255-
max_tokens_accros_dp = torch.max(total_token_buffer).item()
268+
else:
269+
total_tokens_current_rank = torch.sum(step_context.q_seqlens).to(step_context.q_seqlens.dtype)
270+
num_tokens = total_tokens_current_rank.item()
271+
# get max tokens across data parallel ranks
272+
if dp_size == 1:
273+
max_tokens_across_dp = num_tokens
274+
return num_tokens, max_tokens_across_dp
275+
else:
276+
total_tokens_buffer = torch.zeros([dp_size * tp_size],
277+
dtype=step_context.q_seqlens.dtype,
278+
device=torch.npu.current_device())
279+
dist.all_gather_into_tensor(total_tokens_buffer, total_tokens_current_rank, ep_group)
280+
max_tokens_across_dp = torch.max(total_tokens_buffer).item()
281+
return num_tokens, max_tokens_across_dp
282+
283+
def get_ep_meta():
284+
dist_ctx = get_dist_manager().current_context()
285+
dp_size, tp_size, ep_size = dist_ctx.dist_config.dp, dist_ctx.dist_config.tp, dist_ctx.dist_config.ep
286+
tp_rank, ep_rank = dist_ctx.attn_tp_group.rank, dist_ctx.ep_rank
287+
tp_group = dist_ctx.attn_tp_group.gpu_group
288+
ep_group = dist_ctx.ep_gpu_group
289+
return dp_size, tp_size, ep_size, tp_rank, ep_rank, tp_group, ep_group
290+
291+
def get_mc2_token_capacity(tp_size):
292+
if cls.graph_capture_sizes:
293+
max_num_tokens = min(max(cls.graph_capture_sizes), 512)
294+
else:
295+
# NOTE: To save memory, we cap the max number of tokens to 512.
296+
max_num_tokens = min(cls.max_batches * 1, 512)
297+
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
298+
return num_tokens_per_tp_rank * tp_size
299+
300+
def select_moe_type(max_tokens_across_dp, dp_size, tp_size, ep_size):
301+
if ep_size <= 1:
302+
return MoeType.ALLGATHER
303+
mc2_token_capacity = get_mc2_token_capacity(tp_size)
304+
if SocVersion.is_A2():
305+
if max_tokens_across_dp <= mc2_token_capacity and tp_size * dp_size >= 16:
306+
moe_type = MoeType.MC2
307+
else:
308+
# TODO Currently, w4a8_dynamic does not support allgatherep, we need use all2all
309+
moe_type = MoeType.ALLGATHER
310+
elif SocVersion.is_A3():
311+
if max_tokens_across_dp <= mc2_token_capacity:
312+
moe_type = MoeType.MC2
313+
else:
314+
moe_type = MoeType.ALLTOALL
256315
else:
257-
max_tokens_accros_dp = torch.sum(step_context.q_seqlens).item()
258-
return max_tokens_accros_dp
316+
raise ValueError(f'Unsupported soc_version: {SocVersion.soc_version()}')
317+
318+
if moe_type == MoeType.ALLGATHER and not step_context.is_docding:
319+
moe_type = MoeType.ALLGATHER
320+
return moe_type
321+
322+
def update_pad_size(num_tokens, max_tokens_across_dp, tp_size, ep_size, moe_type):
323+
if ep_size <= 1:
324+
return 0
325+
# is_graph = cls.enable_graph and step_context.is_decoding
326+
# num_running_tokens = max_tokens_across_dp if is_graph else num_tokens
327+
if moe_type == MoeType.ALLGATHER:
328+
pad_size = 0
329+
elif moe_type == MoeType.ALLTOALL:
330+
pad_size = tp_size - num_tokens
331+
elif moe_type == MoeType.MC2:
332+
pad_size = (max_tokens_across_dp + tp_size - 1) // tp_size * tp_size - num_tokens
333+
if isinstance(pad_size, torch.Tensor):
334+
pad_size = pad_size.item()
335+
return pad_size
259336

260337
q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded = get_cpu_seqlens(step_context.is_decoding,
261338
is_unpaged_prefill)
@@ -267,7 +344,6 @@ def get_max_tokens_across_dp():
267344
is_unpaged_prefill, q_seqlens_list,
268345
kv_seqlens_list, max_q_seq_len,
269346
max_kv_seq_len)
270-
cls.max_tokens_accros_dp = get_max_tokens_across_dp()
271347

272348
if not cls.enable_graph and step_context.kv_quant_policy == 8:
273349
record_file = os.getenv('ASCEND_QUANT_RECORD_FILE')
@@ -300,8 +376,27 @@ def get_max_tokens_across_dp():
300376
quant_policy=step_context.kv_quant_policy,
301377
quant_meta=AscendKVQuantMeta.quant_meta,
302378
)
303-
304379
step_context.attn_metadata = attn_metadata
380+
381+
dp_size, tp_size, ep_size, tp_rank, ep_rank, tp_group, ep_group = get_ep_meta()
382+
num_tokens, max_tokens_across_dp = get_tokens_across_dp(dp_size, tp_size, ep_size, ep_group)
383+
moe_type = select_moe_type(max_tokens_across_dp, dp_size, tp_size, ep_size)
384+
pad_size = update_pad_size(num_tokens, max_tokens_across_dp, tp_size, ep_size, moe_type)
385+
mlp_meta_cls = cls.get_mlp_metadata_cls()
386+
mlp_metadata = mlp_meta_cls(
387+
max_tokens_across_dp=max_tokens_across_dp,
388+
pad_size=pad_size,
389+
dp_size=dp_size,
390+
tp_size=tp_size,
391+
ep_size=ep_size,
392+
tp_rank=tp_rank,
393+
ep_rank=ep_rank,
394+
tp_group=tp_group,
395+
ep_group=ep_group,
396+
moe_type=moe_type,
397+
)
398+
step_context.mlp_metadata = mlp_metadata
399+
# torch.npu.synchronize()
305400
return step_context
306401

307402
@staticmethod
@@ -310,7 +405,38 @@ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_
310405
"""Build graph runner."""
311406
AscendOpsBackend.enable_graph = not backend_config.eager_mode
312407
AscendOpsBackend.max_batches = cache_config.max_batches
313-
from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import AscendGraphRunner
408+
from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import (AscendGraphRunner,
409+
get_ascend_compatible_size)
410+
411+
@lru_cache
412+
def _get_graph_capture_sizes(max_batches: int):
413+
"""Capture batch size.
414+
415+
Generate compatible sizes up to max_batches (not exceeding it), then add max_batches itself to ensure it can
416+
be handled.
417+
"""
418+
if backend_config.eager_mode:
419+
return None
420+
ret = []
421+
batch_size = 1
422+
423+
# Generate batch sizes and apply get_ascend_compatible_size
424+
# Only include sizes that do not exceed max_batches
425+
while batch_size <= max_batches:
426+
compatible_size = get_ascend_compatible_size(batch_size)
427+
if compatible_size > max_batches:
428+
break
429+
if not ret or compatible_size > ret[-1]:
430+
ret.append(compatible_size)
431+
batch_size = compatible_size + 1
432+
433+
# Add max_batches itself to ensure it can be handled
434+
if max_batches not in ret:
435+
ret.append(max_batches)
436+
return ret
437+
438+
AscendOpsBackend.graph_capture_sizes = _get_graph_capture_sizes(cache_config.max_batches)
439+
314440
return AscendGraphRunner(model, model_config, cache_config, backend_config, device)
315441

316442
@staticmethod
@@ -337,6 +463,7 @@ def device_count():
337463
@staticmethod
338464
def support_ray():
339465
"""Support ray."""
466+
# return False
340467
if not _envs.ascend_set_rt_visable_devices_by_ray:
341468
os.environ['RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES'] = '1'
342469
return True

lmdeploy/pytorch/backends/dlinfer/moe.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,55 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
3+
from dataclasses import dataclass
4+
from enum import Enum, auto
25
from typing import Callable, List
36

47
import torch
58

6-
from lmdeploy.pytorch.distributed import get_dist_manager
7-
from lmdeploy.pytorch.kernels.dlinfer import DlinferDistContext, fused_moe, moe_gating_topk_softmax
9+
from lmdeploy.pytorch.kernels.dlinfer import fused_moe, moe_gating_topk_softmax
810

9-
from ..moe import FusedMoEBuilder, FusedMoEImpl, SoftmaxTopKBuilder, SoftmaxTopKImpl
11+
from ..moe import FusedMoEBuilder, FusedMoEImpl, MLPMetaData, SoftmaxTopKBuilder, SoftmaxTopKImpl
1012

1113

12-
def get_dist_ctx():
13-
dist_ctx = get_dist_manager().current_context()
14+
class MoeType(Enum):
15+
NATIVE = auto()
16+
ALLGATHER = auto()
17+
ALLTOALL = auto()
18+
MC2 = auto()
19+
UNDEFINED = auto()
1420

15-
return DlinferDistContext(dp_size=dist_ctx.dist_config.dp,
16-
tp_size=dist_ctx.dist_config.tp,
17-
ep_size=dist_ctx.dist_config.ep,
18-
dp_rank=dist_ctx.dp_rank,
19-
tp_rank=dist_ctx.attn_tp_group.rank,
20-
ep_rank=dist_ctx.ep_rank,
21-
max_tokens_accros_dp=1,
22-
tp_group=dist_ctx.attn_tp_group.gpu_group,
23-
ep_group=dist_ctx.ep_gpu_group)
2421

22+
@dataclass
23+
class DlinferMLPMetadata(MLPMetaData):
24+
max_tokens_across_dp: int = 1
25+
pad_size: int = 0
26+
dp_size: int = 1
27+
tp_size: int = 1
28+
ep_size: int = 1
29+
tp_rank: int = 0
30+
ep_rank: int = 0
31+
tp_group: torch.distributed.ProcessGroup = None
32+
ep_group: torch.distributed.ProcessGroup = None
33+
moe_type: MoeType = MoeType.UNDEFINED
2534

26-
class DlinferSoftmaxTopKImpl(SoftmaxTopKImpl):
35+
36+
class DlinferSoftmaxTopKImpl(SoftmaxTopKImpl[DlinferMLPMetadata]):
2737
"""Dlinfer softmax topk implementation."""
2838

2939
def __init__(self, top_k: int, dim: int = -1, n_groups: int = -1):
3040
self.top_k = top_k
3141
self.dim = dim
3242
if n_groups != -1:
3343
raise NotImplementedError('Group router not supported')
34-
self.dist_ctx = get_dist_ctx()
3544

36-
def forward(self, x: torch.Tensor):
37-
routing_weights, selected_experts = moe_gating_topk_softmax(x, self.top_k, self.dist_ctx)
45+
def forward(self, x: torch.Tensor, mlp_metada: DlinferMLPMetadata):
46+
routing_weights, selected_experts = moe_gating_topk_softmax(x, self.top_k, mlp_metada.pad_size,
47+
mlp_metada.tp_size, mlp_metada.ep_size,
48+
mlp_metada.tp_rank)
3849
return routing_weights, selected_experts
3950

4051

41-
class DlinferSoftmaxTopKBuilder(SoftmaxTopKBuilder):
52+
class DlinferSoftmaxTopKBuilder(SoftmaxTopKBuilder[DlinferMLPMetadata]):
4253
"""Dlinfer softmax topk implementation builder."""
4354

4455
@staticmethod
@@ -47,7 +58,7 @@ def build(top_k: int, dim: int = -1, n_groups: int = -1):
4758
return DlinferSoftmaxTopKImpl(top_k, dim, n_groups)
4859

4960

50-
class DlinferFusedMoEImpl(FusedMoEImpl):
61+
class DlinferFusedMoEImpl(FusedMoEImpl[DlinferMLPMetadata]):
5162
"""Dlinfer fused moe implementation."""
5263

5364
def __init__(self,
@@ -61,12 +72,13 @@ def __init__(self,
6172
self.renormalize = renormalize
6273
self.ep_size = ep_size
6374
self.ep_group = ep_group
64-
self.dist_ctx = get_dist_ctx()
6575

6676
def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor):
6777
"""Update weights."""
6878
device_type = gate_up_weights.device.type
6979
if device_type in ['npu']:
80+
if os.getenv('DLINFER_RESET_MOE_UPDATE_WEIGHTS', '0') == '1':
81+
return gate_up_weights, down_weights
7082
return gate_up_weights.transpose(-1, -2).contiguous(), down_weights.transpose(-1, -2).contiguous()
7183
return gate_up_weights, down_weights
7284

@@ -84,6 +96,7 @@ def forward(self,
8496
topk_ids: torch.LongTensor,
8597
gate_up_weights: torch.Tensor,
8698
down_weights: torch.Tensor,
99+
mlp_metadata: DlinferMLPMetadata,
87100
gate_up_bias: torch.Tensor = None,
88101
down_bias: torch.Tensor = None,
89102
expert_list: List[int] = None,
@@ -92,10 +105,12 @@ def forward(self,
92105
assert gate_up_bias is None
93106
assert down_bias is None
94107
return fused_moe(hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids, self.top_k,
95-
self.renormalize, self.dist_ctx)
108+
self.renormalize, mlp_metadata.pad_size, mlp_metadata.tp_size, mlp_metadata.ep_size,
109+
mlp_metadata.tp_rank, mlp_metadata.ep_rank, mlp_metadata.tp_group, mlp_metadata.ep_group,
110+
mlp_metadata.moe_type)
96111

97112

98-
class DlinferFusedMoEBuilder(FusedMoEBuilder):
113+
class DlinferFusedMoEBuilder(FusedMoEBuilder[DlinferMLPMetadata]):
99114
"""Dlinfer fused moe builder."""
100115

101116
@staticmethod

lmdeploy/pytorch/backends/dlinfer/op_backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def get_attention_metadata_cls():
6767
from .attention import DlinferAttentionMetadata
6868
return DlinferAttentionMetadata
6969

70+
@staticmethod
71+
def get_mlp_metadata_cls():
72+
from .moe import DlinferMLPMetadata
73+
return DlinferMLPMetadata
74+
7075
@staticmethod
7176
def get_k_block_shape(
7277
block_size: int,

0 commit comments

Comments
 (0)