Skip to content

Commit 281abfc

Browse files
Add cross_encoder serving and fix text_classification token_type_ids (#444)
1 parent bbd4d83 commit 281abfc

File tree

5 files changed

+225
-2
lines changed

5 files changed

+225
-2
lines changed

lib/bumblebee/text.ex

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,68 @@ defmodule Bumblebee.Text do
361361
defdelegate text_classification(model_info, tokenizer, opts \\ []),
362362
to: Bumblebee.Text.TextClassification
363363

364+
@type cross_encoding_input :: {String.t(), String.t()}
365+
@type cross_encoding_output :: %{score: number()}
366+
367+
@doc """
368+
Builds serving for cross-encoder models.
369+
370+
Cross-encoders score text pairs by encoding them jointly through a
371+
transformer with full cross-attention. This is commonly used for
372+
reranking search results, semantic similarity, and natural language
373+
inference tasks.
374+
375+
The serving accepts `t:cross_encoding_input/0` and returns
376+
`t:cross_encoding_output/0`. A list of inputs is also supported.
377+
378+
## Options
379+
380+
* `:compile` - compiles all computations for predefined input shapes
381+
during serving initialization. Should be a keyword list with the
382+
following keys:
383+
384+
* `:batch_size` - the maximum batch size of the input. Inputs
385+
are optionally padded to always match this batch size
386+
387+
* `:sequence_length` - the maximum input sequence length. Input
388+
sequences are always padded/truncated to match that length.
389+
A list can be given, in which case the serving compiles
390+
a separate computation for each length and then inputs are
391+
matched to the smallest bounding length
392+
393+
It is advised to set this option in production and also configure
394+
a defn compiler using `:defn_options` to maximally reduce inference
395+
time.
396+
397+
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
398+
399+
* `:preallocate_params` - when `true`, explicitly allocates params
400+
on the device configured by `:defn_options`. You may want to set
401+
this option when using partitioned serving, to allocate params
402+
on each of the devices. When using this option, you should first
403+
load the parameters into the host. This can be done by passing
404+
`backend: {EXLA.Backend, client: :host}` to `load_model/1` and friends.
405+
Defaults to `false`
406+
407+
## Examples
408+
409+
{:ok, model_info} = Bumblebee.load_model({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})
410+
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})
411+
412+
serving = Bumblebee.Text.cross_encoding(model_info, tokenizer)
413+
414+
Nx.Serving.run(serving, {"How many people live in Berlin?", "Berlin has a population of 3.5 million."})
415+
#=> %{score: 8.761}
416+
417+
"""
418+
@spec cross_encoding(
419+
Bumblebee.model_info(),
420+
Bumblebee.Tokenizer.t(),
421+
keyword()
422+
) :: Nx.Serving.t()
423+
defdelegate cross_encoding(model_info, tokenizer, opts \\ []),
424+
to: Bumblebee.Text.CrossEncoding
425+
364426
@type text_embedding_input :: String.t()
365427
@type text_embedding_output :: %{embedding: Nx.Tensor.t()}
366428

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
defmodule Bumblebee.Text.CrossEncoding do
2+
@moduledoc false
3+
4+
alias Bumblebee.Shared
5+
6+
def cross_encoding(model_info, tokenizer, opts \\ []) do
7+
%{model: model, params: params, spec: spec} = model_info
8+
Shared.validate_architecture!(spec, :for_sequence_classification)
9+
10+
opts =
11+
Keyword.validate!(opts, [
12+
:compile,
13+
defn_options: [],
14+
preallocate_params: false
15+
])
16+
17+
preallocate_params = opts[:preallocate_params]
18+
defn_options = opts[:defn_options]
19+
20+
compile =
21+
if compile = opts[:compile] do
22+
compile
23+
|> Keyword.validate!([:batch_size, :sequence_length])
24+
|> Shared.require_options!([:batch_size, :sequence_length])
25+
end
26+
27+
batch_size = compile[:batch_size]
28+
sequence_length = compile[:sequence_length]
29+
30+
tokenizer =
31+
Bumblebee.configure(tokenizer, length: sequence_length)
32+
33+
{_init_fun, predict_fun} = Axon.build(model)
34+
35+
scores_fun = fn params, input ->
36+
outputs = predict_fun.(params, input)
37+
Nx.squeeze(outputs.logits, axes: [-1])
38+
end
39+
40+
batch_keys = Shared.sequence_batch_keys(sequence_length)
41+
42+
Nx.Serving.new(
43+
fn batch_key, defn_options ->
44+
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)
45+
46+
scope = {:cross_encoding, batch_key}
47+
48+
scores_fun =
49+
Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn ->
50+
{:sequence_length, sequence_length} = batch_key
51+
52+
inputs = %{
53+
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
54+
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32),
55+
"token_type_ids" => Nx.template({batch_size, sequence_length}, :u32)
56+
}
57+
58+
[params, inputs]
59+
end)
60+
61+
fn inputs ->
62+
inputs = Shared.maybe_pad(inputs, batch_size)
63+
scores_fun.(params, inputs) |> Shared.serving_post_computation()
64+
end
65+
end,
66+
defn_options
67+
)
68+
|> Nx.Serving.batch_size(batch_size)
69+
|> Nx.Serving.process_options(batch_keys: batch_keys)
70+
|> Nx.Serving.client_preprocessing(fn input ->
71+
{pairs, multi?} = Shared.validate_serving_input!(input, &validate_pair/1)
72+
73+
inputs =
74+
Nx.with_default_backend(Nx.BinaryBackend, fn ->
75+
Bumblebee.apply_tokenizer(tokenizer, pairs)
76+
end)
77+
78+
batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
79+
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)
80+
81+
{batch, multi?}
82+
end)
83+
|> Nx.Serving.client_postprocessing(fn {scores, _metadata}, multi? ->
84+
scores
85+
|> Nx.to_list()
86+
|> Enum.map(&%{score: &1})
87+
|> Shared.normalize_output(multi?)
88+
end)
89+
end
90+
91+
defp validate_pair({text1, text2}) when is_binary(text1) and is_binary(text2),
92+
do: {:ok, {text1, text2}}
93+
94+
defp validate_pair(value),
95+
do: {:error, "expected a {string, string} pair, got: #{inspect(value)}"}
96+
end

