marc-match-ai / marcai /pl /similarity_vector_model.py
RvanB's picture
Add files from other repo
fbf7e95
raw
history blame
2.74 kB
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchmetrics import Accuracy
class SimilarityVectorModel(pl.LightningModule):
def __init__(self, lr, weight_decay, optimizer, batch_size, attrs, hidden_sizes):
super().__init__()
# Hyperparameters
self.attrs = attrs
self.lr = lr
self.weight_decay = weight_decay
self.optimizer = optimizer
self.batch_size = batch_size
self.save_hyperparameters()
# Create model layers
layer_sizes = [len(attrs)] + hidden_sizes + [1]
layers = []
for i in range(len(layer_sizes) - 1):
in_size, out_size = layer_sizes[i], layer_sizes[i + 1]
layers.append(nn.Linear(in_size, out_size))
if i < len(layer_sizes) - 2:
layers.append(nn.ReLU())
self.layers = nn.Sequential(*layers)
self.sigmoid = nn.Sigmoid()
self.criterion = nn.BCEWithLogitsLoss()
self.accuracy = Accuracy(task="binary")
def forward(self, x):
return self.layers(x)
def predict(self, x):
return self.sigmoid(self(x))
def training_step(self, batch, batch_idx):
sim, label = batch
pred = self(sim.float())
label = label.unsqueeze(1)
loss = self.criterion(pred, label)
acc = self.accuracy(pred, label.long())
self.log("train_loss", loss, on_step=False, on_epoch=True)
self.log("train_acc", acc, on_step=False, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
sim, label = batch
pred = self(sim.float())
label = label.unsqueeze(1)
loss = self.criterion(pred, label)
acc = self.accuracy(pred, label.long())
self.log("val_loss", loss, on_step=False, on_epoch=True)
self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
sim, label = batch
pred = self(sim.float())
label = label.unsqueeze(1)
loss = self.criterion(pred, label)
acc = self.accuracy(pred, label.long())
self.log("test_loss", loss, on_step=False, on_epoch=True)
self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
return loss
def configure_optimizers(self):
optimizers = {
"Adadelta": torch.optim.Adadelta,
"Adagrad": torch.optim.Adagrad,
"Adam": torch.optim.Adam,
"RMSprop": torch.optim.RMSprop,
"SGD": torch.optim.SGD,
}
return optimizers[self.optimizer](
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
)