ankitkushwaha90's picture
Create transformer-RNN.py
049f424 verified
# rnn_transformer_mini.py
# Hybrid RNN + Transformer for text classification (toy dataset)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
# ------------------------
# Mini synthetic dataset
# ------------------------
# Class 0: greetings, Class 1: food related
class MiniTextDataset(Dataset):
def __init__(self, size=200):
self.samples = []
greetings = ["hello", "hi", "good morning", "hey", "greetings"]
food = ["i love pizza", "burger is tasty", "rice and curry", "pasta is great", "i eat apple"]
for _ in range(size):
if random.random() < 0.5:
sent = random.choice(greetings)
label = 0
else:
sent = random.choice(food)
label = 1
self.samples.append((sent, label))
self.vocab = {"<pad>":0, "<unk>":1}
idx = 2
for s,_ in self.samples:
for w in s.split():
if w not in self.vocab:
self.vocab[w] = idx
idx += 1
self.inv_vocab = {i:w for w,i in self.vocab.items()}
def encode(self, text, max_len=6):
ids = [self.vocab.get(w,1) for w in text.split()]
if len(ids) < max_len:
ids += [0]*(max_len-len(ids))
return ids[:max_len]
def __len__(self): return len(self.samples)
def __getitem__(self, idx):
text, label = self.samples[idx]
return torch.tensor(self.encode(text)), torch.tensor(label)
train_ds = MiniTextDataset(size=200)
test_ds = MiniTextDataset(size=50)
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=16)
# ------------------------
# Model: RNN + Transformer
# ------------------------
class RNNTransformer(nn.Module):
def __init__(self, vocab_size, emb_dim=64, rnn_hidden=64, nhead=4, num_layers=2, num_classes=2):
super().__init__()
self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
self.rnn = nn.LSTM(emb_dim, rnn_hidden, batch_first=True, bidirectional=True)
self.pos_emb = nn.Parameter(torch.randn(1, 6, rnn_hidden*2)) # seq_len=6
encoder_layer = nn.TransformerEncoderLayer(
d_model=rnn_hidden*2, nhead=nhead, batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc = nn.Linear(rnn_hidden*2, num_classes)
def forward(self, x):
emb = self.emb(x) # (B,L,E)
rnn_out, _ = self.rnn(emb) # (B,L,H*2)
seq = rnn_out + self.pos_emb # add learnable pos emb
enc = self.transformer(seq) # (B,L,H*2)
pooled = enc.mean(dim=1) # average pooling
return self.fc(pooled)
# ------------------------
# Train & Evaluate
# ------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RNNTransformer(vocab_size=len(train_ds.vocab)).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(5):
model.train()
total_loss, total_correct = 0,0
for x,y in train_dl:
x,y = x.to(DEVICE), y.to(DEVICE)
out = model(x)
loss = criterion(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()*x.size(0)
total_correct += (out.argmax(1)==y).sum().item()
acc = total_correct/len(train_ds)
print(f"Epoch {epoch+1}: Train Loss={total_loss/len(train_ds):.4f}, Acc={acc:.4f}")
# Eval
model.eval()
correct=0
with torch.no_grad():
for x,y in test_dl:
x,y = x.to(DEVICE), y.to(DEVICE)
out = model(x)
correct += (out.argmax(1)==y).sum().item()
print(f"Test Accuracy: {correct/len(test_ds):.4f}")