File size: 2,741 Bytes
fbf7e95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchmetrics import Accuracy


class SimilarityVectorModel(pl.LightningModule):
    def __init__(self, lr, weight_decay, optimizer, batch_size, attrs, hidden_sizes):
        super().__init__()

        # Hyperparameters
        self.attrs = attrs
        self.lr = lr
        self.weight_decay = weight_decay
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.save_hyperparameters()

        # Create model layers
        layer_sizes = [len(attrs)] + hidden_sizes + [1]
        layers = []
        for i in range(len(layer_sizes) - 1):
            in_size, out_size = layer_sizes[i], layer_sizes[i + 1]
            layers.append(nn.Linear(in_size, out_size))

            if i < len(layer_sizes) - 2:
                layers.append(nn.ReLU())

        self.layers = nn.Sequential(*layers)

        self.sigmoid = nn.Sigmoid()
        self.criterion = nn.BCEWithLogitsLoss()
        self.accuracy = Accuracy(task="binary")

    def forward(self, x):
        return self.layers(x)

    def predict(self, x):
        return self.sigmoid(self(x))

    def training_step(self, batch, batch_idx):
        sim, label = batch
        pred = self(sim.float())
        label = label.unsqueeze(1)

        loss = self.criterion(pred, label)
        acc = self.accuracy(pred, label.long())

        self.log("train_loss", loss, on_step=False, on_epoch=True)
        self.log("train_acc", acc, on_step=False, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        sim, label = batch
        pred = self(sim.float())
        label = label.unsqueeze(1)

        loss = self.criterion(pred, label)
        acc = self.accuracy(pred, label.long())

        self.log("val_loss", loss, on_step=False, on_epoch=True)
        self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        sim, label = batch
        pred = self(sim.float())
        label = label.unsqueeze(1)

        loss = self.criterion(pred, label)
        acc = self.accuracy(pred, label.long())

        self.log("test_loss", loss, on_step=False, on_epoch=True)
        self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizers = {
            "Adadelta": torch.optim.Adadelta,
            "Adagrad": torch.optim.Adagrad,
            "Adam": torch.optim.Adam,
            "RMSprop": torch.optim.RMSprop,
            "SGD": torch.optim.SGD,
        }
        return optimizers[self.optimizer](
            self.parameters(), lr=self.lr, weight_decay=self.weight_decay
        )