|
import math |
|
from typing import Optional, Sequence, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.attention.flex_attention import ( |
|
_DEFAULT_SPARSE_BLOCK_SIZE, |
|
create_block_mask, |
|
create_mask, |
|
flex_attention, |
|
) |
|
from torch.nn.attention.flex_attention import flex_attention, _vmap_for_bhqkv |
|
|
|
try: |
|
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex |
|
except ImportError: |
|
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex |
|
|
|
from torch._dynamo import disable |
|
|
|
|
|
def generate_alibi_bias(H=12): |
|
alibi_bias = [] |
|
for h in range(H): |
|
alibi_bias.append(-((h + 1) / H)) |
|
alibi_bias = torch.tensor(alibi_bias) |
|
alibi_bias = torch.exp2(alibi_bias) |
|
return alibi_bias |
|
|
|
|
|
def get_rel_bias_func(scale, coords=None, qk_scale=1.0): |
|
def patch_coords_rel_bias(score, b, h, q_idx, kv_idx): |
|
if coords is None: |
|
return score |
|
with torch.no_grad(): |
|
dx = coords[b, q_idx][0] - coords[b, kv_idx][0] |
|
dy = coords[b, q_idx][1] - coords[b, kv_idx][1] |
|
dist = torch.sqrt(dx * dx + dy * dy) |
|
dist = dist.clamp(max=1000) |
|
dist = torch.log1p(dist) |
|
bias = dist * scale[h] * qk_scale |
|
return score - bias |
|
return patch_coords_rel_bias |
|
|
|
|
|
def key_padding_mask(mask): |
|
def padding_mask(b, h, q_idx, kv_idx): |
|
return ~mask[b, kv_idx] |
|
return padding_mask |
|
|
|
|
|
class FlexCore(nn.Module): |
|
""" |
|
For using "forward hook" |
|
""" |
|
def forward(self, q, k, v, score_mod=None, block_mask=None, return_lse=False): |
|
""" |
|
"return_lse=True" should be used with ATTN_MAP_VIS wrapper. |
|
Though return_lse is "True", _flex_attention(...) only have an output (attention output, not attention scores). |
|
""" |
|
return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask, return_lse=return_lse) |
|
|
|
|
|
class Flex_Attention(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
num_heads: int = 12, |
|
qkv_bias: bool = True, |
|
proj_drop: float = 0., |
|
use_rel_bias: bool = True, |
|
): |
|
super().__init__() |
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads' |
|
self.num_heads = num_heads |
|
self.head_dim = dim // num_heads |
|
|
|
self.scale = self.head_dim ** -0.5 |
|
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) |
|
if qkv_bias: |
|
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) |
|
else: |
|
self.in_proj_bias = None |
|
|
|
self.f_attn = FlexCore() |
|
|
|
self.out_proj = nn.Linear(dim, dim) |
|
self.out_drop = nn.Dropout(proj_drop) |
|
|
|
self.max_distance=16 |
|
|
|
def build_rel_bias(self, coords): |
|
return torch.log1p(torch.cdist(coords, coords, p=2)) |
|
|
|
def forward(self, x, coords=None, attn_mask: Optional[torch.Tensor] = None, return_attn_score=False): |
|
N, L, C = x.shape |
|
|
|
x_proj = F.linear(x, self.in_proj_weight, self.in_proj_bias).contiguous() |
|
q, k, v = [t.contiguous() for t in x_proj.chunk(3, -1)] |
|
|
|
q = q.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous() |
|
k = k.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous() |
|
v = v.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous() |
|
|
|
if attn_mask is not None: |
|
maks_func = create_block_mask( |
|
key_padding_mask(attn_mask), N, self.num_heads, L, L |
|
) |
|
|
|
qk_scale = q.size(-1) ** -0.5 |
|
x = self.f_attn( |
|
q, k, v, |
|
score_mod = get_rel_bias_func(generate_alibi_bias(self.num_heads).to(coords.device), coords, qk_scale) if coords is not None else None, |
|
block_mask = maks_func if attn_mask is not None else None, |
|
return_lse=return_attn_score, |
|
) |
|
|
|
x = x.permute(0, 2, 1, 3).contiguous() |
|
x = x.reshape(N, L, C).contiguous() |
|
|
|
x = self.out_proj(x) |
|
x = self.out_drop(x) |
|
return x |
|
|