Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions lmdeploy/pytorch/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_down_linear, build_gateup_linear, build_o_proj,
build_rowwise_linear)
from lmdeploy.pytorch.nn.moe import build_fused_moe
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters, get_rope_theta
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from lmdeploy.utils import get_logger

Expand Down Expand Up @@ -130,9 +131,10 @@ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device

self.softmax_scale = self.q_head_dim**(-0.5)

if config.rope_scaling is not None:
mscale_all_dim = config.rope_scaling.get('mscale_all_dim', 0)
scaling_factor = config.rope_scaling['factor']
rope_scaling = get_rope_parameters(config)
if rope_scaling is not None:
mscale_all_dim = rope_scaling.get('mscale_all_dim', 0)
scaling_factor = rope_scaling['factor']
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.softmax_scale = self.softmax_scale * mscale * mscale
Expand Down Expand Up @@ -390,7 +392,7 @@ def __init__(self,
rope_dim = config.qk_rope_head_dim if getattr(config, 'use_mla', True) else (config.hidden_size //
config.num_attention_heads)
rope_max_pos_emb = config.max_position_embeddings
rope_base = config.rope_theta
rope_base = get_rope_theta(config)

rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base)
update_params = build_rotary_params(config)
Expand Down
10 changes: 6 additions & 4 deletions lmdeploy/pytorch/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_down_linear, build_gateup_linear, build_o_proj,
build_rowwise_linear)
from lmdeploy.pytorch.nn.moe import MoeType, SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters, get_rope_theta
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
Expand Down Expand Up @@ -441,9 +442,10 @@ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device

self.softmax_scale = self.q_head_dim**(-0.5)

if config.rope_scaling is not None:
mscale_all_dim = config.rope_scaling.get('mscale_all_dim', 0)
scaling_factor = config.rope_scaling['factor']
rope_scaling = get_rope_parameters(config)
if rope_scaling is not None:
mscale_all_dim = rope_scaling.get('mscale_all_dim', 0)
scaling_factor = rope_scaling['factor']
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.softmax_scale = self.softmax_scale * mscale * mscale
Expand Down Expand Up @@ -987,7 +989,7 @@ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device
rope_dim = config.qk_rope_head_dim if getattr(config, 'use_mla', True) else (config.hidden_size //
config.num_attention_heads)
rope_max_pos_emb = config.max_position_embeddings
rope_base = config.rope_theta
rope_base = get_rope_theta(config)

rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base)
update_params = build_rotary_params(config)
Expand Down
10 changes: 6 additions & 4 deletions lmdeploy/pytorch/models/deepseek_v32.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lmdeploy.pytorch.nn.eplb import EPLBManager
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_o_proj, build_rowwise_linear
from lmdeploy.pytorch.nn.nsa import IndexerTopKFP8
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters, get_rope_theta

from .deepseek_v2 import (DeepseekV2Attention, DeepseekV2BMM, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, yarn_get_mscale)
Expand Down Expand Up @@ -197,9 +198,10 @@ def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, devic

self.softmax_scale = self.q_head_dim**(-0.5)

if config.rope_scaling is not None:
mscale_all_dim = config.rope_scaling.get('mscale_all_dim', 0)
scaling_factor = config.rope_scaling['factor']
rope_scaling = get_rope_parameters(config)
if rope_scaling is not None:
mscale_all_dim = rope_scaling.get('mscale_all_dim', 0)
scaling_factor = rope_scaling['factor']
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.softmax_scale = self.softmax_scale * mscale * mscale
Expand Down Expand Up @@ -381,7 +383,7 @@ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device
rope_dim = config.qk_rope_head_dim if getattr(config, 'use_mla', True) else (config.hidden_size //
config.num_attention_heads)
rope_max_pos_emb = config.max_position_embeddings
rope_base = config.rope_theta
rope_base = get_rope_theta(config)

rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base)
update_params = build_rotary_params(config)
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/models/internlm2_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from lmdeploy.pytorch.models.internlm2 import InternLM2Attention, InternLM2MLP
from lmdeploy.pytorch.nn import RMSNorm, RopeType, build_rotary_embedding
from lmdeploy.pytorch.nn.linear import build_rowwise_linear
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters, get_rope_theta
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
Expand Down Expand Up @@ -114,7 +115,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)

