Skip to content

Commit 726fef6

Browse files
authored
Merge branch 'main' into feat/add-functiongemma-support
2 parents 1fc7aaf + b1b7cf1 commit 726fef6

File tree

10 files changed

+1099
-4
lines changed

10 files changed

+1099
-4
lines changed

lib/bumblebee.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ defmodule Bumblebee do
185185
"Phi3ForCausalLM" => {Bumblebee.Text.Phi3, :for_causal_language_modeling},
186186
"Phi3ForSequenceClassification" => {Bumblebee.Text.Phi3, :for_sequence_classification},
187187
"Phi3ForTokenClassification" => {Bumblebee.Text.Phi3, :for_token_classification},
188+
"Qwen3Model" => {Bumblebee.Text.Qwen3, :base},
189+
"Qwen3ForCausalLM" => {Bumblebee.Text.Qwen3, :for_causal_language_modeling},
190+
"Qwen3ForSequenceClassification" => {Bumblebee.Text.Qwen3, :for_sequence_classification},
188191
"ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification},
189192
"ResNetModel" => {Bumblebee.Vision.ResNet, :base},
190193
"RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
@@ -267,6 +270,7 @@ defmodule Bumblebee do
267270
"mbart" => :mbart,
268271
"phi" => :code_gen,
269272
"phi3" => :llama,
273+
"qwen3" => :qwen2,
270274
"roberta" => :roberta,
271275
"smollm3" => :smollm3,
272276
"t5" => :t5,

lib/bumblebee/layers/transformer.ex

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ defmodule Bumblebee.Layers.Transformer do
5959
:output_use_bias,
6060
:layer_norm,
6161
:block_type,
62-
:scale_attention_weights
62+
:attention_window_size,
63+
:scale_attention_weights,
64+
:query_norm,
65+
:key_norm
6366
]
6467

6568
opts =
@@ -348,7 +351,9 @@ defmodule Bumblebee.Layers.Transformer do
348351
layer_norm: [],
349352
attention_window_size: nil,
350353
scale_attention_weights: true,
351-
rotary_embedding: nil
354+
rotary_embedding: nil,
355+
query_norm: nil,
356+
key_norm: nil
352357
])
353358

354359
name = opts[:name]
@@ -378,6 +383,8 @@ defmodule Bumblebee.Layers.Transformer do
378383
attention_window_size = opts[:attention_window_size]
379384
scale_attention_weights = opts[:scale_attention_weights]
380385
rotary_embedding = opts[:rotary_embedding]
386+
query_norm = opts[:query_norm]
387+
key_norm = opts[:key_norm]
381388

382389
ffn_fun =
383390
case ffn do
@@ -436,6 +443,8 @@ defmodule Bumblebee.Layers.Transformer do
436443
attention_window_size: attention_window_size,
437444
scale_attention_weights: scale_attention_weights,
438445
rotary_embedding: rotary_embedding,
446+
query_norm: query_norm,
447+
key_norm: key_norm,
439448
name: join(name, "self_attention")
440449
)
441450

@@ -721,6 +730,14 @@ defmodule Bumblebee.Layers.Transformer do
721730
722731
* `:max_positions` - the maximum number of distinct positions
723732
733+
* `:query_norm` - a function that applies normalization to the query
734+
projection before rotary embedding. The function should accept two
735+
arguments: the input and a name for the layer. Defaults to `nil`
736+
737+
* `:key_norm` - a function that applies normalization to the key
738+
projection before rotary embedding. The function should accept two
739+
arguments: the input and a name for the layer. Defaults to `nil`
740+
724741
* `:name` - the prefix for layer names
725742
726743
## References
@@ -752,7 +769,9 @@ defmodule Bumblebee.Layers.Transformer do
752769
key_use_bias: true,
753770
value_use_bias: true,
754771
output_use_bias: true,
755-
rotary_embedding: nil
772+
rotary_embedding: nil,
773+
query_norm: nil,
774+
key_norm: nil
756775
])
757776

758777
attention_mask = opts[:attention_mask]
@@ -770,6 +789,8 @@ defmodule Bumblebee.Layers.Transformer do
770789
scale_attention_weights = opts[:scale_attention_weights]
771790
dropout_rate = opts[:dropout_rate]
772791
rotary_embedding = opts[:rotary_embedding]
792+
query_norm = opts[:query_norm]
793+
key_norm = opts[:key_norm]
773794

774795
query_use_bias = opts[:query_use_bias]
775796
key_use_bias = opts[:key_use_bias]
@@ -809,6 +830,21 @@ defmodule Bumblebee.Layers.Transformer do
809830
)
810831
|> Layers.split_heads(num_key_value_heads)
811832

