Skip to content

Commit 2b0df16

Browse files
committed
Refactor Gemma 3 to use Layers.Transformer.blocks and add reference tests
- Refactor decoder to use shared Layers.Transformer.blocks infrastructure - Use per-layer attention_window_size function for alternating local/global attention - Use query_norm/key_norm options for QK-normalization - Use custom block_type function for Gemma 3's unique normalization structure - Add assert_all_close with reference values from Python transformers - Fix bug in Layers.Transformer.blocks where attention_window_size was duplicated when using a function for per-layer configuration - Update params_mapping to use query_norm/key_norm naming from shared infrastructure
1 parent 726fef6 commit 2b0df16

File tree

3 files changed

+96
-218
lines changed

3 files changed

+96
-218
lines changed

lib/bumblebee/layers/transformer.ex

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ defmodule Bumblebee.Layers.Transformer do
4343
def blocks(hidden_state, opts) do
4444
validate_required_keys!(opts, [:num_blocks, :num_attention_heads, :hidden_size, :ffn])
4545

46+
# Note: :attention_window_size is NOT in block_opts_keys because it's handled
47+
# specially (supports per-layer function) and passed explicitly to block/2
4648
block_opts_keys = [
4749
:num_attention_heads,
4850
:num_key_value_heads,
@@ -59,7 +61,6 @@ defmodule Bumblebee.Layers.Transformer do
5961
:output_use_bias,
6062
:layer_norm,
6163
:block_type,
62-
:attention_window_size,
6364
:scale_attention_weights,
6465
:query_norm,
6566
:key_norm

lib/bumblebee/text/gemma3.ex

Lines changed: 75 additions & 217 deletions
Original file line numberDiff line numberDiff line change
@@ -362,103 +362,79 @@ defmodule Bumblebee.Text.Gemma3 do
362362
) do
363363
name = opts[:name]
364364

365-
# Use cached attention mask
366-
{attention_mask, cache} = Layers.Decoder.cached_attention_mask(attention_mask, cache)
367-
offset = Layers.Decoder.get_cache_offset(cache)
368-
369-
state = %{
370-
hidden_state: hidden_state,
371-
hidden_states: Axon.container({hidden_state}),
372-
attentions: Axon.container({}),
373-
cache: cache
374-
}
375-
376-
outputs =
377-
for idx <- 0..(spec.num_blocks - 1), reduce: state do
378-
state ->
379-
block_attention_head_mask = Axon.nx(attention_head_mask, & &1[idx])
380-
block_cache = Layers.Decoder.get_block_cache(state.cache, idx)
381-
382-
# Gemma 3 alternates between local (sliding window) and global attention
383-
# Every global_attention_layer_interval-th layer uses global attention
384-
attention_window_size =
385-
if rem(idx + 1, spec.global_attention_layer_interval) == 0 do
386-
# Global attention (no window)
387-
nil
388-
else
389-
# Local attention with sliding window
390-
{spec.sliding_window, spec.sliding_window}
391-
end
392-
393-
{hidden_state, attention, block_cache} =
394-
gemma3_block(state.hidden_state,
395-
attention_mask: attention_mask,
396-
attention_head_mask: block_attention_head_mask,
397-
block_cache: block_cache,
398-
offset: offset,
399-
position_ids: position_ids,
400-
attention_window_size: attention_window_size,
401-
spec: spec,
402-
name: join(name, "blocks.#{idx}")
403-
)
404-
405-
cache = Layers.Decoder.put_block_cache(state.cache, idx, block_cache)
406-
407-
%{
408-
hidden_state: hidden_state,
409-
hidden_states: Layers.append(state.hidden_states, hidden_state),
410-
attentions: Layers.append(state.attentions, attention),
411-
cache: cache
412-
}
365+
# QK-norm functions for Gemma 3 (uses shift: 1.0 for (1+weight) formula)
366+
query_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2)
367+
key_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2)
368+
369+
# Per-layer attention window size for alternating local/global attention
370+
# Every global_attention_layer_interval-th layer uses global attention
371+
attention_window_size = fn idx ->
372+
if rem(idx + 1, spec.global_attention_layer_interval) == 0 do
373+
nil
374+
else
375+
{spec.sliding_window, spec.sliding_window}
413376
end
377+
end
414378

415-
outputs = update_in(outputs.cache, &Layers.Decoder.update_cache_offset(&1, hidden_state))
379+
# Custom block_type function for Gemma 3's unique block structure
380+
block_type = fn hidden_state, steps, block_name ->
381+
gemma3_block_impl(hidden_state, steps, block_name, spec)
382+
end
416383

417-
%{
418-
hidden_state: outputs.hidden_state,
419-
hidden_states: outputs.hidden_states,
420-
attentions: outputs.attentions,
421-
cache: outputs.cache
422-
}
384+
Layers.Transformer.blocks(hidden_state,
385+
attention_mask: attention_mask,
386+
attention_head_mask: attention_head_mask,
387+
cache: cache,
388+
num_blocks: spec.num_blocks,
389+
num_attention_heads: spec.num_attention_heads,
390+
num_key_value_heads: spec.num_key_value_heads,
391+
hidden_size: spec.hidden_size,
392+
attention_head_size: spec.attention_head_size,
393+
kernel_initializer: kernel_initializer(spec),
394+
layer_norm:
395+
&Layers.rms_norm(&1,
396+
shift: 1.0,
397+
name: &2,
398+
epsilon: spec.layer_norm_epsilon,
399+
upcast: :all
400+
),
401+
ffn:
402+
&gated_ffn(&1, spec.intermediate_size, spec.hidden_size,
403+
name: &2,
404+
activation: spec.activation
405+
),
406+
block_type: block_type,
407+
causal: true,
408+
rotary_embedding: [
409+
position_ids: position_ids,
410+
max_positions: spec.max_positions,
411+
base: spec.rotary_embedding_base,
412+
scaling_strategy: spec.rotary_embedding_scaling_strategy
413+
],
414+
attention_window_size: attention_window_size,
415+
query_norm: query_norm,
416+
key_norm: key_norm,
417+
query_use_bias: spec.use_attention_bias,
418+
key_use_bias: spec.use_attention_bias,
419+
value_use_bias: spec.use_attention_bias,
420+
output_use_bias: spec.use_attention_bias,
421+
name: join(name, "blocks")
422+
)
423423
end
424424

425-
defp gemma3_block(hidden_state, opts) do
426-
attention_mask = opts[:attention_mask]
427-
attention_head_mask = opts[:attention_head_mask]
428-
block_cache = opts[:block_cache]
429-
offset = opts[:offset]
430-
position_ids = opts[:position_ids]
431-
attention_window_size = opts[:attention_window_size]
432-
spec = opts[:spec]
433-
name = opts[:name]
434-
435-
{self_attention_cache, cross_attention_cache} =
436-
Layers.Decoder.get_attention_caches(block_cache)
437-
438-
# Self-attention with pre-norm (input_layernorm)
425+
# Custom block implementation for Gemma 3's unique normalization structure:
426+
# - Post-attention norm BEFORE residual add
427+
# - Pre/post FFN norms
428+
defp gemma3_block_impl(hidden_state, steps, name, spec) do
429+
# Pre-attention norm + attention (using provided steps)
439430
shortcut = hidden_state
440431

441-
hidden_state =
442-
Layers.rms_norm(hidden_state,
443-
shift: 1.0,
444-
name: join(name, "self_attention_norm"),
445-
epsilon: spec.layer_norm_epsilon,
446-
upcast: :all
447-
)
448-
449-
{hidden_state, attention, self_attention_cache} =
450-
gemma3_attention(hidden_state, hidden_state, hidden_state,
451-
attention_mask: attention_mask,
452-
attention_head_mask: attention_head_mask,
453-
attention_cache: self_attention_cache,
454-
offset: offset,
455-
position_ids: position_ids,
456-
attention_window_size: attention_window_size,
457-
spec: spec,
458-
name: join(name, "self_attention")
459-
)
432+
{hidden_state, attention_info} =
433+
hidden_state
434+
|> steps.self_attention_norm.()
435+
|> steps.self_attention.()
460436

