AbstractPhil commited on
Commit
4c9ea57
·
verified ·
1 Parent(s): 2c70d6c

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +165 -0
model.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class CausalSelfAttention(nn.Module):
3
+ def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0):
4
+ super().__init__()
5
+ assert dim % n_heads == 0
6
+ self.nh = n_heads; self.hd = dim // n_heads
7
+ self.qkv = nn.Linear(dim, 3*dim, bias=False)
8
+ self.proj = nn.Linear(dim, dim, bias=False)
9
+ self.attn_dropout = attn_dropout
10
+ def forward(self, x):
11
+ B,T,C = x.shape
12
+ qkv = self.qkv(x); q,k,v = qkv.chunk(3, dim=-1)
13
+ q = q.view(B,T,self.nh,self.hd).transpose(1,2)
14
+ k = k.view(B,T,self.nh,self.hd).transpose(1,2)
15
+ v = v.view(B,T,self.nh,self.hd).transpose(1,2)
16
+ if x.is_cuda:
17
+ with sdpa_ctx_prefer_flash():
18
+ y = F.scaled_dot_product_attention(q,k,v,is_causal=True,
19
+ dropout_p=self.attn_dropout if self.training else 0.0)
20
+ else:
21
+ scale = 1.0 / math.sqrt(self.hd)
22
+ att = (q @ k.transpose(-2,-1)) * scale
23
+ mask = torch.full((1,1,T,T), float("-inf"), device=x.device)
24
+ mask = torch.triu(mask, diagonal=1)
25
+ att = (att + mask).softmax(dim=-1)
26
+ y = att @ v
27
+ y = y.transpose(1,2).contiguous().view(B,T,C)
28
+ return self.proj(y)
29
+
30
+ def _normalize_cell(X): # X: [V,D]
31
+ Xc = X - X.mean(dim=0, keepdim=True)
32
+ r = Xc.pow(2).sum(dim=1).mean().sqrt().clamp_min(1e-6)
33
+ return Xc / r
34
+
35
+ class CrystalBank(nn.Module):
36
+ def __init__(self, regions: int, dim: int):
37
+ super().__init__()
38
+ pts = torch.randn(regions, 5, dim) / math.sqrt(dim)
39
+ with torch.no_grad():
40
+ for i in range(regions): pts[i] = _normalize_cell(pts[i])
41
+ self.anchors = nn.Parameter(pts) # [C,5,D]
42
+
43
+ class GeometricGate(nn.Module):
44
+ def __init__(self, dim: int, regions: int, tau: float = 0.08):
45
+ super().__init__()
46
+ self.bank = CrystalBank(regions, dim)
47
+ self.nav = nn.Linear(dim, dim, bias=False)
48
+ self.tau = tau
49
+ self.scale = nn.Parameter(torch.tensor(1.0)) # residual mix scaler
50
+
51
+ def forward(self, h: torch.Tensor, punct_mask: Optional[torch.Tensor] = None, alpha_gate: float = 1.0, hard_mask_gate: bool=False):
52
+ B,T,D = h.shape
53
+ C = self.bank.anchors.size(0)
54
+ H = self.nav(h).reshape(B*T, D) # [BT,D]
55
+ A = self.bank.anchors.reshape(C*5, D) # [C*5,D]
56
+ # squared distances via expansion
57
+ x2 = (H*H).sum(dim=-1, keepdim=True) # [BT,1]
58
+ a2 = (A*A).sum(dim=-1).unsqueeze(0) # [1,C*5]
59
+ xa = H @ A.T # [BT,C*5]
60
+ d2 = (x2 + a2 - 2*xa).clamp_min(0.0).view(B*T, C, 5)
61
+ s = -torch.logsumexp(-d2 / max(1e-6,self.tau), dim=-1) # [BT,C]
62
+ w = F.softmax(-s / max(1e-6,self.tau), dim=-1).view(B,T,C)
63
+
64
+ centroids = self.bank.anchors.mean(dim=1) # [C,D]
65
+ g = (w @ centroids) # [B,T,D]
66
+
67
+ if punct_mask is not None:
68
+ # alpha soft mask reduces gate on punctuation tokens
69
+ pm = punct_mask.float().unsqueeze(-1) # [B,T,1], 1 on punct
70
+ if hard_mask_gate:
71
+ g = g * (1.0 - pm) # zero gate on punct
72
+ else:
73
+ g = g * (1.0 - pm*(1.0 - alpha_gate))
74
+
75
+ return w, self.scale * g
76
+
77
+ class MLP(nn.Module):
78
+ def __init__(self, dim, mlp_ratio=4.0, dropout=0.1):
79
+ super().__init__()
80
+ hidden = int(dim*mlp_ratio)
81
+ self.fc1 = nn.Linear(dim, hidden)
82
+ self.fc2 = nn.Linear(hidden, dim)
83
+ self.drop = nn.Dropout(dropout)
84
+ def forward(self, x):
85
+ x = self.fc1(x)
86
+ x = F.gelu(x, approximate="tanh")
87
+ x = self.drop(x)
88
+ x = self.fc2(x)
89
+ x = self.drop(x)
90
+ return x
91
+
92
+ class GatedBlock(nn.Module):
93
+ def __init__(self, dim, n_heads, mlp_ratio, dropout, regions):
94
+ super().__init__()
95
+ self.norm1 = nn.LayerNorm(dim)
96
+ self.attn = CausalSelfAttention(dim, n_heads, attn_dropout=dropout)
97
+ self.gate = GeometricGate(dim, regions=regions, tau=0.08)
98
+ self.norm2 = nn.LayerNorm(dim)
99
+ self.mlp = MLP(dim, mlp_ratio=mlp_ratio, dropout=dropout)
100
+ self.mix_sdpa = nn.Parameter(torch.tensor(1.0)) # stage-adjustable
101
+
102
+ def forward(self, x, punct_mask=None, return_gate=False, alpha_gate=1.0, hard_mask_gate=False):
103
+ h = self.norm1(x)
104
+ att = self.attn(h) * self.mix_sdpa
105
+ w, g = self.gate(h, punct_mask=punct_mask, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
106
+ x = x + att + g
107
+ x = x + self.mlp(self.norm2(x))
108
+ return (x, w) if return_gate else x
109
+
110
+ class CrystalBeeper(nn.Module):
111
+ def __init__(self, model_cfg: dict):
112
+ super().__init__()
113
+ self.cfg = model_cfg
114
+ D, L, H = model_cfg["dim"], model_cfg["n_layers"], model_cfg["n_heads"]
115
+ ctx = model_cfg["context"]
116
+
117
+ # Ingress
118
+ self.use_ascii = bool(model_cfg.get("use_ascii", True))
119
+ self.codec = AsciiCodec()
120
+ V_ascii = self.codec.vocab_size
121
+
122
+ self.token_emb = nn.Embedding(V_ascii, D)
123
+ self.pos_emb = nn.Parameter(torch.zeros(1, ctx, D))
124
+ self.drop = nn.Dropout(model_cfg.get("resid_dropout", 0.1))
125
+
126
+ regions = [int(model_cfg.get("regions_per_block", 64)) for _ in range(L)]
127
+ self.blocks = nn.ModuleList([
128
+ GatedBlock(D, H, model_cfg["mlp_ratio"], model_cfg["dropout"], regions[i])
129
+ for i in range(L)
130
+ ])
131
+ self.norm = nn.LayerNorm(D)
132
+
133
+ # Output heads
134
+ self.ascii_head = nn.Linear(D, V_ascii, bias=False)
135
+ self.bpe_head = nn.Linear(D, model_cfg.get("vocab_size", 8192), bias=False) # optional
136
+
137
+ # Global tug (Rose)
138
+ self.rose_proj = nn.Linear(D, D, bias=False)
139
+ self.rose_anchors = nn.Parameter(torch.randn(3, D) / math.sqrt(D))
140
+
141
+ self.apply(self._init)
142
+
143
+ @staticmethod
144
+ def _init(m):
145
+ if isinstance(m, (nn.Linear, nn.Embedding)):
146
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
147
+ if getattr(m, "bias", None) is not None: nn.init.zeros_(m.bias)
148
+
149
+ def backbone(self, idx, punct_mask=None, return_routes=False, alpha_gate=1.0, hard_mask_gate=False):
150
+ B,T = idx.shape
151
+ x = self.token_emb(idx) + self.pos_emb[:, :T, :]
152
+ x = self.drop(x)
153
+ routes = []
154
+ for blk in self.blocks:
155
+ x, w = blk(x, punct_mask=punct_mask, return_gate=True, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
156
+ if return_routes: routes.append(w)
157
+ x = self.norm(x)
158
+ return (x, routes) if return_routes else (x, None)
159
+
160
+ def forward(self, idx, punct_mask=None, head="ascii", return_routes=False, alpha_gate=1.0, hard_mask_gate=False):
161
+ h, routes = self.backbone(idx, punct_mask=punct_mask, return_routes=return_routes, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
162
+ if head == "ascii": logits = self.ascii_head(h)
163
+ elif head == "bpe": logits = self.bpe_head(h)
164
+ else: raise ValueError("head must be 'ascii' or 'bpe'")
165
+ return logits, routes