lakshmi97 commited on
Commit
70bba6c
·
1 Parent(s): c59792b

Removed all files

Browse files
__init__.py DELETED
File without changes
configuration_phi4mm.py DELETED
@@ -1,235 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """ Phi-4-MM model configuration"""
17
-
18
- from transformers.configuration_utils import PretrainedConfig
19
- from transformers.utils import logging
20
-
21
-
22
- logger = logging.get_logger(__name__)
23
-
24
-
25
- class Phi4MMConfig(PretrainedConfig):
26
- r"""
27
- This is the configuration class to store the configuration of a [`Phi4MMModel`]. It is used to instantiate a Phi-4-MM
28
- model according to the specified arguments, defining the model architecture.
29
-
30
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
- documentation from [`PretrainedConfig`] for more information.
32
-
33
- Args:
34
- vocab_size (`int`, *optional*, defaults to 200064):
35
- Vocabulary size of the Phi-4-MM model. Defines the number of different tokens that can be represented by the
36
- `inputs_ids` passed when calling [`Phi4MMModel`].
37
- hidden_size (`int`, *optional*, defaults to 3072):
38
- Dimension of the hidden representations.
39
- intermediate_size (`int`, *optional*, defaults to 8192):
40
- Dimension of the MLP representations.
41
- num_hidden_layers (`int`, *optional*, defaults to 32):
42
- Number of hidden layers in the Transformer decoder.
43
- num_attention_heads (`int`, *optional*, defaults to 32):
44
- Number of attention heads for each attention layer in the Transformer decoder.
45
- num_key_value_heads (`int`, *optional*):
46
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
47
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
48
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
49
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
50
- by meanpooling all the original heads within that group. For more details checkout [this
51
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
52
- `num_attention_heads`.
53
- resid_pdrop (`float`, *optional*, defaults to 0.0):
54
- Dropout probability for mlp outputs.
55
- embd_pdrop (`int`, *optional*, defaults to 0.0):
56
- The dropout ratio for the embeddings.
57
- attention_dropout (`float`, *optional*, defaults to 0.0):
58
- The dropout ratio after computing the attention scores.
59
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60
- The non-linear activation function (function or string) in the decoder.
61
- max_position_embeddings (`int`, *optional*, defaults to 4096):
62
- The maximum sequence length that this model might ever be used with.
63
- original_max_position_embeddings (`int`, *optional*, defaults to 4096):
64
- The maximum sequence length that this model was trained with. This is used to determine the size of the
65
- original RoPE embeddings when using long scaling.
66
- initializer_range (`float`, *optional*, defaults to 0.02):
67
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
- rms_norm_eps (`float`, *optional*, defaults to 1e-05):
69
- The epsilon value used for the RMSNorm.
70
- use_cache (`bool`, *optional*, defaults to `True`):
71
- Whether or not the model should return the last key/values attentions (not used by all models). Only
72
- relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
73
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
74
- Whether to tie weight embeddings
75
- rope_theta (`float`, *optional*, defaults to 10000.0):
76
- The base period of the RoPE embeddings.
77
- rope_scaling (`dict`, *optional*):
78
- The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
79
- contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and
80
- the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
81
- divided by the number of attention heads divided by 2.
82
- partial_rotary_factor (`float`, *optional*, defaults to 0.5):
83
- Percentage of the query and keys which will have rotary embedding.
84
- bos_token_id (`int`, *optional*, defaults to 199999):
85
- The id of the "beginning-of-sequence" token.
86
- eos_token_id (`int`, *optional*, defaults to 199999):
87
- The id of the "end-of-sequence" token.
88
- pad_token_id (`int`, *optional*, defaults to 199999):
89
- The id of the padding token.
90
- sliding_window (`int`, *optional*):
91
- Sliding window attention window size. If `None`, no sliding window is applied.
92
-
93
- Example:
94
-
95
- ```python
96
- >>> from transformers import Phi4MMModel, Phi4MMConfig
97
-
98
- >>> # Initializing a Phi-4-MM style configuration
99
- >>> configuration = Phi4MMConfig.from_pretrained("TBA")
100
-
101
- >>> # Initializing a model from the configuration
102
- >>> model = Phi4MMModel(configuration)
103
-
104
- >>> # Accessing the model configuration
105
- >>> configuration = model.config
106
- ```"""
107
-
108
- model_type = "phi4mm"
109
- keys_to_ignore_at_inference = ["past_key_values"]
110
-
111
- def __init__(
112
- self,
113
- vocab_size=200064,
114
- hidden_size=3072,
115
- intermediate_size=8192,
116
- num_hidden_layers=32,
117
- num_attention_heads=32,
118
- num_key_value_heads=None,
119
- resid_pdrop=0.0,
120
- embd_pdrop=0.0,
121
- attention_dropout=0.0,
122
- hidden_act="silu",
123
- max_position_embeddings=4096,
124
- original_max_position_embeddings=4096,
125
- initializer_range=0.02,
126
- rms_norm_eps=1e-5,
127
- use_cache=True,
128
- tie_word_embeddings=False,
129
- rope_theta=10000.0,
130
- rope_scaling=None,
131
- partial_rotary_factor=1,
132
- bos_token_id=199999,
133
- eos_token_id=199999,
134
- pad_token_id=199999,
135
- sliding_window=None,
136
- embd_layer: str = "default",
137
- img_processor=None,
138
- audio_processor=None,
139
- vision_lora=None,
140
- speech_lora=None,
141
- **kwargs,
142
- ):
143
- self.embd_layer = embd_layer
144
- self.img_processor = img_processor
145
- self.audio_processor = audio_processor
146
- self.vision_lora = vision_lora
147
- self.speech_lora = speech_lora
148
-
149
- self.vocab_size = vocab_size
150
- self.hidden_size = hidden_size
151
- self.intermediate_size = intermediate_size
152
- self.num_hidden_layers = num_hidden_layers
153
- self.num_attention_heads = num_attention_heads
154
-
155
- if num_key_value_heads is None:
156
- num_key_value_heads = num_attention_heads
157
-
158
- self.num_key_value_heads = num_key_value_heads
159
- self.resid_pdrop = resid_pdrop
160
- self.embd_pdrop = embd_pdrop
161
- self.attention_dropout = attention_dropout
162
- self.hidden_act = hidden_act
163
- self.max_position_embeddings = max_position_embeddings
164
- self.original_max_position_embeddings = original_max_position_embeddings
165
- self.initializer_range = initializer_range
166
- self.rms_norm_eps = rms_norm_eps
167
- self.use_cache = use_cache
168
- self.rope_theta = rope_theta
169
- self.rope_scaling = rope_scaling
170
- self.partial_rotary_factor = partial_rotary_factor
171
- self._rope_scaling_adjustment()
172
- self._rope_scaling_validation()
173
- self.sliding_window = sliding_window
174
-
175
- super().__init__(
176
- bos_token_id=bos_token_id,
177
- eos_token_id=eos_token_id,
178
- pad_token_id=pad_token_id,
179
- tie_word_embeddings=tie_word_embeddings,
180
- **kwargs,
181
- )
182
-
183
- def _rope_scaling_adjustment(self):
184
- """
185
- Adjust the `type` of the `rope_scaling` configuration for backward compatibility.
186
- """
187
- if self.rope_scaling is None:
188
- return
189
-
190
- rope_scaling_type = self.rope_scaling.get("type", None)
191
-
192
- # For backward compatibility if previous version used "su" or "yarn"
193
- if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]:
194
- self.rope_scaling["type"] = "longrope"
195
-
196
- def _rope_scaling_validation(self):
197
- """
198
- Validate the `rope_scaling` configuration.
199
- """
200
- if self.rope_scaling is None:
201
- return
202
-
203
- if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
204
- raise ValueError(
205
- "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
206
- f"got {self.rope_scaling}"
207
- )
208
- rope_scaling_type = self.rope_scaling.get("type", None)
209
- rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
210
- rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
211
- if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
212
- raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
213
- if not (
214
- isinstance(rope_scaling_short_factor, list)
215
- and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
216
- ):
217
- raise ValueError(
218
- f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
219
- )
220
- rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor)
221
- if not len(rope_scaling_short_factor) == rotary_ndims // 2:
222
- raise ValueError(
223
- f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}"
224
- )
225
- if not (
226
- isinstance(rope_scaling_long_factor, list)
227
- and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
228
- ):
229
- raise ValueError(
230
- f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
231
- )
232
- if not len(rope_scaling_long_factor) == rotary_ndims // 2:
233
- raise ValueError(
234
- f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}"
235
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_phi4mm.py DELETED
The diff for this file is too large to render. See raw diff
 
processing_phi4mm.py DELETED
@@ -1,733 +0,0 @@
1
- # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Processor class for Phi4MM
17
- """
18
- import re
19
- from typing import List, Optional, Tuple, Union
20
- import math
21
- from enum import Enum
22
-
23
- import numpy as np
24
- import scipy
25
- import torch
26
- import torchvision
27
-
28
- from transformers import AutoFeatureExtractor, AutoImageProcessor
29
- from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
30
- from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
31
- from transformers.image_utils import (
32
- ImageInput,
33
- make_list_of_images,
34
- valid_images,
35
- )
36
- from transformers.processing_utils import ProcessorMixin
37
- from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy
38
- from transformers.utils import TensorType, logging
39
- from torch.nn.utils.rnn import pad_sequence
40
-
41
-
42
- logger = logging.get_logger(__name__)
43
-
44
- # Special tokens
45
- _COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN = r'<\|image_\d+\|>' # For backward compatibility
46
- _COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN = r'<\|audio_\d+\|>' # For backward compatibility
47
- _IMAGE_SPECIAL_TOKEN = '<|endoftext10|>'
48
- _AUDIO_SPECIAL_TOKEN = '<|endoftext11|>'
49
- _IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>', or we can better name it (in `tokenizer_config.json`)
50
- _AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>'
51
-
52
-
53
- class InputMode(Enum):
54
- LANGUAGE = 0
55
- VISION = 1
56
- SPEECH = 2
57
- VISION_SPEECH = 3
58
-
59
-
60
- class Phi4MMImageProcessor(BaseImageProcessor):
61
- r"""
62
- Constructs a Phi4MM image processor.
63
- """
64
- model_input_names = ["input_image_embeds", "image_sizes", "image_attention_mask"]
65
-
66
- def __init__(
67
- self,
68
- dynamic_hd,
69
- **kwargs,
70
- ) -> None:
71
- super().__init__(**kwargs)
72
- self.dynamic_hd = dynamic_hd
73
-
74
- def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
75
- best_ratio_diff = float('inf')
76
- best_ratio = (1, 1)
77
- area = width * height
78
- for ratio in target_ratios:
79
- target_aspect_ratio = ratio[0] / ratio[1]
80
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
81
- if ratio_diff < best_ratio_diff:
82
- best_ratio_diff = ratio_diff
83
- best_ratio = ratio
84
- elif ratio_diff == best_ratio_diff:
85
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
86
- best_ratio = ratio
87
- return best_ratio
88
-
89
- def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=384, mask_size=27, use_thumbnail=True):
90
- orig_width, orig_height = image.size
91
-
92
- w_crop_num = math.ceil(orig_width/float(image_size))
93
- h_crop_num = math.ceil(orig_height/float(image_size))
94
- if w_crop_num * h_crop_num > max_num:
95
-
96
- aspect_ratio = orig_width / orig_height
97
-
98
- # calculate the existing image aspect ratio
99
- target_ratios = set(
100
- (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
101
- i * j <= max_num and i * j >= min_num)
102
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
103
-
104
- # find the closest aspect ratio to the target
105
- target_aspect_ratio = self.find_closest_aspect_ratio(
106
- aspect_ratio, target_ratios, orig_width, orig_height, image_size)
107
-
108
- # calculate the target width and height
109
- target_width = image_size * target_aspect_ratio[0]
110
- target_height = image_size * target_aspect_ratio[1]
111
- else:
112
- target_width = image_size * w_crop_num
113
- target_height = image_size * h_crop_num
114
- target_aspect_ratio = (w_crop_num, h_crop_num)
115
-
116
- # Calculate the ratio
117
- ratio_width = target_width / orig_width
118
- ratio_height = target_height / orig_height
119
- if ratio_width < ratio_height:
120
- new_size = (target_width, int(orig_height * ratio_width))
121
- padding_width = 0
122
- padding_height = target_height - int(orig_height * ratio_width)
123
- else:
124
- new_size = (int(orig_width * ratio_height), target_height)
125
- padding_width = target_width - int(orig_width * ratio_height)
126
- padding_height = 0
127
-
128
- attention_mask = torch.ones((int(mask_size*target_aspect_ratio[1]), int(mask_size*target_aspect_ratio[0])))
129
- if padding_width >= 14:
130
- attention_mask[:, -math.floor(padding_width/14):] = 0
131
- if padding_height >= 14:
132
- attention_mask[-math.floor(padding_height/14):,:] = 0
133
- assert attention_mask.sum() > 0
134
-
135
- if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10:
136
- raise ValueError(f'the aspect ratio is very extreme {new_size}')
137
-
138
- image = torchvision.transforms.functional.resize(image, [new_size[1], new_size[0]],)
139
-
140
- resized_img = torchvision.transforms.functional.pad(image, [0, 0, padding_width, padding_height], fill=[255,255,255])
141
-
142
- return resized_img, attention_mask
143
-
144
- def pad_to_max_num_crops(self, images, max_crops=5):
145
- """
146
- images: B x 3 x H x W, B<=max_crops
147
- """
148
- B, _, H, W = images.shape
149
- if B < max_crops:
150
- pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device)
151
- images = torch.cat([images, pad], dim=0)
152
- return images
153
-
154
- def pad_mask_to_max_num_crops(self, masks, max_crops=5):
155
- B, H, W = masks.shape
156
- if B < max_crops:
157
- pad = torch.ones(max_crops - B, H, W, dtype=masks.dtype, device=masks.device)
158
- masks = torch.cat([masks, pad], dim=0)
159
- return masks
160
-
161
- def preprocess(
162
- self,
163
- images: ImageInput,
164
- return_tensors: Optional[Union[str, TensorType]] = None,
165
- ):
166
- """
167
- Args:
168
- images (`ImageInput`):
169
- Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
170
- passing in images with pixel values between 0 and 1, set `do_rescale=False`.
171
- return_tensors (`str` or `TensorType`, *optional*):
172
- The type of tensors to return. Can be one of:
173
- - Unset: Return a list of `np.ndarray`.
174
- - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
175
- - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
176
- - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
177
- - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
178
- """
179
- images = make_list_of_images(images)
180
-
181
- if not valid_images(images):
182
- raise ValueError(
183
- "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
184
- "torch.Tensor, tf.Tensor or jax.ndarray."
185
- )
186
-
187
- # Basic settings.
188
- img_processor = torchvision.transforms.Compose([
189
- torchvision.transforms.ToTensor(),
190
- torchvision.transforms.Normalize(
191
- (0.5, 0.5, 0.5),
192
- (0.5, 0.5, 0.5)
193
- ),
194
- ])
195
- dyhd_base_resolution = 448
196
-
197
- # Dynamic HD
198
- base_resolution = dyhd_base_resolution
199
- images = [image.convert('RGB') for image in images]
200
- # cover 384 and 448 resolution
201
- mask_resolution = base_resolution // 14
202
- elems, image_attention_masks = [], []
203
- for im in images:
204
- elem, attention_mask = self.dynamic_preprocess(im, max_num=self.dynamic_hd, image_size=base_resolution, mask_size=mask_resolution)
205
- elems.append(elem)
206
- image_attention_masks.append(attention_mask)
207
- hd_images = [img_processor(im) for im in elems]
208
- global_image = [torch.nn.functional.interpolate(im.unsqueeze(0).float(), size=(base_resolution, base_resolution), mode='bicubic',).to(im.dtype) for im in hd_images]
209
- shapes = [[im.size(1), im.size(2)] for im in hd_images]
210
- mask_shapes = [[mask.size(0), mask.size(1)] for mask in image_attention_masks]
211
- global_attention_mask = [torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images]
212
- hd_images_reshape = [im.reshape(1, 3,
213
- h//base_resolution,
214
- base_resolution,
215
- w//base_resolution,
216
- base_resolution
217
- ).permute(0,2,4,1,3,5).reshape(-1, 3, base_resolution, base_resolution).contiguous() for im, (h, w) in zip(hd_images, shapes)]
218
- attention_masks_reshape = [mask.reshape(1,
219
- h//mask_resolution,
220
- mask_resolution,
221
- w//mask_resolution,
222
- mask_resolution
223
- ).permute(0,1,3,2,4).reshape(-1, mask_resolution, mask_resolution).contiguous() for mask, (h, w) in zip(image_attention_masks, mask_shapes)]
224
- downsample_attention_masks = [mask[:,0::2,0::2].reshape(1,
225
- h//mask_resolution,
226
- w//mask_resolution,
227
- mask_resolution//2+mask_resolution%2,
228
- mask_resolution//2+mask_resolution%2
229
- ).permute(0,1,3,2,4) for mask, (h,w) in zip(attention_masks_reshape, mask_shapes)]
230
- downsample_attention_masks = [mask.reshape(mask.size(1)*mask.size(2), mask.size(3)*mask.size(4))for mask in downsample_attention_masks]
231
- num_img_tokens = [256 + 1 + int(mask.sum().item()) + int(mask[:,0].sum().item()) + 16 for mask in downsample_attention_masks]
232
-
233
- hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)]
234
- hd_masks_reshape = [torch.cat([_global_mask] + [_mask], dim=0) for _global_mask, _mask in zip(global_attention_mask, attention_masks_reshape)]
235
- max_crops = max([img.size(0) for img in hd_images_reshape])
236
- image_transformed = [self.pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape]
237
- image_transformed = torch.stack(image_transformed, dim=0)
238
- mask_transformed = [self.pad_mask_to_max_num_crops(mask, max_crops) for mask in hd_masks_reshape]
239
- mask_transformed = torch.stack(mask_transformed, dim=0)
240
-
241
- returned_input_image_embeds = image_transformed
242
- returned_image_sizes = torch.tensor(shapes, dtype=torch.long)
243
- returned_image_attention_mask = mask_transformed
244
- returned_num_img_tokens = num_img_tokens
245
-
246
- data = {
247
- "input_image_embeds": returned_input_image_embeds,
248
- "image_sizes": returned_image_sizes,
249
- "image_attention_mask": returned_image_attention_mask,
250
- "num_img_tokens": returned_num_img_tokens,
251
- }
252
-
253
- return BatchFeature(data=data, tensor_type=return_tensors)
254
-
255
-
256
- AudioInput = Tuple[Union[np.ndarray, torch.Tensor], int]
257
- AudioInputs = List[AudioInput]
258
-
259
-
260
- def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
261
- """Create a Mel filter-bank the same as SpeechLib FbankFC.
262
-
263
- Args:
264
- sample_rate (int): Sample rate in Hz. number > 0 [scalar]
265
- n_fft (int): FFT size. int > 0 [scalar]
266
- n_mel (int): Mel filter size. int > 0 [scalar]
267
- fmin (float): lowest frequency (in Hz). If None use 0.0.
268
- float >= 0 [scalar]
269
- fmax: highest frequency (in Hz). If None use sample_rate / 2.
270
- float >= 0 [scalar]
271
-
272
- Returns
273
- out (numpy.ndarray): Mel transform matrix
274
- [shape=(n_mels, 1 + n_fft/2)]
275
- """
276
-
277
- bank_width = int(n_fft // 2 + 1)
278
- if fmax is None:
279
- fmax = sample_rate / 2
280
- if fmin is None:
281
- fmin = 0
282
- assert fmin >= 0, "fmin cannot be negtive"
283
- assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"
284
-
285
- def mel(f):
286
- return 1127.0 * np.log(1.0 + f / 700.0)
287
-
288
- def bin2mel(fft_bin):
289
- return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
290
-
291
- def f2bin(f):
292
- return int((f * n_fft / sample_rate) + 0.5)
293
-
294
- # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
295
- klo = f2bin(fmin) + 1
296
- khi = f2bin(fmax)
297
-
298
- khi = max(khi, klo)
299
-
300
- # Spec 2: SpeechLib uses trianges in Mel space
301
- mlo = mel(fmin)
302
- mhi = mel(fmax)
303
- m_centers = np.linspace(mlo, mhi, n_mels + 2)
304
- ms = (mhi - mlo) / (n_mels + 1)
305
-
306
- matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
307
- for m in range(0, n_mels):
308
- left = m_centers[m]
309
- center = m_centers[m + 1]
310
- right = m_centers[m + 2]
311
- for fft_bin in range(klo, khi):
312
- mbin = bin2mel(fft_bin)
313
- if left < mbin < right:
314
- matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
315
-
316
- return matrix
317
-
318
-
319
- class Phi4MMAudioFeatureExtractor(SequenceFeatureExtractor):
320
- model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
321
-
322
- def __init__(self, audio_compression_rate, audio_downsample_rate, audio_feat_stride, **kwargs):
323
- feature_size = 80
324
- sampling_rate = 16000
325
- padding_value = 0.0
326
- super().__init__(feature_size, sampling_rate, padding_value, **kwargs)
327
-
328
- self.compression_rate = audio_compression_rate
329
- self.qformer_compression_rate = audio_downsample_rate
330
- self.feat_stride = audio_feat_stride
331
-
332
- self._eightk_method = "fillzero"
333
- self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T
334
-
335
- self._hamming400 = np.hamming(400) # for 16k audio
336
- self._hamming200 = np.hamming(200) # for 8k audio
337
-
338
- def duration_to_frames(self, duration):
339
- """duration in s, estimated frames"""
340
- frame_rate = 10
341
-
342
- num_frames = duration * 1000 // frame_rate
343
- return num_frames
344
-
345
- def __call__(
346
- self,
347
- audios: List[AudioInput],
348
- return_tensors: Optional[Union[str, TensorType]] = None,
349
- ):
350
- # Ref: https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161
351
- returned_input_audio_embeds = []
352
- returned_audio_embed_sizes = []
353
- audio_frames_list = []
354
-
355
- for audio_data, sample_rate in audios:
356
- audio_embeds = self._extract_features(audio_data, sample_rate)
357
- audio_frames = len(audio_embeds) * self.feat_stride
358
- audio_embed_size = self._compute_audio_embed_size(audio_frames)
359
-
360
- returned_input_audio_embeds.append(torch.tensor(audio_embeds))
361
- returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long())
362
- audio_frames_list.append(audio_frames)
363
-
364
- returned_input_audio_embeds = pad_sequence(
365
- returned_input_audio_embeds, batch_first=True
366
- )
367
- returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
368
- audio_frames = torch.tensor(audio_frames_list)
369
- returned_audio_attention_mask = torch.arange(0, audio_frames.max()).unsqueeze(0) < audio_frames.unsqueeze(1) if len(audios) > 1 else None
370
-
371
- data = {
372
- "input_audio_embeds": returned_input_audio_embeds,
373
- "audio_embed_sizes": returned_audio_embed_sizes,
374
- }
375
- if returned_audio_attention_mask is not None:
376
- data["audio_attention_mask"] = returned_audio_attention_mask
377
-
378
- return BatchFeature(data=data, tensor_type=return_tensors)
379
-
380
- def _extract_spectrogram(self, wav, fs):
381
- """Extract spectrogram features from waveform.
382
- Args:
383
- wav (1D array): waveform of the input
384
- fs (int): sampling rate of the waveform, 16000 or 8000.
385
- If fs=8000, the waveform will be resampled to 16000Hz.
386
- Output:
387
- log_fbank (2D array): a TxD matrix of log Mel filterbank features.
388
- D=80, and T is the number of frames.
389
- """
390
- if wav.ndim > 1:
391
- wav = np.squeeze(wav)
392
-
393
- # by default, we extract the mean if stereo
394
- if len(wav.shape) == 2:
395
- wav = wav.mean(1)
396
-
397
- # Resample to 16000 or 8000 if needed
398
- if fs > 16000:
399
- wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
400
- fs = 16000
401
- elif 8000 < fs < 16000:
402
- wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
403
- fs = 8000
404
- elif fs < 8000:
405
- raise RuntimeError(f"Unsupported sample rate {fs}")
406
-
407
- if fs == 8000:
408
- if self._eightk_method == "resample":
409
- # Input audio is 8 kHz. Convert to 16 kHz before feature
410
- # extraction
411
- wav = scipy.signal.resample_poly(wav, 2, 1)
412
- fs = 16000
413
- # Do nothing here for fillzero method
414
- elif fs != 16000:
415
- # Input audio is not a supported sample rate.
416
- raise RuntimeError(f"Input data using an unsupported sample rate: {fs}")
417
-
418
- preemphasis = 0.97
419
-
420
- if fs == 8000:
421
- n_fft = 256
422
- win_length = 200
423
- hop_length = 80
424
- fft_window = self._hamming200
425
- elif fs == 16000:
426
- n_fft = 512
427
- win_length = 400
428
- hop_length = 160
429
- fft_window = self._hamming400
430
-
431
- # Spec 1: SpeechLib cut remaining sample insufficient for a hop
432
- n_batch = (wav.shape[0] - win_length) // hop_length + 1
433
- # Here we don't use stride_tricks since the input array may not satisfy
434
- # memory layout requirement and we need writeable output
435
- # Here we only use list of views before copy to desination
436
- # so it is more efficient than broadcasting
437
- y_frames = np.array(
438
- [wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)],
439
- dtype=np.float32,
440
- )
441
-
442
- # Spec 2: SpeechLib applies preemphasis within each batch
443
- y_frames_prev = np.roll(y_frames, 1, axis=1)
444
- y_frames_prev[:, 0] = y_frames_prev[:, 1]
445
- y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
446
-
447
- S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64)
448
-
449
- if fs == 8000:
450
- # Need to pad the output to look like 16 kHz data but with zeros in
451
- # the 4 to 8 kHz bins.
452
- frames, bins = S.shape
453
- padarray = np.zeros((frames, bins))
454
- S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero
455
-
456
- spec = np.abs(S).astype(np.float32)
457
- return spec
458
-
459
- def _extract_features(self, wav, fs):
460
- """Extract log filterbank features from waveform.
461
- Args:
462
- wav (1D array): waveform of the input
463
- fs (int): sampling rate of the waveform, 16000 or 8000.
464
- If fs=8000, the waveform will be resampled to 16000Hz.
465
- Output:
466
- log_fbank (2D array): a TxD matrix of log Mel filterbank features.
467
- D=80, and T is the number of frames.
468
- """
469
- spec = self._extract_spectrogram(wav, fs)
470
- spec_power = spec**2
471
-
472
- fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
473
- log_fbank = np.log(fbank_power).astype(np.float32)
474
-
475
- return log_fbank
476
-
477
- def _compute_audio_embed_size(self, audio_frames):
478
- integer = audio_frames // self.compression_rate
479
- remainder = audio_frames % self.compression_rate
480
-
481
- result = integer if remainder == 0 else integer + 1
482
-
483
- integer = result // self.qformer_compression_rate
484
- remainder = result % self.qformer_compression_rate
485
- result = integer if remainder == 0 else integer + 1 # qformer compression
486
-
487
- return result
488
-
489
-
490
- class Phi4MMProcessor(ProcessorMixin):
491
- r"""
492
- Constructs a Phi4MM processor which raps an image processor, a audio processor, and a GPT tokenizer into a single processor.
493
-
494
- [`Phi4MMProcessor`] offers all the functionalities of [`Phi4MMImageProcessor`] and [`GPT2Tokenizer`]. See the
495
- [`~Phi4MMProcessor.__call__`] and [`~Phi4MMProcessor.decode`] for more information.
496
-
497
- Args:
498
- image_processor ([`Phi4MMImageProcessor`], *optional*):
499
- The image processor is a required input.
500
- tokenizer ([`GPT2Tokenizer`], *optional*):
501
- The tokenizer is a required input.
502
- """
503
-
504
- attributes = ["image_processor", "audio_processor", "tokenizer"]
505
- tokenizer_class = "GPT2TokenizerFast"
506
- image_processor_class = "AutoImageProcessor" # Phi4MMImageProcessor will be registered later
507
- audio_processor_class = "AutoFeatureExtractor" # Phi4MMAudioFeatureExtractor will be registered later
508
-
509
- def __init__(self, image_processor, audio_processor, tokenizer):
510
- self.image_processor = image_processor
511
- self.audio_processor = audio_processor
512
- self.tokenizer = tokenizer
513
-
514
- def __call__(
515
- self,
516
- text: Union[TextInput, List[TextInput]],
517
- images: Optional[ImageInput] = None,
518
- audios: Optional[AudioInputs] = None,
519
- padding: Union[bool, str, PaddingStrategy] = False,
520
- truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
521
- max_length=None,
522
- return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
523
- ) -> BatchFeature:
524
- """
525
- Main method to prepare for the model one or several sequences(s) and image(s). This method forards the `text`
526
- and `kwargs` arguments to GPT2Tokenizer's [`~GPT2Tokenizer.__call__`] if `text` is not `None` to encode
527
- the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
528
- Phi4MMImageProcessor's [`~Phi4MMImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
529
- of the above two methods for more information.
530
-
531
- Args:
532
- text (`str`, `List[str]`, `List[List[str]]`):
533
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
534
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
535
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
536
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
537
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
538
- tensor. Both channels-first and channels-last formats are supported.
539
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
540
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
541
- index) among:
542
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
543
- sequence if provided).
544
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
545
- acceptable input length for the model if that argument is not provided.
546
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
547
- lengths).
548
- max_length (`int`, *optional*):
549
- Maximum length of the returned list and optionally padding length (see above).
550
- truncation (`bool`, *optional*):
551
- Activates truncation to cut input sequences longer than `max_length` to `max_length`.
552
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
553
- If set, will return tensors of a particular framework. Acceptable values are:
554
-
555
- - `'tf'`: Return TensorFlow `tf.constant` objects.
556
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
557
- - `'np'`: Return NumPy `np.ndarray` objects.
558
- - `'jax'`: Return JAX `jnp.ndarray` objects.
559
-
560
- Returns:
561
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
562
-
563
- - **input_ids** -- List of token ids to be fed to a model.
564
- - **input_image_embeds** -- Pixel values to be fed to a model.
565
- - **image_sizes** -- List of tuples specifying the size of each image in `input_image_embeds`.
566
- - **image_attention_mask** -- List of attention masks for each image in `input_image_embeds`.
567
- - **input_audio_embeds** -- Audio embeddings to be fed to a model.
568
- - **audio_embed_sizes** -- List of integers specifying the size of each audio in `input_audio_embeds`.
569
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
570
- """
571
- image_inputs = self.image_processor(images, return_tensors=return_tensors) if images is not None else {}
572
- audio_inputs = self.audio_processor(audios, return_tensors=return_tensors) if audios is not None else {}
573
- inputs = self._convert_images_audios_text_to_inputs(
574
- image_inputs,
575
- audio_inputs,
576
- text,
577
- padding=padding,
578
- truncation=truncation,
579
- max_length=max_length,
580
- return_tensors=return_tensors,
581
- )
582
-
583
- # idenfity the input mode
584
- if len(image_inputs) > 0 and len(audio_inputs) > 0:
585
- input_mode = InputMode.VISION_SPEECH
586
- elif len(image_inputs) > 0:
587
- input_mode = InputMode.VISION
588
- elif len(audio_inputs) > 0:
589
- input_mode = InputMode.SPEECH
590
- else:
591
- input_mode = InputMode.LANGUAGE
592
- inputs["input_mode"] = torch.tensor([input_mode.value], dtype=torch.long)
593
-
594
- return inputs
595
-
596
- @property
597
- def special_image_token_id(self):
598
- return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
599
-
600
- def get_special_image_token_id(self):
601
- return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
602
-
603
- @property
604
- def chat_template(self):
605
- return self.tokenizer.chat_template
606
-
607
- def _convert_images_audios_text_to_inputs(
608
- self, images, audios, text, padding=False, truncation=None, max_length=None, return_tensors=None
609
- ):
610
- # prepare image id to image input ids
611
- if len(images) > 0:
612
- input_image_embeds = images["input_image_embeds"]
613
- image_sizes = images["image_sizes"]
614
- image_attention_mask = images["image_attention_mask"]
615
- num_img_tokens = images['num_img_tokens']
616
- else:
617
- input_image_embeds = torch.tensor([])
618
- image_sizes = torch.tensor([])
619
- image_attention_mask = torch.tensor([])
620
- num_img_tokens = []
621
-
622
- # prepare audio id to audio input ids
623
- if len(audios) > 0:
624
- input_audio_embeds = audios["input_audio_embeds"]
625
- audio_embed_sizes = audios["audio_embed_sizes"]
626
- audio_attention_mask = audios.get("audio_attention_mask", None)
627
- else:
628
- input_audio_embeds = torch.tensor([])
629
- audio_embed_sizes = torch.tensor([])
630
- audio_attention_mask = None
631
-
632
- # Replace certain special tokens for compatibility
633
- # Ref: https://stackoverflow.com/questions/11475885/python-replace-regex
634
- if isinstance(text, str):
635
- text = [text]
636
- assert isinstance(text, list)
637
- processed_text = [re.sub(_COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN, _IMAGE_SPECIAL_TOKEN, t) for t in text]
638
- processed_text = [re.sub(_COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN, _AUDIO_SPECIAL_TOKEN, t) for t in processed_text]
639
-
640
- input_ids_list = [self.tokenizer(t).input_ids for t in processed_text]
641
-
642
- img_cnt, audio_cnt = 0, 0 # only needed for later assertion
643
- image_token_count_iter = iter(num_img_tokens)
644
- audio_embed_size_iter = iter(audio_embed_sizes.tolist())
645
- new_input_ids_list = []
646
- for input_ids in input_ids_list:
647
- i = 0
648
- while i < len(input_ids):
649
- token_id = input_ids[i]
650
- if token_id == _AUDIO_SPECIAL_TOKEN_ID:
651
- token_count = next(audio_embed_size_iter)
652
- audio_cnt += 1
653
- elif token_id == _IMAGE_SPECIAL_TOKEN_ID:
654
- token_count = next(image_token_count_iter)
655
- img_cnt += 1
656
- else:
657
- i += 1
658
- continue
659
- tokens = [token_id] * token_count
660
- input_ids = input_ids[:i] + tokens + input_ids[i + 1:]
661
- i += token_count
662
- input_ids = torch.tensor(input_ids, dtype=torch.long)
663
- new_input_ids_list.append(input_ids)
664
- lengths = torch.tensor([len(input_ids) for input_ids in new_input_ids_list])
665
- max_len = lengths.max()
666
- input_ids = input_ids.new_full((len(new_input_ids_list), max_len), self.tokenizer.pad_token_id)
667
- # batched inference requires left padding
668
- for i in range(len(new_input_ids_list)):
669
- input_ids[i, max_len - len(new_input_ids_list[i]):] = new_input_ids_list[i]
670
-
671
- # If the below assertion fails, it might be that input pure-text
672
- # messages contain image/audio special tokens literally
673
- # (<|endoftext10|>, <|endoftext11|>).
674
- assert (
675
- img_cnt == len(num_img_tokens)
676
- ), (
677
- f"Number of image tokens in prompt_token_ids ({img_cnt}) "
678
- f"does not match number of images ({len(num_img_tokens)})"
679
- )
680
- assert (
681
- audio_cnt == len(audio_embed_sizes)
682
- ), (
683
- f"Number of audio tokens in prompt_token_ids ({audio_cnt}) "
684
- f"does not match number of audios ({len(audio_embed_sizes)})"
685
- )
686
-
687
- # prepare attention mask
688
- seq_range = torch.arange(max_len - 1, -1, -1)
689
- attention_mask = seq_range.unsqueeze(0) < lengths.unsqueeze(1)
690
-
691
- # prepare batch feature
692
- data = {
693
- "input_ids": input_ids,
694
- "input_image_embeds": input_image_embeds,
695
- "image_sizes": image_sizes,
696
- "image_attention_mask": image_attention_mask,
697
- "input_audio_embeds": input_audio_embeds,
698
- "audio_embed_sizes": audio_embed_sizes,
699
- "audio_attention_mask": audio_attention_mask,
700
- "attention_mask": attention_mask,
701
- }
702
-
703
- return BatchFeature(
704
- data=data
705
- )
706
-
707
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
708
- def batch_decode(self, *args, **kwargs):
709
- """
710
- This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
711
- refer to the docstring of this method for more information.
712
- """
713
- return self.tokenizer.batch_decode(*args, **kwargs)
714
-
715
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
716
- def decode(self, *args, **kwargs):
717
- """
718
- This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
719
- the docstring of this method for more information.
720
- """
721
- return self.tokenizer.decode(*args, **kwargs)
722
-
723
- @property
724
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
725
- def model_input_names(self):
726
- tokenizer_input_names = self.tokenizer.model_input_names
727
- image_processor_input_names = self.image_processor.model_input_names
728
- audio_processor_input_names = self.audio_processor.model_input_names
729
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + audio_processor_input_names))
730
-
731
-
732
- AutoImageProcessor.register("Phi4MMImageProcessor", Phi4MMImageProcessor)
733
- AutoFeatureExtractor.register("Phi4MMAudioFeatureExtractor", Phi4MMAudioFeatureExtractor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
speech_conformer_encoder.py DELETED
The diff for this file is too large to render. See raw diff
 
vision_siglip_navit.py DELETED
@@ -1,1717 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Siglip model configuration"""
16
-
17
- import os
18
- from typing import Union
19
-
20
- from transformers.configuration_utils import PretrainedConfig
21
- from transformers.utils import logging
22
-
23
-
24
- logger = logging.get_logger(__name__)
25
-
26
- SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
- "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json",
28
- }
29
-
30
-
31
- class SiglipTextConfig(PretrainedConfig):
32
- r"""
33
- This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
34
- Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
35
- configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
36
- [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
37
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
- documentation from [`PretrainedConfig`] for more information.
39
- Args:
40
- vocab_size (`int`, *optional*, defaults to 32000):
41
- Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
42
- the `inputs_ids` passed when calling [`SiglipModel`].
43
- hidden_size (`int`, *optional*, defaults to 768):
44
- Dimensionality of the encoder layers and the pooler layer.
45
- intermediate_size (`int`, *optional*, defaults to 3072):
46
- Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
47
- num_hidden_layers (`int`, *optional*, defaults to 12):
48
- Number of hidden layers in the Transformer encoder.
49
- num_attention_heads (`int`, *optional*, defaults to 12):
50
- Number of attention heads for each attention layer in the Transformer encoder.
51
- max_position_embeddings (`int`, *optional*, defaults to 64):
52
- The maximum sequence length that this model might ever be used with. Typically set this to something large
53
- just in case (e.g., 512 or 1024 or 2048).
54
- hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
55
- The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
56
- `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
57
- layer_norm_eps (`float`, *optional*, defaults to 1e-06):
58
- The epsilon used by the layer normalization layers.
59
- attention_dropout (`float`, *optional*, defaults to 0.0):
60
- The dropout ratio for the attention probabilities.
61
- pad_token_id (`int`, *optional*, defaults to 1):
62
- The id of the padding token in the vocabulary.
63
- bos_token_id (`int`, *optional*, defaults to 49406):
64
- The id of the beginning-of-sequence token in the vocabulary.
65
- eos_token_id (`int`, *optional*, defaults to 49407):
66
- The id of the end-of-sequence token in the vocabulary.
67
- Example:
68
- ```python
69
- >>> from transformers import SiglipTextConfig, SiglipTextModel
70
- >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
71
- >>> configuration = SiglipTextConfig()
72
- >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
73
- >>> model = SiglipTextModel(configuration)
74
- >>> # Accessing the model configuration
75
- >>> configuration = model.config
76
- ```"""
77
-
78
- model_type = "siglip_text_model"
79
-
80
- def __init__(
81
- self,
82
- vocab_size=32000,
83
- hidden_size=768,
84
- intermediate_size=3072,
85
- num_hidden_layers=12,
86
- num_attention_heads=12,
87
- max_position_embeddings=64,
88
- hidden_act="gelu_pytorch_tanh",
89
- layer_norm_eps=1e-6,
90
- attention_dropout=0.0,
91
- # This differs from `CLIPTokenizer`'s default and from openai/siglip
92
- # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
93
- pad_token_id=1,
94
- bos_token_id=49406,
95
- eos_token_id=49407,
96
- _flash_attn_2_enabled=True,
97
- **kwargs,
98
- ):
99
- super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
100
-
101
- self.vocab_size = vocab_size
102
- self.hidden_size = hidden_size
103
- self.intermediate_size = intermediate_size
104
- self.num_hidden_layers = num_hidden_layers
105
- self.num_attention_heads = num_attention_heads
106
- self.max_position_embeddings = max_position_embeddings
107
- self.layer_norm_eps = layer_norm_eps
108
- self.hidden_act = hidden_act
109
- self.attention_dropout = attention_dropout
110
- self._flash_attn_2_enabled = _flash_attn_2_enabled
111
-
112
- @classmethod
113
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
114
- cls._set_token_in_kwargs(kwargs)
115
-
116
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
117
-
118
- # get the text config dict if we are loading from SiglipConfig
119
- if config_dict.get("model_type") == "siglip":
120
- config_dict = config_dict["text_config"]
121
-
122
- if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
123
- logger.warning(
124
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
125
- f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
126
- )
127
-
128
- return cls.from_dict(config_dict, **kwargs)
129
-
130
-
131
- class SiglipVisionConfig(PretrainedConfig):
132
- r"""
133
- This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
134
- Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
135
- configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
136
- [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
137
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
138
- documentation from [`PretrainedConfig`] for more information.
139
- Args:
140
- hidden_size (`int`, *optional*, defaults to 768):
141
- Dimensionality of the encoder layers and the pooler layer.
142
- intermediate_size (`int`, *optional*, defaults to 3072):
143
- Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
144
- num_hidden_layers (`int`, *optional*, defaults to 12):
145
- Number of hidden layers in the Transformer encoder.
146
- num_attention_heads (`int`, *optional*, defaults to 12):
147
- Number of attention heads for each attention layer in the Transformer encoder.
148
- num_channels (`int`, *optional*, defaults to 3):
149
- Number of channels in the input images.
150
- image_size (`int`, *optional*, defaults to 224):
151
- The size (resolution) of each image.
152
- patch_size (`int`, *optional*, defaults to 16):
153
- The size (resolution) of each patch.
154
- hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
155
- The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
156
- `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
157
- layer_norm_eps (`float`, *optional*, defaults to 1e-06):
158
- The epsilon used by the layer normalization layers.
159
- attention_dropout (`float`, *optional*, defaults to 0.0):
160
- The dropout ratio for the attention probabilities.
161
- Example:
162
- ```python
163
- >>> from transformers import SiglipVisionConfig, SiglipVisionModel
164
- >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
165
- >>> configuration = SiglipVisionConfig()
166
- >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
167
- >>> model = SiglipVisionModel(configuration)
168
- >>> # Accessing the model configuration
169
- >>> configuration = model.config
170
- ```"""
171
-
172
- model_type = "siglip_vision_model"
173
-
174
- def __init__(
175
- self,
176
- hidden_size=768,
177
- intermediate_size=3072,
178
- num_hidden_layers=12,
179
- num_attention_heads=12,
180
- num_channels=3,
181
- image_size=224,
182
- patch_size=16,
183
- hidden_act="gelu_pytorch_tanh",
184
- layer_norm_eps=1e-6,
185
- attention_dropout=0.0,
186
- _flash_attn_2_enabled=True,
187
- **kwargs,
188
- ):
189
- super().__init__(**kwargs)
190
-
191
- self.hidden_size = hidden_size
192
- self.intermediate_size = intermediate_size
193
- self.num_hidden_layers = num_hidden_layers
194
- self.num_attention_heads = num_attention_heads
195
- self.num_channels = num_channels
196
- self.patch_size = patch_size
197
- self.image_size = image_size
198
- self.attention_dropout = attention_dropout
199
- self.layer_norm_eps = layer_norm_eps
200
- self.hidden_act = hidden_act
201
- self._flash_attn_2_enabled = _flash_attn_2_enabled
202
-
203
- @classmethod
204
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
205
- cls._set_token_in_kwargs(kwargs)
206
-
207
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
208
-
209
- # get the vision config dict if we are loading from SiglipConfig
210
- if config_dict.get("model_type") == "siglip":
211
- config_dict = config_dict["vision_config"]
212
-
213
- if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
214
- logger.warning(
215
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
216
- f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
217
- )
218
-
219
- return cls.from_dict(config_dict, **kwargs)
220
-
221
-
222
- class SiglipConfig(PretrainedConfig):
223
- r"""
224
- [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
225
- instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
226
- Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
227
- [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
228
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
229
- documentation from [`PretrainedConfig`] for more information.
230
- Args:
231
- text_config (`dict`, *optional*):
232
- Dictionary of configuration options used to initialize [`SiglipTextConfig`].
233
- vision_config (`dict`, *optional*):
234
- Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
235
- kwargs (*optional*):
236
- Dictionary of keyword arguments.
237
- Example:
238
- ```python
239
- >>> from transformers import SiglipConfig, SiglipModel
240
- >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
241
- >>> configuration = SiglipConfig()
242
- >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
243
- >>> model = SiglipModel(configuration)
244
- >>> # Accessing the model configuration
245
- >>> configuration = model.config
246
- >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
247
- >>> from transformers import SiglipTextConfig, SiglipVisionConfig
248
- >>> # Initializing a SiglipText and SiglipVision configuration
249
- >>> config_text = SiglipTextConfig()
250
- >>> config_vision = SiglipVisionConfig()
251
- >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
252
- ```"""
253
-
254
- model_type = "siglip"
255
-
256
- def __init__(self, text_config=None, vision_config=None, **kwargs):
257
- super().__init__(**kwargs)
258
-
259
- if text_config is None:
260
- text_config = {}
261
- logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
262
-
263
- if vision_config is None:
264
- vision_config = {}
265
- logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
266
-
267
- self.text_config = SiglipTextConfig(**text_config)
268
- self.vision_config = SiglipVisionConfig(**vision_config)
269
-
270
- self.initializer_factor = 1.0
271
-
272
- @classmethod
273
- def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
274
- r"""
275
- Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
276
- model configuration.
277
- Returns:
278
- [`SiglipConfig`]: An instance of a configuration object
279
- """
280
-
281
- return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
282
-
283
- # coding=utf-8
284
- # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
285
- #
286
- # Licensed under the Apache License, Version 2.0 (the "License");
287
- # you may not use this file except in compliance with the License.
288
- # You may obtain a copy of the License at
289
- #
290
- # http://www.apache.org/licenses/LICENSE-2.0
291
- #
292
- # Unless required by applicable law or agreed to in writing, software
293
- # distributed under the License is distributed on an "AS IS" BASIS,
294
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
295
- # See the License for the specific language governing permissions and
296
- # limitations under the License.
297
- """ PyTorch Siglip model."""
298
-
299
-
300
- import math
301
- import warnings
302
- from dataclasses import dataclass
303
- from typing import Any, Optional, Tuple, Union
304
-
305
- import numpy as np
306
- import torch
307
- import torch.nn.functional as F
308
- import torch.utils.checkpoint
309
- from torch import nn
310
- from torch.nn.init import _calculate_fan_in_and_fan_out
311
-
312
- from transformers.activations import ACT2FN
313
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
314
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
315
- from transformers.modeling_utils import PreTrainedModel
316
- from transformers.utils import (
317
- ModelOutput,
318
- add_start_docstrings,
319
- add_start_docstrings_to_model_forward,
320
- is_flash_attn_2_available,
321
- logging,
322
- replace_return_docstrings,
323
- )
324
-
325
- logger = logging.get_logger(__name__)
326
-
327
- _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
328
-
329
- SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
330
- "google/siglip-base-patch16-224",
331
- # See all SigLIP models at https://huggingface.co/models?filter=siglip
332
- ]
333
-
334
- if is_flash_attn_2_available():
335
- from flash_attn import flash_attn_func, flash_attn_varlen_func
336
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
337
-
338
-
339
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
340
- def _get_unpad_data(attention_mask):
341
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
342
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
343
- max_seqlen_in_batch = seqlens_in_batch.max().item()
344
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
345
- return (
346
- indices,
347
- cu_seqlens,
348
- max_seqlen_in_batch,
349
- )
350
-
351
-
352
- def _trunc_normal_(tensor, mean, std, a, b):
353
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
354
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
355
- def norm_cdf(x):
356
- # Computes standard normal cumulative distribution function
357
- return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
358
-
359
- if (mean < a - 2 * std) or (mean > b + 2 * std):
360
- warnings.warn(
361
- "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
362
- "The distribution of values may be incorrect.",
363
- stacklevel=2,
364
- )
365
-
366
- # Values are generated by using a truncated uniform distribution and
367
- # then using the inverse CDF for the normal distribution.
368
- # Get upper and lower cdf values
369
- l = norm_cdf((a - mean) / std)
370
- u = norm_cdf((b - mean) / std)
371
-
372
- # Uniformly fill tensor with values from [l, u], then translate to
373
- # [2l-1, 2u-1].
374
- tensor.uniform_(2 * l - 1, 2 * u - 1)
375
-
376
- # Use inverse cdf transform for normal distribution to get truncated
377
- # standard normal
378
- if tensor.dtype in [torch.float16, torch.bfloat16]:
379
- # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
380
- og_dtype = tensor.dtype
381
- tensor = tensor.to(torch.float32)
382
- tensor.erfinv_()
383
- tensor = tensor.to(og_dtype)
384
- else:
385
- tensor.erfinv_()
386
-
387
- # Transform to proper mean, std
388
- tensor.mul_(std * math.sqrt(2.0))
389
- tensor.add_(mean)
390
-
391
- # Clamp to ensure it's in the proper range
392
- if tensor.dtype == torch.float16:
393
- # The `clamp_` op is not (yet?) defined in float16+cpu
394
- tensor = tensor.to(torch.float32)
395
- tensor.clamp_(min=a, max=b)
396
- tensor = tensor.to(torch.float16)
397
- else:
398
- tensor.clamp_(min=a, max=b)
399
-
400
-
401
- def trunc_normal_tf_(
402
- tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
403
- ) -> torch.Tensor:
404
- """Fills the input Tensor with values drawn from a truncated
405
- normal distribution. The values are effectively drawn from the
406
- normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
407
- with values outside :math:`[a, b]` redrawn until they are within
408
- the bounds. The method used for generating the random values works
409
- best when :math:`a \\leq \text{mean} \\leq b`.
410
- NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
411
- bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
412
- and the result is subsquently scaled and shifted by the mean and std args.
413
- Args:
414
- tensor: an n-dimensional `torch.Tensor`
415
- mean: the mean of the normal distribution
416
- std: the standard deviation of the normal distribution
417
- a: the minimum cutoff value
418
- b: the maximum cutoff value
419
- """
420
- with torch.no_grad():
421
- _trunc_normal_(tensor, 0, 1.0, a, b)
422
- tensor.mul_(std).add_(mean)
423
-
424
-
425
- def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
426
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
427
- if mode == "fan_in":
428
- denom = fan_in
429
- elif mode == "fan_out":
430
- denom = fan_out
431
- elif mode == "fan_avg":
432
- denom = (fan_in + fan_out) / 2
433
-
434
- variance = scale / denom
435
-
436
- if distribution == "truncated_normal":
437
- # constant is stddev of standard normal truncated to (-2, 2)
438
- trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
439
- elif distribution == "normal":
440
- with torch.no_grad():
441
- tensor.normal_(std=math.sqrt(variance))
442
- elif distribution == "uniform":
443
- bound = math.sqrt(3 * variance)
444
- with torch.no_grad():
445
- tensor.uniform_(-bound, bound)
446
- else:
447
- raise ValueError(f"invalid distribution {distribution}")
448
-
449
-
450
- def lecun_normal_(tensor):
451
- variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
452
-
453
-
454
- def default_flax_embed_init(tensor):
455
- variance_scaling_(tensor, mode="fan_in", distribution="normal")
456
-
457
-
458
- @dataclass
459
- # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
460
- class SiglipVisionModelOutput(ModelOutput):
461
- """
462
- Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
463
- Args:
464
- image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
465
- The image embeddings obtained by applying the projection layer to the pooler_output.
466
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
467
- Sequence of hidden-states at the output of the last layer of the model.
468
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
469
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
470
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
471
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
472
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
473
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
474
- sequence_length)`.
475
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
476
- heads.
477
- """
478
-
479
- image_embeds: Optional[torch.FloatTensor] = None
480
- last_hidden_state: torch.FloatTensor = None
481
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
482
- attentions: Optional[Tuple[torch.FloatTensor]] = None
483
-
484
-
485
- @dataclass
486
- # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
487
- class SiglipTextModelOutput(ModelOutput):
488
- """
489
- Base class for text model's outputs that also contains a pooling of the last hidden states.
490
- Args:
491
- text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
492
- The text embeddings obtained by applying the projection layer to the pooler_output.
493
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
494
- Sequence of hidden-states at the output of the last layer of the model.
495
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
496
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
497
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
498
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
499
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
500
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
501
- sequence_length)`.
502
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
503
- heads.
504
- """
505
-
506
- text_embeds: Optional[torch.FloatTensor] = None
507
- last_hidden_state: torch.FloatTensor = None
508
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
509
- attentions: Optional[Tuple[torch.FloatTensor]] = None
510
-
511
-
512
- @dataclass
513
- # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
514
- class SiglipOutput(ModelOutput):
515
- """
516
- Args:
517
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
518
- Contrastive loss for image-text similarity.
519
- logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
520
- The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
521
- similarity scores.
522
- logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
523
- The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
524
- similarity scores.
525
- text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
526
- The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
527
- image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
528
- The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
529
- text_model_output(`BaseModelOutputWithPooling`):
530
- The output of the [`SiglipTextModel`].
531
- vision_model_output(`BaseModelOutputWithPooling`):
532
- The output of the [`SiglipVisionModel`].
533
- """
534
-
535
- loss: Optional[torch.FloatTensor] = None
536
- logits_per_image: torch.FloatTensor = None
537
- logits_per_text: torch.FloatTensor = None
538
- text_embeds: torch.FloatTensor = None
539
- image_embeds: torch.FloatTensor = None
540
- text_model_output: BaseModelOutputWithPooling = None
541
- vision_model_output: BaseModelOutputWithPooling = None
542
-
543
- def to_tuple(self) -> Tuple[Any]:
544
- return tuple(
545
- self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
546
- for k in self.keys()
547
- )
548
-
549
-
550
- class SiglipVisionEmbeddings(nn.Module):
551
- def __init__(self, config: SiglipVisionConfig):
552
- super().__init__()
553
- self.config = config
554
- self.embed_dim = config.hidden_size
555
- self.image_size = config.image_size
556
- self.patch_size = config.patch_size
557
-
558
- self.patch_embedding = nn.Conv2d(
559
- in_channels=config.num_channels,
560
- out_channels=self.embed_dim,
561
- kernel_size=self.patch_size,
562
- stride=self.patch_size,
563
- padding="valid",
564
- )
565
-
566
- self.num_patches_per_side = self.image_size // self.patch_size
567
- self.num_patches = self.num_patches_per_side**2
568
- self.num_positions = self.num_patches
569
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
570
-
571
- def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
572
- batch_size = pixel_values.size(0)
573
-
574
- patch_embeds = self.patch_embedding(pixel_values)
575
- embeddings = patch_embeds.flatten(2).transpose(1, 2)
576
-
577
- max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
578
- max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
579
- boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
580
- position_ids = torch.full(
581
- size=(
582
- batch_size,
583
- max_nb_patches_h * max_nb_patches_w,
584
- ),
585
- fill_value=0,
586
- )
587
-
588
- for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
589
- nb_patches_h = p_attn_mask[:, 0].sum()
590
- nb_patches_w = p_attn_mask[0].sum()
591
-
592
- fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
593
- fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
594
-
595
- bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
596
- bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
597
-
598
- pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
599
- position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
600
-
601
- position_ids = position_ids.to(self.position_embedding.weight.device)
602
-
603
- embeddings = embeddings + self.position_embedding(position_ids)
604
- return embeddings
605
-
606
-
607
- # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
608
- class SiglipTextEmbeddings(nn.Module):
609
- def __init__(self, config: SiglipTextConfig):
610
- super().__init__()
611
- embed_dim = config.hidden_size
612
-
613
- self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
614
- self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
615
-
616
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
617
- self.register_buffer(
618
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
619
- )
620
-
621
- def forward(
622
- self,
623
- input_ids: Optional[torch.LongTensor] = None,
624
- position_ids: Optional[torch.LongTensor] = None,
625
- inputs_embeds: Optional[torch.FloatTensor] = None,
626
- ) -> torch.Tensor:
627
- seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
628
-
629
- if position_ids is None:
630
- position_ids = self.position_ids[:, :seq_length]
631
-
632
- if inputs_embeds is None:
633
- inputs_embeds = self.token_embedding(input_ids)
634
-
635
- position_embeddings = self.position_embedding(position_ids)
636
- embeddings = inputs_embeds + position_embeddings
637
-
638
- return embeddings
639
-
640
-
641
- class SiglipAttention(nn.Module):
642
- """Multi-headed attention from 'Attention Is All You Need' paper"""
643
-
644
- # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
645
- def __init__(self, config):
646
- super().__init__()
647
- self.config = config
648
- self.embed_dim = config.hidden_size
649
- self.num_heads = config.num_attention_heads
650
- self.head_dim = self.embed_dim // self.num_heads
651
- if self.head_dim * self.num_heads != self.embed_dim:
652
- raise ValueError(
653
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
654
- f" {self.num_heads})."
655
- )
656
- self.scale = self.head_dim**-0.5
657
- self.dropout = config.attention_dropout
658
-
659
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
660
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
661
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
662
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
663
-
664
- def forward(
665
- self,
666
- hidden_states: torch.Tensor,
667
- attention_mask: Optional[torch.Tensor] = None,
668
- output_attentions: Optional[bool] = False,
669
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
670
- """Input shape: Batch x Time x Channel"""
671
-
672
- batch_size, q_len, _ = hidden_states.size()
673
-
674
- query_states = self.q_proj(hidden_states)
675
- key_states = self.k_proj(hidden_states)
676
- value_states = self.v_proj(hidden_states)
677
-
678
- query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
679
- key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
680
- value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
681
-
682
- k_v_seq_len = key_states.shape[-2]
683
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
684
-
685
- if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
686
- raise ValueError(
687
- f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
688
- f" {attn_weights.size()}"
689
- )
690
-
691
- if attention_mask is not None:
692
- if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
693
- raise ValueError(
694
- f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
695
- )
696
- attn_weights = attn_weights + attention_mask
697
-
698
- # upcast attention to fp32
699
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
700
- attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
701
- attn_output = torch.matmul(attn_weights, value_states)
702
-
703
- if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
704
- raise ValueError(
705
- f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
706
- f" {attn_output.size()}"
707
- )
708
-
709
- attn_output = attn_output.transpose(1, 2).contiguous()
710
- attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
711
-
712
- attn_output = self.out_proj(attn_output)
713
-
714
- return attn_output, attn_weights
715
-
716
-
717
- class SiglipFlashAttention2(SiglipAttention):
718
- """
719
- Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
720
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
721
- flash attention and deal with padding tokens in case the input contains any of them.
722
- """
723
-
724
- def __init__(self, *args, **kwargs):
725
- super().__init__(*args, **kwargs)
726
- self.is_causal = False # Hack to make sure we don't use a causal mask
727
-
728
- def forward(
729
- self,
730
- hidden_states: torch.Tensor,
731
- attention_mask: Optional[torch.LongTensor] = None,
732
- position_ids: Optional[torch.LongTensor] = None,
733
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
734
- output_attentions: bool = False,
735
- use_cache: bool = False,
736
- **kwargs,
737
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
738
- output_attentions = False
739
-
740
- bsz, q_len, _ = hidden_states.size()
741
-
742
- query_states = self.q_proj(hidden_states)
743
- key_states = self.k_proj(hidden_states)
744
- value_states = self.v_proj(hidden_states)
745
-
746
- # Flash attention requires the input to have the shape
747
- # batch_size x seq_length x head_dim x hidden_dim
748
- # therefore we just need to keep the original shape
749
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
750
- key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
751
- value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
752
-
753
- kv_seq_len = key_states.shape[-2]
754
- if past_key_value is not None:
755
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
756
- # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
757
- # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
758
-
759
- # if past_key_value is not None:
760
- # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
761
- # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
762
-
763
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
764
- # to be able to avoid many of these transpose/reshape/view.
765
- query_states = query_states.transpose(1, 2)
766
- key_states = key_states.transpose(1, 2)
767
- value_states = value_states.transpose(1, 2)
768
-
769
- dropout_rate = self.dropout if self.training else 0.0
770
-
771
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
772
- # therefore the input hidden states gets silently casted in float32. Hence, we need
773
- # cast them back in the correct dtype just to be sure everything works as expected.
774
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
775
- # in fp32. (LlamaRMSNorm handles it correctly)
776
-
777
- input_dtype = query_states.dtype
778
- if input_dtype == torch.float32:
779
- if torch.is_autocast_enabled():
780
- target_dtype = torch.get_autocast_gpu_dtype()
781
- # Handle the case where the model is quantized
782
- elif hasattr(self.config, "_pre_quantization_dtype"):
783
- target_dtype = self.config._pre_quantization_dtype
784
- else:
785
- target_dtype = self.q_proj.weight.dtype
786
-
787
- logger.warning_once(
788
- "The input hidden states seems to be silently casted in float32, this might be related to the fact"
789
- " you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
790
- f" {target_dtype}."
791
- )
792
-
793
- query_states = query_states.to(target_dtype)
794
- key_states = key_states.to(target_dtype)
795
- value_states = value_states.to(target_dtype)
796
-
797
- attn_output = self._flash_attention_forward(
798
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
799
- )
800
-
801
- attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
802
- attn_output = self.out_proj(attn_output)
803
-
804
- if not output_attentions:
805
- attn_weights = None
806
-
807
- return attn_output, attn_weights
808
-
809
- def _flash_attention_forward(
810
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
811
- ):
812
- """
813
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
814
- first unpad the input, then computes the attention scores and pad the final attention scores.
815
- Args:
816
- query_states (`torch.Tensor`):
817
- Input query states to be passed to Flash Attention API
818
- key_states (`torch.Tensor`):
819
- Input key states to be passed to Flash Attention API
820
- value_states (`torch.Tensor`):
821
- Input value states to be passed to Flash Attention API
822
- attention_mask (`torch.Tensor`):
823
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
824
- position of padding tokens and 1 for the position of non-padding tokens.
825
- dropout (`int`, *optional*):
826
- Attention dropout
827
- softmax_scale (`float`, *optional*):
828
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
829
- """
830
-
831
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
832
- causal = self.is_causal and query_length != 1
833
-
834
- # Contains at least one padding token in the sequence
835
- if attention_mask is not None:
836
- batch_size = query_states.shape[0]
837
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
838
- query_states, key_states, value_states, attention_mask, query_length
839
- )
840
-
841
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
842
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
843
-
844
- attn_output_unpad = flash_attn_varlen_func(
845
- query_states,
846
- key_states,
847
- value_states,
848
- cu_seqlens_q=cu_seqlens_q,
849
- cu_seqlens_k=cu_seqlens_k,
850
- max_seqlen_q=max_seqlen_in_batch_q,
851
- max_seqlen_k=max_seqlen_in_batch_k,
852
- dropout_p=dropout,
853
- softmax_scale=softmax_scale,
854
- causal=causal,
855
- )
856
-
857
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
858
- else:
859
- attn_output = flash_attn_func(
860
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
861
- )
862
-
863
- return attn_output
864
-
865
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
866
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
867
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
868
-
869
- key_layer = index_first_axis(
870
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
871
- )
872
- value_layer = index_first_axis(
873
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
874
- )
875
- if query_length == kv_seq_len:
876
- query_layer = index_first_axis(
877
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
878
- )
879
- cu_seqlens_q = cu_seqlens_k
880
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
881
- indices_q = indices_k
882
- elif query_length == 1:
883
- max_seqlen_in_batch_q = 1
884
- cu_seqlens_q = torch.arange(
885
- batch_size + 1, dtype=torch.int32, device=query_layer.device
886
- ) # There is a memcpy here, that is very bad.
887
- indices_q = cu_seqlens_q[:-1]
888
- query_layer = query_layer.squeeze(1)
889
- else:
890
- # The -q_len: slice assumes left padding.
891
- attention_mask = attention_mask[:, -query_length:]
892
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
893
-
894
- return (
895
- query_layer,
896
- key_layer,
897
- value_layer,
898
- indices_q,
899
- (cu_seqlens_q, cu_seqlens_k),
900
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
901
- )
902
-
903
-
904
- # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
905
- class SiglipMLP(nn.Module):
906
- def __init__(self, config):
907
- super().__init__()
908
- self.config = config
909
- self.activation_fn = ACT2FN[config.hidden_act]
910
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
911
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
912
-
913
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
914
- hidden_states = self.fc1(hidden_states)
915
- hidden_states = self.activation_fn(hidden_states)
916
- hidden_states = self.fc2(hidden_states)
917
- return hidden_states
918
-
919
-
920
- # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
921
- class SiglipEncoderLayer(nn.Module):
922
- def __init__(self, config: SiglipConfig):
923
- super().__init__()
924
- self.embed_dim = config.hidden_size
925
- self.self_attn = (
926
- SiglipAttention(config)
927
- if not getattr(config, "_flash_attn_2_enabled", False)
928
- else SiglipFlashAttention2(config)
929
- )
930
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
931
- self.mlp = SiglipMLP(config)
932
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
933
-
934
- def forward(
935
- self,
936
- hidden_states: torch.Tensor,
937
- attention_mask: torch.Tensor,
938
- output_attentions: Optional[bool] = False,
939
- ) -> Tuple[torch.FloatTensor]:
940
- """
941
- Args:
942
- hidden_states (`torch.FloatTensor`):
943
- Input to the layer of shape `(batch, seq_len, embed_dim)`.
944
- attention_mask (`torch.FloatTensor`):
945
- Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
946
- output_attentions (`bool`, *optional*, defaults to `False`):
947
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
948
- returned tensors for more detail.
949
- """
950
- residual = hidden_states
951
-
952
- hidden_states = self.layer_norm1(hidden_states)
953
- hidden_states, attn_weights = self.self_attn(
954
- hidden_states=hidden_states,
955
- attention_mask=attention_mask,
956
- output_attentions=output_attentions,
957
- )
958
- hidden_states = residual + hidden_states
959
-
960
- residual = hidden_states
961
- hidden_states = self.layer_norm2(hidden_states)
962
- hidden_states = self.mlp(hidden_states)
963
- hidden_states = residual + hidden_states
964
-
965
- outputs = (hidden_states,)
966
-
967
- if output_attentions:
968
- outputs += (attn_weights,)
969
-
970
- return outputs
971
-
972
-
973
- class SiglipPreTrainedModel(PreTrainedModel):
974
- """
975
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
976
- models.
977
- """
978
-
979
- config_class = SiglipConfig
980
- base_model_prefix = "siglip"
981
- supports_gradient_checkpointing = True
982
-
983
- def _init_weights(self, module):
984
- """Initialize the weights"""
985
-
986
- if isinstance(module, SiglipVisionEmbeddings):
987
- width = (
988
- self.config.vision_config.hidden_size
989
- if isinstance(self.config, SiglipConfig)
990
- else self.config.hidden_size
991
- )
992
- nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
993
- elif isinstance(module, nn.Embedding):
994
- default_flax_embed_init(module.weight)
995
- elif isinstance(module, SiglipAttention):
996
- nn.init.normal_(module.q_proj.weight)
997
- nn.init.normal_(module.k_proj.weight)
998
- nn.init.normal_(module.v_proj.weight)
999
- nn.init.normal_(module.out_proj.weight)
1000
- nn.init.zeros_(module.q_proj.bias)
1001
- nn.init.zeros_(module.k_proj.bias)
1002
- nn.init.zeros_(module.v_proj.bias)
1003
- nn.init.zeros_(module.out_proj.bias)
1004
- elif isinstance(module, SiglipMLP):
1005
- nn.init.normal_(module.fc1.weight)
1006
- nn.init.normal_(module.fc2.weight)
1007
- nn.init.normal_(module.fc1.bias, std=1e-6)
1008
- nn.init.normal_(module.fc2.bias, std=1e-6)
1009
- elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
1010
- nn.init.normal_(module.probe.data)
1011
- nn.init.normal_(module.attention.in_proj_weight.data)
1012
- nn.init.zeros_(module.attention.in_proj_bias.data)
1013
- elif isinstance(module, SiglipModel):
1014
- logit_scale_init = torch.tensor(0.0)
1015
- module.logit_scale.data.fill_(logit_scale_init)
1016
- module.logit_bias.data.zero_()
1017
- elif isinstance(module, (nn.Linear, nn.Conv2d)):
1018
- lecun_normal_(module.weight)
1019
- if module.bias is not None:
1020
- nn.init.zeros_(module.bias)
1021
- elif isinstance(module, nn.LayerNorm):
1022
- module.bias.data.zero_()
1023
- module.weight.data.fill_(1.0)
1024
-
1025
-
1026
- SIGLIP_START_DOCSTRING = r"""
1027
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1028
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1029
- etc.)
1030
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1031
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1032
- and behavior.
1033
- Parameters:
1034
- config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
1035
- Initializing with a config file does not load the weights associated with the model, only the
1036
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1037
- """
1038
-
1039
- SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
1040
- Args:
1041
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1042
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1043
- it.
1044
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1045
- [`PreTrainedTokenizer.__call__`] for details.
1046
- [What are input IDs?](../glossary#input-ids)
1047
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1048
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1049
- - 1 for tokens that are **not masked**,
1050
- - 0 for tokens that are **masked**.
1051
- [What are attention masks?](../glossary#attention-mask)
1052
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1053
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1054
- config.max_position_embeddings - 1]`.
1055
- [What are position IDs?](../glossary#position-ids)
1056
- output_attentions (`bool`, *optional*):
1057
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1058
- tensors for more detail.
1059
- output_hidden_states (`bool`, *optional*):
1060
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1061
- more detail.
1062
- return_dict (`bool`, *optional*):
1063
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1064
- """
1065
-
1066
- SIGLIP_VISION_INPUTS_DOCSTRING = r"""
1067
- Args:
1068
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1069
- Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
1070
- [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
1071
- output_attentions (`bool`, *optional*):
1072
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1073
- tensors for more detail.
1074
- output_hidden_states (`bool`, *optional*):
1075
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1076
- more detail.
1077
- return_dict (`bool`, *optional*):
1078
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1079
- """
1080
-
1081
- SIGLIP_INPUTS_DOCSTRING = r"""
1082
- Args:
1083
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1084
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1085
- it.
1086
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1087
- [`PreTrainedTokenizer.__call__`] for details.
1088
- [What are input IDs?](../glossary#input-ids)
1089
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1090
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1091
- - 1 for tokens that are **not masked**,
1092
- - 0 for tokens that are **masked**.
1093
- [What are attention masks?](../glossary#attention-mask)
1094
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1095
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1096
- config.max_position_embeddings - 1]`.
1097
- [What are position IDs?](../glossary#position-ids)
1098
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1099
- Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
1100
- [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
1101
- return_loss (`bool`, *optional*):
1102
- Whether or not to return the contrastive loss.
1103
- output_attentions (`bool`, *optional*):
1104
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1105
- tensors for more detail.
1106
- output_hidden_states (`bool`, *optional*):
1107
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1108
- more detail.
1109
- return_dict (`bool`, *optional*):
1110
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1111
- """
1112
-
1113
-
1114
- # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
1115
- class SiglipEncoder(nn.Module):
1116
- """
1117
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
1118
- [`SiglipEncoderLayer`].
1119
- Args:
1120
- config: SiglipConfig
1121
- """
1122
-
1123
- def __init__(self, config: SiglipConfig):
1124
- super().__init__()
1125
- self.config = config
1126
- self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
1127
- self.gradient_checkpointing = False
1128
-
1129
- # Ignore copy
1130
- def forward(
1131
- self,
1132
- inputs_embeds,
1133
- attention_mask: Optional[torch.Tensor] = None,
1134
- output_attentions: Optional[bool] = None,
1135
- output_hidden_states: Optional[bool] = None,
1136
- return_dict: Optional[bool] = None,
1137
- ) -> Union[Tuple, BaseModelOutput]:
1138
- r"""
1139
- Args:
1140
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1141
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1142
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1143
- than the model's internal embedding lookup matrix.
1144
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1145
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1146
- - 1 for tokens that are **not masked**,
1147
- - 0 for tokens that are **masked**.
1148
- [What are attention masks?](../glossary#attention-mask)
1149
- output_attentions (`bool`, *optional*):
1150
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1151
- returned tensors for more detail.
1152
- output_hidden_states (`bool`, *optional*):
1153
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1154
- for more detail.
1155
- return_dict (`bool`, *optional*):
1156
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1157
- """
1158
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1159
- output_hidden_states = (
1160
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1161
- )
1162
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1163
-
1164
- encoder_states = () if output_hidden_states else None
1165
- all_attentions = () if output_attentions else None
1166
-
1167
- hidden_states = inputs_embeds
1168
- for encoder_layer in self.layers:
1169
- if output_hidden_states:
1170
- encoder_states = encoder_states + (hidden_states,)
1171
- if self.gradient_checkpointing and self.training:
1172
- layer_outputs = self._gradient_checkpointing_func(
1173
- encoder_layer.__call__,
1174
- hidden_states,
1175
- attention_mask,
1176
- output_attentions,
1177
- )
1178
- else:
1179
- layer_outputs = encoder_layer(
1180
- hidden_states,
1181
- attention_mask,
1182
- output_attentions=output_attentions,
1183
- )
1184
-
1185
- hidden_states = layer_outputs[0]
1186
-
1187
- if output_attentions:
1188
- all_attentions = all_attentions + (layer_outputs[1],)
1189
-
1190
- if output_hidden_states:
1191
- encoder_states = encoder_states + (hidden_states,)
1192
-
1193
- if not return_dict:
1194
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
1195
- return BaseModelOutput(
1196
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
1197
- )
1198
-
1199
-
1200
- class SiglipTextTransformer(nn.Module):
1201
- def __init__(self, config: SiglipTextConfig):
1202
- super().__init__()
1203
- self.config = config
1204
- embed_dim = config.hidden_size
1205
- self.embeddings = SiglipTextEmbeddings(config)
1206
- self.encoder = SiglipEncoder(config)
1207
- self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1208
-
1209
- self.head = nn.Linear(embed_dim, embed_dim)
1210
-
1211
- @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1212
- @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
1213
- def forward(
1214
- self,
1215
- input_ids: Optional[torch.Tensor] = None,
1216
- attention_mask: Optional[torch.Tensor] = None,
1217
- position_ids: Optional[torch.Tensor] = None,
1218
- output_attentions: Optional[bool] = None,
1219
- output_hidden_states: Optional[bool] = None,
1220
- return_dict: Optional[bool] = None,
1221
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
1222
- r"""
1223
- Returns:
1224
- """
1225
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1226
- output_hidden_states = (
1227
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1228
- )
1229
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1230
-
1231
- if input_ids is None:
1232
- raise ValueError("You have to specify input_ids")
1233
-
1234
- input_shape = input_ids.size()
1235
- input_ids = input_ids.view(-1, input_shape[-1])
1236
-
1237
- hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
1238
-
1239
- # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
1240
- # expand attention_mask
1241
- if attention_mask is not None:
1242
- # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
1243
- attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
1244
-
1245
- encoder_outputs = self.encoder(
1246
- inputs_embeds=hidden_states,
1247
- attention_mask=attention_mask,
1248
- output_attentions=output_attentions,
1249
- output_hidden_states=output_hidden_states,
1250
- return_dict=return_dict,
1251
- )
1252
-
1253
- last_hidden_state = encoder_outputs[0]
1254
- last_hidden_state = self.final_layer_norm(last_hidden_state)
1255
-
1256
- # Assuming "sticky" EOS tokenization, last token is always EOS.
1257
- pooled_output = last_hidden_state[:, -1, :]
1258
- pooled_output = self.head(pooled_output)
1259
-
1260
- if not return_dict:
1261
- return (last_hidden_state, pooled_output) + encoder_outputs[1:]
1262
-
1263
- return BaseModelOutputWithPooling(
1264
- last_hidden_state=last_hidden_state,
1265
- pooler_output=pooled_output,
1266
- hidden_states=encoder_outputs.hidden_states,
1267
- attentions=encoder_outputs.attentions,
1268
- )
1269
-
1270
-
1271
- @add_start_docstrings(
1272
- """The text model from SigLIP without any head or projection on top.""",
1273
- SIGLIP_START_DOCSTRING,
1274
- )
1275
- class SiglipTextModel(SiglipPreTrainedModel):
1276
- config_class = SiglipTextConfig
1277
-
1278
- _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
1279
-
1280
- def __init__(self, config: SiglipTextConfig):
1281
- super().__init__(config)
1282
- self.text_model = SiglipTextTransformer(config)
1283
- # Initialize weights and apply final processing
1284
- self.post_init()
1285
-
1286
- def get_input_embeddings(self) -> nn.Module:
1287
- return self.text_model.embeddings.token_embedding
1288
-
1289
- def set_input_embeddings(self, value):
1290
- self.text_model.embeddings.token_embedding = value
1291
-
1292
- @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1293
- @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
1294
- def forward(
1295
- self,
1296
- input_ids: Optional[torch.Tensor] = None,
1297
- attention_mask: Optional[torch.Tensor] = None,
1298
- position_ids: Optional[torch.Tensor] = None,
1299
- output_attentions: Optional[bool] = None,
1300
- output_hidden_states: Optional[bool] = None,
1301
- return_dict: Optional[bool] = None,
1302
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
1303
- r"""
1304
- Returns:
1305
- Examples:
1306
- ```python
1307
- >>> from transformers import AutoTokenizer, SiglipTextModel
1308
- >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
1309
- >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1310
- >>> # important: make sure to set padding="max_length" as that's how the model was trained
1311
- >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1312
- >>> outputs = model(**inputs)
1313
- >>> last_hidden_state = outputs.last_hidden_state
1314
- >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
1315
- ```"""
1316
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1317
-
1318
- return self.text_model(
1319
- input_ids=input_ids,
1320
- attention_mask=attention_mask,
1321
- position_ids=position_ids,
1322
- output_attentions=output_attentions,
1323
- output_hidden_states=output_hidden_states,
1324
- return_dict=return_dict,
1325
- )
1326
-
1327
-
1328
- class SiglipVisionTransformer(nn.Module):
1329
- def __init__(self, config: SiglipVisionConfig):
1330
- super().__init__()
1331
- self.config = config
1332
- embed_dim = config.hidden_size
1333
-
1334
- self.embeddings = SiglipVisionEmbeddings(config)
1335
- self.encoder = SiglipEncoder(config)
1336
- self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1337
- self.head = SiglipMultiheadAttentionPoolingHead(config)
1338
-
1339
- @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1340
- @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
1341
- def forward(
1342
- self,
1343
- pixel_values,
1344
- patch_attention_mask: Optional[torch.BoolTensor] = None,
1345
- output_attentions: Optional[bool] = None,
1346
- output_hidden_states: Optional[bool] = None,
1347
- return_dict: Optional[bool] = None,
1348
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
1349
- r"""
1350
- Returns:
1351
- """
1352
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1353
- output_hidden_states = (
1354
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1355
- )
1356
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1357
-
1358
- batch_size = pixel_values.size(0)
1359
- if patch_attention_mask is None:
1360
- patch_attention_mask = torch.ones(
1361
- size=(
1362
- batch_size,
1363
- pixel_values.size(2) // self.config.patch_size,
1364
- pixel_values.size(3) // self.config.patch_size,
1365
- ),
1366
- dtype=torch.bool,
1367
- device=pixel_values.device,
1368
- )
1369
-
1370
- hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
1371
-
1372
- patch_attention_mask = patch_attention_mask.view(batch_size, -1)
1373
- # The call to `_upad_input` in `_flash_attention_forward` is expensive
1374
- # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
1375
- # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
1376
- if not torch.any(~patch_attention_mask):
1377
- attention_mask=None
1378
- else:
1379
- attention_mask = (
1380
- _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
1381
- if not self.config._flash_attn_2_enabled
1382
- else patch_attention_mask
1383
- )
1384
-
1385
- encoder_outputs = self.encoder(
1386
- inputs_embeds=hidden_states,
1387
- attention_mask=attention_mask,
1388
- output_attentions=output_attentions,
1389
- output_hidden_states=output_hidden_states,
1390
- return_dict=return_dict,
1391
- )
1392
-
1393
- last_hidden_state = encoder_outputs[0]
1394
- last_hidden_state = self.post_layernorm(last_hidden_state)
1395
-
1396
- pooled_output = self.head(
1397
- hidden_state=last_hidden_state,
1398
- attention_mask=patch_attention_mask,
1399
- )
1400
-
1401
- if not return_dict:
1402
- return (last_hidden_state, pooled_output) + encoder_outputs[1:]
1403
-
1404
- return BaseModelOutputWithPooling(
1405
- last_hidden_state=last_hidden_state,
1406
- pooler_output=pooled_output,
1407
- hidden_states=encoder_outputs.hidden_states,
1408
- attentions=encoder_outputs.attentions,
1409
- )
1410
-
1411
-
1412
- class SiglipMultiheadAttentionPoolingHead(nn.Module):
1413
- """Multihead Attention Pooling."""
1414
-
1415
- def __init__(self, config: SiglipVisionConfig):
1416
- super().__init__()
1417
-
1418
- self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
1419
- self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
1420
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1421
- self.mlp = SiglipMLP(config)
1422
-
1423
- def forward(self, hidden_state, attention_mask):
1424
- batch_size = hidden_state.shape[0]
1425
- probe = self.probe.repeat(batch_size, 1, 1)
1426
-
1427
- hidden_state = self.attention(
1428
- query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask
1429
- )[0]
1430
-
1431
- residual = hidden_state
1432
- hidden_state = self.layernorm(hidden_state)
1433
- hidden_state = residual + self.mlp(hidden_state)
1434
-
1435
- return hidden_state[:, 0]
1436
-
1437
-
1438
- @add_start_docstrings(
1439
- """The vision model from SigLIP without any head or projection on top.""",
1440
- SIGLIP_START_DOCSTRING,
1441
- )
1442
- class SiglipVisionModel(SiglipPreTrainedModel):
1443
- config_class = SiglipVisionConfig
1444
- main_input_name = "pixel_values"
1445
-
1446
- def __init__(self, config: SiglipVisionConfig):
1447
- super().__init__(config)
1448
-
1449
- self.vision_model = SiglipVisionTransformer(config)
1450
-
1451
- # Initialize weights and apply final processing
1452
- self.post_init()
1453
-
1454
- def get_input_embeddings(self) -> nn.Module:
1455
- return self.vision_model.embeddings.patch_embedding
1456
-
1457
- @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1458
- @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
1459
- def forward(
1460
- self,
1461
- pixel_values,
1462
- patch_attention_mask: Optional[torch.BoolTensor] = None,
1463
- output_attentions: Optional[bool] = None,
1464
- output_hidden_states: Optional[bool] = None,
1465
- return_dict: Optional[bool] = None,
1466
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
1467
- r"""
1468
- Returns:
1469
- Examples:
1470
- ```python
1471
- >>> from PIL import Image
1472
- >>> import requests
1473
- >>> from transformers import AutoProcessor, SiglipVisionModel
1474
- >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
1475
- >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1476
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1477
- >>> image = Image.open(requests.get(url, stream=True).raw)
1478
- >>> inputs = processor(images=image, return_tensors="pt")
1479
- >>> outputs = model(**inputs)
1480
- >>> last_hidden_state = outputs.last_hidden_state
1481
- >>> pooled_output = outputs.pooler_output # pooled features
1482
- ```"""
1483
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1484
-
1485
- return self.vision_model(
1486
- pixel_values=pixel_values,
1487
- patch_attention_mask=patch_attention_mask,
1488
- output_attentions=output_attentions,
1489
- output_hidden_states=output_hidden_states,
1490
- return_dict=return_dict,
1491
- )
1492
-
1493
-
1494
- @add_start_docstrings(SIGLIP_START_DOCSTRING)
1495
- class SiglipModel(SiglipPreTrainedModel):
1496
- config_class = SiglipConfig
1497
-
1498
- def __init__(self, config: SiglipConfig):
1499
- super().__init__(config)
1500
-
1501
- if not isinstance(config.text_config, SiglipTextConfig):
1502
- raise ValueError(
1503
- "config.text_config is expected to be of type SiglipTextConfig but is of type"
1504
- f" {type(config.text_config)}."
1505
- )
1506
-
1507
- if not isinstance(config.vision_config, SiglipVisionConfig):
1508
- raise ValueError(
1509
- "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
1510
- f" {type(config.vision_config)}."
1511
- )
1512
-
1513
- text_config = config.text_config
1514
- vision_config = config.vision_config
1515
-
1516
- self.text_model = SiglipTextTransformer(text_config)
1517
- self.vision_model = SiglipVisionTransformer(vision_config)
1518
-
1519
- self.logit_scale = nn.Parameter(torch.randn(1))
1520
- self.logit_bias = nn.Parameter(torch.randn(1))
1521
-
1522
- # Initialize weights and apply final processing
1523
- self.post_init()
1524
-
1525
- @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1526
- def get_text_features(
1527
- self,
1528
- input_ids: Optional[torch.Tensor] = None,
1529
- attention_mask: Optional[torch.Tensor] = None,
1530
- position_ids: Optional[torch.Tensor] = None,
1531
- output_attentions: Optional[bool] = None,
1532
- output_hidden_states: Optional[bool] = None,
1533
- return_dict: Optional[bool] = None,
1534
- ) -> torch.FloatTensor:
1535
- r"""
1536
- Returns:
1537
- text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1538
- applying the projection layer to the pooled output of [`SiglipTextModel`].
1539
- Examples:
1540
- ```python
1541
- >>> from transformers import AutoTokenizer, AutoModel
1542
- >>> import torch
1543
- >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1544
- >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1545
- >>> # important: make sure to set padding="max_length" as that's how the model was trained
1546
- >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1547
- >>> with torch.no_grad():
1548
- ... text_features = model.get_text_features(**inputs)
1549
- ```"""
1550
- # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1551
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1552
- output_hidden_states = (
1553
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1554
- )
1555
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1556
-
1557
- text_outputs = self.text_model(
1558
- input_ids=input_ids,
1559
- attention_mask=attention_mask,
1560
- position_ids=position_ids,
1561
- output_attentions=output_attentions,
1562
- output_hidden_states=output_hidden_states,
1563
- return_dict=return_dict,
1564
- )
1565
-
1566
- pooled_output = text_outputs[1]
1567
-
1568
- return pooled_output
1569
-
1570
- @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1571
- def get_image_features(
1572
- self,
1573
- pixel_values: Optional[torch.FloatTensor] = None,
1574
- output_attentions: Optional[bool] = None,
1575
- output_hidden_states: Optional[bool] = None,
1576
- return_dict: Optional[bool] = None,
1577
- ) -> torch.FloatTensor:
1578
- r"""
1579
- Returns:
1580
- image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1581
- applying the projection layer to the pooled output of [`SiglipVisionModel`].
1582
- Examples:
1583
- ```python
1584
- >>> from PIL import Image
1585
- >>> import requests
1586
- >>> from transformers import AutoProcessor, AutoModel
1587
- >>> import torch
1588
- >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1589
- >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1590
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1591
- >>> image = Image.open(requests.get(url, stream=True).raw)
1592
- >>> inputs = processor(images=image, return_tensors="pt")
1593
- >>> with torch.no_grad():
1594
- ... image_features = model.get_image_features(**inputs)
1595
- ```"""
1596
- # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1597
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1598
- output_hidden_states = (
1599
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1600
- )
1601
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1602
-
1603
- vision_outputs = self.vision_model(
1604
- pixel_values=pixel_values,
1605
- output_attentions=output_attentions,
1606
- output_hidden_states=output_hidden_states,
1607
- return_dict=return_dict,
1608
- )
1609
-
1610
- pooled_output = vision_outputs[1]
1611
-
1612
- return pooled_output
1613
-
1614
- @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1615
- @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1616
- def forward(
1617
- self,
1618
- input_ids: Optional[torch.LongTensor] = None,
1619
- pixel_values: Optional[torch.FloatTensor] = None,
1620
- attention_mask: Optional[torch.Tensor] = None,
1621
- position_ids: Optional[torch.LongTensor] = None,
1622
- return_loss: Optional[bool] = None,
1623
- output_attentions: Optional[bool] = None,
1624
- output_hidden_states: Optional[bool] = None,
1625
- return_dict: Optional[bool] = None,
1626
- ) -> Union[Tuple, SiglipOutput]:
1627
- r"""
1628
- Returns:
1629
- Examples:
1630
- ```python
1631
- >>> from PIL import Image
1632
- >>> import requests
1633
- >>> from transformers import AutoProcessor, AutoModel
1634
- >>> import torch
1635
- >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1636
- >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1637
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1638
- >>> image = Image.open(requests.get(url, stream=True).raw)
1639
- >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1640
- >>> # important: we pass `padding=max_length` since the model was trained with this
1641
- >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
1642
- >>> with torch.no_grad():
1643
- ... outputs = model(**inputs)
1644
- >>> logits_per_image = outputs.logits_per_image
1645
- >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1646
- >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
1647
- 31.9% that image 0 is 'a photo of 2 cats'
1648
- ```"""
1649
- # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1650
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1651
- output_hidden_states = (
1652
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1653
- )
1654
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1655
-
1656
- vision_outputs = self.vision_model(
1657
- pixel_values=pixel_values,
1658
- output_attentions=output_attentions,
1659
- output_hidden_states=output_hidden_states,
1660
- return_dict=return_dict,
1661
- )
1662
-
1663
- text_outputs = self.text_model(
1664
- input_ids=input_ids,
1665
- attention_mask=attention_mask,
1666
- position_ids=position_ids,
1667
- output_attentions=output_attentions,
1668
- output_hidden_states=output_hidden_states,
1669
- return_dict=return_dict,
1670
- )
1671
-
1672
- image_embeds = vision_outputs[1]
1673
- text_embeds = text_outputs[1]
1674
-
1675
- # normalized features
1676
- image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1677
- text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1678
-
1679
- # cosine similarity as logits
1680
- logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
1681
- logits_per_image = logits_per_text.t()
1682
-
1683
- loss = None
1684
- if return_loss:
1685
- raise NotImplementedError("SigLIP loss to be implemented")
1686
-
1687
- if not return_dict:
1688
- output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1689
- return ((loss,) + output) if loss is not None else output
1690
-
1691
- return SiglipOutput(
1692
- loss=loss,
1693
- logits_per_image=logits_per_image,
1694
- logits_per_text=logits_per_text,
1695
- text_embeds=text_embeds,
1696
- image_embeds=image_embeds,
1697
- text_model_output=text_outputs,
1698
- vision_model_output=vision_outputs,
1699
- )
1700
-
1701
-
1702
- def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs):
1703
- siglip_vision_config = {
1704
- "hidden_size": 1152,
1705
- "image_size": 448,
1706
- "intermediate_size": 4304,
1707
- "model_type": "siglip_vision_model",
1708
- "num_attention_heads": 16,
1709
- "num_hidden_layers": 27,
1710
- "patch_size": 14,
1711
- }
1712
-
1713
- model_config = SiglipVisionConfig(**siglip_vision_config, _flash_attn_2_enabled=_flash_attn_2_enabled, **kwargs)
1714
-
1715
- vision_model = SiglipVisionModel(model_config).vision_model
1716
-
1717
- return vision_model