Skip to content

Commit 0c4cfc3

Browse files
nyo16claude
andcommitted
feat: Add native FP8 model support with scale_inv dequantization
Add comprehensive FP8 quantized model support for models like Qwen3-FP8. This enables loading and running FP8 models with per-block scale factors. Changes: bumblebee.ex: - Add :preserve_source_types option to load_model/2 to keep FP8 types pytorch_params.ex: - Pass preserve_source_types through param loading pipeline - Modify ensure_type/3 to preserve FP8 types when option is set layers.ex: - Add fp8_aware_dense/3 layer that handles FP8 quantized weights - Implements block-wise dequantization using scale_inv parameter - Automatically falls back to identity scaling for non-FP8 models layers/transformer.ex: - Add :attention_dense option to blocks/2, block/2, multi_head_attention/4 - Allows custom dense function for Q, K, V, and output projections text/qwen3.ex: - Update decoder to use fp8_aware_dense for attention via attention_dense - Update gated_ffn to use fp8_aware_dense for FFN layers - Add scale_inv to params_mapping for all attention and FFN layers The implementation supports both: - Pre-dequantization: Convert FP8->F32 before loading - Native FP8: Load FP8 weights directly, apply scale_inv at runtime Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent bbd4d83 commit 0c4cfc3

File tree

5 files changed

+257
-27
lines changed

5 files changed

+257
-27
lines changed

lib/bumblebee.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,8 @@ defmodule Bumblebee do
607607
:params_filename,
608608
:log_params_diff,
609609
:backend,
610-
:type
610+
:type,
611+
:preserve_source_types
611612
])
612613

613614
with {:ok, repo_files} <- get_repo_files(repository),
@@ -654,7 +655,7 @@ defmodule Bumblebee do
654655
[
655656
params_mapping: params_mapping,
656657
loader_fun: loader_fun
657-
] ++ Keyword.take(opts, [:backend, :log_params_diff])
658+
] ++ Keyword.take(opts, [:backend, :log_params_diff, :preserve_source_types])
658659

659660
params = Bumblebee.Conversion.PyTorchParams.load_params!(model, input_template, paths, opts)
660661
{:ok, params}

lib/bumblebee/conversion/pytorch_params.ex

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ defmodule Bumblebee.Conversion.PyTorchParams do
2828
and loads the params file. Defaults to
2929
`Bumblebee.Conversion.PyTorchLoader.load!/1`
3030
31+
* `:preserve_source_types` - when `true`, preserves FP8 types from the
32+
source file instead of converting them to the model's expected type.
33+
This is useful for loading quantized models that use FP8 weights.
34+
Defaults to `false`
35+
3136
"""
3237
@spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: %Axon.ModelState{}
3338
def load_params!(model, input_template, path, opts \\ []) do
@@ -36,6 +41,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do
3641
|> Keyword.validate!([
3742
:log_params_diff,
3843
:backend,
44+
:preserve_source_types,
3945
params_mapping: %{},
4046
loader_fun: &Bumblebee.Conversion.PyTorchLoader.load!/1
4147
])
@@ -58,7 +64,8 @@ defmodule Bumblebee.Conversion.PyTorchParams do
5864
model_state = Axon.trace_init(model, input_template)
5965

6066
params_expr = model_state.data
61-
{params, diff} = init_params(model, params_expr, pytorch_state, opts[:params_mapping])
67+
preserve_source_types = opts[:preserve_source_types] || false
68+
{params, diff} = init_params(model, params_expr, pytorch_state, opts[:params_mapping], preserve_source_types)
6269
model_state = %{model_state | data: params}
6370

6471
params_complete? = diff.missing == [] and diff.mismatched == []
@@ -95,15 +102,15 @@ defmodule Bumblebee.Conversion.PyTorchParams do
95102
Nx.Container.impl_for(value) != nil
96103
end
97104

98-
defp init_params(model, params_expr, pytorch_state, params_mapping) do
105+
defp init_params(model, params_expr, pytorch_state, params_mapping, preserve_source_types) do
99106
layers =
100107
model
101108
|> Utils.Axon.nodes_with_names()
102109
|> Enum.filter(fn {layer, _name} -> layer.parameters != [] end)
103110

104111
prefixes = infer_prefixes(layers, pytorch_state, params_mapping)
105112

106-
diff = %{missing: [], mismatched: [], used_keys: []}
113+
diff = %{missing: [], mismatched: [], used_keys: [], preserve_source_types: preserve_source_types}
107114

108115
{params, diff} =
109116
layers
@@ -155,7 +162,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do
155162

156163
case verify_param_shape(param_expr, value) do
157164
:ok ->
158-
value = ensure_type(param_expr, value)
165+
value = ensure_type(param_expr, value, diff.preserve_source_types)
159166
{value, diff}
160167

161168
{:error, expected, actual} ->
@@ -507,11 +514,13 @@ defmodule Bumblebee.Conversion.PyTorchParams do
507514
Utils.Nx.map(expr, &Nx.shape/1)
508515
end
509516

510-
defp ensure_type(param_expr, value) do
517+
defp ensure_type(param_expr, value, preserve_source_types \\ false) do
511518
Utils.Nx.zip_with(param_expr, value, fn expr, tensor ->
512-
case {Nx.type(expr), Nx.type(tensor)} do
513-
{type, type} -> tensor
514-
{expected, _actual} -> Nx.as_type(tensor, expected)
519+
case {Nx.type(expr), Nx.type(tensor), preserve_source_types} do
520+
{type, type, _} -> tensor
521+
# Preserve FP8 types when preserve_source_types is enabled
522+
{_expected, {:f, 8, _format}, true} -> tensor
523+
{expected, _actual, _} -> Nx.as_type(tensor, expected)
515524
end
516525
end)
517526
end

lib/bumblebee/layers.ex

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,128 @@ defmodule Bumblebee.Layers do
438438
|> Nx.add(bias)
439439
end
440440

441+
@doc """
442+
Adds an FP8-aware dense layer to the network.
443+
444+
This layer supports optional scale_inv parameter for FP8 quantized weights.
445+
When scale_inv is provided, it's applied to the matmul output to account
446+
for FP8 quantization scaling.
447+
448+
The kernel parameter uses standard dense layout (transposed from PyTorch).
449+
450+
## Options
451+
452+
* `:name` - layer name
453+
454+
* `:kernel_initializer` - initializer for `kernel` weights.
455+
Defaults to `:glorot_uniform`
456+
457+
* `:use_bias` - whether the layer should add bias to the output.
458+
Defaults to `false`
459+
460+
* `:block_size` - the block size used for FP8 quantization.
461+
Defaults to 128
462+
463+
"""
464+
def fp8_aware_dense(%Axon{} = x, units, opts \\ []) do
465+
opts =
466+
Keyword.validate!(opts, [
467+
:name,
468+
kernel_initializer: :glorot_uniform,
469+
use_bias: false,
470+
block_size: 128
471+
])
472+
473+
name = opts[:name]
474+
block_size = opts[:block_size]
475+
476+
kernel_shape = &Axon.Shape.dense_kernel(&1, units)
477+
bias_shape = &Axon.Shape.dense_bias(&1, units)
478+
479+
# Scale shape: [input_blocks, output_blocks] where block_size is typically 128
480+
# This matches the transposed layout from PyTorch (kernel is transposed, so is scale)
481+
# For non-FP8 models, scale_inv will be initialized to 1.0
482+
scale_shape = fn input_shape ->
483+
in_features = elem(input_shape, tuple_size(input_shape) - 1)
484+
out_features = units
485+
# Round up to handle cases where dimensions aren't exact multiples of block_size
486+
out_blocks = div(out_features + block_size - 1, block_size)
487+
in_blocks = div(in_features + block_size - 1, block_size)
488+
# Note: [in_blocks, out_blocks] to match transposed scale_inv from PyTorch
489+
{in_blocks, out_blocks}
490+
end
491+
492+
kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
493+
494+
# scale_inv is initialized to 1.0 (identity) for non-FP8 models
495+
# For FP8 models, it will be loaded from the checkpoint
496+
scale_inv = Axon.param("scale_inv", scale_shape, initializer: :ones)
497+
498+
{inputs, op} =
499+
if opts[:use_bias] do
500+
bias = Axon.param("bias", bias_shape, initializer: :zeros)
501+
{[x, kernel, scale_inv, bias], &fp8_aware_dense_impl(&1, &2, &3, &4, &5, block_size)}
502+
else
503+
{[x, kernel, scale_inv], &fp8_aware_dense_impl(&1, &2, &3, nil, &4, block_size)}
504+
end
505+
506+
Axon.layer(op, inputs, name: name, op_name: :fp8_aware_dense)
507+
end
508+
509+
deftransformp fp8_aware_dense_impl(x, kernel, scale_inv, bias, _opts, block_size) do
510+
# Dequantize the kernel using scale_inv before matmul
511+
# kernel: [in_features, out_features]
512+
# scale_inv: [in_blocks, out_blocks] (transposed from PyTorch layout)
513+
# Each 128x128 block of the kernel should be multiplied by its scale
514+
kernel_dequant = dequantize_kernel(kernel, scale_inv, block_size)
515+
516+
# Do the matmul with dequantized kernel
517+
# x: [batch, seq_len, in_features]
518+
# kernel_dequant: [in_features, out_features]
519+
# result: [batch, seq_len, out_features]
520+
result = Nx.dot(x, [-1], kernel_dequant, [0])
521+
522+
# Add bias if present
523+
if bias do
524+
Nx.add(result, bias)
525+
else
526+
result
527+
end
528+
end
529+
530+
defp dequantize_kernel(kernel, scale_inv, block_size) do
531+
# kernel: [in_features, out_features]
532+
# scale_inv: [in_blocks, out_blocks] where in_blocks = ceil(in_features/128)
533+
#
534+
# To dequantize: for each element kernel[i,o], multiply by scale_inv[i/128, o/128]
535+
# This is done by expanding scale_inv to match kernel shape
536+
537+
{in_features, out_features} = Nx.shape(kernel)
538+
{in_blocks, out_blocks} = Nx.shape(scale_inv)
539+
540+
# Expand scale_inv to [in_features, out_features]
541+
# Each scale value is replicated block_size times in both dimensions
542+
scale_expanded =
543+
scale_inv
544+
# Replicate along input dimension: [in_blocks, out_blocks] -> [in_blocks * block_size, out_blocks]
545+
|> Nx.reshape({in_blocks, 1, out_blocks})
546+
|> Nx.broadcast({in_blocks, block_size, out_blocks})
547+
|> Nx.reshape({in_blocks * block_size, out_blocks})
548+
# Replicate along output dimension: [..., out_blocks] -> [..., out_blocks * block_size]
549+
|> Nx.reshape({in_blocks * block_size, out_blocks, 1})
550+
|> Nx.broadcast({in_blocks * block_size, out_blocks, block_size})
551+
|> Nx.reshape({in_blocks * block_size, out_blocks * block_size})
552+
553+
# Slice to exact kernel dimensions (in case they're not exact multiples of block_size)
554+
scale_expanded =
555+
scale_expanded
556+
|> Nx.slice([0, 0], [in_features, out_features])
557+
558+
# Convert kernel to higher precision for dequantization, then multiply by scale
559+
kernel_f32 = Nx.as_type(kernel, {:f, 32})
560+
Nx.multiply(kernel_f32, scale_expanded)
561+
end
562+
441563
@doc """
442564
Adds a 1-dimensional convolution layer to the network.
443565

lib/bumblebee/layers/transformer.ex

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ defmodule Bumblebee.Layers.Transformer do
6363
:block_type,
6464
:attention_scale,
6565
:query_norm,
66-
:key_norm
66+
:key_norm,
67+
:attention_dense
6768
]
6869

6970
opts =
@@ -354,7 +355,8 @@ defmodule Bumblebee.Layers.Transformer do
354355
attention_scale: nil,
355356
rotary_embedding: nil,
356357
query_norm: nil,
357-
key_norm: nil
358+
key_norm: nil,
359+
attention_dense: nil
358360
])
359361

360362
name = opts[:name]
@@ -386,6 +388,7 @@ defmodule Bumblebee.Layers.Transformer do
386388
rotary_embedding = opts[:rotary_embedding]
387389
query_norm = opts[:query_norm]
388390
key_norm = opts[:key_norm]
391+
attention_dense = opts[:attention_dense]
389392

390393
ffn_fun =
391394
case ffn do
@@ -446,6 +449,7 @@ defmodule Bumblebee.Layers.Transformer do
446449
rotary_embedding: rotary_embedding,
447450
query_norm: query_norm,
448451
key_norm: key_norm,
452+
attention_dense: attention_dense,
449453
name: join(name, "self_attention")
450454
)
451455

@@ -491,6 +495,7 @@ defmodule Bumblebee.Layers.Transformer do
491495
attention_window_size: attention_window_size,
492496
attention_scale: attention_scale,
493497
rotary_embedding: rotary_embedding,
498+
attention_dense: attention_dense,
494499
name: join(name, "cross_attention")
495500
)
496501

@@ -772,7 +777,8 @@ defmodule Bumblebee.Layers.Transformer do
772777
output_use_bias: true,
773778
rotary_embedding: nil,
774779
query_norm: nil,
775-
key_norm: nil
780+
key_norm: nil,
781+
attention_dense: nil
776782
])
777783

778784
attention_mask = opts[:attention_mask]
@@ -792,6 +798,7 @@ defmodule Bumblebee.Layers.Transformer do
792798
rotary_embedding = opts[:rotary_embedding]
793799
query_norm = opts[:query_norm]
794800
key_norm = opts[:key_norm]
801+
attention_dense = opts[:attention_dense]
795802

796803
query_use_bias = opts[:query_use_bias]
797804
key_use_bias = opts[:key_use_bias]
@@ -804,9 +811,18 @@ defmodule Bumblebee.Layers.Transformer do
804811
inner_size = num_heads * attention_head_size
805812
inner_kv_size = num_key_value_heads * attention_head_size
806813

814+
# Helper to create dense layer, using custom attention_dense if provided
815+
dense_fn = fn input, units, dense_opts ->
816+
if attention_dense do
817+
attention_dense.(input, units, dense_opts)
818+
else
819+
Axon.dense(input, units, dense_opts)
820+
end
821+
end
822+
807823
query =
808824
query
809-
|> Axon.dense(inner_size,
825+
|> dense_fn.(inner_size,
810826
kernel_initializer: kernel_initializer,
811827
name: join(name, "query"),
812828
use_bias: query_use_bias
@@ -815,7 +831,7 @@ defmodule Bumblebee.Layers.Transformer do
815831

816832
key =
817833
key
818-
|> Axon.dense(inner_kv_size,
834+
|> dense_fn.(inner_kv_size,
819835
kernel_initializer: kernel_initializer,
820836
name: join(name, "key"),
821837
use_bias: key_use_bias
@@ -824,7 +840,7 @@ defmodule Bumblebee.Layers.Transformer do
824840

825841
value =
826842
value
827-
|> Axon.dense(inner_kv_size,
843+
|> dense_fn.(inner_kv_size,
828844
kernel_initializer: kernel_initializer,
829845
name: join(name, "value"),
830846
use_bias: value_use_bias
@@ -937,7 +953,7 @@ defmodule Bumblebee.Layers.Transformer do
937953
attention_output =
938954
attention_output
939955
|> Layers.flatten_trailing()
940-
|> Axon.dense(hidden_size,
956+
|> dense_fn.(hidden_size,
941957
kernel_initializer: kernel_initializer,
942958
name: join(name, "output"),
943959
use_bias: output_use_bias

0 commit comments

Comments
 (0)