@@ -362,103 +362,79 @@ defmodule Bumblebee.Text.Gemma3 do
362362 ) do
363363 name = opts [ :name ]
364364
365- # Use cached attention mask
366- { attention_mask , cache } = Layers.Decoder . cached_attention_mask ( attention_mask , cache )
367- offset = Layers.Decoder . get_cache_offset ( cache )
368-
369- state = % {
370- hidden_state: hidden_state ,
371- hidden_states: Axon . container ( { hidden_state } ) ,
372- attentions: Axon . container ( { } ) ,
373- cache: cache
374- }
375-
376- outputs =
377- for idx <- 0 .. ( spec . num_blocks - 1 ) , reduce: state do
378- state ->
379- block_attention_head_mask = Axon . nx ( attention_head_mask , & & 1 [ idx ] )
380- block_cache = Layers.Decoder . get_block_cache ( state . cache , idx )
381-
382- # Gemma 3 alternates between local (sliding window) and global attention
383- # Every global_attention_layer_interval-th layer uses global attention
384- attention_window_size =
385- if rem ( idx + 1 , spec . global_attention_layer_interval ) == 0 do
386- # Global attention (no window)
387- nil
388- else
389- # Local attention with sliding window
390- { spec . sliding_window , spec . sliding_window }
391- end
392-
393- { hidden_state , attention , block_cache } =
394- gemma3_block ( state . hidden_state ,
395- attention_mask: attention_mask ,
396- attention_head_mask: block_attention_head_mask ,
397- block_cache: block_cache ,
398- offset: offset ,
399- position_ids: position_ids ,
400- attention_window_size: attention_window_size ,
401- spec: spec ,
402- name: join ( name , "blocks.#{ idx } " )
403- )
404-
405- cache = Layers.Decoder . put_block_cache ( state . cache , idx , block_cache )
406-
407- % {
408- hidden_state: hidden_state ,
409- hidden_states: Layers . append ( state . hidden_states , hidden_state ) ,
410- attentions: Layers . append ( state . attentions , attention ) ,
411- cache: cache
412- }
365+ # QK-norm functions for Gemma 3 (uses shift: 1.0 for (1+weight) formula)
366+ query_norm = & Layers . rms_norm ( & 1 , shift: 1.0 , epsilon: spec . layer_norm_epsilon , name: & 2 )
367+ key_norm = & Layers . rms_norm ( & 1 , shift: 1.0 , epsilon: spec . layer_norm_epsilon , name: & 2 )
368+
369+ # Per-layer attention window size for alternating local/global attention
370+ # Every global_attention_layer_interval-th layer uses global attention
371+ attention_window_size = fn idx ->
372+ if rem ( idx + 1 , spec . global_attention_layer_interval ) == 0 do
373+ nil
374+ else
375+ { spec . sliding_window , spec . sliding_window }
413376 end
377+ end
414378
415- outputs = update_in ( outputs . cache , & Layers.Decoder . update_cache_offset ( & 1 , hidden_state ) )
379+ # Custom block_type function for Gemma 3's unique block structure
380+ block_type = fn hidden_state , steps , block_name ->
381+ gemma3_block_impl ( hidden_state , steps , block_name , spec )
382+ end
416383
417- % {
418- hidden_state: outputs . hidden_state ,
419- hidden_states: outputs . hidden_states ,
420- attentions: outputs . attentions ,
421- cache: outputs . cache
422- }
384+ Layers.Transformer . blocks ( hidden_state ,
385+ attention_mask: attention_mask ,
386+ attention_head_mask: attention_head_mask ,
387+ cache: cache ,
388+ num_blocks: spec . num_blocks ,
389+ num_attention_heads: spec . num_attention_heads ,
390+ num_key_value_heads: spec . num_key_value_heads ,
391+ hidden_size: spec . hidden_size ,
392+ attention_head_size: spec . attention_head_size ,
393+ kernel_initializer: kernel_initializer ( spec ) ,
394+ layer_norm:
395+ & Layers . rms_norm ( & 1 ,
396+ shift: 1.0 ,
397+ name: & 2 ,
398+ epsilon: spec . layer_norm_epsilon ,
399+ upcast: :all
400+ ) ,
401+ ffn:
402+ & gated_ffn ( & 1 , spec . intermediate_size , spec . hidden_size ,
403+ name: & 2 ,
404+ activation: spec . activation
405+ ) ,
406+ block_type: block_type ,
407+ causal: true ,
408+ rotary_embedding: [
409+ position_ids: position_ids ,
410+ max_positions: spec . max_positions ,
411+ base: spec . rotary_embedding_base ,
412+ scaling_strategy: spec . rotary_embedding_scaling_strategy
413+ ] ,
414+ attention_window_size: attention_window_size ,
415+ query_norm: query_norm ,
416+ key_norm: key_norm ,
417+ query_use_bias: spec . use_attention_bias ,
418+ key_use_bias: spec . use_attention_bias ,
419+ value_use_bias: spec . use_attention_bias ,
420+ output_use_bias: spec . use_attention_bias ,
421+ name: join ( name , "blocks" )
422+ )
423423 end
424424
425- defp gemma3_block ( hidden_state , opts ) do
426- attention_mask = opts [ :attention_mask ]
427- attention_head_mask = opts [ :attention_head_mask ]
428- block_cache = opts [ :block_cache ]
429- offset = opts [ :offset ]
430- position_ids = opts [ :position_ids ]
431- attention_window_size = opts [ :attention_window_size ]
432- spec = opts [ :spec ]
433- name = opts [ :name ]
434-
435- { self_attention_cache , cross_attention_cache } =
436- Layers.Decoder . get_attention_caches ( block_cache )
437-
438- # Self-attention with pre-norm (input_layernorm)
425+ # Custom block implementation for Gemma 3's unique normalization structure:
426+ # - Post-attention norm BEFORE residual add
427+ # - Pre/post FFN norms
428+ defp gemma3_block_impl ( hidden_state , steps , name , spec ) do
429+ # Pre-attention norm + attention (using provided steps)
439430 shortcut = hidden_state
440431
441- hidden_state =
442- Layers . rms_norm ( hidden_state ,
443- shift: 1.0 ,
444- name: join ( name , "self_attention_norm" ) ,
445- epsilon: spec . layer_norm_epsilon ,
446- upcast: :all
447- )
448-
449- { hidden_state , attention , self_attention_cache } =
450- gemma3_attention ( hidden_state , hidden_state , hidden_state ,
451- attention_mask: attention_mask ,
452- attention_head_mask: attention_head_mask ,
453- attention_cache: self_attention_cache ,
454- offset: offset ,
455- position_ids: position_ids ,
456- attention_window_size: attention_window_size ,
457- spec: spec ,
458- name: join ( name , "self_attention" )
459- )
432+ { hidden_state , attention_info } =
433+ hidden_state
434+ |> steps . self_attention_norm . ( )
435+ |> steps . self_attention . ( )
460436
461- # Post-attention norm BEFORE residual add (Gemma 3 specific)
437+ # Post-attention norm BEFORE residual (Gemma 3 specific)
462438 hidden_state =
463439 Layers . rms_norm ( hidden_state ,
464440 shift: 1.0 ,
@@ -467,7 +443,6 @@ defmodule Bumblebee.Text.Gemma3 do
467443 upcast: :all
468444 )
469445
470- # Residual add AFTER post_attention_norm
471446 hidden_state = Axon . add ( shortcut , hidden_state )
472447
473448 # FFN with pre/post norms (Gemma 3 specific)
@@ -481,11 +456,7 @@ defmodule Bumblebee.Text.Gemma3 do
481456 upcast: :all
482457 )
483458
484- hidden_state =
485- gated_ffn ( hidden_state , spec . intermediate_size , spec . hidden_size ,
486- name: join ( name , "ffn" ) ,
487- activation: spec . activation
488- )
459+ hidden_state = steps . ffn . ( hidden_state )
489460
490461 hidden_state =
491462 Layers . rms_norm ( hidden_state ,
@@ -497,126 +468,13 @@ defmodule Bumblebee.Text.Gemma3 do
497468
498469 hidden_state = Axon . add ( shortcut , hidden_state )
499470
500- block_cache =
501- Layers.Decoder . put_attention_caches (
502- block_cache ,
503- self_attention_cache ,
504- cross_attention_cache
505- )
506-
507- { hidden_state , attention , block_cache }
508- end
509-
510- defp gemma3_attention ( query , key , value , opts ) do
511- attention_mask = opts [ :attention_mask ]
512- attention_head_mask = opts [ :attention_head_mask ]
513- attention_cache = opts [ :attention_cache ]
514- offset = opts [ :offset ]
515- position_ids = opts [ :position_ids ]
516- attention_window_size = opts [ :attention_window_size ]
517- spec = opts [ :spec ]
518- name = opts [ :name ]
519-
520- num_heads = spec . num_attention_heads
521- num_key_value_heads = spec . num_key_value_heads
522- attention_head_size = spec . attention_head_size
523- inner_size = num_heads * attention_head_size
524- inner_kv_size = num_key_value_heads * attention_head_size
525-
526- # Project Q, K, V
527- query =
528- query
529- |> Axon . dense ( inner_size ,
530- kernel_initializer: kernel_initializer ( spec ) ,
531- name: join ( name , "query" ) ,
532- use_bias: spec . use_attention_bias
533- )
534- |> Layers . split_heads ( num_heads )
535-
536- key =
537- key
538- |> Axon . dense ( inner_kv_size ,
539- kernel_initializer: kernel_initializer ( spec ) ,
540- name: join ( name , "key" ) ,
541- use_bias: spec . use_attention_bias
542- )
543- |> Layers . split_heads ( num_key_value_heads )
544-
545- value =
546- value
547- |> Axon . dense ( inner_kv_size ,
548- kernel_initializer: kernel_initializer ( spec ) ,
549- name: join ( name , "value" ) ,
550- use_bias: spec . use_attention_bias
551- )
552- |> Layers . split_heads ( num_key_value_heads )
553-
554- # Apply QK-norm (Gemma 3 specific) - uses (1+weight) formula like other norms
555- query =
556- Layers . rms_norm ( query ,
557- shift: 1.0 ,
558- name: join ( name , "q_norm" ) ,
559- epsilon: spec . layer_norm_epsilon
560- )
561-
562- key =
563- Layers . rms_norm ( key ,
564- shift: 1.0 ,
565- name: join ( name , "k_norm" ) ,
566- epsilon: spec . layer_norm_epsilon
567- )
568-
569- # Apply rotary embeddings
570- { query , key } =
571- Layers . rotary_embedding ( query , key , position_ids , attention_mask , attention_head_size ,
572- max_positions: spec . max_positions ,
573- base: spec . rotary_embedding_base ,
574- scaling_strategy: spec . rotary_embedding_scaling_strategy
575- )
576-
577- # Replicate K/V heads for GQA
578- num_key_value_groups = div ( num_heads , num_key_value_heads )
579- key = repeat_states ( key , num_key_value_groups )
580- value = repeat_states ( value , num_key_value_groups )
581-
582- # Update cache
583- { key , value , attention_cache } =
584- Layers.Decoder . cached_attention_key_values ( key , value , attention_cache , offset )
585-
586- # Compute attention
587- # Layers.attention signature: (query, key, value, key_mask, head_mask, bias, offset, opts)
588- { attention_output , attention_weights } =
589- Layers . attention (
590- query ,
591- key ,
592- value ,
593- attention_mask ,
594- attention_head_mask ,
595- Layers . none ( ) ,
596- offset ,
597- scale: true ,
598- causal: true ,
599- window_size: attention_window_size ,
600- dropout_rate: 0.0
601- )
602-
603- # Output projection
604- hidden_state =
605- attention_output
606- |> Layers . flatten_trailing ( )
607- |> Axon . dense ( spec . hidden_size ,
608- kernel_initializer: kernel_initializer ( spec ) ,
609- name: join ( name , "output" ) ,
610- use_bias: spec . use_attention_bias
611- )
612-
613- { hidden_state , attention_weights , attention_cache }
614- end
615-
616- defp repeat_states ( state , 1 ) , do: state
471+ # Handle cross-attention (required by block interface but not used by Gemma 3)
472+ { _hidden_state , cross_attention_info } =
473+ steps . cross_attention_maybe . ( hidden_state , fn _ ->
474+ raise "cross attention not supported"
475+ end )
617476
618- defp repeat_states ( state , times ) do
619- Layers . repeat_interleave ( state , times , axis: 2 )
477+ { hidden_state , attention_info , cross_attention_info }
620478 end
621479
622480 defp gated_ffn ( hidden_state , intermediate_size , output_size , opts ) do
@@ -702,9 +560,9 @@ defmodule Bumblebee.Text.Gemma3 do
702560 "decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj" ,
703561 "decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj" ,
704562 "decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.o_proj" ,
705- # QK-norm (Gemma 3 specific)
706- "decoder.blocks.{n}.self_attention.q_norm " => "model.layers.{n}.self_attn.q_norm" ,
707- "decoder.blocks.{n}.self_attention.k_norm " => "model.layers.{n}.self_attn.k_norm" ,
563+ # QK-norm (Gemma 3 specific) - uses query_norm/key_norm from shared infrastructure
564+ "decoder.blocks.{n}.self_attention.query_norm " => "model.layers.{n}.self_attn.q_norm" ,
565+ "decoder.blocks.{n}.self_attention.key_norm " => "model.layers.{n}.self_attn.k_norm" ,
708566 # Layer norms
709567 "decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm" ,
710568 "decoder.blocks.{n}.post_attention_norm" => "model.layers.{n}.post_attention_layernorm" ,
0 commit comments