shivendrra commited on
Commit
82ea75a
·
verified ·
1 Parent(s): 8bc677c

Delete enigma

Browse files
enigma/EnBERT.py DELETED
@@ -1,206 +0,0 @@
1
- """
2
- simple BERT architecture model, paired with one more layer of
3
- masked self-attention, to predict next token
4
- """
5
-
6
- import torch
7
- import os
8
- current_directory = os.path.dirname(os.path.abspath(__file__))
9
- os.chdir(current_directory)
10
-
11
- import torch.nn as nn
12
- from torch.nn import functional as F
13
-
14
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
-
16
- # hyperparams
17
- batch_size = 8
18
- block_size = 32
19
- max_iters = 10
20
- eval_interval = 10
21
- learning_rate = 3e-4
22
- eval_iters = 5
23
- d_model = 256
24
- n_layer = 16
25
- n_head = 12
26
- dropout = 0.2
27
- norm_eps = 1e-5
28
-
29
- class SWiGLU(nn.Module):
30
- """ SWiGLU(x) = σ(x) ⊙ ReLU(x) + (1−σ(x)) ⊙ x """
31
-
32
- def forward(self, x):
33
- sigmoid_output = torch.sigmoid(x)
34
- relu_output = F.relu(x)
35
- out = sigmoid_output * relu_output + (1 - sigmoid_output) * x
36
-
37
- return out
38
-
39
- class UnMaskedHead(nn.Module):
40
- """ single head of self attention """
41
- def __init__(self, d_model, head_size, dropout):
42
- super().__init__()
43
- self.key = nn.Linear(d_model, head_size, bias=True)
44
- self.query = nn.Linear(d_model, head_size, bias=True)
45
- self.value = nn.Linear(d_model, head_size, bias=True)
46
- self.dropout = nn.Dropout(dropout)
47
-
48
- def forward(self, x):
49
- B, T, C = x.shape
50
- key = self.key(x)
51
- query = self.query(x)
52
-
53
- weights = query @ key.transpose(-2, -1) * key.shape[-1]**-0.5
54
- weights = F.softmax(weights, dim=-1)
55
- weights = self.dropout(weights)
56
-
57
- value = self.value(x)
58
- out = weights @ value
59
- return out
60
-
61
- class MaskedHead(nn.Module):
62
- """ one head of self-attention """
63
- def __init__(self, head_size, dropout, d_model):
64
- super().__init__()
65
- self.key = nn.Linear(d_model, head_size, bias=True)
66
- self.query = nn.Linear(d_model, head_size, bias=True)
67
- self.value = nn.Linear(d_model, head_size, bias=True)
68
- self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
69
-
70
- self.dropout = nn.Dropout(dropout)
71
-
72
- def forward(self, x):
73
- B,T,C = x.shape
74
- k = self.key(x)
75
- q = self.query(x)
76
-
77
- wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
78
- wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
79
- wei = F.softmax(wei, dim=-1) # (B, T, T)
80
- wei = self.dropout(wei)
81
-
82
- v = self.value(x)
83
- out = wei @ v
84
- return out
85
-
86
- class MultiUnMasked(nn.Module):
87
- def __init__(self, d_model, n_head, dropout):
88
- head_size = d_model // n_head
89
- super().__init__()
90
- self.heads = nn.ModuleList([UnMaskedHead(d_model=d_model, dropout=dropout, head_size=head_size) for _ in range(n_head)])
91
- self.proj = nn.Linear(n_head * head_size, d_model)
92
- self.dropout = nn.Dropout(dropout)
93
-
94
- def forward(self, x):
95
- out = torch.cat([h(x) for h in self.heads], dim=-1)
96
- out = self.dropout(self.proj(out))
97
- return out
98
-
99
- class MultiMasked(nn.Module):
100
- def __init__(self, d_model, n_head, dropout):
101
- head_size = d_model // n_head
102
- super().__init__()
103
- self.heads = nn.ModuleList([MaskedHead(d_model=d_model, dropout=dropout, head_size=head_size) for _ in range(n_head)])
104
- self.proj = nn.Linear(n_head * head_size, d_model)
105
- self.dropout = nn.Dropout(dropout)
106
-
107
- def forward(self, x):
108
- out = torch.cat([h(x) for h in self.heads], dim=-1)
109
- out = self.dropout(self.proj(out))
110
- return out
111
-
112
- class FeedForward(nn.Module):
113
- def __init__(self, d_model, dropout):
114
- super().__init__()
115
- self.net = nn.Sequential(
116
- nn.Linear(d_model, 4*d_model),
117
- nn.GELU(),
118
- nn.Linear(4*d_model, d_model),
119
- nn.Dropout(dropout)
120
- )
121
-
122
- def forward(self, x):
123
- return self.net(x)
124
-
125
- class Block(nn.Module):
126
- def __init__(self, d_model, n_head, norm_eps, dropout):
127
- super().__init__()
128
- self.sa_masked = MultiMasked(n_head=n_head, d_model=d_model, dropout=dropout)
129
- self.sa_unmasked = MultiUnMasked(n_head=n_head, d_model=d_model, dropout=dropout)
130
- self.ffwd = FeedForward(d_model, dropout=dropout)
131
- self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)
132
- self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)
133
-
134
- def forward(self, x):
135
- x2 = x + self.sa_unmasked(self.norm1(x))
136
- x = x2 + self.norm2(self.ffwd(x2))
137
-
138
- x2 = x + self.sa_masked(self.norm1(x))
139
- x = x2 + self.norm2(self.ffwd(x2))
140
- return x
141
-
142
- class EnigmaBERT(nn.Module):
143
- def __init__(self, vocab_size):
144
- super().__init__()
145
- self.toked_model = nn.Embedding(vocab_size, d_model)
146
- self.pos_encod = nn.Embedding(block_size, d_model)
147
- self.block = nn.Sequential(*[Block(d_model=d_model, dropout=dropout, norm_eps=norm_eps, n_head=n_head) for _ in range(n_layer)])
148
- self.norm_final = nn.LayerNorm(d_model, eps=norm_eps)
149
- self.linear_final = nn.Linear(d_model, vocab_size)
150
- self.apply(self._init_weights)
151
-
152
-
153
- def _init_weights(self, module):
154
- if isinstance(module, nn.Linear):
155
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
156
- if module.bias is not None:
157
- torch.nn.init.zeros_(module.bias.data)
158
- elif isinstance(module, nn.Embedding):
159
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
160
-
161
- def forward(self, idx, targets=None):
162
- B, T = idx.shape
163
-
164
- toked_model = self.toked_model(idx)
165
- pos_encod = self.pos_encod(torch.arange(T, device=device))
166
- x = toked_model + pos_encod
167
- x = self.block(x)
168
- x = self.norm_final(x)
169
- logits = self.linear_final(x)
170
-
171
- if targets is None:
172
- loss = None
173
-
174
- else:
175
- B, T, C = logits.shape
176
- logits = logits.view(B*T, C)
177
- targets = targets.view(B*T)
178
- loss = F.cross_entropy(logits, targets)
179
-
180
- return logits, loss
181
-
182
- def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
183
- generated_tokens = []
184
-
185
- for _ in range(max_new_tokens):
186
- idx_cond = idx[:, -block_size:]
187
- logits, _ = self(idx_cond)
188
- logits = logits[:, -1, :]
189
-
190
- scaled_logits = logits / temperature
191
- if top_k > 0:
192
- scaled_logits = self._top_k_filtering(scaled_logits, top_k)
193
-
194
- probs = F.softmax(scaled_logits, dim=-1)
195
- sampled_idx = torch.multinomial(probs, num_samples=1)
196
- generated_tokens.append(sampled_idx.item())
197
- idx = torch.cat((idx, sampled_idx), dim=1)
198
-
199
- return generated_tokens
200
-
201
-
202
- def _top_k_filtering(self, logits, top_k):
203
- values, indices = torch.topk(logits, top_k, dim=-1)
204
- min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
205
- filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
206
- return filtered_logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
enigma/TrainEnigma.ipynb DELETED
@@ -1,919 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {
7
- "colab": {
8
- "base_uri": "https://localhost:8080/"
9
- },
10
- "id": "WXpJBLyr30Rx",
11
- "outputId": "2806070a-648f-42ca-fa8a-9aeb8f99ceb7"
12
- },
13
- "outputs": [
14
- {
15
- "output_type": "stream",
16
- "name": "stdout",
17
- "text": [
18
- "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
19
- ]
20
- }
21
- ],
22
- "source": [
23
- "from google.colab import drive\n",
24
- "drive.mount('/content/drive')"
25
- ]
26
- },
27
- {
28
- "cell_type": "code",
29
- "execution_count": 2,
30
- "metadata": {
31
- "colab": {
32
- "base_uri": "https://localhost:8080/"
33
- },
34
- "id": "r7WUm0VL4bN4",
35
- "outputId": "bfdefb82-479e-4f91-9a01-299ff76756e9"
36
- },
37
- "outputs": [
38
- {
39
- "output_type": "stream",
40
- "name": "stdout",
41
- "text": [
42
- "485.52 million letters\n"
43
- ]
44
- }
45
- ],
46
- "source": [
47
- "import torch\n",
48
- "\n",
49
- "# importing the data\n",
50
- "file_path = '/content/drive/MyDrive/train2.txt'\n",
51
- "with open(file_path, 'r', encoding='utf-8') as file:\n",
52
- " dna_seq = file.read()\n",
53
- "file.close()\n",
54
- "\n",
55
- "print(f\"{(len(dna_seq)/1e6):.2f} million letters\")"
56
- ]
57
- },
58
- {
59
- "cell_type": "code",
60
- "execution_count": 3,
61
- "metadata": {
62
- "id": "Cdhybhz9owTK"
63
- },
64
- "outputs": [],
65
- "source": [
66
- "class PerCharTokenizer:\n",
67
- " \"\"\"\n",
68
- " Args:\n",
69
- " - chars (list): all bases along with special tokens represented as characters\n",
70
- " - vocab_size (int): size of vocabulary\n",
71
- "\n",
72
- " Working:\n",
73
- " - vocab contains all the bases and ['P', 'M', 'U'] as padding, mask and unknown token\n",
74
- " - encode(): iterates over each character a time and the looks up for the position in vocab\n",
75
- " and returns it's position as integer\n",
76
- " - decode(): takes input of a list of integers and returns the specific item from vocab\n",
77
- " \"\"\"\n",
78
- " def __init__(self):\n",
79
- " super().__init__()\n",
80
- " self.chars = ['\\n', 'A', 'T', 'G', 'C', 'P', 'M', 'U', ' ']\n",
81
- " self.vocab_size = len(self.chars)\n",
82
- " self.string_to_index = {ch: i for i, ch in enumerate(self.chars)}\n",
83
- " self.index_to_string = {i: ch for i, ch in enumerate(self.chars)}\n",
84
- "\n",
85
- " def encode(self, string):\n",
86
- " encoded = []\n",
87
- " for char in string:\n",
88
- " if char in self.string_to_index:\n",
89
- " encoded.append(self.string_to_index[char])\n",
90
- " else:\n",
91
- " special_index = len(self.string_to_index)\n",
92
- " self.string_to_index[char] = special_index\n",
93
- " self.index_to_string[special_index] = char\n",
94
- " encoded.append(special_index)\n",
95
- " return encoded\n",
96
- "\n",
97
- " def decode(self, integer):\n",
98
- " decoded = []\n",
99
- " for i in integer:\n",
100
- " if i in self.index_to_string:\n",
101
- " decoded.append(self.index_to_string[i])\n",
102
- " else:\n",
103
- " continue\n",
104
- " return ''.join(decoded)"
105
- ]
106
- },
107
- {
108
- "cell_type": "code",
109
- "execution_count": 4,
110
- "metadata": {
111
- "colab": {
112
- "base_uri": "https://localhost:8080/"
113
- },
114
- "id": "6Ou9txgmAdIB",
115
- "outputId": "cb5dd462-8b2a-445a-9524-1b484f288c64"
116
- },
117
- "outputs": [
118
- {
119
- "output_type": "stream",
120
- "name": "stdout",
121
- "text": [
122
- "train data 436.97million, val data 48.55million\n"
123
- ]
124
- }
125
- ],
126
- "source": [
127
- "token = PerCharTokenizer()\n",
128
- "data = torch.tensor(token.encode(dna_seq), dtype=torch.long)\n",
129
- "\n",
130
- "# Train and test splits\n",
131
- "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
132
- "train_data = data[:n]\n",
133
- "val_data = data[n:]\n",
134
- "print(f\"train data {(len(train_data)/1e6):.2f}million, val data {(len(val_data)/1e6):.2f}million\")"
135
- ]
136
- },
137
- {
138
- "cell_type": "code",
139
- "execution_count": 5,
140
- "metadata": {
141
- "id": "ebFKQQ9NAq4e"
142
- },
143
- "outputs": [],
144
- "source": [
145
- "# hyperparams\n",
146
- "batch_size = 10\n",
147
- "block_size = 512\n",
148
- "max_iters = 5000\n",
149
- "eval_interval = 100\n",
150
- "learning_rate = 3e-4\n",
151
- "eval_iters = 100\n",
152
- "d_model = 384\n",
153
- "n_layers = 12\n",
154
- "n_head = 12\n",
155
- "dropout = 0.25\n",
156
- "norm_eps = 1e-4"
157
- ]
158
- },
159
- {
160
- "cell_type": "code",
161
- "execution_count": 6,
162
- "metadata": {
163
- "id": "dZMiYkr37cmU"
164
- },
165
- "outputs": [],
166
- "source": [
167
- "import math\n",
168
- "import torch.nn as nn\n",
169
- "from torch.nn import functional as F\n",
170
- "\n",
171
- "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
172
- "\n",
173
- "class AttentionHead(nn.Module):\n",
174
- " \"\"\"\n",
175
- " initialize a single head of self attention.\n",
176
- "\n",
177
- " Args:\n",
178
- " - d_model (int): dimensionality of the model's hidden layers\n",
179
- " - head_size (int): dimensionality of each attention head\n",
180
- " - dropout (float): dropout probability\n",
181
- " - block_size (int): the maximum sequence length for positional encoding\n",
182
- " \"\"\"\n",
183
- " def __init__(self, d_model, head_size, dropout, block_size):\n",
184
- " super().__init__()\n",
185
- " self.key = nn.Linear(d_model, head_size, bias=True)\n",
186
- " self.query = nn.Linear(d_model, head_size, bias=True)\n",
187
- " self.value = nn.Linear(d_model, head_size, bias=False)\n",
188
- " self.dropout = nn.Dropout(dropout)\n",
189
- " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
190
- "\n",
191
- " self.rel_pos_emb = nn.Parameter(torch.randn(block_size, block_size, head_size))\n",
192
- "\n",
193
- " def forward(self, x, mask=False):\n",
194
- " \"\"\"\n",
195
- " forward pass of a single attention head.\n",
196
- "\n",
197
- " Args:\n",
198
- " - x (Tensor): input tensor.\n",
199
- " - mask (bool): flag indicating whether to apply masking\n",
200
- " Returns:\n",
201
- " - out (Tensor): output tensor after self attention\n",
202
- " \"\"\"\n",
203
- " B, T, C = x.shape\n",
204
- " key = self.key(x)\n",
205
- " query = self.query(x)\n",
206
- " scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)\n",
207
- "\n",
208
- " rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_emb[:T, :T])\n",
209
- " scores += rel_pos_scores\n",
210
- "\n",
211
- " if mask:\n",
212
- " scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))\n",
213
- " weights = F.softmax(scores, dim=-1)\n",
214
- " weights = self.dropout(weights)\n",
215
- "\n",
216
- " value = self.value(x)\n",
217
- " out = torch.matmul(weights, value)\n",
218
- " return out\n",
219
- "\n",
220
- "class MultiHeadAttention(nn.Module):\n",
221
- " \"\"\"\n",
222
- " initialize a multi-head attention module.\n",
223
- "\n",
224
- " Args:\n",
225
- " - d_model (int): dimensionality of the model's hidden layers\n",
226
- " - n_head (int): no of attention heads\n",
227
- " - dropout (float): dropout probability\n",
228
- " - block_size (int): context length\n",
229
- " \"\"\"\n",
230
- " def __init__(self, d_model, n_head, dropout, block_size):\n",
231
- " head_size = d_model // n_head\n",
232
- " super().__init__()\n",
233
- " self.heads = nn.ModuleList([AttentionHead(d_model=d_model, dropout=dropout, head_size=head_size, block_size=block_size) for _ in range(n_head)])\n",
234
- " self.proj = nn.Linear(n_head * head_size, d_model)\n",
235
- " self.dropout = nn.Dropout(dropout)\n",
236
- "\n",
237
- " def forward(self, x, mask):\n",
238
- " \"\"\"\n",
239
- " forward pass of the multi-head attention module\n",
240
- "\n",
241
- " Args:\n",
242
- " - x (Tensor): input tensor\n",
243
- " - mask (bool): flag indicating whether to apply masking\n",
244
- "\n",
245
- " Returns:\n",
246
- " - out (Tensor): output tensor after multi-head attention\n",
247
- "\n",
248
- " \"\"\"\n",
249
- " out = torch.cat([h(x, mask=mask) for h in self.heads], dim=-1)\n",
250
- " out = self.dropout(self.proj(out))\n",
251
- " return out\n",
252
- "\n",
253
- "class FeedForward(nn.Module):\n",
254
- " \"\"\"\n",
255
- " initialize a feedforward network module\n",
256
- "\n",
257
- " Args:\n",
258
- " - d_model (int): the dimensionality of the model's hidden layers\n",
259
- " - dropout (float): dropout probability\n",
260
- "\n",
261
- " \"\"\"\n",
262
- " def __init__(self, d_model, dropout):\n",
263
- " super().__init__()\n",
264
- " self.net = nn.Sequential(\n",
265
- " nn.Linear(d_model, 5*d_model),\n",
266
- " nn.GELU(),\n",
267
- " nn.Linear(5*d_model, d_model),\n",
268
- " nn.Dropout(dropout)\n",
269
- " )\n",
270
- "\n",
271
- " def forward(self, x):\n",
272
- " \"\"\"\n",
273
- " forward pass of the feedforward network module\n",
274
- "\n",
275
- " Args:\n",
276
- " - x (Tensor): input tensor\n",
277
- "\n",
278
- " Returns:\n",
279
- " - out (Tensor): output tensor after passing through the feedforward network\n",
280
- " \"\"\"\n",
281
- " return self.net(x)\n",
282
- "\n",
283
- "class EncoderNetwork(nn.Module):\n",
284
- " \"\"\"\n",
285
- " initialize an encoder network module\n",
286
- "\n",
287
- " Args:\n",
288
- " - d_model (int): dimensionality of the model's hidden layers\n",
289
- " - n_head (int): no of attention heads in multi-head attention layers\n",
290
- " - norm_eps (float): epsilon value for layer normalization\n",
291
- " - dropout (float): dropout probability\n",
292
- " - block_size (int): the maximum sequence length for positional encoding\n",
293
- " \"\"\"\n",
294
- " def __init__(self, d_model, n_head, norm_eps, dropout, block_size):\n",
295
- " super().__init__()\n",
296
- " self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)\n",
297
- " self.ffwd = FeedForward(d_model, dropout)\n",
298
- " self.dropout = nn.Dropout(dropout)\n",
299
- " self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)\n",
300
- " self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)\n",
301
- "\n",
302
- " def forward(self, src):\n",
303
- " \"\"\"\n",
304
- " forward pass of the encoder network module.\n",
305
- "\n",
306
- " Args:\n",
307
- " - src (Tensor): input tensor representing source data\n",
308
- "\n",
309
- " Returns:\n",
310
- " - src (Tensor): output tensor after passing through the encoder network\n",
311
- " \"\"\"\n",
312
- " src2 = self.s_att(src, mask=False)\n",
313
- " src = src + self.dropout(src2)\n",
314
- " src = self.norm1(src)\n",
315
- "\n",
316
- " src2 = self.ffwd(src)\n",
317
- " src = src + self.dropout(src2)\n",
318
- " src = self.norm2(src)\n",
319
- "\n",
320
- " return src\n",
321
- "\n",
322
- "class DecoderNetwork(nn.Module):\n",
323
- " \"\"\"\n",
324
- " initialize a decoder network module\n",
325
- "\n",
326
- " Args:\n",
327
- " - d_model (int): dimensionality of the model's hidden layers\n",
328
- " - n_head (int): no of attention heads in multi-head attention layers\n",
329
- " - norm_eps (float): epsilon value for layer normalization\n",
330
- " - dropout (float): dropout probability\n",
331
- " - block_size (int): the maximum sequence length for positional encoding\n",
332
- " \"\"\"\n",
333
- " def __init__(self, d_model, n_head, norm_eps, dropout, block_size):\n",
334
- " super().__init__()\n",
335
- " self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)\n",
336
- " self.ffwd = FeedForward(d_model, dropout)\n",
337
- " self.dropout = nn.Dropout(dropout)\n",
338
- " self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)\n",
339
- " self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)\n",
340
- "\n",
341
- " def forward(self, src, att):\n",
342
- " \"\"\"\n",
343
- " forward pass of the decoder network module.\n",
344
- "\n",
345
- " Args:\n",
346
- " - src (Tensor): input tensor, same as the encoder's inputs\n",
347
- " - trg (Tensor): encoder's attention matrix\n",
348
- "\n",
349
- " Returns:\n",
350
- " - src_f (Tensor): final output tensor\n",
351
- " \"\"\"\n",
352
- " src2 = self.s_att(src, mask=True)\n",
353
- " src = src + self.dropout(src2)\n",
354
- " src = src + self.norm1(src)\n",
355
- "\n",
356
- " att = src + att\n",
357
- " att2 = self.s_att(att, mask=False)\n",
358
- " att2 = att + self.dropout(att2)\n",
359
- " trg = att2 + self.norm1(att2)\n",
360
- "\n",
361
- " src_f2 = self.ffwd(self.norm2(trg))\n",
362
- " src_f = src_f2 + self.dropout(src_f2)\n",
363
- " src_f = self.norm2(src_f)\n",
364
- "\n",
365
- " return src_f\n",
366
- "\n",
367
- "class Transformer(nn.Module):\n",
368
- " \"\"\"\n",
369
- " initialize a Transformer model\n",
370
- "\n",
371
- " Args:\n",
372
- " - vocab_size (int): size of the vocabulary\n",
373
- " - d_model (int): dimensionality of the model's hidden layers\n",
374
- " - block_size (int): maximum sequence length for positional encoding/context length\n",
375
- " - n_layers (int): number of encoder and decoder layers in the Transformer\n",
376
- " - n_head (int): number of attention heads in multi-head attention layers\n",
377
- " - norm_eps (float): epsilon value for layer normalization\n",
378
- " - dropout (float): dropout probability\n",
379
- " \"\"\"\n",
380
- " def __init__(self, vocab_size):\n",
381
- " super().__init__()\n",
382
- " self.block_size = block_size\n",
383
- " self.toked_model = nn.Embedding(vocab_size, d_model)\n",
384
- " self.pos_encod = nn.Embedding(block_size, d_model)\n",
385
- " self.enc_layer = nn.ModuleList([EncoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])\n",
386
- " self.dec_layer = nn.ModuleList([DecoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])\n",
387
- "\n",
388
- " self.norm_final = nn.LayerNorm(d_model)\n",
389
- " self.linear_final = nn.Linear(d_model, vocab_size)\n",
390
- " self.dropout = nn.Dropout(dropout)\n",
391
- " self.apply(self._init_weights)\n",
392
- "\n",
393
- " def _init_weights(self, module):\n",
394
- " \"\"\"\n",
395
- " initialize weights of linear and embedding layers\n",
396
- "\n",
397
- " Args:\n",
398
- " - module (nn.Module): the module to initialize weights for\n",
399
- " \"\"\"\n",
400
- " if isinstance(module, nn.Linear):\n",
401
- " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
402
- " if module.bias is not None:\n",
403
- " torch.nn.init.zeros_(module.bias.data)\n",
404
- " elif isinstance(module, nn.Embedding):\n",
405
- " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
406
- "\n",
407
- " def forward(self, idx, targets=None):\n",
408
- " \"\"\"\n",
409
- " forward pass of the transformer model\n",
410
- "\n",
411
- " Args:\n",
412
- " - idx (Tensor): input tensor representing token indices\n",
413
- " - targets (Tensor): target tensor for computing loss during training\n",
414
- "\n",
415
- " Returns:\n",
416
- " - logits (Tensor): output logits from the final linear layer\n",
417
- " - loss (Tensor): optional. computed cross-entropy loss if targets are provided, else None\n",
418
- " \"\"\"\n",
419
- " B, T = idx.shape\n",
420
- "\n",
421
- " toked_model = self.toked_model(idx)\n",
422
- " pos_encod = self.pos_encod(torch.arange(T, device=device))\n",
423
- " x = toked_model + pos_encod\n",
424
- "\n",
425
- " for layer in self.enc_layer:\n",
426
- " x_out = layer(x)\n",
427
- "\n",
428
- " for layer in self.dec_layer:\n",
429
- " x_final = layer(x, x_out)\n",
430
- "\n",
431
- " x_final = self.norm_final(x_final)\n",
432
- " logits = self.linear_final(x_final)\n",
433
- "\n",
434
- " if targets is None:\n",
435
- " loss = None\n",
436
- "\n",
437
- " else:\n",
438
- " B, T, C = logits.shape\n",
439
- " logits = logits.view(B*T, C)\n",
440
- " targets = targets.view(B*T)\n",
441
- " loss = F.cross_entropy(logits, targets)\n",
442
- "\n",
443
- " return logits, loss\n",
444
- " def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):\n",
445
- " \"\"\"\n",
446
- " generate new tokens using the trained model\n",
447
- "\n",
448
- " Args:\n",
449
- " - idx (Tensor): input tensor representing initial token indices\n",
450
- " - max_new_tokens (int): max no of new tokens to generate\n",
451
- " - temperature (float): softmax temperature for sampling\n",
452
- " - top_k (int): no of top tokens to consider in sampling\n",
453
- "\n",
454
- " Returns:\n",
455
- " - generated_tokens (list): list of generated token indices\n",
456
- " \"\"\"\n",
457
- " generated_tokens = []\n",
458
- "\n",
459
- " for _ in range(max_new_tokens):\n",
460
- " idx_cond = idx[:, -self.block_size:]\n",
461
- " logits, _ = self(idx_cond)\n",
462
- " logits = logits[:, -1, :]\n",
463
- "\n",
464
- " scaled_logits = logits / temperature\n",
465
- " if top_k > 0:\n",
466
- " scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
467
- "\n",
468
- " probs = F.softmax(scaled_logits, dim=-1)\n",
469
- " sampled_idx = torch.multinomial(probs, num_samples=1)\n",
470
- " generated_tokens.append(sampled_idx.item())\n",
471
- " idx = torch.cat((idx, sampled_idx), dim=1)\n",
472
- "\n",
473
- " return generated_tokens\n",
474
- "\n",
475
- " def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):\n",
476
- " \"\"\"\n",
477
- " Generate predictions for masked tokens using the trained model.\n",
478
- "\n",
479
- " Args:\n",
480
- " - idx (Tensor): input tensor representing token indices\n",
481
- " - masked_indices (Tensor): tensor of indices indicating masked positions\n",
482
- " - temperature (float): softmax temperature for sampling\n",
483
- " - top_k (int): no of top tokens to consider in sampling\n",
484
- "\n",
485
- " Returns:\n",
486
- " - predicted_tokens (Tensor): tensor of predicted token indices\n",
487
- " \"\"\"\n",
488
- " B, T = idx.shape\n",
489
- "\n",
490
- " toked_model = self.toked_model(idx)\n",
491
- " pos_encod = self.pos_encod(torch.arange(T, device=device))\n",
492
- " x = toked_model + pos_encod\n",
493
- "\n",
494
- " for layer in self.enc_layer:\n",
495
- " x_out = layer(x)\n",
496
- "\n",
497
- " for layer in self.dec_layer:\n",
498
- " x_final = layer(x, x_out)\n",
499
- "\n",
500
- " x_masked = x_final.clone()\n",
501
- " x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))\n",
502
- "\n",
503
- " x_masked = self.norm_final(x_masked)\n",
504
- " logits = self.linear_final(x_masked)\n",
505
- "\n",
506
- " masked_logits = logits[masked_indices].view(-1, logits.size(-1))\n",
507
- " scaled_logits = masked_logits / temperature\n",
508
- " if top_k > 0:\n",
509
- " scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
510
- "\n",
511
- " probs = F.softmax(scaled_logits, dim=-1)\n",
512
- " predicted_indices = torch.argmax(probs, dim=-1)\n",
513
- "\n",
514
- " return predicted_indices\n",
515
- "\n",
516
- " def _top_k_filtering(self, logits, top_k):\n",
517
- " \"\"\"\n",
518
- " filter logits to keep only the top-k tokens\n",
519
- "\n",
520
- " Args:\n",
521
- " - logits (Tensor): input tensor representing unscaled logits\n",
522
- " - top_k (int): no of top tokens to keep\n",
523
- "\n",
524
- " Returns:\n",
525
- " - filtered_logits (Tensor): filtered logits with only top-k tokens remaining\n",
526
- " \"\"\"\n",
527
- " values, indices = torch.topk(logits, top_k, dim=-1)\n",
528
- " min_value = values[:, -1].unsqueeze(-1).expand_as(logits)\n",
529
- " filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)\n",
530
- "\n",
531
- " return filtered_logits"
532
- ]
533
- },
534
- {
535
- "cell_type": "code",
536
- "execution_count": 7,
537
- "metadata": {
538
- "colab": {
539
- "base_uri": "https://localhost:8080/",
540
- "height": 816
541
- },
542
- "id": "X9VOBZFr7g3W",
543
- "outputId": "aa376025-0a37-4b93-e90a-9d95c6ef2c11"
544
- },
545
- "outputs": [
546
- {
547
- "output_type": "stream",
548
- "name": "stdout",
549
- "text": [
550
- "2.5 billion parameters\n",
551
- "step 0: train loss 2.2869, val loss 2.2884\n",
552
- "step 100: train loss 1.3312, val loss 1.3281\n",
553
- "step 200: train loss 1.3233, val loss 1.3181\n",
554
- "step 300: train loss 1.3209, val loss 1.3196\n",
555
- "step 400: train loss 1.3215, val loss 1.3203\n",
556
- "step 500: train loss 1.1974, val loss 1.1994\n",
557
- "step 600: train loss 0.3350, val loss 0.3365\n",
558
- "step 700: train loss 0.0703, val loss 0.0702\n",
559
- "step 800: train loss 0.0143, val loss 0.0143\n",
560
- "step 900: train loss 0.0049, val loss 0.0047\n",
561
- "step 1000: train loss 0.0041, val loss 0.0037\n",
562
- "step 1100: train loss 0.0035, val loss 0.0036\n",
563
- "step 1200: train loss 0.0038, val loss 0.0035\n",
564
- "step 1300: train loss 0.0035, val loss 0.0033\n",
565
- "step 1400: train loss 0.0035, val loss 0.0033\n",
566
- "step 1500: train loss 0.0033, val loss 0.0033\n",
567
- "step 1600: train loss 0.0033, val loss 0.0034\n",
568
- "step 1700: train loss 0.0033, val loss 0.0033\n",
569
- "step 1800: train loss 0.0033, val loss 0.0031\n",
570
- "step 1900: train loss 0.0031, val loss 0.0031\n",
571
- "step 2000: train loss 0.0032, val loss 0.0032\n"
572
- ]
573
- },
574
- {
575
- "output_type": "error",
576
- "ename": "KeyboardInterrupt",
577
- "evalue": "",
578
- "traceback": [
579
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
580
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
581
- "\u001b[0;32m<ipython-input-7-44818790f2dc>\u001b[0m in \u001b[0;36m<cell line: 45>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mset_to_none\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
582
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
583
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
584
- "\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, idx, targets)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdec_layer\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mx_final\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0mx_final\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm_final\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_final\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
585
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
586
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
587
- "\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, src, att)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0matt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msrc\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0matt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 191\u001b[0;31m \u001b[0matt2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0ms_att\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0matt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 192\u001b[0m \u001b[0matt2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0matt\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0matt2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[0mtrg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0matt2\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0matt2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
588
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
589
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
590
- "\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 82\u001b[0m \"\"\"\n\u001b[0;32m---> 83\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mh\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheads\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 84\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
591
- "\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 82\u001b[0m \"\"\"\n\u001b[0;32m---> 83\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mh\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheads\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 84\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
592
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
593
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
594
- "\u001b[0;32m<ipython-input-6-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, mask)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 50\u001b[0;31m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 51\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
595
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
596
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
597
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
598
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
599
- ]
600
- }
601
- ],
602
- "source": [
603
- "import timeit\n",
604
- "\n",
605
- "start_time = timeit.default_timer()\n",
606
- "# data loading\n",
607
- "def get_batch(split):\n",
608
- "\n",
609
- " data = train_data if split == 'train' else val_data\n",
610
- " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
611
- " x = torch.stack([data[i:i+block_size] for i in ix])\n",
612
- " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
613
- " x, y = x.to(device), y.to(device)\n",
614
- " return x, y\n",
615
- "\n",
616
- "@torch.no_grad()\n",
617
- "def estimate_loss():\n",
618
- " out = {}\n",
619
- " model.eval()\n",
620
- " for split in ['train', 'val']:\n",
621
- " losses = torch.zeros(eval_iters)\n",
622
- " for k in range(eval_iters):\n",
623
- " X, Y = get_batch(split)\n",
624
- " logits, loss = model(X, Y)\n",
625
- " losses[k] = loss.item()\n",
626
- " out[split] = losses.mean()\n",
627
- " model.train()\n",
628
- " return out\n",
629
- "\n",
630
- "vocab_size = token.vocab_size\n",
631
- "model = Transformer(vocab_size)\n",
632
- "# checkpoint_path = '/content/drive/MyDrive/enigma-2.5b.pth'\n",
633
- "# checkpoint = torch.load(checkpoint_path)\n",
634
- "# model.load_state_dict(checkpoint)\n",
635
- "m = model.to(device)\n",
636
- "\n",
637
- "# no of parameters\n",
638
- "n_param = sum(p.numel() for p in m.parameters())/1e9\n",
639
- "print(f\"{n_param:.1f} billion parameters\")\n",
640
- "\n",
641
- "# optimizer\n",
642
- "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
643
- "steps = []\n",
644
- "train_losses = []\n",
645
- "val_losses = []\n",
646
- "\n",
647
- "for iter in range(max_iters):\n",
648
- "\n",
649
- " if iter % eval_interval == 0 or iter == max_iters - 1:\n",
650
- " losses = estimate_loss()\n",
651
- " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
652
- "\n",
653
- " steps.append(iter)\n",
654
- " train_losses.append(losses['train'])\n",
655
- " val_losses.append(losses['val'])\n",
656
- "\n",
657
- " xb, yb = get_batch('train')\n",
658
- " logits, loss = model(xb, yb)\n",
659
- " optimizer.zero_grad(set_to_none=True)\n",
660
- " loss.backward()\n",
661
- " optimizer.step()"
662
- ]
663
- },
664
- {
665
- "cell_type": "code",
666
- "execution_count": 8,
667
- "metadata": {
668
- "id": "tzJMKoA35uIV",
669
- "colab": {
670
- "base_uri": "https://localhost:8080/"
671
- },
672
- "outputId": "ba527bf5-695c-4a8f-acc4-bd60d549eaad"
673
- },
674
- "outputs": [
675
- {
676
- "output_type": "stream",
677
- "name": "stdout",
678
- "text": [
679
- "total parameters: 2.5 billion\n",
680
- "trained in 1.82hrs\n"
681
- ]
682
- }
683
- ],
684
- "source": [
685
- "end_time = timeit.default_timer()\n",
686
- "print(f\"total parameters: {n_param:.1f} billion\")\n",
687
- "print(f\"trained in {((end_time - start_time)/3600):.2f}hrs\")"
688
- ]
689
- },
690
- {
691
- "cell_type": "code",
692
- "source": [
693
- "model_save_name = f'enigma-{n_param:.1f}b_v1.pth'\n",
694
- "path = f\"/content/drive/MyDrive/{model_save_name}\"\n",
695
- "torch.save(model.state_dict(), path)"
696
- ],
697
- "metadata": {
698
- "id": "eB47Yn9aNrrO"
699
- },
700
- "execution_count": 10,
701
- "outputs": []
702
- },
703
- {
704
- "cell_type": "code",
705
- "source": [
706
- "# 8-bit quantization\n",
707
- "\n",
708
- "import torch\n",
709
- "import torch.quantization\n",
710
- "\n",
711
- "checkpoint_path = '/content/drive/MyDrive/enigma-2.5b.pth'\n",
712
- "checkpoint = torch.load(checkpoint_path)\n",
713
- "model.load_state_dict(checkpoint)\n",
714
- "model = model.to(device)\n",
715
- "\n",
716
- "quantized_model = torch.quantization.quantize_dynamic(\n",
717
- " model,\n",
718
- " dtype=torch.qint8\n",
719
- ")\n",
720
- "quantized_model_file = f'/content/drive/MyDrive/enigma-2.5b-quant.pth'\n",
721
- "torch.save(quantized_model.state_dict(), quantized_model_file)\n",
722
- "\n",
723
- "print(\"Quantized model saved successfully.\")"
724
- ],
725
- "metadata": {
726
- "id": "7iGQdNHgms_U"
727
- },
728
- "execution_count": null,
729
- "outputs": []
730
- },
731
- {
732
- "cell_type": "code",
733
- "source": [
734
- "# pruning\n",
735
- "\n",
736
- "import torch\n",
737
- "from torch import nn\n",
738
- "from torch.utils.model_zoo import load_url\n",
739
- "import torch.nn.utils.prune as prune\n",
740
- "\n",
741
- "parameters_to_prune = [(model.encoder.self_attn, 'weight'), (model.encoder.linear1, 'weight')]\n",
742
- "prune.global_unstructured(\n",
743
- " parameters_to_prune,\n",
744
- " pruning_method=prune.L1Unstructured,\n",
745
- " amount=0.15,\n",
746
- ")\n",
747
- "\n",
748
- "torch.save(model.state_dict(), 'enigma-2.5b_pruned.pth')"
749
- ],
750
- "metadata": {
751
- "id": "YTJ19n4OFvZj"
752
- },
753
- "execution_count": null,
754
- "outputs": []
755
- },
756
- {
757
- "cell_type": "code",
758
- "execution_count": null,
759
- "metadata": {
760
- "id": "K2FDOp7Quibq"
761
- },
762
- "outputs": [],
763
- "source": [
764
- "class Generator(Transformer):\n",
765
- " def __init__(self, vocab_size, block_size):\n",
766
- " super().__init__(vocab_size)\n",
767
- " self.vocab_size = vocab_size\n",
768
- " self.block_size = block_size\n",
769
- "\n",
770
- " def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):\n",
771
- " \"\"\"\n",
772
- " generate new tokens using the trained model\n",
773
- "\n",
774
- " Args:\n",
775
- " - idx (Tensor): input tensor representing initial token indices\n",
776
- " - max_new_tokens (int): max no of new tokens to generate\n",
777
- " - temperature (float): softmax temperature for sampling\n",
778
- " - top_k (int): no of top tokens to consider in sampling\n",
779
- "\n",
780
- " Returns:\n",
781
- " - generated_tokens (list): list of generated token indices\n",
782
- " \"\"\"\n",
783
- " generated_tokens = []\n",
784
- "\n",
785
- " for _ in range(max_new_tokens):\n",
786
- " idx_cond = idx[:, -self.block_size:]\n",
787
- " logits, _ = self(idx_cond)\n",
788
- " logits = logits[:, -1, :]\n",
789
- "\n",
790
- " scaled_logits = logits / temperature\n",
791
- " if top_k > 0:\n",
792
- " scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
793
- "\n",
794
- " probs = F.softmax(scaled_logits, dim=-1)\n",
795
- " sampled_idx = torch.multinomial(probs, num_samples=1)\n",
796
- " generated_tokens.append(sampled_idx.item())\n",
797
- " idx = torch.cat((idx, sampled_idx), dim=1)\n",
798
- "\n",
799
- " return generated_tokens\n",
800
- "\n",
801
- " def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):\n",
802
- " \"\"\"\n",
803
- " Generate predictions for masked tokens using the trained model.\n",
804
- "\n",
805
- " Args:\n",
806
- " - idx (Tensor): input tensor representing token indices\n",
807
- " - masked_indices (Tensor): tensor of indices indicating masked positions\n",
808
- " - temperature (float): softmax temperature for sampling\n",
809
- " - top_k (int): no of top tokens to consider in sampling\n",
810
- "\n",
811
- " Returns:\n",
812
- " - predicted_tokens (Tensor): tensor of predicted token indices\n",
813
- " \"\"\"\n",
814
- " B, T = idx.shape\n",
815
- "\n",
816
- " toked_model = self.toked_model(idx)\n",
817
- " pos_encod = self.pos_encod(torch.arange(T, device=device))\n",
818
- " x = toked_model + pos_encod\n",
819
- "\n",
820
- " for layer in self.enc_layer:\n",
821
- " x_out = layer(x)\n",
822
- "\n",
823
- " for layer in self.dec_layer:\n",
824
- " x_final = layer(x, x_out)\n",
825
- "\n",
826
- " x_masked = x_final.clone()\n",
827
- " x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))\n",
828
- "\n",
829
- " x_masked = self.norm_final(x_masked)\n",
830
- " logits = self.linear_final(x_masked)\n",
831
- "\n",
832
- " masked_logits = logits[masked_indices].view(-1, logits.size(-1))\n",
833
- " scaled_logits = masked_logits / temperature\n",
834
- " if top_k > 0:\n",
835
- " scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
836
- "\n",
837
- " probs = F.softmax(scaled_logits, dim=-1)\n",
838
- " predicted_indices = torch.argmax(probs, dim=-1)\n",
839
- "\n",
840
- " return predicted_indices\n",
841
- "\n",
842
- " def _top_k_filtering(self, logits, top_k):\n",
843
- " \"\"\"\n",
844
- " filter logits to keep only the top-k tokens\n",
845
- "\n",
846
- " Args:\n",
847
- " - logits (Tensor): input tensor representing unscaled logits\n",
848
- " - top_k (int): no of top tokens to keep\n",
849
- "\n",
850
- " Returns:\n",
851
- " - filtered_logits (Tensor): filtered logits with only top-k tokens remaining\n",
852
- " \"\"\"\n",
853
- " values, indices = torch.topk(logits, top_k, dim=-1)\n",
854
- " min_value = values[:, -1].unsqueeze(-1).expand_as(logits)\n",
855
- " filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)\n",
856
- "\n",
857
- " return filtered_logits"
858
- ]
859
- },
860
- {
861
- "cell_type": "code",
862
- "execution_count": null,
863
- "metadata": {
864
- "colab": {
865
- "base_uri": "https://localhost:8080/",
866
- "height": 429
867
- },
868
- "id": "c5CknylV4S2m",
869
- "outputId": "12314d78-9147-4e60-f8b5-84207b97a1c7"
870
- },
871
- "outputs": [
872
- {
873
- "output_type": "error",
874
- "ename": "RuntimeError",
875
- "evalue": "Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)",
876
- "traceback": [
877
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
878
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
879
- "\u001b[0;32m<ipython-input-17-db17ec37b06c>\u001b[0m in \u001b[0;36m<cell line: 5>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mtarget_text\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"AGTTCTGCGAT\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtoken\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget_text\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mgenerated_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoken\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_new_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtemperature\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.9\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtop_k\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{target_text}{generated_output}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
880
- "\u001b[0;32m<ipython-input-16-39da0e3e4598>\u001b[0m in \u001b[0;36mgenerate\u001b[0;34m(self, idx, max_new_tokens, temperature, top_k)\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_new_tokens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0midx_cond\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mblock_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx_cond\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogits\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
881
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
882
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
883
- "\u001b[0;32m<ipython-input-7-b2af72f89b89>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, idx, targets)\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mT\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 255\u001b[0;31m \u001b[0mtoked_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoked_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 256\u001b[0m \u001b[0mpos_encod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpos_encod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoked_model\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mpos_encod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
884
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1510\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1511\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1513\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
885
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1518\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
886
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m return F.embedding(\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpadding_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_norm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m self.norm_type, self.scale_grad_by_freq, self.sparse)\n",
887
- "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36membedding\u001b[0;34m(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001b[0m\n\u001b[1;32m 2235\u001b[0m \u001b[0;31m# remove once script supports set_grad_enabled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2236\u001b[0m \u001b[0m_no_grad_embedding_renorm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_norm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2237\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpadding_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale_grad_by_freq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msparse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2238\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2239\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
888
- "\u001b[0;31mRuntimeError\u001b[0m: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)"
889
- ]
890
- }
891
- ],
892
- "source": [
893
- "generator = Generator(vocab_size, block_size)\n",
894
- "\n",
895
- "target_text = \"AGTTCTGCGAT\"\n",
896
- "context = torch.tensor([token.encode(target_text)], dtype=torch.long, device=device)\n",
897
- "generated_output = token.decode(generator.generate(context, max_new_tokens=100, temperature=0.9, top_k=5))\n",
898
- "print(f\"{target_text}{generated_output}\")"
899
- ]
900
- }
901
- ],
902
- "metadata": {
903
- "accelerator": "GPU",
904
- "colab": {
905
- "gpuType": "T4",
906
- "machine_shape": "hm",
907
- "provenance": []
908
- },
909
- "kernelspec": {
910
- "display_name": "Python 3",
911
- "name": "python3"
912
- },
913
- "language_info": {
914
- "name": "python"
915
- }
916
- },
917
- "nbformat": 4,
918
- "nbformat_minor": 0
919
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
enigma/config_enigma.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "batch_size": 10,
3
- "block_size": 512,
4
- "max_iters": 5000,
5
- "eval_interval": 50,
6
- "learning_rate": 3e-5,
7
- "eval_iters": 100,
8
- "d_model": 384,
9
- "n_head": 12,
10
- "n_layer": 12,
11
- "dropout": 0.2,
12
- "norm_eps": 1e-5
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
enigma/enigma.cpp DELETED
@@ -1,364 +0,0 @@
1
- #include <torch/torch.h>
2
- #include <iostream>
3
- #include <vector>
4
-
5
- // Define device
6
- torch::Device device(torch::kCUDA);
7
-
8
- // Define constants
9
- const int batch_size = 8;
10
- const int block_size = 32;
11
- const int max_iters = 1000;
12
- const int eval_interval = 50;
13
- const int eval_iters = 5;
14
- const int d_model = 256;
15
- const int n_layer = 16;
16
- const int n_head = 12;
17
- const float dropout = 0.2;
18
- const float norm_eps = 1e-5;
19
- const int vocab_size = 5;
20
-
21
- // sample data
22
- torch::Tensor train_data = torch::rand({1000, block_size});
23
- torch::Tensor val_data = torch::rand({500, block_size});
24
-
25
- // Data loading function
26
- std::pair<torch::Tensor, torch::Tensor> get_batch(const std::string& split) {
27
- torch::Tensor data = (split == "train") ? train_data : val_data;
28
- torch::Tensor ix = torch::randint(data.size(0) - block_size, {batch_size});
29
- torch::Tensor x = torch::empty({batch_size, block_size});
30
- torch::Tensor y = torch::empty({batch_size, block_size});
31
- for (int i = 0; i < batch_size; ++i) {
32
- x[i] = data.index({ix[i], ix[i] + block_size});
33
- y[i] = data.index({ix[i] + 1, ix[i] + block_size + 1});
34
- }
35
- return std::make_pair(x.to(device), y.to(device));
36
- }
37
-
38
- // Custom classes and functions
39
- class SWiGLU : public torch::nn::Module {
40
- public:
41
- SWiGLU() {}
42
-
43
- torch::Tensor forward(torch::Tensor x) {
44
- torch::Tensor sigmoid_output = torch::sigmoid(x);
45
- torch::Tensor relu_output = torch::relu(x);
46
- torch::Tensor out = sigmoid_output * relu_output + (1 - sigmoid_output) * x;
47
- return out;
48
- }
49
- };
50
-
51
- class UnMaskedHeadImpl : public torch::nn::Module {
52
- public:
53
- UnMaskedHeadImpl(int d_model, int head_size, float dropout)
54
- : key(register_module("key", torch::nn::Linear(d_model, head_size))),
55
- query(register_module("query", torch::nn::Linear(d_model, head_size))),
56
- value(register_module("value", torch::nn::Linear(d_model, head_size))),
57
- dropout(torch::nn::Dropout(dropout)) {
58
- register_module("dropout", dropout);
59
- }
60
-
61
- torch::Tensor forward(torch::Tensor x) {
62
- torch::Tensor key_out = key->forward(x);
63
- torch::Tensor query_out = query->forward(x);
64
-
65
- torch::Tensor weights = query_out.matmul(key_out.transpose(-2, -1)) * std::sqrt(key_out.size(-1));
66
- weights = torch::softmax(weights, -1);
67
- weights = dropout(weights);
68
-
69
- torch::Tensor value_out = value->forward(x);
70
- torch::Tensor out = weights.matmul(value_out);
71
- return out;
72
- }
73
-
74
- private:
75
- torch::nn::Linear key, query, value;
76
- torch::nn::Dropout dropout;
77
- };
78
-
79
- TORCH_MODULE(UnMaskedHead);
80
-
81
- class MaskedHeadImpl : public torch::nn::Module {
82
- public:
83
- MaskedHeadImpl(int head_size, float dropout, int d_model)
84
- : key(register_module("key", torch::nn::Linear(d_model, head_size))),
85
- query(register_module("query", torch::nn::Linear(d_model, head_size))),
86
- value(register_module("value", torch::nn::Linear(d_model, head_size))),
87
- dropout(torch::nn::Dropout(dropout)) {
88
- register_buffer("tril", torch::tril(torch::ones(block_size, block_size)));
89
- }
90
-
91
- torch::Tensor forward(torch::Tensor x) {
92
- torch::Tensor key_out = key->forward(x);
93
- torch::Tensor query_out = query->forward(x);
94
-
95
- torch::Tensor weights = query_out.matmul(key_out.transpose(-2, -1)) * std::sqrt(key_out.size(-1));
96
- weights = weights.masked_fill(tril[:x.size(1), :x.size(1)] == 0, std::numeric_limits<float>::lowest());
97
- weights = torch::softmax(weights, -1);
98
- weights = dropout(weights);
99
-
100
- torch::Tensor value_out = value->forward(x);
101
- torch::Tensor out = weights.matmul(value_out);
102
- return out;
103
- }
104
-
105
- private:
106
- torch::nn::Linear key, query, value;
107
- torch::nn::Dropout dropout;
108
- torch::Tensor tril;
109
- };
110
-
111
- TORCH_MODULE(MaskedHead);
112
-
113
- class MultiUnMaskedImpl : public torch::nn::Module {
114
- public:
115
- MultiUnMaskedImpl(int d_model, int n_head, float dropout)
116
- : proj(register_module("proj", torch::nn::Linear(n_head * (d_model / n_head), d_model))),
117
- dropout(torch::nn::Dropout(dropout)) {
118
- for (int i = 0; i < n_head; ++i) {
119
- heads.push_back(register_module("head" + std::to_string(i), UnMaskedHead(d_model, d_model / n_head, dropout)));
120
- }
121
- }
122
-
123
- torch::Tensor forward(torch::Tensor x) {
124
- std::vector<torch::Tensor> head_outputs;
125
- for (auto& head : heads) {
126
- head_outputs.push_back(head->forward(x));
127
- }
128
- torch::Tensor out = torch::cat(head_outputs, -1);
129
- out = dropout(out);
130
- out = proj(out);
131
- return out;
132
- }
133
-
134
- private:
135
- torch::nn::Linear proj;
136
- torch::nn::Dropout dropout;
137
- std::vector<UnMaskedHead> heads;
138
- };
139
-
140
- TORCH_MODULE(MultiUnMasked);
141
-
142
- class MultiMaskedImpl : public torch::nn::Module {
143
- public:
144
- MultiMaskedImpl(int d_model, int n_head, float dropout)
145
- : proj(register_module("proj", torch::nn::Linear(n_head * (d_model / n_head), d_model))),
146
- dropout(torch::nn::Dropout(dropout)) {
147
- for (int i = 0; i < n_head; ++i) {
148
- heads.push_back(register_module("head" + std::to_string(i), MaskedHead(d_model, d_model / n_head, dropout)));
149
- }
150
- }
151
-
152
- torch::Tensor forward(torch::Tensor x) {
153
- std::vector<torch::Tensor> head_outputs;
154
- for (auto& head : heads) {
155
- head_outputs.push_back(head->forward(x));
156
- }
157
- torch::Tensor out = torch::cat(head_outputs, -1);
158
- out = dropout(out);
159
- out = proj(out);
160
- return out;
161
- }
162
-
163
- private:
164
- torch::nn::Linear proj;
165
- torch::nn::Dropout dropout;
166
- std::vector<MaskedHead> heads;
167
- };
168
-
169
- TORCH_MODULE(MultiMasked);
170
-
171
- class FeedForwardImpl : public torch::nn::Module {
172
- public:
173
- FeedForwardImpl(int d_model, float dropout)
174
- : net(register_module("net", torch::nn::Sequential(
175
- torch::nn::Linear(d_model, 4 * d_model),
176
- torch::nn::GELU(),
177
- torch::nn::Linear(4 * d_model, d_model),
178
- torch::nn::Dropout(dropout)
179
- ))) {}
180
-
181
- torch::Tensor forward(torch::Tensor x) {
182
- return net->forward(x);
183
- }
184
-
185
- private:
186
- torch::nn::Sequential net;
187
- };
188
-
189
- TORCH_MODULE(FeedForward);
190
-
191
- class BlockImpl : public torch::nn::Module {
192
- public:
193
- BlockImpl(int d_model, int n_head, float norm_eps, float dropout)
194
- : sa_masked(MultiMasked(d_model, n_head, dropout)),
195
- sa_unmasked(MultiUnMasked(d_model, n_head, dropout)),
196
- ffwd(FeedForward(d_model, dropout)),
197
- norm1(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))),
198
- norm2(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))) {}
199
-
200
- torch::Tensor forward(torch::Tensor x) {
201
- torch::Tensor x2 = x + sa_unmasked->forward(norm1->forward(x));
202
- x = x2 + ffwd->forward(norm2->forward(x2));
203
-
204
- x2 = x + sa_masked->forward(norm1->forward(x));
205
- x = x2 + ffwd->forward(norm2->forward(x2));
206
-
207
- return x;
208
- }
209
-
210
- private:
211
- MultiMasked sa_masked;
212
- MultiUnMasked sa_unmasked;
213
- FeedForward ffwd;
214
- torch::nn::LayerNorm norm1, norm2;
215
- };
216
-
217
- TORCH_MODULE(Block);
218
-
219
- class EnigmaImpl : public torch::nn::Module {
220
- public:
221
- EnigmaImpl(int vocab_size, int block_size, int d_model, int n_layer, int n_head, float dropout, float norm_eps)
222
- : toked_model(register_module("toked_model", torch::nn::Embedding(vocab_size, d_model))),
223
- pos_encod(register_module("pos_encod", torch::nn::Embedding(block_size, d_model))),
224
- norm_final(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))),
225
- linear_final(register_module("linear_final", torch::nn::Linear(d_model, vocab_size))) {
226
- for (int i = 0; i < n_layer; ++i) {
227
- block_layers.push_back(register_module("block" + std::to_string(i), Block(d_model, n_head, norm_eps, dropout)));
228
- }
229
- register_buffer("block_size", torch::tensor(block_size));
230
- _init_weights(this);
231
- }
232
-
233
- void _init_weights(torch::nn::Module* module) {
234
- auto parameters = module->named_parameters();
235
- for (auto& param : parameters) {
236
- if (param.key().find("weight") != std::string::npos) {
237
- torch::nn::init::normal_(param.value(), 0.0, 0.02);
238
- } else if (param.key().find("bias") != std::string::npos) {
239
- torch::nn::init::zeros_(param.value());
240
- }
241
- }
242
- }
243
-
244
- std::pair<torch::Tensor, torch::Tensor> forward(torch::Tensor idx, torch::Tensor targets=torch::Tensor()) {
245
- torch::Tensor toked_model_out = toked_model->forward(idx);
246
- torch::Tensor pos_encod_out = pos_encod->forward(torch::arange(idx.size(1)));
247
- torch::Tensor x = toked_model_out + pos_encod_out;
248
-
249
- for (auto& block : block_layers) {
250
- x = block->forward(x);
251
- }
252
-
253
- torch::Tensor logits = linear_final->forward(norm_final->forward(x));
254
-
255
- if (!targets.numel()) {
256
- return {logits, torch::Tensor()};
257
- } else {
258
- logits = logits.view({-1, logits.size(-1)});
259
- targets = targets.view({-1});
260
- torch::Tensor loss = torch::nn::functional::cross_entropy(logits, targets);
261
- return {logits, loss};
262
- }
263
- }
264
-
265
- std::vector<std::vector<std::pair<torch::Tensor, float>>> complex_generate(torch::Tensor idx, int max_new_tokens, float temperature=1.0, int top_k=3, int beam_width=5) {
266
- std::vector<std::vector<std::pair<torch::Tensor, float>>> completed_beams;
267
- torch::Tensor current_idx = idx.clone();
268
- std::vector<std::pair<torch::Tensor, float>> beam = {std::make_pair(current_idx, 0.0)};
269
-
270
- for (int i = 0; i < max_new_tokens; ++i) {
271
- std::vector<std::pair<torch::Tensor, float>> new_beam;
272
-
273
- for (auto& beam_item : beam) {
274
- torch::Tensor& current_idx = beam_item.first;
275
- torch::Tensor logits, loss;
276
- std::tie(logits, loss) = forward(current_idx);
277
- logits = logits.index({torch::indexing::Slice(), -1}); // Get last token predictions
278
-
279
- // Apply softmax and temperature
280
- torch::Tensor probs = torch::nn::functional::softmax(logits / temperature, -1);
281
-
282
- // Top-k sampling
283
- if (top_k > 0) {
284
- probs = top_k_filtering(probs, top_k);
285
- }
286
-
287
- // Sample from the distribution
288
- torch::Tensor sampled_idx = torch::multinomial(probs, beam_width, true);
289
-
290
- for (int j = 0; j < beam_width; ++j) {
291
- torch::Tensor new_idx = torch::cat({current_idx, sampled_idx.index({torch::indexing::Slice(), j})}, 1);
292
- torch::Tensor new_log_prob = beam_item.second + torch::log(probs.index({torch::indexing::Slice(), sampled_idx.index({torch::indexing::Slice(), j})}));
293
- new_beam.push_back(std::make_pair(new_idx, new_log_prob.item()));
294
- }
295
- }
296
-
297
- // Sort new beam by log probabilities
298
- std::sort(new_beam.begin(), new_beam.end(), [](const std::pair<torch::Tensor, float>& a, const std::pair<torch::Tensor, float>& b) {
299
- return a.second > b.second;
300
- });
301
-
302
- // Only keep top beams
303
- beam = std::vector<std::pair<torch::Tensor, float>>(new_beam.begin(), new_beam.begin() + beam_width);
304
- }
305
-
306
- completed_beams.push_back(beam);
307
- return completed_beams;
308
- }
309
-
310
- std::vector<std::vector<std::pair<torch::Tensor, float>>> top_k_filtering(torch::Tensor logits, int top_k) {
311
- torch::Tensor top_values, top_indices;
312
- std::tie(top_values, top_indices) = torch::topk(logits, top_k, -1);
313
-
314
- torch::Tensor min_value = torch::index_select(top_values, -1, torch::tensor({top_k-1}));
315
- torch::Tensor filtered_logits = torch::where(logits < min_value, torch::full_like(logits, -std::numeric_limits<float>::infinity()), logits);
316
- return filtered_logits;
317
- }
318
-
319
- private:
320
- torch::nn::Embedding toked_model, pos_encod;
321
- std::vector<Block> block_layers;
322
- torch::nn::LayerNorm norm_final;
323
- torch::nn::Linear linear_final;
324
- int block_size;
325
- };
326
-
327
- TORCH_MODULE(Enigma);
328
-
329
- int main() {
330
- // Set seed
331
- torch::manual_seed(1400);
332
-
333
- // Create model
334
- Enigma model(vocab_size, block_size, d_model, n_layer, n_head, dropout, norm_eps);
335
- model->to(device);
336
-
337
- // Define optimizer
338
- torch::optim::AdamW optimizer(model->parameters(), torch::optim::AdamWOptions(learning_rate));
339
-
340
- // Training loop
341
- std::vector<float> train_losses, val_losses;
342
- for (int iter = 0; iter < max_iters; ++iter) {
343
- if (iter % eval_interval == 0 || iter == max_iters - 1) {
344
- // Evaluate and print losses
345
- auto losses = estimate_loss();
346
- std::cout << "step " << iter << ": train loss " << losses["train"] << ", val loss " << losses["val"] << std::endl;
347
-
348
- // Save losses for plotting
349
- train_losses.push_back(losses["train"]);
350
- val_losses.push_back(losses["val"]);
351
- }
352
-
353
- // Get batch, forward pass, loss calculation, backward pass, optimizer step
354
- auto [xb, yb] = get_batch("train");
355
- torch::Tensor logits, loss;
356
- std::tie(logits, loss) = model->forward(xb, yb);
357
-
358
- optimizer.zero_grad();
359
- loss.backward();
360
- optimizer.step();
361
- }
362
-
363
- return 0;
364
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
enigma/generate.py DELETED
@@ -1,126 +0,0 @@
1
- import os
2
- current_directory = os.path.dirname(os.path.abspath(__file__))
3
- os.chdir(current_directory)
4
-
5
- with open('../parquet files/new_dna.txt', 'r', encoding='utf-8') as file:
6
- captions = file.read()
7
-
8
- print(f"{(len(captions)/1e6):.2f} million letters")
9
-
10
- from tokenizer import PerCharTokenizer
11
-
12
- tokenizer = PerCharTokenizer()
13
- vocab_size = tokenizer.vocab_size
14
-
15
- import torch
16
- import torch.nn as nn
17
- from torch.nn import functional as F
18
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
-
20
- from model import Transformer
21
- model = Transformer(vocab_size=vocab_size)
22
-
23
- class Generator(Transformer):
24
- def __init__(self, vocab_size):
25
- super().__init__()
26
- self.vocab_size = vocab_size
27
- self.block_size = Transformer.block_size
28
-
29
- def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
30
- """
31
- generate new tokens using the trained model
32
-
33
- Args:
34
- - idx (Tensor): input tensor representing initial token indices
35
- - max_new_tokens (int): max no of new tokens to generate
36
- - temperature (float): softmax temperature for sampling
37
- - top_k (int): no of top tokens to consider in sampling
38
-
39
- Returns:
40
- - generated_tokens (list): list of generated token indices
41
- """
42
- generated_tokens = []
43
-
44
- for _ in range(max_new_tokens):
45
- idx_cond = idx[:, -self.block_size:]
46
- logits, _ = self(idx_cond)
47
- logits = logits[:, -1, :]
48
-
49
- scaled_logits = logits / temperature
50
- if top_k > 0:
51
- scaled_logits = self._top_k_filtering(scaled_logits, top_k)
52
-
53
- probs = F.softmax(scaled_logits, dim=-1)
54
- sampled_idx = torch.multinomial(probs, num_samples=1)
55
- generated_tokens.append(sampled_idx.item())
56
- idx = torch.cat((idx, sampled_idx), dim=1)
57
-
58
- return generated_tokens
59
-
60
- def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):
61
- """
62
- Generate predictions for masked tokens using the trained model.
63
-
64
- Args:
65
- - idx (Tensor): input tensor representing token indices
66
- - masked_indices (Tensor): tensor of indices indicating masked positions
67
- - temperature (float): softmax temperature for sampling
68
- - top_k (int): no of top tokens to consider in sampling
69
-
70
- Returns:
71
- - predicted_tokens (Tensor): tensor of predicted token indices
72
- """
73
- B, T = idx.shape
74
-
75
- toked_model = self.toked_model(idx)
76
- pos_encod = self.pos_encod(torch.arange(T, device=device))
77
- x = toked_model + pos_encod
78
-
79
- for layer in self.enc_layer:
80
- x_out = layer(x)
81
-
82
- for layer in self.dec_layer:
83
- x_final = layer(x, x_out)
84
-
85
- x_masked = x_final.clone()
86
- x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))
87
-
88
- x_masked = self.norm_final(x_masked)
89
- logits = self.linear_final(x_masked)
90
-
91
- masked_logits = logits[masked_indices].view(-1, logits.size(-1))
92
- scaled_logits = masked_logits / temperature
93
- if top_k > 0:
94
- scaled_logits = self._top_k_filtering(scaled_logits, top_k)
95
-
96
- probs = F.softmax(scaled_logits, dim=-1)
97
- predicted_indices = torch.argmax(probs, dim=-1)
98
-
99
- return predicted_indices
100
-
101
- def _top_k_filtering(self, logits, top_k):
102
- """
103
- filter logits to keep only the top-k tokens
104
-
105
- Args:
106
- - logits (Tensor): input tensor representing unscaled logits
107
- - top_k (int): no of top tokens to keep
108
-
109
- Returns:
110
- - filtered_logits (Tensor): filtered logits with only top-k tokens remaining
111
- """
112
- values, indices = torch.topk(logits, top_k, dim=-1)
113
- min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
114
- filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
115
-
116
- return filtered_logits
117
-
118
- checkpoint_path = '../trained models/enigma_47m.pth'
119
- checkpoint = torch.load(checkpoint_path)
120
- model.load_state_dict(checkpoint)
121
- m = model.to(device)
122
-
123
- target_text = "AGTTCTGCGAT"
124
- context = torch.tensor([tokenizer.encode(target_text)], dtype=torch.long, device=device)
125
- generated_output = tokenizer.decode(Generator.generate(context, max_new_tokens=10, temperature=0.5, top_k=5))
126
- print(f"{target_text}{generated_output}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
enigma/model.py DELETED
@@ -1,388 +0,0 @@
1
- """
2
- transformer based model, but with few minimal tweaks
3
- trained a 2.5billion parameters model with current set configurations
4
- """
5
-
6
- import torch
7
- import json
8
- import os
9
- current_directory = os.path.dirname(os.path.abspath(__file__))
10
- os.chdir(current_directory)
11
-
12
- import torch.nn as nn
13
- from torch.nn import functional as F
14
-
15
- with open('config_enigma.json', 'r', encoding='utf-8') as file:
16
- params = json.load(file)
17
-
18
- batch_size = params['batch_size']
19
- block_size = params['block_size']
20
- n_head = params['n_head']
21
- d_model = params['d_model']
22
- n_layers = params['n_layer']
23
- dropout = params['dropout']
24
- norm_eps = params['norm_eps']
25
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
26
-
27
- class AttentionHead(nn.Module):
28
- """
29
- initialize a single head of self attention.
30
-
31
- Args:
32
- - d_model (int): dimensionality of the model's hidden layers
33
- - head_size (int): dimensionality of each attention head
34
- - dropout (float): dropout probability
35
- - block_size (int): the maximum sequence length for positional encoding
36
- """
37
- def __init__(self, d_model, head_size, dropout, block_size):
38
- super().__init__()
39
- self.key = nn.Linear(d_model, head_size, bias=True)
40
- self.query = nn.Linear(d_model, head_size, bias=True)
41
- self.value = nn.Linear(d_model, head_size, bias=False)
42
- self.dropout = nn.Dropout(dropout)
43
- self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
44
-
45
- self.rel_pos_emb = nn.Parameter(torch.randn(block_size, block_size, head_size))
46
-
47
- def forward(self, x, mask=False):
48
- """
49
- forward pass of a single attention head.
50
-
51
- Args:
52
- - x (Tensor): input tensor.
53
- - mask (bool): flag indicating whether to apply masking
54
-
55
- Returns:
56
- - out (Tensor): output tensor after self attention
57
- """
58
- B, T, C = x.shape
59
- key = self.key(x)
60
- query = self.query(x)
61
-
62
- scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)
63
- rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_emb[:T, :T])
64
- scores += rel_pos_scores
65
-
66
- if mask:
67
- scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
68
-
69
- weights = F.softmax(scores, dim=-1)
70
- weights = self.dropout(weights)
71
-
72
- value = self.value(x)
73
- out = torch.matmul(weights, value)
74
- return out
75
-
76
- class MultiHeadAttention(nn.Module):
77
- """
78
- initialize a multi-head attention module.
79
-
80
- Args:
81
- - d_model (int): dimensionality of the model's hidden layers
82
- - n_head (int): no of attention heads
83
- - dropout (float): dropout probability
84
- - block_size (int): context length
85
- """
86
- def __init__(self, d_model, n_head, dropout, block_size):
87
- head_size = d_model // n_head
88
- super().__init__()
89
- self.heads = nn.ModuleList([AttentionHead(d_model=d_model, dropout=dropout, head_size=head_size, block_size=block_size) for _ in range(n_head)])
90
- self.proj = nn.Linear(n_head * head_size, d_model)
91
- self.dropout = nn.Dropout(dropout)
92
-
93
- def forward(self, x, mask):
94
- """
95
- forward pass of the multi-head attention module
96
-
97
- Args:
98
- - x (Tensor): input tensor
99
- - mask (bool): flag indicating whether to apply masking
100
-
101
- Returns:
102
- - out (Tensor): output tensor after multi-head attention
103
-
104
- """
105
- out = torch.cat([h(x, mask=mask) for h in self.heads], dim=-1)
106
- out = self.dropout(self.proj(out))
107
- return out
108
-
109
- class FeedForward(nn.Module):
110
- """
111
- initialize a feedforward network module
112
-
113
- Args:
114
- - d_model (int): the dimensionality of the model's hidden layers
115
- - dropout (float): dropout probability
116
-
117
- """
118
- def __init__(self, d_model, dropout):
119
- super().__init__()
120
- self.net = nn.Sequential(
121
- nn.Linear(d_model, 10*d_model),
122
- nn.GELU(),
123
- nn.Linear(10*d_model, d_model),
124
- nn.Dropout(dropout)
125
- )
126
-
127
- def forward(self, x):
128
- """
129
- forward pass of the feedforward network module
130
-
131
- Args:
132
- - x (Tensor): input tensor
133
-
134
- Returns:
135
- - out (Tensor): output tensor after passing through the feedforward network
136
- """
137
- return self.net(x)
138
-
139
- class EncoderNetwork(nn.Module):
140
- """
141
- initialize an encoder network module
142
-
143
- Args:
144
- - d_model (int): dimensionality of the model's hidden layers
145
- - n_head (int): no of attention heads in multi-head attention layers
146
- - norm_eps (float): epsilon value for layer normalization
147
- - dropout (float): dropout probability
148
- - block_size (int): the maximum sequence length for positional encoding
149
- """
150
- def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
151
- super().__init__()
152
- self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
153
- self.ffwd = FeedForward(d_model, dropout)
154
- self.dropout = nn.Dropout(dropout)
155
- self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)
156
- self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)
157
-
158
- def forward(self, src):
159
- """
160
- forward pass of the encoder network module.
161
-
162
- Args:
163
- - src (Tensor): input tensor representing source data
164
-
165
- Returns:
166
- - src (Tensor): output tensor after passing through the encoder network
167
- """
168
- src2 = self.s_att(src, mask=False)
169
- src = src + self.dropout(src2)
170
- src = self.norm1(src)
171
-
172
- src2 = self.ffwd(src)
173
- src = src + self.dropout(src2)
174
- src = self.norm2(src)
175
-
176
- return src
177
-
178
- class DecoderNetwork(nn.Module):
179
- """
180
- initialize a decoder network module
181
-
182
- Args:
183
- - d_model (int): dimensionality of the model's hidden layers
184
- - n_head (int): no of attention heads in multi-head attention layers
185
- - norm_eps (float): epsilon value for layer normalization
186
- - dropout (float): dropout probability
187
- - block_size (int): the maximum sequence length for positional encoding
188
- """
189
- def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
190
- super().__init__()
191
- self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
192
- self.ffwd = FeedForward(d_model, dropout)
193
- self.dropout = nn.Dropout(dropout)
194
- self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)
195
- self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)
196
-
197
- def forward(self, src, att):
198
- """
199
- forward pass of the decoder network module.
200
-
201
- Args:
202
- - src (Tensor): input tensor, same as the encoder's inputs
203
- - trg (Tensor): encoder's attention matrix
204
-
205
- Returns:
206
- - src_f (Tensor): final output tensor
207
- """
208
- src2 = self.s_att(src, mask=True)
209
- src = src + self.dropout(src2)
210
- src = src + self.norm1(src)
211
-
212
- att = src + att
213
- att2 = self.s_att(att, mask=False)
214
- att2 = att + self.dropout(att2)
215
- trg = att2 + self.norm1(att2)
216
-
217
- src_f2 = self.ffwd(self.norm2(trg))
218
- src_f = src_f + self.dropout(src_f2)
219
- src_f = self.norm2(src_f)
220
-
221
- return src_f
222
-
223
- class Transformer(nn.Module):
224
- """
225
- initialize a Transformer model
226
-
227
- Args:
228
- - vocab_size (int): size of the vocabulary
229
- - d_model (int): dimensionality of the model's hidden layers
230
- - block_size (int): maximum sequence length for positional encoding/context length
231
- - n_layers (int): number of encoder and decoder layers in the Transformer
232
- - n_head (int): number of attention heads in multi-head attention layers
233
- - norm_eps (float): epsilon value for layer normalization
234
- - dropout (float): dropout probability
235
- """
236
- def __init__(self, vocab_size):
237
- super().__init__()
238
- self.block_size = block_size
239
- self.toked_model = nn.Embedding(vocab_size, d_model)
240
- self.pos_encod = nn.Embedding(block_size, d_model)
241
- self.enc_layer = nn.ModuleList([EncoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
242
- self.dec_layer = nn.ModuleList([DecoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
243
-
244
- self.norm_final = nn.LayerNorm(d_model)
245
- self.linear_final = nn.Linear(d_model, vocab_size)
246
- self.dropout = nn.Dropout(dropout)
247
- self.apply(self._init_weights)
248
-
249
- def _init_weights(self, module):
250
- """
251
- initialize weights of linear and embedding layers
252
-
253
- Args:
254
- - module (nn.Module): the module to initialize weights for
255
- """
256
- if isinstance(module, nn.Linear):
257
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
258
- if module.bias is not None:
259
- torch.nn.init.zeros_(module.bias.data)
260
- elif isinstance(module, nn.Embedding):
261
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
262
-
263
- def forward(self, idx, targets=None):
264
- """
265
- forward pass of the transformer model
266
-
267
- Args:
268
- - idx (Tensor): input tensor representing token indices
269
- - targets (Tensor): target tensor for computing loss during training
270
-
271
- Returns:
272
- - logits (Tensor): output logits from the final linear layer
273
- - loss (Tensor): optional. computed cross-entropy loss if targets are provided, else None
274
- """
275
- B, T = idx.shape
276
-
277
- toked_model = self.toked_model(idx)
278
- pos_encod = self.pos_encod(torch.arange(T, device=device))
279
- x = toked_model + pos_encod
280
-
281
- for layer in self.enc_layer:
282
- x_out = layer(x)
283
-
284
- for layer in self.dec_layer:
285
- x_final = layer(x, x_out)
286
-
287
- x_final = self.norm_final(x_final)
288
- logits = self.linear_final(x_final)
289
-
290
- if targets is None:
291
- loss = None
292
-
293
- else:
294
- B, T, C = logits.shape
295
- logits = logits.view(B*T, C)
296
- targets = targets.view(B*T)
297
- loss = F.cross_entropy(logits, targets)
298
-
299
- return logits, loss
300
-
301
- def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
302
- """
303
- generate new tokens using the trained model
304
-
305
- Args:
306
- - idx (Tensor): input tensor representing initial token indices
307
- - max_new_tokens (int): max no of new tokens to generate
308
- - temperature (float): softmax temperature for sampling
309
- - top_k (int): no of top tokens to consider in sampling
310
-
311
- Returns:
312
- - generated_tokens (list): list of generated token indices
313
- """
314
- generated_tokens = []
315
-
316
- for _ in range(max_new_tokens):
317
- idx_cond = idx[:, -self.block_size:]
318
- logits, _ = self(idx_cond)
319
- logits = logits[:, -1, :]
320
-
321
- scaled_logits = logits / temperature
322
- if top_k > 0:
323
- scaled_logits = self._top_k_filtering(scaled_logits, top_k)
324
-
325
- probs = F.softmax(scaled_logits, dim=-1)
326
- sampled_idx = torch.multinomial(probs, num_samples=1)
327
- generated_tokens.append(sampled_idx.item())
328
- idx = torch.cat((idx, sampled_idx), dim=1)
329
-
330
- return generated_tokens
331
-
332
- def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):
333
- """
334
- Generate predictions for masked tokens using the trained model.
335
-
336
- Args:
337
- - idx (Tensor): input tensor representing token indices
338
- - masked_indices (Tensor): tensor of indices indicating masked positions
339
- - temperature (float): softmax temperature for sampling
340
- - top_k (int): no of top tokens to consider in sampling
341
-
342
- Returns:
343
- - predicted_tokens (Tensor): tensor of predicted token indices
344
- """
345
- B, T = idx.shape
346
-
347
- toked_model = self.toked_model(idx)
348
- pos_encod = self.pos_encod(torch.arange(T, device=device))
349
- x = toked_model + pos_encod
350
-
351
- for layer in self.enc_layer:
352
- x_out = layer(x)
353
-
354
- for layer in self.dec_layer:
355
- x_final = layer(x, x_out)
356
-
357
- x_masked = x_final.clone()
358
- x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))
359
-
360
- x_masked = self.norm_final(x_masked)
361
- logits = self.linear_final(x_masked)
362
-
363
- masked_logits = logits[masked_indices].view(-1, logits.size(-1))
364
- scaled_logits = masked_logits / temperature
365
- if top_k > 0:
366
- scaled_logits = self._top_k_filtering(scaled_logits, top_k)
367
-
368
- probs = F.softmax(scaled_logits, dim=-1)
369
- predicted_indices = torch.argmax(probs, dim=-1)
370
-
371
- return predicted_indices
372
-
373
- def _top_k_filtering(self, logits, top_k):
374
- """
375
- filter logits to keep only the top-k tokens
376
-
377
- Args:
378
- - logits (Tensor): input tensor representing unscaled logits
379
- - top_k (int): no of top tokens to keep
380
-
381
- Returns:
382
- - filtered_logits (Tensor): filtered logits with only top-k tokens remaining
383
- """
384
- values, indices = torch.topk(logits, top_k, dim=-1)
385
- min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
386
- filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
387
-
388
- return filtered_logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
enigma/run.py DELETED
@@ -1,100 +0,0 @@
1
- """
2
- use this file to train the model
3
-
4
- working:
5
- - imports vatious dependencies first, and then loads the training data
6
- - tokenizes it, per-character basis
7
- - loads the required hyper-parameters and the model file
8
- - trains it till 'max_iters' and saves the model state, and generates outputs
9
-
10
- with the current set configuration, model can reach upto ~60million parameters
11
- and can become ~99% accurate with next token prediction
12
- """
13
-
14
- import torch
15
- import json
16
- import os
17
- current_directory = os.path.dirname(os.path.abspath(__file__))
18
- os.chdir(current_directory)
19
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
-
21
- with open('../parquet files/new_dna.txt', 'r', encoding='utf-8') as file:
22
- captions = file.read()
23
-
24
- print(f"{(len(captions)/1e6):.2f} million letters")
25
-
26
- from ..tokenizer import PerCharTokenizer
27
-
28
- tokenizer = PerCharTokenizer()
29
- vocab_size = tokenizer.vocab_size
30
- # Train and test splits
31
- data = torch.tensor(tokenizer.encode(captions), dtype=torch.long)
32
- n = int(0.9*len(data)) # first 90% will be train, rest val
33
- train_data = data[:n]
34
- val_data = data[n:]
35
-
36
- with open('/config_enigma.json', 'r', encoding='utf-8') as file:
37
- params = json.load(file)
38
-
39
- # required parameters
40
- batch_size = params['batch_size']
41
- block_size = params['block_size']
42
- max_iters = params['max_iters']
43
- eval_interval = params['eval_interval']
44
- eval_iters = params['eval_iters']
45
- learning_rate = params['learning_rate']
46
-
47
- torch.manual_seed(1400)
48
- # data loading
49
- def get_batch(split):
50
- # generate a small batch of data of inputs x and targets y
51
- data = train_data if split == 'train' else val_data
52
- ix = torch.randint(len(data) - block_size, (batch_size,))
53
- x = torch.stack([data[i:i+block_size] for i in ix])
54
- y = torch.stack([data[i+1:i+block_size+1] for i in ix])
55
- x, y = x.to(device), y.to(device)
56
- return x, y
57
-
58
- @torch.no_grad()
59
- def estimate_loss():
60
- out = {}
61
- model.eval()
62
- for split in ['train', 'val']:
63
- losses = torch.zeros(eval_iters)
64
- for k in range(eval_iters):
65
- X, Y = get_batch(split)
66
- logits, loss = model(X, Y)
67
- losses[k] = loss.item()
68
- out[split] = losses.mean()
69
- model.train()
70
- return out
71
-
72
- from model import Transformer
73
- model = Transformer(vocab_size=vocab_size)
74
- m = model.to(device)
75
-
76
- # no of parameters
77
- n_param = sum(p.numel() for p in m.parameters())/1e6
78
- print(f"{n_param:.2f} million")
79
- optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
80
- steps = []
81
- train_losses = []
82
- val_losses = []
83
-
84
- for iter in range(max_iters):
85
-
86
- if iter % eval_interval == 0 or iter == max_iters - 1:
87
- losses = estimate_loss()
88
- print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
89
-
90
- steps.append(iter)
91
- train_losses.append(losses['train'])
92
- val_losses.append(losses['val'])
93
-
94
- xb, yb = get_batch('train')
95
- logits, loss = model(xb, yb)
96
- optimizer.zero_grad(set_to_none=True)
97
- loss.backward()
98
- optimizer.step()
99
-
100
- torch.save(model.state_dict(), f'enigma_{n_param:.0f}m.pth')