File size: 4,319 Bytes
03ae676
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import math
from typing import Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention.flex_attention import (
    _DEFAULT_SPARSE_BLOCK_SIZE,
    create_block_mask,
    create_mask,
    flex_attention,
)
from torch.nn.attention.flex_attention import flex_attention, _vmap_for_bhqkv

try:
    from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
except ImportError:
    from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
    
from torch._dynamo import disable


def generate_alibi_bias(H=12):
    alibi_bias = []
    for h in range(H):
        alibi_bias.append(-((h + 1) / H))
    alibi_bias = torch.tensor(alibi_bias)
    alibi_bias = torch.exp2(alibi_bias)
    return alibi_bias


def get_rel_bias_func(scale, coords=None, qk_scale=1.0):
    def patch_coords_rel_bias(score, b, h, q_idx, kv_idx):
        if coords is None:
            return score
        with torch.no_grad():
            dx = coords[b, q_idx][0] - coords[b, kv_idx][0]
            dy = coords[b, q_idx][1] - coords[b, kv_idx][1]
            dist = torch.sqrt(dx * dx + dy * dy)
            dist = dist.clamp(max=1000)  # max distance
            dist = torch.log1p(dist)
            bias = dist * scale[h] * qk_scale
        return score - bias  # closer → larger score
    return patch_coords_rel_bias


def key_padding_mask(mask):
    def padding_mask(b, h, q_idx, kv_idx):
        return ~mask[b, kv_idx]
    return padding_mask


class FlexCore(nn.Module):
    """
        For using "forward hook"
    """
    def forward(self, q, k, v, score_mod=None, block_mask=None, return_lse=False):
        """
            "return_lse=True" should be used with ATTN_MAP_VIS wrapper.
            Though return_lse is "True", _flex_attention(...) only have an output (attention output, not attention scores).
        """
        return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask, return_lse=return_lse)


class Flex_Attention(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int = 12,
            qkv_bias: bool = True,
            proj_drop: float = 0.,
            use_rel_bias: bool = True,
    ):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.scale = self.head_dim ** -0.5
        self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
        if qkv_bias:
            self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
        else:
            self.in_proj_bias = None
            
        self.f_attn = FlexCore()

        self.out_proj = nn.Linear(dim, dim)
        self.out_drop = nn.Dropout(proj_drop)
        
        self.max_distance=16
    
    def build_rel_bias(self, coords):
        return torch.log1p(torch.cdist(coords, coords, p=2))

    def forward(self, x, coords=None, attn_mask: Optional[torch.Tensor] = None, return_attn_score=False):
        N, L, C = x.shape
        # ensure contiguous projection before chunking
        x_proj = F.linear(x, self.in_proj_weight, self.in_proj_bias).contiguous()
        q, k, v = [t.contiguous() for t in x_proj.chunk(3, -1)]

        q = q.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
        k = k.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
        v = v.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
        
        if attn_mask is not None:
            maks_func = create_block_mask(
                key_padding_mask(attn_mask), N, self.num_heads, L, L
            )
        
        qk_scale = q.size(-1) ** -0.5
        x = self.f_attn(
                q, k, v,
                score_mod = get_rel_bias_func(generate_alibi_bias(self.num_heads).to(coords.device), coords, qk_scale) if coords is not None else None,
                block_mask = maks_func if attn_mask is not None else None,
                return_lse=return_attn_score,
        )
        
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.reshape(N, L, C).contiguous()

        x = self.out_proj(x)
        x = self.out_drop(x)
        return x