|
from typing import Any, Dict, List, Literal, Optional, Union |
|
|
|
import torch |
|
from PIL import Image |
|
from torch import nn |
|
from transformers import AutoConfig, AutoProcessor, AutoModel |
|
|
|
|
|
class Transformer(nn.Module): |
|
|
|
save_in_root: bool = True |
|
|
|
def __init__( |
|
self, |
|
model_name_or_path: str = 'jinaai/jina-embeddings-v4', |
|
max_seq_length: Optional[int] = None, |
|
config_args: Optional[Dict[str, Any]] = None, |
|
model_args: Optional[Dict[str, Any]] = None, |
|
tokenizer_args: Optional[Dict[str, Any]] = None, |
|
cache_dir: Optional[str] = None, |
|
backend: Literal['torch', 'onnx', 'openvino'] = 'torch', |
|
**kwargs, |
|
) -> None: |
|
super(Transformer, self).__init__() |
|
if backend != 'torch': |
|
raise ValueError( |
|
f'Backend \'{backend}\' is not supported, please use \'torch\' instead' |
|
) |
|
|
|
config_kwargs = config_args or {} |
|
model_kwargs = model_args or {} |
|
tokenizer_kwargs = tokenizer_args or {} |
|
|
|
self.config = AutoConfig.from_pretrained( |
|
model_name_or_path, cache_dir=cache_dir, **config_kwargs |
|
) |
|
self.default_task = model_args.pop('default_task', None) |
|
if self.default_task and self.default_task not in self.config.task_names: |
|
raise ValueError(f"Invalid task: {self.default_task}. Must be one of {self.config.task_names}.") |
|
|
|
self.model = AutoModel.from_pretrained( |
|
model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs |
|
) |
|
|
|
self.processor = AutoProcessor.from_pretrained( |
|
model_name_or_path, |
|
cache_dir=cache_dir, |
|
**tokenizer_kwargs, |
|
) |
|
self.max_seq_length = max_seq_length or 8192 |
|
|
|
def tokenize( |
|
self, texts: List[Union[str, Image.Image]], padding: Union[str, bool] = True |
|
) -> Dict[str, torch.Tensor]: |
|
encoding = {} |
|
text_indices = [] |
|
image_indices = [] |
|
|
|
for i, text in enumerate(texts): |
|
if isinstance(text, str): |
|
text_indices.append(i) |
|
elif isinstance(text, Image.Image): |
|
image_indices.append(i) |
|
else: |
|
raise ValueError(f'Invalid input type: {type(text)}') |
|
|
|
if text_indices: |
|
_texts = [texts[i] for i in text_indices] |
|
text_features = self.processor.process_texts(_texts, max_length=self.max_seq_length) |
|
for key, value in text_features.items(): |
|
encoding[f'text_{key}'] = value |
|
encoding['text_indices'] = text_indices |
|
|
|
if image_indices: |
|
_images = [texts[i] for i in image_indices] |
|
img_features = self.processor.process_images(_images) |
|
for key, value in img_features.items(): |
|
encoding[f'image_{key}'] = value |
|
encoding['image_indices'] = image_indices |
|
|
|
return encoding |
|
|
|
|
|
def forward(self, features: Dict[str, torch.Tensor], task: Optional[str] = None) -> Dict[str, torch.Tensor]: |
|
self.model.eval() |
|
|
|
if task is None: |
|
if self.default_task is None: |
|
raise ValueError( |
|
"Task must be specified before encoding data. You can set it either during " |
|
"loading the model (e.g., model_kwargs={'default_task': 'retrieval'}) or " |
|
"pass it as an argument to the encode method (e.g., model.encode(texts, task='retrieval'))." |
|
) |
|
task = self.default_task |
|
else: |
|
if task not in self.config.task_names: |
|
raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.") |
|
|
|
device = self.model.device.type |
|
all_embeddings = [] |
|
|
|
with torch.no_grad(): |
|
if any(k.startswith('text_') for k in features.keys()): |
|
text_batch = {k[len('text_'):]: v.to(device) for k, v in features.items() if k.startswith('text_') and k != 'text_indices'} |
|
text_indices = features.get('text_indices', []) |
|
|
|
with torch.autocast(device_type=device): |
|
text_embeddings = self.model(**text_batch, task_label=task).single_vec_emb |
|
if self.config.truncate_dim: |
|
text_embeddings = text_embeddings[:, :self.config.truncate_dim] |
|
|
|
for i, embedding in enumerate(text_embeddings): |
|
all_embeddings.append((text_indices[i], embedding)) |
|
|
|
if any(k.startswith('image_') for k in features.keys()): |
|
image_batch = {k[len('image_'):]: v.to(device) for k, v in features.items() if k.startswith('image_') and k != 'image_indices'} |
|
image_indices = features.get('image_indices', []) |
|
|
|
with torch.autocast(device_type=device): |
|
img_embeddings = self.model(**image_batch, task_label=task).single_vec_emb |
|
if self.config.truncate_dim: |
|
img_embeddings = img_embeddings[:, :self.config.truncate_dim] |
|
|
|
for i, embedding in enumerate(img_embeddings): |
|
all_embeddings.append((image_indices[i], embedding)) |
|
|
|
if not all_embeddings: |
|
raise RuntimeError('No embeddings were generated') |
|
|
|
all_embeddings.sort(key=lambda x: x[0]) |
|
combined_embeddings = torch.stack([emb for _, emb in all_embeddings]) |
|
features['sentence_embedding'] = combined_embeddings |
|
|
|
return features |
|
|