clean
Browse filesSigned-off-by: Isotr0py <[email protected]>
- modeling_ovis2_5.py +22 -67
modeling_ovis2_5.py
CHANGED
@@ -36,73 +36,6 @@ VISUAL_ATOM_ID = -300
|
|
36 |
INDICATOR_IDS = [-301, -302, -303, -304]
|
37 |
|
38 |
|
39 |
-
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
40 |
-
def rotate_half(x):
|
41 |
-
"""Rotates half the hidden dims of the input."""
|
42 |
-
x1 = x[..., : x.shape[-1] // 2]
|
43 |
-
x2 = x[..., x.shape[-1] // 2 :]
|
44 |
-
return torch.cat((-x2, x1), dim=-1)
|
45 |
-
|
46 |
-
|
47 |
-
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
48 |
-
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
|
49 |
-
|
50 |
-
Explanation:
|
51 |
-
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
|
52 |
-
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
|
53 |
-
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
|
54 |
-
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
|
55 |
-
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
|
56 |
-
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
|
57 |
-
difference with modern LLMs.
|
58 |
-
|
59 |
-
Args:
|
60 |
-
q (`torch.Tensor`): The query tensor.
|
61 |
-
k (`torch.Tensor`): The key tensor.
|
62 |
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
63 |
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
64 |
-
position_ids (`torch.Tensor`):
|
65 |
-
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
66 |
-
used to pass offsetted position ids when working with a KV-cache.
|
67 |
-
mrope_section(`List(int)`):
|
68 |
-
Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
|
69 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
70 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
71 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
72 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
73 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
74 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
75 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
76 |
-
Returns:
|
77 |
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
78 |
-
"""
|
79 |
-
mrope_section = mrope_section * 2
|
80 |
-
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
81 |
-
unsqueeze_dim
|
82 |
-
)
|
83 |
-
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
84 |
-
unsqueeze_dim
|
85 |
-
)
|
86 |
-
|
87 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
88 |
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
89 |
-
return q_embed, k_embed
|
90 |
-
|
91 |
-
|
92 |
-
def apply_rotary_pos_emb_vision(
|
93 |
-
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
94 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
95 |
-
orig_q_dtype = q.dtype
|
96 |
-
orig_k_dtype = k.dtype
|
97 |
-
q, k = q.float(), k.float()
|
98 |
-
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
99 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
100 |
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
101 |
-
q_embed = q_embed.to(orig_q_dtype)
|
102 |
-
k_embed = k_embed.to(orig_k_dtype)
|
103 |
-
return q_embed, k_embed
|
104 |
-
|
105 |
-
|
106 |
# copied from qwen2.5-vl
|
107 |
class VisionRotaryEmbedding(nn.Module):
|
108 |
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
@@ -266,6 +199,28 @@ def apply_rotary_pos_emb_flashatt(
|
|
266 |
return q_embed, k_embed
|
267 |
|
268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
class Siglip2Attention(nn.Module):
|
270 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
271 |
|
|
|
36 |
INDICATOR_IDS = [-301, -302, -303, -304]
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
# copied from qwen2.5-vl
|
40 |
class VisionRotaryEmbedding(nn.Module):
|
41 |
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
|
199 |
return q_embed, k_embed
|
200 |
|
201 |
|
202 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
203 |
+
def rotate_half(x):
|
204 |
+
"""Rotates half the hidden dims of the input."""
|
205 |
+
x1 = x[..., : x.shape[-1] // 2]
|
206 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
207 |
+
return torch.cat((-x2, x1), dim=-1)
|
208 |
+
|
209 |
+
|
210 |
+
def apply_rotary_pos_emb_vision(
|
211 |
+
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
212 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
213 |
+
orig_q_dtype = q.dtype
|
214 |
+
orig_k_dtype = k.dtype
|
215 |
+
q, k = q.float(), k.float()
|
216 |
+
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
217 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
218 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
219 |
+
q_embed = q_embed.to(orig_q_dtype)
|
220 |
+
k_embed = k_embed.to(orig_k_dtype)
|
221 |
+
return q_embed, k_embed
|
222 |
+
|
223 |
+
|
224 |
class Siglip2Attention(nn.Module):
|
225 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
226 |
|