|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn as nn |
|
from torchmetrics import Accuracy |
|
|
|
|
|
class SimilarityVectorModel(pl.LightningModule): |
|
def __init__(self, lr, weight_decay, optimizer, batch_size, attrs, hidden_sizes): |
|
super().__init__() |
|
|
|
|
|
self.attrs = attrs |
|
self.lr = lr |
|
self.weight_decay = weight_decay |
|
self.optimizer = optimizer |
|
self.batch_size = batch_size |
|
self.save_hyperparameters() |
|
|
|
|
|
layer_sizes = [len(attrs)] + hidden_sizes + [1] |
|
layers = [] |
|
for i in range(len(layer_sizes) - 1): |
|
in_size, out_size = layer_sizes[i], layer_sizes[i + 1] |
|
layers.append(nn.Linear(in_size, out_size)) |
|
|
|
if i < len(layer_sizes) - 2: |
|
layers.append(nn.ReLU()) |
|
|
|
self.layers = nn.Sequential(*layers) |
|
|
|
self.sigmoid = nn.Sigmoid() |
|
self.criterion = nn.BCEWithLogitsLoss() |
|
self.accuracy = Accuracy(task="binary") |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
def predict(self, x): |
|
return self.sigmoid(self(x)) |
|
|
|
def training_step(self, batch, batch_idx): |
|
sim, label = batch |
|
pred = self(sim.float()) |
|
label = label.unsqueeze(1) |
|
|
|
loss = self.criterion(pred, label) |
|
acc = self.accuracy(pred, label.long()) |
|
|
|
self.log("train_loss", loss, on_step=False, on_epoch=True) |
|
self.log("train_acc", acc, on_step=False, on_epoch=True) |
|
|
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
sim, label = batch |
|
pred = self(sim.float()) |
|
label = label.unsqueeze(1) |
|
|
|
loss = self.criterion(pred, label) |
|
acc = self.accuracy(pred, label.long()) |
|
|
|
self.log("val_loss", loss, on_step=False, on_epoch=True) |
|
self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True) |
|
|
|
return loss |
|
|
|
def test_step(self, batch, batch_idx): |
|
sim, label = batch |
|
pred = self(sim.float()) |
|
label = label.unsqueeze(1) |
|
|
|
loss = self.criterion(pred, label) |
|
acc = self.accuracy(pred, label.long()) |
|
|
|
self.log("test_loss", loss, on_step=False, on_epoch=True) |
|
self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True) |
|
|
|
return loss |
|
|
|
def configure_optimizers(self): |
|
optimizers = { |
|
"Adadelta": torch.optim.Adadelta, |
|
"Adagrad": torch.optim.Adagrad, |
|
"Adam": torch.optim.Adam, |
|
"RMSprop": torch.optim.RMSprop, |
|
"SGD": torch.optim.SGD, |
|
} |
|
return optimizers[self.optimizer]( |
|
self.parameters(), lr=self.lr, weight_decay=self.weight_decay |
|
) |
|
|