Create model.py
Browse files
model.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class CausalSelfAttention(nn.Module):
|
3 |
+
def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0):
|
4 |
+
super().__init__()
|
5 |
+
assert dim % n_heads == 0
|
6 |
+
self.nh = n_heads; self.hd = dim // n_heads
|
7 |
+
self.qkv = nn.Linear(dim, 3*dim, bias=False)
|
8 |
+
self.proj = nn.Linear(dim, dim, bias=False)
|
9 |
+
self.attn_dropout = attn_dropout
|
10 |
+
def forward(self, x):
|
11 |
+
B,T,C = x.shape
|
12 |
+
qkv = self.qkv(x); q,k,v = qkv.chunk(3, dim=-1)
|
13 |
+
q = q.view(B,T,self.nh,self.hd).transpose(1,2)
|
14 |
+
k = k.view(B,T,self.nh,self.hd).transpose(1,2)
|
15 |
+
v = v.view(B,T,self.nh,self.hd).transpose(1,2)
|
16 |
+
if x.is_cuda:
|
17 |
+
with sdpa_ctx_prefer_flash():
|
18 |
+
y = F.scaled_dot_product_attention(q,k,v,is_causal=True,
|
19 |
+
dropout_p=self.attn_dropout if self.training else 0.0)
|
20 |
+
else:
|
21 |
+
scale = 1.0 / math.sqrt(self.hd)
|
22 |
+
att = (q @ k.transpose(-2,-1)) * scale
|
23 |
+
mask = torch.full((1,1,T,T), float("-inf"), device=x.device)
|
24 |
+
mask = torch.triu(mask, diagonal=1)
|
25 |
+
att = (att + mask).softmax(dim=-1)
|
26 |
+
y = att @ v
|
27 |
+
y = y.transpose(1,2).contiguous().view(B,T,C)
|
28 |
+
return self.proj(y)
|
29 |
+
|
30 |
+
def _normalize_cell(X): # X: [V,D]
|
31 |
+
Xc = X - X.mean(dim=0, keepdim=True)
|
32 |
+
r = Xc.pow(2).sum(dim=1).mean().sqrt().clamp_min(1e-6)
|
33 |
+
return Xc / r
|
34 |
+
|
35 |
+
class CrystalBank(nn.Module):
|
36 |
+
def __init__(self, regions: int, dim: int):
|
37 |
+
super().__init__()
|
38 |
+
pts = torch.randn(regions, 5, dim) / math.sqrt(dim)
|
39 |
+
with torch.no_grad():
|
40 |
+
for i in range(regions): pts[i] = _normalize_cell(pts[i])
|
41 |
+
self.anchors = nn.Parameter(pts) # [C,5,D]
|
42 |
+
|
43 |
+
class GeometricGate(nn.Module):
|
44 |
+
def __init__(self, dim: int, regions: int, tau: float = 0.08):
|
45 |
+
super().__init__()
|
46 |
+
self.bank = CrystalBank(regions, dim)
|
47 |
+
self.nav = nn.Linear(dim, dim, bias=False)
|
48 |
+
self.tau = tau
|
49 |
+
self.scale = nn.Parameter(torch.tensor(1.0)) # residual mix scaler
|
50 |
+
|
51 |
+
def forward(self, h: torch.Tensor, punct_mask: Optional[torch.Tensor] = None, alpha_gate: float = 1.0, hard_mask_gate: bool=False):
|
52 |
+
B,T,D = h.shape
|
53 |
+
C = self.bank.anchors.size(0)
|
54 |
+
H = self.nav(h).reshape(B*T, D) # [BT,D]
|
55 |
+
A = self.bank.anchors.reshape(C*5, D) # [C*5,D]
|
56 |
+
# squared distances via expansion
|
57 |
+
x2 = (H*H).sum(dim=-1, keepdim=True) # [BT,1]
|
58 |
+
a2 = (A*A).sum(dim=-1).unsqueeze(0) # [1,C*5]
|
59 |
+
xa = H @ A.T # [BT,C*5]
|
60 |
+
d2 = (x2 + a2 - 2*xa).clamp_min(0.0).view(B*T, C, 5)
|
61 |
+
s = -torch.logsumexp(-d2 / max(1e-6,self.tau), dim=-1) # [BT,C]
|
62 |
+
w = F.softmax(-s / max(1e-6,self.tau), dim=-1).view(B,T,C)
|
63 |
+
|
64 |
+
centroids = self.bank.anchors.mean(dim=1) # [C,D]
|
65 |
+
g = (w @ centroids) # [B,T,D]
|
66 |
+
|
67 |
+
if punct_mask is not None:
|
68 |
+
# alpha soft mask reduces gate on punctuation tokens
|
69 |
+
pm = punct_mask.float().unsqueeze(-1) # [B,T,1], 1 on punct
|
70 |
+
if hard_mask_gate:
|
71 |
+
g = g * (1.0 - pm) # zero gate on punct
|
72 |
+
else:
|
73 |
+
g = g * (1.0 - pm*(1.0 - alpha_gate))
|
74 |
+
|
75 |
+
return w, self.scale * g
|
76 |
+
|
77 |
+
class MLP(nn.Module):
|
78 |
+
def __init__(self, dim, mlp_ratio=4.0, dropout=0.1):
|
79 |
+
super().__init__()
|
80 |
+
hidden = int(dim*mlp_ratio)
|
81 |
+
self.fc1 = nn.Linear(dim, hidden)
|
82 |
+
self.fc2 = nn.Linear(hidden, dim)
|
83 |
+
self.drop = nn.Dropout(dropout)
|
84 |
+
def forward(self, x):
|
85 |
+
x = self.fc1(x)
|
86 |
+
x = F.gelu(x, approximate="tanh")
|
87 |
+
x = self.drop(x)
|
88 |
+
x = self.fc2(x)
|
89 |
+
x = self.drop(x)
|
90 |
+
return x
|
91 |
+
|
92 |
+
class GatedBlock(nn.Module):
|
93 |
+
def __init__(self, dim, n_heads, mlp_ratio, dropout, regions):
|
94 |
+
super().__init__()
|
95 |
+
self.norm1 = nn.LayerNorm(dim)
|
96 |
+
self.attn = CausalSelfAttention(dim, n_heads, attn_dropout=dropout)
|
97 |
+
self.gate = GeometricGate(dim, regions=regions, tau=0.08)
|
98 |
+
self.norm2 = nn.LayerNorm(dim)
|
99 |
+
self.mlp = MLP(dim, mlp_ratio=mlp_ratio, dropout=dropout)
|
100 |
+
self.mix_sdpa = nn.Parameter(torch.tensor(1.0)) # stage-adjustable
|
101 |
+
|
102 |
+
def forward(self, x, punct_mask=None, return_gate=False, alpha_gate=1.0, hard_mask_gate=False):
|
103 |
+
h = self.norm1(x)
|
104 |
+
att = self.attn(h) * self.mix_sdpa
|
105 |
+
w, g = self.gate(h, punct_mask=punct_mask, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
|
106 |
+
x = x + att + g
|
107 |
+
x = x + self.mlp(self.norm2(x))
|
108 |
+
return (x, w) if return_gate else x
|
109 |
+
|
110 |
+
class CrystalBeeper(nn.Module):
|
111 |
+
def __init__(self, model_cfg: dict):
|
112 |
+
super().__init__()
|
113 |
+
self.cfg = model_cfg
|
114 |
+
D, L, H = model_cfg["dim"], model_cfg["n_layers"], model_cfg["n_heads"]
|
115 |
+
ctx = model_cfg["context"]
|
116 |
+
|
117 |
+
# Ingress
|
118 |
+
self.use_ascii = bool(model_cfg.get("use_ascii", True))
|
119 |
+
self.codec = AsciiCodec()
|
120 |
+
V_ascii = self.codec.vocab_size
|
121 |
+
|
122 |
+
self.token_emb = nn.Embedding(V_ascii, D)
|
123 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, ctx, D))
|
124 |
+
self.drop = nn.Dropout(model_cfg.get("resid_dropout", 0.1))
|
125 |
+
|
126 |
+
regions = [int(model_cfg.get("regions_per_block", 64)) for _ in range(L)]
|
127 |
+
self.blocks = nn.ModuleList([
|
128 |
+
GatedBlock(D, H, model_cfg["mlp_ratio"], model_cfg["dropout"], regions[i])
|
129 |
+
for i in range(L)
|
130 |
+
])
|
131 |
+
self.norm = nn.LayerNorm(D)
|
132 |
+
|
133 |
+
# Output heads
|
134 |
+
self.ascii_head = nn.Linear(D, V_ascii, bias=False)
|
135 |
+
self.bpe_head = nn.Linear(D, model_cfg.get("vocab_size", 8192), bias=False) # optional
|
136 |
+
|
137 |
+
# Global tug (Rose)
|
138 |
+
self.rose_proj = nn.Linear(D, D, bias=False)
|
139 |
+
self.rose_anchors = nn.Parameter(torch.randn(3, D) / math.sqrt(D))
|
140 |
+
|
141 |
+
self.apply(self._init)
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def _init(m):
|
145 |
+
if isinstance(m, (nn.Linear, nn.Embedding)):
|
146 |
+
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
147 |
+
if getattr(m, "bias", None) is not None: nn.init.zeros_(m.bias)
|
148 |
+
|
149 |
+
def backbone(self, idx, punct_mask=None, return_routes=False, alpha_gate=1.0, hard_mask_gate=False):
|
150 |
+
B,T = idx.shape
|
151 |
+
x = self.token_emb(idx) + self.pos_emb[:, :T, :]
|
152 |
+
x = self.drop(x)
|
153 |
+
routes = []
|
154 |
+
for blk in self.blocks:
|
155 |
+
x, w = blk(x, punct_mask=punct_mask, return_gate=True, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
|
156 |
+
if return_routes: routes.append(w)
|
157 |
+
x = self.norm(x)
|
158 |
+
return (x, routes) if return_routes else (x, None)
|
159 |
+
|
160 |
+
def forward(self, idx, punct_mask=None, head="ascii", return_routes=False, alpha_gate=1.0, hard_mask_gate=False):
|
161 |
+
h, routes = self.backbone(idx, punct_mask=punct_mask, return_routes=return_routes, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
|
162 |
+
if head == "ascii": logits = self.ascii_head(h)
|
163 |
+
elif head == "bpe": logits = self.bpe_head(h)
|
164 |
+
else: raise ValueError("head must be 'ascii' or 'bpe'")
|
165 |
+
return logits, routes
|