Create mini-transformers-ANN.py
Browse files- mini-transformers-ANN.py +242 -0
mini-transformers-ANN.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# transformer_mini_addition.py
|
2 |
+
# Train a tiny Transformer encoder-decoder on a toy task: "a + b =" -> "c"
|
3 |
+
# Example: input tokens ["3","+","5","="] -> output tokens ["8"]
|
4 |
+
# Run: python transformer_mini_addition.py
|
5 |
+
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
from typing import List, Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.utils.data import Dataset, DataLoader
|
13 |
+
|
14 |
+
# -----------------------
|
15 |
+
# Config
|
16 |
+
# -----------------------
|
17 |
+
VOCAB = ["0","1","2","3","4","5","6","7","8","9","+","=","<pad>","<s>","</s>"]
|
18 |
+
PAD, BOS, EOS = VOCAB.index("<pad>"), VOCAB.index("<s>"), VOCAB.index("</s>")
|
19 |
+
TOK2ID = {t:i for i,t in enumerate(VOCAB)}
|
20 |
+
ID2TOK = {i:t for t,i in TOK2ID.items()}
|
21 |
+
|
22 |
+
MAX_IN_LEN = 4 # "d + d =" -> 4 tokens
|
23 |
+
MAX_OUT_LEN = 3 # could be 1 or 2 digits + EOS
|
24 |
+
EMB_DIM = 128
|
25 |
+
FF_DIM = 256
|
26 |
+
N_HEAD = 4
|
27 |
+
N_LAYERS = 2
|
28 |
+
DROPOUT = 0.1
|
29 |
+
BATCH_SIZE = 128
|
30 |
+
STEPS = 1500 # keep small for a mini training run
|
31 |
+
LR = 3e-4
|
32 |
+
WARMUP = 100
|
33 |
+
SAVE_PATH = "mini_transformer_addition.pt"
|
34 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
torch.manual_seed(42)
|
36 |
+
random.seed(42)
|
37 |
+
|
38 |
+
# -----------------------
|
39 |
+
# Data: generate on-the-fly mini dataset
|
40 |
+
# -----------------------
|
41 |
+
def encode(seq: List[str], max_len: int) -> List[int]:
|
42 |
+
ids = [TOK2ID[s] for s in seq]
|
43 |
+
if len(ids) < max_len:
|
44 |
+
ids += [PAD]*(max_len-len(ids))
|
45 |
+
return ids[:max_len]
|
46 |
+
|
47 |
+
def sample_pair() -> Tuple[List[int], List[int]]:
|
48 |
+
a, b = random.randint(0,9), random.randint(0,9)
|
49 |
+
c = a + b
|
50 |
+
inp = [str(a), "+", str(b), "="] # length 4
|
51 |
+
out_tokens = list(str(c)) # "0".."18"
|
52 |
+
tgt = [BOS] + [TOK2ID[t] for t in out_tokens] + [EOS] # BOS ... EOS
|
53 |
+
# pad to MAX_OUT_LEN + 2 (BOS/EOS)
|
54 |
+
max_len = MAX_OUT_LEN + 2
|
55 |
+
if len(tgt) < max_len:
|
56 |
+
tgt += [PAD] * (max_len - len(tgt))
|
57 |
+
return encode(inp, MAX_IN_LEN), tgt
|
58 |
+
|
59 |
+
class MiniAddDataset(Dataset):
|
60 |
+
def __init__(self, size=5000):
|
61 |
+
self.size = size
|
62 |
+
def __len__(self): return self.size
|
63 |
+
def __getitem__(self, idx):
|
64 |
+
src, tgt = sample_pair()
|
65 |
+
return torch.tensor(src), torch.tensor(tgt)
|
66 |
+
|
67 |
+
train_ds = MiniAddDataset(size=8000)
|
68 |
+
val_ds = MiniAddDataset(size=500)
|
69 |
+
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
|
70 |
+
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE)
|
71 |
+
|
72 |
+
# -----------------------
|
73 |
+
# Model
|
74 |
+
# -----------------------
|
75 |
+
class PositionalEncoding(nn.Module):
|
76 |
+
def __init__(self, d_model, max_len=64):
|
77 |
+
super().__init__()
|
78 |
+
pe = torch.zeros(max_len, d_model)
|
79 |
+
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
80 |
+
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
|
81 |
+
pe[:, 0::2] = torch.sin(pos * div)
|
82 |
+
pe[:, 1::2] = torch.cos(pos * div)
|
83 |
+
self.register_buffer("pe", pe.unsqueeze(0)) # (1, max_len, d_model)
|
84 |
+
def forward(self, x):
|
85 |
+
return x + self.pe[:, :x.size(1), :]
|
86 |
+
|
87 |
+
class TinyTransformer(nn.Module):
|
88 |
+
def __init__(self, vocab_size: int):
|
89 |
+
super().__init__()
|
90 |
+
self.src_emb = nn.Embedding(vocab_size, EMB_DIM, padding_idx=PAD)
|
91 |
+
self.tgt_emb = nn.Embedding(vocab_size, EMB_DIM, padding_idx=PAD)
|
92 |
+
self.pos_enc_src = PositionalEncoding(EMB_DIM, max_len=MAX_IN_LEN+8)
|
93 |
+
self.pos_enc_tgt = PositionalEncoding(EMB_DIM, max_len=MAX_OUT_LEN+8)
|
94 |
+
|
95 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
96 |
+
d_model=EMB_DIM, nhead=N_HEAD, dim_feedforward=FF_DIM, dropout=DROPOUT, batch_first=True
|
97 |
+
)
|
98 |
+
decoder_layer = nn.TransformerDecoderLayer(
|
99 |
+
d_model=EMB_DIM, nhead=N_HEAD, dim_feedforward=FF_DIM, dropout=DROPOUT, batch_first=True
|
100 |
+
)
|
101 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=N_LAYERS)
|
102 |
+
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=N_LAYERS)
|
103 |
+
self.lm_head = nn.Linear(EMB_DIM, vocab_size)
|
104 |
+
|
105 |
+
def make_padding_mask(self, seq, pad_idx=PAD):
|
106 |
+
# returns (batch, 1, 1, seq_len) for nn.Transformer; but with batch_first=True we can use (batch, seq_len)
|
107 |
+
return (seq == pad_idx)
|
108 |
+
|
109 |
+
def generate_square_subsequent_mask(self, sz):
|
110 |
+
# causal mask for decoder (tgt): allow attention to previous positions only
|
111 |
+
return torch.triu(torch.ones(sz, sz, device=DEVICE), diagonal=1).bool()
|
112 |
+
|
113 |
+
def forward(self, src_ids, tgt_ids):
|
114 |
+
# src_ids: (B, S) ; tgt_ids: (B, T)
|
115 |
+
src_key_padding_mask = self.make_padding_mask(src_ids) # (B,S)
|
116 |
+
tgt_key_padding_mask = self.make_padding_mask(tgt_ids) # (B,T)
|
117 |
+
tgt_mask = self.generate_square_subsequent_mask(tgt_ids.size(1))
|
118 |
+
|
119 |
+
src = self.src_emb(src_ids)
|
120 |
+
src = self.pos_enc_src(src)
|
121 |
+
memory = self.encoder(src, src_key_padding_mask=src_key_padding_mask)
|
122 |
+
|
123 |
+
tgt = self.tgt_emb(tgt_ids)
|
124 |
+
tgt = self.pos_enc_tgt(tgt)
|
125 |
+
out = self.decoder(
|
126 |
+
tgt, memory,
|
127 |
+
tgt_mask=tgt_mask,
|
128 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
129 |
+
memory_key_padding_mask=src_key_padding_mask
|
130 |
+
)
|
131 |
+
logits = self.lm_head(out) # (B,T,V)
|
132 |
+
return logits
|
133 |
+
|
134 |
+
# -----------------------
|
135 |
+
# Training utils
|
136 |
+
# -----------------------
|
137 |
+
class WarmupAdam(torch.optim.Adam):
|
138 |
+
def __init__(self, params, lr, warmup_steps=1000):
|
139 |
+
super().__init__(params, lr=lr, betas=(0.9, 0.98), eps=1e-9)
|
140 |
+
self.warmup_steps = warmup_steps
|
141 |
+
self._step = 0
|
142 |
+
self._base_lr = lr
|
143 |
+
def step(self, closure=None):
|
144 |
+
self._step += 1
|
145 |
+
scale = min(self._step ** (-0.5), self._step * (self.warmup_steps ** (-1.5)))
|
146 |
+
for g in self.param_groups:
|
147 |
+
g['lr'] = self._base_lr * scale * (self.warmup_steps ** 0.5)
|
148 |
+
return super().step(closure=closure)
|
149 |
+
|
150 |
+
def shift_right(tgt):
|
151 |
+
"""
|
152 |
+
Teacher forcing: model sees BOS + y[:-1] and predicts y.
|
153 |
+
Here tgt is already [BOS, y..., EOS, PAD...]
|
154 |
+
We return inp=tgt[:, :-1], label=tgt[:, 1:]
|
155 |
+
"""
|
156 |
+
return tgt[:, :-1], tgt[:, 1:]
|
157 |
+
|
158 |
+
def accuracy_from_logits(logits, labels, pad=PAD):
|
159 |
+
# logits: (B,T,V), labels: (B,T)
|
160 |
+
preds = logits.argmax(-1)
|
161 |
+
mask = labels.ne(pad)
|
162 |
+
correct = (preds.eq(labels) & mask).sum().item()
|
163 |
+
total = mask.sum().item() + 1e-9
|
164 |
+
return correct/total
|
165 |
+
|
166 |
+
# -----------------------
|
167 |
+
# Train
|
168 |
+
# -----------------------
|
169 |
+
model = TinyTransformer(vocab_size=len(VOCAB)).to(DEVICE)
|
170 |
+
criterion = nn.CrossEntropyLoss(ignore_index=PAD)
|
171 |
+
optim = WarmupAdam(model.parameters(), lr=LR, warmup_steps=WARMUP)
|
172 |
+
|
173 |
+
def run_epoch(dl, train=True):
|
174 |
+
model.train(train)
|
175 |
+
total_loss, total_acc, n = 0.0, 0.0, 0
|
176 |
+
for src, tgt in dl:
|
177 |
+
src, tgt = src.to(DEVICE), tgt.to(DEVICE)
|
178 |
+
dec_inp, labels = shift_right(tgt)
|
179 |
+
logits = model(src, dec_inp)
|
180 |
+
loss = criterion(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
|
181 |
+
acc = accuracy_from_logits(logits, labels)
|
182 |
+
|
183 |
+
if train:
|
184 |
+
optim.zero_grad(set_to_none=True)
|
185 |
+
loss.backward()
|
186 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
187 |
+
optim.step()
|
188 |
+
|
189 |
+
bs = src.size(0)
|
190 |
+
total_loss += loss.item() * bs
|
191 |
+
total_acc += acc * bs
|
192 |
+
n += bs
|
193 |
+
return total_loss/n, total_acc/n
|
194 |
+
|
195 |
+
best_val = 0.0
|
196 |
+
for step in range(1, STEPS+1):
|
197 |
+
tr_loss, tr_acc = run_epoch(train_dl, train=True)
|
198 |
+
if step % 50 == 0:
|
199 |
+
val_loss, val_acc = run_epoch(val_dl, train=False)
|
200 |
+
print(f"[step {step:4d}] train loss {tr_loss:.3f} acc {tr_acc:.3f} | val loss {val_loss:.3f} acc {val_acc:.3f}")
|
201 |
+
if val_acc > best_val:
|
202 |
+
best_val = val_acc
|
203 |
+
torch.save({"model": model.state_dict()}, SAVE_PATH)
|
204 |
+
|
205 |
+
print(f"Saved best model to: {SAVE_PATH}")
|
206 |
+
|
207 |
+
# -----------------------
|
208 |
+
# Inference demo
|
209 |
+
# -----------------------
|
210 |
+
def encode_inp(a:int,b:int):
|
211 |
+
seq = [str(a), "+", str(b), "="]
|
212 |
+
return torch.tensor([encode(seq, MAX_IN_LEN)], device=DEVICE)
|
213 |
+
|
214 |
+
def greedy_decode(src_ids, max_len=MAX_OUT_LEN+2):
|
215 |
+
model.eval()
|
216 |
+
with torch.no_grad():
|
217 |
+
# Start with BOS
|
218 |
+
ys = torch.tensor([[BOS]], device=DEVICE)
|
219 |
+
for _ in range(max_len-1):
|
220 |
+
logits = model(src_ids, ys)
|
221 |
+
next_tok = logits[:, -1, :].argmax(-1, keepdim=True) # (B,1)
|
222 |
+
ys = torch.cat([ys, next_tok], dim=1)
|
223 |
+
if next_tok.item() == EOS:
|
224 |
+
break
|
225 |
+
return ys.squeeze(0).tolist()
|
226 |
+
|
227 |
+
def detok(ids: List[int]) -> str:
|
228 |
+
toks = [ID2TOK[i] for i in ids if i not in (PAD, BOS)]
|
229 |
+
out = []
|
230 |
+
for t in toks:
|
231 |
+
if t == "</s>": break
|
232 |
+
out.append(t)
|
233 |
+
return "".join(out)
|
234 |
+
|
235 |
+
# Load best (optional, already in memory)
|
236 |
+
ckpt = torch.load(SAVE_PATH, map_location=DEVICE)
|
237 |
+
model.load_state_dict(ckpt["model"])
|
238 |
+
|
239 |
+
for (a,b) in [(3,5),(9,8),(0,0),(7,2),(4,6)]:
|
240 |
+
src = encode_inp(a,b)
|
241 |
+
out_ids = greedy_decode(src)
|
242 |
+
print(f"{a}+{b} => {detok(out_ids)}")
|