461-
# Post-attention norm BEFORE residual add (Gemma 3 specific)
437+
# Post-attention norm BEFORE residual (Gemma 3 specific)
462438
hidden_state =
463439
Layers.rms_norm(hidden_state,
464440
shift: 1.0,
@@ -467,7 +443,6 @@ defmodule Bumblebee.Text.Gemma3 do
467443
upcast: :all
468444
)
469445

470-
# Residual add AFTER post_attention_norm
471446
hidden_state = Axon.add(shortcut, hidden_state)
472447

473448
# FFN with pre/post norms (Gemma 3 specific)
@@ -481,11 +456,7 @@ defmodule Bumblebee.Text.Gemma3 do
481456
upcast: :all
482457
)
483458

484-
hidden_state =
485-
gated_ffn(hidden_state, spec.intermediate_size, spec.hidden_size,
486-
name: join(name, "ffn"),
487-
activation: spec.activation
488-
)
459+
hidden_state = steps.ffn.(hidden_state)
489460

490461
hidden_state =
491462
Layers.rms_norm(hidden_state,
@@ -497,126 +468,13 @@ defmodule Bumblebee.Text.Gemma3 do
497468

498469
hidden_state = Axon.add(shortcut, hidden_state)
499470

500-
block_cache =
501-
Layers.Decoder.put_attention_caches(
502-
block_cache,
503-
self_attention_cache,
504-
cross_attention_cache
505-
)
506-
507-
{hidden_state, attention, block_cache}
508-
end
509-
510-
defp gemma3_attention(query, key, value, opts) do
511-
attention_mask = opts[:attention_mask]
512-
attention_head_mask = opts[:attention_head_mask]
513-
attention_cache = opts[:attention_cache]
514-
offset = opts[:offset]
515-
position_ids = opts[:position_ids]
516-
attention_window_size = opts[:attention_window_size]
517-
spec = opts[:spec]
518-
name = opts[:name]
519-
520-
num_heads = spec.num_attention_heads
521-
num_key_value_heads = spec.num_key_value_heads
522-
attention_head_size = spec.attention_head_size
523-
inner_size = num_heads * attention_head_size
524-
inner_kv_size = num_key_value_heads * attention_head_size
525-
526-
# Project Q, K, V
527-
query =
528-
query
529-
|> Axon.dense(inner_size,
530-
kernel_initializer: kernel_initializer(spec),
531-
name: join(name, "query"),
532-
use_bias: spec.use_attention_bias
533-
)
534-
|> Layers.split_heads(num_heads)
535-
536-
key =
537-
key
538-
|> Axon.dense(inner_kv_size,
539-
kernel_initializer: kernel_initializer(spec),
540-
name: join(name, "key"),
541-
use_bias: spec.use_attention_bias
542-
)
543-
|> Layers.split_heads(num_key_value_heads)
544-
545-
value =
546-
value
547-
|> Axon.dense(inner_kv_size,
548-
kernel_initializer: kernel_initializer(spec),
549-
name: join(name, "value"),
550-
use_bias: spec.use_attention_bias
551-
)
552-
|> Layers.split_heads(num_key_value_heads)
553-
554-
# Apply QK-norm (Gemma 3 specific) - uses (1+weight) formula like other norms
555-
query =
556-
Layers.rms_norm(query,
557-
shift: 1.0,
558-
name: join(name, "q_norm"),
559-
epsilon: spec.layer_norm_epsilon
560-
)
561-
562-
key =
563-
Layers.rms_norm(key,
564-
shift: 1.0,
565-
name: join(name, "k_norm"),
566-
epsilon: spec.layer_norm_epsilon
567-
)
568-
569-
# Apply rotary embeddings
570-
{query, key} =
571-
Layers.rotary_embedding(query, key, position_ids, attention_mask, attention_head_size,
572-
max_positions: spec.max_positions,
573-
base: spec.rotary_embedding_base,
574-
scaling_strategy: spec.rotary_embedding_scaling_strategy
575-
)
576-
577-
# Replicate K/V heads for GQA
578-
num_key_value_groups = div(num_heads, num_key_value_heads)
579-
key = repeat_states(key, num_key_value_groups)
580-
value = repeat_states(value, num_key_value_groups)
581-
582-
# Update cache
583-
{key, value, attention_cache} =
584-
Layers.Decoder.cached_attention_key_values(key, value, attention_cache, offset)
585-
586-
# Compute attention
587-
# Layers.attention signature: (query, key, value, key_mask, head_mask, bias, offset, opts)
588-
{attention_output, attention_weights} =
589-
Layers.attention(
590-
query,
591-
key,
592-
value,
593-
attention_mask,
594-
attention_head_mask,
595-
Layers.none(),
596-
offset,
597-
scale: true,
598-
causal: true,
599-
window_size: attention_window_size,
600-
dropout_rate: 0.0
601-
)
602-
603-
# Output projection
604-
hidden_state =
605-
attention_output
606-
|> Layers.flatten_trailing()
607-
|> Axon.dense(spec.hidden_size,
608-
kernel_initializer: kernel_initializer(spec),
609-
name: join(name, "output"),
610-
use_bias: spec.use_attention_bias
611-
)
612-
613-
{hidden_state, attention_weights, attention_cache}
614-
end
615-
616-
defp repeat_states(state, 1), do: state
471+
# Handle cross-attention (required by block interface but not used by Gemma 3)
472+
{_hidden_state, cross_attention_info} =
473+
steps.cross_attention_maybe.(hidden_state, fn _ ->
474+
raise "cross attention not supported"
475+
end)
617476

618-
defp repeat_states(state, times) do
619-
Layers.repeat_interleave(state, times, axis: 2)
477+
{hidden_state, attention_info, cross_attention_info}
620478
end
621479

622480
defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do
@@ -702,9 +560,9 @@ defmodule Bumblebee.Text.Gemma3 do
702560
"decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj",
703561
"decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj",
704562
"decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.o_proj",
705-
# QK-norm (Gemma 3 specific)
706-
"decoder.blocks.{n}.self_attention.q_norm" => "model.layers.{n}.self_attn.q_norm",
707-
"decoder.blocks.{n}.self_attention.k_norm" => "model.layers.{n}.self_attn.k_norm",
563+
# QK-norm (Gemma 3 specific) - uses query_norm/key_norm from shared infrastructure
564+
"decoder.blocks.{n}.self_attention.query_norm" => "model.layers.{n}.self_attn.q_norm",
565+
"decoder.blocks.{n}.self_attention.key_norm" => "model.layers.{n}.self_attn.k_norm",
708566
# Layer norms
709567
"decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm",
710568
"decoder.blocks.{n}.post_attention_norm" => "model.layers.{n}.post_attention_layernorm",

test/bumblebee/text/gemma3_test.exs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ defmodule Bumblebee.Text.Gemma3Test do
1919
outputs = Axon.predict(model, params, inputs)
2020

2121
assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
22+
23+
assert_all_close(
24+
outputs.hidden_state[[.., 1..3, 1..3]],
25+
Nx.tensor([
26+
[[-1.6458, 0.7249, -0.5747], [-1.9452, -0.1602, -0.2329], [-2.3408, -0.4665, -0.1177]]
27+
])
28+
)
2229
end
2330

2431
test ":for_sequence_classification" do
@@ -35,6 +42,11 @@ defmodule Bumblebee.Text.Gemma3Test do
3542
outputs = Axon.predict(model, params, inputs)
3643

3744
assert Nx.shape(outputs.logits) == {1, 2}
45+
46+
assert_all_close(
47+
outputs.logits,
48+
Nx.tensor([[-0.0060, -0.0212]])
49+
)
3850
end
3951

4052
test ":for_causal_language_modeling" do
@@ -51,5 +63,12 @@ defmodule Bumblebee.Text.Gemma3Test do
5163
outputs = Axon.predict(model, params, inputs)
5264

5365
assert Nx.shape(outputs.logits) == {1, 10, 1024}
66+
67+
assert_all_close(
68+
outputs.logits[[.., 1..3, 1..3]],
69+
Nx.tensor([
70+
[[0.1472, 0.0633, 0.0922], [-0.1089, -0.0344, 0.0755], [0.0112, 0.1083, 0.1461]]
71+
])
72+
)
5473
end
5574
end

0 commit comments

Comments
 (0)