# build rotary embedding in Model
rope_scaling = config.rope_scaling
rope_scaling = get_rope_parameters(config)
scaling_factor = 1.0
emb_type = RopeType.LinearScaling
if rope_scaling is not None:
Expand All @@ -128,7 +129,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
raise RuntimeError(f'Unsupported rope type: {rope_type}')
rope_dim = config.hidden_size // config.num_attention_heads
rope_max_pos_emb = config.max_position_embeddings
rope_base = config.rope_theta
rope_base = get_rope_theta(config)
self.rotary_emb = build_rotary_embedding(
rope_dim,
rope_max_pos_emb,
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj,
build_rowwise_linear)
from lmdeploy.pytorch.nn.moe import build_fused_moe
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_theta
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
Expand Down Expand Up @@ -459,7 +460,7 @@ def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device
frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x
frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y
freq_dim = config.hidden_size // config.num_attention_heads // 2
rope_freq = 1.0 / (config.rope_theta**(torch.arange(0, freq_dim, 2)[:(freq_dim // 2)].float() / freq_dim))
rope_freq = 1.0 / (get_rope_theta(config)**(torch.arange(0, freq_dim, 2)[:(freq_dim // 2)].float() / freq_dim))
freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1)
freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1)
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
Expand Down
7 changes: 4 additions & 3 deletions lmdeploy/pytorch/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding
from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_merged_colwise_linear, build_rowwise_linear
from lmdeploy.pytorch.nn.rotary_embedding import ApplyRotaryEmb, LongRoPEScalingParameters
from lmdeploy.pytorch.nn.rotary_embedding import (ApplyRotaryEmb, LongRoPEScalingParameters, get_rope_parameters,
get_rope_theta)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
Expand Down Expand Up @@ -298,8 +299,8 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
emb_type = RopeType.LinearScaling
rope_dim = config.qk_rope_head_dim
rope_max_pos_emb = config.max_position_embeddings
rope_base = config.rope_theta
rope_scaling = config.rope_scaling
rope_base = get_rope_theta(config)
rope_scaling = get_rope_parameters(config)
if rope_scaling is not None:
scaling_type = rope_scaling['type']
assert scaling_type in ['longrope', 'su']
Expand Down
7 changes: 4 additions & 3 deletions lmdeploy/pytorch/models/phi3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, LayerNorm, RopeType
from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.nn.moe import build_fused_moe
from lmdeploy.pytorch.nn.rotary_embedding import LongRoPEScalingParameters, build_rotary_embedding
from lmdeploy.pytorch.nn.rotary_embedding import (LongRoPEScalingParameters, build_rotary_embedding,
get_rope_parameters, get_rope_theta)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
Expand Down Expand Up @@ -273,8 +274,8 @@ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device
emb_type = RopeType.LinearScaling
rope_dim = config.hidden_size // config.num_attention_heads
rope_max_pos_emb = config.max_position_embeddings
rope_base = config.rope_theta
rope_scaling = config.rope_scaling
rope_base = get_rope_theta(config)
rope_scaling = get_rope_parameters(config)
if rope_scaling is not None:
scaling_type = rope_scaling['type']
assert scaling_type in ['longrope', 'su']
Expand Down
38 changes: 29 additions & 9 deletions lmdeploy/pytorch/nn/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,30 @@
YarnParameters)


def get_rope_parameters(config: PretrainedConfig):
"""Try get rope parameters from config."""
if hasattr(config, 'rope_parameters'):
# for transformers v5
return config.rope_parameters
else:
return getattr(config, 'rope_scaling', None)


def _get_default_rope_parameters(config: PretrainedConfig):
"""Get default rope parameters."""
return dict(emb_type=RopeType.Default, scaling_factor=1.0)


def _get_linear_scaling_rope_parameters(config: PretrainedConfig):
"""Get linear rope parameters."""
rope_scaling = config.rope_scaling
rope_scaling = get_rope_parameters(config=config)
scaling_factor = rope_scaling['factor']
return dict(emb_type=RopeType.LinearScaling, scaling_factor=scaling_factor)


def _get_dynamic_ntk_parameters(config: PretrainedConfig):
"""Get dynamic ntk parameters."""
rope_scaling = config.rope_scaling
rope_scaling = get_rope_parameters(config=config)
scaling_factor = rope_scaling['factor']
return dict(emb_type=RopeType.DynamicNTKScaling, scaling_factor=scaling_factor)

Expand All @@ -38,7 +47,7 @@ def get_mscale(scale, mscale=1):
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0

rope_scaling = config.rope_scaling
rope_scaling = get_rope_parameters(config=config)
factor = rope_scaling['factor']
params = YarnParameters()
params.beta_fast = rope_scaling.get('beta_fast', params.beta_fast)
Expand Down Expand Up @@ -66,7 +75,7 @@ def get_mscale(scale, mscale=1):

def _get_longrope_parameters(config: PretrainedConfig):
"""Get longrope parameters."""
rope_scaling = config.rope_scaling
rope_scaling = get_rope_parameters(config=config)
scaling_factor = rope_scaling.get('factor', 1.0)
long_factor = rope_scaling['long_factor']
short_factor = rope_scaling['short_factor']
Expand All @@ -84,7 +93,7 @@ def _get_longrope_parameters(config: PretrainedConfig):

def _get_llama3_parameters(config: PretrainedConfig):
"""Get llama rope parameters."""
rope_scaling = config.rope_scaling
rope_scaling = get_rope_parameters(config=config)
params = Llama3Parameters()
scaling_factor = rope_scaling['factor']
params.low_freq_factor = rope_scaling['low_freq_factor']
Expand All @@ -104,7 +113,7 @@ def _get_fope_parameters(config: PretrainedConfig):
return dict()

params = FopeParameters()
rope_scaling = config.rope_scaling
rope_scaling = get_rope_parameters(config=config)
params.num_inv_freq = rope_scaling.get('fope_num_inv_freq', rope_scaling.get('num_inv_freq', params.num_inv_freq))
params.num_key_value_heads = config.num_key_value_heads
params.fope_sep_head = rope_scaling['fope_sep_head']
Expand All @@ -115,10 +124,10 @@ def build_rotary_params(config: PretrainedConfig):
"""Get scaling_factor rotary params, and emb_type."""
params = dict(emb_type=RopeType.Default)
# cannot access config.rope_scaling when the model is "Qwen/Qwen2-Math-RM-72B"
rope_scaling = getattr(config, 'rope_scaling', None)
rope_scaling = get_rope_parameters(config=config)
if rope_scaling is not None:
# BC: "rope_type" was originally "type"
rope_type_str = config.rope_scaling.get('rope_type', config.rope_scaling.get('type', 'default'))
rope_type_str = rope_scaling.get('rope_type', rope_scaling.get('type', 'default'))
if rope_type_str == 'fope':
rope_type_str = 'default'
build_funcs = dict(default=_get_default_rope_parameters,
Expand Down Expand Up @@ -176,14 +185,25 @@ def build_rotary_embedding(dim: int,
return impl


def get_rope_theta(config: PretrainedConfig, default: int = 10000) -> int:
"""Get rope theta from config."""
if hasattr(config, 'rope_parameters'):
# for transformers v5
rope_base = config.rope_parameters.get('rope_theta', default)
else:
rope_base = getattr(config, 'rope_theta', default)
return rope_base


def build_rotary_embedding_from_config(config: PretrainedConfig, device: torch.device = None) -> nn.Module:
"""Build rotary embedding op from config."""
emb_type = RopeType.LinearScaling
rope_dim = getattr(config, 'head_dim', None)
if rope_dim is None:
rope_dim = config.hidden_size // config.num_attention_heads
rope_max_pos_emb = config.max_position_embeddings
rope_base = config.rope_theta

rope_base = get_rope_theta(config, default=10000)
rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base)
update_params = build_rotary_params(config)
rope_params.update(update_params)
Expand Down
Loading