Guilherme34 commited on
Commit
cf73bb6
·
verified ·
1 Parent(s): 040782e

Upload image_processing_minicpmv.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. image_processing_minicpmv.py +407 -0
image_processing_minicpmv.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The OpenBMB Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Any
18
+ from typing import Dict
19
+ from typing import List
20
+ from typing import Optional
21
+ from typing import Union
22
+
23
+ import numpy as np
24
+ import PIL
25
+ import PIL.Image
26
+ import PIL.ImageSequence
27
+ import torch
28
+ from PIL import Image
29
+ from transformers import AutoImageProcessor
30
+ from transformers.image_processing_utils import BaseImageProcessor
31
+ from transformers.image_processing_utils import BatchFeature
32
+ from transformers.image_transforms import to_channel_dimension_format
33
+ from transformers.image_utils import ChannelDimension
34
+ from transformers.image_utils import infer_channel_dimension_format
35
+ from transformers.image_utils import is_torch_tensor
36
+ from transformers.image_utils import to_numpy_array
37
+ from transformers.image_utils import valid_images
38
+ from transformers.utils import is_torch_device
39
+ from transformers.utils import is_torch_dtype
40
+ from transformers.utils import requires_backends
41
+ from transformers.utils import TensorType
42
+
43
+
44
+ def recursive_converter(converter, value):
45
+ if isinstance(value, list):
46
+ new_value = []
47
+ for v in value:
48
+ new_value += [recursive_converter(converter, v)]
49
+ return new_value
50
+ else:
51
+ return converter(value)
52
+
53
+
54
+ class MiniCPMOBatchFeature(BatchFeature):
55
+ r"""
56
+ Extend from BatchFeature for supporting various image size
57
+ """
58
+
59
+ def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
60
+ super().__init__(data)
61
+ self.convert_to_tensors(tensor_type=tensor_type)
62
+
63
+ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
64
+ if tensor_type is None:
65
+ return self
66
+
67
+ is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
68
+
69
+ def converter(value):
70
+ try:
71
+ if not is_tensor(value):
72
+ tensor = as_tensor(value)
73
+ return tensor
74
+ except: # noqa E722
75
+ if key == "overflowing_values":
76
+ raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
77
+ raise ValueError(
78
+ "Unable to create tensor, you should probably activate padding "
79
+ "with 'padding=True' to have batched tensors with the same length."
80
+ )
81
+
82
+ for key, value in self.items():
83
+ self[key] = recursive_converter(converter, value)
84
+ return self
85
+
86
+ def to(self, *args, **kwargs) -> "MiniCPMOBatchFeature":
87
+ requires_backends(self, ["torch"])
88
+ import torch
89
+
90
+ def cast_tensor(v):
91
+ # check if v is a floating point
92
+ if torch.is_floating_point(v):
93
+ # cast and send to device
94
+ return v.to(*args, **kwargs)
95
+ elif device is not None:
96
+ return v.to(device=device)
97
+ else:
98
+ return v
99
+
100
+ new_data = {}
101
+ device = kwargs.get("device")
102
+ # Check if the args are a device or a dtype
103
+ if device is None and len(args) > 0:
104
+ # device should be always the first argument
105
+ arg = args[0]
106
+ if is_torch_dtype(arg):
107
+ # The first argument is a dtype
108
+ pass
109
+ elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
110
+ device = arg
111
+ else:
112
+ # it's something else
113
+ raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
114
+ # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
115
+ for k, v in self.items():
116
+ new_data[k] = recursive_converter(cast_tensor, v)
117
+ self.data = new_data
118
+ return self
119
+
120
+
121
+ class MiniCPMVImageProcessor(BaseImageProcessor):
122
+ model_input_names = ["pixel_values"]
123
+
124
+ def __init__(self, max_slice_nums=9, scale_resolution=448, patch_size=14, **kwargs):
125
+ super().__init__(**kwargs)
126
+ self.max_slice_nums = max_slice_nums
127
+ self.scale_resolution = scale_resolution
128
+ self.patch_size = patch_size
129
+ self.use_image_id = kwargs.pop("use_image_id", False)
130
+ self.image_feature_size = kwargs.pop("image_feature_size", 64)
131
+ self.im_start_token = kwargs.pop("im_start", "<image>")
132
+ self.im_end_token = kwargs.pop("im_end", "</image>")
133
+ self.slice_start_token = kwargs.pop("slice_start", "<slice>")
134
+ self.slice_end_token = kwargs.pop("slice_end", "</slice>")
135
+ self.unk_token = kwargs.pop("unk", "<unk>")
136
+ self.im_id_start = kwargs.pop("im_id_start", "<image_id>")
137
+ self.im_id_end = kwargs.pop("im_id_end", "</image_id>")
138
+ self.slice_mode = kwargs.pop("slice_mode", True)
139
+
140
+ self.mean = np.array(kwargs.pop("norm_mean", [0.5, 0.5, 0.5]))
141
+ self.std = np.array(kwargs.pop("norm_std", [0.5, 0.5, 0.5]))
142
+ self.version = kwargs.pop("version", 2.0)
143
+
144
+ def ensure_divide(self, length, patch_size):
145
+ return max(round(length / patch_size) * patch_size, patch_size)
146
+
147
+ def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False):
148
+ width, height = original_size
149
+ if (width * height > scale_resolution * scale_resolution) or allow_upscale:
150
+ r = width / height
151
+ height = int(scale_resolution / math.sqrt(r))
152
+ width = int(height * r)
153
+ best_width = self.ensure_divide(width, patch_size)
154
+ best_height = self.ensure_divide(height, patch_size)
155
+ return (best_width, best_height)
156
+
157
+ def get_refine_size(self, original_size, grid, scale_resolution, patch_size, allow_upscale=False):
158
+ width, height = original_size
159
+ grid_x, grid_y = grid
160
+
161
+ refine_width = self.ensure_divide(width, grid_x)
162
+ refine_height = self.ensure_divide(height, grid_y)
163
+
164
+ grid_width = refine_width / grid_x
165
+ grid_height = refine_height / grid_y
166
+
167
+ best_grid_size = self.find_best_resize(
168
+ (grid_width, grid_height), scale_resolution, patch_size, allow_upscale=allow_upscale
169
+ )
170
+ refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
171
+ return refine_size
172
+
173
+ def split_to_patches(self, image, grid):
174
+ patches = []
175
+ width, height = image.size
176
+ grid_x = int(width / grid[0])
177
+ grid_y = int(height / grid[1])
178
+ for i in range(0, height, grid_y):
179
+ images = []
180
+ for j in range(0, width, grid_x):
181
+ box = (j, i, j + grid_x, i + grid_y)
182
+ patch = image.crop(box)
183
+ images.append(patch)
184
+ patches.append(images)
185
+ return patches
186
+
187
+ def slice_image(self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):
188
+ original_size = image.size
189
+ source_image = None
190
+ best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split)
191
+ patches = []
192
+
193
+ if best_grid is None:
194
+ # dont need to slice, upsample
195
+ best_size = self.find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)
196
+ source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
197
+ else:
198
+ # source image, down-sampling and ensure divided by patch_size
199
+ best_resize = self.find_best_resize(original_size, scale_resolution, patch_size)
200
+ source_image = image.copy().resize(best_resize, resample=Image.Resampling.BICUBIC)
201
+ refine_size = self.get_refine_size(
202
+ original_size, best_grid, scale_resolution, patch_size, allow_upscale=True
203
+ )
204
+ refine_image = image.resize(refine_size, resample=Image.Resampling.BICUBIC)
205
+ patches = self.split_to_patches(refine_image, best_grid)
206
+
207
+ return source_image, patches, best_grid
208
+
209
+ def get_grid_placeholder(self, grid):
210
+ if grid is None:
211
+ return ""
212
+ slice_image_placeholder = (
213
+ self.slice_start_token + self.unk_token * self.image_feature_size + self.slice_end_token
214
+ )
215
+
216
+ cols = grid[0]
217
+ rows = grid[1]
218
+ slices = []
219
+ for i in range(rows):
220
+ lines = []
221
+ for j in range(cols):
222
+ lines.append(slice_image_placeholder)
223
+ slices.append("".join(lines))
224
+
225
+ slice_placeholder = "\n".join(slices)
226
+ return slice_placeholder
227
+
228
+ def get_image_id_placeholder(self, idx=0):
229
+ return f"{self.im_id_start}{idx}{self.im_id_end}"
230
+
231
+ def get_sliced_images(self, image, max_slice_nums=None):
232
+ slice_images = []
233
+
234
+ if not self.slice_mode:
235
+ return [image]
236
+
237
+ max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
238
+ assert max_slice_nums > 0
239
+ source_image, patches, sliced_grid = self.slice_image(
240
+ image, max_slice_nums, self.scale_resolution, self.patch_size # default: 9 # default: 448 # default: 14
241
+ )
242
+
243
+ slice_images.append(source_image)
244
+ if len(patches) > 0:
245
+ for i in range(len(patches)):
246
+ for j in range(len(patches[0])):
247
+ slice_images.append(patches[i][j])
248
+ return slice_images
249
+
250
+ def get_sliced_grid(self, image_size, max_slice_nums, nerver_split=False):
251
+ original_width, original_height = image_size
252
+ log_ratio = math.log(original_width / original_height)
253
+ ratio = original_width * original_height / (self.scale_resolution * self.scale_resolution)
254
+ multiple = min(math.ceil(ratio), max_slice_nums)
255
+ if multiple <= 1 or nerver_split:
256
+ return None
257
+ candidate_split_grids_nums = []
258
+ for i in [multiple - 1, multiple, multiple + 1]:
259
+ if i == 1 or i > max_slice_nums:
260
+ continue
261
+ candidate_split_grids_nums.append(i)
262
+
263
+ candidate_grids = []
264
+ for split_grids_nums in candidate_split_grids_nums:
265
+ m = 1
266
+ while m <= split_grids_nums:
267
+ if split_grids_nums % m == 0:
268
+ candidate_grids.append([m, split_grids_nums // m])
269
+ m += 1
270
+
271
+ best_grid = [1, 1]
272
+ min_error = float("inf")
273
+ for grid in candidate_grids:
274
+ error = abs(log_ratio - math.log(grid[0] / grid[1]))
275
+ if error < min_error:
276
+ best_grid = grid
277
+ min_error = error
278
+
279
+ return best_grid
280
+
281
+ def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None):
282
+ max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
283
+ assert max_slice_nums > 0
284
+ grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums)
285
+
286
+ image_placeholder = self.im_start_token + self.unk_token * self.image_feature_size + self.im_end_token
287
+ use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id)
288
+ if use_image_id:
289
+ final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder
290
+ else:
291
+ final_placeholder = image_placeholder
292
+
293
+ if self.slice_mode:
294
+ final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid)
295
+ return final_placeholder
296
+
297
+ def to_pil_image(self, image, rescale=None) -> PIL.Image.Image:
298
+ """
299
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
300
+ needed.
301
+
302
+ Args:
303
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
304
+ The image to convert to the PIL Image format.
305
+ rescale (`bool`, *optional*):
306
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
307
+ default to `True` if the image type is a floating type, `False` otherwise.
308
+ """
309
+ if isinstance(image, PIL.Image.Image):
310
+ return image
311
+ if is_torch_tensor(image):
312
+ image = image.numpy()
313
+
314
+ if isinstance(image, np.ndarray):
315
+ if rescale is None:
316
+ # rescale default to the array being of floating type.
317
+ rescale = isinstance(image.flat[0], np.floating)
318
+ # If the channel as been moved to first dim, we put it back at the end.
319
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
320
+ image = image.transpose(1, 2, 0)
321
+ if rescale:
322
+ image = image * 255
323
+ image = image.astype(np.uint8)
324
+ return PIL.Image.fromarray(image)
325
+ return image
326
+
327
+ def reshape_by_patch(self, image):
328
+ """
329
+ :param image: shape [3, H, W]
330
+ :param patch_size:
331
+ :return: [3, patch_size, HW/patch_size]
332
+ """
333
+ image = torch.from_numpy(image)
334
+ patch_size = self.patch_size
335
+ patches = torch.nn.functional.unfold(image, (patch_size, patch_size), stride=(patch_size, patch_size))
336
+
337
+ patches = patches.reshape(image.size(0), patch_size, patch_size, -1)
338
+ patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1)
339
+ return patches.numpy()
340
+
341
+ def preprocess(
342
+ self,
343
+ images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
344
+ do_pad: Optional[bool] = True,
345
+ max_slice_nums: int = None,
346
+ return_tensors: Optional[Union[str, TensorType]] = None,
347
+ **kwargs,
348
+ ) -> MiniCPMOBatchFeature:
349
+ if isinstance(images, Image.Image):
350
+ images_list = [[images]]
351
+ elif isinstance(images[0], Image.Image):
352
+ images_list = [images]
353
+ else:
354
+ images_list = images
355
+
356
+ new_images_list = []
357
+ image_sizes_list = []
358
+ tgt_sizes_list = []
359
+
360
+ for _images in images_list:
361
+ if _images is None or len(_images) == 0:
362
+ new_images_list.append([])
363
+ image_sizes_list.append([])
364
+ tgt_sizes_list.append([])
365
+ continue
366
+ if not valid_images(_images):
367
+ raise ValueError(
368
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
369
+ "torch.Tensor, tf.Tensor or jax.ndarray."
370
+ )
371
+
372
+ _images = [self.to_pil_image(image).convert("RGB") for image in _images]
373
+ input_data_format = infer_channel_dimension_format(np.array(_images[0]))
374
+
375
+ new_images = []
376
+ image_sizes = [image.size for image in _images]
377
+ tgt_sizes = []
378
+ for image in _images:
379
+ image_patches = self.get_sliced_images(image, max_slice_nums)
380
+ image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches]
381
+ image_patches = [
382
+ self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format)
383
+ for image in image_patches
384
+ ]
385
+ image_patches = [
386
+ to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
387
+ for image in image_patches
388
+ ]
389
+ for slice_image in image_patches:
390
+ new_images.append(self.reshape_by_patch(slice_image))
391
+ tgt_sizes.append(
392
+ np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))
393
+ )
394
+
395
+ if tgt_sizes:
396
+ tgt_sizes = np.vstack(tgt_sizes)
397
+
398
+ new_images_list.append(new_images)
399
+ image_sizes_list.append(image_sizes)
400
+ tgt_sizes_list.append(tgt_sizes)
401
+ return MiniCPMOBatchFeature(
402
+ data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list},
403
+ tensor_type=return_tensors,
404
+ )
405
+
406
+
407
+ AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor)