File size: 4,319 Bytes
03ae676 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import math
from typing import Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE,
create_block_mask,
create_mask,
flex_attention,
)
from torch.nn.attention.flex_attention import flex_attention, _vmap_for_bhqkv
try:
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
except ImportError:
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
from torch._dynamo import disable
def generate_alibi_bias(H=12):
alibi_bias = []
for h in range(H):
alibi_bias.append(-((h + 1) / H))
alibi_bias = torch.tensor(alibi_bias)
alibi_bias = torch.exp2(alibi_bias)
return alibi_bias
def get_rel_bias_func(scale, coords=None, qk_scale=1.0):
def patch_coords_rel_bias(score, b, h, q_idx, kv_idx):
if coords is None:
return score
with torch.no_grad():
dx = coords[b, q_idx][0] - coords[b, kv_idx][0]
dy = coords[b, q_idx][1] - coords[b, kv_idx][1]
dist = torch.sqrt(dx * dx + dy * dy)
dist = dist.clamp(max=1000) # max distance
dist = torch.log1p(dist)
bias = dist * scale[h] * qk_scale
return score - bias # closer → larger score
return patch_coords_rel_bias
def key_padding_mask(mask):
def padding_mask(b, h, q_idx, kv_idx):
return ~mask[b, kv_idx]
return padding_mask
class FlexCore(nn.Module):
"""
For using "forward hook"
"""
def forward(self, q, k, v, score_mod=None, block_mask=None, return_lse=False):
"""
"return_lse=True" should be used with ATTN_MAP_VIS wrapper.
Though return_lse is "True", _flex_attention(...) only have an output (attention output, not attention scores).
"""
return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask, return_lse=return_lse)
class Flex_Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 12,
qkv_bias: bool = True,
proj_drop: float = 0.,
use_rel_bias: bool = True,
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
if qkv_bias:
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
else:
self.in_proj_bias = None
self.f_attn = FlexCore()
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)
self.max_distance=16
def build_rel_bias(self, coords):
return torch.log1p(torch.cdist(coords, coords, p=2))
def forward(self, x, coords=None, attn_mask: Optional[torch.Tensor] = None, return_attn_score=False):
N, L, C = x.shape
# ensure contiguous projection before chunking
x_proj = F.linear(x, self.in_proj_weight, self.in_proj_bias).contiguous()
q, k, v = [t.contiguous() for t in x_proj.chunk(3, -1)]
q = q.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
k = k.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
v = v.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
if attn_mask is not None:
maks_func = create_block_mask(
key_padding_mask(attn_mask), N, self.num_heads, L, L
)
qk_scale = q.size(-1) ** -0.5
x = self.f_attn(
q, k, v,
score_mod = get_rel_bias_func(generate_alibi_bias(self.num_heads).to(coords.device), coords, qk_scale) if coords is not None else None,
block_mask = maks_func if attn_mask is not None else None,
return_lse=return_attn_score,
)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.reshape(N, L, C).contiguous()
x = self.out_proj(x)
x = self.out_drop(x)
return x
|