Support ignore layers in quant config for qwen3 models#4293
Support ignore layers in quant config for qwen3 models#4293RunningLeon wants to merge 4 commits intoInternLM:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This pull request adds support for ignoring specific layers in quantization configurations, allowing fine-grained control over which layers are quantized. The changes involve significant refactoring of how quantization configurations are passed and used throughout the codebase.
Changes:
- Introduced a new
QuantizationConfigdataclass to centralize quantization configuration management - Added
prefixparameter throughout the model hierarchy to enable layer-specific quantization control - Refactored quantization config handling from dictionary-based to object-based approach
- Removed
model_formatparameter from several spec_decode interfaces, consolidating it into the new config system
Reviewed changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| lmdeploy/pytorch/config.py | Added QuantizationConfig dataclass with methods for layer-specific quantization control and config parsing |
| lmdeploy/pytorch/model_inputs.py | Added quant_config field to BuildModelContext for passing quantization config during model building |
| lmdeploy/pytorch/models/utils/model.py | Added update_quant_config method to handle layer name transformations for ignored layers |
| lmdeploy/pytorch/models/patch.py | Refactored quantization config patching and removed old _patch_quantization_config function |
| lmdeploy/pytorch/models/qwen*.py | Added prefix parameter propagation throughout model hierarchies for layer identification |
| lmdeploy/pytorch/models/glm4_1v.py, internvl3_hf.py, qwen3_vl.py | Changed rename_weight to classmethod to support static usage |
| lmdeploy/pytorch/nn/linear/init.py | Updated linear layer builders to accept and use prefix for quantization control |
| lmdeploy/pytorch/nn/norm.py | Updated RMSNorm to accept prefix and use QuantizationConfig object |
| lmdeploy/pytorch/nn/moe/init.py | Updated MoE builder to accept prefix and use QuantizationConfig object |
| lmdeploy/pytorch/engine/model_agent/agent.py | Updated to pass quant_config through BuildModelContext |
| lmdeploy/pytorch/engine/executor/init.py | Added model_format parameter passing to build_executor |
| lmdeploy/pytorch/spec_decode/*.py | Removed model_format parameter from build_model methods |
Comments suppressed due to low confidence (1)
lmdeploy/pytorch/models/qwen3.py:75
- The RMSNorm instances on lines 74-75 (q_norm and k_norm) are instantiated without the prefix parameter, while other RMSNorm instances in this file (like input_layernorm and post_attention_layernorm) are being updated to include prefix. For consistency and to ensure the quantization ignore layers functionality works correctly for these norms, they should also include the prefix parameter.
These should be updated to:
- self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device, prefix=add_prefix('q_norm', prefix))
- self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device, prefix=add_prefix('k_norm', prefix))
self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)
self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if not prefix or not self.ignored_layers: | ||
| return self.quant_method | ||
|
|
||
| is_ignore = any([prefix in layer_name for layer_name in self.ignored_layers]) |
There was a problem hiding this comment.
The method check on line 600 uses substring matching with prefix in layer_name, which could lead to false positives. For example, if prefix is "layer" and ignored_layers contains "my_layer_norm", it would incorrectly match.
A more robust approach would be to check if the prefix exactly matches the beginning of the layer name or matches a complete segment. Consider using layer_name.startswith(prefix) or checking for exact component matches with proper delimiter handling.
| is_ignore = any([prefix in layer_name for layer_name in self.ignored_layers]) | |
| def _matches_prefix(p: str, layer_name: str) -> bool: | |
| """Return True if p and layer_name refer to the same module or | |
| one is a dotted-prefix of the other. | |
| """ | |
| if not p or not layer_name: | |
| return False | |
| if p == layer_name: | |
| return True | |
| if layer_name.startswith(p + '.'): | |
| return True | |
| if p.startswith(layer_name + '.'): | |
| return True | |
| return False | |
| is_ignore = any(_matches_prefix(prefix, layer_name) | |
| for layer_name in self.ignored_layers) |
| # build attention layer norm | ||
| self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) | ||
| self.post_attention_layernorm = RMSNorm(config.hidden_size, | ||
| config.rms_norm_eps, |
There was a problem hiding this comment.
The post_attention_layernorm RMSNorm is missing the quant_config parameter on lines 310-314, while the input_layernorm on lines 302-307 correctly includes it. This inconsistency could cause issues with quantization configuration. Both layer norms should have the same parameters for consistency.
The post_attention_layernorm should include quant_config=quantization_config to match the pattern used for input_layernorm.
| config.rms_norm_eps, | |
| config.rms_norm_eps, | |
| quant_config=quantization_config, |
| quant_config = getattr(hf_config, 'quantization_config', None) | ||
| if quant_config is None: | ||
| if model_format == 'fp8': | ||
| from lmdeploy.pytorch.envs import scale_fmt | ||
| quant_config = dict(quant_method='fp8', fmt='e4m3', weight_block_size=[128, 128], scale_fmt=scale_fmt) | ||
|
|
||
| if quant_config is None: | ||
| return cls() |
There was a problem hiding this comment.
In the from_config method, when hf_config already has a quantization_config set, the model_format parameter (if provided as 'fp8') will be ignored. This behavior differs slightly from the removed _patch_quantization_config function which would log a warning. Consider adding a warning log when model_format is provided but ignored due to existing quantization_config, to help users understand why their model_format specification is not being applied.
| if '.experts' in layer_name: | ||
| added_ignore_layers.add(layer_name.split('.experts', 1)[0] + '.experts') | ||
| else: | ||
| added_ignore_layers.add(layer_name.replace('.down_proj', '.down_proj')) |
There was a problem hiding this comment.
The logic on line 79 appears to be incorrect. When replacing '.down_proj', it should replace it with something meaningful, but currently it's replacing '.down_proj' with '.down_proj' (same text), which is a no-op. This seems like a copy-paste error from the '.gate_proj' case above.
Looking at line 74, when handling '.gate_proj', it correctly replaces it with '.gate_up_proj'. For '.down_proj', this line should likely either:
- Be removed if no transformation is needed for down_proj in non-MoE layers
- Have a different replacement target
Based on the pattern, this appears to be a bug where the replacement should either not happen or should have different logic.
| added_ignore_layers.add(layer_name.replace('.down_proj', '.down_proj')) | |
| added_ignore_layers.add(layer_name) |
Motivation
Support ignore layers in quant config
Modification
Please briefly describe what modification is made in this PR.
BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist