Skip to content

Conversation

@AymenKallala
Copy link

Summary

This PR introduces MultiVectorEncoder, a new model class for ColBERT-style multi-vector encoding in sentence-transformers. Unlike standard SentenceTransformer which produces a single embedding per text, MultiVectorEncoder produces multiple embeddings (one per token) and computes similarity via MaxSim (maximum similarity) between token embeddings.

Key Features

  • MultiVectorEncoder class: Extends SentenceTransformer with multi-vector encoding capabilities
  • LateInteractionPooling module: A pooling layer that preserves token-level embeddings with optional:
    • Dimension projection (e.g., 768 → 128)
    • Special token masking ([CLS], [SEP])
    • L2 normalization per token
  • MaxSim similarity functions: Implementations of late interaction similarity computation
  • Query/Document encoding: Dedicated encode_query() and encode_document() methods with automatic prompt handling
  • Ranking: Built-in rank() method for document ranking

Changes

File Lines Description
sentence_transformers/multi_vec_encoder/MultiVectorEncoder.py 534 Main encoder class extending SentenceTransformer
sentence_transformers/multi_vec_encoder/LateInteractionPooling.py 197 Token-preserving pooling layer with projection
sentence_transformers/multi_vec_encoder/similarity.py 111 MaxSim similarity functions
sentence_transformers/multi_vec_encoder/__init__.py 12 Package exports
sentence_transformers/__init__.py +2 Added MultiVectorEncoder export
tests/multi_vec_encoder/test_multi_vec_encoder.py 230 Comprehensive pytest tests

Usage Example

Option 1: Create from a pre-trained transformer

from sentence_transformers import MultiVectorEncoder

# Automatically creates Transformer + LateInteractionPooling pipeline
model = MultiVectorEncoder("bert-base-uncased")

# Encode queries and documents
query_embeddings = model.encode_query(["What is machine learning?"])
doc_embeddings = model.encode_document([
    "Machine learning is a subset of artificial intelligence.",
    "The weather is nice today.",
])

# Each embedding is a 2D tensor: [num_tokens, dim]
print(f"Query shape: {query_embeddings[0].shape}")  # [7, 128]
print(f"Doc shape: {doc_embeddings[0].shape}")      #[11, 128]

# Compute similarity scores using MaxSim
scores = model.similarity(query_embeddings, doc_embeddings)
print(f"Scores: {scores}")  # Shape: [1, 2]

Option 2: Create from custom modules

from sentence_transformers import MultiVectorEncoder
from sentence_transformers.multi_vec_encoder import LateInteractionPooling
from sentence_transformers.models import Transformer

# Create custom pipeline with specific configuration
transformer = Transformer("bert-base-uncased")
pooling = LateInteractionPooling(
    word_embedding_dimension=transformer.get_word_embedding_dimension(),
    output_dimension=128,      # Project to 128 dimensions
    normalize=True,            # L2-normalize each token
    skip_cls_token=False,      # Keep [CLS] token
    skip_sep_token=False,      # Keep [SEP] token
)

model = MultiVectorEncoder(modules=[transformer, pooling])

Document Ranking

from sentence_transformers import MultiVectorEncoder

model = MultiVectorEncoder("bert-base-uncased")

documents = [
    "Machine learning is a subset of artificial intelligence.",
    "The weather is nice today.",
    "Deep learning uses neural networks with many layers.",
]

# Rank documents by relevance to query
results = model.rank(
    query="What is machine learning?",
    documents=documents,
    top_k=2,
    return_documents=True,
)

for result in results:
    print(f"Score: {result['score']:.2f} - {result['text']}")

Similarity Scores

# Compute similarity for all pairs
queries = ["What is AI?", "How's the weather?"]
documents = ["AI is artificial intelligence.", "It's sunny outside."]

q_emb = model.encode_query(queries)
d_emb = model.encode_document(documents)

# score[i,j] = similarity(query[i], doc[j])
similarity_scores = model.similarity(q_emb, d_emb)
print(f"Similarity scores: {similarity_scores}")  # Shape: [2,2]

Pairwise Similarity

# Compute similarity for corresponding pairs only
queries = ["What is AI?", "How's the weather?"]
documents = ["AI is artificial intelligence.", "It's sunny outside."]

q_emb = model.encode_query(queries)
d_emb = model.encode_document(documents)

# Pairwise: score[i] = similarity(query[i], doc[i])
pairwise_scores = model.similarity_pairwise(q_emb, d_emb)
print(f"Pairwise scores: {pairwise_scores}")  # Shape: [2]

Future Work

  • Pre-trained model integration: Load existing ColBERT checkpoints (e.g., colbert-ir/colbertv2.0, Stanford ColBERT weights) directly via MultiVectorEncoder
  • Model card: Add MultiVectorEncoderModelCardData for proper model documentation
  • Training support: Add training related losses and evaluations

Related

@AymenKallala AymenKallala marked this pull request as ready for review January 25, 2026 12:50
@tomaarsen
Copy link
Member

tomaarsen commented Jan 30, 2026

Hello!

This is quite solid, quite reminiscent of PyLate. I'm quite interested in this architecture in Sentence Transformers, although I planned it after the #3554 refactor. This refactor introduces new Base... classes (Model, Trainer, DataCollator, etc.), and would simplify new architectures like multi-vector models. You may have already noticed that although subclassing SentenceTransformer is convenient, you also borrow some features that multi-vector models don't outright use (e.g. truncate_dim). The refactor changes that.

I think this is a very strong start though, and I'd be glad to work on top of this after #3554. For context, my current TODO is:

I think sticking to that order is best for the project, so then I'll likely get back to this after v5.4 is merged. What do you think?

  • Tom Aarsen

@AymenKallala
Copy link
Author

Hello!

This is quite solid, quite reminiscent of PyLate. I'm quite interested in this architecture in Sentence Transformers, although I planned it after the #3554 refactor. This refactor introduces new Base... classes (Model, Trainer, DataCollator, etc.), and would simplify new architectures like multi-vector models. You may have already noticed that although subclassing SentenceTransformer is convenient, you also borrow some features that multi-vector models don't outright use (e.g. truncate_dim). The refactor changes that.

I think this is a very strong start though, and I'd be glad to work on top of this after #3554. For context, my current TODO is:

I think sticking to that order is best for the project, so then I'll likely get back to this after v5.4 is merged. What do you think?

  • Tom Aarsen

Sure! Thanks for giving it a first review. I am happy to keep working on it when it will be more of a priority.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants