ankitkushwaha90's picture
Create transformer-CNN.py
4efecb4 verified
# cnn_transformer_mini.py
# Hybrid CNN + Transformer for MNIST (mini dataset)
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import random
# ------------------------
# Config
# ------------------------
BATCH_SIZE = 64
EPOCHS = 3
LR = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------------
# Mini Dataset (MNIST subset)
# ------------------------
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
# take only small subset for quick training
train_subset = Subset(train_dataset, random.sample(range(len(train_dataset)), 2000))
test_subset = Subset(test_dataset, random.sample(range(len(test_dataset)), 500))
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE)
# ------------------------
# Model
# ------------------------
class CNN_Extractor(nn.Module):
"""Simple CNN to get feature maps"""
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2), # 32 x 14 x 14
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2), # 64 x 7 x 7
)
def forward(self, x):
return self.conv(x) # (B,64,7,7)
class TransformerClassifier(nn.Module):
def __init__(self, num_classes=10, emb_dim=64, nhead=4, num_layers=2, ff_dim=128):
super().__init__()
self.cnn = CNN_Extractor()
# Flatten CNN feature maps into sequence (7x7=49 tokens)
self.pos_emb = nn.Parameter(torch.randn(1, 49, emb_dim))
self.proj = nn.Linear(64, emb_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=emb_dim, nhead=nhead, dim_feedforward=ff_dim, batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.cls_head = nn.Linear(emb_dim, num_classes)
def forward(self, x):
feat = self.cnn(x) # (B,64,7,7)
B, C, H, W = feat.shape
seq = feat.permute(0,2,3,1).reshape(B, H*W, C) # (B,49,64)
seq = self.proj(seq) + self.pos_emb # (B,49,emb)
enc = self.encoder(seq) # (B,49,emb)
cls_token = enc.mean(dim=1) # average pooling
return self.cls_head(cls_token)
# ------------------------
# Train & Evaluate
# ------------------------
model = TransformerClassifier().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
for epoch in range(EPOCHS):
model.train()
total_loss, total_correct = 0, 0
for imgs, labels in train_loader:
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
out = model(imgs)
loss = criterion(out, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()*imgs.size(0)
total_correct += (out.argmax(1)==labels).sum().item()
acc = total_correct/len(train_loader.dataset)
print(f"Epoch {epoch+1}: Train Loss={total_loss/len(train_loader.dataset):.4f}, Acc={acc:.4f}")
# Eval
model.eval()
correct = 0
with torch.no_grad():
for imgs, labels in test_loader:
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
out = model(imgs)
correct += (out.argmax(1)==labels).sum().item()
print(f"Test Accuracy: {correct/len(test_loader.dataset):.4f}")