clean
Browse filesSigned-off-by: Isotr0py <[email protected]>
- modeling_ovis2_5.py +22 -67
modeling_ovis2_5.py
CHANGED
|
@@ -36,73 +36,6 @@ VISUAL_ATOM_ID = -300
|
|
| 36 |
INDICATOR_IDS = [-301, -302, -303, -304]
|
| 37 |
|
| 38 |
|
| 39 |
-
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 40 |
-
def rotate_half(x):
|
| 41 |
-
"""Rotates half the hidden dims of the input."""
|
| 42 |
-
x1 = x[..., : x.shape[-1] // 2]
|
| 43 |
-
x2 = x[..., x.shape[-1] // 2 :]
|
| 44 |
-
return torch.cat((-x2, x1), dim=-1)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
| 48 |
-
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
|
| 49 |
-
|
| 50 |
-
Explanation:
|
| 51 |
-
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
|
| 52 |
-
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
|
| 53 |
-
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
|
| 54 |
-
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
|
| 55 |
-
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
|
| 56 |
-
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
|
| 57 |
-
difference with modern LLMs.
|
| 58 |
-
|
| 59 |
-
Args:
|
| 60 |
-
q (`torch.Tensor`): The query tensor.
|
| 61 |
-
k (`torch.Tensor`): The key tensor.
|
| 62 |
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 63 |
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 64 |
-
position_ids (`torch.Tensor`):
|
| 65 |
-
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 66 |
-
used to pass offsetted position ids when working with a KV-cache.
|
| 67 |
-
mrope_section(`List(int)`):
|
| 68 |
-
Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
|
| 69 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 70 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 71 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 72 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 73 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 74 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 75 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 76 |
-
Returns:
|
| 77 |
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 78 |
-
"""
|
| 79 |
-
mrope_section = mrope_section * 2
|
| 80 |
-
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
| 81 |
-
unsqueeze_dim
|
| 82 |
-
)
|
| 83 |
-
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
| 84 |
-
unsqueeze_dim
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 88 |
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 89 |
-
return q_embed, k_embed
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def apply_rotary_pos_emb_vision(
|
| 93 |
-
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| 94 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 95 |
-
orig_q_dtype = q.dtype
|
| 96 |
-
orig_k_dtype = k.dtype
|
| 97 |
-
q, k = q.float(), k.float()
|
| 98 |
-
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
| 99 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 100 |
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 101 |
-
q_embed = q_embed.to(orig_q_dtype)
|
| 102 |
-
k_embed = k_embed.to(orig_k_dtype)
|
| 103 |
-
return q_embed, k_embed
|
| 104 |
-
|
| 105 |
-
|
| 106 |
# copied from qwen2.5-vl
|
| 107 |
class VisionRotaryEmbedding(nn.Module):
|
| 108 |
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
@@ -266,6 +199,28 @@ def apply_rotary_pos_emb_flashatt(
|
|
| 266 |
return q_embed, k_embed
|
| 267 |
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
class Siglip2Attention(nn.Module):
|
| 270 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 271 |
|
|
|
|
| 36 |
INDICATOR_IDS = [-301, -302, -303, -304]
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
# copied from qwen2.5-vl
|
| 40 |
class VisionRotaryEmbedding(nn.Module):
|
| 41 |
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
|
|
| 199 |
return q_embed, k_embed
|
| 200 |
|
| 201 |
|
| 202 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 203 |
+
def rotate_half(x):
|
| 204 |
+
"""Rotates half the hidden dims of the input."""
|
| 205 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 206 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 207 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def apply_rotary_pos_emb_vision(
|
| 211 |
+
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| 212 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 213 |
+
orig_q_dtype = q.dtype
|
| 214 |
+
orig_k_dtype = k.dtype
|
| 215 |
+
q, k = q.float(), k.float()
|
| 216 |
+
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
| 217 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 218 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 219 |
+
q_embed = q_embed.to(orig_q_dtype)
|
| 220 |
+
k_embed = k_embed.to(orig_k_dtype)
|
| 221 |
+
return q_embed, k_embed
|
| 222 |
+
|
| 223 |
+
|
| 224 |
class Siglip2Attention(nn.Module):
|
| 225 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 226 |
|