Skip to content

feat: Add native FP8 model support with scale_inv dequantization#443

Closed
nyo16 wants to merge 0 commit intoelixir-nx:mainfrom
nyo16:main
Closed

feat: Add native FP8 model support with scale_inv dequantization#443
nyo16 wants to merge 0 commit intoelixir-nx:mainfrom
nyo16:main

Conversation

@nyo16
Copy link
Contributor

@nyo16 nyo16 commented Jan 8, 2026

Summary

Add native FP8 quantized model support for models like Qwen3-FP8. This enables loading and running FP8 models with per-block
scale factors (scale_inv) for dequantization.

Changes

bumblebee.ex

  • Add :preserve_source_types option to load_model/2 to keep FP8 types during loading

pytorch_params.ex

  • Pass preserve_source_types through param loading pipeline
  • Modify ensure_type/3 to preserve FP8 types when option is set

layers.ex

  • Add fp8_aware_dense/3 layer that handles FP8 quantized weights
  • Implements block-wise dequantization using scale_inv parameter
  • Automatically falls back to identity scaling (1.0) for non-FP8 models

layers/transformer.ex

  • Add :attention_dense option to blocks/2, block/2, multi_head_attention/4
  • Allows custom dense function for Q, K, V, and output projections

text/qwen3.ex

  • Update decoder to use fp8_aware_dense for attention via attention_dense option
  • Update gated_ffn to use fp8_aware_dense for FFN layers
  • Add scale_inv to params_mapping for all attention and FFN layers

Test plan

  • FP8 model (Qwen3-0.6B-FP8) generates correct output ("Paris" for capital of France)
  • Non-FP8 model (Qwen3-0.6B) still works correctly (backward compatible)
  • Tested on RTX 5070 Ti (Blackwell, SM 12.0)

Dependencies

Requires (merge in order):

  1. elixir-nx/safetensors - FP8 file I/O
  2. elixir-nx/nx - FP8 type system support

Usage

# Load FP8 model with native weights
{:ok, model_info} = Bumblebee.load_model(
  {:hf, "Qwen/Qwen3-0.6B-FP8"},
  architecture: :for_causal_language_modeling,
  preserve_source_types: true
)

# Use normally - scale_inv dequantization happens automatically
serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config)
Nx.Serving.run(serving, "The capital of France is")
# => "Paris..."

@nyo16 nyo16 marked this pull request as draft January 8, 2026 17:39
@josevalim
Copy link
Contributor

Thank you! This PR should probably wait until this is done: elixir-nx/nx#1657 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants