ankitkushwaha90 commited on
Commit
049f424
·
verified ·
1 Parent(s): 4efecb4

Create transformer-RNN.py

Browse files
Files changed (1) hide show
  1. transformer-RNN.py +114 -0
transformer-RNN.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rnn_transformer_mini.py
2
+ # Hybrid RNN + Transformer for text classification (toy dataset)
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import random
9
+
10
+ # ------------------------
11
+ # Mini synthetic dataset
12
+ # ------------------------
13
+ # Class 0: greetings, Class 1: food related
14
+ class MiniTextDataset(Dataset):
15
+ def __init__(self, size=200):
16
+ self.samples = []
17
+ greetings = ["hello", "hi", "good morning", "hey", "greetings"]
18
+ food = ["i love pizza", "burger is tasty", "rice and curry", "pasta is great", "i eat apple"]
19
+
20
+ for _ in range(size):
21
+ if random.random() < 0.5:
22
+ sent = random.choice(greetings)
23
+ label = 0
24
+ else:
25
+ sent = random.choice(food)
26
+ label = 1
27
+ self.samples.append((sent, label))
28
+
29
+ self.vocab = {"<pad>":0, "<unk>":1}
30
+ idx = 2
31
+ for s,_ in self.samples:
32
+ for w in s.split():
33
+ if w not in self.vocab:
34
+ self.vocab[w] = idx
35
+ idx += 1
36
+ self.inv_vocab = {i:w for w,i in self.vocab.items()}
37
+
38
+ def encode(self, text, max_len=6):
39
+ ids = [self.vocab.get(w,1) for w in text.split()]
40
+ if len(ids) < max_len:
41
+ ids += [0]*(max_len-len(ids))
42
+ return ids[:max_len]
43
+
44
+ def __len__(self): return len(self.samples)
45
+ def __getitem__(self, idx):
46
+ text, label = self.samples[idx]
47
+ return torch.tensor(self.encode(text)), torch.tensor(label)
48
+
49
+ train_ds = MiniTextDataset(size=200)
50
+ test_ds = MiniTextDataset(size=50)
51
+ train_dl = DataLoader(train_ds, batch_size=16, shuffle=True)
52
+ test_dl = DataLoader(test_ds, batch_size=16)
53
+
54
+ # ------------------------
55
+ # Model: RNN + Transformer
56
+ # ------------------------
57
+ class RNNTransformer(nn.Module):
58
+ def __init__(self, vocab_size, emb_dim=64, rnn_hidden=64, nhead=4, num_layers=2, num_classes=2):
59
+ super().__init__()
60
+ self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
61
+ self.rnn = nn.LSTM(emb_dim, rnn_hidden, batch_first=True, bidirectional=True)
62
+
63
+ self.pos_emb = nn.Parameter(torch.randn(1, 6, rnn_hidden*2)) # seq_len=6
64
+ encoder_layer = nn.TransformerEncoderLayer(
65
+ d_model=rnn_hidden*2, nhead=nhead, batch_first=True
66
+ )
67
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
68
+
69
+ self.fc = nn.Linear(rnn_hidden*2, num_classes)
70
+
71
+ def forward(self, x):
72
+ emb = self.emb(x) # (B,L,E)
73
+ rnn_out, _ = self.rnn(emb) # (B,L,H*2)
74
+ seq = rnn_out + self.pos_emb # add learnable pos emb
75
+ enc = self.transformer(seq) # (B,L,H*2)
76
+ pooled = enc.mean(dim=1) # average pooling
77
+ return self.fc(pooled)
78
+
79
+ # ------------------------
80
+ # Train & Evaluate
81
+ # ------------------------
82
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
+ model = RNNTransformer(vocab_size=len(train_ds.vocab)).to(DEVICE)
84
+ criterion = nn.CrossEntropyLoss()
85
+ optimizer = optim.Adam(model.parameters(), lr=1e-3)
86
+
87
+ for epoch in range(5):
88
+ model.train()
89
+ total_loss, total_correct = 0,0
90
+ for x,y in train_dl:
91
+ x,y = x.to(DEVICE), y.to(DEVICE)
92
+ out = model(x)
93
+ loss = criterion(out, y)
94
+
95
+ optimizer.zero_grad()
96
+ loss.backward()
97
+ optimizer.step()
98
+
99
+ total_loss += loss.item()*x.size(0)
100
+ total_correct += (out.argmax(1)==y).sum().item()
101
+
102
+ acc = total_correct/len(train_ds)
103
+ print(f"Epoch {epoch+1}: Train Loss={total_loss/len(train_ds):.4f}, Acc={acc:.4f}")
104
+
105
+ # Eval
106
+ model.eval()
107
+ correct=0
108
+ with torch.no_grad():
109
+ for x,y in test_dl:
110
+ x,y = x.to(DEVICE), y.to(DEVICE)
111
+ out = model(x)
112
+ correct += (out.argmax(1)==y).sum().item()
113
+ print(f"Test Accuracy: {correct/len(test_ds):.4f}")
114
+