File size: 7,066 Bytes
4c9ea57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
class CausalSelfAttention(nn.Module):
def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0):
super().__init__()
assert dim % n_heads == 0
self.nh = n_heads; self.hd = dim // n_heads
self.qkv = nn.Linear(dim, 3*dim, bias=False)
self.proj = nn.Linear(dim, dim, bias=False)
self.attn_dropout = attn_dropout
def forward(self, x):
B,T,C = x.shape
qkv = self.qkv(x); q,k,v = qkv.chunk(3, dim=-1)
q = q.view(B,T,self.nh,self.hd).transpose(1,2)
k = k.view(B,T,self.nh,self.hd).transpose(1,2)
v = v.view(B,T,self.nh,self.hd).transpose(1,2)
if x.is_cuda:
with sdpa_ctx_prefer_flash():
y = F.scaled_dot_product_attention(q,k,v,is_causal=True,
dropout_p=self.attn_dropout if self.training else 0.0)
else:
scale = 1.0 / math.sqrt(self.hd)
att = (q @ k.transpose(-2,-1)) * scale
mask = torch.full((1,1,T,T), float("-inf"), device=x.device)
mask = torch.triu(mask, diagonal=1)
att = (att + mask).softmax(dim=-1)
y = att @ v
y = y.transpose(1,2).contiguous().view(B,T,C)
return self.proj(y)
def _normalize_cell(X): # X: [V,D]
Xc = X - X.mean(dim=0, keepdim=True)
r = Xc.pow(2).sum(dim=1).mean().sqrt().clamp_min(1e-6)
return Xc / r
class CrystalBank(nn.Module):
def __init__(self, regions: int, dim: int):
super().__init__()
pts = torch.randn(regions, 5, dim) / math.sqrt(dim)
with torch.no_grad():
for i in range(regions): pts[i] = _normalize_cell(pts[i])
self.anchors = nn.Parameter(pts) # [C,5,D]
class GeometricGate(nn.Module):
def __init__(self, dim: int, regions: int, tau: float = 0.08):
super().__init__()
self.bank = CrystalBank(regions, dim)
self.nav = nn.Linear(dim, dim, bias=False)
self.tau = tau
self.scale = nn.Parameter(torch.tensor(1.0)) # residual mix scaler
def forward(self, h: torch.Tensor, punct_mask: Optional[torch.Tensor] = None, alpha_gate: float = 1.0, hard_mask_gate: bool=False):
B,T,D = h.shape
C = self.bank.anchors.size(0)
H = self.nav(h).reshape(B*T, D) # [BT,D]
A = self.bank.anchors.reshape(C*5, D) # [C*5,D]
# squared distances via expansion
x2 = (H*H).sum(dim=-1, keepdim=True) # [BT,1]
a2 = (A*A).sum(dim=-1).unsqueeze(0) # [1,C*5]
xa = H @ A.T # [BT,C*5]
d2 = (x2 + a2 - 2*xa).clamp_min(0.0).view(B*T, C, 5)
s = -torch.logsumexp(-d2 / max(1e-6,self.tau), dim=-1) # [BT,C]
w = F.softmax(-s / max(1e-6,self.tau), dim=-1).view(B,T,C)
centroids = self.bank.anchors.mean(dim=1) # [C,D]
g = (w @ centroids) # [B,T,D]
if punct_mask is not None:
# alpha soft mask reduces gate on punctuation tokens
pm = punct_mask.float().unsqueeze(-1) # [B,T,1], 1 on punct
if hard_mask_gate:
g = g * (1.0 - pm) # zero gate on punct
else:
g = g * (1.0 - pm*(1.0 - alpha_gate))
return w, self.scale * g
class MLP(nn.Module):
def __init__(self, dim, mlp_ratio=4.0, dropout=0.1):
super().__init__()
hidden = int(dim*mlp_ratio)
self.fc1 = nn.Linear(dim, hidden)
self.fc2 = nn.Linear(hidden, dim)
self.drop = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = F.gelu(x, approximate="tanh")
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class GatedBlock(nn.Module):
def __init__(self, dim, n_heads, mlp_ratio, dropout, regions):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = CausalSelfAttention(dim, n_heads, attn_dropout=dropout)
self.gate = GeometricGate(dim, regions=regions, tau=0.08)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, mlp_ratio=mlp_ratio, dropout=dropout)
self.mix_sdpa = nn.Parameter(torch.tensor(1.0)) # stage-adjustable
def forward(self, x, punct_mask=None, return_gate=False, alpha_gate=1.0, hard_mask_gate=False):
h = self.norm1(x)
att = self.attn(h) * self.mix_sdpa
w, g = self.gate(h, punct_mask=punct_mask, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
x = x + att + g
x = x + self.mlp(self.norm2(x))
return (x, w) if return_gate else x
class CrystalBeeper(nn.Module):
def __init__(self, model_cfg: dict):
super().__init__()
self.cfg = model_cfg
D, L, H = model_cfg["dim"], model_cfg["n_layers"], model_cfg["n_heads"]
ctx = model_cfg["context"]
# Ingress
self.use_ascii = bool(model_cfg.get("use_ascii", True))
self.codec = AsciiCodec()
V_ascii = self.codec.vocab_size
self.token_emb = nn.Embedding(V_ascii, D)
self.pos_emb = nn.Parameter(torch.zeros(1, ctx, D))
self.drop = nn.Dropout(model_cfg.get("resid_dropout", 0.1))
regions = [int(model_cfg.get("regions_per_block", 64)) for _ in range(L)]
self.blocks = nn.ModuleList([
GatedBlock(D, H, model_cfg["mlp_ratio"], model_cfg["dropout"], regions[i])
for i in range(L)
])
self.norm = nn.LayerNorm(D)
# Output heads
self.ascii_head = nn.Linear(D, V_ascii, bias=False)
self.bpe_head = nn.Linear(D, model_cfg.get("vocab_size", 8192), bias=False) # optional
# Global tug (Rose)
self.rose_proj = nn.Linear(D, D, bias=False)
self.rose_anchors = nn.Parameter(torch.randn(3, D) / math.sqrt(D))
self.apply(self._init)
@staticmethod
def _init(m):
if isinstance(m, (nn.Linear, nn.Embedding)):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if getattr(m, "bias", None) is not None: nn.init.zeros_(m.bias)
def backbone(self, idx, punct_mask=None, return_routes=False, alpha_gate=1.0, hard_mask_gate=False):
B,T = idx.shape
x = self.token_emb(idx) + self.pos_emb[:, :T, :]
x = self.drop(x)
routes = []
for blk in self.blocks:
x, w = blk(x, punct_mask=punct_mask, return_gate=True, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
if return_routes: routes.append(w)
x = self.norm(x)
return (x, routes) if return_routes else (x, None)
def forward(self, idx, punct_mask=None, head="ascii", return_routes=False, alpha_gate=1.0, hard_mask_gate=False):
h, routes = self.backbone(idx, punct_mask=punct_mask, return_routes=return_routes, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
if head == "ascii": logits = self.ascii_head(h)
elif head == "bpe": logits = self.bpe_head(h)
else: raise ValueError("head must be 'ascii' or 'bpe'")
return logits, routes
|