Skip to content

Commit 581ae14

Browse files
fix: use right tokenization and type IDs for bert style cross encoder
1 parent 196fa06 commit 581ae14

File tree

5 files changed

+107
-34
lines changed

5 files changed

+107
-34
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](http://keepachangelog.com/)
55
and this project adheres to [Semantic Versioning](http://semver.org/).
66

7+
## [0.5.4] - 2025-09-10
8+
9+
### Changed
10+
11+
- Fix: use right tokenization and token type IDs for Bert-style sentence pair in cross encoder
12+
713
## [0.5.3] - 2025-09-01
814

915
### Changed

hugot_test.go

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -820,29 +820,50 @@ func crossEncoderPipeline(t *testing.T, session *Session) {
820820
pipeline, err := NewPipeline(session, config)
821821
checkT(t, err)
822822

823-
query := "What is the capital of France?"
823+
query := "Organic skincare products for sensitive skin"
824824
documents := []string{
825-
"Paris is the capital of France.",
826-
"The Eiffel Tower is in Paris.",
827-
"France is a country in Europe.",
825+
"Eco-friendly kitchenware for modern homes",
826+
"Biodegradable cleaning supplies for eco-conscious consumers",
827+
"Organic cotton baby clothes for sensitive skin",
828+
"Natural organic skincare range for sensitive skin",
829+
"Tech gadgets for smart homes: 2024 edition",
830+
"Sustainable gardening tools and compost solutions",
831+
"Sensitive skin-friendly facial cleansers and toners",
832+
"Organic food wraps and storage solutions",
833+
"All-natural pet food for dogs with allergies",
834+
"Yoga mats made from recycled materials",
835+
}
836+
837+
type Expected struct {
838+
Document string
839+
Score float32
840+
}
841+
842+
expectedRoberta := []Expected{
843+
{Document: "Natural organic skincare range for sensitive skin", Score: 0.95478064},
844+
{Document: "Organic cotton baby clothes for sensitive skin", Score: 0.8185698},
845+
{Document: "Sensitive skin-friendly facial cleansers and toners", Score: 0.5848757},
846+
{Document: "Organic food wraps and storage solutions", Score: 0.2567817},
847+
{Document: "Biodegradable cleaning supplies for eco-conscious consumers", Score: 0.22029042},
848+
{Document: "Yoga mats made from recycled materials", Score: 0.20082192},
849+
{Document: "Sustainable gardening tools and compost solutions", Score: 0.19299757},
850+
{Document: "All-natural pet food for dogs with allergies", Score: 0.18836288},
851+
{Document: "Eco-friendly kitchenware for modern homes", Score: 0.18346606},
852+
{Document: "Tech gadgets for smart homes: 2024 edition", Score: 0.16224432},
828853
}
829854

830855
inputs := append([]string{query}, documents...)
831856
output, err := pipeline.Run(inputs)
832857
checkT(t, err)
833-
834858
results := output.(*pipelines.CrossEncoderOutput).Results
835-
if len(results) != 3 {
836-
t.Errorf("Expected 3 results, got %d", len(results))
837-
}
838-
if results[0].Document != "Paris is the capital of France." {
839-
t.Errorf("Expected 'Paris is the capital of France.' as best document, got '%s'", results[0].Document)
840-
}
841-
if results[0].Score <= results[1].Score {
842-
t.Errorf("Expected result 0 to have higher score than result 1, but got %f and %f", results[0].Score, results[1].Score)
843-
}
844-
if results[1].Score <= results[2].Score {
845-
t.Errorf("Expected result 1 to have higher score than result 2, but got %f and %f", results[1].Score, results[2].Score)
859+
860+
for i, expected := range expectedRoberta {
861+
if expected.Document != results[i].Document {
862+
t.Errorf("Expected document '%s', got '%s'", expected.Document, results[i].Document)
863+
}
864+
if math.Abs(float64(expected.Score-results[i].Score)) > 0.01 {
865+
t.Errorf("Expected score '%f', got '%f'", expected.Score, results[i].Score)
866+
}
846867
}
847868
}
848869

pipelines/crossEncoder.go

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@ type CrossEncoderPipeline struct {
2222
}
2323

2424
type CrossEncoderStats struct {
25-
TotalQueries uint64
26-
TotalDocuments uint64
27-
AverageLatency time.Duration
28-
AverageBatchSize float64
29-
TruncatedSequences uint64
30-
FilteredResults uint64
25+
TotalQueries uint64
26+
TotalDocuments uint64
27+
AverageLatency time.Duration
28+
AverageBatchSize float64
29+
FilteredResults uint64
3130
}
3231

3332
type CrossEncoderResult struct {
@@ -122,7 +121,6 @@ func (p *CrossEncoderPipeline) GetStats() []string {
122121
fmt.Sprintf("Total documents scored: %d", p.stats.TotalDocuments),
123122
fmt.Sprintf("Average latency per query: %s", avgLatency),
124123
fmt.Sprintf("Average batch size: %.2f", p.stats.AverageBatchSize),
125-
fmt.Sprintf("Truncated sequences: %d", p.stats.TruncatedSequences),
126124
fmt.Sprintf("Filtered results: %d", p.stats.FilteredResults),
127125
fmt.Sprintf("Tokenizer: Total time=%s, Execution count=%d, Average query time=%s",
128126
time.Duration(p.Model.Tokenizer.TokenizerTimings.TotalNS),
@@ -142,6 +140,14 @@ func (p *CrossEncoderPipeline) Validate() error {
142140
validationErrors = append(validationErrors, fmt.Errorf("cross encoder pipeline requires a tokenizer"))
143141
}
144142

143+
if p.Model.SeparatorToken == "" {
144+
validationErrors = append(validationErrors, fmt.Errorf("cross encoder pipeline requires a separator token to be set in the model"))
145+
}
146+
147+
if p.Model.SeparatorToken != "[SEP]" && p.Model.SeparatorToken != "</s>" {
148+
validationErrors = append(validationErrors, fmt.Errorf("cross encoder pipeline only supports [SEP] (BERT) and </s> (Roberta) as separator tokens, got %s", p.Model.SeparatorToken))
149+
}
150+
145151
outDims := p.Model.OutputsMeta[0].Dimensions
146152
if len(outDims) != 2 {
147153
validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: cross encoder must have 2 dimensional output"))
@@ -167,16 +173,47 @@ func (p *CrossEncoderPipeline) Validate() error {
167173
return errors.Join(validationErrors...)
168174
}
169175

176+
func patchBertSequenceTokenTypeIDs(batch *pipelineBackends.PipelineBatch, sepToken string) {
177+
// Fix token_type_ids for BERT-style models when we manually concatenated the pair as a single sequence.
178+
// Pattern expected: [CLS] query [SEP] doc [SEP]
179+
// HF sets token_type_ids=0 up to and including first [SEP], then 1 for remainder (including final [SEP]).
180+
for index := range batch.Input {
181+
input := &batch.Input[index]
182+
// Only adjust if type ids exist and are all zero
183+
allZero := true
184+
for _, t := range input.TypeIDs {
185+
if t != 0 {
186+
allZero = false
187+
break
188+
}
189+
}
190+
if !allZero || len(input.TypeIDs) == 0 {
191+
continue
192+
}
193+
// Find first [SEP] token index (skip position 0 which should be [CLS])
194+
firstSep := -1
195+
for iTok := 1; iTok < len(input.Tokens); iTok++ {
196+
if input.Tokens[iTok] == sepToken {
197+
firstSep = iTok
198+
break
199+
}
200+
}
201+
if firstSep == -1 || firstSep == len(input.Tokens)-1 { // nothing to split
202+
continue
203+
}
204+
for iTok := firstSep + 1; iTok < len(input.TypeIDs); iTok++ {
205+
input.TypeIDs[iTok] = 1
206+
}
207+
}
208+
}
209+
170210
func (p *CrossEncoderPipeline) Preprocess(batch *pipelineBackends.PipelineBatch, inputs []string) error {
171211
start := time.Now()
172212

173213
pipelineBackends.TokenizeInputs(batch, p.Model.Tokenizer, inputs)
174214

175-
// Track truncated sequences (tokenizer already handles truncation)
176-
for _, tokenizedInput := range batch.Input {
177-
if len(tokenizedInput.TokenIDs) >= p.Model.Tokenizer.MaxAllowedTokens {
178-
atomic.AddUint64(&p.stats.TruncatedSequences, 1)
179-
}
215+
if p.Model != nil && p.Model.Tokenizer != nil && p.Model.SeparatorToken == "[SEP]" {
216+
patchBertSequenceTokenTypeIDs(batch, p.Model.SeparatorToken)
180217
}
181218

182219
atomic.AddUint64(&p.Model.Tokenizer.TokenizerTimings.NumCalls, 1)
@@ -289,8 +326,16 @@ func (p *CrossEncoderPipeline) runBatch(query string, documents []string, startI
289326
var runErrors []error
290327

291328
inputs := make([]string, len(documents))
329+
sep := p.Model.SeparatorToken
330+
292331
for i, doc := range documents {
293-
inputs[i] = fmt.Sprintf("[CLS] %s [SEP] %s [SEP]", query, doc)
332+
if sep == "</s>" {
333+
// RoBERTa style: query </s> </s> document
334+
inputs[i] = fmt.Sprintf("%s%s%s%s", query, sep, sep, doc)
335+
} else {
336+
// BERT style: query [SEP] document [SEP]
337+
inputs[i] = fmt.Sprintf("%s%s%s", query, sep, doc)
338+
}
294339
}
295340

296341
batch := pipelineBackends.NewBatch(len(inputs))
@@ -300,6 +345,7 @@ func (p *CrossEncoderPipeline) runBatch(query string, documents []string, startI
300345
}(batch)
301346

302347
runErrors = append(runErrors, p.Preprocess(batch, inputs))
348+
303349
if e := errors.Join(runErrors...); e != nil {
304350
return nil, e
305351
}

scripts/run-unit-tests-container.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ echo "XLA tests completed."
2222

2323
# echo "Running training tests..."
2424

25-
# gotestsum --format testname --junitfile=$folder/unit-training.xml --jsonfile=$folder/unit-training.json -- -coverprofile=$folder/cover-training.out -coverpkg ./... -tags=ORT,XLA,TRAINING -timeout 60m
25+
gotestsum --format testname --junitfile=$folder/unit-training.xml --jsonfile=$folder/unit-training.json -- -coverprofile=$folder/cover-training.out -coverpkg ./... -tags=ORT,XLA,TRAINING -timeout 60m
2626

27-
# echo "Training tests completed."
27+
echo "Training tests completed."
2828

2929
# echo "Running simplego tests..."
3030

@@ -36,7 +36,7 @@ echo "merging coverage files"
3636
head -n 1 $folder/cover-ort.out > $folder/cover.out
3737
tail -n +2 $folder/cover-ort.out >> $folder/cover.out
3838
tail -n +2 $folder/cover-xla.out >> $folder/cover.out
39-
# tail -n +2 $folder/cover-training.out >> $folder/cover.out
39+
tail -n +2 $folder/cover-training.out >> $folder/cover.out
4040
# tail -n +2 $folder/cover-go.out >> $folder/cover.out
4141

4242
head -n 1 $folder/cover.out > $folder/cover.dedup.out

testData/downloadModels.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ func main() {
5757
options := hugot.NewDownloadOptions()
5858
options.OnnxFilePath = model.onnxFilePath
5959
options.ExternalDataPath = model.externalDataPath
60-
fmt.Println(fmt.Sprintf("Downloading %s", model.name))
60+
fmt.Printf("Downloading %s\n", model.name)
6161
outPath, dlErr := hugot.DownloadModel(model.name, "./models", options)
6262
if dlErr != nil {
6363
panic(dlErr)
6464
}
65-
fmt.Println(fmt.Sprintf("Downloaded %s to %s", model.name, outPath))
65+
fmt.Printf("Downloaded %s to %s\n", model.name, outPath)
6666
}
6767
} else {
6868
panic(err)

0 commit comments

Comments
 (0)