lib/bumblebee/text/text_classification.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ defmodule Bumblebee.Text.TextClassification do
3232
sequence_length = compile[:sequence_length]
3333

3434
tokenizer =
35-
Bumblebee.configure(tokenizer, length: sequence_length, return_token_type_ids: false)
35+
Bumblebee.configure(tokenizer, length: sequence_length)
3636

3737
{_init_fun, predict_fun} = Axon.build(model)
3838

@@ -58,7 +58,8 @@ defmodule Bumblebee.Text.TextClassification do
5858

5959
inputs = %{
6060
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
61-
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
61+
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32),
62+
"token_type_ids" => Nx.template({batch_size, sequence_length}, :u32)
6263
}
6364

6465
[params, inputs]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
defmodule Bumblebee.Text.CrossEncodingTest do
2+
use ExUnit.Case, async: true
3+
4+
import Bumblebee.TestHelpers
5+
6+
@moduletag serving_test_tags()
7+
8+
test "scores sentence pairs" do
9+
{:ok, model_info} = Bumblebee.load_model({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})
10+
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})
11+
12+
serving = Bumblebee.Text.cross_encoding(model_info, tokenizer)
13+
14+
query = "How many people live in Berlin?"
15+
16+
# Single pair
17+
assert %{score: score} =
18+
Nx.Serving.run(
19+
serving,
20+
{query, "Berlin has a population of 3,520,031 registered inhabitants."}
21+
)
22+
23+
assert_in_delta score, 8.76, 0.01
24+
25+
# Multiple pairs (batch)
26+
assert [%{score: relevant_score}, %{score: irrelevant_score}] =
27+
Nx.Serving.run(serving, [
28+
{query, "Berlin has a population of 3,520,031 registered inhabitants."},
29+
{query, "New York City is famous for its skyscrapers."}
30+
])
31+
32+
assert relevant_score > irrelevant_score
33+
assert_in_delta relevant_score, 8.76, 0.01
34+
assert_in_delta irrelevant_score, -11.24, 0.01
35+
end
36+
end

test/bumblebee/text/text_classification_test.exs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,32 @@ defmodule Bumblebee.Text.TextClassificationTest do
2222
]
2323
} = Nx.Serving.run(serving, text)
2424
end
25+
26+
test "scores sentence pairs correctly for cross-encoder reranking" do
27+
{:ok, model_info} = Bumblebee.load_model({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})
28+
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})
29+
30+
serving =
31+
Bumblebee.Text.TextClassification.text_classification(model_info, tokenizer,
32+
scores_function: :none
33+
)
34+
35+
query = "How many people live in Berlin?"
36+
37+
# Relevant document should score higher than irrelevant
38+
%{predictions: [%{score: relevant_score}]} =
39+
Nx.Serving.run(
40+
serving,
41+
{query, "Berlin has a population of 3,520,031 registered inhabitants."}
42+
)
43+
44+
%{predictions: [%{score: irrelevant_score}]} =
45+
Nx.Serving.run(serving, {query, "New York City is famous for its skyscrapers."})
46+
47+
assert relevant_score > irrelevant_score
48+
49+
# Verify scores match Python sentence-transformers reference values
50+
assert_in_delta relevant_score, 8.76, 0.01
51+
assert_in_delta irrelevant_score, -11.24, 0.01
52+
end
2553
end

0 commit comments

Comments
 (0)