xxyyy123 Isotr0py commited on
Commit
181a23b
·
verified ·
1 Parent(s): 5456300

Add SDPA fallback for Siglip2Navit attention (#5)

Browse files

- add sdpa fallback (8d086dc4f92f19f736480202f96ab109e74add20)
- clean (ad6d85fb24806f7fa8f0f51c20dfbae4c18f3bb8)


Co-authored-by: Isotr0py <[email protected]>

Files changed (1) hide show
  1. modeling_ovis2_5.py +61 -6
modeling_ovis2_5.py CHANGED
@@ -4,8 +4,6 @@ from typing import Dict, List, Optional, Tuple, Union
4
  import PIL.Image
5
  import numpy as np
6
  import torch
7
- from flash_attn import flash_attn_varlen_func
8
- from flash_attn.layers.rotary import apply_rotary_emb
9
  from torch import Tensor, nn
10
  from torch.nn import functional as F
11
  from transformers import (
@@ -19,9 +17,16 @@ from transformers.activations import ACT2FN
19
  from transformers.generation.utils import GenerateOutput
20
  from transformers.modeling_outputs import BaseModelOutputWithNoAttention
21
  from transformers.modeling_utils import PreTrainedModel
 
22
 
23
  from .configuration_ovis2_5 import Siglip2NavitConfig, Ovis2_5_Config
24
 
 
 
 
 
 
 
25
  IMAGE_PLACEHOLDER = "<image>"
26
  IMAGE_PLACEHOLDER_ID = -200
27
  VIDEO_PLACEHOLDER = "<video>"
@@ -30,6 +35,7 @@ VIDEO_PLACEHOLDER_ID = -201
30
  VISUAL_ATOM_ID = -300
31
  INDICATOR_IDS = [-301, -302, -303, -304]
32
 
 
33
  # copied from qwen2.5-vl
34
  class VisionRotaryEmbedding(nn.Module):
35
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
@@ -193,6 +199,28 @@ def apply_rotary_pos_emb_flashatt(
193
  return q_embed, k_embed
194
 
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  class Siglip2Attention(nn.Module):
197
  """Multi-headed attention from 'Attention Is All You Need' paper"""
198
 
@@ -238,14 +266,41 @@ class Siglip2Attention(nn.Module):
238
 
239
  if self.use_rope:
240
  cos, sin = position_embeddings
241
- queries, keys = apply_rotary_pos_emb_flashatt(queries.unsqueeze(0), keys.unsqueeze(0), cos, sin)
 
 
 
242
  queries = queries.squeeze(0)
243
  keys = keys.squeeze(0)
244
 
245
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
246
- attn_output = flash_attn_varlen_func(queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
247
- seq_length, -1
248
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  attn_output = self.out_proj(attn_output)
250
  return attn_output
251
 
 
4
  import PIL.Image
5
  import numpy as np
6
  import torch
 
 
7
  from torch import Tensor, nn
8
  from torch.nn import functional as F
9
  from transformers import (
 
17
  from transformers.generation.utils import GenerateOutput
18
  from transformers.modeling_outputs import BaseModelOutputWithNoAttention
19
  from transformers.modeling_utils import PreTrainedModel
20
+ from transformers.utils import is_flash_attn_2_available
21
 
22
  from .configuration_ovis2_5 import Siglip2NavitConfig, Ovis2_5_Config
23
 
24
+
25
+ if is_flash_attn_2_available():
26
+ from flash_attn import flash_attn_varlen_func
27
+ from flash_attn.layers.rotary import apply_rotary_emb
28
+
29
+
30
  IMAGE_PLACEHOLDER = "<image>"
31
  IMAGE_PLACEHOLDER_ID = -200
32
  VIDEO_PLACEHOLDER = "<video>"
 
35
  VISUAL_ATOM_ID = -300
36
  INDICATOR_IDS = [-301, -302, -303, -304]
37
 
38
+
39
  # copied from qwen2.5-vl
40
  class VisionRotaryEmbedding(nn.Module):
41
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
 
199
  return q_embed, k_embed
200
 
201
 
202
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
203
+ def rotate_half(x):
204
+ """Rotates half the hidden dims of the input."""
205
+ x1 = x[..., : x.shape[-1] // 2]
206
+ x2 = x[..., x.shape[-1] // 2 :]
207
+ return torch.cat((-x2, x1), dim=-1)
208
+
209
+
210
+ def apply_rotary_pos_emb_vision(
211
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
212
+ ) -> tuple[torch.Tensor, torch.Tensor]:
213
+ orig_q_dtype = q.dtype
214
+ orig_k_dtype = k.dtype
215
+ q, k = q.float(), k.float()
216
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
217
+ q_embed = (q * cos) + (rotate_half(q) * sin)
218
+ k_embed = (k * cos) + (rotate_half(k) * sin)
219
+ q_embed = q_embed.to(orig_q_dtype)
220
+ k_embed = k_embed.to(orig_k_dtype)
221
+ return q_embed, k_embed
222
+
223
+
224
  class Siglip2Attention(nn.Module):
225
  """Multi-headed attention from 'Attention Is All You Need' paper"""
226
 
 
266
 
267
  if self.use_rope:
268
  cos, sin = position_embeddings
269
+ if is_flash_attn_2_available():
270
+ queries, keys = apply_rotary_pos_emb_flashatt(queries.unsqueeze(0), keys.unsqueeze(0), cos, sin)
271
+ else:
272
+ queries, keys = apply_rotary_pos_emb_vision(queries.unsqueeze(0), keys.unsqueeze(0), cos, sin)
273
  queries = queries.squeeze(0)
274
  keys = keys.squeeze(0)
275
 
276
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
277
+ if is_flash_attn_2_available():
278
+ attn_output = flash_attn_varlen_func(queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
279
+ seq_length, -1
280
+ )
281
+ else:
282
+ batch_size = cu_seqlens.shape[0] - 1
283
+ outputs = []
284
+ cu = cu_seqlens.tolist()
285
+ for i in range(batch_size):
286
+ start_idx = cu[i]
287
+ end_idx = cu[i + 1]
288
+ # Each sequence is processed independently.
289
+ q_i = queries[start_idx:end_idx].unsqueeze(0)
290
+ k_i = keys[start_idx:end_idx].unsqueeze(0)
291
+ v_i = values[start_idx:end_idx].unsqueeze(0)
292
+ # (1, seq_len, num_heads, head_dim) ->
293
+ # (1, num_heads, seq_len, head_dim)
294
+ q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]
295
+ output_i = F.scaled_dot_product_attention(q_i,
296
+ k_i,
297
+ v_i,
298
+ dropout_p=0.0)
299
+ # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
300
+ output_i = output_i.transpose(1, 2).reshape(-1, self.embed_dim)
301
+ outputs.append(output_i)
302
+ attn_output = torch.cat(outputs, dim=0)
303
+
304
  attn_output = self.out_proj(attn_output)
305
  return attn_output
306