izhx commited on
Commit
64a2954
·
verified ·
1 Parent(s): eb0f9be

Create custom_st.py

Browse files
Files changed (1) hide show
  1. custom_st.py +221 -0
custom_st.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from typing import Any, Dict, Optional, List
3
+ import torch
4
+ from PIL import Image
5
+ from sentence_transformers.models import Transformer as BaseTransformer
6
+ from transformers import AutoModelForVision2Seq, AutoProcessor
7
+
8
+
9
+ class MultiModalTransformer(BaseTransformer):
10
+ def __init__(
11
+ self,
12
+ model_name_or_path: str,
13
+ cache_dir: Optional[str] = None,
14
+ tokenizer_args: Optional[Dict[str, Any]] = None,
15
+ min_image_tokens: int = 256,
16
+ max_image_tokens: int = 1280,
17
+ max_length: int = 1800,
18
+ **kwargs,
19
+ ):
20
+ super().__init__(model_name_or_path, **kwargs)
21
+ if tokenizer_args is None:
22
+ tokenizer_args = {}
23
+ tokenizer_args.pop("trust_remote_code", None)
24
+
25
+ # Initialize processor
26
+ min_pixels = min_image_tokens * 28 * 28
27
+ max_pixels = max_image_tokens * 28 * 28
28
+ self.processor = AutoProcessor.from_pretrained(
29
+ model_name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
30
+ )
31
+ self.processor.tokenizer.padding_side = 'right'
32
+ self.sep = ' '
33
+ self.max_length = max_length
34
+ self.normalize = True
35
+
36
+ def _load_model(
37
+ self,
38
+ model_name_or_path: str,
39
+ config,
40
+ cache_dir: str,
41
+ backend: str,
42
+ is_peft_model: bool,
43
+ **model_args,
44
+ ) -> None:
45
+ model_args.pop("trust_remote_code", None)
46
+ self.auto_model = AutoModelForVision2Seq.from_pretrained(
47
+ model_name_or_path, torch_dtype=torch.float16, **model_args
48
+ )
49
+
50
+ def forward(
51
+ self, features: Dict[str, torch.Tensor], **kwargs
52
+ ) -> Dict[str, torch.Tensor]:
53
+ if features.get("inputs_embeds", None) is None:
54
+ features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"])
55
+ if features.get("pixel_values", None) is not None:
56
+ features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
57
+ image_embeds = self.auto_model.visual(
58
+ features["pixel_values"], grid_thw=features["image_grid_thw"]
59
+ )
60
+ image_mask = features["input_ids"] == self.auto_model.config.image_token_id
61
+ features["inputs_embeds"][image_mask] = image_embeds
62
+ # features.pop("pixel_values")
63
+ # features.pop("image_grid_thw")
64
+ # features.pop("input_ids")
65
+ inputs = {k: v for k, v in features.items() if k in 'position_ids,attention_mask,inputs_embeds'}
66
+ outputs = self.auto_model.model(
67
+ **inputs,
68
+ return_dict=True,
69
+ output_hidden_states=True,
70
+ # **kwargs
71
+ )
72
+ # pooling_mask = features["attention_mask"] if features.get("pooling_mask", None) is None else features["pooling_mask"]
73
+ # left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
74
+ # if left_padding:
75
+ # embeddings = outputs.last_hidden_state
76
+ # else:
77
+ # sequence_lengths = pooling_mask.sum(dim=1) - 1
78
+ # embeddings = outputs.last_hidden_state[torch.arange(
79
+ # outputs.last_hidden_state.shape[0], device=outputs.last_hidden_state.device
80
+ # ), sequence_lengths]
81
+ features.update({"token_embeddings": outputs.last_hidden_state})
82
+ return features
83
+
84
+ def tokenize(self, texts: List[List[Dict[str, Any]]] | List[str]) -> Dict[str, torch.Tensor]:
85
+ default_instruction = 'You are a helpful assistant.'
86
+
87
+ all_texts, all_images = list(), list()
88
+ for item in texts:
89
+ if isinstance(item, str):
90
+ txt, img, inst = item, None, default_instruction
91
+ elif isinstance(item, dict):
92
+ txt = item.get('text', None)
93
+ img = item.get('image', None)
94
+ inst = item.get('prompt', default_instruction)
95
+ else:
96
+ raise RuntimeError(f'Input format not supported! {item=}')
97
+
98
+ input_str = ''
99
+ if img is None:
100
+ all_images = None # All examples in the same batch are consistent
101
+ # or will have ValueError: Could not make a flat list of images from xxxx
102
+ else:
103
+ input_str += '<|vision_start|><|image_pad|><|vision_end|>'
104
+ img = fetch_image(img)
105
+ all_images.append(img)
106
+ if txt is not None:
107
+ input_str += txt
108
+ msg = f'<|im_start|>system\n{inst}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
109
+ all_texts.append(msg)
110
+
111
+ inputs = self.processor(
112
+ text=all_texts,
113
+ images=all_images,
114
+ padding="longest",
115
+ truncation=True,
116
+ max_length=self.max_seq_length,
117
+ return_tensors='pt'
118
+ )
119
+ return inputs
120
+
121
+
122
+ ### Copied from qwen_vl_utils.vision_process.py
123
+ import base64
124
+ from io import BytesIO
125
+ import requests
126
+
127
+ IMAGE_FACTOR = 28
128
+ MIN_PIXELS = 4 * 28 * 28
129
+ MAX_PIXELS = 16384 * 28 * 28
130
+ MAX_RATIO = 200
131
+
132
+
133
+ def round_by_factor(number: int, factor: int) -> int:
134
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
135
+ return round(number / factor) * factor
136
+
137
+
138
+ def ceil_by_factor(number: int, factor: int) -> int:
139
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
140
+ return math.ceil(number / factor) * factor
141
+
142
+
143
+ def floor_by_factor(number: int, factor: int) -> int:
144
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
145
+ return math.floor(number / factor) * factor
146
+
147
+
148
+ def smart_resize(
149
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
150
+ ) -> tuple[int, int]:
151
+ """
152
+ Rescales the image so that the following conditions are met:
153
+
154
+ 1. Both dimensions (height and width) are divisible by 'factor'.
155
+
156
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
157
+
158
+ 3. The aspect ratio of the image is maintained as closely as possible.
159
+ """
160
+ h_bar = max(factor, round_by_factor(height, factor))
161
+ w_bar = max(factor, round_by_factor(width, factor))
162
+ if h_bar * w_bar > max_pixels:
163
+ beta = math.sqrt((height * width) / max_pixels)
164
+ h_bar = floor_by_factor(height / beta, factor)
165
+ w_bar = floor_by_factor(width / beta, factor)
166
+ elif h_bar * w_bar < min_pixels:
167
+ beta = math.sqrt(min_pixels / (height * width))
168
+ h_bar = ceil_by_factor(height * beta, factor)
169
+ w_bar = ceil_by_factor(width * beta, factor)
170
+
171
+ if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
172
+ logging.warning(
173
+ f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}"
174
+ )
175
+ if h_bar > w_bar:
176
+ h_bar = w_bar * MAX_RATIO
177
+ else:
178
+ w_bar = h_bar * MAX_RATIO
179
+ return h_bar, w_bar
180
+
181
+
182
+ def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
183
+ image_obj = None
184
+ if isinstance(image, Image.Image):
185
+ image_obj = image
186
+ elif image.startswith("http://") or image.startswith("https://"):
187
+ image_obj = Image.open(requests.get(image, stream=True).raw)
188
+ elif image.startswith("file://"):
189
+ image_obj = Image.open(image[7:])
190
+ elif image.startswith("data:image"):
191
+ if "base64," in image:
192
+ _, base64_data = image.split("base64,", 1)
193
+ data = base64.b64decode(base64_data)
194
+ image_obj = Image.open(BytesIO(data))
195
+ else:
196
+ image_obj = Image.open(image)
197
+ if image_obj is None:
198
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
199
+ image = image_obj.convert("RGB")
200
+ ## resize
201
+ # if "resized_height" in ele and "resized_width" in ele:
202
+ # resized_height, resized_width = smart_resize(
203
+ # ele["resized_height"],
204
+ # ele["resized_width"],
205
+ # factor=size_factor,
206
+ # )
207
+ # else:
208
+ width, height = image.size
209
+ # min_pixels = ele.get("min_pixels", MIN_PIXELS)
210
+ # max_pixels = ele.get("max_pixels", MAX_PIXELS)
211
+ resized_height, resized_width = smart_resize(
212
+ height,
213
+ width,
214
+ factor=size_factor,
215
+ min_pixels=MIN_PIXELS,
216
+ max_pixels=MAX_PIXELS,
217
+ )
218
+ image = image.resize((resized_width, resized_height))
219
+
220
+ return image
221
+ ###