Attention_is_all_you_need_transformers / mini-transformers-ANN.py
ankitkushwaha90's picture
Create mini-transformers-ANN.py
6fc6ec8 verified
# transformer_mini_addition.py
# Train a tiny Transformer encoder-decoder on a toy task: "a + b =" -> "c"
# Example: input tokens ["3","+","5","="] -> output tokens ["8"]
# Run: python transformer_mini_addition.py
import math
import random
from typing import List, Tuple
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
# -----------------------
# Config
# -----------------------
VOCAB = ["0","1","2","3","4","5","6","7","8","9","+","=","<pad>","<s>","</s>"]
PAD, BOS, EOS = VOCAB.index("<pad>"), VOCAB.index("<s>"), VOCAB.index("</s>")
TOK2ID = {t:i for i,t in enumerate(VOCAB)}
ID2TOK = {i:t for t,i in TOK2ID.items()}
MAX_IN_LEN = 4 # "d + d =" -> 4 tokens
MAX_OUT_LEN = 3 # could be 1 or 2 digits + EOS
EMB_DIM = 128
FF_DIM = 256
N_HEAD = 4
N_LAYERS = 2
DROPOUT = 0.1
BATCH_SIZE = 128
STEPS = 1500 # keep small for a mini training run
LR = 3e-4
WARMUP = 100
SAVE_PATH = "mini_transformer_addition.pt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
random.seed(42)
# -----------------------
# Data: generate on-the-fly mini dataset
# -----------------------
def encode(seq: List[str], max_len: int) -> List[int]:
ids = [TOK2ID[s] for s in seq]
if len(ids) < max_len:
ids += [PAD]*(max_len-len(ids))
return ids[:max_len]
def sample_pair() -> Tuple[List[int], List[int]]:
a, b = random.randint(0,9), random.randint(0,9)
c = a + b
inp = [str(a), "+", str(b), "="] # length 4
out_tokens = list(str(c)) # "0".."18"
tgt = [BOS] + [TOK2ID[t] for t in out_tokens] + [EOS] # BOS ... EOS
# pad to MAX_OUT_LEN + 2 (BOS/EOS)
max_len = MAX_OUT_LEN + 2
if len(tgt) < max_len:
tgt += [PAD] * (max_len - len(tgt))
return encode(inp, MAX_IN_LEN), tgt
class MiniAddDataset(Dataset):
def __init__(self, size=5000):
self.size = size
def __len__(self): return self.size
def __getitem__(self, idx):
src, tgt = sample_pair()
return torch.tensor(src), torch.tensor(tgt)
train_ds = MiniAddDataset(size=8000)
val_ds = MiniAddDataset(size=500)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE)
# -----------------------
# Model
# -----------------------
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=64):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe.unsqueeze(0)) # (1, max_len, d_model)
def forward(self, x):
return x + self.pe[:, :x.size(1), :]
class TinyTransformer(nn.Module):
def __init__(self, vocab_size: int):
super().__init__()
self.src_emb = nn.Embedding(vocab_size, EMB_DIM, padding_idx=PAD)
self.tgt_emb = nn.Embedding(vocab_size, EMB_DIM, padding_idx=PAD)
self.pos_enc_src = PositionalEncoding(EMB_DIM, max_len=MAX_IN_LEN+8)
self.pos_enc_tgt = PositionalEncoding(EMB_DIM, max_len=MAX_OUT_LEN+8)
encoder_layer = nn.TransformerEncoderLayer(
d_model=EMB_DIM, nhead=N_HEAD, dim_feedforward=FF_DIM, dropout=DROPOUT, batch_first=True
)
decoder_layer = nn.TransformerDecoderLayer(
d_model=EMB_DIM, nhead=N_HEAD, dim_feedforward=FF_DIM, dropout=DROPOUT, batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=N_LAYERS)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=N_LAYERS)
self.lm_head = nn.Linear(EMB_DIM, vocab_size)
def make_padding_mask(self, seq, pad_idx=PAD):
# returns (batch, 1, 1, seq_len) for nn.Transformer; but with batch_first=True we can use (batch, seq_len)
return (seq == pad_idx)
def generate_square_subsequent_mask(self, sz):
# causal mask for decoder (tgt): allow attention to previous positions only
return torch.triu(torch.ones(sz, sz, device=DEVICE), diagonal=1).bool()
def forward(self, src_ids, tgt_ids):
# src_ids: (B, S) ; tgt_ids: (B, T)
src_key_padding_mask = self.make_padding_mask(src_ids) # (B,S)
tgt_key_padding_mask = self.make_padding_mask(tgt_ids) # (B,T)
tgt_mask = self.generate_square_subsequent_mask(tgt_ids.size(1))
src = self.src_emb(src_ids)
src = self.pos_enc_src(src)
memory = self.encoder(src, src_key_padding_mask=src_key_padding_mask)
tgt = self.tgt_emb(tgt_ids)
tgt = self.pos_enc_tgt(tgt)
out = self.decoder(
tgt, memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask
)
logits = self.lm_head(out) # (B,T,V)
return logits
# -----------------------
# Training utils
# -----------------------
class WarmupAdam(torch.optim.Adam):
def __init__(self, params, lr, warmup_steps=1000):
super().__init__(params, lr=lr, betas=(0.9, 0.98), eps=1e-9)
self.warmup_steps = warmup_steps
self._step = 0
self._base_lr = lr
def step(self, closure=None):
self._step += 1
scale = min(self._step ** (-0.5), self._step * (self.warmup_steps ** (-1.5)))
for g in self.param_groups:
g['lr'] = self._base_lr * scale * (self.warmup_steps ** 0.5)
return super().step(closure=closure)
def shift_right(tgt):
"""
Teacher forcing: model sees BOS + y[:-1] and predicts y.
Here tgt is already [BOS, y..., EOS, PAD...]
We return inp=tgt[:, :-1], label=tgt[:, 1:]
"""
return tgt[:, :-1], tgt[:, 1:]
def accuracy_from_logits(logits, labels, pad=PAD):
# logits: (B,T,V), labels: (B,T)
preds = logits.argmax(-1)
mask = labels.ne(pad)
correct = (preds.eq(labels) & mask).sum().item()
total = mask.sum().item() + 1e-9
return correct/total
# -----------------------
# Train
# -----------------------
model = TinyTransformer(vocab_size=len(VOCAB)).to(DEVICE)
criterion = nn.CrossEntropyLoss(ignore_index=PAD)
optim = WarmupAdam(model.parameters(), lr=LR, warmup_steps=WARMUP)
def run_epoch(dl, train=True):
model.train(train)
total_loss, total_acc, n = 0.0, 0.0, 0
for src, tgt in dl:
src, tgt = src.to(DEVICE), tgt.to(DEVICE)
dec_inp, labels = shift_right(tgt)
logits = model(src, dec_inp)
loss = criterion(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
acc = accuracy_from_logits(logits, labels)
if train:
optim.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optim.step()
bs = src.size(0)
total_loss += loss.item() * bs
total_acc += acc * bs
n += bs
return total_loss/n, total_acc/n
best_val = 0.0
for step in range(1, STEPS+1):
tr_loss, tr_acc = run_epoch(train_dl, train=True)
if step % 50 == 0:
val_loss, val_acc = run_epoch(val_dl, train=False)
print(f"[step {step:4d}] train loss {tr_loss:.3f} acc {tr_acc:.3f} | val loss {val_loss:.3f} acc {val_acc:.3f}")
if val_acc > best_val:
best_val = val_acc
torch.save({"model": model.state_dict()}, SAVE_PATH)
print(f"Saved best model to: {SAVE_PATH}")
# -----------------------
# Inference demo
# -----------------------
def encode_inp(a:int,b:int):
seq = [str(a), "+", str(b), "="]
return torch.tensor([encode(seq, MAX_IN_LEN)], device=DEVICE)
def greedy_decode(src_ids, max_len=MAX_OUT_LEN+2):
model.eval()
with torch.no_grad():
# Start with BOS
ys = torch.tensor([[BOS]], device=DEVICE)
for _ in range(max_len-1):
logits = model(src_ids, ys)
next_tok = logits[:, -1, :].argmax(-1, keepdim=True) # (B,1)
ys = torch.cat([ys, next_tok], dim=1)
if next_tok.item() == EOS:
break
return ys.squeeze(0).tolist()
def detok(ids: List[int]) -> str:
toks = [ID2TOK[i] for i in ids if i not in (PAD, BOS)]
out = []
for t in toks:
if t == "</s>": break
out.append(t)
return "".join(out)
# Load best (optional, already in memory)
ckpt = torch.load(SAVE_PATH, map_location=DEVICE)
model.load_state_dict(ckpt["model"])
for (a,b) in [(3,5),(9,8),(0,0),(7,2),(4,6)]:
src = encode_inp(a,b)
out_ids = greedy_decode(src)
print(f"{a}+{b} => {detok(out_ids)}")