|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torchvision import datasets, transforms |
|
from torch.utils.data import DataLoader, Subset |
|
import random |
|
|
|
|
|
|
|
|
|
BATCH_SIZE = 64 |
|
EPOCHS = 3 |
|
LR = 1e-3 |
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,), (0.5,)) |
|
]) |
|
|
|
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform) |
|
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform) |
|
|
|
|
|
train_subset = Subset(train_dataset, random.sample(range(len(train_dataset)), 2000)) |
|
test_subset = Subset(test_dataset, random.sample(range(len(test_dataset)), 500)) |
|
|
|
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True) |
|
test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE) |
|
|
|
|
|
|
|
|
|
class CNN_Extractor(nn.Module): |
|
"""Simple CNN to get feature maps""" |
|
def __init__(self): |
|
super().__init__() |
|
self.conv = nn.Sequential( |
|
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
) |
|
def forward(self, x): |
|
return self.conv(x) |
|
|
|
class TransformerClassifier(nn.Module): |
|
def __init__(self, num_classes=10, emb_dim=64, nhead=4, num_layers=2, ff_dim=128): |
|
super().__init__() |
|
self.cnn = CNN_Extractor() |
|
|
|
|
|
self.pos_emb = nn.Parameter(torch.randn(1, 49, emb_dim)) |
|
self.proj = nn.Linear(64, emb_dim) |
|
|
|
encoder_layer = nn.TransformerEncoderLayer( |
|
d_model=emb_dim, nhead=nhead, dim_feedforward=ff_dim, batch_first=True |
|
) |
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
|
|
|
self.cls_head = nn.Linear(emb_dim, num_classes) |
|
|
|
def forward(self, x): |
|
feat = self.cnn(x) |
|
B, C, H, W = feat.shape |
|
seq = feat.permute(0,2,3,1).reshape(B, H*W, C) |
|
seq = self.proj(seq) + self.pos_emb |
|
enc = self.encoder(seq) |
|
cls_token = enc.mean(dim=1) |
|
return self.cls_head(cls_token) |
|
|
|
|
|
|
|
|
|
model = TransformerClassifier().to(DEVICE) |
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=LR) |
|
|
|
for epoch in range(EPOCHS): |
|
model.train() |
|
total_loss, total_correct = 0, 0 |
|
for imgs, labels in train_loader: |
|
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE) |
|
out = model(imgs) |
|
loss = criterion(out, labels) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
total_loss += loss.item()*imgs.size(0) |
|
total_correct += (out.argmax(1)==labels).sum().item() |
|
|
|
acc = total_correct/len(train_loader.dataset) |
|
print(f"Epoch {epoch+1}: Train Loss={total_loss/len(train_loader.dataset):.4f}, Acc={acc:.4f}") |
|
|
|
|
|
model.eval() |
|
correct = 0 |
|
with torch.no_grad(): |
|
for imgs, labels in test_loader: |
|
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE) |
|
out = model(imgs) |
|
correct += (out.argmax(1)==labels).sum().item() |
|
|
|
print(f"Test Accuracy: {correct/len(test_loader.dataset):.4f}") |
|
|