Isotr0py commited on
Commit
ad6d85f
·
1 Parent(s): 8d086dc

Signed-off-by: Isotr0py <[email protected]>

Files changed (1) hide show
  1. modeling_ovis2_5.py +22 -67
modeling_ovis2_5.py CHANGED
@@ -36,73 +36,6 @@ VISUAL_ATOM_ID = -300
36
  INDICATOR_IDS = [-301, -302, -303, -304]
37
 
38
 
39
- # Copied from transformers.models.llama.modeling_llama.rotate_half
40
- def rotate_half(x):
41
- """Rotates half the hidden dims of the input."""
42
- x1 = x[..., : x.shape[-1] // 2]
43
- x2 = x[..., x.shape[-1] // 2 :]
44
- return torch.cat((-x2, x1), dim=-1)
45
-
46
-
47
- def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
48
- """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
49
-
50
- Explanation:
51
- Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
52
- sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
53
- vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
54
- Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
55
- For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
56
- height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
57
- difference with modern LLMs.
58
-
59
- Args:
60
- q (`torch.Tensor`): The query tensor.
61
- k (`torch.Tensor`): The key tensor.
62
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
63
- sin (`torch.Tensor`): The sine part of the rotary embedding.
64
- position_ids (`torch.Tensor`):
65
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
66
- used to pass offsetted position ids when working with a KV-cache.
67
- mrope_section(`List(int)`):
68
- Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
69
- unsqueeze_dim (`int`, *optional*, defaults to 1):
70
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
71
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
72
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
73
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
74
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
75
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
76
- Returns:
77
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
78
- """
79
- mrope_section = mrope_section * 2
80
- cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
81
- unsqueeze_dim
82
- )
83
- sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
84
- unsqueeze_dim
85
- )
86
-
87
- q_embed = (q * cos) + (rotate_half(q) * sin)
88
- k_embed = (k * cos) + (rotate_half(k) * sin)
89
- return q_embed, k_embed
90
-
91
-
92
- def apply_rotary_pos_emb_vision(
93
- q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
94
- ) -> tuple[torch.Tensor, torch.Tensor]:
95
- orig_q_dtype = q.dtype
96
- orig_k_dtype = k.dtype
97
- q, k = q.float(), k.float()
98
- cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
99
- q_embed = (q * cos) + (rotate_half(q) * sin)
100
- k_embed = (k * cos) + (rotate_half(k) * sin)
101
- q_embed = q_embed.to(orig_q_dtype)
102
- k_embed = k_embed.to(orig_k_dtype)
103
- return q_embed, k_embed
104
-
105
-
106
  # copied from qwen2.5-vl
107
  class VisionRotaryEmbedding(nn.Module):
108
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
@@ -266,6 +199,28 @@ def apply_rotary_pos_emb_flashatt(
266
  return q_embed, k_embed
267
 
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  class Siglip2Attention(nn.Module):
270
  """Multi-headed attention from 'Attention Is All You Need' paper"""
271
 
 
36
  INDICATOR_IDS = [-301, -302, -303, -304]
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # copied from qwen2.5-vl
40
  class VisionRotaryEmbedding(nn.Module):
41
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
 
199
  return q_embed, k_embed
200
 
201
 
202
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
203
+ def rotate_half(x):
204
+ """Rotates half the hidden dims of the input."""
205
+ x1 = x[..., : x.shape[-1] // 2]
206
+ x2 = x[..., x.shape[-1] // 2 :]
207
+ return torch.cat((-x2, x1), dim=-1)
208
+
209
+
210
+ def apply_rotary_pos_emb_vision(
211
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
212
+ ) -> tuple[torch.Tensor, torch.Tensor]:
213
+ orig_q_dtype = q.dtype
214
+ orig_k_dtype = k.dtype
215
+ q, k = q.float(), k.float()
216
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
217
+ q_embed = (q * cos) + (rotate_half(q) * sin)
218
+ k_embed = (k * cos) + (rotate_half(k) * sin)
219
+ q_embed = q_embed.to(orig_q_dtype)
220
+ k_embed = k_embed.to(orig_k_dtype)
221
+ return q_embed, k_embed
222
+
223
+
224
  class Siglip2Attention(nn.Module):
225
  """Multi-headed attention from 'Attention Is All You Need' paper"""
226