import pytorch_lightning as pl import torch import torch.nn as nn from torchmetrics import Accuracy class SimilarityVectorModel(pl.LightningModule): def __init__(self, lr, weight_decay, optimizer, batch_size, attrs, hidden_sizes): super().__init__() # Hyperparameters self.attrs = attrs self.lr = lr self.weight_decay = weight_decay self.optimizer = optimizer self.batch_size = batch_size self.save_hyperparameters() # Create model layers layer_sizes = [len(attrs)] + hidden_sizes + [1] layers = [] for i in range(len(layer_sizes) - 1): in_size, out_size = layer_sizes[i], layer_sizes[i + 1] layers.append(nn.Linear(in_size, out_size)) if i < len(layer_sizes) - 2: layers.append(nn.ReLU()) self.layers = nn.Sequential(*layers) self.sigmoid = nn.Sigmoid() self.criterion = nn.BCEWithLogitsLoss() self.accuracy = Accuracy(task="binary") def forward(self, x): return self.layers(x) def predict(self, x): return self.sigmoid(self(x)) def training_step(self, batch, batch_idx): sim, label = batch pred = self(sim.float()) label = label.unsqueeze(1) loss = self.criterion(pred, label) acc = self.accuracy(pred, label.long()) self.log("train_loss", loss, on_step=False, on_epoch=True) self.log("train_acc", acc, on_step=False, on_epoch=True) return loss def validation_step(self, batch, batch_idx): sim, label = batch pred = self(sim.float()) label = label.unsqueeze(1) loss = self.criterion(pred, label) acc = self.accuracy(pred, label.long()) self.log("val_loss", loss, on_step=False, on_epoch=True) self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True) return loss def test_step(self, batch, batch_idx): sim, label = batch pred = self(sim.float()) label = label.unsqueeze(1) loss = self.criterion(pred, label) acc = self.accuracy(pred, label.long()) self.log("test_loss", loss, on_step=False, on_epoch=True) self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True) return loss def configure_optimizers(self): optimizers = { "Adadelta": torch.optim.Adadelta, "Adagrad": torch.optim.Adagrad, "Adam": torch.optim.Adam, "RMSprop": torch.optim.RMSprop, "SGD": torch.optim.SGD, } return optimizers[self.optimizer]( self.parameters(), lr=self.lr, weight_decay=self.weight_decay )