jupyterjazz commited on
Commit
9ef2e43
·
1 Parent(s): 9624180

refactor: image loading in st wrapper

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (1) hide show
  1. custom_st.py +80 -42
custom_st.py CHANGED
@@ -1,32 +1,34 @@
 
 
1
  from typing import Any, Dict, List, Literal, Optional, Union
2
 
 
3
  import torch
4
  from PIL import Image
5
  from torch import nn
6
- from transformers import AutoConfig, AutoProcessor, AutoModel
7
 
8
 
9
  class Transformer(nn.Module):
10
 
11
  save_in_root: bool = True
12
-
13
  def __init__(
14
  self,
15
- model_name_or_path: str = 'jinaai/jina-embeddings-v4',
16
  max_seq_length: Optional[int] = None,
17
  config_args: Optional[Dict[str, Any]] = None,
18
  model_args: Optional[Dict[str, Any]] = None,
19
  tokenizer_args: Optional[Dict[str, Any]] = None,
20
  cache_dir: Optional[str] = None,
21
- backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
22
  **kwargs,
23
  ) -> None:
24
  super(Transformer, self).__init__()
25
- if backend != 'torch':
26
  raise ValueError(
27
- f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
28
  )
29
-
30
  config_kwargs = config_args or {}
31
  model_kwargs = model_args or {}
32
  tokenizer_kwargs = tokenizer_args or {}
@@ -34,9 +36,11 @@ class Transformer(nn.Module):
34
  self.config = AutoConfig.from_pretrained(
35
  model_name_or_path, cache_dir=cache_dir, **config_kwargs
36
  )
37
- self.default_task = model_args.pop('default_task', None)
38
  if self.default_task and self.default_task not in self.config.task_names:
39
- raise ValueError(f"Invalid task: {self.default_task}. Must be one of {self.config.task_names}.")
 
 
40
 
41
  self.model = AutoModel.from_pretrained(
42
  model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs
@@ -45,6 +49,7 @@ class Transformer(nn.Module):
45
  self.processor = AutoProcessor.from_pretrained(
46
  model_name_or_path,
47
  cache_dir=cache_dir,
 
48
  **tokenizer_kwargs,
49
  )
50
  self.max_seq_length = max_seq_length or 8192
@@ -55,33 +60,52 @@ class Transformer(nn.Module):
55
  encoding = {}
56
  text_indices = []
57
  image_indices = []
58
-
59
  for i, text in enumerate(texts):
60
  if isinstance(text, str):
61
- text_indices.append(i)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  elif isinstance(text, Image.Image):
63
  image_indices.append(i)
64
  else:
65
- raise ValueError(f'Invalid input type: {type(text)}')
66
-
67
  if text_indices:
68
  _texts = [texts[i] for i in text_indices]
69
- text_features = self.processor.process_texts(_texts, max_length=self.max_seq_length)
 
 
70
  for key, value in text_features.items():
71
- encoding[f'text_{key}'] = value
72
- encoding['text_indices'] = text_indices
73
-
74
  if image_indices:
75
  _images = [texts[i] for i in image_indices]
76
  img_features = self.processor.process_images(_images)
77
  for key, value in img_features.items():
78
- encoding[f'image_{key}'] = value
79
- encoding['image_indices'] = image_indices
80
-
81
  return encoding
82
-
83
 
84
- def forward(self, features: Dict[str, torch.Tensor], task: Optional[str] = None) -> Dict[str, torch.Tensor]:
 
 
85
  self.model.eval()
86
 
87
  if task is None:
@@ -94,41 +118,55 @@ class Transformer(nn.Module):
94
  task = self.default_task
95
  else:
96
  if task not in self.config.task_names:
97
- raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.")
 
 
98
 
99
  device = self.model.device.type
100
  all_embeddings = []
101
-
102
  with torch.no_grad():
103
- if any(k.startswith('text_') for k in features.keys()):
104
- text_batch = {k[len('text_'):]: v.to(device) for k, v in features.items() if k.startswith('text_') and k != 'text_indices'}
105
- text_indices = features.get('text_indices', [])
106
-
 
 
 
 
107
  with torch.autocast(device_type=device):
108
- text_embeddings = self.model(**text_batch, task_label=task).single_vec_emb
 
 
109
  if self.config.truncate_dim:
110
- text_embeddings = text_embeddings[:, :self.config.truncate_dim]
111
-
112
  for i, embedding in enumerate(text_embeddings):
113
  all_embeddings.append((text_indices[i], embedding))
114
-
115
- if any(k.startswith('image_') for k in features.keys()):
116
- image_batch = {k[len('image_'):]: v.to(device) for k, v in features.items() if k.startswith('image_') and k != 'image_indices'}
117
- image_indices = features.get('image_indices', [])
118
-
 
 
 
 
119
  with torch.autocast(device_type=device):
120
- img_embeddings = self.model(**image_batch, task_label=task).single_vec_emb
 
 
121
  if self.config.truncate_dim:
122
- img_embeddings = img_embeddings[:, :self.config.truncate_dim]
123
-
124
  for i, embedding in enumerate(img_embeddings):
125
  all_embeddings.append((image_indices[i], embedding))
126
 
127
  if not all_embeddings:
128
- raise RuntimeError('No embeddings were generated')
129
 
130
  all_embeddings.sort(key=lambda x: x[0]) # sort by original index
131
  combined_embeddings = torch.stack([emb for _, emb in all_embeddings])
132
- features['sentence_embedding'] = combined_embeddings
133
-
134
  return features
 
1
+ from io import BytesIO
2
+ from pathlib import Path
3
  from typing import Any, Dict, List, Literal, Optional, Union
4
 
5
+ import requests
6
  import torch
7
  from PIL import Image
8
  from torch import nn
9
+ from transformers import AutoConfig, AutoModel, AutoProcessor
10
 
11
 
12
  class Transformer(nn.Module):
13
 
14
  save_in_root: bool = True
15
+
16
  def __init__(
17
  self,
18
+ model_name_or_path: str = "jinaai/jina-embeddings-v4",
19
  max_seq_length: Optional[int] = None,
20
  config_args: Optional[Dict[str, Any]] = None,
21
  model_args: Optional[Dict[str, Any]] = None,
22
  tokenizer_args: Optional[Dict[str, Any]] = None,
23
  cache_dir: Optional[str] = None,
24
+ backend: Literal["torch", "onnx", "openvino"] = "torch",
25
  **kwargs,
26
  ) -> None:
27
  super(Transformer, self).__init__()
28
+ if backend != "torch":
29
  raise ValueError(
30
+ f"Backend '{backend}' is not supported, please use 'torch' instead"
31
  )
 
32
  config_kwargs = config_args or {}
33
  model_kwargs = model_args or {}
34
  tokenizer_kwargs = tokenizer_args or {}
 
36
  self.config = AutoConfig.from_pretrained(
37
  model_name_or_path, cache_dir=cache_dir, **config_kwargs
38
  )
39
+ self.default_task = model_args.pop("default_task", None)
40
  if self.default_task and self.default_task not in self.config.task_names:
41
+ raise ValueError(
42
+ f"Invalid task: {self.default_task}. Must be one of {self.config.task_names}."
43
+ )
44
 
45
  self.model = AutoModel.from_pretrained(
46
  model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs
 
49
  self.processor = AutoProcessor.from_pretrained(
50
  model_name_or_path,
51
  cache_dir=cache_dir,
52
+ use_fast=True,
53
  **tokenizer_kwargs,
54
  )
55
  self.max_seq_length = max_seq_length or 8192
 
60
  encoding = {}
61
  text_indices = []
62
  image_indices = []
 
63
  for i, text in enumerate(texts):
64
  if isinstance(text, str):
65
+ # Remove Query: or Passage: prefixes when checking for URLs or file paths
66
+ clean_text = text
67
+ if text.startswith("Query: "):
68
+ clean_text = text[len("Query: ") :]
69
+ elif text.startswith("Passage: "):
70
+ clean_text = text[len("Passage: ") :]
71
+
72
+ if clean_text.startswith("http"):
73
+ response = requests.get(clean_text)
74
+ texts[i] = Image.open(BytesIO(response.content)).convert("RGB")
75
+ image_indices.append(i)
76
+ elif Path(clean_text).is_file():
77
+ try:
78
+ texts[i] = Image.open(clean_text).convert("RGB")
79
+ image_indices.append(i)
80
+ except Exception as e:
81
+ text_indices.append(i)
82
+ else:
83
+ text_indices.append(i)
84
  elif isinstance(text, Image.Image):
85
  image_indices.append(i)
86
  else:
87
+ raise ValueError(f"Invalid input type: {type(text)}")
 
88
  if text_indices:
89
  _texts = [texts[i] for i in text_indices]
90
+ text_features = self.processor.process_texts(
91
+ _texts, max_length=self.max_seq_length
92
+ )
93
  for key, value in text_features.items():
94
+ encoding[f"text_{key}"] = value
95
+ encoding["text_indices"] = text_indices
96
+
97
  if image_indices:
98
  _images = [texts[i] for i in image_indices]
99
  img_features = self.processor.process_images(_images)
100
  for key, value in img_features.items():
101
+ encoding[f"image_{key}"] = value
102
+ encoding["image_indices"] = image_indices
103
+
104
  return encoding
 
105
 
106
+ def forward(
107
+ self, features: Dict[str, torch.Tensor], task: Optional[str] = None
108
+ ) -> Dict[str, torch.Tensor]:
109
  self.model.eval()
110
 
111
  if task is None:
 
118
  task = self.default_task
119
  else:
120
  if task not in self.config.task_names:
121
+ raise ValueError(
122
+ f"Invalid task: {task}. Must be one of {self.config.task_names}."
123
+ )
124
 
125
  device = self.model.device.type
126
  all_embeddings = []
127
+
128
  with torch.no_grad():
129
+ if any(k.startswith("text_") for k in features.keys()):
130
+ text_batch = {
131
+ k[len("text_") :]: v.to(device)
132
+ for k, v in features.items()
133
+ if k.startswith("text_") and k != "text_indices"
134
+ }
135
+ text_indices = features.get("text_indices", [])
136
+
137
  with torch.autocast(device_type=device):
138
+ text_embeddings = self.model(
139
+ **text_batch, task_label=task
140
+ ).single_vec_emb
141
  if self.config.truncate_dim:
142
+ text_embeddings = text_embeddings[:, : self.config.truncate_dim]
143
+
144
  for i, embedding in enumerate(text_embeddings):
145
  all_embeddings.append((text_indices[i], embedding))
146
+
147
+ if any(k.startswith("image_") for k in features.keys()):
148
+ image_batch = {
149
+ k[len("image_") :]: v.to(device)
150
+ for k, v in features.items()
151
+ if k.startswith("image_") and k != "image_indices"
152
+ }
153
+ image_indices = features.get("image_indices", [])
154
+
155
  with torch.autocast(device_type=device):
156
+ img_embeddings = self.model(
157
+ **image_batch, task_label=task
158
+ ).single_vec_emb
159
  if self.config.truncate_dim:
160
+ img_embeddings = img_embeddings[:, : self.config.truncate_dim]
161
+
162
  for i, embedding in enumerate(img_embeddings):
163
  all_embeddings.append((image_indices[i], embedding))
164
 
165
  if not all_embeddings:
166
+ raise RuntimeError("No embeddings were generated")
167
 
168
  all_embeddings.sort(key=lambda x: x[0]) # sort by original index
169
  combined_embeddings = torch.stack([emb for _, emb in all_embeddings])
170
+ features["sentence_embedding"] = combined_embeddings
171
+
172
  return features