import math from typing import Optional, Sequence, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.attention.flex_attention import ( _DEFAULT_SPARSE_BLOCK_SIZE, create_block_mask, create_mask, flex_attention, ) from torch.nn.attention.flex_attention import flex_attention, _vmap_for_bhqkv try: from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex except ImportError: from torch._higher_order_ops.flex_attention import TransformGetItemToIndex from torch._dynamo import disable def generate_alibi_bias(H=12): alibi_bias = [] for h in range(H): alibi_bias.append(-((h + 1) / H)) alibi_bias = torch.tensor(alibi_bias) alibi_bias = torch.exp2(alibi_bias) return alibi_bias def get_rel_bias_func(scale, coords=None, qk_scale=1.0): def patch_coords_rel_bias(score, b, h, q_idx, kv_idx): if coords is None: return score with torch.no_grad(): dx = coords[b, q_idx][0] - coords[b, kv_idx][0] dy = coords[b, q_idx][1] - coords[b, kv_idx][1] dist = torch.sqrt(dx * dx + dy * dy) dist = dist.clamp(max=1000) # max distance dist = torch.log1p(dist) bias = dist * scale[h] * qk_scale return score - bias # closer → larger score return patch_coords_rel_bias def key_padding_mask(mask): def padding_mask(b, h, q_idx, kv_idx): return ~mask[b, kv_idx] return padding_mask class FlexCore(nn.Module): """ For using "forward hook" """ def forward(self, q, k, v, score_mod=None, block_mask=None, return_lse=False): """ "return_lse=True" should be used with ATTN_MAP_VIS wrapper. Though return_lse is "True", _flex_attention(...) only have an output (attention output, not attention scores). """ return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask, return_lse=return_lse) class Flex_Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 12, qkv_bias: bool = True, proj_drop: float = 0., use_rel_bias: bool = True, ): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) if qkv_bias: self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) else: self.in_proj_bias = None self.f_attn = FlexCore() self.out_proj = nn.Linear(dim, dim) self.out_drop = nn.Dropout(proj_drop) self.max_distance=16 def build_rel_bias(self, coords): return torch.log1p(torch.cdist(coords, coords, p=2)) def forward(self, x, coords=None, attn_mask: Optional[torch.Tensor] = None, return_attn_score=False): N, L, C = x.shape # ensure contiguous projection before chunking x_proj = F.linear(x, self.in_proj_weight, self.in_proj_bias).contiguous() q, k, v = [t.contiguous() for t in x_proj.chunk(3, -1)] q = q.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous() k = k.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous() v = v.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous() if attn_mask is not None: maks_func = create_block_mask( key_padding_mask(attn_mask), N, self.num_heads, L, L ) qk_scale = q.size(-1) ** -0.5 x = self.f_attn( q, k, v, score_mod = get_rel_bias_func(generate_alibi_bias(self.num_heads).to(coords.device), coords, qk_scale) if coords is not None else None, block_mask = maks_func if attn_mask is not None else None, return_lse=return_attn_score, ) x = x.permute(0, 2, 1, 3).contiguous() x = x.reshape(N, L, C).contiguous() x = self.out_proj(x) x = self.out_drop(x) return x