File size: 3,914 Bytes
049f424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# rnn_transformer_mini.py
# Hybrid RNN + Transformer for text classification (toy dataset)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random

# ------------------------
# Mini synthetic dataset
# ------------------------
# Class 0: greetings, Class 1: food related
class MiniTextDataset(Dataset):
    def __init__(self, size=200):
        self.samples = []
        greetings = ["hello", "hi", "good morning", "hey", "greetings"]
        food = ["i love pizza", "burger is tasty", "rice and curry", "pasta is great", "i eat apple"]

        for _ in range(size):
            if random.random() < 0.5:
                sent = random.choice(greetings)
                label = 0
            else:
                sent = random.choice(food)
                label = 1
            self.samples.append((sent, label))

        self.vocab = {"<pad>":0, "<unk>":1}
        idx = 2
        for s,_ in self.samples:
            for w in s.split():
                if w not in self.vocab:
                    self.vocab[w] = idx
                    idx += 1
        self.inv_vocab = {i:w for w,i in self.vocab.items()}

    def encode(self, text, max_len=6):
        ids = [self.vocab.get(w,1) for w in text.split()]
        if len(ids) < max_len:
            ids += [0]*(max_len-len(ids))
        return ids[:max_len]

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        text, label = self.samples[idx]
        return torch.tensor(self.encode(text)), torch.tensor(label)

train_ds = MiniTextDataset(size=200)
test_ds  = MiniTextDataset(size=50)
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True)
test_dl  = DataLoader(test_ds, batch_size=16)

# ------------------------
# Model: RNN + Transformer
# ------------------------
class RNNTransformer(nn.Module):
    def __init__(self, vocab_size, emb_dim=64, rnn_hidden=64, nhead=4, num_layers=2, num_classes=2):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.rnn = nn.LSTM(emb_dim, rnn_hidden, batch_first=True, bidirectional=True)

        self.pos_emb = nn.Parameter(torch.randn(1, 6, rnn_hidden*2))  # seq_len=6
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=rnn_hidden*2, nhead=nhead, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fc = nn.Linear(rnn_hidden*2, num_classes)

    def forward(self, x):
        emb = self.emb(x)                     # (B,L,E)
        rnn_out, _ = self.rnn(emb)            # (B,L,H*2)
        seq = rnn_out + self.pos_emb          # add learnable pos emb
        enc = self.transformer(seq)           # (B,L,H*2)
        pooled = enc.mean(dim=1)              # average pooling
        return self.fc(pooled)

# ------------------------
# Train & Evaluate
# ------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RNNTransformer(vocab_size=len(train_ds.vocab)).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(5):
    model.train()
    total_loss, total_correct = 0,0
    for x,y in train_dl:
        x,y = x.to(DEVICE), y.to(DEVICE)
        out = model(x)
        loss = criterion(out, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()*x.size(0)
        total_correct += (out.argmax(1)==y).sum().item()

    acc = total_correct/len(train_ds)
    print(f"Epoch {epoch+1}: Train Loss={total_loss/len(train_ds):.4f}, Acc={acc:.4f}")

# Eval
model.eval()
correct=0
with torch.no_grad():
    for x,y in test_dl:
        x,y = x.to(DEVICE), y.to(DEVICE)
        out = model(x)
        correct += (out.argmax(1)==y).sum().item()
print(f"Test Accuracy: {correct/len(test_ds):.4f}")