File size: 3,838 Bytes
4efecb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# cnn_transformer_mini.py
# Hybrid CNN + Transformer for MNIST (mini dataset)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import random

# ------------------------
# Config
# ------------------------
BATCH_SIZE = 64
EPOCHS = 3
LR = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------------------
# Mini Dataset (MNIST subset)
# ------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

# take only small subset for quick training
train_subset = Subset(train_dataset, random.sample(range(len(train_dataset)), 2000))
test_subset = Subset(test_dataset, random.sample(range(len(test_dataset)), 500))

train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE)

# ------------------------
# Model
# ------------------------
class CNN_Extractor(nn.Module):
    """Simple CNN to get feature maps"""
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),   # 32 x 14 x 14
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),   # 64 x 7 x 7
        )
    def forward(self, x):
        return self.conv(x)   # (B,64,7,7)

class TransformerClassifier(nn.Module):
    def __init__(self, num_classes=10, emb_dim=64, nhead=4, num_layers=2, ff_dim=128):
        super().__init__()
        self.cnn = CNN_Extractor()

        # Flatten CNN feature maps into sequence (7x7=49 tokens)
        self.pos_emb = nn.Parameter(torch.randn(1, 49, emb_dim))
        self.proj = nn.Linear(64, emb_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=emb_dim, nhead=nhead, dim_feedforward=ff_dim, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.cls_head = nn.Linear(emb_dim, num_classes)

    def forward(self, x):
        feat = self.cnn(x)  # (B,64,7,7)
        B, C, H, W = feat.shape
        seq = feat.permute(0,2,3,1).reshape(B, H*W, C)  # (B,49,64)
        seq = self.proj(seq) + self.pos_emb  # (B,49,emb)
        enc = self.encoder(seq)              # (B,49,emb)
        cls_token = enc.mean(dim=1)          # average pooling
        return self.cls_head(cls_token)

# ------------------------
# Train & Evaluate
# ------------------------
model = TransformerClassifier().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    model.train()
    total_loss, total_correct = 0, 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        out = model(imgs)
        loss = criterion(out, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()*imgs.size(0)
        total_correct += (out.argmax(1)==labels).sum().item()

    acc = total_correct/len(train_loader.dataset)
    print(f"Epoch {epoch+1}: Train Loss={total_loss/len(train_loader.dataset):.4f}, Acc={acc:.4f}")

# Eval
model.eval()
correct = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        out = model(imgs)
        correct += (out.argmax(1)==labels).sum().item()

print(f"Test Accuracy: {correct/len(test_loader.dataset):.4f}")