@@ -385,6 +385,9 @@ defmodule Bumblebee.Text do
385385 Note that we currently assume that the CLS token is the first token
386386 in the sequence
387387
388+ * `:last_token_pooling` - takes the embedding for the last non-padding
389+ token in each sequence
390+
388391 By default no pooling is applied
389392
390393 * `:embedding_processor` - a post-processing step to apply to the
@@ -444,6 +447,82 @@ defmodule Bumblebee.Text do
444447 defdelegate text_embedding ( model_info , tokenizer , opts \\ [ ] ) ,
445448 to: Bumblebee.Text.TextEmbedding
446449
450+ @ type text_reranking_qwen3_input :: { String . t ( ) , String . t ( ) } | [ { String . t ( ) , String . t ( ) } ]
451+ @ type text_reranking_qwen3_output :: % {
452+ scores: text_reranking_qwen3_score ( ) | list ( text_reranking_qwen3_score ( ) )
453+ }
454+ @ type text_reranking_qwen3_score :: % { score: number ( ) , query: String . t ( ) , document: String . t ( ) }
455+
456+ @ doc """
457+ Builds a serving for text reranking with Qwen3 reranker models.
458+
459+ The serving expects input in one of the following formats:
460+
461+ * `{query, document}` - a tuple with query and document text
462+ * `[{query1, doc1}, {query2, doc2}, ...]` - a list of query-document pairs
463+
464+ ## Options
465+
466+ * `:yes_token` - the token ID corresponding to "yes" for relevance scoring.
467+ If not provided, will be inferred from the tokenizer
468+
469+ * `:no_token` - the token ID corresponding to "no" for relevance scoring.
470+ If not provided, will be inferred from the tokenizer
471+
472+ * `:instruction_prefix` - the instruction prefix to use. Defaults to the
473+ Qwen3 reranker format
474+
475+ * `:instruction_suffix` - the instruction suffix to use. Defaults to the
476+ Qwen3 reranker format
477+
478+ * `:task_description` - the task description to include in prompts. Defaults
479+ to "Given a web search query, retrieve relevant passages that answer the query"
480+
481+ * `:compile` - compiles all computations for predefined input shapes
482+ during serving initialization. Should be a keyword list with the
483+ following keys:
484+
485+ * `:batch_size` - the maximum batch size of the input. Inputs
486+ are optionally padded to always match this batch size
487+
488+ * `:sequence_length` - the maximum input sequence length. Input
489+ sequences are always padded/truncated to match that length
490+
491+ It is advised to set this option in production and also configure
492+ a defn compiler using `:defn_options` to maximally reduce inference
493+ time
494+
495+ * `:defn_options` - the options for JIT compilation. Defaults to `[]`
496+
497+ * `:preallocate_params` - when `true`, explicitly allocates params
498+ on the device configured in `:defn_options`. You may want to set
499+ this option when using partitioned models on the GPU. Defaults to `false`
500+
501+ ## Examples
502+
503+ {:ok, model_info} = Bumblebee.load_model({:hf, "Qwen/Qwen3-Reranker-0.6B"})
504+ {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Qwen/Qwen3-Reranker-0.6B"})
505+
506+ serving = Bumblebee.Text.text_reranking_qwen3(model_info, tokenizer)
507+
508+ query = "What is the capital of France?"
509+ documents = [
510+ "Paris is the capital of France.",
511+ "Berlin is the capital of Germany."
512+ ]
513+
514+ pairs = Enum.map(documents, &{query, &1})
515+ Nx.Serving.run(serving, pairs)
516+
517+ """
518+ @ spec text_reranking_qwen3 (
519+ Bumblebee . model_info ( ) ,
520+ Bumblebee.Tokenizer . t ( ) ,
521+ keyword ( )
522+ ) :: Nx.Serving . t ( )
523+ defdelegate text_reranking_qwen3 ( model_info , tokenizer , opts \\ [ ] ) ,
524+ to: Bumblebee.Text.TextRerankingQwen3
525+
447526 @ type fill_mask_input :: String . t ( )
448527 @ type fill_mask_output :: % { predictions: list ( fill_mask_prediction ( ) ) }
449528 @ type fill_mask_prediction :: % { score: number ( ) , token: String . t ( ) }
0 commit comments