File size: 5,685 Bytes
bc22368 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from typing import Any, Dict, List, Literal, Optional, Union
import torch
from PIL import Image
from torch import nn
from transformers import AutoConfig, AutoProcessor, AutoModel
class Transformer(nn.Module):
save_in_root: bool = True
def __init__(
self,
model_name_or_path: str = 'jinaai/jina-embeddings-v4',
max_seq_length: Optional[int] = None,
config_args: Optional[Dict[str, Any]] = None,
model_args: Optional[Dict[str, Any]] = None,
tokenizer_args: Optional[Dict[str, Any]] = None,
cache_dir: Optional[str] = None,
backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
**kwargs,
) -> None:
super(Transformer, self).__init__()
if backend != 'torch':
raise ValueError(
f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
)
config_kwargs = config_args or {}
model_kwargs = model_args or {}
tokenizer_kwargs = tokenizer_args or {}
self.config = AutoConfig.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **config_kwargs
)
self.default_task = model_args.pop('default_task', None)
if self.default_task and self.default_task not in self.config.task_names:
raise ValueError(f"Invalid task: {self.default_task}. Must be one of {self.config.task_names}.")
self.model = AutoModel.from_pretrained(
model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs
)
self.processor = AutoProcessor.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
**tokenizer_kwargs,
)
self.max_seq_length = max_seq_length or 8192
def tokenize(
self, texts: List[Union[str, Image.Image]], padding: Union[str, bool] = True
) -> Dict[str, torch.Tensor]:
encoding = {}
text_indices = []
image_indices = []
for i, text in enumerate(texts):
if isinstance(text, str):
text_indices.append(i)
elif isinstance(text, Image.Image):
image_indices.append(i)
else:
raise ValueError(f'Invalid input type: {type(text)}')
if text_indices:
_texts = [texts[i] for i in text_indices]
text_features = self.processor.process_texts(_texts, max_length=self.max_seq_length)
for key, value in text_features.items():
encoding[f'text_{key}'] = value
encoding['text_indices'] = text_indices
if image_indices:
_images = [texts[i] for i in image_indices]
img_features = self.processor.process_images(_images)
for key, value in img_features.items():
encoding[f'image_{key}'] = value
encoding['image_indices'] = image_indices
return encoding
def forward(self, features: Dict[str, torch.Tensor], task: Optional[str] = None) -> Dict[str, torch.Tensor]:
self.model.eval()
if task is None:
if self.default_task is None:
raise ValueError(
"Task must be specified before encoding data. You can set it either during "
"loading the model (e.g., model_kwargs={'default_task': 'retrieval'}) or "
"pass it as an argument to the encode method (e.g., model.encode(texts, task='retrieval'))."
)
task = self.default_task
else:
if task not in self.config.task_names:
raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.")
device = self.model.device.type
all_embeddings = []
with torch.no_grad():
if any(k.startswith('text_') for k in features.keys()):
text_batch = {k[len('text_'):]: v.to(device) for k, v in features.items() if k.startswith('text_') and k != 'text_indices'}
text_indices = features.get('text_indices', [])
with torch.autocast(device_type=device):
text_embeddings = self.model(**text_batch, task_label=task).single_vec_emb
if self.config.truncate_dim:
text_embeddings = text_embeddings[:, :self.config.truncate_dim]
for i, embedding in enumerate(text_embeddings):
all_embeddings.append((text_indices[i], embedding))
if any(k.startswith('image_') for k in features.keys()):
image_batch = {k[len('image_'):]: v.to(device) for k, v in features.items() if k.startswith('image_') and k != 'image_indices'}
image_indices = features.get('image_indices', [])
with torch.autocast(device_type=device):
img_embeddings = self.model(**image_batch, task_label=task).single_vec_emb
if self.config.truncate_dim:
img_embeddings = img_embeddings[:, :self.config.truncate_dim]
for i, embedding in enumerate(img_embeddings):
all_embeddings.append((image_indices[i], embedding))
if not all_embeddings:
raise RuntimeError('No embeddings were generated')
all_embeddings.sort(key=lambda x: x[0]) # sort by original index
combined_embeddings = torch.stack([emb for _, emb in all_embeddings])
features['sentence_embedding'] = combined_embeddings
return features
|