ankitkushwaha90 commited on
Commit
4efecb4
·
verified ·
1 Parent(s): 6fc6ec8

Create transformer-CNN.py

Browse files
Files changed (1) hide show
  1. transformer-CNN.py +114 -0
transformer-CNN.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cnn_transformer_mini.py
2
+ # Hybrid CNN + Transformer for MNIST (mini dataset)
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torchvision import datasets, transforms
8
+ from torch.utils.data import DataLoader, Subset
9
+ import random
10
+
11
+ # ------------------------
12
+ # Config
13
+ # ------------------------
14
+ BATCH_SIZE = 64
15
+ EPOCHS = 3
16
+ LR = 1e-3
17
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ # ------------------------
20
+ # Mini Dataset (MNIST subset)
21
+ # ------------------------
22
+ transform = transforms.Compose([
23
+ transforms.ToTensor(),
24
+ transforms.Normalize((0.5,), (0.5,))
25
+ ])
26
+
27
+ train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
28
+ test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
29
+
30
+ # take only small subset for quick training
31
+ train_subset = Subset(train_dataset, random.sample(range(len(train_dataset)), 2000))
32
+ test_subset = Subset(test_dataset, random.sample(range(len(test_dataset)), 500))
33
+
34
+ train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
35
+ test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE)
36
+
37
+ # ------------------------
38
+ # Model
39
+ # ------------------------
40
+ class CNN_Extractor(nn.Module):
41
+ """Simple CNN to get feature maps"""
42
+ def __init__(self):
43
+ super().__init__()
44
+ self.conv = nn.Sequential(
45
+ nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
46
+ nn.ReLU(),
47
+ nn.MaxPool2d(2), # 32 x 14 x 14
48
+ nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
49
+ nn.ReLU(),
50
+ nn.MaxPool2d(2), # 64 x 7 x 7
51
+ )
52
+ def forward(self, x):
53
+ return self.conv(x) # (B,64,7,7)
54
+
55
+ class TransformerClassifier(nn.Module):
56
+ def __init__(self, num_classes=10, emb_dim=64, nhead=4, num_layers=2, ff_dim=128):
57
+ super().__init__()
58
+ self.cnn = CNN_Extractor()
59
+
60
+ # Flatten CNN feature maps into sequence (7x7=49 tokens)
61
+ self.pos_emb = nn.Parameter(torch.randn(1, 49, emb_dim))
62
+ self.proj = nn.Linear(64, emb_dim)
63
+
64
+ encoder_layer = nn.TransformerEncoderLayer(
65
+ d_model=emb_dim, nhead=nhead, dim_feedforward=ff_dim, batch_first=True
66
+ )
67
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
68
+
69
+ self.cls_head = nn.Linear(emb_dim, num_classes)
70
+
71
+ def forward(self, x):
72
+ feat = self.cnn(x) # (B,64,7,7)
73
+ B, C, H, W = feat.shape
74
+ seq = feat.permute(0,2,3,1).reshape(B, H*W, C) # (B,49,64)
75
+ seq = self.proj(seq) + self.pos_emb # (B,49,emb)
76
+ enc = self.encoder(seq) # (B,49,emb)
77
+ cls_token = enc.mean(dim=1) # average pooling
78
+ return self.cls_head(cls_token)
79
+
80
+ # ------------------------
81
+ # Train & Evaluate
82
+ # ------------------------
83
+ model = TransformerClassifier().to(DEVICE)
84
+ criterion = nn.CrossEntropyLoss()
85
+ optimizer = optim.Adam(model.parameters(), lr=LR)
86
+
87
+ for epoch in range(EPOCHS):
88
+ model.train()
89
+ total_loss, total_correct = 0, 0
90
+ for imgs, labels in train_loader:
91
+ imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
92
+ out = model(imgs)
93
+ loss = criterion(out, labels)
94
+
95
+ optimizer.zero_grad()
96
+ loss.backward()
97
+ optimizer.step()
98
+
99
+ total_loss += loss.item()*imgs.size(0)
100
+ total_correct += (out.argmax(1)==labels).sum().item()
101
+
102
+ acc = total_correct/len(train_loader.dataset)
103
+ print(f"Epoch {epoch+1}: Train Loss={total_loss/len(train_loader.dataset):.4f}, Acc={acc:.4f}")
104
+
105
+ # Eval
106
+ model.eval()
107
+ correct = 0
108
+ with torch.no_grad():
109
+ for imgs, labels in test_loader:
110
+ imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
111
+ out = model(imgs)
112
+ correct += (out.argmax(1)==labels).sum().item()
113
+
114
+ print(f"Test Accuracy: {correct/len(test_loader.dataset):.4f}")