Isotr0py commited on
Commit
8d086dc
·
1 Parent(s): 5456300

add sdpa fallback

Browse files

Signed-off-by: Isotr0py <[email protected]>

Files changed (1) hide show
  1. modeling_ovis2_5.py +106 -6
modeling_ovis2_5.py CHANGED
@@ -4,8 +4,6 @@ from typing import Dict, List, Optional, Tuple, Union
4
  import PIL.Image
5
  import numpy as np
6
  import torch
7
- from flash_attn import flash_attn_varlen_func
8
- from flash_attn.layers.rotary import apply_rotary_emb
9
  from torch import Tensor, nn
10
  from torch.nn import functional as F
11
  from transformers import (
@@ -19,9 +17,16 @@ from transformers.activations import ACT2FN
19
  from transformers.generation.utils import GenerateOutput
20
  from transformers.modeling_outputs import BaseModelOutputWithNoAttention
21
  from transformers.modeling_utils import PreTrainedModel
 
22
 
23
  from .configuration_ovis2_5 import Siglip2NavitConfig, Ovis2_5_Config
24
 
 
 
 
 
 
 
25
  IMAGE_PLACEHOLDER = "<image>"
26
  IMAGE_PLACEHOLDER_ID = -200
27
  VIDEO_PLACEHOLDER = "<video>"
@@ -30,6 +35,74 @@ VIDEO_PLACEHOLDER_ID = -201
30
  VISUAL_ATOM_ID = -300
31
  INDICATOR_IDS = [-301, -302, -303, -304]
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # copied from qwen2.5-vl
34
  class VisionRotaryEmbedding(nn.Module):
35
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
@@ -238,14 +311,41 @@ class Siglip2Attention(nn.Module):
238
 
239
  if self.use_rope:
240
  cos, sin = position_embeddings
241
- queries, keys = apply_rotary_pos_emb_flashatt(queries.unsqueeze(0), keys.unsqueeze(0), cos, sin)
 
 
 
242
  queries = queries.squeeze(0)
243
  keys = keys.squeeze(0)
244
 
245
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
246
- attn_output = flash_attn_varlen_func(queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
247
- seq_length, -1
248
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  attn_output = self.out_proj(attn_output)
250
  return attn_output
251
 
 
4
  import PIL.Image
5
  import numpy as np
6
  import torch
 
 
7
  from torch import Tensor, nn
8
  from torch.nn import functional as F
9
  from transformers import (
 
17
  from transformers.generation.utils import GenerateOutput
18
  from transformers.modeling_outputs import BaseModelOutputWithNoAttention
19
  from transformers.modeling_utils import PreTrainedModel
20
+ from transformers.utils import is_flash_attn_2_available
21
 
22
  from .configuration_ovis2_5 import Siglip2NavitConfig, Ovis2_5_Config
23
 
24
+
25
+ if is_flash_attn_2_available():
26
+ from flash_attn import flash_attn_varlen_func
27
+ from flash_attn.layers.rotary import apply_rotary_emb
28
+
29
+
30
  IMAGE_PLACEHOLDER = "<image>"
31
  IMAGE_PLACEHOLDER_ID = -200
32
  VIDEO_PLACEHOLDER = "<video>"
 
35
  VISUAL_ATOM_ID = -300
36
  INDICATOR_IDS = [-301, -302, -303, -304]
37
 
38
+
39
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
40
+ def rotate_half(x):
41
+ """Rotates half the hidden dims of the input."""
42
+ x1 = x[..., : x.shape[-1] // 2]
43
+ x2 = x[..., x.shape[-1] // 2 :]
44
+ return torch.cat((-x2, x1), dim=-1)
45
+
46
+
47
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
48
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
49
+
50
+ Explanation:
51
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
52
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
53
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
54
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
55
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
56
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
57
+ difference with modern LLMs.
58
+
59
+ Args:
60
+ q (`torch.Tensor`): The query tensor.
61
+ k (`torch.Tensor`): The key tensor.
62
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
63
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
64
+ position_ids (`torch.Tensor`):
65
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
66
+ used to pass offsetted position ids when working with a KV-cache.
67
+ mrope_section(`List(int)`):
68
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
69
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
70
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
71
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
72
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
73
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
74
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
75
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
76
+ Returns:
77
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
78
+ """
79
+ mrope_section = mrope_section * 2
80
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
81
+ unsqueeze_dim
82
+ )
83
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
84
+ unsqueeze_dim
85
+ )
86
+
87
+ q_embed = (q * cos) + (rotate_half(q) * sin)
88
+ k_embed = (k * cos) + (rotate_half(k) * sin)
89
+ return q_embed, k_embed
90
+
91
+
92
+ def apply_rotary_pos_emb_vision(
93
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
94
+ ) -> tuple[torch.Tensor, torch.Tensor]:
95
+ orig_q_dtype = q.dtype
96
+ orig_k_dtype = k.dtype
97
+ q, k = q.float(), k.float()
98
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
99
+ q_embed = (q * cos) + (rotate_half(q) * sin)
100
+ k_embed = (k * cos) + (rotate_half(k) * sin)
101
+ q_embed = q_embed.to(orig_q_dtype)
102
+ k_embed = k_embed.to(orig_k_dtype)
103
+ return q_embed, k_embed
104
+
105
+
106
  # copied from qwen2.5-vl
107
  class VisionRotaryEmbedding(nn.Module):
108
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
 
311
 
312
  if self.use_rope:
313
  cos, sin = position_embeddings
314
+ if is_flash_attn_2_available():
315
+ queries, keys = apply_rotary_pos_emb_flashatt(queries.unsqueeze(0), keys.unsqueeze(0), cos, sin)
316
+ else:
317
+ queries, keys = apply_rotary_pos_emb_vision(queries.unsqueeze(0), keys.unsqueeze(0), cos, sin)
318
  queries = queries.squeeze(0)
319
  keys = keys.squeeze(0)
320
 
321
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
322
+ if is_flash_attn_2_available():
323
+ attn_output = flash_attn_varlen_func(queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
324
+ seq_length, -1
325
+ )
326
+ else:
327
+ batch_size = cu_seqlens.shape[0] - 1
328
+ outputs = []
329
+ cu = cu_seqlens.tolist()
330
+ for i in range(batch_size):
331
+ start_idx = cu[i]
332
+ end_idx = cu[i + 1]
333
+ # Each sequence is processed independently.
334
+ q_i = queries[start_idx:end_idx].unsqueeze(0)
335
+ k_i = keys[start_idx:end_idx].unsqueeze(0)
336
+ v_i = values[start_idx:end_idx].unsqueeze(0)
337
+ # (1, seq_len, num_heads, head_dim) ->
338
+ # (1, num_heads, seq_len, head_dim)
339
+ q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]
340
+ output_i = F.scaled_dot_product_attention(q_i,
341
+ k_i,
342
+ v_i,
343
+ dropout_p=0.0)
344
+ # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
345
+ output_i = output_i.transpose(1, 2).reshape(-1, self.embed_dim)
346
+ outputs.append(output_i)
347
+ attn_output = torch.cat(outputs, dim=0)
348
+
349
  attn_output = self.out_proj(attn_output)
350
  return attn_output
351