Skip to content

Commit df3d5ae

Browse files
committed
fix: align several usage with authors' example; fix indexing issues when lengths of both modalities don't match; Implemented normalization of embeddings as recommended by the authors.
- Prepared inputs for generation with cache position handling for Qwen2.5-Omni.
1 parent 1cefe0e commit df3d5ae

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

mteb/models/model_implementations/e5_omni.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ def encode(
104104
max_len = max(len(batch_texts), len(batch_images))
105105
for i in range(max_len):
106106
content = []
107-
if batch_texts:
107+
if i < len(batch_texts):
108108
content.append({"type": "text", "text": batch_texts[i]})
109-
if batch_images:
109+
if i < len(batch_images):
110110
content.append({"type": "image", "image": batch_images[i]})
111111
messages.append([{"role": "user", "content": content}])
112112

@@ -121,6 +121,7 @@ def encode(
121121

122122
image_inputs = None
123123
video_inputs = None
124+
audio_inputs = None
124125
if batch_images:
125126
from qwen_vl_utils import process_vision_info
126127

@@ -130,12 +131,21 @@ def encode(
130131
text=texts,
131132
images=image_inputs,
132133
videos=video_inputs,
134+
audio=audio_inputs,
133135
padding=True,
134136
return_tensors="pt",
135137
truncation=True,
136138
max_length=512,
137139
).to(self.device)
138140

141+
# Prepare inputs for generation to handle cache_position and other requirements for Qwen2.5-Omni
142+
cache_position = torch.arange(
143+
0, model_inputs["input_ids"].shape[1], device=self.device
144+
)
145+
model_inputs = self.model.prepare_inputs_for_generation(
146+
**model_inputs, use_cache=True, cache_position=cache_position
147+
)
148+
139149
outputs = self.model(**model_inputs, output_hidden_states=True)
140150

141151
# For E5-Omni, we use the last hidden state of the last token
@@ -144,12 +154,14 @@ def encode(
144154

145155
# Find the last non-padding token
146156
attention_mask = model_inputs["attention_mask"]
147-
# Qwen2.5-Omni uses right padding by default in many setups
148157
sequence_lengths = attention_mask.sum(dim=1) - 1
149158
embeddings = last_hidden_state[
150159
torch.arange(last_hidden_state.size(0)), sequence_lengths
151160
]
152161

162+
# Normalize embeddings as recommended by the authors
163+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
164+
153165
all_embeddings.append(embeddings.cpu().to(torch.float32))
154166

155167
return torch.cat(all_embeddings, dim=0).numpy()

0 commit comments

Comments
 (0)