Enable loading pre-quantized INT4 weights in Llama4#330
Open
jiawenliu64 wants to merge 2 commits intomainfrom
Open
Enable loading pre-quantized INT4 weights in Llama4#330jiawenliu64 wants to merge 2 commits intomainfrom
jiawenliu64 wants to merge 2 commits intomainfrom
Conversation
Generate INT4 MP8 checkpoint: ``` torchrun --nproc-per-node=8 -m models.llama4.scripts.quantize --ckpt_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct --output_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct-INT4 --quantization_mode int4_mixed --world_size 8 ``` Verify generated INT4 MP8 checkpoint with int4_mixed on single GPU (output): ``` PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=1 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct-INT4 --world_size 1 --quantization-mode int4_mixed ``` Generate FP8 MP8 checkpoint: ``` torchrun --nproc-per-node=8 -m models.llama4.scripts.quantize --ckpt_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct --output_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct-FP8 --quantization_mode fp8_mixed --world_size 8 ``` Verify generated FP8 MP8 checkpoint with fp8_mixed (output): ``` PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct-FP8 --world_size 8 --quantization-mode fp8_mixed ``` Verify BF16 MP8 checkpoint (output): ``` PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 8 ``` Verify BF16 MP8 checkpoint with fp8_mixed (output): ``` PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 8 --quantization-mode fp8_mixed ``` Verify BF16 MP8 checkpoint with int4_mixed on single GPU (output): ``` PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=1 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 1 --quantization-mode int4_mixed ```
jianyuh
approved these changes
Apr 24, 2025
shethaadit
approved these changes
Apr 25, 2025
ashwinb
reviewed
Apr 26, 2025
|
|
||
| self.int4_weight = int4_weight | ||
| dtype = torch.get_default_dtype() | ||
| if int4_weight: |
Contributor
There was a problem hiding this comment.
this feels like complexity that truly doesn't belong at this layer. can we please keep it outside into quantization code somehow?
Contributor
There was a problem hiding this comment.
we don't want llama-models to become torchao or vllm or whatever really. it is not a full fledged all powerful inference engine.
ashwinb
requested changes
Apr 26, 2025
| torch.set_default_tensor_type(torch.BFloat16Tensor) | ||
| model = Transformer(model_args) | ||
| print("Loading state dict...") | ||
| model.load_state_dict(state_dict, strict=False) |
Contributor
There was a problem hiding this comment.
if you move the model.load_state_dict() to convert_to_quantized_model() then you can do the following:
- change the structure of the Transformer from the outside in this code path (whatever you are doing with Experts)
- move all this scale ckpt paths complexity into quantization land
nobody reading generation.py should know about quantization unless they want to dig into it.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Generate INT4 MP8 checkpoint:
Verify generated INT4 MP8 checkpoint with int4_mixed on single GPU (output):
Generate FP8 MP8 checkpoint:
Verify generated FP8 MP8 checkpoint with fp8_mixed (output):
Verify BF16 MP8 checkpoint (output):
Verify BF16 MP8 checkpoint with fp8_mixed (output):
Verify BF16 MP8 checkpoint with int4_mixed on single GPU (output):