Delete enigma
Browse files- enigma/EnBERT.py +0 -206
- enigma/TrainEnigma.ipynb +0 -919
- enigma/config_enigma.json +0 -13
- enigma/enigma.cpp +0 -364
- enigma/generate.py +0 -126
- enigma/model.py +0 -388
- enigma/run.py +0 -100
enigma/EnBERT.py
DELETED
|
@@ -1,206 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
simple BERT architecture model, paired with one more layer of
|
| 3 |
-
masked self-attention, to predict next token
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import os
|
| 8 |
-
current_directory = os.path.dirname(os.path.abspath(__file__))
|
| 9 |
-
os.chdir(current_directory)
|
| 10 |
-
|
| 11 |
-
import torch.nn as nn
|
| 12 |
-
from torch.nn import functional as F
|
| 13 |
-
|
| 14 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 15 |
-
|
| 16 |
-
# hyperparams
|
| 17 |
-
batch_size = 8
|
| 18 |
-
block_size = 32
|
| 19 |
-
max_iters = 10
|
| 20 |
-
eval_interval = 10
|
| 21 |
-
learning_rate = 3e-4
|
| 22 |
-
eval_iters = 5
|
| 23 |
-
d_model = 256
|
| 24 |
-
n_layer = 16
|
| 25 |
-
n_head = 12
|
| 26 |
-
dropout = 0.2
|
| 27 |
-
norm_eps = 1e-5
|
| 28 |
-
|
| 29 |
-
class SWiGLU(nn.Module):
|
| 30 |
-
""" SWiGLU(x) = σ(x) ⊙ ReLU(x) + (1−σ(x)) ⊙ x """
|
| 31 |
-
|
| 32 |
-
def forward(self, x):
|
| 33 |
-
sigmoid_output = torch.sigmoid(x)
|
| 34 |
-
relu_output = F.relu(x)
|
| 35 |
-
out = sigmoid_output * relu_output + (1 - sigmoid_output) * x
|
| 36 |
-
|
| 37 |
-
return out
|
| 38 |
-
|
| 39 |
-
class UnMaskedHead(nn.Module):
|
| 40 |
-
""" single head of self attention """
|
| 41 |
-
def __init__(self, d_model, head_size, dropout):
|
| 42 |
-
super().__init__()
|
| 43 |
-
self.key = nn.Linear(d_model, head_size, bias=True)
|
| 44 |
-
self.query = nn.Linear(d_model, head_size, bias=True)
|
| 45 |
-
self.value = nn.Linear(d_model, head_size, bias=True)
|
| 46 |
-
self.dropout = nn.Dropout(dropout)
|
| 47 |
-
|
| 48 |
-
def forward(self, x):
|
| 49 |
-
B, T, C = x.shape
|
| 50 |
-
key = self.key(x)
|
| 51 |
-
query = self.query(x)
|
| 52 |
-
|
| 53 |
-
weights = query @ key.transpose(-2, -1) * key.shape[-1]**-0.5
|
| 54 |
-
weights = F.softmax(weights, dim=-1)
|
| 55 |
-
weights = self.dropout(weights)
|
| 56 |
-
|
| 57 |
-
value = self.value(x)
|
| 58 |
-
out = weights @ value
|
| 59 |
-
return out
|
| 60 |
-
|
| 61 |
-
class MaskedHead(nn.Module):
|
| 62 |
-
""" one head of self-attention """
|
| 63 |
-
def __init__(self, head_size, dropout, d_model):
|
| 64 |
-
super().__init__()
|
| 65 |
-
self.key = nn.Linear(d_model, head_size, bias=True)
|
| 66 |
-
self.query = nn.Linear(d_model, head_size, bias=True)
|
| 67 |
-
self.value = nn.Linear(d_model, head_size, bias=True)
|
| 68 |
-
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
|
| 69 |
-
|
| 70 |
-
self.dropout = nn.Dropout(dropout)
|
| 71 |
-
|
| 72 |
-
def forward(self, x):
|
| 73 |
-
B,T,C = x.shape
|
| 74 |
-
k = self.key(x)
|
| 75 |
-
q = self.query(x)
|
| 76 |
-
|
| 77 |
-
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
|
| 78 |
-
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
|
| 79 |
-
wei = F.softmax(wei, dim=-1) # (B, T, T)
|
| 80 |
-
wei = self.dropout(wei)
|
| 81 |
-
|
| 82 |
-
v = self.value(x)
|
| 83 |
-
out = wei @ v
|
| 84 |
-
return out
|
| 85 |
-
|
| 86 |
-
class MultiUnMasked(nn.Module):
|
| 87 |
-
def __init__(self, d_model, n_head, dropout):
|
| 88 |
-
head_size = d_model // n_head
|
| 89 |
-
super().__init__()
|
| 90 |
-
self.heads = nn.ModuleList([UnMaskedHead(d_model=d_model, dropout=dropout, head_size=head_size) for _ in range(n_head)])
|
| 91 |
-
self.proj = nn.Linear(n_head * head_size, d_model)
|
| 92 |
-
self.dropout = nn.Dropout(dropout)
|
| 93 |
-
|
| 94 |
-
def forward(self, x):
|
| 95 |
-
out = torch.cat([h(x) for h in self.heads], dim=-1)
|
| 96 |
-
out = self.dropout(self.proj(out))
|
| 97 |
-
return out
|
| 98 |
-
|
| 99 |
-
class MultiMasked(nn.Module):
|
| 100 |
-
def __init__(self, d_model, n_head, dropout):
|
| 101 |
-
head_size = d_model // n_head
|
| 102 |
-
super().__init__()
|
| 103 |
-
self.heads = nn.ModuleList([MaskedHead(d_model=d_model, dropout=dropout, head_size=head_size) for _ in range(n_head)])
|
| 104 |
-
self.proj = nn.Linear(n_head * head_size, d_model)
|
| 105 |
-
self.dropout = nn.Dropout(dropout)
|
| 106 |
-
|
| 107 |
-
def forward(self, x):
|
| 108 |
-
out = torch.cat([h(x) for h in self.heads], dim=-1)
|
| 109 |
-
out = self.dropout(self.proj(out))
|
| 110 |
-
return out
|
| 111 |
-
|
| 112 |
-
class FeedForward(nn.Module):
|
| 113 |
-
def __init__(self, d_model, dropout):
|
| 114 |
-
super().__init__()
|
| 115 |
-
self.net = nn.Sequential(
|
| 116 |
-
nn.Linear(d_model, 4*d_model),
|
| 117 |
-
nn.GELU(),
|
| 118 |
-
nn.Linear(4*d_model, d_model),
|
| 119 |
-
nn.Dropout(dropout)
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
def forward(self, x):
|
| 123 |
-
return self.net(x)
|
| 124 |
-
|
| 125 |
-
class Block(nn.Module):
|
| 126 |
-
def __init__(self, d_model, n_head, norm_eps, dropout):
|
| 127 |
-
super().__init__()
|
| 128 |
-
self.sa_masked = MultiMasked(n_head=n_head, d_model=d_model, dropout=dropout)
|
| 129 |
-
self.sa_unmasked = MultiUnMasked(n_head=n_head, d_model=d_model, dropout=dropout)
|
| 130 |
-
self.ffwd = FeedForward(d_model, dropout=dropout)
|
| 131 |
-
self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)
|
| 132 |
-
self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)
|
| 133 |
-
|
| 134 |
-
def forward(self, x):
|
| 135 |
-
x2 = x + self.sa_unmasked(self.norm1(x))
|
| 136 |
-
x = x2 + self.norm2(self.ffwd(x2))
|
| 137 |
-
|
| 138 |
-
x2 = x + self.sa_masked(self.norm1(x))
|
| 139 |
-
x = x2 + self.norm2(self.ffwd(x2))
|
| 140 |
-
return x
|
| 141 |
-
|
| 142 |
-
class EnigmaBERT(nn.Module):
|
| 143 |
-
def __init__(self, vocab_size):
|
| 144 |
-
super().__init__()
|
| 145 |
-
self.toked_model = nn.Embedding(vocab_size, d_model)
|
| 146 |
-
self.pos_encod = nn.Embedding(block_size, d_model)
|
| 147 |
-
self.block = nn.Sequential(*[Block(d_model=d_model, dropout=dropout, norm_eps=norm_eps, n_head=n_head) for _ in range(n_layer)])
|
| 148 |
-
self.norm_final = nn.LayerNorm(d_model, eps=norm_eps)
|
| 149 |
-
self.linear_final = nn.Linear(d_model, vocab_size)
|
| 150 |
-
self.apply(self._init_weights)
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
def _init_weights(self, module):
|
| 154 |
-
if isinstance(module, nn.Linear):
|
| 155 |
-
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 156 |
-
if module.bias is not None:
|
| 157 |
-
torch.nn.init.zeros_(module.bias.data)
|
| 158 |
-
elif isinstance(module, nn.Embedding):
|
| 159 |
-
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 160 |
-
|
| 161 |
-
def forward(self, idx, targets=None):
|
| 162 |
-
B, T = idx.shape
|
| 163 |
-
|
| 164 |
-
toked_model = self.toked_model(idx)
|
| 165 |
-
pos_encod = self.pos_encod(torch.arange(T, device=device))
|
| 166 |
-
x = toked_model + pos_encod
|
| 167 |
-
x = self.block(x)
|
| 168 |
-
x = self.norm_final(x)
|
| 169 |
-
logits = self.linear_final(x)
|
| 170 |
-
|
| 171 |
-
if targets is None:
|
| 172 |
-
loss = None
|
| 173 |
-
|
| 174 |
-
else:
|
| 175 |
-
B, T, C = logits.shape
|
| 176 |
-
logits = logits.view(B*T, C)
|
| 177 |
-
targets = targets.view(B*T)
|
| 178 |
-
loss = F.cross_entropy(logits, targets)
|
| 179 |
-
|
| 180 |
-
return logits, loss
|
| 181 |
-
|
| 182 |
-
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
|
| 183 |
-
generated_tokens = []
|
| 184 |
-
|
| 185 |
-
for _ in range(max_new_tokens):
|
| 186 |
-
idx_cond = idx[:, -block_size:]
|
| 187 |
-
logits, _ = self(idx_cond)
|
| 188 |
-
logits = logits[:, -1, :]
|
| 189 |
-
|
| 190 |
-
scaled_logits = logits / temperature
|
| 191 |
-
if top_k > 0:
|
| 192 |
-
scaled_logits = self._top_k_filtering(scaled_logits, top_k)
|
| 193 |
-
|
| 194 |
-
probs = F.softmax(scaled_logits, dim=-1)
|
| 195 |
-
sampled_idx = torch.multinomial(probs, num_samples=1)
|
| 196 |
-
generated_tokens.append(sampled_idx.item())
|
| 197 |
-
idx = torch.cat((idx, sampled_idx), dim=1)
|
| 198 |
-
|
| 199 |
-
return generated_tokens
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
def _top_k_filtering(self, logits, top_k):
|
| 203 |
-
values, indices = torch.topk(logits, top_k, dim=-1)
|
| 204 |
-
min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
|
| 205 |
-
filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
|
| 206 |
-
return filtered_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enigma/TrainEnigma.ipynb
DELETED
|
@@ -1,919 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 1,
|
| 6 |
-
"metadata": {
|
| 7 |
-
"colab": {
|
| 8 |
-
"base_uri": "https://localhost:8080/"
|
| 9 |
-
},
|
| 10 |
-
"id": "WXpJBLyr30Rx",
|
| 11 |
-
"outputId": "2806070a-648f-42ca-fa8a-9aeb8f99ceb7"
|
| 12 |
-
},
|
| 13 |
-
"outputs": [
|
| 14 |
-
{
|
| 15 |
-
"output_type": "stream",
|
| 16 |
-
"name": "stdout",
|
| 17 |
-
"text": [
|
| 18 |
-
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
|
| 19 |
-
]
|
| 20 |
-
}
|
| 21 |
-
],
|
| 22 |
-
"source": [
|
| 23 |
-
"from google.colab import drive\n",
|
| 24 |
-
"drive.mount('/content/drive')"
|
| 25 |
-
]
|
| 26 |
-
},
|
| 27 |
-
{
|
| 28 |
-
"cell_type": "code",
|
| 29 |
-
"execution_count": 2,
|
| 30 |
-
"metadata": {
|
| 31 |
-
"colab": {
|
| 32 |
-
"base_uri": "https://localhost:8080/"
|
| 33 |
-
},
|
| 34 |
-
"id": "r7WUm0VL4bN4",
|
| 35 |
-
"outputId": "bfdefb82-479e-4f91-9a01-299ff76756e9"
|
| 36 |
-
},
|
| 37 |
-
"outputs": [
|
| 38 |
-
{
|
| 39 |
-
"output_type": "stream",
|
| 40 |
-
"name": "stdout",
|
| 41 |
-
"text": [
|
| 42 |
-
"485.52 million letters\n"
|
| 43 |
-
]
|
| 44 |
-
}
|
| 45 |
-
],
|
| 46 |
-
"source": [
|
| 47 |
-
"import torch\n",
|
| 48 |
-
"\n",
|
| 49 |
-
"# importing the data\n",
|
| 50 |
-
"file_path = '/content/drive/MyDrive/train2.txt'\n",
|
| 51 |
-
"with open(file_path, 'r', encoding='utf-8') as file:\n",
|
| 52 |
-
" dna_seq = file.read()\n",
|
| 53 |
-
"file.close()\n",
|
| 54 |
-
"\n",
|
| 55 |
-
"print(f\"{(len(dna_seq)/1e6):.2f} million letters\")"
|
| 56 |
-
]
|
| 57 |
-
},
|
| 58 |
-
{
|
| 59 |
-
"cell_type": "code",
|
| 60 |
-
"execution_count": 3,
|
| 61 |
-
"metadata": {
|
| 62 |
-
"id": "Cdhybhz9owTK"
|
| 63 |
-
},
|
| 64 |
-
"outputs": [],
|
| 65 |
-
"source": [
|
| 66 |
-
"class PerCharTokenizer:\n",
|
| 67 |
-
" \"\"\"\n",
|
| 68 |
-
" Args:\n",
|
| 69 |
-
" - chars (list): all bases along with special tokens represented as characters\n",
|
| 70 |
-
" - vocab_size (int): size of vocabulary\n",
|
| 71 |
-
"\n",
|
| 72 |
-
" Working:\n",
|
| 73 |
-
" - vocab contains all the bases and ['P', 'M', 'U'] as padding, mask and unknown token\n",
|
| 74 |
-
" - encode(): iterates over each character a time and the looks up for the position in vocab\n",
|
| 75 |
-
" and returns it's position as integer\n",
|
| 76 |
-
" - decode(): takes input of a list of integers and returns the specific item from vocab\n",
|
| 77 |
-
" \"\"\"\n",
|
| 78 |
-
" def __init__(self):\n",
|
| 79 |
-
" super().__init__()\n",
|
| 80 |
-
" self.chars = ['\\n', 'A', 'T', 'G', 'C', 'P', 'M', 'U', ' ']\n",
|
| 81 |
-
" self.vocab_size = len(self.chars)\n",
|
| 82 |
-
" self.string_to_index = {ch: i for i, ch in enumerate(self.chars)}\n",
|
| 83 |
-
" self.index_to_string = {i: ch for i, ch in enumerate(self.chars)}\n",
|
| 84 |
-
"\n",
|
| 85 |
-
" def encode(self, string):\n",
|
| 86 |
-
" encoded = []\n",
|
| 87 |
-
" for char in string:\n",
|
| 88 |
-
" if char in self.string_to_index:\n",
|
| 89 |
-
" encoded.append(self.string_to_index[char])\n",
|
| 90 |
-
" else:\n",
|
| 91 |
-
" special_index = len(self.string_to_index)\n",
|
| 92 |
-
" self.string_to_index[char] = special_index\n",
|
| 93 |
-
" self.index_to_string[special_index] = char\n",
|
| 94 |
-
" encoded.append(special_index)\n",
|
| 95 |
-
" return encoded\n",
|
| 96 |
-
"\n",
|
| 97 |
-
" def decode(self, integer):\n",
|
| 98 |
-
" decoded = []\n",
|
| 99 |
-
" for i in integer:\n",
|
| 100 |
-
" if i in self.index_to_string:\n",
|
| 101 |
-
" decoded.append(self.index_to_string[i])\n",
|
| 102 |
-
" else:\n",
|
| 103 |
-
" continue\n",
|
| 104 |
-
" return ''.join(decoded)"
|
| 105 |
-
]
|
| 106 |
-
},
|
| 107 |
-
{
|
| 108 |
-
"cell_type": "code",
|
| 109 |
-
"execution_count": 4,
|
| 110 |
-
"metadata": {
|
| 111 |
-
"colab": {
|
| 112 |
-
"base_uri": "https://localhost:8080/"
|
| 113 |
-
},
|
| 114 |
-
"id": "6Ou9txgmAdIB",
|
| 115 |
-
"outputId": "cb5dd462-8b2a-445a-9524-1b484f288c64"
|
| 116 |
-
},
|
| 117 |
-
"outputs": [
|
| 118 |
-
{
|
| 119 |
-
"output_type": "stream",
|
| 120 |
-
"name": "stdout",
|
| 121 |
-
"text": [
|
| 122 |
-
"train data 436.97million, val data 48.55million\n"
|
| 123 |
-
]
|
| 124 |
-
}
|
| 125 |
-
],
|
| 126 |
-
"source": [
|
| 127 |
-
"token = PerCharTokenizer()\n",
|
| 128 |
-
"data = torch.tensor(token.encode(dna_seq), dtype=torch.long)\n",
|
| 129 |
-
"\n",
|
| 130 |
-
"# Train and test splits\n",
|
| 131 |
-
"n = int(0.9*len(data)) # first 90% will be train, rest val\n",
|
| 132 |
-
"train_data = data[:n]\n",
|
| 133 |
-
"val_data = data[n:]\n",
|
| 134 |
-
"print(f\"train data {(len(train_data)/1e6):.2f}million, val data {(len(val_data)/1e6):.2f}million\")"
|
| 135 |
-
]
|
| 136 |
-
},
|
| 137 |
-
{
|
| 138 |
-
"cell_type": "code",
|
| 139 |
-
"execution_count": 5,
|
| 140 |
-
"metadata": {
|
| 141 |
-
"id": "ebFKQQ9NAq4e"
|
| 142 |
-
},
|
| 143 |
-
"outputs": [],
|
| 144 |
-
"source": [
|
| 145 |
-
"# hyperparams\n",
|
| 146 |
-
"batch_size = 10\n",
|
| 147 |
-
"block_size = 512\n",
|
| 148 |
-
"max_iters = 5000\n",
|
| 149 |
-
"eval_interval = 100\n",
|
| 150 |
-
"learning_rate = 3e-4\n",
|
| 151 |
-
"eval_iters = 100\n",
|
| 152 |
-
"d_model = 384\n",
|
| 153 |
-
"n_layers = 12\n",
|
| 154 |
-
"n_head = 12\n",
|
| 155 |
-
"dropout = 0.25\n",
|
| 156 |
-
"norm_eps = 1e-4"
|
| 157 |
-
]
|
| 158 |
-
},
|
| 159 |
-
{
|
| 160 |
-
"cell_type": "code",
|
| 161 |
-
"execution_count": 6,
|
| 162 |
-
"metadata": {
|
| 163 |
-
"id": "dZMiYkr37cmU"
|
| 164 |
-
},
|
| 165 |
-
"outputs": [],
|
| 166 |
-
"source": [
|
| 167 |
-
"import math\n",
|
| 168 |
-
"import torch.nn as nn\n",
|
| 169 |
-
"from torch.nn import functional as F\n",
|
| 170 |
-
"\n",
|
| 171 |
-
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 172 |
-
"\n",
|
| 173 |
-
"class AttentionHead(nn.Module):\n",
|
| 174 |
-
" \"\"\"\n",
|
| 175 |
-
" initialize a single head of self attention.\n",
|
| 176 |
-
"\n",
|
| 177 |
-
" Args:\n",
|
| 178 |
-
" - d_model (int): dimensionality of the model's hidden layers\n",
|
| 179 |
-
" - head_size (int): dimensionality of each attention head\n",
|
| 180 |
-
" - dropout (float): dropout probability\n",
|
| 181 |
-
" - block_size (int): the maximum sequence length for positional encoding\n",
|
| 182 |
-
" \"\"\"\n",
|
| 183 |
-
" def __init__(self, d_model, head_size, dropout, block_size):\n",
|
| 184 |
-
" super().__init__()\n",
|
| 185 |
-
" self.key = nn.Linear(d_model, head_size, bias=True)\n",
|
| 186 |
-
" self.query = nn.Linear(d_model, head_size, bias=True)\n",
|
| 187 |
-
" self.value = nn.Linear(d_model, head_size, bias=False)\n",
|
| 188 |
-
" self.dropout = nn.Dropout(dropout)\n",
|
| 189 |
-
" self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
|
| 190 |
-
"\n",
|
| 191 |
-
" self.rel_pos_emb = nn.Parameter(torch.randn(block_size, block_size, head_size))\n",
|
| 192 |
-
"\n",
|
| 193 |
-
" def forward(self, x, mask=False):\n",
|
| 194 |
-
" \"\"\"\n",
|
| 195 |
-
" forward pass of a single attention head.\n",
|
| 196 |
-
"\n",
|
| 197 |
-
" Args:\n",
|
| 198 |
-
" - x (Tensor): input tensor.\n",
|
| 199 |
-
" - mask (bool): flag indicating whether to apply masking\n",
|
| 200 |
-
" Returns:\n",
|
| 201 |
-
" - out (Tensor): output tensor after self attention\n",
|
| 202 |
-
" \"\"\"\n",
|
| 203 |
-
" B, T, C = x.shape\n",
|
| 204 |
-
" key = self.key(x)\n",
|
| 205 |
-
" query = self.query(x)\n",
|
| 206 |
-
" scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)\n",
|
| 207 |
-
"\n",
|
| 208 |
-
" rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_emb[:T, :T])\n",
|
| 209 |
-
" scores += rel_pos_scores\n",
|
| 210 |
-
"\n",
|
| 211 |
-
" if mask:\n",
|
| 212 |
-
" scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))\n",
|
| 213 |
-
" weights = F.softmax(scores, dim=-1)\n",
|
| 214 |
-
" weights = self.dropout(weights)\n",
|
| 215 |
-
"\n",
|
| 216 |
-
" value = self.value(x)\n",
|
| 217 |
-
" out = torch.matmul(weights, value)\n",
|
| 218 |
-
" return out\n",
|
| 219 |
-
"\n",
|
| 220 |
-
"class MultiHeadAttention(nn.Module):\n",
|
| 221 |
-
" \"\"\"\n",
|
| 222 |
-
" initialize a multi-head attention module.\n",
|
| 223 |
-
"\n",
|
| 224 |
-
" Args:\n",
|
| 225 |
-
" - d_model (int): dimensionality of the model's hidden layers\n",
|
| 226 |
-
" - n_head (int): no of attention heads\n",
|
| 227 |
-
" - dropout (float): dropout probability\n",
|
| 228 |
-
" - block_size (int): context length\n",
|
| 229 |
-
" \"\"\"\n",
|
| 230 |
-
" def __init__(self, d_model, n_head, dropout, block_size):\n",
|
| 231 |
-
" head_size = d_model // n_head\n",
|
| 232 |
-
" super().__init__()\n",
|
| 233 |
-
" self.heads = nn.ModuleList([AttentionHead(d_model=d_model, dropout=dropout, head_size=head_size, block_size=block_size) for _ in range(n_head)])\n",
|
| 234 |
-
" self.proj = nn.Linear(n_head * head_size, d_model)\n",
|
| 235 |
-
" self.dropout = nn.Dropout(dropout)\n",
|
| 236 |
-
"\n",
|
| 237 |
-
" def forward(self, x, mask):\n",
|
| 238 |
-
" \"\"\"\n",
|
| 239 |
-
" forward pass of the multi-head attention module\n",
|
| 240 |
-
"\n",
|
| 241 |
-
" Args:\n",
|
| 242 |
-
" - x (Tensor): input tensor\n",
|
| 243 |
-
" - mask (bool): flag indicating whether to apply masking\n",
|
| 244 |
-
"\n",
|
| 245 |
-
" Returns:\n",
|
| 246 |
-
" - out (Tensor): output tensor after multi-head attention\n",
|
| 247 |
-
"\n",
|
| 248 |
-
" \"\"\"\n",
|
| 249 |
-
" out = torch.cat([h(x, mask=mask) for h in self.heads], dim=-1)\n",
|
| 250 |
-
" out = self.dropout(self.proj(out))\n",
|
| 251 |
-
" return out\n",
|
| 252 |
-
"\n",
|
| 253 |
-
"class FeedForward(nn.Module):\n",
|
| 254 |
-
" \"\"\"\n",
|
| 255 |
-
" initialize a feedforward network module\n",
|
| 256 |
-
"\n",
|
| 257 |
-
" Args:\n",
|
| 258 |
-
" - d_model (int): the dimensionality of the model's hidden layers\n",
|
| 259 |
-
" - dropout (float): dropout probability\n",
|
| 260 |
-
"\n",
|
| 261 |
-
" \"\"\"\n",
|
| 262 |
-
" def __init__(self, d_model, dropout):\n",
|
| 263 |
-
" super().__init__()\n",
|
| 264 |
-
" self.net = nn.Sequential(\n",
|
| 265 |
-
" nn.Linear(d_model, 5*d_model),\n",
|
| 266 |
-
" nn.GELU(),\n",
|
| 267 |
-
" nn.Linear(5*d_model, d_model),\n",
|
| 268 |
-
" nn.Dropout(dropout)\n",
|
| 269 |
-
" )\n",
|
| 270 |
-
"\n",
|
| 271 |
-
" def forward(self, x):\n",
|
| 272 |
-
" \"\"\"\n",
|
| 273 |
-
" forward pass of the feedforward network module\n",
|
| 274 |
-
"\n",
|
| 275 |
-
" Args:\n",
|
| 276 |
-
" - x (Tensor): input tensor\n",
|
| 277 |
-
"\n",
|
| 278 |
-
" Returns:\n",
|
| 279 |
-
" - out (Tensor): output tensor after passing through the feedforward network\n",
|
| 280 |
-
" \"\"\"\n",
|
| 281 |
-
" return self.net(x)\n",
|
| 282 |
-
"\n",
|
| 283 |
-
"class EncoderNetwork(nn.Module):\n",
|
| 284 |
-
" \"\"\"\n",
|
| 285 |
-
" initialize an encoder network module\n",
|
| 286 |
-
"\n",
|
| 287 |
-
" Args:\n",
|
| 288 |
-
" - d_model (int): dimensionality of the model's hidden layers\n",
|
| 289 |
-
" - n_head (int): no of attention heads in multi-head attention layers\n",
|
| 290 |
-
" - norm_eps (float): epsilon value for layer normalization\n",
|
| 291 |
-
" - dropout (float): dropout probability\n",
|
| 292 |
-
" - block_size (int): the maximum sequence length for positional encoding\n",
|
| 293 |
-
" \"\"\"\n",
|
| 294 |
-
" def __init__(self, d_model, n_head, norm_eps, dropout, block_size):\n",
|
| 295 |
-
" super().__init__()\n",
|
| 296 |
-
" self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)\n",
|
| 297 |
-
" self.ffwd = FeedForward(d_model, dropout)\n",
|
| 298 |
-
" self.dropout = nn.Dropout(dropout)\n",
|
| 299 |
-
" self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)\n",
|
| 300 |
-
" self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)\n",
|
| 301 |
-
"\n",
|
| 302 |
-
" def forward(self, src):\n",
|
| 303 |
-
" \"\"\"\n",
|
| 304 |
-
" forward pass of the encoder network module.\n",
|
| 305 |
-
"\n",
|
| 306 |
-
" Args:\n",
|
| 307 |
-
" - src (Tensor): input tensor representing source data\n",
|
| 308 |
-
"\n",
|
| 309 |
-
" Returns:\n",
|
| 310 |
-
" - src (Tensor): output tensor after passing through the encoder network\n",
|
| 311 |
-
" \"\"\"\n",
|
| 312 |
-
" src2 = self.s_att(src, mask=False)\n",
|
| 313 |
-
" src = src + self.dropout(src2)\n",
|
| 314 |
-
" src = self.norm1(src)\n",
|
| 315 |
-
"\n",
|
| 316 |
-
" src2 = self.ffwd(src)\n",
|
| 317 |
-
" src = src + self.dropout(src2)\n",
|
| 318 |
-
" src = self.norm2(src)\n",
|
| 319 |
-
"\n",
|
| 320 |
-
" return src\n",
|
| 321 |
-
"\n",
|
| 322 |
-
"class DecoderNetwork(nn.Module):\n",
|
| 323 |
-
" \"\"\"\n",
|
| 324 |
-
" initialize a decoder network module\n",
|
| 325 |
-
"\n",
|
| 326 |
-
" Args:\n",
|
| 327 |
-
" - d_model (int): dimensionality of the model's hidden layers\n",
|
| 328 |
-
" - n_head (int): no of attention heads in multi-head attention layers\n",
|
| 329 |
-
" - norm_eps (float): epsilon value for layer normalization\n",
|
| 330 |
-
" - dropout (float): dropout probability\n",
|
| 331 |
-
" - block_size (int): the maximum sequence length for positional encoding\n",
|
| 332 |
-
" \"\"\"\n",
|
| 333 |
-
" def __init__(self, d_model, n_head, norm_eps, dropout, block_size):\n",
|
| 334 |
-
" super().__init__()\n",
|
| 335 |
-
" self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)\n",
|
| 336 |
-
" self.ffwd = FeedForward(d_model, dropout)\n",
|
| 337 |
-
" self.dropout = nn.Dropout(dropout)\n",
|
| 338 |
-
" self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)\n",
|
| 339 |
-
" self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)\n",
|
| 340 |
-
"\n",
|
| 341 |
-
" def forward(self, src, att):\n",
|
| 342 |
-
" \"\"\"\n",
|
| 343 |
-
" forward pass of the decoder network module.\n",
|
| 344 |
-
"\n",
|
| 345 |
-
" Args:\n",
|
| 346 |
-
" - src (Tensor): input tensor, same as the encoder's inputs\n",
|
| 347 |
-
" - trg (Tensor): encoder's attention matrix\n",
|
| 348 |
-
"\n",
|
| 349 |
-
" Returns:\n",
|
| 350 |
-
" - src_f (Tensor): final output tensor\n",
|
| 351 |
-
" \"\"\"\n",
|
| 352 |
-
" src2 = self.s_att(src, mask=True)\n",
|
| 353 |
-
" src = src + self.dropout(src2)\n",
|
| 354 |
-
" src = src + self.norm1(src)\n",
|
| 355 |
-
"\n",
|
| 356 |
-
" att = src + att\n",
|
| 357 |
-
" att2 = self.s_att(att, mask=False)\n",
|
| 358 |
-
" att2 = att + self.dropout(att2)\n",
|
| 359 |
-
" trg = att2 + self.norm1(att2)\n",
|
| 360 |
-
"\n",
|
| 361 |
-
" src_f2 = self.ffwd(self.norm2(trg))\n",
|
| 362 |
-
" src_f = src_f2 + self.dropout(src_f2)\n",
|
| 363 |
-
" src_f = self.norm2(src_f)\n",
|
| 364 |
-
"\n",
|
| 365 |
-
" return src_f\n",
|
| 366 |
-
"\n",
|
| 367 |
-
"class Transformer(nn.Module):\n",
|
| 368 |
-
" \"\"\"\n",
|
| 369 |
-
" initialize a Transformer model\n",
|
| 370 |
-
"\n",
|
| 371 |
-
" Args:\n",
|
| 372 |
-
" - vocab_size (int): size of the vocabulary\n",
|
| 373 |
-
" - d_model (int): dimensionality of the model's hidden layers\n",
|
| 374 |
-
" - block_size (int): maximum sequence length for positional encoding/context length\n",
|
| 375 |
-
" - n_layers (int): number of encoder and decoder layers in the Transformer\n",
|
| 376 |
-
" - n_head (int): number of attention heads in multi-head attention layers\n",
|
| 377 |
-
" - norm_eps (float): epsilon value for layer normalization\n",
|
| 378 |
-
" - dropout (float): dropout probability\n",
|
| 379 |
-
" \"\"\"\n",
|
| 380 |
-
" def __init__(self, vocab_size):\n",
|
| 381 |
-
" super().__init__()\n",
|
| 382 |
-
" self.block_size = block_size\n",
|
| 383 |
-
" self.toked_model = nn.Embedding(vocab_size, d_model)\n",
|
| 384 |
-
" self.pos_encod = nn.Embedding(block_size, d_model)\n",
|
| 385 |
-
" self.enc_layer = nn.ModuleList([EncoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])\n",
|
| 386 |
-
" self.dec_layer = nn.ModuleList([DecoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])\n",
|
| 387 |
-
"\n",
|
| 388 |
-
" self.norm_final = nn.LayerNorm(d_model)\n",
|
| 389 |
-
" self.linear_final = nn.Linear(d_model, vocab_size)\n",
|
| 390 |
-
" self.dropout = nn.Dropout(dropout)\n",
|
| 391 |
-
" self.apply(self._init_weights)\n",
|
| 392 |
-
"\n",
|
| 393 |
-
" def _init_weights(self, module):\n",
|
| 394 |
-
" \"\"\"\n",
|
| 395 |
-
" initialize weights of linear and embedding layers\n",
|
| 396 |
-
"\n",
|
| 397 |
-
" Args:\n",
|
| 398 |
-
" - module (nn.Module): the module to initialize weights for\n",
|
| 399 |
-
" \"\"\"\n",
|
| 400 |
-
" if isinstance(module, nn.Linear):\n",
|
| 401 |
-
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
|
| 402 |
-
" if module.bias is not None:\n",
|
| 403 |
-
" torch.nn.init.zeros_(module.bias.data)\n",
|
| 404 |
-
" elif isinstance(module, nn.Embedding):\n",
|
| 405 |
-
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
|
| 406 |
-
"\n",
|
| 407 |
-
" def forward(self, idx, targets=None):\n",
|
| 408 |
-
" \"\"\"\n",
|
| 409 |
-
" forward pass of the transformer model\n",
|
| 410 |
-
"\n",
|
| 411 |
-
" Args:\n",
|
| 412 |
-
" - idx (Tensor): input tensor representing token indices\n",
|
| 413 |
-
" - targets (Tensor): target tensor for computing loss during training\n",
|
| 414 |
-
"\n",
|
| 415 |
-
" Returns:\n",
|
| 416 |
-
" - logits (Tensor): output logits from the final linear layer\n",
|
| 417 |
-
" - loss (Tensor): optional. computed cross-entropy loss if targets are provided, else None\n",
|
| 418 |
-
" \"\"\"\n",
|
| 419 |
-
" B, T = idx.shape\n",
|
| 420 |
-
"\n",
|
| 421 |
-
" toked_model = self.toked_model(idx)\n",
|
| 422 |
-
" pos_encod = self.pos_encod(torch.arange(T, device=device))\n",
|
| 423 |
-
" x = toked_model + pos_encod\n",
|
| 424 |
-
"\n",
|
| 425 |
-
" for layer in self.enc_layer:\n",
|
| 426 |
-
" x_out = layer(x)\n",
|
| 427 |
-
"\n",
|
| 428 |
-
" for layer in self.dec_layer:\n",
|
| 429 |
-
" x_final = layer(x, x_out)\n",
|
| 430 |
-
"\n",
|
| 431 |
-
" x_final = self.norm_final(x_final)\n",
|
| 432 |
-
" logits = self.linear_final(x_final)\n",
|
| 433 |
-
"\n",
|
| 434 |
-
" if targets is None:\n",
|
| 435 |
-
" loss = None\n",
|
| 436 |
-
"\n",
|
| 437 |
-
" else:\n",
|
| 438 |
-
" B, T, C = logits.shape\n",
|
| 439 |
-
" logits = logits.view(B*T, C)\n",
|
| 440 |
-
" targets = targets.view(B*T)\n",
|
| 441 |
-
" loss = F.cross_entropy(logits, targets)\n",
|
| 442 |
-
"\n",
|
| 443 |
-
" return logits, loss\n",
|
| 444 |
-
" def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):\n",
|
| 445 |
-
" \"\"\"\n",
|
| 446 |
-
" generate new tokens using the trained model\n",
|
| 447 |
-
"\n",
|
| 448 |
-
" Args:\n",
|
| 449 |
-
" - idx (Tensor): input tensor representing initial token indices\n",
|
| 450 |
-
" - max_new_tokens (int): max no of new tokens to generate\n",
|
| 451 |
-
" - temperature (float): softmax temperature for sampling\n",
|
| 452 |
-
" - top_k (int): no of top tokens to consider in sampling\n",
|
| 453 |
-
"\n",
|
| 454 |
-
" Returns:\n",
|
| 455 |
-
" - generated_tokens (list): list of generated token indices\n",
|
| 456 |
-
" \"\"\"\n",
|
| 457 |
-
" generated_tokens = []\n",
|
| 458 |
-
"\n",
|
| 459 |
-
" for _ in range(max_new_tokens):\n",
|
| 460 |
-
" idx_cond = idx[:, -self.block_size:]\n",
|
| 461 |
-
" logits, _ = self(idx_cond)\n",
|
| 462 |
-
" logits = logits[:, -1, :]\n",
|
| 463 |
-
"\n",
|
| 464 |
-
" scaled_logits = logits / temperature\n",
|
| 465 |
-
" if top_k > 0:\n",
|
| 466 |
-
" scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
|
| 467 |
-
"\n",
|
| 468 |
-
" probs = F.softmax(scaled_logits, dim=-1)\n",
|
| 469 |
-
" sampled_idx = torch.multinomial(probs, num_samples=1)\n",
|
| 470 |
-
" generated_tokens.append(sampled_idx.item())\n",
|
| 471 |
-
" idx = torch.cat((idx, sampled_idx), dim=1)\n",
|
| 472 |
-
"\n",
|
| 473 |
-
" return generated_tokens\n",
|
| 474 |
-
"\n",
|
| 475 |
-
" def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):\n",
|
| 476 |
-
" \"\"\"\n",
|
| 477 |
-
" Generate predictions for masked tokens using the trained model.\n",
|
| 478 |
-
"\n",
|
| 479 |
-
" Args:\n",
|
| 480 |
-
" - idx (Tensor): input tensor representing token indices\n",
|
| 481 |
-
" - masked_indices (Tensor): tensor of indices indicating masked positions\n",
|
| 482 |
-
" - temperature (float): softmax temperature for sampling\n",
|
| 483 |
-
" - top_k (int): no of top tokens to consider in sampling\n",
|
| 484 |
-
"\n",
|
| 485 |
-
" Returns:\n",
|
| 486 |
-
" - predicted_tokens (Tensor): tensor of predicted token indices\n",
|
| 487 |
-
" \"\"\"\n",
|
| 488 |
-
" B, T = idx.shape\n",
|
| 489 |
-
"\n",
|
| 490 |
-
" toked_model = self.toked_model(idx)\n",
|
| 491 |
-
" pos_encod = self.pos_encod(torch.arange(T, device=device))\n",
|
| 492 |
-
" x = toked_model + pos_encod\n",
|
| 493 |
-
"\n",
|
| 494 |
-
" for layer in self.enc_layer:\n",
|
| 495 |
-
" x_out = layer(x)\n",
|
| 496 |
-
"\n",
|
| 497 |
-
" for layer in self.dec_layer:\n",
|
| 498 |
-
" x_final = layer(x, x_out)\n",
|
| 499 |
-
"\n",
|
| 500 |
-
" x_masked = x_final.clone()\n",
|
| 501 |
-
" x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))\n",
|
| 502 |
-
"\n",
|
| 503 |
-
" x_masked = self.norm_final(x_masked)\n",
|
| 504 |
-
" logits = self.linear_final(x_masked)\n",
|
| 505 |
-
"\n",
|
| 506 |
-
" masked_logits = logits[masked_indices].view(-1, logits.size(-1))\n",
|
| 507 |
-
" scaled_logits = masked_logits / temperature\n",
|
| 508 |
-
" if top_k > 0:\n",
|
| 509 |
-
" scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
|
| 510 |
-
"\n",
|
| 511 |
-
" probs = F.softmax(scaled_logits, dim=-1)\n",
|
| 512 |
-
" predicted_indices = torch.argmax(probs, dim=-1)\n",
|
| 513 |
-
"\n",
|
| 514 |
-
" return predicted_indices\n",
|
| 515 |
-
"\n",
|
| 516 |
-
" def _top_k_filtering(self, logits, top_k):\n",
|
| 517 |
-
" \"\"\"\n",
|
| 518 |
-
" filter logits to keep only the top-k tokens\n",
|
| 519 |
-
"\n",
|
| 520 |
-
" Args:\n",
|
| 521 |
-
" - logits (Tensor): input tensor representing unscaled logits\n",
|
| 522 |
-
" - top_k (int): no of top tokens to keep\n",
|
| 523 |
-
"\n",
|
| 524 |
-
" Returns:\n",
|
| 525 |
-
" - filtered_logits (Tensor): filtered logits with only top-k tokens remaining\n",
|
| 526 |
-
" \"\"\"\n",
|
| 527 |
-
" values, indices = torch.topk(logits, top_k, dim=-1)\n",
|
| 528 |
-
" min_value = values[:, -1].unsqueeze(-1).expand_as(logits)\n",
|
| 529 |
-
" filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)\n",
|
| 530 |
-
"\n",
|
| 531 |
-
" return filtered_logits"
|
| 532 |
-
]
|
| 533 |
-
},
|
| 534 |
-
{
|
| 535 |
-
"cell_type": "code",
|
| 536 |
-
"execution_count": 7,
|
| 537 |
-
"metadata": {
|
| 538 |
-
"colab": {
|
| 539 |
-
"base_uri": "https://localhost:8080/",
|
| 540 |
-
"height": 816
|
| 541 |
-
},
|
| 542 |
-
"id": "X9VOBZFr7g3W",
|
| 543 |
-
"outputId": "aa376025-0a37-4b93-e90a-9d95c6ef2c11"
|
| 544 |
-
},
|
| 545 |
-
"outputs": [
|
| 546 |
-
{
|
| 547 |
-
"output_type": "stream",
|
| 548 |
-
"name": "stdout",
|
| 549 |
-
"text": [
|
| 550 |
-
"2.5 billion parameters\n",
|
| 551 |
-
"step 0: train loss 2.2869, val loss 2.2884\n",
|
| 552 |
-
"step 100: train loss 1.3312, val loss 1.3281\n",
|
| 553 |
-
"step 200: train loss 1.3233, val loss 1.3181\n",
|
| 554 |
-
"step 300: train loss 1.3209, val loss 1.3196\n",
|
| 555 |
-
"step 400: train loss 1.3215, val loss 1.3203\n",
|
| 556 |
-
"step 500: train loss 1.1974, val loss 1.1994\n",
|
| 557 |
-
"step 600: train loss 0.3350, val loss 0.3365\n",
|
| 558 |
-
"step 700: train loss 0.0703, val loss 0.0702\n",
|
| 559 |
-
"step 800: train loss 0.0143, val loss 0.0143\n",
|
| 560 |
-
"step 900: train loss 0.0049, val loss 0.0047\n",
|
| 561 |
-
"step 1000: train loss 0.0041, val loss 0.0037\n",
|
| 562 |
-
"step 1100: train loss 0.0035, val loss 0.0036\n",
|
| 563 |
-
"step 1200: train loss 0.0038, val loss 0.0035\n",
|
| 564 |
-
"step 1300: train loss 0.0035, val loss 0.0033\n",
|
| 565 |
-
"step 1400: train loss 0.0035, val loss 0.0033\n",
|
| 566 |
-
"step 1500: train loss 0.0033, val loss 0.0033\n",
|
| 567 |
-
"step 1600: train loss 0.0033, val loss 0.0034\n",
|
| 568 |
-
"step 1700: train loss 0.0033, val loss 0.0033\n",
|
| 569 |
-
"step 1800: train loss 0.0033, val loss 0.0031\n",
|
| 570 |
-
"step 1900: train loss 0.0031, val loss 0.0031\n",
|
| 571 |
-
"step 2000: train loss 0.0032, val loss 0.0032\n"
|
| 572 |
-
]
|
| 573 |
-
},
|
| 574 |
-
{
|
| 575 |
-
"output_type": "error",
|
| 576 |
-
"ename": "KeyboardInterrupt",
|
| 577 |
-
"evalue": "",
|
| 578 |
-
"traceback": [
|
| 579 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 580 |
-
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 581 |
-
"\u001b[0;32m<ipython-input-7-44818790f2dc>\u001b[0m in \u001b[0;36m<cell line: 45>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mset_to_none\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 582 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 583 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 584 |
-
"\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, idx, targets)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdec_layer\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mx_final\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0mx_final\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm_final\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_final\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 585 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 586 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 587 |
-
"\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, src, att)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0matt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msrc\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0matt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 191\u001b[0;31m \u001b[0matt2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0ms_att\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0matt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 192\u001b[0m \u001b[0matt2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0matt\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0matt2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[0mtrg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0matt2\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0matt2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 588 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 589 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 590 |
-
"\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 82\u001b[0m \"\"\"\n\u001b[0;32m---> 83\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mh\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheads\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 84\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 591 |
-
"\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 82\u001b[0m \"\"\"\n\u001b[0;32m---> 83\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mh\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheads\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 84\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 592 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 593 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 594 |
-
"\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 50\u001b[0;31m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 51\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 595 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 596 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 597 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 598 |
-
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
| 599 |
-
]
|
| 600 |
-
}
|
| 601 |
-
],
|
| 602 |
-
"source": [
|
| 603 |
-
"import timeit\n",
|
| 604 |
-
"\n",
|
| 605 |
-
"start_time = timeit.default_timer()\n",
|
| 606 |
-
"# data loading\n",
|
| 607 |
-
"def get_batch(split):\n",
|
| 608 |
-
"\n",
|
| 609 |
-
" data = train_data if split == 'train' else val_data\n",
|
| 610 |
-
" ix = torch.randint(len(data) - block_size, (batch_size,))\n",
|
| 611 |
-
" x = torch.stack([data[i:i+block_size] for i in ix])\n",
|
| 612 |
-
" y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
|
| 613 |
-
" x, y = x.to(device), y.to(device)\n",
|
| 614 |
-
" return x, y\n",
|
| 615 |
-
"\n",
|
| 616 |
-
"@torch.no_grad()\n",
|
| 617 |
-
"def estimate_loss():\n",
|
| 618 |
-
" out = {}\n",
|
| 619 |
-
" model.eval()\n",
|
| 620 |
-
" for split in ['train', 'val']:\n",
|
| 621 |
-
" losses = torch.zeros(eval_iters)\n",
|
| 622 |
-
" for k in range(eval_iters):\n",
|
| 623 |
-
" X, Y = get_batch(split)\n",
|
| 624 |
-
" logits, loss = model(X, Y)\n",
|
| 625 |
-
" losses[k] = loss.item()\n",
|
| 626 |
-
" out[split] = losses.mean()\n",
|
| 627 |
-
" model.train()\n",
|
| 628 |
-
" return out\n",
|
| 629 |
-
"\n",
|
| 630 |
-
"vocab_size = token.vocab_size\n",
|
| 631 |
-
"model = Transformer(vocab_size)\n",
|
| 632 |
-
"# checkpoint_path = '/content/drive/MyDrive/enigma-2.5b.pth'\n",
|
| 633 |
-
"# checkpoint = torch.load(checkpoint_path)\n",
|
| 634 |
-
"# model.load_state_dict(checkpoint)\n",
|
| 635 |
-
"m = model.to(device)\n",
|
| 636 |
-
"\n",
|
| 637 |
-
"# no of parameters\n",
|
| 638 |
-
"n_param = sum(p.numel() for p in m.parameters())/1e9\n",
|
| 639 |
-
"print(f\"{n_param:.1f} billion parameters\")\n",
|
| 640 |
-
"\n",
|
| 641 |
-
"# optimizer\n",
|
| 642 |
-
"optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
|
| 643 |
-
"steps = []\n",
|
| 644 |
-
"train_losses = []\n",
|
| 645 |
-
"val_losses = []\n",
|
| 646 |
-
"\n",
|
| 647 |
-
"for iter in range(max_iters):\n",
|
| 648 |
-
"\n",
|
| 649 |
-
" if iter % eval_interval == 0 or iter == max_iters - 1:\n",
|
| 650 |
-
" losses = estimate_loss()\n",
|
| 651 |
-
" print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
|
| 652 |
-
"\n",
|
| 653 |
-
" steps.append(iter)\n",
|
| 654 |
-
" train_losses.append(losses['train'])\n",
|
| 655 |
-
" val_losses.append(losses['val'])\n",
|
| 656 |
-
"\n",
|
| 657 |
-
" xb, yb = get_batch('train')\n",
|
| 658 |
-
" logits, loss = model(xb, yb)\n",
|
| 659 |
-
" optimizer.zero_grad(set_to_none=True)\n",
|
| 660 |
-
" loss.backward()\n",
|
| 661 |
-
" optimizer.step()"
|
| 662 |
-
]
|
| 663 |
-
},
|
| 664 |
-
{
|
| 665 |
-
"cell_type": "code",
|
| 666 |
-
"execution_count": 8,
|
| 667 |
-
"metadata": {
|
| 668 |
-
"id": "tzJMKoA35uIV",
|
| 669 |
-
"colab": {
|
| 670 |
-
"base_uri": "https://localhost:8080/"
|
| 671 |
-
},
|
| 672 |
-
"outputId": "ba527bf5-695c-4a8f-acc4-bd60d549eaad"
|
| 673 |
-
},
|
| 674 |
-
"outputs": [
|
| 675 |
-
{
|
| 676 |
-
"output_type": "stream",
|
| 677 |
-
"name": "stdout",
|
| 678 |
-
"text": [
|
| 679 |
-
"total parameters: 2.5 billion\n",
|
| 680 |
-
"trained in 1.82hrs\n"
|
| 681 |
-
]
|
| 682 |
-
}
|
| 683 |
-
],
|
| 684 |
-
"source": [
|
| 685 |
-
"end_time = timeit.default_timer()\n",
|
| 686 |
-
"print(f\"total parameters: {n_param:.1f} billion\")\n",
|
| 687 |
-
"print(f\"trained in {((end_time - start_time)/3600):.2f}hrs\")"
|
| 688 |
-
]
|
| 689 |
-
},
|
| 690 |
-
{
|
| 691 |
-
"cell_type": "code",
|
| 692 |
-
"source": [
|
| 693 |
-
"model_save_name = f'enigma-{n_param:.1f}b_v1.pth'\n",
|
| 694 |
-
"path = f\"/content/drive/MyDrive/{model_save_name}\"\n",
|
| 695 |
-
"torch.save(model.state_dict(), path)"
|
| 696 |
-
],
|
| 697 |
-
"metadata": {
|
| 698 |
-
"id": "eB47Yn9aNrrO"
|
| 699 |
-
},
|
| 700 |
-
"execution_count": 10,
|
| 701 |
-
"outputs": []
|
| 702 |
-
},
|
| 703 |
-
{
|
| 704 |
-
"cell_type": "code",
|
| 705 |
-
"source": [
|
| 706 |
-
"# 8-bit quantization\n",
|
| 707 |
-
"\n",
|
| 708 |
-
"import torch\n",
|
| 709 |
-
"import torch.quantization\n",
|
| 710 |
-
"\n",
|
| 711 |
-
"checkpoint_path = '/content/drive/MyDrive/enigma-2.5b.pth'\n",
|
| 712 |
-
"checkpoint = torch.load(checkpoint_path)\n",
|
| 713 |
-
"model.load_state_dict(checkpoint)\n",
|
| 714 |
-
"model = model.to(device)\n",
|
| 715 |
-
"\n",
|
| 716 |
-
"quantized_model = torch.quantization.quantize_dynamic(\n",
|
| 717 |
-
" model,\n",
|
| 718 |
-
" dtype=torch.qint8\n",
|
| 719 |
-
")\n",
|
| 720 |
-
"quantized_model_file = f'/content/drive/MyDrive/enigma-2.5b-quant.pth'\n",
|
| 721 |
-
"torch.save(quantized_model.state_dict(), quantized_model_file)\n",
|
| 722 |
-
"\n",
|
| 723 |
-
"print(\"Quantized model saved successfully.\")"
|
| 724 |
-
],
|
| 725 |
-
"metadata": {
|
| 726 |
-
"id": "7iGQdNHgms_U"
|
| 727 |
-
},
|
| 728 |
-
"execution_count": null,
|
| 729 |
-
"outputs": []
|
| 730 |
-
},
|
| 731 |
-
{
|
| 732 |
-
"cell_type": "code",
|
| 733 |
-
"source": [
|
| 734 |
-
"# pruning\n",
|
| 735 |
-
"\n",
|
| 736 |
-
"import torch\n",
|
| 737 |
-
"from torch import nn\n",
|
| 738 |
-
"from torch.utils.model_zoo import load_url\n",
|
| 739 |
-
"import torch.nn.utils.prune as prune\n",
|
| 740 |
-
"\n",
|
| 741 |
-
"parameters_to_prune = [(model.encoder.self_attn, 'weight'), (model.encoder.linear1, 'weight')]\n",
|
| 742 |
-
"prune.global_unstructured(\n",
|
| 743 |
-
" parameters_to_prune,\n",
|
| 744 |
-
" pruning_method=prune.L1Unstructured,\n",
|
| 745 |
-
" amount=0.15,\n",
|
| 746 |
-
")\n",
|
| 747 |
-
"\n",
|
| 748 |
-
"torch.save(model.state_dict(), 'enigma-2.5b_pruned.pth')"
|
| 749 |
-
],
|
| 750 |
-
"metadata": {
|
| 751 |
-
"id": "YTJ19n4OFvZj"
|
| 752 |
-
},
|
| 753 |
-
"execution_count": null,
|
| 754 |
-
"outputs": []
|
| 755 |
-
},
|
| 756 |
-
{
|
| 757 |
-
"cell_type": "code",
|
| 758 |
-
"execution_count": null,
|
| 759 |
-
"metadata": {
|
| 760 |
-
"id": "K2FDOp7Quibq"
|
| 761 |
-
},
|
| 762 |
-
"outputs": [],
|
| 763 |
-
"source": [
|
| 764 |
-
"class Generator(Transformer):\n",
|
| 765 |
-
" def __init__(self, vocab_size, block_size):\n",
|
| 766 |
-
" super().__init__(vocab_size)\n",
|
| 767 |
-
" self.vocab_size = vocab_size\n",
|
| 768 |
-
" self.block_size = block_size\n",
|
| 769 |
-
"\n",
|
| 770 |
-
" def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):\n",
|
| 771 |
-
" \"\"\"\n",
|
| 772 |
-
" generate new tokens using the trained model\n",
|
| 773 |
-
"\n",
|
| 774 |
-
" Args:\n",
|
| 775 |
-
" - idx (Tensor): input tensor representing initial token indices\n",
|
| 776 |
-
" - max_new_tokens (int): max no of new tokens to generate\n",
|
| 777 |
-
" - temperature (float): softmax temperature for sampling\n",
|
| 778 |
-
" - top_k (int): no of top tokens to consider in sampling\n",
|
| 779 |
-
"\n",
|
| 780 |
-
" Returns:\n",
|
| 781 |
-
" - generated_tokens (list): list of generated token indices\n",
|
| 782 |
-
" \"\"\"\n",
|
| 783 |
-
" generated_tokens = []\n",
|
| 784 |
-
"\n",
|
| 785 |
-
" for _ in range(max_new_tokens):\n",
|
| 786 |
-
" idx_cond = idx[:, -self.block_size:]\n",
|
| 787 |
-
" logits, _ = self(idx_cond)\n",
|
| 788 |
-
" logits = logits[:, -1, :]\n",
|
| 789 |
-
"\n",
|
| 790 |
-
" scaled_logits = logits / temperature\n",
|
| 791 |
-
" if top_k > 0:\n",
|
| 792 |
-
" scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
|
| 793 |
-
"\n",
|
| 794 |
-
" probs = F.softmax(scaled_logits, dim=-1)\n",
|
| 795 |
-
" sampled_idx = torch.multinomial(probs, num_samples=1)\n",
|
| 796 |
-
" generated_tokens.append(sampled_idx.item())\n",
|
| 797 |
-
" idx = torch.cat((idx, sampled_idx), dim=1)\n",
|
| 798 |
-
"\n",
|
| 799 |
-
" return generated_tokens\n",
|
| 800 |
-
"\n",
|
| 801 |
-
" def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):\n",
|
| 802 |
-
" \"\"\"\n",
|
| 803 |
-
" Generate predictions for masked tokens using the trained model.\n",
|
| 804 |
-
"\n",
|
| 805 |
-
" Args:\n",
|
| 806 |
-
" - idx (Tensor): input tensor representing token indices\n",
|
| 807 |
-
" - masked_indices (Tensor): tensor of indices indicating masked positions\n",
|
| 808 |
-
" - temperature (float): softmax temperature for sampling\n",
|
| 809 |
-
" - top_k (int): no of top tokens to consider in sampling\n",
|
| 810 |
-
"\n",
|
| 811 |
-
" Returns:\n",
|
| 812 |
-
" - predicted_tokens (Tensor): tensor of predicted token indices\n",
|
| 813 |
-
" \"\"\"\n",
|
| 814 |
-
" B, T = idx.shape\n",
|
| 815 |
-
"\n",
|
| 816 |
-
" toked_model = self.toked_model(idx)\n",
|
| 817 |
-
" pos_encod = self.pos_encod(torch.arange(T, device=device))\n",
|
| 818 |
-
" x = toked_model + pos_encod\n",
|
| 819 |
-
"\n",
|
| 820 |
-
" for layer in self.enc_layer:\n",
|
| 821 |
-
" x_out = layer(x)\n",
|
| 822 |
-
"\n",
|
| 823 |
-
" for layer in self.dec_layer:\n",
|
| 824 |
-
" x_final = layer(x, x_out)\n",
|
| 825 |
-
"\n",
|
| 826 |
-
" x_masked = x_final.clone()\n",
|
| 827 |
-
" x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))\n",
|
| 828 |
-
"\n",
|
| 829 |
-
" x_masked = self.norm_final(x_masked)\n",
|
| 830 |
-
" logits = self.linear_final(x_masked)\n",
|
| 831 |
-
"\n",
|
| 832 |
-
" masked_logits = logits[masked_indices].view(-1, logits.size(-1))\n",
|
| 833 |
-
" scaled_logits = masked_logits / temperature\n",
|
| 834 |
-
" if top_k > 0:\n",
|
| 835 |
-
" scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
|
| 836 |
-
"\n",
|
| 837 |
-
" probs = F.softmax(scaled_logits, dim=-1)\n",
|
| 838 |
-
" predicted_indices = torch.argmax(probs, dim=-1)\n",
|
| 839 |
-
"\n",
|
| 840 |
-
" return predicted_indices\n",
|
| 841 |
-
"\n",
|
| 842 |
-
" def _top_k_filtering(self, logits, top_k):\n",
|
| 843 |
-
" \"\"\"\n",
|
| 844 |
-
" filter logits to keep only the top-k tokens\n",
|
| 845 |
-
"\n",
|
| 846 |
-
" Args:\n",
|
| 847 |
-
" - logits (Tensor): input tensor representing unscaled logits\n",
|
| 848 |
-
" - top_k (int): no of top tokens to keep\n",
|
| 849 |
-
"\n",
|
| 850 |
-
" Returns:\n",
|
| 851 |
-
" - filtered_logits (Tensor): filtered logits with only top-k tokens remaining\n",
|
| 852 |
-
" \"\"\"\n",
|
| 853 |
-
" values, indices = torch.topk(logits, top_k, dim=-1)\n",
|
| 854 |
-
" min_value = values[:, -1].unsqueeze(-1).expand_as(logits)\n",
|
| 855 |
-
" filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)\n",
|
| 856 |
-
"\n",
|
| 857 |
-
" return filtered_logits"
|
| 858 |
-
]
|
| 859 |
-
},
|
| 860 |
-
{
|
| 861 |
-
"cell_type": "code",
|
| 862 |
-
"execution_count": null,
|
| 863 |
-
"metadata": {
|
| 864 |
-
"colab": {
|
| 865 |
-
"base_uri": "https://localhost:8080/",
|
| 866 |
-
"height": 429
|
| 867 |
-
},
|
| 868 |
-
"id": "c5CknylV4S2m",
|
| 869 |
-
"outputId": "12314d78-9147-4e60-f8b5-84207b97a1c7"
|
| 870 |
-
},
|
| 871 |
-
"outputs": [
|
| 872 |
-
{
|
| 873 |
-
"output_type": "error",
|
| 874 |
-
"ename": "RuntimeError",
|
| 875 |
-
"evalue": "Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)",
|
| 876 |
-
"traceback": [
|
| 877 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 878 |
-
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
| 879 |
-
"\u001b[0;32m<ipython-input-17-db17ec37b06c>\u001b[0m in \u001b[0;36m<cell line: 5>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mtarget_text\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"AGTTCTGCGAT\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtoken\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget_text\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mgenerated_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoken\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_new_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtemperature\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.9\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtop_k\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{target_text}{generated_output}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 880 |
-
"\u001b[0;32m<ipython-input-16-39da0e3e4598>\u001b[0m in \u001b[0;36mgenerate\u001b[0;34m(self, idx, max_new_tokens, temperature, top_k)\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_new_tokens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0midx_cond\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mblock_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx_cond\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogits\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 881 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 882 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 883 |
-
"\u001b[0;32m<ipython-input-7-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, idx, targets)\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mT\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 255\u001b[0;31m \u001b[0mtoked_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoked_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 256\u001b[0m \u001b[0mpos_encod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpos_encod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoked_model\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mpos_encod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 884 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 885 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 886 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m return F.embedding(\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpadding_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_norm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m self.norm_type, self.scale_grad_by_freq, self.sparse)\n",
|
| 887 |
-
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36membedding\u001b[0;34m(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001b[0m\n\u001b[1;32m 2235\u001b[0m \u001b[0;31m# remove once script supports set_grad_enabled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2236\u001b[0m \u001b[0m_no_grad_embedding_renorm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_norm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2237\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpadding_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale_grad_by_freq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msparse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2238\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2239\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 888 |
-
"\u001b[0;31mRuntimeError\u001b[0m: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)"
|
| 889 |
-
]
|
| 890 |
-
}
|
| 891 |
-
],
|
| 892 |
-
"source": [
|
| 893 |
-
"generator = Generator(vocab_size, block_size)\n",
|
| 894 |
-
"\n",
|
| 895 |
-
"target_text = \"AGTTCTGCGAT\"\n",
|
| 896 |
-
"context = torch.tensor([token.encode(target_text)], dtype=torch.long, device=device)\n",
|
| 897 |
-
"generated_output = token.decode(generator.generate(context, max_new_tokens=100, temperature=0.9, top_k=5))\n",
|
| 898 |
-
"print(f\"{target_text}{generated_output}\")"
|
| 899 |
-
]
|
| 900 |
-
}
|
| 901 |
-
],
|
| 902 |
-
"metadata": {
|
| 903 |
-
"accelerator": "GPU",
|
| 904 |
-
"colab": {
|
| 905 |
-
"gpuType": "T4",
|
| 906 |
-
"machine_shape": "hm",
|
| 907 |
-
"provenance": []
|
| 908 |
-
},
|
| 909 |
-
"kernelspec": {
|
| 910 |
-
"display_name": "Python 3",
|
| 911 |
-
"name": "python3"
|
| 912 |
-
},
|
| 913 |
-
"language_info": {
|
| 914 |
-
"name": "python"
|
| 915 |
-
}
|
| 916 |
-
},
|
| 917 |
-
"nbformat": 4,
|
| 918 |
-
"nbformat_minor": 0
|
| 919 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enigma/config_enigma.json
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"batch_size": 10,
|
| 3 |
-
"block_size": 512,
|
| 4 |
-
"max_iters": 5000,
|
| 5 |
-
"eval_interval": 50,
|
| 6 |
-
"learning_rate": 3e-5,
|
| 7 |
-
"eval_iters": 100,
|
| 8 |
-
"d_model": 384,
|
| 9 |
-
"n_head": 12,
|
| 10 |
-
"n_layer": 12,
|
| 11 |
-
"dropout": 0.2,
|
| 12 |
-
"norm_eps": 1e-5
|
| 13 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enigma/enigma.cpp
DELETED
|
@@ -1,364 +0,0 @@
|
|
| 1 |
-
#include <torch/torch.h>
|
| 2 |
-
#include <iostream>
|
| 3 |
-
#include <vector>
|
| 4 |
-
|
| 5 |
-
// Define device
|
| 6 |
-
torch::Device device(torch::kCUDA);
|
| 7 |
-
|
| 8 |
-
// Define constants
|
| 9 |
-
const int batch_size = 8;
|
| 10 |
-
const int block_size = 32;
|
| 11 |
-
const int max_iters = 1000;
|
| 12 |
-
const int eval_interval = 50;
|
| 13 |
-
const int eval_iters = 5;
|
| 14 |
-
const int d_model = 256;
|
| 15 |
-
const int n_layer = 16;
|
| 16 |
-
const int n_head = 12;
|
| 17 |
-
const float dropout = 0.2;
|
| 18 |
-
const float norm_eps = 1e-5;
|
| 19 |
-
const int vocab_size = 5;
|
| 20 |
-
|
| 21 |
-
// sample data
|
| 22 |
-
torch::Tensor train_data = torch::rand({1000, block_size});
|
| 23 |
-
torch::Tensor val_data = torch::rand({500, block_size});
|
| 24 |
-
|
| 25 |
-
// Data loading function
|
| 26 |
-
std::pair<torch::Tensor, torch::Tensor> get_batch(const std::string& split) {
|
| 27 |
-
torch::Tensor data = (split == "train") ? train_data : val_data;
|
| 28 |
-
torch::Tensor ix = torch::randint(data.size(0) - block_size, {batch_size});
|
| 29 |
-
torch::Tensor x = torch::empty({batch_size, block_size});
|
| 30 |
-
torch::Tensor y = torch::empty({batch_size, block_size});
|
| 31 |
-
for (int i = 0; i < batch_size; ++i) {
|
| 32 |
-
x[i] = data.index({ix[i], ix[i] + block_size});
|
| 33 |
-
y[i] = data.index({ix[i] + 1, ix[i] + block_size + 1});
|
| 34 |
-
}
|
| 35 |
-
return std::make_pair(x.to(device), y.to(device));
|
| 36 |
-
}
|
| 37 |
-
|
| 38 |
-
// Custom classes and functions
|
| 39 |
-
class SWiGLU : public torch::nn::Module {
|
| 40 |
-
public:
|
| 41 |
-
SWiGLU() {}
|
| 42 |
-
|
| 43 |
-
torch::Tensor forward(torch::Tensor x) {
|
| 44 |
-
torch::Tensor sigmoid_output = torch::sigmoid(x);
|
| 45 |
-
torch::Tensor relu_output = torch::relu(x);
|
| 46 |
-
torch::Tensor out = sigmoid_output * relu_output + (1 - sigmoid_output) * x;
|
| 47 |
-
return out;
|
| 48 |
-
}
|
| 49 |
-
};
|
| 50 |
-
|
| 51 |
-
class UnMaskedHeadImpl : public torch::nn::Module {
|
| 52 |
-
public:
|
| 53 |
-
UnMaskedHeadImpl(int d_model, int head_size, float dropout)
|
| 54 |
-
: key(register_module("key", torch::nn::Linear(d_model, head_size))),
|
| 55 |
-
query(register_module("query", torch::nn::Linear(d_model, head_size))),
|
| 56 |
-
value(register_module("value", torch::nn::Linear(d_model, head_size))),
|
| 57 |
-
dropout(torch::nn::Dropout(dropout)) {
|
| 58 |
-
register_module("dropout", dropout);
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
torch::Tensor forward(torch::Tensor x) {
|
| 62 |
-
torch::Tensor key_out = key->forward(x);
|
| 63 |
-
torch::Tensor query_out = query->forward(x);
|
| 64 |
-
|
| 65 |
-
torch::Tensor weights = query_out.matmul(key_out.transpose(-2, -1)) * std::sqrt(key_out.size(-1));
|
| 66 |
-
weights = torch::softmax(weights, -1);
|
| 67 |
-
weights = dropout(weights);
|
| 68 |
-
|
| 69 |
-
torch::Tensor value_out = value->forward(x);
|
| 70 |
-
torch::Tensor out = weights.matmul(value_out);
|
| 71 |
-
return out;
|
| 72 |
-
}
|
| 73 |
-
|
| 74 |
-
private:
|
| 75 |
-
torch::nn::Linear key, query, value;
|
| 76 |
-
torch::nn::Dropout dropout;
|
| 77 |
-
};
|
| 78 |
-
|
| 79 |
-
TORCH_MODULE(UnMaskedHead);
|
| 80 |
-
|
| 81 |
-
class MaskedHeadImpl : public torch::nn::Module {
|
| 82 |
-
public:
|
| 83 |
-
MaskedHeadImpl(int head_size, float dropout, int d_model)
|
| 84 |
-
: key(register_module("key", torch::nn::Linear(d_model, head_size))),
|
| 85 |
-
query(register_module("query", torch::nn::Linear(d_model, head_size))),
|
| 86 |
-
value(register_module("value", torch::nn::Linear(d_model, head_size))),
|
| 87 |
-
dropout(torch::nn::Dropout(dropout)) {
|
| 88 |
-
register_buffer("tril", torch::tril(torch::ones(block_size, block_size)));
|
| 89 |
-
}
|
| 90 |
-
|
| 91 |
-
torch::Tensor forward(torch::Tensor x) {
|
| 92 |
-
torch::Tensor key_out = key->forward(x);
|
| 93 |
-
torch::Tensor query_out = query->forward(x);
|
| 94 |
-
|
| 95 |
-
torch::Tensor weights = query_out.matmul(key_out.transpose(-2, -1)) * std::sqrt(key_out.size(-1));
|
| 96 |
-
weights = weights.masked_fill(tril[:x.size(1), :x.size(1)] == 0, std::numeric_limits<float>::lowest());
|
| 97 |
-
weights = torch::softmax(weights, -1);
|
| 98 |
-
weights = dropout(weights);
|
| 99 |
-
|
| 100 |
-
torch::Tensor value_out = value->forward(x);
|
| 101 |
-
torch::Tensor out = weights.matmul(value_out);
|
| 102 |
-
return out;
|
| 103 |
-
}
|
| 104 |
-
|
| 105 |
-
private:
|
| 106 |
-
torch::nn::Linear key, query, value;
|
| 107 |
-
torch::nn::Dropout dropout;
|
| 108 |
-
torch::Tensor tril;
|
| 109 |
-
};
|
| 110 |
-
|
| 111 |
-
TORCH_MODULE(MaskedHead);
|
| 112 |
-
|
| 113 |
-
class MultiUnMaskedImpl : public torch::nn::Module {
|
| 114 |
-
public:
|
| 115 |
-
MultiUnMaskedImpl(int d_model, int n_head, float dropout)
|
| 116 |
-
: proj(register_module("proj", torch::nn::Linear(n_head * (d_model / n_head), d_model))),
|
| 117 |
-
dropout(torch::nn::Dropout(dropout)) {
|
| 118 |
-
for (int i = 0; i < n_head; ++i) {
|
| 119 |
-
heads.push_back(register_module("head" + std::to_string(i), UnMaskedHead(d_model, d_model / n_head, dropout)));
|
| 120 |
-
}
|
| 121 |
-
}
|
| 122 |
-
|
| 123 |
-
torch::Tensor forward(torch::Tensor x) {
|
| 124 |
-
std::vector<torch::Tensor> head_outputs;
|
| 125 |
-
for (auto& head : heads) {
|
| 126 |
-
head_outputs.push_back(head->forward(x));
|
| 127 |
-
}
|
| 128 |
-
torch::Tensor out = torch::cat(head_outputs, -1);
|
| 129 |
-
out = dropout(out);
|
| 130 |
-
out = proj(out);
|
| 131 |
-
return out;
|
| 132 |
-
}
|
| 133 |
-
|
| 134 |
-
private:
|
| 135 |
-
torch::nn::Linear proj;
|
| 136 |
-
torch::nn::Dropout dropout;
|
| 137 |
-
std::vector<UnMaskedHead> heads;
|
| 138 |
-
};
|
| 139 |
-
|
| 140 |
-
TORCH_MODULE(MultiUnMasked);
|
| 141 |
-
|
| 142 |
-
class MultiMaskedImpl : public torch::nn::Module {
|
| 143 |
-
public:
|
| 144 |
-
MultiMaskedImpl(int d_model, int n_head, float dropout)
|
| 145 |
-
: proj(register_module("proj", torch::nn::Linear(n_head * (d_model / n_head), d_model))),
|
| 146 |
-
dropout(torch::nn::Dropout(dropout)) {
|
| 147 |
-
for (int i = 0; i < n_head; ++i) {
|
| 148 |
-
heads.push_back(register_module("head" + std::to_string(i), MaskedHead(d_model, d_model / n_head, dropout)));
|
| 149 |
-
}
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
torch::Tensor forward(torch::Tensor x) {
|
| 153 |
-
std::vector<torch::Tensor> head_outputs;
|
| 154 |
-
for (auto& head : heads) {
|
| 155 |
-
head_outputs.push_back(head->forward(x));
|
| 156 |
-
}
|
| 157 |
-
torch::Tensor out = torch::cat(head_outputs, -1);
|
| 158 |
-
out = dropout(out);
|
| 159 |
-
out = proj(out);
|
| 160 |
-
return out;
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
private:
|
| 164 |
-
torch::nn::Linear proj;
|
| 165 |
-
torch::nn::Dropout dropout;
|
| 166 |
-
std::vector<MaskedHead> heads;
|
| 167 |
-
};
|
| 168 |
-
|
| 169 |
-
TORCH_MODULE(MultiMasked);
|
| 170 |
-
|
| 171 |
-
class FeedForwardImpl : public torch::nn::Module {
|
| 172 |
-
public:
|
| 173 |
-
FeedForwardImpl(int d_model, float dropout)
|
| 174 |
-
: net(register_module("net", torch::nn::Sequential(
|
| 175 |
-
torch::nn::Linear(d_model, 4 * d_model),
|
| 176 |
-
torch::nn::GELU(),
|
| 177 |
-
torch::nn::Linear(4 * d_model, d_model),
|
| 178 |
-
torch::nn::Dropout(dropout)
|
| 179 |
-
))) {}
|
| 180 |
-
|
| 181 |
-
torch::Tensor forward(torch::Tensor x) {
|
| 182 |
-
return net->forward(x);
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
private:
|
| 186 |
-
torch::nn::Sequential net;
|
| 187 |
-
};
|
| 188 |
-
|
| 189 |
-
TORCH_MODULE(FeedForward);
|
| 190 |
-
|
| 191 |
-
class BlockImpl : public torch::nn::Module {
|
| 192 |
-
public:
|
| 193 |
-
BlockImpl(int d_model, int n_head, float norm_eps, float dropout)
|
| 194 |
-
: sa_masked(MultiMasked(d_model, n_head, dropout)),
|
| 195 |
-
sa_unmasked(MultiUnMasked(d_model, n_head, dropout)),
|
| 196 |
-
ffwd(FeedForward(d_model, dropout)),
|
| 197 |
-
norm1(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))),
|
| 198 |
-
norm2(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))) {}
|
| 199 |
-
|
| 200 |
-
torch::Tensor forward(torch::Tensor x) {
|
| 201 |
-
torch::Tensor x2 = x + sa_unmasked->forward(norm1->forward(x));
|
| 202 |
-
x = x2 + ffwd->forward(norm2->forward(x2));
|
| 203 |
-
|
| 204 |
-
x2 = x + sa_masked->forward(norm1->forward(x));
|
| 205 |
-
x = x2 + ffwd->forward(norm2->forward(x2));
|
| 206 |
-
|
| 207 |
-
return x;
|
| 208 |
-
}
|
| 209 |
-
|
| 210 |
-
private:
|
| 211 |
-
MultiMasked sa_masked;
|
| 212 |
-
MultiUnMasked sa_unmasked;
|
| 213 |
-
FeedForward ffwd;
|
| 214 |
-
torch::nn::LayerNorm norm1, norm2;
|
| 215 |
-
};
|
| 216 |
-
|
| 217 |
-
TORCH_MODULE(Block);
|
| 218 |
-
|
| 219 |
-
class EnigmaImpl : public torch::nn::Module {
|
| 220 |
-
public:
|
| 221 |
-
EnigmaImpl(int vocab_size, int block_size, int d_model, int n_layer, int n_head, float dropout, float norm_eps)
|
| 222 |
-
: toked_model(register_module("toked_model", torch::nn::Embedding(vocab_size, d_model))),
|
| 223 |
-
pos_encod(register_module("pos_encod", torch::nn::Embedding(block_size, d_model))),
|
| 224 |
-
norm_final(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))),
|
| 225 |
-
linear_final(register_module("linear_final", torch::nn::Linear(d_model, vocab_size))) {
|
| 226 |
-
for (int i = 0; i < n_layer; ++i) {
|
| 227 |
-
block_layers.push_back(register_module("block" + std::to_string(i), Block(d_model, n_head, norm_eps, dropout)));
|
| 228 |
-
}
|
| 229 |
-
register_buffer("block_size", torch::tensor(block_size));
|
| 230 |
-
_init_weights(this);
|
| 231 |
-
}
|
| 232 |
-
|
| 233 |
-
void _init_weights(torch::nn::Module* module) {
|
| 234 |
-
auto parameters = module->named_parameters();
|
| 235 |
-
for (auto& param : parameters) {
|
| 236 |
-
if (param.key().find("weight") != std::string::npos) {
|
| 237 |
-
torch::nn::init::normal_(param.value(), 0.0, 0.02);
|
| 238 |
-
} else if (param.key().find("bias") != std::string::npos) {
|
| 239 |
-
torch::nn::init::zeros_(param.value());
|
| 240 |
-
}
|
| 241 |
-
}
|
| 242 |
-
}
|
| 243 |
-
|
| 244 |
-
std::pair<torch::Tensor, torch::Tensor> forward(torch::Tensor idx, torch::Tensor targets=torch::Tensor()) {
|
| 245 |
-
torch::Tensor toked_model_out = toked_model->forward(idx);
|
| 246 |
-
torch::Tensor pos_encod_out = pos_encod->forward(torch::arange(idx.size(1)));
|
| 247 |
-
torch::Tensor x = toked_model_out + pos_encod_out;
|
| 248 |
-
|
| 249 |
-
for (auto& block : block_layers) {
|
| 250 |
-
x = block->forward(x);
|
| 251 |
-
}
|
| 252 |
-
|
| 253 |
-
torch::Tensor logits = linear_final->forward(norm_final->forward(x));
|
| 254 |
-
|
| 255 |
-
if (!targets.numel()) {
|
| 256 |
-
return {logits, torch::Tensor()};
|
| 257 |
-
} else {
|
| 258 |
-
logits = logits.view({-1, logits.size(-1)});
|
| 259 |
-
targets = targets.view({-1});
|
| 260 |
-
torch::Tensor loss = torch::nn::functional::cross_entropy(logits, targets);
|
| 261 |
-
return {logits, loss};
|
| 262 |
-
}
|
| 263 |
-
}
|
| 264 |
-
|
| 265 |
-
std::vector<std::vector<std::pair<torch::Tensor, float>>> complex_generate(torch::Tensor idx, int max_new_tokens, float temperature=1.0, int top_k=3, int beam_width=5) {
|
| 266 |
-
std::vector<std::vector<std::pair<torch::Tensor, float>>> completed_beams;
|
| 267 |
-
torch::Tensor current_idx = idx.clone();
|
| 268 |
-
std::vector<std::pair<torch::Tensor, float>> beam = {std::make_pair(current_idx, 0.0)};
|
| 269 |
-
|
| 270 |
-
for (int i = 0; i < max_new_tokens; ++i) {
|
| 271 |
-
std::vector<std::pair<torch::Tensor, float>> new_beam;
|
| 272 |
-
|
| 273 |
-
for (auto& beam_item : beam) {
|
| 274 |
-
torch::Tensor& current_idx = beam_item.first;
|
| 275 |
-
torch::Tensor logits, loss;
|
| 276 |
-
std::tie(logits, loss) = forward(current_idx);
|
| 277 |
-
logits = logits.index({torch::indexing::Slice(), -1}); // Get last token predictions
|
| 278 |
-
|
| 279 |
-
// Apply softmax and temperature
|
| 280 |
-
torch::Tensor probs = torch::nn::functional::softmax(logits / temperature, -1);
|
| 281 |
-
|
| 282 |
-
// Top-k sampling
|
| 283 |
-
if (top_k > 0) {
|
| 284 |
-
probs = top_k_filtering(probs, top_k);
|
| 285 |
-
}
|
| 286 |
-
|
| 287 |
-
// Sample from the distribution
|
| 288 |
-
torch::Tensor sampled_idx = torch::multinomial(probs, beam_width, true);
|
| 289 |
-
|
| 290 |
-
for (int j = 0; j < beam_width; ++j) {
|
| 291 |
-
torch::Tensor new_idx = torch::cat({current_idx, sampled_idx.index({torch::indexing::Slice(), j})}, 1);
|
| 292 |
-
torch::Tensor new_log_prob = beam_item.second + torch::log(probs.index({torch::indexing::Slice(), sampled_idx.index({torch::indexing::Slice(), j})}));
|
| 293 |
-
new_beam.push_back(std::make_pair(new_idx, new_log_prob.item()));
|
| 294 |
-
}
|
| 295 |
-
}
|
| 296 |
-
|
| 297 |
-
// Sort new beam by log probabilities
|
| 298 |
-
std::sort(new_beam.begin(), new_beam.end(), [](const std::pair<torch::Tensor, float>& a, const std::pair<torch::Tensor, float>& b) {
|
| 299 |
-
return a.second > b.second;
|
| 300 |
-
});
|
| 301 |
-
|
| 302 |
-
// Only keep top beams
|
| 303 |
-
beam = std::vector<std::pair<torch::Tensor, float>>(new_beam.begin(), new_beam.begin() + beam_width);
|
| 304 |
-
}
|
| 305 |
-
|
| 306 |
-
completed_beams.push_back(beam);
|
| 307 |
-
return completed_beams;
|
| 308 |
-
}
|
| 309 |
-
|
| 310 |
-
std::vector<std::vector<std::pair<torch::Tensor, float>>> top_k_filtering(torch::Tensor logits, int top_k) {
|
| 311 |
-
torch::Tensor top_values, top_indices;
|
| 312 |
-
std::tie(top_values, top_indices) = torch::topk(logits, top_k, -1);
|
| 313 |
-
|
| 314 |
-
torch::Tensor min_value = torch::index_select(top_values, -1, torch::tensor({top_k-1}));
|
| 315 |
-
torch::Tensor filtered_logits = torch::where(logits < min_value, torch::full_like(logits, -std::numeric_limits<float>::infinity()), logits);
|
| 316 |
-
return filtered_logits;
|
| 317 |
-
}
|
| 318 |
-
|
| 319 |
-
private:
|
| 320 |
-
torch::nn::Embedding toked_model, pos_encod;
|
| 321 |
-
std::vector<Block> block_layers;
|
| 322 |
-
torch::nn::LayerNorm norm_final;
|
| 323 |
-
torch::nn::Linear linear_final;
|
| 324 |
-
int block_size;
|
| 325 |
-
};
|
| 326 |
-
|
| 327 |
-
TORCH_MODULE(Enigma);
|
| 328 |
-
|
| 329 |
-
int main() {
|
| 330 |
-
// Set seed
|
| 331 |
-
torch::manual_seed(1400);
|
| 332 |
-
|
| 333 |
-
// Create model
|
| 334 |
-
Enigma model(vocab_size, block_size, d_model, n_layer, n_head, dropout, norm_eps);
|
| 335 |
-
model->to(device);
|
| 336 |
-
|
| 337 |
-
// Define optimizer
|
| 338 |
-
torch::optim::AdamW optimizer(model->parameters(), torch::optim::AdamWOptions(learning_rate));
|
| 339 |
-
|
| 340 |
-
// Training loop
|
| 341 |
-
std::vector<float> train_losses, val_losses;
|
| 342 |
-
for (int iter = 0; iter < max_iters; ++iter) {
|
| 343 |
-
if (iter % eval_interval == 0 || iter == max_iters - 1) {
|
| 344 |
-
// Evaluate and print losses
|
| 345 |
-
auto losses = estimate_loss();
|
| 346 |
-
std::cout << "step " << iter << ": train loss " << losses["train"] << ", val loss " << losses["val"] << std::endl;
|
| 347 |
-
|
| 348 |
-
// Save losses for plotting
|
| 349 |
-
train_losses.push_back(losses["train"]);
|
| 350 |
-
val_losses.push_back(losses["val"]);
|
| 351 |
-
}
|
| 352 |
-
|
| 353 |
-
// Get batch, forward pass, loss calculation, backward pass, optimizer step
|
| 354 |
-
auto [xb, yb] = get_batch("train");
|
| 355 |
-
torch::Tensor logits, loss;
|
| 356 |
-
std::tie(logits, loss) = model->forward(xb, yb);
|
| 357 |
-
|
| 358 |
-
optimizer.zero_grad();
|
| 359 |
-
loss.backward();
|
| 360 |
-
optimizer.step();
|
| 361 |
-
}
|
| 362 |
-
|
| 363 |
-
return 0;
|
| 364 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enigma/generate.py
DELETED
|
@@ -1,126 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
current_directory = os.path.dirname(os.path.abspath(__file__))
|
| 3 |
-
os.chdir(current_directory)
|
| 4 |
-
|
| 5 |
-
with open('../parquet files/new_dna.txt', 'r', encoding='utf-8') as file:
|
| 6 |
-
captions = file.read()
|
| 7 |
-
|
| 8 |
-
print(f"{(len(captions)/1e6):.2f} million letters")
|
| 9 |
-
|
| 10 |
-
from tokenizer import PerCharTokenizer
|
| 11 |
-
|
| 12 |
-
tokenizer = PerCharTokenizer()
|
| 13 |
-
vocab_size = tokenizer.vocab_size
|
| 14 |
-
|
| 15 |
-
import torch
|
| 16 |
-
import torch.nn as nn
|
| 17 |
-
from torch.nn import functional as F
|
| 18 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 19 |
-
|
| 20 |
-
from model import Transformer
|
| 21 |
-
model = Transformer(vocab_size=vocab_size)
|
| 22 |
-
|
| 23 |
-
class Generator(Transformer):
|
| 24 |
-
def __init__(self, vocab_size):
|
| 25 |
-
super().__init__()
|
| 26 |
-
self.vocab_size = vocab_size
|
| 27 |
-
self.block_size = Transformer.block_size
|
| 28 |
-
|
| 29 |
-
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
|
| 30 |
-
"""
|
| 31 |
-
generate new tokens using the trained model
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
- idx (Tensor): input tensor representing initial token indices
|
| 35 |
-
- max_new_tokens (int): max no of new tokens to generate
|
| 36 |
-
- temperature (float): softmax temperature for sampling
|
| 37 |
-
- top_k (int): no of top tokens to consider in sampling
|
| 38 |
-
|
| 39 |
-
Returns:
|
| 40 |
-
- generated_tokens (list): list of generated token indices
|
| 41 |
-
"""
|
| 42 |
-
generated_tokens = []
|
| 43 |
-
|
| 44 |
-
for _ in range(max_new_tokens):
|
| 45 |
-
idx_cond = idx[:, -self.block_size:]
|
| 46 |
-
logits, _ = self(idx_cond)
|
| 47 |
-
logits = logits[:, -1, :]
|
| 48 |
-
|
| 49 |
-
scaled_logits = logits / temperature
|
| 50 |
-
if top_k > 0:
|
| 51 |
-
scaled_logits = self._top_k_filtering(scaled_logits, top_k)
|
| 52 |
-
|
| 53 |
-
probs = F.softmax(scaled_logits, dim=-1)
|
| 54 |
-
sampled_idx = torch.multinomial(probs, num_samples=1)
|
| 55 |
-
generated_tokens.append(sampled_idx.item())
|
| 56 |
-
idx = torch.cat((idx, sampled_idx), dim=1)
|
| 57 |
-
|
| 58 |
-
return generated_tokens
|
| 59 |
-
|
| 60 |
-
def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):
|
| 61 |
-
"""
|
| 62 |
-
Generate predictions for masked tokens using the trained model.
|
| 63 |
-
|
| 64 |
-
Args:
|
| 65 |
-
- idx (Tensor): input tensor representing token indices
|
| 66 |
-
- masked_indices (Tensor): tensor of indices indicating masked positions
|
| 67 |
-
- temperature (float): softmax temperature for sampling
|
| 68 |
-
- top_k (int): no of top tokens to consider in sampling
|
| 69 |
-
|
| 70 |
-
Returns:
|
| 71 |
-
- predicted_tokens (Tensor): tensor of predicted token indices
|
| 72 |
-
"""
|
| 73 |
-
B, T = idx.shape
|
| 74 |
-
|
| 75 |
-
toked_model = self.toked_model(idx)
|
| 76 |
-
pos_encod = self.pos_encod(torch.arange(T, device=device))
|
| 77 |
-
x = toked_model + pos_encod
|
| 78 |
-
|
| 79 |
-
for layer in self.enc_layer:
|
| 80 |
-
x_out = layer(x)
|
| 81 |
-
|
| 82 |
-
for layer in self.dec_layer:
|
| 83 |
-
x_final = layer(x, x_out)
|
| 84 |
-
|
| 85 |
-
x_masked = x_final.clone()
|
| 86 |
-
x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))
|
| 87 |
-
|
| 88 |
-
x_masked = self.norm_final(x_masked)
|
| 89 |
-
logits = self.linear_final(x_masked)
|
| 90 |
-
|
| 91 |
-
masked_logits = logits[masked_indices].view(-1, logits.size(-1))
|
| 92 |
-
scaled_logits = masked_logits / temperature
|
| 93 |
-
if top_k > 0:
|
| 94 |
-
scaled_logits = self._top_k_filtering(scaled_logits, top_k)
|
| 95 |
-
|
| 96 |
-
probs = F.softmax(scaled_logits, dim=-1)
|
| 97 |
-
predicted_indices = torch.argmax(probs, dim=-1)
|
| 98 |
-
|
| 99 |
-
return predicted_indices
|
| 100 |
-
|
| 101 |
-
def _top_k_filtering(self, logits, top_k):
|
| 102 |
-
"""
|
| 103 |
-
filter logits to keep only the top-k tokens
|
| 104 |
-
|
| 105 |
-
Args:
|
| 106 |
-
- logits (Tensor): input tensor representing unscaled logits
|
| 107 |
-
- top_k (int): no of top tokens to keep
|
| 108 |
-
|
| 109 |
-
Returns:
|
| 110 |
-
- filtered_logits (Tensor): filtered logits with only top-k tokens remaining
|
| 111 |
-
"""
|
| 112 |
-
values, indices = torch.topk(logits, top_k, dim=-1)
|
| 113 |
-
min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
|
| 114 |
-
filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
|
| 115 |
-
|
| 116 |
-
return filtered_logits
|
| 117 |
-
|
| 118 |
-
checkpoint_path = '../trained models/enigma_47m.pth'
|
| 119 |
-
checkpoint = torch.load(checkpoint_path)
|
| 120 |
-
model.load_state_dict(checkpoint)
|
| 121 |
-
m = model.to(device)
|
| 122 |
-
|
| 123 |
-
target_text = "AGTTCTGCGAT"
|
| 124 |
-
context = torch.tensor([tokenizer.encode(target_text)], dtype=torch.long, device=device)
|
| 125 |
-
generated_output = tokenizer.decode(Generator.generate(context, max_new_tokens=10, temperature=0.5, top_k=5))
|
| 126 |
-
print(f"{target_text}{generated_output}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enigma/model.py
DELETED
|
@@ -1,388 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
transformer based model, but with few minimal tweaks
|
| 3 |
-
trained a 2.5billion parameters model with current set configurations
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import json
|
| 8 |
-
import os
|
| 9 |
-
current_directory = os.path.dirname(os.path.abspath(__file__))
|
| 10 |
-
os.chdir(current_directory)
|
| 11 |
-
|
| 12 |
-
import torch.nn as nn
|
| 13 |
-
from torch.nn import functional as F
|
| 14 |
-
|
| 15 |
-
with open('config_enigma.json', 'r', encoding='utf-8') as file:
|
| 16 |
-
params = json.load(file)
|
| 17 |
-
|
| 18 |
-
batch_size = params['batch_size']
|
| 19 |
-
block_size = params['block_size']
|
| 20 |
-
n_head = params['n_head']
|
| 21 |
-
d_model = params['d_model']
|
| 22 |
-
n_layers = params['n_layer']
|
| 23 |
-
dropout = params['dropout']
|
| 24 |
-
norm_eps = params['norm_eps']
|
| 25 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 26 |
-
|
| 27 |
-
class AttentionHead(nn.Module):
|
| 28 |
-
"""
|
| 29 |
-
initialize a single head of self attention.
|
| 30 |
-
|
| 31 |
-
Args:
|
| 32 |
-
- d_model (int): dimensionality of the model's hidden layers
|
| 33 |
-
- head_size (int): dimensionality of each attention head
|
| 34 |
-
- dropout (float): dropout probability
|
| 35 |
-
- block_size (int): the maximum sequence length for positional encoding
|
| 36 |
-
"""
|
| 37 |
-
def __init__(self, d_model, head_size, dropout, block_size):
|
| 38 |
-
super().__init__()
|
| 39 |
-
self.key = nn.Linear(d_model, head_size, bias=True)
|
| 40 |
-
self.query = nn.Linear(d_model, head_size, bias=True)
|
| 41 |
-
self.value = nn.Linear(d_model, head_size, bias=False)
|
| 42 |
-
self.dropout = nn.Dropout(dropout)
|
| 43 |
-
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
|
| 44 |
-
|
| 45 |
-
self.rel_pos_emb = nn.Parameter(torch.randn(block_size, block_size, head_size))
|
| 46 |
-
|
| 47 |
-
def forward(self, x, mask=False):
|
| 48 |
-
"""
|
| 49 |
-
forward pass of a single attention head.
|
| 50 |
-
|
| 51 |
-
Args:
|
| 52 |
-
- x (Tensor): input tensor.
|
| 53 |
-
- mask (bool): flag indicating whether to apply masking
|
| 54 |
-
|
| 55 |
-
Returns:
|
| 56 |
-
- out (Tensor): output tensor after self attention
|
| 57 |
-
"""
|
| 58 |
-
B, T, C = x.shape
|
| 59 |
-
key = self.key(x)
|
| 60 |
-
query = self.query(x)
|
| 61 |
-
|
| 62 |
-
scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)
|
| 63 |
-
rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_emb[:T, :T])
|
| 64 |
-
scores += rel_pos_scores
|
| 65 |
-
|
| 66 |
-
if mask:
|
| 67 |
-
scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
|
| 68 |
-
|
| 69 |
-
weights = F.softmax(scores, dim=-1)
|
| 70 |
-
weights = self.dropout(weights)
|
| 71 |
-
|
| 72 |
-
value = self.value(x)
|
| 73 |
-
out = torch.matmul(weights, value)
|
| 74 |
-
return out
|
| 75 |
-
|
| 76 |
-
class MultiHeadAttention(nn.Module):
|
| 77 |
-
"""
|
| 78 |
-
initialize a multi-head attention module.
|
| 79 |
-
|
| 80 |
-
Args:
|
| 81 |
-
- d_model (int): dimensionality of the model's hidden layers
|
| 82 |
-
- n_head (int): no of attention heads
|
| 83 |
-
- dropout (float): dropout probability
|
| 84 |
-
- block_size (int): context length
|
| 85 |
-
"""
|
| 86 |
-
def __init__(self, d_model, n_head, dropout, block_size):
|
| 87 |
-
head_size = d_model // n_head
|
| 88 |
-
super().__init__()
|
| 89 |
-
self.heads = nn.ModuleList([AttentionHead(d_model=d_model, dropout=dropout, head_size=head_size, block_size=block_size) for _ in range(n_head)])
|
| 90 |
-
self.proj = nn.Linear(n_head * head_size, d_model)
|
| 91 |
-
self.dropout = nn.Dropout(dropout)
|
| 92 |
-
|
| 93 |
-
def forward(self, x, mask):
|
| 94 |
-
"""
|
| 95 |
-
forward pass of the multi-head attention module
|
| 96 |
-
|
| 97 |
-
Args:
|
| 98 |
-
- x (Tensor): input tensor
|
| 99 |
-
- mask (bool): flag indicating whether to apply masking
|
| 100 |
-
|
| 101 |
-
Returns:
|
| 102 |
-
- out (Tensor): output tensor after multi-head attention
|
| 103 |
-
|
| 104 |
-
"""
|
| 105 |
-
out = torch.cat([h(x, mask=mask) for h in self.heads], dim=-1)
|
| 106 |
-
out = self.dropout(self.proj(out))
|
| 107 |
-
return out
|
| 108 |
-
|
| 109 |
-
class FeedForward(nn.Module):
|
| 110 |
-
"""
|
| 111 |
-
initialize a feedforward network module
|
| 112 |
-
|
| 113 |
-
Args:
|
| 114 |
-
- d_model (int): the dimensionality of the model's hidden layers
|
| 115 |
-
- dropout (float): dropout probability
|
| 116 |
-
|
| 117 |
-
"""
|
| 118 |
-
def __init__(self, d_model, dropout):
|
| 119 |
-
super().__init__()
|
| 120 |
-
self.net = nn.Sequential(
|
| 121 |
-
nn.Linear(d_model, 10*d_model),
|
| 122 |
-
nn.GELU(),
|
| 123 |
-
nn.Linear(10*d_model, d_model),
|
| 124 |
-
nn.Dropout(dropout)
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
def forward(self, x):
|
| 128 |
-
"""
|
| 129 |
-
forward pass of the feedforward network module
|
| 130 |
-
|
| 131 |
-
Args:
|
| 132 |
-
- x (Tensor): input tensor
|
| 133 |
-
|
| 134 |
-
Returns:
|
| 135 |
-
- out (Tensor): output tensor after passing through the feedforward network
|
| 136 |
-
"""
|
| 137 |
-
return self.net(x)
|
| 138 |
-
|
| 139 |
-
class EncoderNetwork(nn.Module):
|
| 140 |
-
"""
|
| 141 |
-
initialize an encoder network module
|
| 142 |
-
|
| 143 |
-
Args:
|
| 144 |
-
- d_model (int): dimensionality of the model's hidden layers
|
| 145 |
-
- n_head (int): no of attention heads in multi-head attention layers
|
| 146 |
-
- norm_eps (float): epsilon value for layer normalization
|
| 147 |
-
- dropout (float): dropout probability
|
| 148 |
-
- block_size (int): the maximum sequence length for positional encoding
|
| 149 |
-
"""
|
| 150 |
-
def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
|
| 151 |
-
super().__init__()
|
| 152 |
-
self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
|
| 153 |
-
self.ffwd = FeedForward(d_model, dropout)
|
| 154 |
-
self.dropout = nn.Dropout(dropout)
|
| 155 |
-
self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)
|
| 156 |
-
self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)
|
| 157 |
-
|
| 158 |
-
def forward(self, src):
|
| 159 |
-
"""
|
| 160 |
-
forward pass of the encoder network module.
|
| 161 |
-
|
| 162 |
-
Args:
|
| 163 |
-
- src (Tensor): input tensor representing source data
|
| 164 |
-
|
| 165 |
-
Returns:
|
| 166 |
-
- src (Tensor): output tensor after passing through the encoder network
|
| 167 |
-
"""
|
| 168 |
-
src2 = self.s_att(src, mask=False)
|
| 169 |
-
src = src + self.dropout(src2)
|
| 170 |
-
src = self.norm1(src)
|
| 171 |
-
|
| 172 |
-
src2 = self.ffwd(src)
|
| 173 |
-
src = src + self.dropout(src2)
|
| 174 |
-
src = self.norm2(src)
|
| 175 |
-
|
| 176 |
-
return src
|
| 177 |
-
|
| 178 |
-
class DecoderNetwork(nn.Module):
|
| 179 |
-
"""
|
| 180 |
-
initialize a decoder network module
|
| 181 |
-
|
| 182 |
-
Args:
|
| 183 |
-
- d_model (int): dimensionality of the model's hidden layers
|
| 184 |
-
- n_head (int): no of attention heads in multi-head attention layers
|
| 185 |
-
- norm_eps (float): epsilon value for layer normalization
|
| 186 |
-
- dropout (float): dropout probability
|
| 187 |
-
- block_size (int): the maximum sequence length for positional encoding
|
| 188 |
-
"""
|
| 189 |
-
def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
|
| 190 |
-
super().__init__()
|
| 191 |
-
self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
|
| 192 |
-
self.ffwd = FeedForward(d_model, dropout)
|
| 193 |
-
self.dropout = nn.Dropout(dropout)
|
| 194 |
-
self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)
|
| 195 |
-
self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)
|
| 196 |
-
|
| 197 |
-
def forward(self, src, att):
|
| 198 |
-
"""
|
| 199 |
-
forward pass of the decoder network module.
|
| 200 |
-
|
| 201 |
-
Args:
|
| 202 |
-
- src (Tensor): input tensor, same as the encoder's inputs
|
| 203 |
-
- trg (Tensor): encoder's attention matrix
|
| 204 |
-
|
| 205 |
-
Returns:
|
| 206 |
-
- src_f (Tensor): final output tensor
|
| 207 |
-
"""
|
| 208 |
-
src2 = self.s_att(src, mask=True)
|
| 209 |
-
src = src + self.dropout(src2)
|
| 210 |
-
src = src + self.norm1(src)
|
| 211 |
-
|
| 212 |
-
att = src + att
|
| 213 |
-
att2 = self.s_att(att, mask=False)
|
| 214 |
-
att2 = att + self.dropout(att2)
|
| 215 |
-
trg = att2 + self.norm1(att2)
|
| 216 |
-
|
| 217 |
-
src_f2 = self.ffwd(self.norm2(trg))
|
| 218 |
-
src_f = src_f + self.dropout(src_f2)
|
| 219 |
-
src_f = self.norm2(src_f)
|
| 220 |
-
|
| 221 |
-
return src_f
|
| 222 |
-
|
| 223 |
-
class Transformer(nn.Module):
|
| 224 |
-
"""
|
| 225 |
-
initialize a Transformer model
|
| 226 |
-
|
| 227 |
-
Args:
|
| 228 |
-
- vocab_size (int): size of the vocabulary
|
| 229 |
-
- d_model (int): dimensionality of the model's hidden layers
|
| 230 |
-
- block_size (int): maximum sequence length for positional encoding/context length
|
| 231 |
-
- n_layers (int): number of encoder and decoder layers in the Transformer
|
| 232 |
-
- n_head (int): number of attention heads in multi-head attention layers
|
| 233 |
-
- norm_eps (float): epsilon value for layer normalization
|
| 234 |
-
- dropout (float): dropout probability
|
| 235 |
-
"""
|
| 236 |
-
def __init__(self, vocab_size):
|
| 237 |
-
super().__init__()
|
| 238 |
-
self.block_size = block_size
|
| 239 |
-
self.toked_model = nn.Embedding(vocab_size, d_model)
|
| 240 |
-
self.pos_encod = nn.Embedding(block_size, d_model)
|
| 241 |
-
self.enc_layer = nn.ModuleList([EncoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
|
| 242 |
-
self.dec_layer = nn.ModuleList([DecoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
|
| 243 |
-
|
| 244 |
-
self.norm_final = nn.LayerNorm(d_model)
|
| 245 |
-
self.linear_final = nn.Linear(d_model, vocab_size)
|
| 246 |
-
self.dropout = nn.Dropout(dropout)
|
| 247 |
-
self.apply(self._init_weights)
|
| 248 |
-
|
| 249 |
-
def _init_weights(self, module):
|
| 250 |
-
"""
|
| 251 |
-
initialize weights of linear and embedding layers
|
| 252 |
-
|
| 253 |
-
Args:
|
| 254 |
-
- module (nn.Module): the module to initialize weights for
|
| 255 |
-
"""
|
| 256 |
-
if isinstance(module, nn.Linear):
|
| 257 |
-
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 258 |
-
if module.bias is not None:
|
| 259 |
-
torch.nn.init.zeros_(module.bias.data)
|
| 260 |
-
elif isinstance(module, nn.Embedding):
|
| 261 |
-
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 262 |
-
|
| 263 |
-
def forward(self, idx, targets=None):
|
| 264 |
-
"""
|
| 265 |
-
forward pass of the transformer model
|
| 266 |
-
|
| 267 |
-
Args:
|
| 268 |
-
- idx (Tensor): input tensor representing token indices
|
| 269 |
-
- targets (Tensor): target tensor for computing loss during training
|
| 270 |
-
|
| 271 |
-
Returns:
|
| 272 |
-
- logits (Tensor): output logits from the final linear layer
|
| 273 |
-
- loss (Tensor): optional. computed cross-entropy loss if targets are provided, else None
|
| 274 |
-
"""
|
| 275 |
-
B, T = idx.shape
|
| 276 |
-
|
| 277 |
-
toked_model = self.toked_model(idx)
|
| 278 |
-
pos_encod = self.pos_encod(torch.arange(T, device=device))
|
| 279 |
-
x = toked_model + pos_encod
|
| 280 |
-
|
| 281 |
-
for layer in self.enc_layer:
|
| 282 |
-
x_out = layer(x)
|
| 283 |
-
|
| 284 |
-
for layer in self.dec_layer:
|
| 285 |
-
x_final = layer(x, x_out)
|
| 286 |
-
|
| 287 |
-
x_final = self.norm_final(x_final)
|
| 288 |
-
logits = self.linear_final(x_final)
|
| 289 |
-
|
| 290 |
-
if targets is None:
|
| 291 |
-
loss = None
|
| 292 |
-
|
| 293 |
-
else:
|
| 294 |
-
B, T, C = logits.shape
|
| 295 |
-
logits = logits.view(B*T, C)
|
| 296 |
-
targets = targets.view(B*T)
|
| 297 |
-
loss = F.cross_entropy(logits, targets)
|
| 298 |
-
|
| 299 |
-
return logits, loss
|
| 300 |
-
|
| 301 |
-
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
|
| 302 |
-
"""
|
| 303 |
-
generate new tokens using the trained model
|
| 304 |
-
|
| 305 |
-
Args:
|
| 306 |
-
- idx (Tensor): input tensor representing initial token indices
|
| 307 |
-
- max_new_tokens (int): max no of new tokens to generate
|
| 308 |
-
- temperature (float): softmax temperature for sampling
|
| 309 |
-
- top_k (int): no of top tokens to consider in sampling
|
| 310 |
-
|
| 311 |
-
Returns:
|
| 312 |
-
- generated_tokens (list): list of generated token indices
|
| 313 |
-
"""
|
| 314 |
-
generated_tokens = []
|
| 315 |
-
|
| 316 |
-
for _ in range(max_new_tokens):
|
| 317 |
-
idx_cond = idx[:, -self.block_size:]
|
| 318 |
-
logits, _ = self(idx_cond)
|
| 319 |
-
logits = logits[:, -1, :]
|
| 320 |
-
|
| 321 |
-
scaled_logits = logits / temperature
|
| 322 |
-
if top_k > 0:
|
| 323 |
-
scaled_logits = self._top_k_filtering(scaled_logits, top_k)
|
| 324 |
-
|
| 325 |
-
probs = F.softmax(scaled_logits, dim=-1)
|
| 326 |
-
sampled_idx = torch.multinomial(probs, num_samples=1)
|
| 327 |
-
generated_tokens.append(sampled_idx.item())
|
| 328 |
-
idx = torch.cat((idx, sampled_idx), dim=1)
|
| 329 |
-
|
| 330 |
-
return generated_tokens
|
| 331 |
-
|
| 332 |
-
def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):
|
| 333 |
-
"""
|
| 334 |
-
Generate predictions for masked tokens using the trained model.
|
| 335 |
-
|
| 336 |
-
Args:
|
| 337 |
-
- idx (Tensor): input tensor representing token indices
|
| 338 |
-
- masked_indices (Tensor): tensor of indices indicating masked positions
|
| 339 |
-
- temperature (float): softmax temperature for sampling
|
| 340 |
-
- top_k (int): no of top tokens to consider in sampling
|
| 341 |
-
|
| 342 |
-
Returns:
|
| 343 |
-
- predicted_tokens (Tensor): tensor of predicted token indices
|
| 344 |
-
"""
|
| 345 |
-
B, T = idx.shape
|
| 346 |
-
|
| 347 |
-
toked_model = self.toked_model(idx)
|
| 348 |
-
pos_encod = self.pos_encod(torch.arange(T, device=device))
|
| 349 |
-
x = toked_model + pos_encod
|
| 350 |
-
|
| 351 |
-
for layer in self.enc_layer:
|
| 352 |
-
x_out = layer(x)
|
| 353 |
-
|
| 354 |
-
for layer in self.dec_layer:
|
| 355 |
-
x_final = layer(x, x_out)
|
| 356 |
-
|
| 357 |
-
x_masked = x_final.clone()
|
| 358 |
-
x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))
|
| 359 |
-
|
| 360 |
-
x_masked = self.norm_final(x_masked)
|
| 361 |
-
logits = self.linear_final(x_masked)
|
| 362 |
-
|
| 363 |
-
masked_logits = logits[masked_indices].view(-1, logits.size(-1))
|
| 364 |
-
scaled_logits = masked_logits / temperature
|
| 365 |
-
if top_k > 0:
|
| 366 |
-
scaled_logits = self._top_k_filtering(scaled_logits, top_k)
|
| 367 |
-
|
| 368 |
-
probs = F.softmax(scaled_logits, dim=-1)
|
| 369 |
-
predicted_indices = torch.argmax(probs, dim=-1)
|
| 370 |
-
|
| 371 |
-
return predicted_indices
|
| 372 |
-
|
| 373 |
-
def _top_k_filtering(self, logits, top_k):
|
| 374 |
-
"""
|
| 375 |
-
filter logits to keep only the top-k tokens
|
| 376 |
-
|
| 377 |
-
Args:
|
| 378 |
-
- logits (Tensor): input tensor representing unscaled logits
|
| 379 |
-
- top_k (int): no of top tokens to keep
|
| 380 |
-
|
| 381 |
-
Returns:
|
| 382 |
-
- filtered_logits (Tensor): filtered logits with only top-k tokens remaining
|
| 383 |
-
"""
|
| 384 |
-
values, indices = torch.topk(logits, top_k, dim=-1)
|
| 385 |
-
min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
|
| 386 |
-
filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
|
| 387 |
-
|
| 388 |
-
return filtered_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enigma/run.py
DELETED
|
@@ -1,100 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
use this file to train the model
|
| 3 |
-
|
| 4 |
-
working:
|
| 5 |
-
- imports vatious dependencies first, and then loads the training data
|
| 6 |
-
- tokenizes it, per-character basis
|
| 7 |
-
- loads the required hyper-parameters and the model file
|
| 8 |
-
- trains it till 'max_iters' and saves the model state, and generates outputs
|
| 9 |
-
|
| 10 |
-
with the current set configuration, model can reach upto ~60million parameters
|
| 11 |
-
and can become ~99% accurate with next token prediction
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
import json
|
| 16 |
-
import os
|
| 17 |
-
current_directory = os.path.dirname(os.path.abspath(__file__))
|
| 18 |
-
os.chdir(current_directory)
|
| 19 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 20 |
-
|
| 21 |
-
with open('../parquet files/new_dna.txt', 'r', encoding='utf-8') as file:
|
| 22 |
-
captions = file.read()
|
| 23 |
-
|
| 24 |
-
print(f"{(len(captions)/1e6):.2f} million letters")
|
| 25 |
-
|
| 26 |
-
from ..tokenizer import PerCharTokenizer
|
| 27 |
-
|
| 28 |
-
tokenizer = PerCharTokenizer()
|
| 29 |
-
vocab_size = tokenizer.vocab_size
|
| 30 |
-
# Train and test splits
|
| 31 |
-
data = torch.tensor(tokenizer.encode(captions), dtype=torch.long)
|
| 32 |
-
n = int(0.9*len(data)) # first 90% will be train, rest val
|
| 33 |
-
train_data = data[:n]
|
| 34 |
-
val_data = data[n:]
|
| 35 |
-
|
| 36 |
-
with open('/config_enigma.json', 'r', encoding='utf-8') as file:
|
| 37 |
-
params = json.load(file)
|
| 38 |
-
|
| 39 |
-
# required parameters
|
| 40 |
-
batch_size = params['batch_size']
|
| 41 |
-
block_size = params['block_size']
|
| 42 |
-
max_iters = params['max_iters']
|
| 43 |
-
eval_interval = params['eval_interval']
|
| 44 |
-
eval_iters = params['eval_iters']
|
| 45 |
-
learning_rate = params['learning_rate']
|
| 46 |
-
|
| 47 |
-
torch.manual_seed(1400)
|
| 48 |
-
# data loading
|
| 49 |
-
def get_batch(split):
|
| 50 |
-
# generate a small batch of data of inputs x and targets y
|
| 51 |
-
data = train_data if split == 'train' else val_data
|
| 52 |
-
ix = torch.randint(len(data) - block_size, (batch_size,))
|
| 53 |
-
x = torch.stack([data[i:i+block_size] for i in ix])
|
| 54 |
-
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
|
| 55 |
-
x, y = x.to(device), y.to(device)
|
| 56 |
-
return x, y
|
| 57 |
-
|
| 58 |
-
@torch.no_grad()
|
| 59 |
-
def estimate_loss():
|
| 60 |
-
out = {}
|
| 61 |
-
model.eval()
|
| 62 |
-
for split in ['train', 'val']:
|
| 63 |
-
losses = torch.zeros(eval_iters)
|
| 64 |
-
for k in range(eval_iters):
|
| 65 |
-
X, Y = get_batch(split)
|
| 66 |
-
logits, loss = model(X, Y)
|
| 67 |
-
losses[k] = loss.item()
|
| 68 |
-
out[split] = losses.mean()
|
| 69 |
-
model.train()
|
| 70 |
-
return out
|
| 71 |
-
|
| 72 |
-
from model import Transformer
|
| 73 |
-
model = Transformer(vocab_size=vocab_size)
|
| 74 |
-
m = model.to(device)
|
| 75 |
-
|
| 76 |
-
# no of parameters
|
| 77 |
-
n_param = sum(p.numel() for p in m.parameters())/1e6
|
| 78 |
-
print(f"{n_param:.2f} million")
|
| 79 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 80 |
-
steps = []
|
| 81 |
-
train_losses = []
|
| 82 |
-
val_losses = []
|
| 83 |
-
|
| 84 |
-
for iter in range(max_iters):
|
| 85 |
-
|
| 86 |
-
if iter % eval_interval == 0 or iter == max_iters - 1:
|
| 87 |
-
losses = estimate_loss()
|
| 88 |
-
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
| 89 |
-
|
| 90 |
-
steps.append(iter)
|
| 91 |
-
train_losses.append(losses['train'])
|
| 92 |
-
val_losses.append(losses['val'])
|
| 93 |
-
|
| 94 |
-
xb, yb = get_batch('train')
|
| 95 |
-
logits, loss = model(xb, yb)
|
| 96 |
-
optimizer.zero_grad(set_to_none=True)
|
| 97 |
-
loss.backward()
|
| 98 |
-
optimizer.step()
|
| 99 |
-
|
| 100 |
-
torch.save(model.state_dict(), f'enigma_{n_param:.0f}m.pth')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|