feat: Add cross_encoder serving and fix text_classification token_type_ids#444
Conversation
Previously, text_classification set return_token_type_ids: false which broke sentence-pair inputs like query-document pairs used by cross-encoder rerankers. Now token_type_ids are included, making rerankers produce correct scores. Closes elixir-nx#251
There was a problem hiding this comment.
Pull request overview
This PR adds cross-encoder support and fixes token_type_ids handling for sentence-pair classification tasks. Cross-encoder models like cross-encoder/ms-marco-MiniLM-L-6-v2 require token_type_ids to distinguish query tokens from document tokens, which were previously disabled in text_classification.
Changes:
- Fixed text_classification to include token_type_ids for sentence-pair inputs
- Added new cross_encoder serving with a dedicated API for reranking use cases
- Added comprehensive test coverage for both changes
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| lib/bumblebee/text/text_classification.ex | Removed return_token_type_ids: false configuration and added token_type_ids to compile template |
| lib/bumblebee/text/cross_encoder.ex | New module implementing cross-encoder serving with pair validation and score extraction |
| lib/bumblebee/text.ex | Added cross_encoder function documentation, type specs, and public API delegation |
| test/bumblebee/text/text_classification_test.exs | Added test verifying correct scoring for cross-encoder sentence pairs |
| test/bumblebee/text/cross_encoder_test.exs | New test file with comprehensive coverage for single and batch pair scoring |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
jonatanklosko
left a comment
There was a problem hiding this comment.
Just one note about naming and looks good to me!
lib/bumblebee/text.ex
Outdated
| defdelegate cross_encoder(model_info, tokenizer, opts \\ []), | ||
| to: Bumblebee.Text.CrossEncoder |
There was a problem hiding this comment.
Nit: let's rename to "cross encoding" to match all the other functions we already have (embeding, classifiction, etc):
| defdelegate cross_encoder(model_info, tokenizer, opts \\ []), | |
| to: Bumblebee.Text.CrossEncoder | |
| defdelegate cross_encoding(model_info, tokenizer, opts \\ []), | |
| to: Bumblebee.Text.CrossEncoding |
f1cdc18 to
15cb657
Compare
I was experimenting with rerankers for georgeguimaraes/arcana and found that cross-encoder models like
cross-encoder/ms-marco-MiniLM-L-6-v2weren't producing correct scores.The issue:
text_classificationwas settingreturn_token_type_ids: false, which breaks sentence-pair inputs. Cross-encoders need token_type_ids to distinguish query tokens from document tokens. Without them, scores don't match Python's sentence-transformers.Changes:
text_classificationto include token_type_ids (also added it to the compile template)cross_encoderserving with a cleaner API for the reranking use caseThe token_type_ids fix also benefits other sentence-pair tasks like NLI and entailment. If you don't want that change in text_classification, let me know.
Closes #251