@@ -22,12 +22,11 @@ type CrossEncoderPipeline struct {
2222}
2323
2424type CrossEncoderStats struct {
25- TotalQueries uint64
26- TotalDocuments uint64
27- AverageLatency time.Duration
28- AverageBatchSize float64
29- TruncatedSequences uint64
30- FilteredResults uint64
25+ TotalQueries uint64
26+ TotalDocuments uint64
27+ AverageLatency time.Duration
28+ AverageBatchSize float64
29+ FilteredResults uint64
3130}
3231
3332type CrossEncoderResult struct {
@@ -122,7 +121,6 @@ func (p *CrossEncoderPipeline) GetStats() []string {
122121 fmt .Sprintf ("Total documents scored: %d" , p .stats .TotalDocuments ),
123122 fmt .Sprintf ("Average latency per query: %s" , avgLatency ),
124123 fmt .Sprintf ("Average batch size: %.2f" , p .stats .AverageBatchSize ),
125- fmt .Sprintf ("Truncated sequences: %d" , p .stats .TruncatedSequences ),
126124 fmt .Sprintf ("Filtered results: %d" , p .stats .FilteredResults ),
127125 fmt .Sprintf ("Tokenizer: Total time=%s, Execution count=%d, Average query time=%s" ,
128126 time .Duration (p .Model .Tokenizer .TokenizerTimings .TotalNS ),
@@ -142,6 +140,14 @@ func (p *CrossEncoderPipeline) Validate() error {
142140 validationErrors = append (validationErrors , fmt .Errorf ("cross encoder pipeline requires a tokenizer" ))
143141 }
144142
143+ if p .Model .SeparatorToken == "" {
144+ validationErrors = append (validationErrors , fmt .Errorf ("cross encoder pipeline requires a separator token to be set in the model" ))
145+ }
146+
147+ if p .Model .SeparatorToken != "[SEP]" && p .Model .SeparatorToken != "</s>" {
148+ validationErrors = append (validationErrors , fmt .Errorf ("cross encoder pipeline only supports [SEP] (BERT) and </s> (Roberta) as separator tokens, got %s" , p .Model .SeparatorToken ))
149+ }
150+
145151 outDims := p .Model .OutputsMeta [0 ].Dimensions
146152 if len (outDims ) != 2 {
147153 validationErrors = append (validationErrors , fmt .Errorf ("pipeline configuration invalid: cross encoder must have 2 dimensional output" ))
@@ -167,16 +173,47 @@ func (p *CrossEncoderPipeline) Validate() error {
167173 return errors .Join (validationErrors ... )
168174}
169175
176+ func patchBertSequenceTokenTypeIDs (batch * pipelineBackends.PipelineBatch , sepToken string ) {
177+ // Fix token_type_ids for BERT-style models when we manually concatenated the pair as a single sequence.
178+ // Pattern expected: [CLS] query [SEP] doc [SEP]
179+ // HF sets token_type_ids=0 up to and including first [SEP], then 1 for remainder (including final [SEP]).
180+ for index := range batch .Input {
181+ input := & batch .Input [index ]
182+ // Only adjust if type ids exist and are all zero
183+ allZero := true
184+ for _ , t := range input .TypeIDs {
185+ if t != 0 {
186+ allZero = false
187+ break
188+ }
189+ }
190+ if ! allZero || len (input .TypeIDs ) == 0 {
191+ continue
192+ }
193+ // Find first [SEP] token index (skip position 0 which should be [CLS])
194+ firstSep := - 1
195+ for iTok := 1 ; iTok < len (input .Tokens ); iTok ++ {
196+ if input .Tokens [iTok ] == sepToken {
197+ firstSep = iTok
198+ break
199+ }
200+ }
201+ if firstSep == - 1 || firstSep == len (input .Tokens )- 1 { // nothing to split
202+ continue
203+ }
204+ for iTok := firstSep + 1 ; iTok < len (input .TypeIDs ); iTok ++ {
205+ input .TypeIDs [iTok ] = 1
206+ }
207+ }
208+ }
209+
170210func (p * CrossEncoderPipeline ) Preprocess (batch * pipelineBackends.PipelineBatch , inputs []string ) error {
171211 start := time .Now ()
172212
173213 pipelineBackends .TokenizeInputs (batch , p .Model .Tokenizer , inputs )
174214
175- // Track truncated sequences (tokenizer already handles truncation)
176- for _ , tokenizedInput := range batch .Input {
177- if len (tokenizedInput .TokenIDs ) >= p .Model .Tokenizer .MaxAllowedTokens {
178- atomic .AddUint64 (& p .stats .TruncatedSequences , 1 )
179- }
215+ if p .Model != nil && p .Model .Tokenizer != nil && p .Model .SeparatorToken == "[SEP]" {
216+ patchBertSequenceTokenTypeIDs (batch , p .Model .SeparatorToken )
180217 }
181218
182219 atomic .AddUint64 (& p .Model .Tokenizer .TokenizerTimings .NumCalls , 1 )
@@ -289,8 +326,16 @@ func (p *CrossEncoderPipeline) runBatch(query string, documents []string, startI
289326 var runErrors []error
290327
291328 inputs := make ([]string , len (documents ))
329+ sep := p .Model .SeparatorToken
330+
292331 for i , doc := range documents {
293- inputs [i ] = fmt .Sprintf ("[CLS] %s [SEP] %s [SEP]" , query , doc )
332+ if sep == "</s>" {
333+ // RoBERTa style: query </s> </s> document
334+ inputs [i ] = fmt .Sprintf ("%s%s%s%s" , query , sep , sep , doc )
335+ } else {
336+ // BERT style: query [SEP] document [SEP]
337+ inputs [i ] = fmt .Sprintf ("%s%s%s" , query , sep , doc )
338+ }
294339 }
295340
296341 batch := pipelineBackends .NewBatch (len (inputs ))
@@ -300,6 +345,7 @@ func (p *CrossEncoderPipeline) runBatch(query string, documents []string, startI
300345 }(batch )
301346
302347 runErrors = append (runErrors , p .Preprocess (batch , inputs ))
348+
303349 if e := errors .Join (runErrors ... ); e != nil {
304350 return nil , e
305351 }
0 commit comments