Create transformer-CNN.py
Browse files- transformer-CNN.py +114 -0
transformer-CNN.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cnn_transformer_mini.py
|
2 |
+
# Hybrid CNN + Transformer for MNIST (mini dataset)
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
from torchvision import datasets, transforms
|
8 |
+
from torch.utils.data import DataLoader, Subset
|
9 |
+
import random
|
10 |
+
|
11 |
+
# ------------------------
|
12 |
+
# Config
|
13 |
+
# ------------------------
|
14 |
+
BATCH_SIZE = 64
|
15 |
+
EPOCHS = 3
|
16 |
+
LR = 1e-3
|
17 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
+
|
19 |
+
# ------------------------
|
20 |
+
# Mini Dataset (MNIST subset)
|
21 |
+
# ------------------------
|
22 |
+
transform = transforms.Compose([
|
23 |
+
transforms.ToTensor(),
|
24 |
+
transforms.Normalize((0.5,), (0.5,))
|
25 |
+
])
|
26 |
+
|
27 |
+
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
|
28 |
+
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
|
29 |
+
|
30 |
+
# take only small subset for quick training
|
31 |
+
train_subset = Subset(train_dataset, random.sample(range(len(train_dataset)), 2000))
|
32 |
+
test_subset = Subset(test_dataset, random.sample(range(len(test_dataset)), 500))
|
33 |
+
|
34 |
+
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
|
35 |
+
test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE)
|
36 |
+
|
37 |
+
# ------------------------
|
38 |
+
# Model
|
39 |
+
# ------------------------
|
40 |
+
class CNN_Extractor(nn.Module):
|
41 |
+
"""Simple CNN to get feature maps"""
|
42 |
+
def __init__(self):
|
43 |
+
super().__init__()
|
44 |
+
self.conv = nn.Sequential(
|
45 |
+
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
46 |
+
nn.ReLU(),
|
47 |
+
nn.MaxPool2d(2), # 32 x 14 x 14
|
48 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
49 |
+
nn.ReLU(),
|
50 |
+
nn.MaxPool2d(2), # 64 x 7 x 7
|
51 |
+
)
|
52 |
+
def forward(self, x):
|
53 |
+
return self.conv(x) # (B,64,7,7)
|
54 |
+
|
55 |
+
class TransformerClassifier(nn.Module):
|
56 |
+
def __init__(self, num_classes=10, emb_dim=64, nhead=4, num_layers=2, ff_dim=128):
|
57 |
+
super().__init__()
|
58 |
+
self.cnn = CNN_Extractor()
|
59 |
+
|
60 |
+
# Flatten CNN feature maps into sequence (7x7=49 tokens)
|
61 |
+
self.pos_emb = nn.Parameter(torch.randn(1, 49, emb_dim))
|
62 |
+
self.proj = nn.Linear(64, emb_dim)
|
63 |
+
|
64 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
65 |
+
d_model=emb_dim, nhead=nhead, dim_feedforward=ff_dim, batch_first=True
|
66 |
+
)
|
67 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
68 |
+
|
69 |
+
self.cls_head = nn.Linear(emb_dim, num_classes)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
feat = self.cnn(x) # (B,64,7,7)
|
73 |
+
B, C, H, W = feat.shape
|
74 |
+
seq = feat.permute(0,2,3,1).reshape(B, H*W, C) # (B,49,64)
|
75 |
+
seq = self.proj(seq) + self.pos_emb # (B,49,emb)
|
76 |
+
enc = self.encoder(seq) # (B,49,emb)
|
77 |
+
cls_token = enc.mean(dim=1) # average pooling
|
78 |
+
return self.cls_head(cls_token)
|
79 |
+
|
80 |
+
# ------------------------
|
81 |
+
# Train & Evaluate
|
82 |
+
# ------------------------
|
83 |
+
model = TransformerClassifier().to(DEVICE)
|
84 |
+
criterion = nn.CrossEntropyLoss()
|
85 |
+
optimizer = optim.Adam(model.parameters(), lr=LR)
|
86 |
+
|
87 |
+
for epoch in range(EPOCHS):
|
88 |
+
model.train()
|
89 |
+
total_loss, total_correct = 0, 0
|
90 |
+
for imgs, labels in train_loader:
|
91 |
+
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
|
92 |
+
out = model(imgs)
|
93 |
+
loss = criterion(out, labels)
|
94 |
+
|
95 |
+
optimizer.zero_grad()
|
96 |
+
loss.backward()
|
97 |
+
optimizer.step()
|
98 |
+
|
99 |
+
total_loss += loss.item()*imgs.size(0)
|
100 |
+
total_correct += (out.argmax(1)==labels).sum().item()
|
101 |
+
|
102 |
+
acc = total_correct/len(train_loader.dataset)
|
103 |
+
print(f"Epoch {epoch+1}: Train Loss={total_loss/len(train_loader.dataset):.4f}, Acc={acc:.4f}")
|
104 |
+
|
105 |
+
# Eval
|
106 |
+
model.eval()
|
107 |
+
correct = 0
|
108 |
+
with torch.no_grad():
|
109 |
+
for imgs, labels in test_loader:
|
110 |
+
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
|
111 |
+
out = model(imgs)
|
112 |
+
correct += (out.argmax(1)==labels).sum().item()
|
113 |
+
|
114 |
+
print(f"Test Accuracy: {correct/len(test_loader.dataset):.4f}")
|