File size: 7,066 Bytes
4c9ea57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

class CausalSelfAttention(nn.Module):
    def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0):
        super().__init__()
        assert dim % n_heads == 0
        self.nh = n_heads; self.hd = dim // n_heads
        self.qkv = nn.Linear(dim, 3*dim, bias=False)
        self.proj = nn.Linear(dim, dim, bias=False)
        self.attn_dropout = attn_dropout
    def forward(self, x):
        B,T,C = x.shape
        qkv = self.qkv(x); q,k,v = qkv.chunk(3, dim=-1)
        q = q.view(B,T,self.nh,self.hd).transpose(1,2)
        k = k.view(B,T,self.nh,self.hd).transpose(1,2)
        v = v.view(B,T,self.nh,self.hd).transpose(1,2)
        if x.is_cuda:
            with sdpa_ctx_prefer_flash():
                y = F.scaled_dot_product_attention(q,k,v,is_causal=True,
                    dropout_p=self.attn_dropout if self.training else 0.0)
        else:
            scale = 1.0 / math.sqrt(self.hd)
            att = (q @ k.transpose(-2,-1)) * scale
            mask = torch.full((1,1,T,T), float("-inf"), device=x.device)
            mask = torch.triu(mask, diagonal=1)
            att = (att + mask).softmax(dim=-1)
            y = att @ v
        y = y.transpose(1,2).contiguous().view(B,T,C)
        return self.proj(y)

def _normalize_cell(X):  # X: [V,D]
    Xc = X - X.mean(dim=0, keepdim=True)
    r = Xc.pow(2).sum(dim=1).mean().sqrt().clamp_min(1e-6)
    return Xc / r

class CrystalBank(nn.Module):
    def __init__(self, regions: int, dim: int):
        super().__init__()
        pts = torch.randn(regions, 5, dim) / math.sqrt(dim)
        with torch.no_grad():
            for i in range(regions): pts[i] = _normalize_cell(pts[i])
        self.anchors = nn.Parameter(pts)  # [C,5,D]

class GeometricGate(nn.Module):
    def __init__(self, dim: int, regions: int, tau: float = 0.08):
        super().__init__()
        self.bank = CrystalBank(regions, dim)
        self.nav  = nn.Linear(dim, dim, bias=False)
        self.tau  = tau
        self.scale = nn.Parameter(torch.tensor(1.0))  # residual mix scaler

    def forward(self, h: torch.Tensor, punct_mask: Optional[torch.Tensor] = None, alpha_gate: float = 1.0, hard_mask_gate: bool=False):
        B,T,D = h.shape
        C = self.bank.anchors.size(0)
        H = self.nav(h).reshape(B*T, D)           # [BT,D]
        A = self.bank.anchors.reshape(C*5, D)     # [C*5,D]
        # squared distances via expansion
        x2 = (H*H).sum(dim=-1, keepdim=True)      # [BT,1]
        a2 = (A*A).sum(dim=-1).unsqueeze(0)       # [1,C*5]
        xa = H @ A.T                               # [BT,C*5]
        d2 = (x2 + a2 - 2*xa).clamp_min(0.0).view(B*T, C, 5)
        s = -torch.logsumexp(-d2 / max(1e-6,self.tau), dim=-1)  # [BT,C]
        w = F.softmax(-s / max(1e-6,self.tau), dim=-1).view(B,T,C)

        centroids = self.bank.anchors.mean(dim=1)  # [C,D]
        g = (w @ centroids)                        # [B,T,D]

        if punct_mask is not None:
            # alpha soft mask reduces gate on punctuation tokens
            pm = punct_mask.float().unsqueeze(-1)  # [B,T,1], 1 on punct
            if hard_mask_gate:
                g = g * (1.0 - pm)                 # zero gate on punct
            else:
                g = g * (1.0 - pm*(1.0 - alpha_gate))

        return w, self.scale * g

class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        hidden = int(dim*mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x, approximate="tanh")
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class GatedBlock(nn.Module):
    def __init__(self, dim, n_heads, mlp_ratio, dropout, regions):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = CausalSelfAttention(dim, n_heads, attn_dropout=dropout)
        self.gate  = GeometricGate(dim, regions=regions, tau=0.08)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp   = MLP(dim, mlp_ratio=mlp_ratio, dropout=dropout)
        self.mix_sdpa = nn.Parameter(torch.tensor(1.0))  # stage-adjustable

    def forward(self, x, punct_mask=None, return_gate=False, alpha_gate=1.0, hard_mask_gate=False):
        h = self.norm1(x)
        att = self.attn(h) * self.mix_sdpa
        w, g = self.gate(h, punct_mask=punct_mask, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
        x = x + att + g
        x = x + self.mlp(self.norm2(x))
        return (x, w) if return_gate else x

class CrystalBeeper(nn.Module):
    def __init__(self, model_cfg: dict):
        super().__init__()
        self.cfg = model_cfg
        D, L, H = model_cfg["dim"], model_cfg["n_layers"], model_cfg["n_heads"]
        ctx     = model_cfg["context"]

        # Ingress
        self.use_ascii = bool(model_cfg.get("use_ascii", True))
        self.codec = AsciiCodec()
        V_ascii = self.codec.vocab_size

        self.token_emb = nn.Embedding(V_ascii, D)
        self.pos_emb   = nn.Parameter(torch.zeros(1, ctx, D))
        self.drop      = nn.Dropout(model_cfg.get("resid_dropout", 0.1))

        regions = [int(model_cfg.get("regions_per_block", 64)) for _ in range(L)]
        self.blocks = nn.ModuleList([
            GatedBlock(D, H, model_cfg["mlp_ratio"], model_cfg["dropout"], regions[i])
            for i in range(L)
        ])
        self.norm = nn.LayerNorm(D)

        # Output heads
        self.ascii_head = nn.Linear(D, V_ascii, bias=False)
        self.bpe_head   = nn.Linear(D, model_cfg.get("vocab_size", 8192), bias=False)  # optional

        # Global tug (Rose)
        self.rose_proj    = nn.Linear(D, D, bias=False)
        self.rose_anchors = nn.Parameter(torch.randn(3, D) / math.sqrt(D))

        self.apply(self._init)

    @staticmethod
    def _init(m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if getattr(m, "bias", None) is not None: nn.init.zeros_(m.bias)

    def backbone(self, idx, punct_mask=None, return_routes=False, alpha_gate=1.0, hard_mask_gate=False):
        B,T = idx.shape
        x = self.token_emb(idx) + self.pos_emb[:, :T, :]
        x = self.drop(x)
        routes = []
        for blk in self.blocks:
            x, w = blk(x, punct_mask=punct_mask, return_gate=True, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
            if return_routes: routes.append(w)
        x = self.norm(x)
        return (x, routes) if return_routes else (x, None)

    def forward(self, idx, punct_mask=None, head="ascii", return_routes=False, alpha_gate=1.0, hard_mask_gate=False):
        h, routes = self.backbone(idx, punct_mask=punct_mask, return_routes=return_routes, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
        if head == "ascii": logits = self.ascii_head(h)
        elif head == "bpe": logits = self.bpe_head(h)
        else: raise ValueError("head must be 'ascii' or 'bpe'")
        return logits, routes