fix-image-pooling (#9)
Browse files- fix: image pooling (725b8ba6ba8cff17579843ca46e5eb21f7d5ea37)
- chore: remove prints (660fe4c4d743be00c7bdcb17c740414c21c53374)
- modeling_jina_embeddings_v4.py +10 -12
modeling_jina_embeddings_v4.py
CHANGED
|
@@ -216,22 +216,21 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 216 |
Project the hidden states to single-vector embeddings.
|
| 217 |
"""
|
| 218 |
if self._input_has_image(input_ids[0]): # got document image
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
.unsqueeze(0)
|
| 229 |
-
)
|
| 230 |
|
| 231 |
else: # got query text
|
| 232 |
pooled_output = torch.sum(
|
| 233 |
hidden_states * attention_mask.unsqueeze(-1), dim=1
|
| 234 |
) / torch.sum(attention_mask, dim=1, keepdim=True)
|
|
|
|
| 235 |
single_vec_emb = self.single_vector_projector(pooled_output)
|
| 236 |
return torch.nn.functional.normalize(single_vec_emb, dim=-1)
|
| 237 |
|
|
@@ -317,7 +316,6 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 317 |
embeddings = embeddings[:, :truncate_dim]
|
| 318 |
else:
|
| 319 |
embeddings = embeddings.multi_vec_emb
|
| 320 |
-
|
| 321 |
results.append(
|
| 322 |
embeddings.cpu()
|
| 323 |
if return_numpy
|
|
|
|
| 216 |
Project the hidden states to single-vector embeddings.
|
| 217 |
"""
|
| 218 |
if self._input_has_image(input_ids[0]): # got document image
|
| 219 |
+
img_start_positions = torch.where(input_ids == self.config.vision_start_token_id)[1]
|
| 220 |
+
img_end_positions = torch.where(input_ids == self.config.vision_end_token_id)[1]
|
| 221 |
+
|
| 222 |
+
batch_size, seq_len = input_ids.shape
|
| 223 |
+
position_indices = torch.arange(seq_len, device=input_ids.device).expand(batch_size, -1)
|
| 224 |
+
image_mask = (position_indices >= img_start_positions.unsqueeze(1)) & (position_indices <= img_end_positions.unsqueeze(1))
|
| 225 |
+
|
| 226 |
+
masked_hidden_states = hidden_states * image_mask.unsqueeze(-1)
|
| 227 |
+
pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(dim=1, keepdim=True)
|
|
|
|
|
|
|
| 228 |
|
| 229 |
else: # got query text
|
| 230 |
pooled_output = torch.sum(
|
| 231 |
hidden_states * attention_mask.unsqueeze(-1), dim=1
|
| 232 |
) / torch.sum(attention_mask, dim=1, keepdim=True)
|
| 233 |
+
|
| 234 |
single_vec_emb = self.single_vector_projector(pooled_output)
|
| 235 |
return torch.nn.functional.normalize(single_vec_emb, dim=-1)
|
| 236 |
|
|
|
|
| 316 |
embeddings = embeddings[:, :truncate_dim]
|
| 317 |
else:
|
| 318 |
embeddings = embeddings.multi_vec_emb
|
|
|
|
| 319 |
results.append(
|
| 320 |
embeddings.cpu()
|
| 321 |
if return_numpy
|