# transformer_mini_addition.py # Train a tiny Transformer encoder-decoder on a toy task: "a + b =" -> "c" # Example: input tokens ["3","+","5","="] -> output tokens ["8"] # Run: python transformer_mini_addition.py import math import random from typing import List, Tuple import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader # ----------------------- # Config # ----------------------- VOCAB = ["0","1","2","3","4","5","6","7","8","9","+","=","","",""] PAD, BOS, EOS = VOCAB.index(""), VOCAB.index(""), VOCAB.index("") TOK2ID = {t:i for i,t in enumerate(VOCAB)} ID2TOK = {i:t for t,i in TOK2ID.items()} MAX_IN_LEN = 4 # "d + d =" -> 4 tokens MAX_OUT_LEN = 3 # could be 1 or 2 digits + EOS EMB_DIM = 128 FF_DIM = 256 N_HEAD = 4 N_LAYERS = 2 DROPOUT = 0.1 BATCH_SIZE = 128 STEPS = 1500 # keep small for a mini training run LR = 3e-4 WARMUP = 100 SAVE_PATH = "mini_transformer_addition.pt" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.manual_seed(42) random.seed(42) # ----------------------- # Data: generate on-the-fly mini dataset # ----------------------- def encode(seq: List[str], max_len: int) -> List[int]: ids = [TOK2ID[s] for s in seq] if len(ids) < max_len: ids += [PAD]*(max_len-len(ids)) return ids[:max_len] def sample_pair() -> Tuple[List[int], List[int]]: a, b = random.randint(0,9), random.randint(0,9) c = a + b inp = [str(a), "+", str(b), "="] # length 4 out_tokens = list(str(c)) # "0".."18" tgt = [BOS] + [TOK2ID[t] for t in out_tokens] + [EOS] # BOS ... EOS # pad to MAX_OUT_LEN + 2 (BOS/EOS) max_len = MAX_OUT_LEN + 2 if len(tgt) < max_len: tgt += [PAD] * (max_len - len(tgt)) return encode(inp, MAX_IN_LEN), tgt class MiniAddDataset(Dataset): def __init__(self, size=5000): self.size = size def __len__(self): return self.size def __getitem__(self, idx): src, tgt = sample_pair() return torch.tensor(src), torch.tensor(tgt) train_ds = MiniAddDataset(size=8000) val_ds = MiniAddDataset(size=500) train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE) # ----------------------- # Model # ----------------------- class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=64): super().__init__() pe = torch.zeros(max_len, d_model) pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model)) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) self.register_buffer("pe", pe.unsqueeze(0)) # (1, max_len, d_model) def forward(self, x): return x + self.pe[:, :x.size(1), :] class TinyTransformer(nn.Module): def __init__(self, vocab_size: int): super().__init__() self.src_emb = nn.Embedding(vocab_size, EMB_DIM, padding_idx=PAD) self.tgt_emb = nn.Embedding(vocab_size, EMB_DIM, padding_idx=PAD) self.pos_enc_src = PositionalEncoding(EMB_DIM, max_len=MAX_IN_LEN+8) self.pos_enc_tgt = PositionalEncoding(EMB_DIM, max_len=MAX_OUT_LEN+8) encoder_layer = nn.TransformerEncoderLayer( d_model=EMB_DIM, nhead=N_HEAD, dim_feedforward=FF_DIM, dropout=DROPOUT, batch_first=True ) decoder_layer = nn.TransformerDecoderLayer( d_model=EMB_DIM, nhead=N_HEAD, dim_feedforward=FF_DIM, dropout=DROPOUT, batch_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=N_LAYERS) self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=N_LAYERS) self.lm_head = nn.Linear(EMB_DIM, vocab_size) def make_padding_mask(self, seq, pad_idx=PAD): # returns (batch, 1, 1, seq_len) for nn.Transformer; but with batch_first=True we can use (batch, seq_len) return (seq == pad_idx) def generate_square_subsequent_mask(self, sz): # causal mask for decoder (tgt): allow attention to previous positions only return torch.triu(torch.ones(sz, sz, device=DEVICE), diagonal=1).bool() def forward(self, src_ids, tgt_ids): # src_ids: (B, S) ; tgt_ids: (B, T) src_key_padding_mask = self.make_padding_mask(src_ids) # (B,S) tgt_key_padding_mask = self.make_padding_mask(tgt_ids) # (B,T) tgt_mask = self.generate_square_subsequent_mask(tgt_ids.size(1)) src = self.src_emb(src_ids) src = self.pos_enc_src(src) memory = self.encoder(src, src_key_padding_mask=src_key_padding_mask) tgt = self.tgt_emb(tgt_ids) tgt = self.pos_enc_tgt(tgt) out = self.decoder( tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=src_key_padding_mask ) logits = self.lm_head(out) # (B,T,V) return logits # ----------------------- # Training utils # ----------------------- class WarmupAdam(torch.optim.Adam): def __init__(self, params, lr, warmup_steps=1000): super().__init__(params, lr=lr, betas=(0.9, 0.98), eps=1e-9) self.warmup_steps = warmup_steps self._step = 0 self._base_lr = lr def step(self, closure=None): self._step += 1 scale = min(self._step ** (-0.5), self._step * (self.warmup_steps ** (-1.5))) for g in self.param_groups: g['lr'] = self._base_lr * scale * (self.warmup_steps ** 0.5) return super().step(closure=closure) def shift_right(tgt): """ Teacher forcing: model sees BOS + y[:-1] and predicts y. Here tgt is already [BOS, y..., EOS, PAD...] We return inp=tgt[:, :-1], label=tgt[:, 1:] """ return tgt[:, :-1], tgt[:, 1:] def accuracy_from_logits(logits, labels, pad=PAD): # logits: (B,T,V), labels: (B,T) preds = logits.argmax(-1) mask = labels.ne(pad) correct = (preds.eq(labels) & mask).sum().item() total = mask.sum().item() + 1e-9 return correct/total # ----------------------- # Train # ----------------------- model = TinyTransformer(vocab_size=len(VOCAB)).to(DEVICE) criterion = nn.CrossEntropyLoss(ignore_index=PAD) optim = WarmupAdam(model.parameters(), lr=LR, warmup_steps=WARMUP) def run_epoch(dl, train=True): model.train(train) total_loss, total_acc, n = 0.0, 0.0, 0 for src, tgt in dl: src, tgt = src.to(DEVICE), tgt.to(DEVICE) dec_inp, labels = shift_right(tgt) logits = model(src, dec_inp) loss = criterion(logits.reshape(-1, logits.size(-1)), labels.reshape(-1)) acc = accuracy_from_logits(logits, labels) if train: optim.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optim.step() bs = src.size(0) total_loss += loss.item() * bs total_acc += acc * bs n += bs return total_loss/n, total_acc/n best_val = 0.0 for step in range(1, STEPS+1): tr_loss, tr_acc = run_epoch(train_dl, train=True) if step % 50 == 0: val_loss, val_acc = run_epoch(val_dl, train=False) print(f"[step {step:4d}] train loss {tr_loss:.3f} acc {tr_acc:.3f} | val loss {val_loss:.3f} acc {val_acc:.3f}") if val_acc > best_val: best_val = val_acc torch.save({"model": model.state_dict()}, SAVE_PATH) print(f"Saved best model to: {SAVE_PATH}") # ----------------------- # Inference demo # ----------------------- def encode_inp(a:int,b:int): seq = [str(a), "+", str(b), "="] return torch.tensor([encode(seq, MAX_IN_LEN)], device=DEVICE) def greedy_decode(src_ids, max_len=MAX_OUT_LEN+2): model.eval() with torch.no_grad(): # Start with BOS ys = torch.tensor([[BOS]], device=DEVICE) for _ in range(max_len-1): logits = model(src_ids, ys) next_tok = logits[:, -1, :].argmax(-1, keepdim=True) # (B,1) ys = torch.cat([ys, next_tok], dim=1) if next_tok.item() == EOS: break return ys.squeeze(0).tolist() def detok(ids: List[int]) -> str: toks = [ID2TOK[i] for i in ids if i not in (PAD, BOS)] out = [] for t in toks: if t == "": break out.append(t) return "".join(out) # Load best (optional, already in memory) ckpt = torch.load(SAVE_PATH, map_location=DEVICE) model.load_state_dict(ckpt["model"]) for (a,b) in [(3,5),(9,8),(0,0),(7,2),(4,6)]: src = encode_inp(a,b) out_ids = greedy_decode(src) print(f"{a}+{b} => {detok(out_ids)}")