fix rotary embedding for transformers v5#4303
Conversation
There was a problem hiding this comment.
Pull request overview
Updates RoPE (rotary embedding) configuration handling to improve compatibility with Transformers v5 configs by introducing helpers to read rope_theta and (partially) rope_parameters.
Changes:
- Added
_get_rope_parameters()andget_rope_theta()helpers inrotary_embedding.py. - Switched multiple model implementations from direct
config.rope_thetaaccess toget_rope_theta(config). - Updated RoPE scaling param extraction in
rotary_embedding.pyto read fromrope_parameterswhen present.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| lmdeploy/pytorch/nn/rotary_embedding.py | Adds v5-aware helpers and updates RoPE param/theta extraction logic. |
| lmdeploy/pytorch/models/phi3_moe.py | Uses get_rope_theta when building rotary embedding. |
| lmdeploy/pytorch/models/minicpm3.py | Uses get_rope_theta when building rotary embedding. |
| lmdeploy/pytorch/models/llama4.py | Uses get_rope_theta for vision rotary frequency computation. |
| lmdeploy/pytorch/models/internlm2_ve.py | Uses get_rope_theta when building rotary embedding. |
| lmdeploy/pytorch/models/deepseek_v32.py | Uses get_rope_theta when building rotary embedding params. |
| lmdeploy/pytorch/models/deepseek_v2.py | Uses get_rope_theta when building rotary embedding params. |
| lmdeploy/pytorch/models/deepseek_mtp.py | Uses get_rope_theta when building rotary embedding params. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def get_rope_theta(config: PretrainedConfig, default: int = 100000) -> int: | ||
| """Get rope theta from config.""" | ||
| if hasattr(config, 'rope_parameters'): | ||
| # for transformers v5 | ||
| rope_base = config.rope_parameters.get('rope_theta', default) | ||
| else: | ||
| rope_base = config.rope_theta | ||
| return rope_base | ||
|
|
There was a problem hiding this comment.
The new Transformers v5 compatibility paths (_get_rope_parameters / get_rope_theta) don’t appear to be covered by tests. Adding a small unit test that constructs a dummy PretrainedConfig with rope_parameters (and without rope_scaling / rope_theta) would help prevent regressions in RoPE parameter parsing.
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
Please describe the motivation of this PR and the goal you want to achieve through this PR.
Modification
Please briefly describe what modification is made in this PR.
BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist