xxyyy123 commited on
Commit
ab82b02
·
verified ·
1 Parent(s): 5a1f9b9

Add SDPA fallback for Siglip2Navit attention (#8)

Browse files

- Add SDPA fallback for Siglip2Navit attention (521a718db3f4012656d2e9305bacb5774081a98c)

Files changed (1) hide show
  1. modeling_ovis2_5.py +62 -12
modeling_ovis2_5.py CHANGED
@@ -4,8 +4,6 @@ from typing import Dict, List, Optional, Tuple, Union
4
  import PIL.Image
5
  import numpy as np
6
  import torch
7
- from flash_attn import flash_attn_varlen_func
8
- from flash_attn.layers.rotary import apply_rotary_emb
9
  from torch import Tensor, nn
10
  from torch.nn import functional as F
11
  from transformers import (
@@ -19,9 +17,16 @@ from transformers.activations import ACT2FN
19
  from transformers.generation.utils import GenerateOutput
20
  from transformers.modeling_outputs import BaseModelOutputWithNoAttention
21
  from transformers.modeling_utils import PreTrainedModel
 
22
 
23
  from .configuration_ovis2_5 import Siglip2NavitConfig, Ovis2_5_Config
24
 
 
 
 
 
 
 
25
  IMAGE_PLACEHOLDER = "<image>"
26
  IMAGE_PLACEHOLDER_ID = -200
27
  VIDEO_PLACEHOLDER = "<video>"
@@ -30,6 +35,7 @@ VIDEO_PLACEHOLDER_ID = -201
30
  VISUAL_ATOM_ID = -300
31
  INDICATOR_IDS = [-301, -302, -303, -304]
32
 
 
33
  # copied from qwen2.5-vl
34
  class VisionRotaryEmbedding(nn.Module):
35
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
@@ -86,7 +92,6 @@ class Siglip2VisionEmbeddings(nn.Module):
86
  ) -> torch.Tensor:
87
  """
88
  Resize positional embeddings to image-specific size and pad to a fixed size.
89
-
90
  Args:
91
  positional_embeddings (`torch.Tensor`):
92
  Position embeddings of shape (height, width, embed_dim)
@@ -94,7 +99,6 @@ class Siglip2VisionEmbeddings(nn.Module):
94
  Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
95
  max_length (`int`):
96
  Maximum length of the positional embeddings to pad resized positional embeddings to
97
-
98
  Returns:
99
  `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
100
  """
@@ -193,6 +197,28 @@ def apply_rotary_pos_emb_flashatt(
193
  return q_embed, k_embed
194
 
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  class Siglip2Attention(nn.Module):
197
  """Multi-headed attention from 'Attention Is All You Need' paper"""
198
 
@@ -238,14 +264,41 @@ class Siglip2Attention(nn.Module):
238
 
239
  if self.use_rope:
240
  cos, sin = position_embeddings
241
- queries, keys = apply_rotary_pos_emb_flashatt(queries.unsqueeze(0), keys.unsqueeze(0), cos, sin)
 
 
 
242
  queries = queries.squeeze(0)
243
  keys = keys.squeeze(0)
244
 
245
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
246
- attn_output = flash_attn_varlen_func(queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
247
- seq_length, -1
248
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  attn_output = self.out_proj(attn_output)
250
  return attn_output
251
 
@@ -310,7 +363,6 @@ class Siglip2Encoder(nn.Module):
310
  """
311
  Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
312
  [`Siglip2EncoderLayer`].
313
-
314
  Args:
315
  config: Siglip2NavitConfig
316
  """
@@ -415,10 +467,8 @@ class Siglip2Encoder(nn.Module):
415
  than the model's internal embedding lookup matrix.
416
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
417
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
418
-
419
  - 1 for tokens that are **not masked**,
420
  - 0 for tokens that are **masked**.
421
-
422
  [What are attention masks?](../glossary#attention-mask)
423
  output_attentions (`bool`, *optional*):
424
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
@@ -946,4 +996,4 @@ class Ovis2_5(OvisPreTrainedModel):
946
  AutoConfig.register('siglip2_navit', Siglip2NavitConfig)
947
  AutoModel.register(Siglip2NavitConfig, Siglip2NavitModel)
948
  AutoConfig.register("ovis2_5", Ovis2_5_Config)
949
- AutoModelForCausalLM.register(Ovis2_5_Config, Ovis2_5)
 
4
  import PIL.Image
5
  import numpy as np
6
  import torch
 
 
7
  from torch import Tensor, nn
8
  from torch.nn import functional as F
9
  from transformers import (
 
17
  from transformers.generation.utils import GenerateOutput
18
  from transformers.modeling_outputs import BaseModelOutputWithNoAttention
19
  from transformers.modeling_utils import PreTrainedModel
20
+ from transformers.utils import is_flash_attn_2_available
21
 
22
  from .configuration_ovis2_5 import Siglip2NavitConfig, Ovis2_5_Config
23
 
24
+
25
+ if is_flash_attn_2_available():
26
+ from flash_attn import flash_attn_varlen_func
27
+ from flash_attn.layers.rotary import apply_rotary_emb
28
+
29
+
30
  IMAGE_PLACEHOLDER = "<image>"
31
  IMAGE_PLACEHOLDER_ID = -200
32
  VIDEO_PLACEHOLDER = "<video>"
 
35
  VISUAL_ATOM_ID = -300
36
  INDICATOR_IDS = [-301, -302, -303, -304]
37
 
38
+
39
  # copied from qwen2.5-vl
40
  class VisionRotaryEmbedding(nn.Module):
41
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
 
92
  ) -> torch.Tensor:
93
  """
94
  Resize positional embeddings to image-specific size and pad to a fixed size.
 
95
  Args:
96
  positional_embeddings (`torch.Tensor`):
97
  Position embeddings of shape (height, width, embed_dim)
 
99
  Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
100
  max_length (`int`):
101
  Maximum length of the positional embeddings to pad resized positional embeddings to
 
102
  Returns:
103
  `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
104
  """
 
197
  return q_embed, k_embed
198
 
199
 
200
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
201
+ def rotate_half(x):
202
+ """Rotates half the hidden dims of the input."""
203
+ x1 = x[..., : x.shape[-1] // 2]
204
+ x2 = x[..., x.shape[-1] // 2 :]
205
+ return torch.cat((-x2, x1), dim=-1)
206
+
207
+
208
+ def apply_rotary_pos_emb_vision(
209
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
210
+ ) -> tuple[torch.Tensor, torch.Tensor]:
211
+ orig_q_dtype = q.dtype
212
+ orig_k_dtype = k.dtype
213
+ q, k = q.float(), k.float()
214
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
215
+ q_embed = (q * cos) + (rotate_half(q) * sin)
216
+ k_embed = (k * cos) + (rotate_half(k) * sin)
217
+ q_embed = q_embed.to(orig_q_dtype)
218
+ k_embed = k_embed.to(orig_k_dtype)
219
+ return q_embed, k_embed
220
+
221
+
222
  class Siglip2Attention(nn.Module):
223
  """Multi-headed attention from 'Attention Is All You Need' paper"""
224
 
 
264
 
265
  if self.use_rope:
266
  cos, sin = position_embeddings
267
+ if is_flash_attn_2_available():
268
+ queries, keys = apply_rotary_pos_emb_flashatt(queries.unsqueeze(0), keys.unsqueeze(0), cos, sin)
269
+ else:
270
+ queries, keys = apply_rotary_pos_emb_vision(queries.unsqueeze(0), keys.unsqueeze(0), cos, sin)
271
  queries = queries.squeeze(0)
272
  keys = keys.squeeze(0)
273
 
274
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
275
+ if is_flash_attn_2_available():
276
+ attn_output = flash_attn_varlen_func(queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
277
+ seq_length, -1
278
+ )
279
+ else:
280
+ batch_size = cu_seqlens.shape[0] - 1
281
+ outputs = []
282
+ cu = cu_seqlens.tolist()
283
+ for i in range(batch_size):
284
+ start_idx = cu[i]
285
+ end_idx = cu[i + 1]
286
+ # Each sequence is processed independently.
287
+ q_i = queries[start_idx:end_idx].unsqueeze(0)
288
+ k_i = keys[start_idx:end_idx].unsqueeze(0)
289
+ v_i = values[start_idx:end_idx].unsqueeze(0)
290
+ # (1, seq_len, num_heads, head_dim) ->
291
+ # (1, num_heads, seq_len, head_dim)
292
+ q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]
293
+ output_i = F.scaled_dot_product_attention(q_i,
294
+ k_i,
295
+ v_i,
296
+ dropout_p=0.0)
297
+ # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
298
+ output_i = output_i.transpose(1, 2).reshape(-1, self.embed_dim)
299
+ outputs.append(output_i)
300
+ attn_output = torch.cat(outputs, dim=0)
301
+
302
  attn_output = self.out_proj(attn_output)
303
  return attn_output
304
 
 
363
  """
364
  Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
365
  [`Siglip2EncoderLayer`].
 
366
  Args:
367
  config: Siglip2NavitConfig
368
  """
 
467
  than the model's internal embedding lookup matrix.
468
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
469
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
470
  - 1 for tokens that are **not masked**,
471
  - 0 for tokens that are **masked**.
 
472
  [What are attention masks?](../glossary#attention-mask)
473
  output_attentions (`bool`, *optional*):
474
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
 
996
  AutoConfig.register('siglip2_navit', Siglip2NavitConfig)
997
  AutoModel.register(Siglip2NavitConfig, Siglip2NavitModel)
998
  AutoConfig.register("ovis2_5", Ovis2_5_Config)
999
+ AutoModelForCausalLM.register(Ovis2_5_Config, Ovis2_5)