833+
# Apply query and key normalization if configured (before rotary embedding)
834+
query =
835+
if query_norm do
836+
query_norm.(query, join(name, "query_norm"))
837+
else
838+
query
839+
end
840+
841+
key =
842+
if key_norm do
843+
key_norm.(key, join(name, "key_norm"))
844+
else
845+
key
846+
end
847+
812848
{query, key} =
813849
case rotary_embedding do
814850
opts when is_list(opts) ->

lib/bumblebee/text.ex

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,9 @@ defmodule Bumblebee.Text do
385385
Note that we currently assume that the CLS token is the first token
386386
in the sequence
387387
388+
* `:last_token_pooling` - takes the embedding for the last non-padding
389+
token in each sequence
390+
388391
By default no pooling is applied
389392
390393
* `:embedding_processor` - a post-processing step to apply to the
@@ -444,6 +447,82 @@ defmodule Bumblebee.Text do
444447
defdelegate text_embedding(model_info, tokenizer, opts \\ []),
445448
to: Bumblebee.Text.TextEmbedding
446449

450+
@type text_reranking_qwen3_input :: {String.t(), String.t()} | [{String.t(), String.t()}]
451+
@type text_reranking_qwen3_output :: %{
452+
scores: text_reranking_qwen3_score() | list(text_reranking_qwen3_score())
453+
}
454+
@type text_reranking_qwen3_score :: %{score: number(), query: String.t(), document: String.t()}
455+
456+
@doc """
457+
Builds a serving for text reranking with Qwen3 reranker models.
458+
459+
The serving expects input in one of the following formats:
460+
461+
* `{query, document}` - a tuple with query and document text
462+
* `[{query1, doc1}, {query2, doc2}, ...]` - a list of query-document pairs
463+
464+
## Options
465+
466+
* `:yes_token` - the token ID corresponding to "yes" for relevance scoring.
467+
If not provided, will be inferred from the tokenizer
468+
469+
* `:no_token` - the token ID corresponding to "no" for relevance scoring.
470+
If not provided, will be inferred from the tokenizer
471+
472+
* `:instruction_prefix` - the instruction prefix to use. Defaults to the
473+
Qwen3 reranker format
474+
475+
* `:instruction_suffix` - the instruction suffix to use. Defaults to the
476+
Qwen3 reranker format
477+
478+
* `:task_description` - the task description to include in prompts. Defaults
479+
to "Given a web search query, retrieve relevant passages that answer the query"
480+
481+
* `:compile` - compiles all computations for predefined input shapes
482+
during serving initialization. Should be a keyword list with the
483+
following keys:
484+
485+
* `:batch_size` - the maximum batch size of the input. Inputs
486+
are optionally padded to always match this batch size
487+
488+
* `:sequence_length` - the maximum input sequence length. Input
489+
sequences are always padded/truncated to match that length
490+
491+
It is advised to set this option in production and also configure
492+
a defn compiler using `:defn_options` to maximally reduce inference
493+
time
494+
495+
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
496+
497+
* `:preallocate_params` - when `true`, explicitly allocates params
498+
on the device configured in `:defn_options`. You may want to set
499+
this option when using partitioned models on the GPU. Defaults to `false`
500+
501+
## Examples
502+
503+
{:ok, model_info} = Bumblebee.load_model({:hf, "Qwen/Qwen3-Reranker-0.6B"})
504+
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Qwen/Qwen3-Reranker-0.6B"})
505+
506+
serving = Bumblebee.Text.text_reranking_qwen3(model_info, tokenizer)
507+
508+
query = "What is the capital of France?"
509+
documents = [
510+
"Paris is the capital of France.",
511+
"Berlin is the capital of Germany."
512+
]
513+
514+
pairs = Enum.map(documents, &{query, &1})
515+
Nx.Serving.run(serving, pairs)
516+
517+
"""
518+
@spec text_reranking_qwen3(
519+
Bumblebee.model_info(),
520+
Bumblebee.Tokenizer.t(),
521+
keyword()
522+
) :: Nx.Serving.t()
523+
defdelegate text_reranking_qwen3(model_info, tokenizer, opts \\ []),
524+
to: Bumblebee.Text.TextRerankingQwen3
525+
447526
@type fill_mask_input :: String.t()
448527
@type fill_mask_output :: %{predictions: list(fill_mask_prediction())}
449528
@type fill_mask_prediction :: %{score: number(), token: String.t()}

lib/bumblebee/text/pre_trained_tokenizer.ex

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,13 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
200200
},
201201
default_template_options: [language_token: "eng_Latn"]
202202
},
203+
qwen2: %{
204+
special_tokens: %{
205+
unk: "<|endoftext|>",
206+
eos: "<|endoftext|>",
207+
pad: "<|endoftext|>"
208+
}
209+
},
203210
roberta: %{
204211
special_tokens: %{
205212
bos: "<s>",

0 commit comments

Comments
 (0)