ankitkushwaha90 commited on
Commit
6fc6ec8
·
verified ·
1 Parent(s): a44ffc6

Create mini-transformers-ANN.py

Browse files
Files changed (1) hide show
  1. mini-transformers-ANN.py +242 -0
mini-transformers-ANN.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # transformer_mini_addition.py
2
+ # Train a tiny Transformer encoder-decoder on a toy task: "a + b =" -> "c"
3
+ # Example: input tokens ["3","+","5","="] -> output tokens ["8"]
4
+ # Run: python transformer_mini_addition.py
5
+
6
+ import math
7
+ import random
8
+ from typing import List, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.utils.data import Dataset, DataLoader
13
+
14
+ # -----------------------
15
+ # Config
16
+ # -----------------------
17
+ VOCAB = ["0","1","2","3","4","5","6","7","8","9","+","=","<pad>","<s>","</s>"]
18
+ PAD, BOS, EOS = VOCAB.index("<pad>"), VOCAB.index("<s>"), VOCAB.index("</s>")
19
+ TOK2ID = {t:i for i,t in enumerate(VOCAB)}
20
+ ID2TOK = {i:t for t,i in TOK2ID.items()}
21
+
22
+ MAX_IN_LEN = 4 # "d + d =" -> 4 tokens
23
+ MAX_OUT_LEN = 3 # could be 1 or 2 digits + EOS
24
+ EMB_DIM = 128
25
+ FF_DIM = 256
26
+ N_HEAD = 4
27
+ N_LAYERS = 2
28
+ DROPOUT = 0.1
29
+ BATCH_SIZE = 128
30
+ STEPS = 1500 # keep small for a mini training run
31
+ LR = 3e-4
32
+ WARMUP = 100
33
+ SAVE_PATH = "mini_transformer_addition.pt"
34
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ torch.manual_seed(42)
36
+ random.seed(42)
37
+
38
+ # -----------------------
39
+ # Data: generate on-the-fly mini dataset
40
+ # -----------------------
41
+ def encode(seq: List[str], max_len: int) -> List[int]:
42
+ ids = [TOK2ID[s] for s in seq]
43
+ if len(ids) < max_len:
44
+ ids += [PAD]*(max_len-len(ids))
45
+ return ids[:max_len]
46
+
47
+ def sample_pair() -> Tuple[List[int], List[int]]:
48
+ a, b = random.randint(0,9), random.randint(0,9)
49
+ c = a + b
50
+ inp = [str(a), "+", str(b), "="] # length 4
51
+ out_tokens = list(str(c)) # "0".."18"
52
+ tgt = [BOS] + [TOK2ID[t] for t in out_tokens] + [EOS] # BOS ... EOS
53
+ # pad to MAX_OUT_LEN + 2 (BOS/EOS)
54
+ max_len = MAX_OUT_LEN + 2
55
+ if len(tgt) < max_len:
56
+ tgt += [PAD] * (max_len - len(tgt))
57
+ return encode(inp, MAX_IN_LEN), tgt
58
+
59
+ class MiniAddDataset(Dataset):
60
+ def __init__(self, size=5000):
61
+ self.size = size
62
+ def __len__(self): return self.size
63
+ def __getitem__(self, idx):
64
+ src, tgt = sample_pair()
65
+ return torch.tensor(src), torch.tensor(tgt)
66
+
67
+ train_ds = MiniAddDataset(size=8000)
68
+ val_ds = MiniAddDataset(size=500)
69
+ train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
70
+ val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE)
71
+
72
+ # -----------------------
73
+ # Model
74
+ # -----------------------
75
+ class PositionalEncoding(nn.Module):
76
+ def __init__(self, d_model, max_len=64):
77
+ super().__init__()
78
+ pe = torch.zeros(max_len, d_model)
79
+ pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
80
+ div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
81
+ pe[:, 0::2] = torch.sin(pos * div)
82
+ pe[:, 1::2] = torch.cos(pos * div)
83
+ self.register_buffer("pe", pe.unsqueeze(0)) # (1, max_len, d_model)
84
+ def forward(self, x):
85
+ return x + self.pe[:, :x.size(1), :]
86
+
87
+ class TinyTransformer(nn.Module):
88
+ def __init__(self, vocab_size: int):
89
+ super().__init__()
90
+ self.src_emb = nn.Embedding(vocab_size, EMB_DIM, padding_idx=PAD)
91
+ self.tgt_emb = nn.Embedding(vocab_size, EMB_DIM, padding_idx=PAD)
92
+ self.pos_enc_src = PositionalEncoding(EMB_DIM, max_len=MAX_IN_LEN+8)
93
+ self.pos_enc_tgt = PositionalEncoding(EMB_DIM, max_len=MAX_OUT_LEN+8)
94
+
95
+ encoder_layer = nn.TransformerEncoderLayer(
96
+ d_model=EMB_DIM, nhead=N_HEAD, dim_feedforward=FF_DIM, dropout=DROPOUT, batch_first=True
97
+ )
98
+ decoder_layer = nn.TransformerDecoderLayer(
99
+ d_model=EMB_DIM, nhead=N_HEAD, dim_feedforward=FF_DIM, dropout=DROPOUT, batch_first=True
100
+ )
101
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=N_LAYERS)
102
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=N_LAYERS)
103
+ self.lm_head = nn.Linear(EMB_DIM, vocab_size)
104
+
105
+ def make_padding_mask(self, seq, pad_idx=PAD):
106
+ # returns (batch, 1, 1, seq_len) for nn.Transformer; but with batch_first=True we can use (batch, seq_len)
107
+ return (seq == pad_idx)
108
+
109
+ def generate_square_subsequent_mask(self, sz):
110
+ # causal mask for decoder (tgt): allow attention to previous positions only
111
+ return torch.triu(torch.ones(sz, sz, device=DEVICE), diagonal=1).bool()
112
+
113
+ def forward(self, src_ids, tgt_ids):
114
+ # src_ids: (B, S) ; tgt_ids: (B, T)
115
+ src_key_padding_mask = self.make_padding_mask(src_ids) # (B,S)
116
+ tgt_key_padding_mask = self.make_padding_mask(tgt_ids) # (B,T)
117
+ tgt_mask = self.generate_square_subsequent_mask(tgt_ids.size(1))
118
+
119
+ src = self.src_emb(src_ids)
120
+ src = self.pos_enc_src(src)
121
+ memory = self.encoder(src, src_key_padding_mask=src_key_padding_mask)
122
+
123
+ tgt = self.tgt_emb(tgt_ids)
124
+ tgt = self.pos_enc_tgt(tgt)
125
+ out = self.decoder(
126
+ tgt, memory,
127
+ tgt_mask=tgt_mask,
128
+ tgt_key_padding_mask=tgt_key_padding_mask,
129
+ memory_key_padding_mask=src_key_padding_mask
130
+ )
131
+ logits = self.lm_head(out) # (B,T,V)
132
+ return logits
133
+
134
+ # -----------------------
135
+ # Training utils
136
+ # -----------------------
137
+ class WarmupAdam(torch.optim.Adam):
138
+ def __init__(self, params, lr, warmup_steps=1000):
139
+ super().__init__(params, lr=lr, betas=(0.9, 0.98), eps=1e-9)
140
+ self.warmup_steps = warmup_steps
141
+ self._step = 0
142
+ self._base_lr = lr
143
+ def step(self, closure=None):
144
+ self._step += 1
145
+ scale = min(self._step ** (-0.5), self._step * (self.warmup_steps ** (-1.5)))
146
+ for g in self.param_groups:
147
+ g['lr'] = self._base_lr * scale * (self.warmup_steps ** 0.5)
148
+ return super().step(closure=closure)
149
+
150
+ def shift_right(tgt):
151
+ """
152
+ Teacher forcing: model sees BOS + y[:-1] and predicts y.
153
+ Here tgt is already [BOS, y..., EOS, PAD...]
154
+ We return inp=tgt[:, :-1], label=tgt[:, 1:]
155
+ """
156
+ return tgt[:, :-1], tgt[:, 1:]
157
+
158
+ def accuracy_from_logits(logits, labels, pad=PAD):
159
+ # logits: (B,T,V), labels: (B,T)
160
+ preds = logits.argmax(-1)
161
+ mask = labels.ne(pad)
162
+ correct = (preds.eq(labels) & mask).sum().item()
163
+ total = mask.sum().item() + 1e-9
164
+ return correct/total
165
+
166
+ # -----------------------
167
+ # Train
168
+ # -----------------------
169
+ model = TinyTransformer(vocab_size=len(VOCAB)).to(DEVICE)
170
+ criterion = nn.CrossEntropyLoss(ignore_index=PAD)
171
+ optim = WarmupAdam(model.parameters(), lr=LR, warmup_steps=WARMUP)
172
+
173
+ def run_epoch(dl, train=True):
174
+ model.train(train)
175
+ total_loss, total_acc, n = 0.0, 0.0, 0
176
+ for src, tgt in dl:
177
+ src, tgt = src.to(DEVICE), tgt.to(DEVICE)
178
+ dec_inp, labels = shift_right(tgt)
179
+ logits = model(src, dec_inp)
180
+ loss = criterion(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
181
+ acc = accuracy_from_logits(logits, labels)
182
+
183
+ if train:
184
+ optim.zero_grad(set_to_none=True)
185
+ loss.backward()
186
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
187
+ optim.step()
188
+
189
+ bs = src.size(0)
190
+ total_loss += loss.item() * bs
191
+ total_acc += acc * bs
192
+ n += bs
193
+ return total_loss/n, total_acc/n
194
+
195
+ best_val = 0.0
196
+ for step in range(1, STEPS+1):
197
+ tr_loss, tr_acc = run_epoch(train_dl, train=True)
198
+ if step % 50 == 0:
199
+ val_loss, val_acc = run_epoch(val_dl, train=False)
200
+ print(f"[step {step:4d}] train loss {tr_loss:.3f} acc {tr_acc:.3f} | val loss {val_loss:.3f} acc {val_acc:.3f}")
201
+ if val_acc > best_val:
202
+ best_val = val_acc
203
+ torch.save({"model": model.state_dict()}, SAVE_PATH)
204
+
205
+ print(f"Saved best model to: {SAVE_PATH}")
206
+
207
+ # -----------------------
208
+ # Inference demo
209
+ # -----------------------
210
+ def encode_inp(a:int,b:int):
211
+ seq = [str(a), "+", str(b), "="]
212
+ return torch.tensor([encode(seq, MAX_IN_LEN)], device=DEVICE)
213
+
214
+ def greedy_decode(src_ids, max_len=MAX_OUT_LEN+2):
215
+ model.eval()
216
+ with torch.no_grad():
217
+ # Start with BOS
218
+ ys = torch.tensor([[BOS]], device=DEVICE)
219
+ for _ in range(max_len-1):
220
+ logits = model(src_ids, ys)
221
+ next_tok = logits[:, -1, :].argmax(-1, keepdim=True) # (B,1)
222
+ ys = torch.cat([ys, next_tok], dim=1)
223
+ if next_tok.item() == EOS:
224
+ break
225
+ return ys.squeeze(0).tolist()
226
+
227
+ def detok(ids: List[int]) -> str:
228
+ toks = [ID2TOK[i] for i in ids if i not in (PAD, BOS)]
229
+ out = []
230
+ for t in toks:
231
+ if t == "</s>": break
232
+ out.append(t)
233
+ return "".join(out)
234
+
235
+ # Load best (optional, already in memory)
236
+ ckpt = torch.load(SAVE_PATH, map_location=DEVICE)
237
+ model.load_state_dict(ckpt["model"])
238
+
239
+ for (a,b) in [(3,5),(9,8),(0,0),(7,2),(4,6)]:
240
+ src = encode_inp(a,b)
241
+ out_ids = greedy_decode(src)
242
+ print(f"{a}+{b} => {detok(out_ids)}")