2ms's picture
init commit
03ae676
import math
from typing import Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE,
create_block_mask,
create_mask,
flex_attention,
)
from torch.nn.attention.flex_attention import flex_attention, _vmap_for_bhqkv
try:
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
except ImportError:
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
from torch._dynamo import disable
def generate_alibi_bias(H=12):
alibi_bias = []
for h in range(H):
alibi_bias.append(-((h + 1) / H))
alibi_bias = torch.tensor(alibi_bias)
alibi_bias = torch.exp2(alibi_bias)
return alibi_bias
def get_rel_bias_func(scale, coords=None, qk_scale=1.0):
def patch_coords_rel_bias(score, b, h, q_idx, kv_idx):
if coords is None:
return score
with torch.no_grad():
dx = coords[b, q_idx][0] - coords[b, kv_idx][0]
dy = coords[b, q_idx][1] - coords[b, kv_idx][1]
dist = torch.sqrt(dx * dx + dy * dy)
dist = dist.clamp(max=1000) # max distance
dist = torch.log1p(dist)
bias = dist * scale[h] * qk_scale
return score - bias # closer → larger score
return patch_coords_rel_bias
def key_padding_mask(mask):
def padding_mask(b, h, q_idx, kv_idx):
return ~mask[b, kv_idx]
return padding_mask
class FlexCore(nn.Module):
"""
For using "forward hook"
"""
def forward(self, q, k, v, score_mod=None, block_mask=None, return_lse=False):
"""
"return_lse=True" should be used with ATTN_MAP_VIS wrapper.
Though return_lse is "True", _flex_attention(...) only have an output (attention output, not attention scores).
"""
return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask, return_lse=return_lse)
class Flex_Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 12,
qkv_bias: bool = True,
proj_drop: float = 0.,
use_rel_bias: bool = True,
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
if qkv_bias:
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
else:
self.in_proj_bias = None
self.f_attn = FlexCore()
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)
self.max_distance=16
def build_rel_bias(self, coords):
return torch.log1p(torch.cdist(coords, coords, p=2))
def forward(self, x, coords=None, attn_mask: Optional[torch.Tensor] = None, return_attn_score=False):
N, L, C = x.shape
# ensure contiguous projection before chunking
x_proj = F.linear(x, self.in_proj_weight, self.in_proj_bias).contiguous()
q, k, v = [t.contiguous() for t in x_proj.chunk(3, -1)]
q = q.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
k = k.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
v = v.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
if attn_mask is not None:
maks_func = create_block_mask(
key_padding_mask(attn_mask), N, self.num_heads, L, L
)
qk_scale = q.size(-1) ** -0.5
x = self.f_attn(
q, k, v,
score_mod = get_rel_bias_func(generate_alibi_bias(self.num_heads).to(coords.device), coords, qk_scale) if coords is not None else None,
block_mask = maks_func if attn_mask is not None else None,
return_lse=return_attn_score,
)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.reshape(N, L, C).contiguous()
x = self.out_proj(x)
x = self.out_drop(x)
return x