|
|
|
|
|
|
|
|
|
|
|
import math |
|
import random |
|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
|
|
|
|
VOCAB = ["0","1","2","3","4","5","6","7","8","9","+","=","<pad>","<s>","</s>"] |
|
PAD, BOS, EOS = VOCAB.index("<pad>"), VOCAB.index("<s>"), VOCAB.index("</s>") |
|
TOK2ID = {t:i for i,t in enumerate(VOCAB)} |
|
ID2TOK = {i:t for t,i in TOK2ID.items()} |
|
|
|
MAX_IN_LEN = 4 |
|
MAX_OUT_LEN = 3 |
|
EMB_DIM = 128 |
|
FF_DIM = 256 |
|
N_HEAD = 4 |
|
N_LAYERS = 2 |
|
DROPOUT = 0.1 |
|
BATCH_SIZE = 128 |
|
STEPS = 1500 |
|
LR = 3e-4 |
|
WARMUP = 100 |
|
SAVE_PATH = "mini_transformer_addition.pt" |
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
torch.manual_seed(42) |
|
random.seed(42) |
|
|
|
|
|
|
|
|
|
def encode(seq: List[str], max_len: int) -> List[int]: |
|
ids = [TOK2ID[s] for s in seq] |
|
if len(ids) < max_len: |
|
ids += [PAD]*(max_len-len(ids)) |
|
return ids[:max_len] |
|
|
|
def sample_pair() -> Tuple[List[int], List[int]]: |
|
a, b = random.randint(0,9), random.randint(0,9) |
|
c = a + b |
|
inp = [str(a), "+", str(b), "="] |
|
out_tokens = list(str(c)) |
|
tgt = [BOS] + [TOK2ID[t] for t in out_tokens] + [EOS] |
|
|
|
max_len = MAX_OUT_LEN + 2 |
|
if len(tgt) < max_len: |
|
tgt += [PAD] * (max_len - len(tgt)) |
|
return encode(inp, MAX_IN_LEN), tgt |
|
|
|
class MiniAddDataset(Dataset): |
|
def __init__(self, size=5000): |
|
self.size = size |
|
def __len__(self): return self.size |
|
def __getitem__(self, idx): |
|
src, tgt = sample_pair() |
|
return torch.tensor(src), torch.tensor(tgt) |
|
|
|
train_ds = MiniAddDataset(size=8000) |
|
val_ds = MiniAddDataset(size=500) |
|
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) |
|
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE) |
|
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
def __init__(self, d_model, max_len=64): |
|
super().__init__() |
|
pe = torch.zeros(max_len, d_model) |
|
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model)) |
|
pe[:, 0::2] = torch.sin(pos * div) |
|
pe[:, 1::2] = torch.cos(pos * div) |
|
self.register_buffer("pe", pe.unsqueeze(0)) |
|
def forward(self, x): |
|
return x + self.pe[:, :x.size(1), :] |
|
|
|
class TinyTransformer(nn.Module): |
|
def __init__(self, vocab_size: int): |
|
super().__init__() |
|
self.src_emb = nn.Embedding(vocab_size, EMB_DIM, padding_idx=PAD) |
|
self.tgt_emb = nn.Embedding(vocab_size, EMB_DIM, padding_idx=PAD) |
|
self.pos_enc_src = PositionalEncoding(EMB_DIM, max_len=MAX_IN_LEN+8) |
|
self.pos_enc_tgt = PositionalEncoding(EMB_DIM, max_len=MAX_OUT_LEN+8) |
|
|
|
encoder_layer = nn.TransformerEncoderLayer( |
|
d_model=EMB_DIM, nhead=N_HEAD, dim_feedforward=FF_DIM, dropout=DROPOUT, batch_first=True |
|
) |
|
decoder_layer = nn.TransformerDecoderLayer( |
|
d_model=EMB_DIM, nhead=N_HEAD, dim_feedforward=FF_DIM, dropout=DROPOUT, batch_first=True |
|
) |
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=N_LAYERS) |
|
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=N_LAYERS) |
|
self.lm_head = nn.Linear(EMB_DIM, vocab_size) |
|
|
|
def make_padding_mask(self, seq, pad_idx=PAD): |
|
|
|
return (seq == pad_idx) |
|
|
|
def generate_square_subsequent_mask(self, sz): |
|
|
|
return torch.triu(torch.ones(sz, sz, device=DEVICE), diagonal=1).bool() |
|
|
|
def forward(self, src_ids, tgt_ids): |
|
|
|
src_key_padding_mask = self.make_padding_mask(src_ids) |
|
tgt_key_padding_mask = self.make_padding_mask(tgt_ids) |
|
tgt_mask = self.generate_square_subsequent_mask(tgt_ids.size(1)) |
|
|
|
src = self.src_emb(src_ids) |
|
src = self.pos_enc_src(src) |
|
memory = self.encoder(src, src_key_padding_mask=src_key_padding_mask) |
|
|
|
tgt = self.tgt_emb(tgt_ids) |
|
tgt = self.pos_enc_tgt(tgt) |
|
out = self.decoder( |
|
tgt, memory, |
|
tgt_mask=tgt_mask, |
|
tgt_key_padding_mask=tgt_key_padding_mask, |
|
memory_key_padding_mask=src_key_padding_mask |
|
) |
|
logits = self.lm_head(out) |
|
return logits |
|
|
|
|
|
|
|
|
|
class WarmupAdam(torch.optim.Adam): |
|
def __init__(self, params, lr, warmup_steps=1000): |
|
super().__init__(params, lr=lr, betas=(0.9, 0.98), eps=1e-9) |
|
self.warmup_steps = warmup_steps |
|
self._step = 0 |
|
self._base_lr = lr |
|
def step(self, closure=None): |
|
self._step += 1 |
|
scale = min(self._step ** (-0.5), self._step * (self.warmup_steps ** (-1.5))) |
|
for g in self.param_groups: |
|
g['lr'] = self._base_lr * scale * (self.warmup_steps ** 0.5) |
|
return super().step(closure=closure) |
|
|
|
def shift_right(tgt): |
|
""" |
|
Teacher forcing: model sees BOS + y[:-1] and predicts y. |
|
Here tgt is already [BOS, y..., EOS, PAD...] |
|
We return inp=tgt[:, :-1], label=tgt[:, 1:] |
|
""" |
|
return tgt[:, :-1], tgt[:, 1:] |
|
|
|
def accuracy_from_logits(logits, labels, pad=PAD): |
|
|
|
preds = logits.argmax(-1) |
|
mask = labels.ne(pad) |
|
correct = (preds.eq(labels) & mask).sum().item() |
|
total = mask.sum().item() + 1e-9 |
|
return correct/total |
|
|
|
|
|
|
|
|
|
model = TinyTransformer(vocab_size=len(VOCAB)).to(DEVICE) |
|
criterion = nn.CrossEntropyLoss(ignore_index=PAD) |
|
optim = WarmupAdam(model.parameters(), lr=LR, warmup_steps=WARMUP) |
|
|
|
def run_epoch(dl, train=True): |
|
model.train(train) |
|
total_loss, total_acc, n = 0.0, 0.0, 0 |
|
for src, tgt in dl: |
|
src, tgt = src.to(DEVICE), tgt.to(DEVICE) |
|
dec_inp, labels = shift_right(tgt) |
|
logits = model(src, dec_inp) |
|
loss = criterion(logits.reshape(-1, logits.size(-1)), labels.reshape(-1)) |
|
acc = accuracy_from_logits(logits, labels) |
|
|
|
if train: |
|
optim.zero_grad(set_to_none=True) |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
optim.step() |
|
|
|
bs = src.size(0) |
|
total_loss += loss.item() * bs |
|
total_acc += acc * bs |
|
n += bs |
|
return total_loss/n, total_acc/n |
|
|
|
best_val = 0.0 |
|
for step in range(1, STEPS+1): |
|
tr_loss, tr_acc = run_epoch(train_dl, train=True) |
|
if step % 50 == 0: |
|
val_loss, val_acc = run_epoch(val_dl, train=False) |
|
print(f"[step {step:4d}] train loss {tr_loss:.3f} acc {tr_acc:.3f} | val loss {val_loss:.3f} acc {val_acc:.3f}") |
|
if val_acc > best_val: |
|
best_val = val_acc |
|
torch.save({"model": model.state_dict()}, SAVE_PATH) |
|
|
|
print(f"Saved best model to: {SAVE_PATH}") |
|
|
|
|
|
|
|
|
|
def encode_inp(a:int,b:int): |
|
seq = [str(a), "+", str(b), "="] |
|
return torch.tensor([encode(seq, MAX_IN_LEN)], device=DEVICE) |
|
|
|
def greedy_decode(src_ids, max_len=MAX_OUT_LEN+2): |
|
model.eval() |
|
with torch.no_grad(): |
|
|
|
ys = torch.tensor([[BOS]], device=DEVICE) |
|
for _ in range(max_len-1): |
|
logits = model(src_ids, ys) |
|
next_tok = logits[:, -1, :].argmax(-1, keepdim=True) |
|
ys = torch.cat([ys, next_tok], dim=1) |
|
if next_tok.item() == EOS: |
|
break |
|
return ys.squeeze(0).tolist() |
|
|
|
def detok(ids: List[int]) -> str: |
|
toks = [ID2TOK[i] for i in ids if i not in (PAD, BOS)] |
|
out = [] |
|
for t in toks: |
|
if t == "</s>": break |
|
out.append(t) |
|
return "".join(out) |
|
|
|
|
|
ckpt = torch.load(SAVE_PATH, map_location=DEVICE) |
|
model.load_state_dict(ckpt["model"]) |
|
|
|
for (a,b) in [(3,5),(9,8),(0,0),(7,2),(4,6)]: |
|
src = encode_inp(a,b) |
|
out_ids = greedy_decode(src) |
|
print(f"{a}+{b} => {detok(out_ids)}") |
|
|