Initial commit for files
Browse files- README.md +57 -0
- exercise-1.ipynb +1167 -0
- exercise-2.ipynb +432 -0
- exercise-3.ipynb +372 -0
- final-code.ipynb +482 -0
- names.txt +0 -0
- starter-code.ipynb +623 -0
README.md
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## SET 1 - MAKEMORE (PART 4) 🔗
|
2 |
+
|
3 |
+
[](https://muzzammilshah.github.io/Road-to-GPT/Makemore-part4/)
|
4 |
+

|
5 |
+
[](https://github.com/MuzzammilShah/NeuralNetworks-LanguageModels-4/commits/main)
|
6 |
+

|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
### **Overview**
|
11 |
+
In this repository, we take the 2-layer MLP (with BatchNorm) from the previous video/lecture and **backpropagate through it manually without using PyTorch autograd's loss.backward()**. So we will be manually backpropagating through the cross entropy loss, 2nd linear layer, tanh, batchnorm, 1st linear layer, and the embedding table.
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
### **🗂️Repository Structure**
|
16 |
+
|
17 |
+
```plaintext
|
18 |
+
├── .gitignore
|
19 |
+
├── starter-code.ipynb
|
20 |
+
├── exercise-1.ipynb
|
21 |
+
├── exercise-2.ipynb
|
22 |
+
├── exercise-3.ipynb
|
23 |
+
├── final-code.ipynb
|
24 |
+
├── README.md
|
25 |
+
└── names.txt
|
26 |
+
```
|
27 |
+
|
28 |
+
- **Jupyter Notebooks**: Step-by-step implementation and exploration of the concepts.
|
29 |
+
- **README.md**: Overview and guide for this repository.
|
30 |
+
- **names.txt**: Supplementary data file used in training the model.
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
### **📄Instructions**
|
35 |
+
|
36 |
+
To get the best understanding:
|
37 |
+
|
38 |
+
- The format and structure of this particular section of the project will be different from what I've implemented so far, as Andrej Karpathy himself had quoted- "I recommend you work through the exercise yourself but work with it in tandem and whenever you are stuck unpause the video and see me give away the answer. This video is not super intended to be simply watched."
|
39 |
+
|
40 |
+
- So keeping this in mind, we will be focusing more on the notebook itself and only making notes whenever absolutely necessary.
|
41 |
+
|
42 |
+
- You will find my notes/key points as comments in the code cells (Apart from the time stamps with necessary headers which will be in their normal format ofcourse)
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
### **⭐Documentation**
|
47 |
+
|
48 |
+
For a better reading experience and detailed notes, visit my **[Road to GPT Documentation Site](https://muzzammilshah.github.io/Road-to-GPT/)**.
|
49 |
+
|
50 |
+
> **💡Pro Tip**: This site provides an interactive and visually rich explanation of the notes and code. It is highly recommended you view this project from there.
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
### **✍🏻Acknowledgments**
|
55 |
+
Notes and implementations inspired by the **Makemore - Part 4** video by [Andrej Karpathy](https://karpathy.ai/).
|
56 |
+
|
57 |
+
For more of my projects, visit my [Portfolio Site](https://muhammedshah.com).
|
exercise-1.ipynb
ADDED
@@ -0,0 +1,1167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "rToK0Tku8PPn"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"## makemore: becoming a backprop ninja"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"metadata": {
|
16 |
+
"id": "ChBbac4y8PPq"
|
17 |
+
},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"import torch\n",
|
21 |
+
"import torch.nn.functional as F\n",
|
22 |
+
"import matplotlib.pyplot as plt # for making figures\n",
|
23 |
+
"%matplotlib inline"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": 2,
|
29 |
+
"metadata": {
|
30 |
+
"id": "klmu3ZG08PPr"
|
31 |
+
},
|
32 |
+
"outputs": [
|
33 |
+
{
|
34 |
+
"name": "stdout",
|
35 |
+
"output_type": "stream",
|
36 |
+
"text": [
|
37 |
+
"32033\n",
|
38 |
+
"15\n",
|
39 |
+
"['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']\n"
|
40 |
+
]
|
41 |
+
}
|
42 |
+
],
|
43 |
+
"source": [
|
44 |
+
"# read in all the words\n",
|
45 |
+
"words = open('names.txt', 'r').read().splitlines()\n",
|
46 |
+
"print(len(words))\n",
|
47 |
+
"print(max(len(w) for w in words))\n",
|
48 |
+
"print(words[:8])"
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"execution_count": 3,
|
54 |
+
"metadata": {
|
55 |
+
"id": "BCQomLE_8PPs"
|
56 |
+
},
|
57 |
+
"outputs": [
|
58 |
+
{
|
59 |
+
"name": "stdout",
|
60 |
+
"output_type": "stream",
|
61 |
+
"text": [
|
62 |
+
"{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}\n",
|
63 |
+
"27\n"
|
64 |
+
]
|
65 |
+
}
|
66 |
+
],
|
67 |
+
"source": [
|
68 |
+
"# build the vocabulary of characters and mappings to/from integers\n",
|
69 |
+
"chars = sorted(list(set(''.join(words))))\n",
|
70 |
+
"stoi = {s:i+1 for i,s in enumerate(chars)}\n",
|
71 |
+
"stoi['.'] = 0\n",
|
72 |
+
"itos = {i:s for s,i in stoi.items()}\n",
|
73 |
+
"vocab_size = len(itos)\n",
|
74 |
+
"print(itos)\n",
|
75 |
+
"print(vocab_size)"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": 4,
|
81 |
+
"metadata": {
|
82 |
+
"id": "V_zt2QHr8PPs"
|
83 |
+
},
|
84 |
+
"outputs": [
|
85 |
+
{
|
86 |
+
"name": "stdout",
|
87 |
+
"output_type": "stream",
|
88 |
+
"text": [
|
89 |
+
"torch.Size([182625, 3]) torch.Size([182625])\n",
|
90 |
+
"torch.Size([22655, 3]) torch.Size([22655])\n",
|
91 |
+
"torch.Size([22866, 3]) torch.Size([22866])\n"
|
92 |
+
]
|
93 |
+
}
|
94 |
+
],
|
95 |
+
"source": [
|
96 |
+
"# build the dataset\n",
|
97 |
+
"block_size = 3 # context length: how many characters do we take to predict the next one?\n",
|
98 |
+
"\n",
|
99 |
+
"def build_dataset(words):\n",
|
100 |
+
" X, Y = [], []\n",
|
101 |
+
"\n",
|
102 |
+
" for w in words:\n",
|
103 |
+
" context = [0] * block_size\n",
|
104 |
+
" for ch in w + '.':\n",
|
105 |
+
" ix = stoi[ch]\n",
|
106 |
+
" X.append(context)\n",
|
107 |
+
" Y.append(ix)\n",
|
108 |
+
" context = context[1:] + [ix] # crop and append\n",
|
109 |
+
"\n",
|
110 |
+
" X = torch.tensor(X)\n",
|
111 |
+
" Y = torch.tensor(Y)\n",
|
112 |
+
" print(X.shape, Y.shape)\n",
|
113 |
+
" return X, Y\n",
|
114 |
+
"\n",
|
115 |
+
"import random\n",
|
116 |
+
"random.seed(42)\n",
|
117 |
+
"random.shuffle(words)\n",
|
118 |
+
"n1 = int(0.8*len(words))\n",
|
119 |
+
"n2 = int(0.9*len(words))\n",
|
120 |
+
"\n",
|
121 |
+
"Xtr, Ytr = build_dataset(words[:n1]) # 80%\n",
|
122 |
+
"Xdev, Ydev = build_dataset(words[n1:n2]) # 10%\n",
|
123 |
+
"Xte, Yte = build_dataset(words[n2:]) # 10%"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"cell_type": "code",
|
128 |
+
"execution_count": 5,
|
129 |
+
"metadata": {
|
130 |
+
"id": "eg20-vsg8PPt"
|
131 |
+
},
|
132 |
+
"outputs": [],
|
133 |
+
"source": [
|
134 |
+
"# ok biolerplate done, now we get to the action:"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": 5,
|
140 |
+
"metadata": {
|
141 |
+
"id": "MJPU8HT08PPu"
|
142 |
+
},
|
143 |
+
"outputs": [],
|
144 |
+
"source": [
|
145 |
+
"# utility function we will use later when comparing manual gradients to PyTorch gradients\n",
|
146 |
+
"def cmp(s, dt, t):\n",
|
147 |
+
" ex = torch.all(dt == t.grad).item()\n",
|
148 |
+
" app = torch.allclose(dt, t.grad)\n",
|
149 |
+
" maxdiff = (dt - t.grad).abs().max().item()\n",
|
150 |
+
" print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": 26,
|
156 |
+
"metadata": {
|
157 |
+
"id": "ZlFLjQyT8PPu"
|
158 |
+
},
|
159 |
+
"outputs": [
|
160 |
+
{
|
161 |
+
"name": "stdout",
|
162 |
+
"output_type": "stream",
|
163 |
+
"text": [
|
164 |
+
"4137\n"
|
165 |
+
]
|
166 |
+
}
|
167 |
+
],
|
168 |
+
"source": [
|
169 |
+
"n_embd = 10 # the dimensionality of the character embedding vectors\n",
|
170 |
+
"n_hidden = 64 # the number of neurons in the hidden layer of the MLP\n",
|
171 |
+
"\n",
|
172 |
+
"g = torch.Generator().manual_seed(2147483647) # for reproducibility\n",
|
173 |
+
"C = torch.randn((vocab_size, n_embd), generator=g)\n",
|
174 |
+
"# Layer 1\n",
|
175 |
+
"W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)\n",
|
176 |
+
"b1 = torch.randn(n_hidden, generator=g) * 0.1 # using b1 just for fun, it's useless because of BN\n",
|
177 |
+
"# Layer 2\n",
|
178 |
+
"W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1\n",
|
179 |
+
"b2 = torch.randn(vocab_size, generator=g) * 0.1\n",
|
180 |
+
"# BatchNorm parameters\n",
|
181 |
+
"bngain = torch.randn((1, n_hidden))*0.1 + 1.0\n",
|
182 |
+
"bnbias = torch.randn((1, n_hidden))*0.1\n",
|
183 |
+
"\n",
|
184 |
+
"# Note: I am initializating many of these parameters in non-standard ways\n",
|
185 |
+
"# because sometimes initializating with e.g. all zeros could mask an incorrect\n",
|
186 |
+
"# implementation of the backward pass.\n",
|
187 |
+
"\n",
|
188 |
+
"parameters = [C, W1, b1, W2, b2, bngain, bnbias]\n",
|
189 |
+
"print(sum(p.nelement() for p in parameters)) # number of parameters in total\n",
|
190 |
+
"for p in parameters:\n",
|
191 |
+
" p.requires_grad = True"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"cell_type": "code",
|
196 |
+
"execution_count": 27,
|
197 |
+
"metadata": {
|
198 |
+
"id": "QY-y96Y48PPv"
|
199 |
+
},
|
200 |
+
"outputs": [],
|
201 |
+
"source": [
|
202 |
+
"batch_size = 32\n",
|
203 |
+
"n = batch_size # a shorter variable also, for convenience\n",
|
204 |
+
"# construct a minibatch\n",
|
205 |
+
"ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)\n",
|
206 |
+
"Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y"
|
207 |
+
]
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "code",
|
211 |
+
"execution_count": 28,
|
212 |
+
"metadata": {
|
213 |
+
"id": "8ofj1s6d8PPv"
|
214 |
+
},
|
215 |
+
"outputs": [
|
216 |
+
{
|
217 |
+
"data": {
|
218 |
+
"text/plain": [
|
219 |
+
"tensor(3.3221, grad_fn=<NegBackward0>)"
|
220 |
+
]
|
221 |
+
},
|
222 |
+
"execution_count": 28,
|
223 |
+
"metadata": {},
|
224 |
+
"output_type": "execute_result"
|
225 |
+
}
|
226 |
+
],
|
227 |
+
"source": [
|
228 |
+
"# forward pass, \"chunkated\" into smaller steps that are possible to backward one at a time\n",
|
229 |
+
"\n",
|
230 |
+
"emb = C[Xb] # embed the characters into vectors\n",
|
231 |
+
"embcat = emb.view(emb.shape[0], -1) # concatenate the vectors\n",
|
232 |
+
"# Linear layer 1\n",
|
233 |
+
"hprebn = embcat @ W1 + b1 # hidden layer pre-activation\n",
|
234 |
+
"# BatchNorm layer\n",
|
235 |
+
"bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n",
|
236 |
+
"bndiff = hprebn - bnmeani\n",
|
237 |
+
"bndiff2 = bndiff**2\n",
|
238 |
+
"bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)\n",
|
239 |
+
"bnvar_inv = (bnvar + 1e-5)**-0.5\n",
|
240 |
+
"bnraw = bndiff * bnvar_inv\n",
|
241 |
+
"hpreact = bngain * bnraw + bnbias\n",
|
242 |
+
"# Non-linearity\n",
|
243 |
+
"h = torch.tanh(hpreact) # hidden layer\n",
|
244 |
+
"# Linear layer 2\n",
|
245 |
+
"logits = h @ W2 + b2 # output layer\n",
|
246 |
+
"# cross entropy loss (same as F.cross_entropy(logits, Yb))\n",
|
247 |
+
"logit_maxes = logits.max(1, keepdim=True).values\n",
|
248 |
+
"norm_logits = logits - logit_maxes # subtract max for numerical stability\n",
|
249 |
+
"counts = norm_logits.exp()\n",
|
250 |
+
"counts_sum = counts.sum(1, keepdims=True) #DONE\n",
|
251 |
+
"counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact... #DONE\n",
|
252 |
+
"probs = counts * counts_sum_inv #DONE\n",
|
253 |
+
"logprobs = probs.log() #DONE\n",
|
254 |
+
"loss = -logprobs[range(n), Yb].mean() #DONE\n",
|
255 |
+
"\n",
|
256 |
+
"# PyTorch backward pass\n",
|
257 |
+
"for p in parameters:\n",
|
258 |
+
" p.grad = None\n",
|
259 |
+
"for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way\n",
|
260 |
+
" norm_logits, logit_maxes, logits, h, hpreact, bnraw,\n",
|
261 |
+
" bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,\n",
|
262 |
+
" embcat, emb]:\n",
|
263 |
+
" t.retain_grad()\n",
|
264 |
+
"loss.backward()\n",
|
265 |
+
"loss"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"cell_type": "markdown",
|
270 |
+
"metadata": {},
|
271 |
+
"source": [
|
272 |
+
"---------\n",
|
273 |
+
"\n",
|
274 |
+
"### **EXERCISE 1**"
|
275 |
+
]
|
276 |
+
},
|
277 |
+
{
|
278 |
+
"cell_type": "markdown",
|
279 |
+
"metadata": {},
|
280 |
+
"source": [
|
281 |
+
"[13:01](https://www.youtube.com/watch?v=q8SA3rM6ckI&t=781s) to [19:05](https://youtu.be/q8SA3rM6ckI?si=mm8M8feWFToF4STA&t=1145) `cmp('logprobs', dlogprobs, logprobs)`"
|
282 |
+
]
|
283 |
+
},
|
284 |
+
{
|
285 |
+
"cell_type": "code",
|
286 |
+
"execution_count": 9,
|
287 |
+
"metadata": {},
|
288 |
+
"outputs": [
|
289 |
+
{
|
290 |
+
"name": "stdout",
|
291 |
+
"output_type": "stream",
|
292 |
+
"text": [
|
293 |
+
"torch.Size([32, 27])\n"
|
294 |
+
]
|
295 |
+
},
|
296 |
+
{
|
297 |
+
"data": {
|
298 |
+
"text/plain": [
|
299 |
+
"tensor([-4.0562, -3.0820, -3.6629, -3.2621, -4.1229, -3.4201, -3.2428, -3.9554,\n",
|
300 |
+
" -3.1259, -4.2500, -3.1582, -1.6256, -2.8483, -2.9654, -2.9990, -3.1882,\n",
|
301 |
+
" -3.9132, -3.0643, -3.5065, -3.5153, -2.8832, -3.0837, -4.2941, -4.0007,\n",
|
302 |
+
" -3.4440, -2.9220, -3.1386, -3.8946, -2.6488, -3.5292, -3.3408, -3.1560],\n",
|
303 |
+
" grad_fn=<IndexBackward0>)"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
"execution_count": 9,
|
307 |
+
"metadata": {},
|
308 |
+
"output_type": "execute_result"
|
309 |
+
}
|
310 |
+
],
|
311 |
+
"source": [
|
312 |
+
"print(logprobs.shape)\n",
|
313 |
+
"logprobs[range(n), Yb]"
|
314 |
+
]
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"cell_type": "code",
|
318 |
+
"execution_count": 10,
|
319 |
+
"metadata": {},
|
320 |
+
"outputs": [
|
321 |
+
{
|
322 |
+
"name": "stdout",
|
323 |
+
"output_type": "stream",
|
324 |
+
"text": [
|
325 |
+
"torch.Size([32])\n"
|
326 |
+
]
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"data": {
|
330 |
+
"text/plain": [
|
331 |
+
"tensor([ 8, 14, 15, 22, 0, 19, 9, 14, 5, 1, 20, 3, 8, 14, 12, 0, 11, 0,\n",
|
332 |
+
" 26, 9, 25, 0, 1, 1, 7, 18, 9, 3, 5, 9, 0, 18])"
|
333 |
+
]
|
334 |
+
},
|
335 |
+
"execution_count": 10,
|
336 |
+
"metadata": {},
|
337 |
+
"output_type": "execute_result"
|
338 |
+
}
|
339 |
+
],
|
340 |
+
"source": [
|
341 |
+
"print(Yb.shape)\n",
|
342 |
+
"Yb"
|
343 |
+
]
|
344 |
+
},
|
345 |
+
{
|
346 |
+
"cell_type": "code",
|
347 |
+
"execution_count": null,
|
348 |
+
"metadata": {},
|
349 |
+
"outputs": [],
|
350 |
+
"source": [
|
351 |
+
"#simple breakdown\n",
|
352 |
+
"#now here we know there are 32 examples, for explaination lets assume we only have 3 in total i.e. a,b,c\n",
|
353 |
+
"\n",
|
354 |
+
"#loss = - (a + b + c) / 3 ==> so we are basically doing the mean calculation here\n",
|
355 |
+
"#loss = - (1/3a + 1/3b + 1/3c) ==> same equation\n",
|
356 |
+
"#so now, when we take the derivative wrt a\n",
|
357 |
+
"#dloss/da = -1/3 ==>where 3 is the number of elements we consider, so we can also say that it is -1/n, therefore\n",
|
358 |
+
"#dloss/dn = -1/n"
|
359 |
+
]
|
360 |
+
},
|
361 |
+
{
|
362 |
+
"cell_type": "code",
|
363 |
+
"execution_count": 29,
|
364 |
+
"metadata": {},
|
365 |
+
"outputs": [
|
366 |
+
{
|
367 |
+
"name": "stdout",
|
368 |
+
"output_type": "stream",
|
369 |
+
"text": [
|
370 |
+
"logprobs | exact: True | approximate: True | maxdiff: 0.0\n"
|
371 |
+
]
|
372 |
+
}
|
373 |
+
],
|
374 |
+
"source": [
|
375 |
+
"#So based on our calculation above\n",
|
376 |
+
"dlogprobs = torch.zeros_like(logprobs) #same as torch.zeros((32, 27)) as we need the same shape as logprobs. So instead of hardcoding the values we did this\n",
|
377 |
+
"dlogprobs[range(n), Yb] = -1.0/n #as we need to do it for each of the elements in the array\n",
|
378 |
+
"\n",
|
379 |
+
"#Now, lets check\n",
|
380 |
+
"cmp('logprobs', dlogprobs, logprobs)"
|
381 |
+
]
|
382 |
+
},
|
383 |
+
{
|
384 |
+
"cell_type": "markdown",
|
385 |
+
"metadata": {},
|
386 |
+
"source": [
|
387 |
+
"[19:06](https://youtu.be/q8SA3rM6ckI?si=mO61nJLwtQpxsjju&t=1146) to [20:55](https://youtu.be/q8SA3rM6ckI?si=fgJsPGOCdJIIRYC9&t=1255) `cmp('probs', dprobs, probs)`"
|
388 |
+
]
|
389 |
+
},
|
390 |
+
{
|
391 |
+
"cell_type": "code",
|
392 |
+
"execution_count": 30,
|
393 |
+
"metadata": {},
|
394 |
+
"outputs": [
|
395 |
+
{
|
396 |
+
"name": "stdout",
|
397 |
+
"output_type": "stream",
|
398 |
+
"text": [
|
399 |
+
"probs | exact: True | approximate: True | maxdiff: 0.0\n"
|
400 |
+
]
|
401 |
+
}
|
402 |
+
],
|
403 |
+
"source": [
|
404 |
+
"dprobs = (1.0/probs) * dlogprobs #we had to take the derivative of logprobs, which was 1/x --> d/dx(log(x)) = 1/x \n",
|
405 |
+
"#then we multiplied it with dlogprobs (the one we calculated before this for the chainrule)\n",
|
406 |
+
"\n",
|
407 |
+
"cmp('probs', dprobs, probs)"
|
408 |
+
]
|
409 |
+
},
|
410 |
+
{
|
411 |
+
"cell_type": "markdown",
|
412 |
+
"metadata": {},
|
413 |
+
"source": [
|
414 |
+
"[20:56](https://youtu.be/q8SA3rM6ckI?si=sNM67lNSfsmUke2Y&t=1256) to [26:21](https://youtu.be/q8SA3rM6ckI?si=5MWGHdf1v-72g5ib&t=1581) `cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)`"
|
415 |
+
]
|
416 |
+
},
|
417 |
+
{
|
418 |
+
"cell_type": "code",
|
419 |
+
"execution_count": 31,
|
420 |
+
"metadata": {},
|
421 |
+
"outputs": [
|
422 |
+
{
|
423 |
+
"name": "stdout",
|
424 |
+
"output_type": "stream",
|
425 |
+
"text": [
|
426 |
+
"counts_sum_inv | exact: True | approximate: True | maxdiff: 0.0\n"
|
427 |
+
]
|
428 |
+
}
|
429 |
+
],
|
430 |
+
"source": [
|
431 |
+
"# probs = counts * counts_sum_inv, now here before we do the multiplication, take a look at the matrix dimensions using `.shape`\n",
|
432 |
+
"# You would see that `counts` would have 3x3 and `counts_sum_inv` will have 3x1\n",
|
433 |
+
"# So before the backpropagation calculation, there is 'broadcasting' happening where the value of b is been replicated/broadcasted multiple time across the matrix\n",
|
434 |
+
"\n",
|
435 |
+
"# Example\n",
|
436 |
+
"# c = a * b\n",
|
437 |
+
"# a[3x3] * b[3x1] ---->\n",
|
438 |
+
"# a[1,1]*b1 + a[1,2]*b1 + a[1,3]*b1\n",
|
439 |
+
"# a[2,1]*b2 + a[2,2]*b2 + a[2,3]*b2\n",
|
440 |
+
"# a[3,1]*b3 + a[3,2]*b3 + a[2,3]*b3\n",
|
441 |
+
"# ====> c[3x3]\n",
|
442 |
+
"\n",
|
443 |
+
"# The point of this is just to show that there are two things happening internally: The broadcasting and then the backpropagation\n",
|
444 |
+
"\n",
|
445 |
+
"# (first case) The derivative of c wrt b will be a\n",
|
446 |
+
"# So, here just `counts` will remain -> then `dprobs` is multiplied because chain rule.\n",
|
447 |
+
"# Finally, in order to make `dcounts_sum_inv` the same dimension as `counts_sum_inv` we sum all of them by 1 and also keepdims as true\n",
|
448 |
+
"\n",
|
449 |
+
"dcounts_sum_inv = (counts * dprobs).sum(1, keepdims=True) # So this is our final manually calcualted equation\n",
|
450 |
+
"\n",
|
451 |
+
"cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)"
|
452 |
+
]
|
453 |
+
},
|
454 |
+
{
|
455 |
+
"cell_type": "markdown",
|
456 |
+
"metadata": {},
|
457 |
+
"source": [
|
458 |
+
"[26:26](https://youtu.be/q8SA3rM6ckI?si=TBwv2QkGmkp-d8JR&t=1586) to [27:56](https://youtu.be/q8SA3rM6ckI?si=awbZx9fZ_-WB_q5M&t=1676) first contribution of `counts`"
|
459 |
+
]
|
460 |
+
},
|
461 |
+
{
|
462 |
+
"cell_type": "code",
|
463 |
+
"execution_count": 32,
|
464 |
+
"metadata": {},
|
465 |
+
"outputs": [],
|
466 |
+
"source": [
|
467 |
+
"# Here we have to calculate the second half of `dcounts` i.e. (Second case) The derivative of c wrt a will be b\n",
|
468 |
+
"\n",
|
469 |
+
"dcounts = counts_sum_inv * dprobs\n",
|
470 |
+
"\n",
|
471 |
+
"#but we cant compare it yet as `counts` is later depended on top again as well, which we will check"
|
472 |
+
]
|
473 |
+
},
|
474 |
+
{
|
475 |
+
"cell_type": "markdown",
|
476 |
+
"metadata": {},
|
477 |
+
"source": [
|
478 |
+
"[27:57](https://youtu.be/q8SA3rM6ckI?si=APAFn28Pf8HVpbM3&t=1677) to [28:59](https://youtu.be/q8SA3rM6ckI?si=O5ja7cEm2xS_yuzN&t=1740) `cmp('counts_sum', dcounts_sum, counts_sum)`"
|
479 |
+
]
|
480 |
+
},
|
481 |
+
{
|
482 |
+
"cell_type": "code",
|
483 |
+
"execution_count": 33,
|
484 |
+
"metadata": {},
|
485 |
+
"outputs": [
|
486 |
+
{
|
487 |
+
"name": "stdout",
|
488 |
+
"output_type": "stream",
|
489 |
+
"text": [
|
490 |
+
"counts_sum | exact: True | approximate: True | maxdiff: 0.0\n"
|
491 |
+
]
|
492 |
+
}
|
493 |
+
],
|
494 |
+
"source": [
|
495 |
+
"# counts_sum_inv = counts_sum**-1\n",
|
496 |
+
"\n",
|
497 |
+
"# Okay so for this, the derivative of x^-1 is -(x^-2)\n",
|
498 |
+
"\n",
|
499 |
+
"dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv #Remember for this its the one before the `26:26 to 27:56 first contribution of counts` section\n",
|
500 |
+
"\n",
|
501 |
+
"cmp('counts_sum', dcounts_sum, counts_sum)"
|
502 |
+
]
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"cell_type": "markdown",
|
506 |
+
"metadata": {},
|
507 |
+
"source": [
|
508 |
+
"[29:00](https://youtu.be/q8SA3rM6ckI?si=UsxgAcBfiU5GAHaz&t=1740) to [32:26](https://youtu.be/q8SA3rM6ckI?si=nsXvTD-8RWvUAubq&t=1947) `cmp('counts', dcounts, counts)`"
|
509 |
+
]
|
510 |
+
},
|
511 |
+
{
|
512 |
+
"cell_type": "code",
|
513 |
+
"execution_count": 34,
|
514 |
+
"metadata": {},
|
515 |
+
"outputs": [
|
516 |
+
{
|
517 |
+
"name": "stdout",
|
518 |
+
"output_type": "stream",
|
519 |
+
"text": [
|
520 |
+
"counts | exact: True | approximate: True | maxdiff: 0.0\n"
|
521 |
+
]
|
522 |
+
}
|
523 |
+
],
|
524 |
+
"source": [
|
525 |
+
"# counts_sum = counts.sum(1, keepdims=True)\n",
|
526 |
+
"\n",
|
527 |
+
"# Now here we know the shape of counts_sum is 32 by 1 and the shape of counts is 32 by 27. So we need to broadcast counts_sum 27 times\n",
|
528 |
+
"# We are dirctly using a PyTorch function where it keeps adding numbers from `counts`\n",
|
529 |
+
"\n",
|
530 |
+
"dcounts += torch.ones_like(counts) * dcounts_sum #Also here we are adding `dcounts` as remember this is the second iteration of it, we had calculated one more value of it at the top\n",
|
531 |
+
"\n",
|
532 |
+
"cmp('counts', dcounts, counts)"
|
533 |
+
]
|
534 |
+
},
|
535 |
+
{
|
536 |
+
"cell_type": "markdown",
|
537 |
+
"metadata": {},
|
538 |
+
"source": [
|
539 |
+
"[32:27](https://youtu.be/q8SA3rM6ckI?si=nsXvTD-8RWvUAubq&t=1947) to [33:13](https://youtu.be/q8SA3rM6ckI?si=Ydk-b_pmKybrrnxe&t=1994) `cmp('norm_logits', dnorm_logits, norm_logits)`"
|
540 |
+
]
|
541 |
+
},
|
542 |
+
{
|
543 |
+
"cell_type": "code",
|
544 |
+
"execution_count": 35,
|
545 |
+
"metadata": {},
|
546 |
+
"outputs": [
|
547 |
+
{
|
548 |
+
"name": "stdout",
|
549 |
+
"output_type": "stream",
|
550 |
+
"text": [
|
551 |
+
"norm_logits | exact: True | approximate: True | maxdiff: 0.0\n"
|
552 |
+
]
|
553 |
+
}
|
554 |
+
],
|
555 |
+
"source": [
|
556 |
+
"# counts = norm_logits.exp()\n",
|
557 |
+
"\n",
|
558 |
+
"# Now here, the derivative of `norm_logits.exp()`, now the derivate of e^x is (famously) e^x, so its just `norm_logits.exp()` itself\n",
|
559 |
+
"# so we can also just write it as `counts` directly as it holds that value\n",
|
560 |
+
"\n",
|
561 |
+
"dnorm_logits = counts * dcounts\n",
|
562 |
+
"\n",
|
563 |
+
"cmp('norm_logits', dnorm_logits, norm_logits)"
|
564 |
+
]
|
565 |
+
},
|
566 |
+
{
|
567 |
+
"cell_type": "markdown",
|
568 |
+
"metadata": {},
|
569 |
+
"source": [
|
570 |
+
"[33:14](https://youtu.be/q8SA3rM6ckI?si=GIbBvHKGbW0RvlWf&t=1994) to [36:20](https://youtu.be/q8SA3rM6ckI?si=LGenDRNCeOVsWIkY&t=2180) `cmp('logit_maxes', dlogit_maxes, logit_maxes)`"
|
571 |
+
]
|
572 |
+
},
|
573 |
+
{
|
574 |
+
"cell_type": "code",
|
575 |
+
"execution_count": 36,
|
576 |
+
"metadata": {},
|
577 |
+
"outputs": [
|
578 |
+
{
|
579 |
+
"name": "stdout",
|
580 |
+
"output_type": "stream",
|
581 |
+
"text": [
|
582 |
+
"logit_maxes | exact: True | approximate: True | maxdiff: 0.0\n"
|
583 |
+
]
|
584 |
+
}
|
585 |
+
],
|
586 |
+
"source": [
|
587 |
+
"# norm_logits = logits - logit_maxes\n",
|
588 |
+
"\n",
|
589 |
+
"# Now here if you would look at the shape of all these variables, you would notice that there is internal broadcasting happening here (logit_maxes)\n",
|
590 |
+
"\n",
|
591 |
+
"dlogits = dnorm_logits.clone()\n",
|
592 |
+
"dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True) #WILL HAVE TO REWATCH THIS PART AGAIN, DIDN'T COMPLETELY GET IT\n",
|
593 |
+
"\n",
|
594 |
+
"cmp('logit_maxes', dlogit_maxes, logit_maxes)"
|
595 |
+
]
|
596 |
+
},
|
597 |
+
{
|
598 |
+
"cell_type": "markdown",
|
599 |
+
"metadata": {},
|
600 |
+
"source": [
|
601 |
+
"[38:27](https://youtu.be/q8SA3rM6ckI?si=sVCg29V84Ua56x3H&t=2307) to [41:44](https://youtu.be/q8SA3rM6ckI?si=yHhzlWlaR9J4VBo_&t=2504) `cmp('logits', dlogits, logits)`"
|
602 |
+
]
|
603 |
+
},
|
604 |
+
{
|
605 |
+
"cell_type": "code",
|
606 |
+
"execution_count": 37,
|
607 |
+
"metadata": {},
|
608 |
+
"outputs": [
|
609 |
+
{
|
610 |
+
"name": "stdout",
|
611 |
+
"output_type": "stream",
|
612 |
+
"text": [
|
613 |
+
"logits | exact: True | approximate: True | maxdiff: 0.0\n"
|
614 |
+
]
|
615 |
+
}
|
616 |
+
],
|
617 |
+
"source": [
|
618 |
+
"# logit_maxes = logits.max(1, keepdim=True).values\n",
|
619 |
+
"\n",
|
620 |
+
"# Here, this step is similar to that of the first one in `dlogprobs` where we used torch.zeros_like() function\n",
|
621 |
+
"# So we are doing another alternative way of doing that\n",
|
622 |
+
"\n",
|
623 |
+
"dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes #Just remember the += here as we already have one dlogits above\n",
|
624 |
+
"\n",
|
625 |
+
"cmp('logits', dlogits, logits)"
|
626 |
+
]
|
627 |
+
},
|
628 |
+
{
|
629 |
+
"cell_type": "markdown",
|
630 |
+
"metadata": {},
|
631 |
+
"source": [
|
632 |
+
"[41:45](https://youtu.be/q8SA3rM6ckI?si=wJvhK8v1Hj2sEhc6&t=2505) to [53:25](https://youtu.be/q8SA3rM6ckI?si=xg15htmnJE03afh5&t=3216) `cmp('h', dh, h)`, `cmp('W2', dW2, W2)` and `cmp('b2', db2, b2)` - Bckpropagation through a linear layer\n",
|
633 |
+
"\n",
|
634 |
+
"( Till [49:56](https://youtu.be/q8SA3rM6ckI?si=nX-tCDJWXFHTgqi3&t=2996) had theoritical proofs on the matrix multiplication )"
|
635 |
+
]
|
636 |
+
},
|
637 |
+
{
|
638 |
+
"cell_type": "code",
|
639 |
+
"execution_count": null,
|
640 |
+
"metadata": {},
|
641 |
+
"outputs": [],
|
642 |
+
"source": [
|
643 |
+
"# # Linear layer 2\n",
|
644 |
+
"# logits = h @ W2 + b2 # output layer\n",
|
645 |
+
"\n",
|
646 |
+
"# in `b2` broadcasting is happening"
|
647 |
+
]
|
648 |
+
},
|
649 |
+
{
|
650 |
+
"cell_type": "code",
|
651 |
+
"execution_count": 19,
|
652 |
+
"metadata": {},
|
653 |
+
"outputs": [
|
654 |
+
{
|
655 |
+
"data": {
|
656 |
+
"text/plain": [
|
657 |
+
"(torch.Size([32, 27]),\n",
|
658 |
+
" torch.Size([32, 64]),\n",
|
659 |
+
" torch.Size([64, 27]),\n",
|
660 |
+
" torch.Size([27]))"
|
661 |
+
]
|
662 |
+
},
|
663 |
+
"execution_count": 19,
|
664 |
+
"metadata": {},
|
665 |
+
"output_type": "execute_result"
|
666 |
+
}
|
667 |
+
],
|
668 |
+
"source": [
|
669 |
+
"# Need these for understanding the matrix mulitplication why we are multiplying with what\n",
|
670 |
+
"dlogits.shape, h.shape, W2.shape, b2.shape"
|
671 |
+
]
|
672 |
+
},
|
673 |
+
{
|
674 |
+
"cell_type": "code",
|
675 |
+
"execution_count": 38,
|
676 |
+
"metadata": {},
|
677 |
+
"outputs": [
|
678 |
+
{
|
679 |
+
"name": "stdout",
|
680 |
+
"output_type": "stream",
|
681 |
+
"text": [
|
682 |
+
"h | exact: True | approximate: True | maxdiff: 0.0\n",
|
683 |
+
"W2 | exact: True | approximate: True | maxdiff: 0.0\n",
|
684 |
+
"b2 | exact: True | approximate: True | maxdiff: 0.0\n"
|
685 |
+
]
|
686 |
+
}
|
687 |
+
],
|
688 |
+
"source": [
|
689 |
+
"# watch the last few minutes, probably from 51 to see how he broke down this based on the matrix sizes\n",
|
690 |
+
"dh = dlogits @ W2.T\n",
|
691 |
+
"dW2 = h.T @ dlogits\n",
|
692 |
+
"db2 = dlogits.sum(0)\n",
|
693 |
+
"\n",
|
694 |
+
"cmp('h', dh, h)\n",
|
695 |
+
"cmp('W2', dW2, W2)\n",
|
696 |
+
"cmp('b2', db2, b2)"
|
697 |
+
]
|
698 |
+
},
|
699 |
+
{
|
700 |
+
"cell_type": "markdown",
|
701 |
+
"metadata": {},
|
702 |
+
"source": [
|
703 |
+
"[53:37](https://youtu.be/q8SA3rM6ckI?si=xASEEmeuBmpZwd6B&t=3217) to 55:12 `cmp('hpreact', dhpreact, hpreact)`"
|
704 |
+
]
|
705 |
+
},
|
706 |
+
{
|
707 |
+
"cell_type": "code",
|
708 |
+
"execution_count": 39,
|
709 |
+
"metadata": {},
|
710 |
+
"outputs": [
|
711 |
+
{
|
712 |
+
"name": "stdout",
|
713 |
+
"output_type": "stream",
|
714 |
+
"text": [
|
715 |
+
"hpreact | exact: True | approximate: True | maxdiff: 0.0\n"
|
716 |
+
]
|
717 |
+
}
|
718 |
+
],
|
719 |
+
"source": [
|
720 |
+
"# h = torch.tanh(hpreact) # hidden layer\n",
|
721 |
+
"\n",
|
722 |
+
"dhpreact = (1.0 - h**2)*dh #we saw that the derivative of tanh is also (1-a^2) where a was the external variable `a`, not the input `z` to tanh i.e. a = tanh(z)\n",
|
723 |
+
"\n",
|
724 |
+
"cmp('hpreact', dhpreact, hpreact)"
|
725 |
+
]
|
726 |
+
},
|
727 |
+
{
|
728 |
+
"cell_type": "markdown",
|
729 |
+
"metadata": {},
|
730 |
+
"source": [
|
731 |
+
"[55:13](https://youtu.be/q8SA3rM6ckI?si=7v0ZQ9alRi52gD9s&t=3313) to 59:38 `cmp('bngain', dbngain, bngain)`"
|
732 |
+
]
|
733 |
+
},
|
734 |
+
{
|
735 |
+
"cell_type": "code",
|
736 |
+
"execution_count": 22,
|
737 |
+
"metadata": {},
|
738 |
+
"outputs": [
|
739 |
+
{
|
740 |
+
"data": {
|
741 |
+
"text/plain": [
|
742 |
+
"(torch.Size([32, 64]),\n",
|
743 |
+
" torch.Size([1, 64]),\n",
|
744 |
+
" torch.Size([1, 64]),\n",
|
745 |
+
" torch.Size([32, 64]))"
|
746 |
+
]
|
747 |
+
},
|
748 |
+
"execution_count": 22,
|
749 |
+
"metadata": {},
|
750 |
+
"output_type": "execute_result"
|
751 |
+
}
|
752 |
+
],
|
753 |
+
"source": [
|
754 |
+
"bnraw.shape, bngain.shape, bnbias.shape, dhpreact.shape"
|
755 |
+
]
|
756 |
+
},
|
757 |
+
{
|
758 |
+
"cell_type": "code",
|
759 |
+
"execution_count": 40,
|
760 |
+
"metadata": {},
|
761 |
+
"outputs": [
|
762 |
+
{
|
763 |
+
"name": "stdout",
|
764 |
+
"output_type": "stream",
|
765 |
+
"text": [
|
766 |
+
"bngain | exact: True | approximate: True | maxdiff: 0.0\n",
|
767 |
+
"bnbias | exact: True | approximate: True | maxdiff: 0.0\n",
|
768 |
+
"bnraw | exact: True | approximate: True | maxdiff: 0.0\n"
|
769 |
+
]
|
770 |
+
}
|
771 |
+
],
|
772 |
+
"source": [
|
773 |
+
"# hpreact = bngain * bnraw + bnbias\n",
|
774 |
+
"\n",
|
775 |
+
"dbngain = (bnraw * dhpreact).sum(0, keepdim=True) #because dbraw and dhpreact are 32by64, but dbngain expects 1by64 (we also keep the dimension)\n",
|
776 |
+
"dbnraw = (bngain * dhpreact)\n",
|
777 |
+
"dbnbias = (dhpreact).sum(0, keepdim=True) #because dhpreact is 32by64 but the dbnbias expects 1by64 (we also keep the dimension)\n",
|
778 |
+
"\n",
|
779 |
+
"cmp('bngain', dbngain, bngain)\n",
|
780 |
+
"cmp('bnbias', dbnbias, bnbias)\n",
|
781 |
+
"cmp('bnraw', dbnraw, bnraw)"
|
782 |
+
]
|
783 |
+
},
|
784 |
+
{
|
785 |
+
"cell_type": "markdown",
|
786 |
+
"metadata": {},
|
787 |
+
"source": [
|
788 |
+
"[59:40](https://youtu.be/q8SA3rM6ckI?si=RNb8T5WGla37958Q&t=3580) to 1:04:1 `cmp('bnvar_inv', dbnvar_inv, bnvar_inv)`"
|
789 |
+
]
|
790 |
+
},
|
791 |
+
{
|
792 |
+
"cell_type": "code",
|
793 |
+
"execution_count": null,
|
794 |
+
"metadata": {},
|
795 |
+
"outputs": [],
|
796 |
+
"source": [
|
797 |
+
"# From here we are working on the batch norm layer\n",
|
798 |
+
"# the code has been spread out and broken down to different parts (based on the equations on the \"bottom right corner box\" in the paper for batch norm - See prev lecture) inorder to perform manual backprop more easily"
|
799 |
+
]
|
800 |
+
},
|
801 |
+
{
|
802 |
+
"cell_type": "code",
|
803 |
+
"execution_count": 21,
|
804 |
+
"metadata": {},
|
805 |
+
"outputs": [
|
806 |
+
{
|
807 |
+
"data": {
|
808 |
+
"text/plain": [
|
809 |
+
"(torch.Size([32, 64]), torch.Size([32, 64]), torch.Size([1, 64]))"
|
810 |
+
]
|
811 |
+
},
|
812 |
+
"execution_count": 21,
|
813 |
+
"metadata": {},
|
814 |
+
"output_type": "execute_result"
|
815 |
+
}
|
816 |
+
],
|
817 |
+
"source": [
|
818 |
+
"bnraw.shape, bndiff.shape, bnvar_inv.shape"
|
819 |
+
]
|
820 |
+
},
|
821 |
+
{
|
822 |
+
"cell_type": "code",
|
823 |
+
"execution_count": 41,
|
824 |
+
"metadata": {},
|
825 |
+
"outputs": [
|
826 |
+
{
|
827 |
+
"name": "stdout",
|
828 |
+
"output_type": "stream",
|
829 |
+
"text": [
|
830 |
+
"bnvar_inv | exact: True | approximate: True | maxdiff: 0.0\n"
|
831 |
+
]
|
832 |
+
}
|
833 |
+
],
|
834 |
+
"source": [
|
835 |
+
"# bnraw = bndiff * bnvar_inv\n",
|
836 |
+
"\n",
|
837 |
+
"dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)\n",
|
838 |
+
"dbndiff = bnvar_inv * dbnraw #We will come back to this in 1:12:43 - (1)\n",
|
839 |
+
"\n",
|
840 |
+
"cmp('bnvar_inv', dbnvar_inv, bnvar_inv)"
|
841 |
+
]
|
842 |
+
},
|
843 |
+
{
|
844 |
+
"cell_type": "markdown",
|
845 |
+
"metadata": {},
|
846 |
+
"source": [
|
847 |
+
"[1:04:15](https://youtu.be/q8SA3rM6ckI?si=Mj6mc99YFmqYxo_l&t=3855) to 1:05:16 `cmp('bnvar', dbnvar, bnvar)`"
|
848 |
+
]
|
849 |
+
},
|
850 |
+
{
|
851 |
+
"cell_type": "code",
|
852 |
+
"execution_count": 42,
|
853 |
+
"metadata": {},
|
854 |
+
"outputs": [
|
855 |
+
{
|
856 |
+
"name": "stdout",
|
857 |
+
"output_type": "stream",
|
858 |
+
"text": [
|
859 |
+
"bnvar | exact: True | approximate: True | maxdiff: 0.0\n"
|
860 |
+
]
|
861 |
+
}
|
862 |
+
],
|
863 |
+
"source": [
|
864 |
+
"# bnvar_inv = (bnvar + 1e-5)**-0.5\n",
|
865 |
+
"#This is a direct equation of derivative of x^n so the output should be n*x^n-1\n",
|
866 |
+
"\n",
|
867 |
+
"dbnvar = (-0.5 * ((bnvar + 1e-5) ** (-1.5))) * dbnvar_inv\n",
|
868 |
+
"\n",
|
869 |
+
"cmp('bnvar', dbnvar, bnvar)"
|
870 |
+
]
|
871 |
+
},
|
872 |
+
{
|
873 |
+
"cell_type": "markdown",
|
874 |
+
"metadata": {},
|
875 |
+
"source": [
|
876 |
+
"[1:05:17](https://youtu.be/q8SA3rM6ckI?si=vjAXVF6w3BoZMC04&t=3917) to 1:09:01 - Why he implemented the bessel's correction (as there seem to be some problem/issue in the paper. Using Bias during training time and Unbiased during testing). But we prefer to use Unbiased during both training and testing and that is what we went ahead with."
|
877 |
+
]
|
878 |
+
},
|
879 |
+
{
|
880 |
+
"cell_type": "markdown",
|
881 |
+
"metadata": {},
|
882 |
+
"source": [
|
883 |
+
"[1:09:02](https://youtu.be/q8SA3rM6ckI?si=WxOg7f0S-mqLiZfD&t=4142) to 1:12:42 `cmp('bndiff2', dbndiff2, bndiff2)`"
|
884 |
+
]
|
885 |
+
},
|
886 |
+
{
|
887 |
+
"cell_type": "code",
|
888 |
+
"execution_count": 43,
|
889 |
+
"metadata": {},
|
890 |
+
"outputs": [
|
891 |
+
{
|
892 |
+
"name": "stdout",
|
893 |
+
"output_type": "stream",
|
894 |
+
"text": [
|
895 |
+
"bndiff2 | exact: True | approximate: True | maxdiff: 0.0\n"
|
896 |
+
]
|
897 |
+
}
|
898 |
+
],
|
899 |
+
"source": [
|
900 |
+
"# bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True)\n",
|
901 |
+
"\n",
|
902 |
+
"dbndiff2 = 1/(n-1) * torch.ones_like(bndiff2) * dbnvar\n",
|
903 |
+
"\n",
|
904 |
+
"cmp('bndiff2', dbndiff2, bndiff2)"
|
905 |
+
]
|
906 |
+
},
|
907 |
+
{
|
908 |
+
"cell_type": "markdown",
|
909 |
+
"metadata": {},
|
910 |
+
"source": [
|
911 |
+
"[1:12:43](https://youtu.be/q8SA3rM6ckI?si=HkT46KjpcZoit33H&t=4363) to 1:13:58 `cmp('bndiff', dbndiff, bndiff)`"
|
912 |
+
]
|
913 |
+
},
|
914 |
+
{
|
915 |
+
"cell_type": "code",
|
916 |
+
"execution_count": 44,
|
917 |
+
"metadata": {},
|
918 |
+
"outputs": [
|
919 |
+
{
|
920 |
+
"name": "stdout",
|
921 |
+
"output_type": "stream",
|
922 |
+
"text": [
|
923 |
+
"bndiff | exact: True | approximate: True | maxdiff: 0.0\n"
|
924 |
+
]
|
925 |
+
}
|
926 |
+
],
|
927 |
+
"source": [
|
928 |
+
"# bndiff2 = bndiff**2\n",
|
929 |
+
"\n",
|
930 |
+
"dbndiff += 2*bndiff * dbndiff2 #This is the (2)nd occurance of dbndiff - 59:40 so, we add it here\n",
|
931 |
+
"\n",
|
932 |
+
"cmp('bndiff', dbndiff, bndiff)"
|
933 |
+
]
|
934 |
+
},
|
935 |
+
{
|
936 |
+
"cell_type": "markdown",
|
937 |
+
"metadata": {},
|
938 |
+
"source": [
|
939 |
+
"[1:13:59](https://youtu.be/q8SA3rM6ckI?si=t03BQ_sro2n6X0a2&t=4439) to 1:18:35 `cmp('bnmeani', dbnmeani, bnmeani)` and `cmp('hprebn', dhprebn, hprebn)`"
|
940 |
+
]
|
941 |
+
},
|
942 |
+
{
|
943 |
+
"cell_type": "code",
|
944 |
+
"execution_count": 45,
|
945 |
+
"metadata": {},
|
946 |
+
"outputs": [
|
947 |
+
{
|
948 |
+
"name": "stdout",
|
949 |
+
"output_type": "stream",
|
950 |
+
"text": [
|
951 |
+
"bnmeani | exact: True | approximate: True | maxdiff: 0.0\n",
|
952 |
+
"hprebn | exact: True | approximate: True | maxdiff: 0.0\n"
|
953 |
+
]
|
954 |
+
}
|
955 |
+
],
|
956 |
+
"source": [
|
957 |
+
"## Please go thorugh this one again, i didnt completely get it\n",
|
958 |
+
"\n",
|
959 |
+
"# bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n",
|
960 |
+
"# bndiff = hprebn - bnmeani\n",
|
961 |
+
"\n",
|
962 |
+
"dhprebn = dbndiff.clone() #we are making a copy of it\n",
|
963 |
+
"dbnmeani = (-dbndiff).sum(0)\n",
|
964 |
+
"\n",
|
965 |
+
"dhprebn += (1.0/n)*(torch.ones_like(hprebn) * dbnmeani)\n",
|
966 |
+
"\n",
|
967 |
+
"cmp('bnmeani', dbnmeani, bnmeani)\n",
|
968 |
+
"cmp('hprebn', dhprebn, hprebn)"
|
969 |
+
]
|
970 |
+
},
|
971 |
+
{
|
972 |
+
"cell_type": "markdown",
|
973 |
+
"metadata": {},
|
974 |
+
"source": [
|
975 |
+
"[1:18:36](https://youtu.be/q8SA3rM6ckI?si=j_uFOOB3AsbrkbwM&t=4716) to 1:20:34 `cmp('embcat', dembcat, embcat)`, `cmp('W1', dW1, W1)` and `cmp('b1', db1, b1)`"
|
976 |
+
]
|
977 |
+
},
|
978 |
+
{
|
979 |
+
"cell_type": "code",
|
980 |
+
"execution_count": 47,
|
981 |
+
"metadata": {},
|
982 |
+
"outputs": [
|
983 |
+
{
|
984 |
+
"data": {
|
985 |
+
"text/plain": [
|
986 |
+
"(torch.Size([32, 64]),\n",
|
987 |
+
" torch.Size([32, 30]),\n",
|
988 |
+
" torch.Size([30, 64]),\n",
|
989 |
+
" torch.Size([64]))"
|
990 |
+
]
|
991 |
+
},
|
992 |
+
"execution_count": 47,
|
993 |
+
"metadata": {},
|
994 |
+
"output_type": "execute_result"
|
995 |
+
}
|
996 |
+
],
|
997 |
+
"source": [
|
998 |
+
"hprebn.shape, embcat.shape, W1.shape, b1.shape"
|
999 |
+
]
|
1000 |
+
},
|
1001 |
+
{
|
1002 |
+
"cell_type": "code",
|
1003 |
+
"execution_count": 49,
|
1004 |
+
"metadata": {},
|
1005 |
+
"outputs": [
|
1006 |
+
{
|
1007 |
+
"name": "stdout",
|
1008 |
+
"output_type": "stream",
|
1009 |
+
"text": [
|
1010 |
+
"embcat | exact: True | approximate: True | maxdiff: 0.0\n",
|
1011 |
+
"W1 | exact: True | approximate: True | maxdiff: 0.0\n",
|
1012 |
+
"b1 | exact: True | approximate: True | maxdiff: 0.0\n"
|
1013 |
+
]
|
1014 |
+
}
|
1015 |
+
],
|
1016 |
+
"source": [
|
1017 |
+
"# Forward pass: hprebn = embcat @ W1 + b1\n",
|
1018 |
+
"\n",
|
1019 |
+
"dembcat = dhprebn @ W1.T\n",
|
1020 |
+
"dW1 = embcat.T @ dhprebn\n",
|
1021 |
+
"db1 = dhprebn.sum(0)\n",
|
1022 |
+
"\n",
|
1023 |
+
"cmp('embcat', dembcat, embcat)\n",
|
1024 |
+
"cmp('W1', dW1, W1)\n",
|
1025 |
+
"cmp('b1', db1, b1)"
|
1026 |
+
]
|
1027 |
+
},
|
1028 |
+
{
|
1029 |
+
"cell_type": "markdown",
|
1030 |
+
"metadata": {},
|
1031 |
+
"source": [
|
1032 |
+
"[1:20:35](https://youtu.be/q8SA3rM6ckI?si=F8arFi8ee8a9eAvv&t=4835) to 1:21:58 `cmp('emb', demb, emb)`"
|
1033 |
+
]
|
1034 |
+
},
|
1035 |
+
{
|
1036 |
+
"cell_type": "code",
|
1037 |
+
"execution_count": 50,
|
1038 |
+
"metadata": {},
|
1039 |
+
"outputs": [
|
1040 |
+
{
|
1041 |
+
"name": "stdout",
|
1042 |
+
"output_type": "stream",
|
1043 |
+
"text": [
|
1044 |
+
"emb | exact: True | approximate: True | maxdiff: 0.0\n"
|
1045 |
+
]
|
1046 |
+
}
|
1047 |
+
],
|
1048 |
+
"source": [
|
1049 |
+
"## Please rewatch this as well\n",
|
1050 |
+
"\n",
|
1051 |
+
"# embcat = emb.view(emb.shape[0], -1)\n",
|
1052 |
+
"\n",
|
1053 |
+
"demb = dembcat.view(emb.shape)\n",
|
1054 |
+
"\n",
|
1055 |
+
"cmp('emb', demb, emb)"
|
1056 |
+
]
|
1057 |
+
},
|
1058 |
+
{
|
1059 |
+
"cell_type": "markdown",
|
1060 |
+
"metadata": {},
|
1061 |
+
"source": [
|
1062 |
+
"[1:21:59](https://youtu.be/q8SA3rM6ckI?si=cPimgFWzBgjrkpAr&t=4919) to `cmp('C', dC, C)`"
|
1063 |
+
]
|
1064 |
+
},
|
1065 |
+
{
|
1066 |
+
"cell_type": "code",
|
1067 |
+
"execution_count": 51,
|
1068 |
+
"metadata": {},
|
1069 |
+
"outputs": [
|
1070 |
+
{
|
1071 |
+
"name": "stdout",
|
1072 |
+
"output_type": "stream",
|
1073 |
+
"text": [
|
1074 |
+
"C | exact: True | approximate: True | maxdiff: 0.0\n"
|
1075 |
+
]
|
1076 |
+
}
|
1077 |
+
],
|
1078 |
+
"source": [
|
1079 |
+
"## Please rewatch this as well\n",
|
1080 |
+
"# emb = C[Xb]\n",
|
1081 |
+
"\n",
|
1082 |
+
"dC = torch.zeros_like(C)\n",
|
1083 |
+
"for k in range(Xb.shape[0]):\n",
|
1084 |
+
" for j in range(Xb.shape[1]):\n",
|
1085 |
+
" ix = Xb[k,j]\n",
|
1086 |
+
" dC[ix] += demb[k,j]\n",
|
1087 |
+
"\n",
|
1088 |
+
"cmp('C', dC, C)"
|
1089 |
+
]
|
1090 |
+
},
|
1091 |
+
{
|
1092 |
+
"cell_type": "markdown",
|
1093 |
+
"metadata": {},
|
1094 |
+
"source": [
|
1095 |
+
"And we are done with the first exercise!!"
|
1096 |
+
]
|
1097 |
+
},
|
1098 |
+
{
|
1099 |
+
"cell_type": "code",
|
1100 |
+
"execution_count": null,
|
1101 |
+
"metadata": {
|
1102 |
+
"id": "mO-8aqxK8PPw"
|
1103 |
+
},
|
1104 |
+
"outputs": [],
|
1105 |
+
"source": [
|
1106 |
+
"# Exercise 1: backprop through the whole thing manually,\n",
|
1107 |
+
"# backpropagating through exactly all of the variables\n",
|
1108 |
+
"# as they are defined in the forward pass above, one by one\n",
|
1109 |
+
"\n",
|
1110 |
+
"# -----------------\n",
|
1111 |
+
"# YOUR CODE HERE :)\n",
|
1112 |
+
"# -----------------\n",
|
1113 |
+
"\n",
|
1114 |
+
"# cmp('logprobs', dlogprobs, logprobs)\n",
|
1115 |
+
"# cmp('probs', dprobs, probs)\n",
|
1116 |
+
"# cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)\n",
|
1117 |
+
"# cmp('counts_sum', dcounts_sum, counts_sum)\n",
|
1118 |
+
"# cmp('counts', dcounts, counts)\n",
|
1119 |
+
"# cmp('norm_logits', dnorm_logits, norm_logits)\n",
|
1120 |
+
"# cmp('logit_maxes', dlogit_maxes, logit_maxes)\n",
|
1121 |
+
"# cmp('logits', dlogits, logits)\n",
|
1122 |
+
"# cmp('h', dh, h)\n",
|
1123 |
+
"# cmp('W2', dW2, W2)\n",
|
1124 |
+
"# cmp('b2', db2, b2)\n",
|
1125 |
+
"# cmp('hpreact', dhpreact, hpreact)\n",
|
1126 |
+
"# cmp('bngain', dbngain, bngain)\n",
|
1127 |
+
"# cmp('bnbias', dbnbias, bnbias)\n",
|
1128 |
+
"# cmp('bnraw', dbnraw, bnraw)\n",
|
1129 |
+
"# cmp('bnvar_inv', dbnvar_inv, bnvar_inv)\n",
|
1130 |
+
"# cmp('bnvar', dbnvar, bnvar)\n",
|
1131 |
+
"# cmp('bndiff2', dbndiff2, bndiff2)\n",
|
1132 |
+
"# cmp('bndiff', dbndiff, bndiff)\n",
|
1133 |
+
"# cmp('bnmeani', dbnmeani, bnmeani)\n",
|
1134 |
+
"# cmp('hprebn', dhprebn, hprebn)\n",
|
1135 |
+
"# cmp('embcat', dembcat, embcat)\n",
|
1136 |
+
"# cmp('W1', dW1, W1)\n",
|
1137 |
+
"# cmp('b1', db1, b1)\n",
|
1138 |
+
"# cmp('emb', demb, emb)\n",
|
1139 |
+
"# cmp('C', dC, C)"
|
1140 |
+
]
|
1141 |
+
}
|
1142 |
+
],
|
1143 |
+
"metadata": {
|
1144 |
+
"colab": {
|
1145 |
+
"provenance": []
|
1146 |
+
},
|
1147 |
+
"kernelspec": {
|
1148 |
+
"display_name": "venv",
|
1149 |
+
"language": "python",
|
1150 |
+
"name": "python3"
|
1151 |
+
},
|
1152 |
+
"language_info": {
|
1153 |
+
"codemirror_mode": {
|
1154 |
+
"name": "ipython",
|
1155 |
+
"version": 3
|
1156 |
+
},
|
1157 |
+
"file_extension": ".py",
|
1158 |
+
"mimetype": "text/x-python",
|
1159 |
+
"name": "python",
|
1160 |
+
"nbconvert_exporter": "python",
|
1161 |
+
"pygments_lexer": "ipython3",
|
1162 |
+
"version": "3.10.0"
|
1163 |
+
}
|
1164 |
+
},
|
1165 |
+
"nbformat": 4,
|
1166 |
+
"nbformat_minor": 0
|
1167 |
+
}
|
exercise-2.ipynb
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"# This exercise wasn't exactly smooth sailing for me, but I did try to understand most of it. Will try to come back to this whenever I can"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"# there no change change in the first several cells from last lecture\n",
|
19 |
+
"\n",
|
20 |
+
"import torch\n",
|
21 |
+
"import torch.nn.functional as F\n",
|
22 |
+
"import matplotlib.pyplot as plt # for making figures\n",
|
23 |
+
"%matplotlib inline"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": null,
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"# download the names.txt file from github\n",
|
33 |
+
"!wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "code",
|
38 |
+
"execution_count": null,
|
39 |
+
"metadata": {},
|
40 |
+
"outputs": [],
|
41 |
+
"source": [
|
42 |
+
"# read in all the words\n",
|
43 |
+
"words = open('names.txt', 'r').read().splitlines()\n",
|
44 |
+
"# print(len(words))\n",
|
45 |
+
"# print(max(len(w) for w in words))\n",
|
46 |
+
"# print(words[:8])\n",
|
47 |
+
"\n",
|
48 |
+
"# build the vocabulary of characters and mappings to/from integers\n",
|
49 |
+
"chars = sorted(list(set(''.join(words))))\n",
|
50 |
+
"stoi = {s:i+1 for i,s in enumerate(chars)}\n",
|
51 |
+
"stoi['.'] = 0\n",
|
52 |
+
"itos = {i:s for s,i in stoi.items()}\n",
|
53 |
+
"vocab_size = len(itos)\n",
|
54 |
+
"# print(itos)\n",
|
55 |
+
"# print(vocab_size)\n",
|
56 |
+
"\n",
|
57 |
+
"# build the dataset\n",
|
58 |
+
"block_size = 3 # context length: how many characters do we take to predict the next one?\n",
|
59 |
+
"\n",
|
60 |
+
"def build_dataset(words):\n",
|
61 |
+
" X, Y = [], []\n",
|
62 |
+
"\n",
|
63 |
+
" for w in words:\n",
|
64 |
+
" context = [0] * block_size\n",
|
65 |
+
" for ch in w + '.':\n",
|
66 |
+
" ix = stoi[ch]\n",
|
67 |
+
" X.append(context)\n",
|
68 |
+
" Y.append(ix)\n",
|
69 |
+
" context = context[1:] + [ix] # crop and append\n",
|
70 |
+
"\n",
|
71 |
+
" X = torch.tensor(X)\n",
|
72 |
+
" Y = torch.tensor(Y)\n",
|
73 |
+
" # print(X.shape, Y.shape)\n",
|
74 |
+
" return X, Y\n",
|
75 |
+
"\n",
|
76 |
+
"import random\n",
|
77 |
+
"random.seed(42)\n",
|
78 |
+
"random.shuffle(words)\n",
|
79 |
+
"n1 = int(0.8*len(words))\n",
|
80 |
+
"n2 = int(0.9*len(words))\n",
|
81 |
+
"\n",
|
82 |
+
"Xtr, Ytr = build_dataset(words[:n1]) # 80%\n",
|
83 |
+
"Xdev, Ydev = build_dataset(words[n1:n2]) # 10%\n",
|
84 |
+
"Xte, Yte = build_dataset(words[n2:]) # 10%"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "code",
|
89 |
+
"execution_count": 3,
|
90 |
+
"metadata": {},
|
91 |
+
"outputs": [],
|
92 |
+
"source": [
|
93 |
+
"# utility function we will use later when comparing manual gradients to PyTorch gradients\n",
|
94 |
+
"def cmp(s, dt, t):\n",
|
95 |
+
" ex = torch.all(dt == t.grad).item()\n",
|
96 |
+
" app = torch.allclose(dt, t.grad)\n",
|
97 |
+
" maxdiff = (dt - t.grad).abs().max().item()\n",
|
98 |
+
" print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')"
|
99 |
+
]
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "code",
|
103 |
+
"execution_count": 4,
|
104 |
+
"metadata": {},
|
105 |
+
"outputs": [
|
106 |
+
{
|
107 |
+
"name": "stdout",
|
108 |
+
"output_type": "stream",
|
109 |
+
"text": [
|
110 |
+
"4137\n"
|
111 |
+
]
|
112 |
+
}
|
113 |
+
],
|
114 |
+
"source": [
|
115 |
+
"n_embd = 10 # the dimensionality of the character embedding vectors\n",
|
116 |
+
"n_hidden = 64 # the number of neurons in the hidden layer of the MLP\n",
|
117 |
+
"\n",
|
118 |
+
"g = torch.Generator().manual_seed(2147483647) # for reproducibility\n",
|
119 |
+
"C = torch.randn((vocab_size, n_embd), generator=g)\n",
|
120 |
+
"# Layer 1\n",
|
121 |
+
"W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)\n",
|
122 |
+
"b1 = torch.randn(n_hidden, generator=g) * 0.1 # using b1 just for fun, it's useless because of BN\n",
|
123 |
+
"# Layer 2\n",
|
124 |
+
"W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1\n",
|
125 |
+
"b2 = torch.randn(vocab_size, generator=g) * 0.1\n",
|
126 |
+
"# BatchNorm parameters\n",
|
127 |
+
"bngain = torch.randn((1, n_hidden))*0.1 + 1.0\n",
|
128 |
+
"bnbias = torch.randn((1, n_hidden))*0.1\n",
|
129 |
+
"\n",
|
130 |
+
"# Note: I am initializating many of these parameters in non-standard ways\n",
|
131 |
+
"# because sometimes initializating with e.g. all zeros could mask an incorrect\n",
|
132 |
+
"# implementation of the backward pass.\n",
|
133 |
+
"\n",
|
134 |
+
"parameters = [C, W1, b1, W2, b2, bngain, bnbias]\n",
|
135 |
+
"print(sum(p.nelement() for p in parameters)) # number of parameters in total\n",
|
136 |
+
"for p in parameters:\n",
|
137 |
+
" p.requires_grad = True"
|
138 |
+
]
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"cell_type": "code",
|
142 |
+
"execution_count": 5,
|
143 |
+
"metadata": {},
|
144 |
+
"outputs": [],
|
145 |
+
"source": [
|
146 |
+
"batch_size = 32\n",
|
147 |
+
"n = batch_size # a shorter variable also, for convenience\n",
|
148 |
+
"# construct a minibatch\n",
|
149 |
+
"ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)\n",
|
150 |
+
"Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": 6,
|
156 |
+
"metadata": {},
|
157 |
+
"outputs": [
|
158 |
+
{
|
159 |
+
"data": {
|
160 |
+
"text/plain": [
|
161 |
+
"tensor(3.3596, grad_fn=<NegBackward0>)"
|
162 |
+
]
|
163 |
+
},
|
164 |
+
"execution_count": 6,
|
165 |
+
"metadata": {},
|
166 |
+
"output_type": "execute_result"
|
167 |
+
}
|
168 |
+
],
|
169 |
+
"source": [
|
170 |
+
"# forward pass, \"chunkated\" into smaller steps that are possible to backward one at a time\n",
|
171 |
+
"\n",
|
172 |
+
"emb = C[Xb] # embed the characters into vectors\n",
|
173 |
+
"embcat = emb.view(emb.shape[0], -1) # concatenate the vectors\n",
|
174 |
+
"# Linear layer 1\n",
|
175 |
+
"hprebn = embcat @ W1 + b1 # hidden layer pre-activation\n",
|
176 |
+
"# BatchNorm layer\n",
|
177 |
+
"bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n",
|
178 |
+
"bndiff = hprebn - bnmeani\n",
|
179 |
+
"bndiff2 = bndiff**2\n",
|
180 |
+
"bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)\n",
|
181 |
+
"bnvar_inv = (bnvar + 1e-5)**-0.5\n",
|
182 |
+
"bnraw = bndiff * bnvar_inv\n",
|
183 |
+
"hpreact = bngain * bnraw + bnbias\n",
|
184 |
+
"# Non-linearity\n",
|
185 |
+
"h = torch.tanh(hpreact) # hidden layer\n",
|
186 |
+
"# Linear layer 2\n",
|
187 |
+
"logits = h @ W2 + b2 # output layer\n",
|
188 |
+
"# cross entropy loss (same as F.cross_entropy(logits, Yb))\n",
|
189 |
+
"logit_maxes = logits.max(1, keepdim=True).values\n",
|
190 |
+
"norm_logits = logits - logit_maxes # subtract max for numerical stability\n",
|
191 |
+
"counts = norm_logits.exp()\n",
|
192 |
+
"counts_sum = counts.sum(1, keepdims=True)\n",
|
193 |
+
"counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...\n",
|
194 |
+
"probs = counts * counts_sum_inv\n",
|
195 |
+
"logprobs = probs.log()\n",
|
196 |
+
"loss = -logprobs[range(n), Yb].mean()\n",
|
197 |
+
"\n",
|
198 |
+
"# PyTorch backward pass\n",
|
199 |
+
"for p in parameters:\n",
|
200 |
+
" p.grad = None\n",
|
201 |
+
"for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way\n",
|
202 |
+
" norm_logits, logit_maxes, logits, h, hpreact, bnraw,\n",
|
203 |
+
" bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,\n",
|
204 |
+
" embcat, emb]:\n",
|
205 |
+
" t.retain_grad()\n",
|
206 |
+
"loss.backward()\n",
|
207 |
+
"loss"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "markdown",
|
212 |
+
"metadata": {},
|
213 |
+
"source": [
|
214 |
+
"Similar boiler plate codes as done in the prev exercise and provided in the starter code^\n",
|
215 |
+
"\n",
|
216 |
+
"------------"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "code",
|
221 |
+
"execution_count": null,
|
222 |
+
"metadata": {},
|
223 |
+
"outputs": [],
|
224 |
+
"source": [
|
225 |
+
"# Exercise 2: backprop through cross_entropy but all in one go\n",
|
226 |
+
"# to complete this challenge look at the mathematical expression of the loss,\n",
|
227 |
+
"# take the derivative, simplify the expression, and just write it out\n",
|
228 |
+
"\n",
|
229 |
+
"# forward pass\n",
|
230 |
+
"\n",
|
231 |
+
"# before:\n",
|
232 |
+
"# logit_maxes = logits.max(1, keepdim=True).values\n",
|
233 |
+
"# norm_logits = logits - logit_maxes # subtract max for numerical stability\n",
|
234 |
+
"# counts = norm_logits.exp()\n",
|
235 |
+
"# counts_sum = counts.sum(1, keepdims=True)\n",
|
236 |
+
"# counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...\n",
|
237 |
+
"# probs = counts * counts_sum_inv\n",
|
238 |
+
"# logprobs = probs.log()\n",
|
239 |
+
"# loss = -logprobs[range(n), Yb].mean()\n",
|
240 |
+
"\n",
|
241 |
+
"# now:\n",
|
242 |
+
"# loss_fast = F.cross_entropy(logits, Yb)\n",
|
243 |
+
"# print(loss_fast.item(), 'diff:', (loss_fast - loss).item())"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "markdown",
|
248 |
+
"metadata": {},
|
249 |
+
"source": [
|
250 |
+
"In the above example we are seeing how the forward pass is broken down if we do the manual breakdown of calculation vs just directly using PyTorch"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "markdown",
|
255 |
+
"metadata": {},
|
256 |
+
"source": [
|
257 |
+
"[1:28:34](https://youtu.be/q8SA3rM6ckI?si=O-RCp2YO7QbSbUIW&t=5314) to 1:32:48 - Andrej sensei gives us an hint followed with an explaination of solving the equation and convert that to code"
|
258 |
+
]
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "code",
|
262 |
+
"execution_count": 7,
|
263 |
+
"metadata": {},
|
264 |
+
"outputs": [
|
265 |
+
{
|
266 |
+
"name": "stdout",
|
267 |
+
"output_type": "stream",
|
268 |
+
"text": [
|
269 |
+
"logits | exact: False | approximate: True | maxdiff: 8.381903171539307e-09\n"
|
270 |
+
]
|
271 |
+
}
|
272 |
+
],
|
273 |
+
"source": [
|
274 |
+
"# backward pass\n",
|
275 |
+
"\n",
|
276 |
+
"dlogits = F.softmax(logits, 1)\n",
|
277 |
+
"dlogits[range(n), Yb] -= 1\n",
|
278 |
+
"dlogits /= n\n",
|
279 |
+
"\n",
|
280 |
+
"cmp('logits', dlogits, logits)\n",
|
281 |
+
"\n",
|
282 |
+
"#This wasnt exactly very clear to me, but i will come back to this\n",
|
283 |
+
"#Also my output came slightly bigger than sensei's though"
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "markdown",
|
288 |
+
"metadata": {},
|
289 |
+
"source": [
|
290 |
+
"[1:32:49](https://youtu.be/q8SA3rM6ckI?si=-204uFZWpJPaT9oU&t=5569) to 1:36:36 - Breakdown of what `dlogits` actually is by taking one row and representing it dynamically"
|
291 |
+
]
|
292 |
+
},
|
293 |
+
{
|
294 |
+
"cell_type": "code",
|
295 |
+
"execution_count": 8,
|
296 |
+
"metadata": {},
|
297 |
+
"outputs": [
|
298 |
+
{
|
299 |
+
"data": {
|
300 |
+
"text/plain": [
|
301 |
+
"(torch.Size([32, 27]), torch.Size([32]))"
|
302 |
+
]
|
303 |
+
},
|
304 |
+
"execution_count": 8,
|
305 |
+
"metadata": {},
|
306 |
+
"output_type": "execute_result"
|
307 |
+
}
|
308 |
+
],
|
309 |
+
"source": [
|
310 |
+
"logits.shape, Yb.shape"
|
311 |
+
]
|
312 |
+
},
|
313 |
+
{
|
314 |
+
"cell_type": "code",
|
315 |
+
"execution_count": 9,
|
316 |
+
"metadata": {},
|
317 |
+
"outputs": [
|
318 |
+
{
|
319 |
+
"data": {
|
320 |
+
"text/plain": [
|
321 |
+
"tensor([0.0727, 0.0823, 0.0164, 0.0532, 0.0213, 0.0895, 0.0218, 0.0357, 0.0174,\n",
|
322 |
+
" 0.0327, 0.0371, 0.0337, 0.0347, 0.0311, 0.0346, 0.0131, 0.0086, 0.0178,\n",
|
323 |
+
" 0.0161, 0.0499, 0.0532, 0.0226, 0.0259, 0.0712, 0.0607, 0.0274, 0.0192],\n",
|
324 |
+
" grad_fn=<SelectBackward0>)"
|
325 |
+
]
|
326 |
+
},
|
327 |
+
"execution_count": 9,
|
328 |
+
"metadata": {},
|
329 |
+
"output_type": "execute_result"
|
330 |
+
}
|
331 |
+
],
|
332 |
+
"source": [
|
333 |
+
"F.softmax(logits, 1)[0]"
|
334 |
+
]
|
335 |
+
},
|
336 |
+
{
|
337 |
+
"cell_type": "code",
|
338 |
+
"execution_count": 10,
|
339 |
+
"metadata": {},
|
340 |
+
"outputs": [
|
341 |
+
{
|
342 |
+
"data": {
|
343 |
+
"text/plain": [
|
344 |
+
"tensor([ 0.0727, 0.0823, 0.0164, 0.0532, 0.0213, 0.0895, 0.0218, 0.0357,\n",
|
345 |
+
" -0.9826, 0.0327, 0.0371, 0.0337, 0.0347, 0.0311, 0.0346, 0.0131,\n",
|
346 |
+
" 0.0086, 0.0178, 0.0161, 0.0499, 0.0532, 0.0226, 0.0259, 0.0712,\n",
|
347 |
+
" 0.0607, 0.0274, 0.0192], grad_fn=<MulBackward0>)"
|
348 |
+
]
|
349 |
+
},
|
350 |
+
"execution_count": 10,
|
351 |
+
"metadata": {},
|
352 |
+
"output_type": "execute_result"
|
353 |
+
}
|
354 |
+
],
|
355 |
+
"source": [
|
356 |
+
"dlogits[0] * n"
|
357 |
+
]
|
358 |
+
},
|
359 |
+
{
|
360 |
+
"cell_type": "code",
|
361 |
+
"execution_count": 11,
|
362 |
+
"metadata": {},
|
363 |
+
"outputs": [
|
364 |
+
{
|
365 |
+
"data": {
|
366 |
+
"text/plain": [
|
367 |
+
"tensor(2.0955e-09, grad_fn=<SumBackward0>)"
|
368 |
+
]
|
369 |
+
},
|
370 |
+
"execution_count": 11,
|
371 |
+
"metadata": {},
|
372 |
+
"output_type": "execute_result"
|
373 |
+
}
|
374 |
+
],
|
375 |
+
"source": [
|
376 |
+
"dlogits[0].sum()"
|
377 |
+
]
|
378 |
+
},
|
379 |
+
{
|
380 |
+
"cell_type": "code",
|
381 |
+
"execution_count": 13,
|
382 |
+
"metadata": {},
|
383 |
+
"outputs": [
|
384 |
+
{
|
385 |
+
"data": {
|
386 |
+
"text/plain": [
|
387 |
+
"<matplotlib.image.AxesImage at 0x1b1aabfa7a0>"
|
388 |
+
]
|
389 |
+
},
|
390 |
+
"execution_count": 13,
|
391 |
+
"metadata": {},
|
392 |
+
"output_type": "execute_result"
|
393 |
+
},
|
394 |
+
{
|
395 |
+
"data": {
|
396 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAKTCAYAAADlpSlWAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAMVBJREFUeJzt3X2MXXWdP/DPnTuPpdPB8tCHpS0FhKrQmiDURmVRupS6ISI1wYdkwRCMbiELjavpRkVck+5iov7cIP6zC2ti1WUjGEgWo1VKzBZcakhFobSlLmWhZYW00047T/ee3x9NZx1pgWk/5Q7fvl7JTTozt+/53HPP98x7zsycW6uqqgoAgEK0tXoAAIBMyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKK0t3qAP9VsNuP555+P3t7eqNVqrR4HAJgEqqqKvXv3xuzZs6Ot7dXPzUy6cvP888/HnDlzWj0GADAJ7dixI84444xXvc+kKze9vb0REfHrX/967N/H4rXa3UT09/enZUVEdHV1pWUNDw+nZfX19aVlRUTs2bMnLater6dlXXDBBWlZv/nNb9KyIiL1rOVkvQh59lyZ22xkZCQtK9NkPpudeTzLlHlsjMh9DqZMmZKW1Ww207IGBwfTsiLy1vq+ffvive997+vqBpOu3BzacXp7eydduck+GE/WcpOx3f9Y5qLLLDeZB6nsbabcTNxkLTeZjzPzeJZNuZm4yVpuOjo60rIiWrPWJ+9KAQA4CsoNAFAU5QYAKMpxKzd33HFHnHnmmdHd3R2LFy+OX/3qV8frUwEAjDku5eaHP/xhrFq1Km699db49a9/HYsWLYply5bFiy++eDw+HQDAmONSbr7+9a/HDTfcEJ/85Cfj7W9/e3znO9+JKVOmxL/8y78cj08HADAmvdwMDw/Hxo0bY+nSpf/3SdraYunSpbFhw4ZX3H9oaCj6+/vH3QAAjlZ6ufnDH/4QjUYjZsyYMe79M2bMiJ07d77i/mvWrIm+vr6xm6sTAwDHouV/LbV69erYs2fP2G3Hjh2tHgkAeBNLv0LxqaeeGvV6PXbt2jXu/bt27YqZM2e+4v5dXV2T9sqWAMCbT/qZm87Ozrjwwgtj3bp1Y+9rNpuxbt26WLJkSfanAwAY57i8ttSqVavi2muvjXe9611x8cUXxze/+c0YGBiIT37yk8fj0wEAjDku5eaaa66J//3f/40vfelLsXPnznjnO98ZDz744Ct+yRgAINtxe1XwG2+8MW688cbjFQ8AcFgt/2spAIBMyg0AUJTj9mOpYzUyMhIjIyOtHmOc6dOnp+YNDg6mZbW35z2Ve/fuTcvKVq/X07K2bNmSltVsNtOyIiI6OjrSsqqqSsuq1WppWY1GIy0rIuLcc89Ny8rcNzIfZ/Z+Nlmfz8yszMcYkfscZGZlfj1pa8s975H1OCfyXDpzAwAURbkBAIqi3AAARVFuAICiKDcAQFGUGwCgKMoNAFAU5QYAKIpyAwAURbkBAIqi3AAARVFuAICiKDcAQFGUGwCgKMoNAFAU5QYAKIpyAwAURbkBAIrS3uoBjmRoaCg6OzuPOaetLa+/vfzyy2lZERFVVaVlZT7Ojo6OtKyISHkeD8l8nJlZw8PDaVnZeZmPs16vT8qsiIgnn3wyLeuss85Ky3rqqafSsrLXZuYxqK+vLy1rcHAwLSt7bWbutyMjI2lZmXONjo6mZUVE1Gq11LzXw5kbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoSnurBziSer0e9Xr9mHOazWbCNAd1dHSkZUVEtLfnbf6MbXXIgQMH0rKyNRqNVo9wWJn7WUTuvjE6OpqWlSlzn42I6OzsTMv6n//5n7Ss/fv3p2Vl7/9VVaVl7du3Ly1reHg4LatWq6VlRUQsWLAgLevJJ59My2pryztXkbmWIvL2s4kcF525AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVpb/UAR7Jw4cKUnK1bt6bkREQ0Go20rOy8ZrOZltXR0ZGWFZE72+joaFpWV1dXWlZ7e+5SqqoqLStz+2c+zsznMiJ3thkzZqRlPfPMM2lZnZ2daVnZ2tryvlfOfJxDQ0NpWRERTz75ZFrWZD1uj4yMpGVF5B8fXw9nbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBR2ls9wJFs2rQpent7Wz3GOO3tuZsrM6+tLa+nDgwMpGVl6+npScsaGhpKy2o2m2lZERGdnZ1pWbVaLS2r0WikZXV0dKRlRUTU6/W0rOeeey4tq6qqtKzMfTYid7YFCxakZW3bti0tK/u4nXmsHR4enpRZ06ZNS8uKiBgcHEzNez2cuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAo6eXmy1/+ctRqtXG3zN+iBwB4NcflT8Hf8Y53xM9+9rP/+yTJf4oHAHAkx6V1tLe3x8yZM49HNADAqzouv3OzZcuWmD17dpx11lnxiU98Ip599tkj3ndoaCj6+/vH3QAAjlZ6uVm8eHHcfffd8eCDD8add94Z27dvj/e9732xd+/ew95/zZo10dfXN3abM2dO9kgAwAmkVmVef/swdu/eHfPmzYuvf/3rcf3117/i40NDQ+MuKd7f3x9z5szx8gsTdKK8/EJXV1da1ony8guZl2XP3M8m88svZL7MRPZLJmQ6EV5+IdtkffmFTJP15Rf27t0b73jHO2LPnj2vOeNx/03fk08+Oc4999zYunXrYT/e1dWV+gULADixHffr3Ozbty+2bdsWs2bNOt6fCgAgv9x89rOfjfXr18fvf//7+M///M/48Ic/HPV6PT72sY9lfyoAgFdI/7HUc889Fx/72MfipZdeitNOOy3e+973xiOPPBKnnXZa9qcCAHiF9HLzgx/8IDsSAOB189pSAEBRlBsAoCiT9kWf2tvbU64Dk/X39RG511iJyL2eTOY1PrIvfdTT05OalyXzOitnn312WlZExO9+97u0rMm6b0zm679kXmOru7s7LWv//v1pWRG511nJvDZN5n42mV/bsFarpWVlPs59+/alZUXkPc6JXH/KmRsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGAChKe6sHOJJmsxnNZvOYc9rb8x7i4OBgWlZExKxZs9KyXnzxxbSszs7OtKyI3O3W29ubljU8PJyW9cQTT6RlRUS0teV93zE6OpqWlTnXlClT0rIiImbPnp2WtXXr1rSsqqrSsrLVarW0rKlTp6ZlDQwMpGVlGxkZScuq1+tpWY1GIy2rq6srLSsib5tNZH915gYAKIpyAwAURbkBAIqi3AAARVFuAICiKDcAQFGUGwCgKMoNAFAU5QYAKIpyAwAURbkBAIqi3AAARVFuAICiKDcAQFGUGwCgKMoNAFAU5QYAKIpyAwAUpb3VAxxJW1tbtLUde/dqNBoJ0xzUbDbTsiIiXnrppbSs0dHRtKwFCxakZUVEbN26NS2rVqulZWU+n/V6PS0rIvdxtrfnLfOMNXnI4OBgWlZExJYtW9KyMrd/Zlb2fjaZj49Zenp6UvOqqkrNy5K5Ng8cOJCWFZG3305k2ztzAwAURbkBAIqi3AAARVFuAICiKDcAQFGUGwCgKMoNAFAU5QYAKIpyAwAURbkBAIqi3AAARVFuAICiKDcAQFGUGwCgKMoNAFAU5QYAKIpyAwAURbkBAIrS3uoBjmRkZCRGRkaOOeess85KmOag7du3p2VFRMrjO6SjoyMta/PmzWlZERGNRiMta9++fWlZvb29aVmZz2VExMDAQFpW5r6Rqb099/BTVVVaVltb3vd9XV1daVmjo6NpWRERtVotLStzbU6ZMiUtq7+/Py0rIqKnpyctK3Od1+v1tKzsY0bW8XEiX0ucuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFaW/1AEfSbDaj2Wwec87TTz+dMM1BbW25XbBer6dlNRqNtKxarZaWFRExOjqalpWxTxwPk3nfyNz+PT09aVnDw8NpWRER7e15h7NZs2alZe3atSstK3ttdnd3p2Xt378/LWvevHlpWU888URaVkTEvn370rIyjxuZWZlfTyLyZptIjjM3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKMuFy8/DDD8eVV14Zs2fPjlqtFvfdd9+4j1dVFV/60pdi1qxZ0dPTE0uXLo0tW7ZkzQsA8KomXG4GBgZi0aJFcccddxz247fffnt861vfiu985zvx6KOPxkknnRTLli2LwcHBYx4WAOC1TPiqV8uXL4/ly5cf9mNVVcU3v/nN+MIXvhAf+tCHIiLiu9/9bsyYMSPuu++++OhHP/qK/zM0NBRDQ0Njb/f39090JACAMam/c7N9+/bYuXNnLF26dOx9fX19sXjx4tiwYcNh/8+aNWuir69v7DZnzpzMkQCAE0xqudm5c2dERMyYMWPc+2fMmDH2sT+1evXq2LNnz9htx44dmSMBACeYlr+2VFdXV3R1dbV6DACgEKlnbmbOnBkRr3yhuF27do19DADgeEotN/Pnz4+ZM2fGunXrxt7X398fjz76aCxZsiTzUwEAHNaEfyy1b9++2Lp169jb27dvj8cffzymT58ec+fOjZtvvjm++tWvxlvf+taYP39+fPGLX4zZs2fHVVddlTk3AMBhTbjcPPbYY/H+979/7O1Vq1ZFRMS1114bd999d3zuc5+LgYGB+NSnPhW7d++O9773vfHggw9Gd3d33tQAAEcw4XJz6aWXRlVVR/x4rVaLr3zlK/GVr3zlmAYDADgaXlsKACiKcgMAFKXl17k5klqtFrVa7ZhzOjo6EqY5aHR0NC0rIuIv//Iv07Luv//+tKzs34/q7OxMy2o0GmlZr/bj1YnKnCsiotlspuZlyXyNuIz1/cf++GVcjtXvf//7tKx6vT4psyJyX+4m83plzzzzTFpW9tocGRlJy8p8PjPXU3t7bjUYHh5OyZnIcdGZGwCgKMoNAFAU5QYAKIpyAwAURbkBAIqi3AAARVFuAICiKDcAQFGUGwCgKMoNAFAU5QYAKIpyAwAURbkBAIqi3AAARVFuAICiKDcAQFGUGwCgKMoNAFCU9lYPcCRVVUVVVcecMzo6mjDNQd3d3WlZEREPPPBAWlZbW15PHRoaSsuKiOjt7U3Lynw+FyxYkJa1efPmtKyIiEajkZbV3p63zDP3s8zHGBFRq9XSsrq6utKyOjs707KGh4fTsiIiOjo60rIyZ8vcZtlOPvnktKyXX345LStzbWaupYiIer3+huc4cwMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCK0t7qAY6kVqtFrVY75py2tsnb3zIe3yHNZjMtq7e3Ny0rImJgYCAtK/NxPvHEE2lZ2er1elpWVVVpWZ2dnWlZQ0NDaVkREeeff35a1tatW9OyDhw4kJaVbcqUKWlZ/f39aVkdHR1pWfv27UvLiogYHh5Oy8p8nJkyvzZF5B2DJjLX5P3KDwBwFJQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAo7a0e4Eg6Ojqio6PjmHNGR0cTpsnPiojo6elJy9q/f39a1oEDB9KyIiJqtVpa1pQpU9KyGo1GWtZk1taW9z3MmWeemZb19NNPp2VFRDz55JNpWSMjI2lZVVWlZXV2dqZlReQeN7q7u9OyMtdmV1dXWlZE7r6R6UQ4nk3kMTpzAwAURbkBAIqi3AAARVFuAICiKDcAQFGUGwCgKMoNAFAU5QYAKIpyAwAURbkBAIqi3AAARVFuAICiKDcAQFGUGwCgKMoNAFAU5QYAKIpyAwAURbkBAIqi3AAARWlv9QBH8s53vjNqtdox52zfvj1hmoOGh4fTsiIiDhw4kJaVsa0OmTJlSlpWRMTAwEBa1uDgYFpWpo6OjtS8zOczM+uZZ55Jy8rc/yNyH2ez2UzLytw3hoaG0rIiIrq6utKyMmfL3GaNRiMtKyKirS3vnEDm9q+qKi0rez/LXE+vlzM3AEBRlBsAoCjKDQBQFOUGACiKcgMAFGXC5ebhhx+OK6+8MmbPnh21Wi3uu+++cR+/7rrrolarjbtdccUVWfMCALyqCZebgYGBWLRoUdxxxx1HvM8VV1wRL7zwwtjt+9///jENCQDwek34OjfLly+P5cuXv+p9urq6YubMmUc9FADA0Touv3Pz0EMPxemnnx7nnXdefOYzn4mXXnrpiPcdGhqK/v7+cTcAgKOVXm6uuOKK+O53vxvr1q2Lf/zHf4z169fH8uXLj3iVyDVr1kRfX9/Ybc6cOdkjAQAnkPSXX/joRz869u8LLrggFi5cGGeffXY89NBDcdlll73i/qtXr45Vq1aNvd3f36/gAABH7bj/KfhZZ50Vp556amzduvWwH+/q6opp06aNuwEAHK3jXm6ee+65eOmll2LWrFnH+1MBAEz8x1L79u0bdxZm+/bt8fjjj8f06dNj+vTpcdttt8WKFSti5syZsW3btvjc5z4X55xzTixbtix1cACAw5lwuXnsscfi/e9//9jbh35f5tprr40777wzNm3aFP/6r/8au3fvjtmzZ8fll18ef//3f5/60u4AAEcy4XJz6aWXRlVVR/z4T37yk2MaCADgWHhtKQCgKMoNAFCU9OvcZHnssceit7f3mHMGBwcTpjkoY54/ljlbZ2dnWtbQ0FBaVkQc8QKOR6Ner6dlZc41PDyclhWR+3zOnTs3Lev3v/99WlZPT09aVrZarZaWNTAwkJaVLXO/7ejoSMvKXJvNZjMtKyJ3tra2vPMLo6OjaVnt7bnVICtvIsdFZ24AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUdpbPcCRXHTRRVGr1Y4557//+78TpjloaGgoLSsiol6vp2UNDw+nZTUajbSsiEh5Hg+ZMmVKWtbAwEBaVrPZTMuKiOjs7EzLevrpp9OyMveN0dHRtKyIiPb2vMNZ5vNZVVVaVuYxIyL3cba15X2vnLlvdHV1pWVFRIyMjKRlZR63M7d/tqz9diI5k3drAAAcBeUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGAChKe6sHOJJHH300ent7jzln7969CdMc1N3dnZYVEXHgwIG0rLa2vJ7abDbTsiIi5Xk8ZGhoKC2rq6srLSt7m2Xutx0dHWlZmftZo9FIy4qIGB4eTsvK3Dd6enrSsjIfY0REVVVpWZmzdXZ2pmVlPsaIiL6+vrSsl19+OS0r83GOjo6mZUVEzJ8/PyVnIo/RmRsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGAChKe6sHOJJarRa1Wi0lJ0uj0UjLylav19Oy2tpyO+/o6GhaVubzOTw8nJZ17rnnpmVFRGzZsiUtK/P5bG/PO2RkZkXk7meTNavZbKZlReQeN6ZOnZqWlbk2s7fZwMBAWlZ3d3daVuZ+VlVVWlZE3vFs7969sWjRotd1X2duAICiKDcAQFGUGwCgKMoNAFAU5QYAKIpyAwAURbkBAIqi3AAARVFuAICiKDcAQFGUGwCgKMoNAFAU5QYAKIpyAwAURbkBAIqi3AAARVFuAICiKDcAQFHaWz3AkXR2dkZnZ+cx5wwODiZMc1BVVWlZEREdHR1pWZmz1Wq1tKyIiKGhobSs9va8XTYz66mnnkrLiojo6upKy8rc/pn7WebajIiU48Uh3d3daVl79+5Ny8pem5l5w8PDkzKrXq+nZUVEjI6OpuZlyXycb3vb29KyIiKefPLJlJyJPEZnbgCAoig3AEBRlBsAoCjKDQBQFOUGACjKhMrNmjVr4qKLLore3t44/fTT46qrrorNmzePu8/g4GCsXLkyTjnllJg6dWqsWLEidu3alTo0AMCRTKjcrF+/PlauXBmPPPJI/PSnP42RkZG4/PLLY2BgYOw+t9xyS9x///1xzz33xPr16+P555+Pq6++On1wAIDDmdCFPh588MFxb999991x+umnx8aNG+OSSy6JPXv2xD//8z/H2rVr4wMf+EBERNx1113xtre9LR555JF497vfnTc5AMBhHNPv3OzZsyciIqZPnx4RERs3boyRkZFYunTp2H0WLFgQc+fOjQ0bNhw2Y2hoKPr7+8fdAACO1lGXm2azGTfffHO85z3vifPPPz8iInbu3BmdnZ1x8sknj7vvjBkzYufOnYfNWbNmTfT19Y3d5syZc7QjAQAcfblZuXJlPPHEE/GDH/zgmAZYvXp17NmzZ+y2Y8eOY8oDAE5sR/XiOjfeeGM88MAD8fDDD8cZZ5wx9v6ZM2fG8PBw7N69e9zZm127dsXMmTMPm9XV1ZX6OjoAwIltQmduqqqKG2+8Me699974+c9/HvPnzx/38QsvvDA6Ojpi3bp1Y+/bvHlzPPvss7FkyZKciQEAXsWEztysXLky1q5dGz/+8Y+jt7d37Pdo+vr6oqenJ/r6+uL666+PVatWxfTp02PatGlx0003xZIlS/ylFADwhphQubnzzjsjIuLSSy8d9/677rorrrvuuoiI+MY3vhFtbW2xYsWKGBoaimXLlsW3v/3tlGEBAF7LhMpNVVWveZ/u7u6444474o477jjqoQAAjpbXlgIAiqLcAABFOao/BX8jvPOd74xarXbMOdu3b0+Y5qBGo5GWlW1kZCQtK/tP85vNZlpWxj5xyPDwcFrW6/mR7URkbrPMrKGhobSser2elpUtc9/IlL3NMo9pJ510UlrW4OBgWlb2NstcT5N1Dfz2t79Nzcs6Pk4kx5kbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUJT2Vg9wJI8++mj09vYec87s2bMTpjlox44daVkREUNDQ2lZbW15PfXAgQNpWRGR8jwekrnNuru707KazWZaVkTuc9DenrfMM/ezRqORlhURMTIykpbV1dWVljV16tS0rOHh4bSsiIharZaW1d/fn5aVuTarqkrLioh4y1vekpb18ssvp2Vlrs3M/SLTRJ5LZ24AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAo7a0e4Eg6Ozujs7PzmHNqtVrCNAeNjIykZUVEVFWVltXV1ZWWNTQ0lJYVEdFsNidl1ujoaFpWvV5Py4qIaG+fnEszc/tnrs2IiI6OjrSszNkm8zEoc7/NPJ4NDw+nZWXvZ5lrM3O27u7utKzM7R8R0Wg03vAcZ24AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUdpbPcCRNBqNaDQax5zzwgsvJExz0MDAQFpWRER3d3da1vDwcFpWT09PWlZExP79+9OyzjvvvLSsp59+Oi2r2WymZUVEnHzyyWlZL730UlpWvV5PyxodHU3Liojo6OhIy8pcT0NDQ2lZVVWlZUXk7reZ+0bGsf+QzLkiIl588cW0rHnz5qVl7dq1Ky0rez/r6upKyZnIunTmBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABSlvdUDHElXV1d0dXUdc87+/fsTpjmoqqq0rIiI4eHhtKx6vT4psyIi2tvzdrMtW7akZWU+n7VaLS0rImLPnj1pWRnr6JC2trzvh7K32ejoaGpelsz11Gw207IiIs4///y0rE2bNqVlZR4zsrfZ1KlT07J27dqVltXR0ZGWlf21bnBw8A3PceYGACiKcgMAFEW5AQCKotwAAEVRbgCAokyo3KxZsyYuuuii6O3tjdNPPz2uuuqq2Lx587j7XHrppVGr1cbdPv3pT6cODQBwJBMqN+vXr4+VK1fGI488Ej/96U9jZGQkLr/88hgYGBh3vxtuuCFeeOGFsdvtt9+eOjQAwJFM6GICDz744Li377777jj99NNj48aNcckll4y9f8qUKTFz5sycCQEAJuCYfufm0IXGpk+fPu793/ve9+LUU0+N888/P1avXv2qF9IbGhqK/v7+cTcAgKN11JeBbDabcfPNN8d73vOecVe5/PjHPx7z5s2L2bNnx6ZNm+Lzn/98bN68OX70ox8dNmfNmjVx2223He0YAADjHHW5WblyZTzxxBPxy1/+ctz7P/WpT439+4ILLohZs2bFZZddFtu2bYuzzz77FTmrV6+OVatWjb3d398fc+bMOdqxAIAT3FGVmxtvvDEeeOCBePjhh+OMM8541fsuXrw4IiK2bt162HKT9RpSAAAREyw3VVXFTTfdFPfee2889NBDMX/+/Nf8P48//nhERMyaNeuoBgQAmIgJlZuVK1fG2rVr48c//nH09vbGzp07IyKir68venp6Ytu2bbF27dr44Ac/GKecckps2rQpbrnllrjkkkti4cKFx+UBAAD8sQmVmzvvvDMiDl6o74/dddddcd1110VnZ2f87Gc/i29+85sxMDAQc+bMiRUrVsQXvvCFtIEBAF7NhH8s9WrmzJkT69evP6aBAACOhdeWAgCKotwAAEU56uvcHG8jIyMxMjLS6jHGqdVqqXnNZjMtq7OzMy1r7969aVkREdOmTUvLOnDgQFpWo9FIy1qwYEFaVkTEb3/727Sstra872Ey18Br/Zh7orLXZ5bMtTk0NJSWFRHxu9/9LjUvS+a+Ua/X07IiIqZOnZqWtWvXrrSszK8nmcfGVnHmBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAitLe6gGOpNFoRKPRaPUY43R0dKTmnXnmmWlZzzzzTFpWtoGBgbSsqqrSstra8rr9li1b0rIiIoaHh9OyRkdH07Im6/bPzuvs7EzLytz+9Xo9LSsid5sNDQ2lZZ166qlpWX/4wx/SsiIidu/enZaVuZ5GRkbSstrbc6tB1tfOiTxGZ24AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUZQbAKAoyg0AUBTlBgAoinIDABRFuQEAiqLcAABFUW4AgKIoNwBAUdpbPcCRdHd3R3d39zHnjIyMJExz0PDwcFpWRMSWLVtS87IsXLgwNe93v/tdWlatVkvLynw+6/V6WlZEREdHR1rW6OhoWlaj0UjLmsyGhobSsnp6etKyBgYG0rIicveztra875V3796dltXenvtlrqqqtKwpU6akZXV2dqZlZW7/iLxj0ETWpTM3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjtrR7gSA4cOBDt7cc+XlVVCdMcVK/X07IiImq1WlpW5mybNm1Ky4qI6OjoSMsaHh5Oy5o6dWpa1pw5c9KyIiK2bNmSlpW5n2Vqa8v93qrZbKZldXd3p2Xt378/LStb5nrKlHk8Gx0dTcuKyJ0tc9/IfC4z9/+IvLU5ka8lztwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAorS3eoAjufDCC6NWqx1zzjPPPJMwzUGjo6NpWRER3d3daVmZs3V2dqZlRUQMDQ2l5mU5cOBAWtbmzZvTsiIi2tryvu/I3DeqqkrLypb5OCfrPpstcz+brFnZBgcH07J6e3vTstrb876c79mzJy0rIlK+lkdENBqN133fybsHAQAcBeUGACiKcgMAFEW5AQCKotwAAEVRbgCAokyo3Nx5552xcOHCmDZtWkybNi2WLFkS//Ef/zH28cHBwVi5cmWccsopMXXq1FixYkXs2rUrfWgAgCOZULk544wz4h/+4R9i48aN8dhjj8UHPvCB+NCHPhS//e1vIyLilltuifvvvz/uueeeWL9+fTz//PNx9dVXH5fBAQAOp1Yd41W5pk+fHl/72tfiIx/5SJx22mmxdu3a+MhHPhIREU899VS87W1viw0bNsS73/3uw/7/oaGhcRfM6u/vjzlz5kS9Xi/+In49PT1pWdmzZZqss2VekC774naZF+SarBfxq9fraVkRk/dCls1mMy1rMu9nmTL3jeHh4bSsiNzn86STTkrLOhEu4rd3795YtGhR7NmzJ6ZNm/aq9z3q37lpNBrxgx/8IAYGBmLJkiWxcePGGBkZiaVLl47dZ8GCBTF37tzYsGHDEXPWrFkTfX19Y7c5c+Yc7UgAABMvN7/5zW9i6tSp0dXVFZ/+9Kfj3nvvjbe//e2xc+fO6OzsjJNPPnnc/WfMmBE7d+48Yt7q1atjz549Y7cdO3ZM+EEAABwy4fNY5513Xjz++OOxZ8+e+Pd///e49tprY/369Uc9QFdXV3R1dR31/wcA+GMTLjednZ1xzjnnRMTBF7f8r//6r/h//+//xTXXXBPDw8Oxe/fucWdvdu3aFTNnzkwbGADg1RzzdW6azWYMDQ3FhRdeGB0dHbFu3bqxj23evDmeffbZWLJkybF+GgCA12VCZ25Wr14dy5cvj7lz58bevXtj7dq18dBDD8VPfvKT6Ovri+uvvz5WrVoV06dPj2nTpsVNN90US5YsOeJfSgEAZJtQuXnxxRfjr/7qr+KFF16Ivr6+WLhwYfzkJz+Jv/iLv4iIiG984xvR1tYWK1asiKGhoVi2bFl8+9vfPi6DAwAczjFf5yZbf39/9PX1uc7NBE3Wa8lETN7ZXOdm4lznZuJc52biXOdm4lznZjyvLQUAFEW5AQCKMjnPSUbEpk2bore395hzRkZGEqY5aMqUKWlZERH79+9Py8rYVofs27cvLSsi9zRuW1teH280GmlZ3d3daVkRuT9iyTolHJH744LMU/IRuespc5tl/rggcy1FxNhlPTIceo3BDJn7RvaPxTNnGxgYSMvKlP3jyqyvwxPZ/525AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCKotwAAEVRbgCAoig3AEBRlBsAoCjKDQBQFOUGACiKcgMAFEW5AQCK0t7qAf5UVVUREbFv376UvJGRkZSciIhGo5GWFRGxf//+1LwsWdv+kGazmZbV1pbXxzOfz8z9LCJidHQ0LevQmsqQuf0z94uI3PVUq9UmZVb2NsvcN/bu3ZuWlfk4BwYG0rIicmc7cOBAWlam9vbcapB1fDz0ten17Le1KnPvTvDcc8/FnDlzWj0GADAJ7dixI84444xXvc+kKzfNZjOef/756O3tfdXvePr7+2POnDmxY8eOmDZt2hs4IRG2f6vZ/q3nOWgt27+1WrH9q6qKvXv3xuzZs1/zLPKk+7FUW1vbazayPzZt2jQ7dgvZ/q1l+7ee56C1bP/WeqO3f19f3+u6n18oBgCKotwAAEV505abrq6uuPXWW6Orq6vVo5yQbP/Wsv1bz3PQWrZ/a0327T/pfqEYAOBYvGnP3AAAHI5yAwAURbkBAIqi3AAARVFuAICivCnLzR133BFnnnlmdHd3x+LFi+NXv/pVq0c6YXz5y1+OWq027rZgwYJWj1Wshx9+OK688sqYPXt21Gq1uO+++8Z9vKqq+NKXvhSzZs2Knp6eWLp0aWzZsqU1wxbotbb/dddd94r1cMUVV7Rm2AKtWbMmLrrooujt7Y3TTz89rrrqqti8efO4+wwODsbKlSvjlFNOialTp8aKFSti165dLZq4LK9n+1966aWvWAOf/vSnWzTx/3nTlZsf/vCHsWrVqrj11lvj17/+dSxatCiWLVsWL774YqtHO2G84x3viBdeeGHs9stf/rLVIxVrYGAgFi1aFHfcccdhP3777bfHt771rfjOd74Tjz76aJx00kmxbNmyGBwcfIMnLdNrbf+IiCuuuGLcevj+97//Bk5YtvXr18fKlSvjkUceiZ/+9KcxMjISl19++bhX+r7lllvi/vvvj3vuuSfWr18fzz//fFx99dUtnLocr2f7R0TccMMN49bA7bff3qKJ/0j1JnPxxRdXK1euHHu70WhUs2fPrtasWdPCqU4ct956a7Vo0aJWj3FCiojq3nvvHXu72WxWM2fOrL72ta+NvW/37t1VV1dX9f3vf78FE5btT7d/VVXVtddeW33oQx9qyTwnohdffLGKiGr9+vVVVR3c3zs6Oqp77rln7D5PPvlkFRHVhg0bWjVmsf50+1dVVf35n/959Td/8zetG+oI3lRnboaHh2Pjxo2xdOnSsfe1tbXF0qVLY8OGDS2c7MSyZcuWmD17dpx11lnxiU98Ip599tlWj3RC2r59e+zcuXPceujr64vFixdbD2+ghx56KE4//fQ477zz4jOf+Uy89NJLrR6pWHv27ImIiOnTp0dExMaNG2NkZGTcGliwYEHMnTvXGjgO/nT7H/K9730vTj311Dj//PNj9erVsX///laMN86ke1XwV/OHP/whGo1GzJgxY9z7Z8yYEU899VSLpjqxLF68OO6+++4477zz4oUXXojbbrst3ve+98UTTzwRvb29rR7vhLJz586IiMOuh0Mf4/i64oor4uqrr4758+fHtm3b4u/+7u9i+fLlsWHDhqjX660eryjNZjNuvvnmeM973hPnn39+RBxcA52dnXHyySePu681kO9w2z8i4uMf/3jMmzcvZs+eHZs2bYrPf/7zsXnz5vjRj37UwmnfZOWG1lu+fPnYvxcuXBiLFy+OefPmxb/927/F9ddf38LJ4I330Y9+dOzfF1xwQSxcuDDOPvvseOihh+Kyyy5r4WTlWblyZTzxxBN+x69FjrT9P/WpT439+4ILLohZs2bFZZddFtu2bYuzzz77jR5zzJvqx1Knnnpq1Ov1V/wm/K5du2LmzJktmurEdvLJJ8e5554bW7dubfUoJ5xD+7z1MHmcddZZceqpp1oPyW688cZ44IEH4he/+EWcccYZY++fOXNmDA8Px+7du8fd3xrIdaTtfziLFy+OiGj5GnhTlZvOzs648MILY926dWPvazabsW7duliyZEkLJztx7du3L7Zt2xazZs1q9SgnnPnz58fMmTPHrYf+/v549NFHrYcWee655+Kll16yHpJUVRU33nhj3HvvvfHzn/885s+fP+7jF154YXR0dIxbA5s3b45nn33WGkjwWtv/cB5//PGIiJavgTfdj6VWrVoV1157bbzrXe+Kiy++OL75zW/GwMBAfPKTn2z1aCeEz372s3HllVfGvHnz4vnnn49bb7016vV6fOxjH2v1aEXat2/fuO+Atm/fHo8//nhMnz495s6dGzfffHN89atfjbe+9a0xf/78+OIXvxizZ8+Oq666qnVDF+TVtv/06dPjtttuixUrVsTMmTNj27Zt8bnPfS7OOeecWLZsWQunLsfKlStj7dq18eMf/zh6e3vHfo+mr68venp6oq+vL66//vpYtWpVTJ8+PaZNmxY33XRTLFmyJN797ne3ePo3v9fa/tu2bYu1a9fGBz/4wTjllFNi06ZNccstt8Qll1wSCxcubO3wrf5zraPxT//0T9XcuXOrzs7O6uKLL64eeeSRVo90wrjmmmuqWbNmVZ2dndWf/dmfVddcc021devWVo9VrF/84hdVRLzidu2111ZVdfDPwb/4xS9WM2bMqLq6uqrLLrus2rx5c2uHLsirbf/9+/dXl19+eXXaaadVHR0d1bx586obbrih2rlzZ6vHLsbhtn1EVHfdddfYfQ4cOFD99V//dfWWt7ylmjJlSvXhD3+4euGFF1o3dEFea/s/++yz1SWXXFJNnz696urqqs4555zqb//2b6s9e/a0dvCqqmpVVVVvZJkCADie3lS/cwMA8FqUGwCgKMoNAFAU5QYAKIpyAwAURbkBAIqi3AAARVFuAICiKDcAQFGUGwCgKMoNAFCU/w//ZE9Dt1OyrAAAAABJRU5ErkJggg==",
|
397 |
+
"text/plain": [
|
398 |
+
"<Figure size 800x800 with 1 Axes>"
|
399 |
+
]
|
400 |
+
},
|
401 |
+
"metadata": {},
|
402 |
+
"output_type": "display_data"
|
403 |
+
}
|
404 |
+
],
|
405 |
+
"source": [
|
406 |
+
"plt.figure(figsize=(8,8))\n",
|
407 |
+
"plt.imshow(dlogits.detach(), cmap='gray')"
|
408 |
+
]
|
409 |
+
}
|
410 |
+
],
|
411 |
+
"metadata": {
|
412 |
+
"kernelspec": {
|
413 |
+
"display_name": "venv",
|
414 |
+
"language": "python",
|
415 |
+
"name": "python3"
|
416 |
+
},
|
417 |
+
"language_info": {
|
418 |
+
"codemirror_mode": {
|
419 |
+
"name": "ipython",
|
420 |
+
"version": 3
|
421 |
+
},
|
422 |
+
"file_extension": ".py",
|
423 |
+
"mimetype": "text/x-python",
|
424 |
+
"name": "python",
|
425 |
+
"nbconvert_exporter": "python",
|
426 |
+
"pygments_lexer": "ipython3",
|
427 |
+
"version": "3.10.0"
|
428 |
+
}
|
429 |
+
},
|
430 |
+
"nbformat": 4,
|
431 |
+
"nbformat_minor": 2
|
432 |
+
}
|
exercise-3.ipynb
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"# This exercise wasn't exactly smooth sailing for me, but I did try to understand most of it. Will try to come back to this whenever I can"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"# there no change change in the first several cells from last lecture\n",
|
19 |
+
"\n",
|
20 |
+
"import torch\n",
|
21 |
+
"import torch.nn.functional as F\n",
|
22 |
+
"import matplotlib.pyplot as plt # for making figures\n",
|
23 |
+
"%matplotlib inline"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": null,
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"# download the names.txt file from github\n",
|
33 |
+
"!wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "code",
|
38 |
+
"execution_count": 3,
|
39 |
+
"metadata": {},
|
40 |
+
"outputs": [],
|
41 |
+
"source": [
|
42 |
+
"# read in all the words\n",
|
43 |
+
"words = open('names.txt', 'r').read().splitlines()\n",
|
44 |
+
"# print(len(words))\n",
|
45 |
+
"# print(max(len(w) for w in words))\n",
|
46 |
+
"# print(words[:8])\n",
|
47 |
+
"\n",
|
48 |
+
"# build the vocabulary of characters and mappings to/from integers\n",
|
49 |
+
"chars = sorted(list(set(''.join(words))))\n",
|
50 |
+
"stoi = {s:i+1 for i,s in enumerate(chars)}\n",
|
51 |
+
"stoi['.'] = 0\n",
|
52 |
+
"itos = {i:s for s,i in stoi.items()}\n",
|
53 |
+
"vocab_size = len(itos)\n",
|
54 |
+
"# print(itos)\n",
|
55 |
+
"# print(vocab_size)\n",
|
56 |
+
"\n",
|
57 |
+
"# build the dataset\n",
|
58 |
+
"block_size = 3 # context length: how many characters do we take to predict the next one?\n",
|
59 |
+
"\n",
|
60 |
+
"def build_dataset(words):\n",
|
61 |
+
" X, Y = [], []\n",
|
62 |
+
"\n",
|
63 |
+
" for w in words:\n",
|
64 |
+
" context = [0] * block_size\n",
|
65 |
+
" for ch in w + '.':\n",
|
66 |
+
" ix = stoi[ch]\n",
|
67 |
+
" X.append(context)\n",
|
68 |
+
" Y.append(ix)\n",
|
69 |
+
" context = context[1:] + [ix] # crop and append\n",
|
70 |
+
"\n",
|
71 |
+
" X = torch.tensor(X)\n",
|
72 |
+
" Y = torch.tensor(Y)\n",
|
73 |
+
" # print(X.shape, Y.shape)\n",
|
74 |
+
" return X, Y\n",
|
75 |
+
"\n",
|
76 |
+
"import random\n",
|
77 |
+
"random.seed(42)\n",
|
78 |
+
"random.shuffle(words)\n",
|
79 |
+
"n1 = int(0.8*len(words))\n",
|
80 |
+
"n2 = int(0.9*len(words))\n",
|
81 |
+
"\n",
|
82 |
+
"Xtr, Ytr = build_dataset(words[:n1]) # 80%\n",
|
83 |
+
"Xdev, Ydev = build_dataset(words[n1:n2]) # 10%\n",
|
84 |
+
"Xte, Yte = build_dataset(words[n2:]) # 10%"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "code",
|
89 |
+
"execution_count": 4,
|
90 |
+
"metadata": {},
|
91 |
+
"outputs": [],
|
92 |
+
"source": [
|
93 |
+
"# utility function we will use later when comparing manual gradients to PyTorch gradients\n",
|
94 |
+
"def cmp(s, dt, t):\n",
|
95 |
+
" ex = torch.all(dt == t.grad).item()\n",
|
96 |
+
" app = torch.allclose(dt, t.grad)\n",
|
97 |
+
" maxdiff = (dt - t.grad).abs().max().item()\n",
|
98 |
+
" print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')"
|
99 |
+
]
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "code",
|
103 |
+
"execution_count": 9,
|
104 |
+
"metadata": {},
|
105 |
+
"outputs": [
|
106 |
+
{
|
107 |
+
"name": "stdout",
|
108 |
+
"output_type": "stream",
|
109 |
+
"text": [
|
110 |
+
"4137\n"
|
111 |
+
]
|
112 |
+
}
|
113 |
+
],
|
114 |
+
"source": [
|
115 |
+
"n_embd = 10 # the dimensionality of the character embedding vectors\n",
|
116 |
+
"n_hidden = 64 # the number of neurons in the hidden layer of the MLP\n",
|
117 |
+
"\n",
|
118 |
+
"g = torch.Generator().manual_seed(2147483647) # for reproducibility\n",
|
119 |
+
"C = torch.randn((vocab_size, n_embd), generator=g)\n",
|
120 |
+
"# Layer 1\n",
|
121 |
+
"W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)\n",
|
122 |
+
"b1 = torch.randn(n_hidden, generator=g) * 0.1 # using b1 just for fun, it's useless because of BN\n",
|
123 |
+
"# Layer 2\n",
|
124 |
+
"W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1\n",
|
125 |
+
"b2 = torch.randn(vocab_size, generator=g) * 0.1\n",
|
126 |
+
"# BatchNorm parameters\n",
|
127 |
+
"bngain = torch.randn((1, n_hidden))*0.1 + 1.0\n",
|
128 |
+
"bnbias = torch.randn((1, n_hidden))*0.1\n",
|
129 |
+
"\n",
|
130 |
+
"# Note: I am initializating many of these parameters in non-standard ways\n",
|
131 |
+
"# because sometimes initializating with e.g. all zeros could mask an incorrect\n",
|
132 |
+
"# implementation of the backward pass.\n",
|
133 |
+
"\n",
|
134 |
+
"parameters = [C, W1, b1, W2, b2, bngain, bnbias]\n",
|
135 |
+
"print(sum(p.nelement() for p in parameters)) # number of parameters in total\n",
|
136 |
+
"for p in parameters:\n",
|
137 |
+
" p.requires_grad = True"
|
138 |
+
]
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"cell_type": "code",
|
142 |
+
"execution_count": 10,
|
143 |
+
"metadata": {},
|
144 |
+
"outputs": [],
|
145 |
+
"source": [
|
146 |
+
"batch_size = 32\n",
|
147 |
+
"n = batch_size # a shorter variable also, for convenience\n",
|
148 |
+
"# construct a minibatch\n",
|
149 |
+
"ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)\n",
|
150 |
+
"Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": 11,
|
156 |
+
"metadata": {},
|
157 |
+
"outputs": [
|
158 |
+
{
|
159 |
+
"data": {
|
160 |
+
"text/plain": [
|
161 |
+
"tensor(3.3479, grad_fn=<NegBackward0>)"
|
162 |
+
]
|
163 |
+
},
|
164 |
+
"execution_count": 11,
|
165 |
+
"metadata": {},
|
166 |
+
"output_type": "execute_result"
|
167 |
+
}
|
168 |
+
],
|
169 |
+
"source": [
|
170 |
+
"# forward pass, \"chunkated\" into smaller steps that are possible to backward one at a time\n",
|
171 |
+
"\n",
|
172 |
+
"emb = C[Xb] # embed the characters into vectors\n",
|
173 |
+
"embcat = emb.view(emb.shape[0], -1) # concatenate the vectors\n",
|
174 |
+
"# Linear layer 1\n",
|
175 |
+
"hprebn = embcat @ W1 + b1 # hidden layer pre-activation\n",
|
176 |
+
"# BatchNorm layer\n",
|
177 |
+
"bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n",
|
178 |
+
"bndiff = hprebn - bnmeani\n",
|
179 |
+
"bndiff2 = bndiff**2\n",
|
180 |
+
"bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)\n",
|
181 |
+
"bnvar_inv = (bnvar + 1e-5)**-0.5\n",
|
182 |
+
"bnraw = bndiff * bnvar_inv\n",
|
183 |
+
"hpreact = bngain * bnraw + bnbias\n",
|
184 |
+
"# Non-linearity\n",
|
185 |
+
"h = torch.tanh(hpreact) # hidden layer\n",
|
186 |
+
"# Linear layer 2\n",
|
187 |
+
"logits = h @ W2 + b2 # output layer\n",
|
188 |
+
"# cross entropy loss (same as F.cross_entropy(logits, Yb))\n",
|
189 |
+
"logit_maxes = logits.max(1, keepdim=True).values\n",
|
190 |
+
"norm_logits = logits - logit_maxes # subtract max for numerical stability\n",
|
191 |
+
"counts = norm_logits.exp()\n",
|
192 |
+
"counts_sum = counts.sum(1, keepdims=True)\n",
|
193 |
+
"counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...\n",
|
194 |
+
"probs = counts * counts_sum_inv\n",
|
195 |
+
"logprobs = probs.log()\n",
|
196 |
+
"loss = -logprobs[range(n), Yb].mean()\n",
|
197 |
+
"\n",
|
198 |
+
"# PyTorch backward pass\n",
|
199 |
+
"for p in parameters:\n",
|
200 |
+
" p.grad = None\n",
|
201 |
+
"for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way\n",
|
202 |
+
" norm_logits, logit_maxes, logits, h, hpreact, bnraw,\n",
|
203 |
+
" bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,\n",
|
204 |
+
" embcat, emb]:\n",
|
205 |
+
" t.retain_grad()\n",
|
206 |
+
"loss.backward()\n",
|
207 |
+
"loss"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "code",
|
212 |
+
"execution_count": 12,
|
213 |
+
"metadata": {},
|
214 |
+
"outputs": [],
|
215 |
+
"source": [
|
216 |
+
"#The entire Exercise 1 implementation combined\n",
|
217 |
+
"\n",
|
218 |
+
"dlogprobs = torch.zeros_like(logprobs)\n",
|
219 |
+
"dlogprobs[range(n), Yb] = -1.0/n\n",
|
220 |
+
"dprobs = (1.0 / probs) * dlogprobs\n",
|
221 |
+
"dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)\n",
|
222 |
+
"dcounts = counts_sum_inv * dprobs\n",
|
223 |
+
"dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv\n",
|
224 |
+
"dcounts += torch.ones_like(counts) * dcounts_sum\n",
|
225 |
+
"dnorm_logits = counts * dcounts\n",
|
226 |
+
"dlogits = dnorm_logits.clone()\n",
|
227 |
+
"dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)\n",
|
228 |
+
"dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes\n",
|
229 |
+
"dh = dlogits @ W2.T\n",
|
230 |
+
"dW2 = h.T @ dlogits\n",
|
231 |
+
"db2 = dlogits.sum(0)\n",
|
232 |
+
"dhpreact = (1.0 - h**2) * dh\n",
|
233 |
+
"dbngain = (bnraw * dhpreact).sum(0, keepdim=True)\n",
|
234 |
+
"dbnraw = bngain * dhpreact\n",
|
235 |
+
"dbnbias = dhpreact.sum(0, keepdim=True)\n",
|
236 |
+
"dbndiff = bnvar_inv * dbnraw\n",
|
237 |
+
"dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)\n",
|
238 |
+
"dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv\n",
|
239 |
+
"dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2) * dbnvar\n",
|
240 |
+
"dbndiff += (2*bndiff) * dbndiff2\n",
|
241 |
+
"dhprebn = dbndiff.clone()\n",
|
242 |
+
"dbnmeani = (-dbndiff).sum(0)\n",
|
243 |
+
"dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmeani)\n",
|
244 |
+
"dembcat = dhprebn @ W1.T\n",
|
245 |
+
"dW1 = embcat.T @ dhprebn\n",
|
246 |
+
"db1 = dhprebn.sum(0)\n",
|
247 |
+
"demb = dembcat.view(emb.shape)\n",
|
248 |
+
"dC = torch.zeros_like(C)\n",
|
249 |
+
"for k in range(Xb.shape[0]):\n",
|
250 |
+
" for j in range(Xb.shape[1]):\n",
|
251 |
+
" ix = Xb[k,j]\n",
|
252 |
+
" dC[ix] += demb[k,j]"
|
253 |
+
]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"cell_type": "markdown",
|
257 |
+
"metadata": {},
|
258 |
+
"source": [
|
259 |
+
"Similar boiler plate codes as done in the prev exercise and provided in the starter code^\n",
|
260 |
+
"\n",
|
261 |
+
"------------"
|
262 |
+
]
|
263 |
+
},
|
264 |
+
{
|
265 |
+
"cell_type": "markdown",
|
266 |
+
"metadata": {},
|
267 |
+
"source": [
|
268 |
+
"[1:36:38](https://youtu.be/q8SA3rM6ckI?si=Lo5Ly5jApvwIBfy9&t=6516) to 1:48:35 - Pen and Paper derivation"
|
269 |
+
]
|
270 |
+
},
|
271 |
+
{
|
272 |
+
"cell_type": "markdown",
|
273 |
+
"metadata": {},
|
274 |
+
"source": [
|
275 |
+
"[1:48:36](https://youtu.be/q8SA3rM6ckI?si=Lo5Ly5jApvwIBfy9&t=6516) to - Implementation of the derivation in code"
|
276 |
+
]
|
277 |
+
},
|
278 |
+
{
|
279 |
+
"cell_type": "code",
|
280 |
+
"execution_count": 13,
|
281 |
+
"metadata": {},
|
282 |
+
"outputs": [
|
283 |
+
{
|
284 |
+
"name": "stdout",
|
285 |
+
"output_type": "stream",
|
286 |
+
"text": [
|
287 |
+
"max diff: tensor(7.1526e-07, grad_fn=<MaxBackward1>)\n"
|
288 |
+
]
|
289 |
+
}
|
290 |
+
],
|
291 |
+
"source": [
|
292 |
+
"# Exercise 3: backprop through batchnorm but all in one go\n",
|
293 |
+
"# to complete this challenge look at the mathematical expression of the output of batchnorm,\n",
|
294 |
+
"# take the derivative w.r.t. its input, simplify the expression, and just write it out\n",
|
295 |
+
"# BatchNorm paper: https://arxiv.org/abs/1502.03167\n",
|
296 |
+
"\n",
|
297 |
+
"# forward pass\n",
|
298 |
+
"\n",
|
299 |
+
"# before:\n",
|
300 |
+
"# bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n",
|
301 |
+
"# bndiff = hprebn - bnmeani\n",
|
302 |
+
"# bndiff2 = bndiff**2\n",
|
303 |
+
"# bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)\n",
|
304 |
+
"# bnvar_inv = (bnvar + 1e-5)**-0.5\n",
|
305 |
+
"# bnraw = bndiff * bnvar_inv\n",
|
306 |
+
"# hpreact = bngain * bnraw + bnbias\n",
|
307 |
+
"\n",
|
308 |
+
"# now:\n",
|
309 |
+
"hpreact_fast = bngain * (hprebn - hprebn.mean(0, keepdim=True)) / torch.sqrt(hprebn.var(0, keepdim=True, unbiased=True) + 1e-5) + bnbias\n",
|
310 |
+
"print('max diff:', (hpreact_fast - hpreact).abs().max())"
|
311 |
+
]
|
312 |
+
},
|
313 |
+
{
|
314 |
+
"cell_type": "code",
|
315 |
+
"execution_count": 14,
|
316 |
+
"metadata": {},
|
317 |
+
"outputs": [
|
318 |
+
{
|
319 |
+
"name": "stdout",
|
320 |
+
"output_type": "stream",
|
321 |
+
"text": [
|
322 |
+
"hprebn | exact: False | approximate: True | maxdiff: 9.313225746154785e-10\n"
|
323 |
+
]
|
324 |
+
}
|
325 |
+
],
|
326 |
+
"source": [
|
327 |
+
"# backward pass\n",
|
328 |
+
"\n",
|
329 |
+
"# before we had:\n",
|
330 |
+
"# dbnraw = bngain * dhpreact\n",
|
331 |
+
"# dbndiff = bnvar_inv * dbnraw\n",
|
332 |
+
"# dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)\n",
|
333 |
+
"# dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv\n",
|
334 |
+
"# dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2) * dbnvar\n",
|
335 |
+
"# dbndiff += (2*bndiff) * dbndiff2\n",
|
336 |
+
"# dhprebn = dbndiff.clone()\n",
|
337 |
+
"# dbnmeani = (-dbndiff).sum(0)\n",
|
338 |
+
"# dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmeani)\n",
|
339 |
+
"\n",
|
340 |
+
"# calculate dhprebn given dhpreact (i.e. backprop through the batchnorm)\n",
|
341 |
+
"# (you'll also need to use some of the variables from the forward pass up above)\n",
|
342 |
+
"\n",
|
343 |
+
"#This is a direct implementation of what sensei did, as he said in the video this equation itself has a lot of breakdown steps to be considered\n",
|
344 |
+
"#But this is what we come up with at the end\n",
|
345 |
+
"dhprebn = bngain*bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))\n",
|
346 |
+
"\n",
|
347 |
+
"cmp('hprebn', dhprebn, hprebn) # I can only get approximate to be true, my maxdiff is 9e-10"
|
348 |
+
]
|
349 |
+
}
|
350 |
+
],
|
351 |
+
"metadata": {
|
352 |
+
"kernelspec": {
|
353 |
+
"display_name": "venv",
|
354 |
+
"language": "python",
|
355 |
+
"name": "python3"
|
356 |
+
},
|
357 |
+
"language_info": {
|
358 |
+
"codemirror_mode": {
|
359 |
+
"name": "ipython",
|
360 |
+
"version": 3
|
361 |
+
},
|
362 |
+
"file_extension": ".py",
|
363 |
+
"mimetype": "text/x-python",
|
364 |
+
"name": "python",
|
365 |
+
"nbconvert_exporter": "python",
|
366 |
+
"pygments_lexer": "ipython3",
|
367 |
+
"version": "3.10.0"
|
368 |
+
}
|
369 |
+
},
|
370 |
+
"nbformat": 4,
|
371 |
+
"nbformat_minor": 2
|
372 |
+
}
|
final-code.ipynb
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import torch\n",
|
10 |
+
"import torch.nn.functional as F\n",
|
11 |
+
"import matplotlib.pyplot as plt # for making figures\n",
|
12 |
+
"%matplotlib inline"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": 2,
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [
|
20 |
+
{
|
21 |
+
"name": "stderr",
|
22 |
+
"output_type": "stream",
|
23 |
+
"text": [
|
24 |
+
"'wget' is not recognized as an internal or external command,\n",
|
25 |
+
"operable program or batch file.\n"
|
26 |
+
]
|
27 |
+
}
|
28 |
+
],
|
29 |
+
"source": [
|
30 |
+
"# download the names.txt file from github\n",
|
31 |
+
"!wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": 3,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [
|
39 |
+
{
|
40 |
+
"name": "stdout",
|
41 |
+
"output_type": "stream",
|
42 |
+
"text": [
|
43 |
+
"32033\n",
|
44 |
+
"15\n",
|
45 |
+
"['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']\n"
|
46 |
+
]
|
47 |
+
}
|
48 |
+
],
|
49 |
+
"source": [
|
50 |
+
"# read in all the words\n",
|
51 |
+
"words = open('names.txt', 'r').read().splitlines()\n",
|
52 |
+
"print(len(words))\n",
|
53 |
+
"print(max(len(w) for w in words))\n",
|
54 |
+
"print(words[:8])"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": 4,
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [
|
62 |
+
{
|
63 |
+
"name": "stdout",
|
64 |
+
"output_type": "stream",
|
65 |
+
"text": [
|
66 |
+
"{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}\n",
|
67 |
+
"27\n"
|
68 |
+
]
|
69 |
+
}
|
70 |
+
],
|
71 |
+
"source": [
|
72 |
+
"# build the vocabulary of characters and mappings to/from integers\n",
|
73 |
+
"chars = sorted(list(set(''.join(words))))\n",
|
74 |
+
"stoi = {s:i+1 for i,s in enumerate(chars)}\n",
|
75 |
+
"stoi['.'] = 0\n",
|
76 |
+
"itos = {i:s for s,i in stoi.items()}\n",
|
77 |
+
"vocab_size = len(itos)\n",
|
78 |
+
"print(itos)\n",
|
79 |
+
"print(vocab_size)"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": 5,
|
85 |
+
"metadata": {},
|
86 |
+
"outputs": [
|
87 |
+
{
|
88 |
+
"name": "stdout",
|
89 |
+
"output_type": "stream",
|
90 |
+
"text": [
|
91 |
+
"torch.Size([182625, 3]) torch.Size([182625])\n",
|
92 |
+
"torch.Size([22655, 3]) torch.Size([22655])\n",
|
93 |
+
"torch.Size([22866, 3]) torch.Size([22866])\n"
|
94 |
+
]
|
95 |
+
}
|
96 |
+
],
|
97 |
+
"source": [
|
98 |
+
"# build the dataset\n",
|
99 |
+
"block_size = 3 # context length: how many characters do we take to predict the next one?\n",
|
100 |
+
"\n",
|
101 |
+
"def build_dataset(words):\n",
|
102 |
+
" X, Y = [], []\n",
|
103 |
+
"\n",
|
104 |
+
" for w in words:\n",
|
105 |
+
" context = [0] * block_size\n",
|
106 |
+
" for ch in w + '.':\n",
|
107 |
+
" ix = stoi[ch]\n",
|
108 |
+
" X.append(context)\n",
|
109 |
+
" Y.append(ix)\n",
|
110 |
+
" context = context[1:] + [ix] # crop and append\n",
|
111 |
+
"\n",
|
112 |
+
" X = torch.tensor(X)\n",
|
113 |
+
" Y = torch.tensor(Y)\n",
|
114 |
+
" print(X.shape, Y.shape)\n",
|
115 |
+
" return X, Y\n",
|
116 |
+
"\n",
|
117 |
+
"import random\n",
|
118 |
+
"random.seed(42)\n",
|
119 |
+
"random.shuffle(words)\n",
|
120 |
+
"n1 = int(0.8*len(words))\n",
|
121 |
+
"n2 = int(0.9*len(words))\n",
|
122 |
+
"\n",
|
123 |
+
"Xtr, Ytr = build_dataset(words[:n1]) # 80%\n",
|
124 |
+
"Xdev, Ydev = build_dataset(words[n1:n2]) # 10%\n",
|
125 |
+
"Xte, Yte = build_dataset(words[n2:]) # 10%"
|
126 |
+
]
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"cell_type": "code",
|
130 |
+
"execution_count": 6,
|
131 |
+
"metadata": {},
|
132 |
+
"outputs": [],
|
133 |
+
"source": [
|
134 |
+
"# utility function we will use later when comparing manual gradients to PyTorch gradients\n",
|
135 |
+
"def cmp(s, dt, t):\n",
|
136 |
+
" ex = torch.all(dt == t.grad).item()\n",
|
137 |
+
" app = torch.allclose(dt, t.grad)\n",
|
138 |
+
" maxdiff = (dt - t.grad).abs().max().item()\n",
|
139 |
+
" print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "markdown",
|
144 |
+
"metadata": {},
|
145 |
+
"source": [
|
146 |
+
"---------"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"cell_type": "markdown",
|
151 |
+
"metadata": {},
|
152 |
+
"source": [
|
153 |
+
"[1:50:03](https://youtu.be/q8SA3rM6ckI?si=HV1qfgMFKATlDxk1&t=6603) to 1:54:25 - Putting all of the codes together to form a Neural Net, but by commenting out the `loss.backward()` :)"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "code",
|
158 |
+
"execution_count": 7,
|
159 |
+
"metadata": {},
|
160 |
+
"outputs": [
|
161 |
+
{
|
162 |
+
"name": "stdout",
|
163 |
+
"output_type": "stream",
|
164 |
+
"text": [
|
165 |
+
"12297\n",
|
166 |
+
" 0/ 200000: 3.8069\n",
|
167 |
+
" 10000/ 200000: 2.1598\n",
|
168 |
+
" 20000/ 200000: 2.4110\n",
|
169 |
+
" 30000/ 200000: 2.4295\n",
|
170 |
+
" 40000/ 200000: 2.0158\n",
|
171 |
+
" 50000/ 200000: 2.4050\n",
|
172 |
+
" 60000/ 200000: 2.3825\n",
|
173 |
+
" 70000/ 200000: 2.0596\n",
|
174 |
+
" 80000/ 200000: 2.3024\n",
|
175 |
+
" 90000/ 200000: 2.2073\n",
|
176 |
+
" 100000/ 200000: 2.0443\n",
|
177 |
+
" 110000/ 200000: 2.2937\n",
|
178 |
+
" 120000/ 200000: 2.0340\n",
|
179 |
+
" 130000/ 200000: 2.4557\n",
|
180 |
+
" 140000/ 200000: 2.2876\n",
|
181 |
+
" 150000/ 200000: 2.2016\n",
|
182 |
+
" 160000/ 200000: 1.9720\n",
|
183 |
+
" 170000/ 200000: 1.8015\n",
|
184 |
+
" 180000/ 200000: 2.0065\n",
|
185 |
+
" 190000/ 200000: 1.9932\n"
|
186 |
+
]
|
187 |
+
}
|
188 |
+
],
|
189 |
+
"source": [
|
190 |
+
"# Exercise 4: putting it all together!\n",
|
191 |
+
"# Train the MLP neural net with your own backward pass\n",
|
192 |
+
"\n",
|
193 |
+
"# init\n",
|
194 |
+
"n_embd = 10 # the dimensionality of the character embedding vectors\n",
|
195 |
+
"n_hidden = 200 # the number of neurons in the hidden layer of the MLP\n",
|
196 |
+
"\n",
|
197 |
+
"g = torch.Generator().manual_seed(2147483647) # for reproducibility\n",
|
198 |
+
"C = torch.randn((vocab_size, n_embd), generator=g)\n",
|
199 |
+
"# Layer 1\n",
|
200 |
+
"W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)\n",
|
201 |
+
"b1 = torch.randn(n_hidden, generator=g) * 0.1\n",
|
202 |
+
"# Layer 2\n",
|
203 |
+
"W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1\n",
|
204 |
+
"b2 = torch.randn(vocab_size, generator=g) * 0.1\n",
|
205 |
+
"# BatchNorm parameters\n",
|
206 |
+
"bngain = torch.randn((1, n_hidden))*0.1 + 1.0\n",
|
207 |
+
"bnbias = torch.randn((1, n_hidden))*0.1\n",
|
208 |
+
"\n",
|
209 |
+
"parameters = [C, W1, b1, W2, b2, bngain, bnbias]\n",
|
210 |
+
"print(sum(p.nelement() for p in parameters)) # number of parameters in total\n",
|
211 |
+
"for p in parameters:\n",
|
212 |
+
" p.requires_grad = True\n",
|
213 |
+
"\n",
|
214 |
+
"# same optimization as last time\n",
|
215 |
+
"max_steps = 200000\n",
|
216 |
+
"batch_size = 32\n",
|
217 |
+
"n = batch_size # convenience\n",
|
218 |
+
"lossi = []\n",
|
219 |
+
"\n",
|
220 |
+
"# use this context manager for efficiency once your backward pass is written (TODO)\n",
|
221 |
+
"with torch.no_grad():\n",
|
222 |
+
"\n",
|
223 |
+
" # kick off optimization\n",
|
224 |
+
" for i in range(max_steps):\n",
|
225 |
+
"\n",
|
226 |
+
" # minibatch construct\n",
|
227 |
+
" ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)\n",
|
228 |
+
" Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y\n",
|
229 |
+
"\n",
|
230 |
+
" # forward pass\n",
|
231 |
+
" emb = C[Xb] # embed the characters into vectors\n",
|
232 |
+
" embcat = emb.view(emb.shape[0], -1) # concatenate the vectors\n",
|
233 |
+
" # Linear layer\n",
|
234 |
+
" hprebn = embcat @ W1 + b1 # hidden layer pre-activation\n",
|
235 |
+
" # BatchNorm layer\n",
|
236 |
+
" # -------------------------------------------------------------\n",
|
237 |
+
" bnmean = hprebn.mean(0, keepdim=True)\n",
|
238 |
+
" bnvar = hprebn.var(0, keepdim=True, unbiased=True)\n",
|
239 |
+
" bnvar_inv = (bnvar + 1e-5)**-0.5\n",
|
240 |
+
" bnraw = (hprebn - bnmean) * bnvar_inv\n",
|
241 |
+
" hpreact = bngain * bnraw + bnbias\n",
|
242 |
+
" # -------------------------------------------------------------\n",
|
243 |
+
" # Non-linearity\n",
|
244 |
+
" h = torch.tanh(hpreact) # hidden layer\n",
|
245 |
+
" logits = h @ W2 + b2 # output layer\n",
|
246 |
+
" loss = F.cross_entropy(logits, Yb) # loss function\n",
|
247 |
+
"\n",
|
248 |
+
" # backward pass\n",
|
249 |
+
" for p in parameters:\n",
|
250 |
+
" p.grad = None\n",
|
251 |
+
" #loss.backward() # use this for correctness comparisons, delete it later!\n",
|
252 |
+
"\n",
|
253 |
+
" # manual backprop! #swole_doge_meme\n",
|
254 |
+
" # -----------------\n",
|
255 |
+
" dlogits = F.softmax(logits, 1)\n",
|
256 |
+
" dlogits[range(n), Yb] -= 1\n",
|
257 |
+
" dlogits /= n\n",
|
258 |
+
" # 2nd layer backprop\n",
|
259 |
+
" dh = dlogits @ W2.T\n",
|
260 |
+
" dW2 = h.T @ dlogits\n",
|
261 |
+
" db2 = dlogits.sum(0)\n",
|
262 |
+
" # tanh\n",
|
263 |
+
" dhpreact = (1.0 - h**2) * dh\n",
|
264 |
+
" # batchnorm backprop\n",
|
265 |
+
" dbngain = (bnraw * dhpreact).sum(0, keepdim=True)\n",
|
266 |
+
" dbnbias = dhpreact.sum(0, keepdim=True)\n",
|
267 |
+
" dhprebn = bngain*bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))\n",
|
268 |
+
" # 1st layer\n",
|
269 |
+
" dembcat = dhprebn @ W1.T\n",
|
270 |
+
" dW1 = embcat.T @ dhprebn\n",
|
271 |
+
" db1 = dhprebn.sum(0)\n",
|
272 |
+
" # embedding\n",
|
273 |
+
" demb = dembcat.view(emb.shape)\n",
|
274 |
+
" dC = torch.zeros_like(C)\n",
|
275 |
+
" for k in range(Xb.shape[0]):\n",
|
276 |
+
" for j in range(Xb.shape[1]):\n",
|
277 |
+
" ix = Xb[k,j]\n",
|
278 |
+
" dC[ix] += demb[k,j]\n",
|
279 |
+
" grads = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]\n",
|
280 |
+
" # -----------------\n",
|
281 |
+
"\n",
|
282 |
+
" # update\n",
|
283 |
+
" lr = 0.1 if i < 100000 else 0.01 # step learning rate decay\n",
|
284 |
+
" for p, grad in zip(parameters, grads):\n",
|
285 |
+
" #p.data += -lr * p.grad # old way of cheems doge (using PyTorch grad from .backward())\n",
|
286 |
+
" p.data += -lr * grad # new way of swole doge TODO: enable\n",
|
287 |
+
"\n",
|
288 |
+
" # track stats\n",
|
289 |
+
" if i % 10000 == 0: # print every once in a while\n",
|
290 |
+
" print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')\n",
|
291 |
+
" lossi.append(loss.log10().item())\n",
|
292 |
+
"\n",
|
293 |
+
" # if i >= 100: # TODO: delete early breaking when you're ready to train the full net\n",
|
294 |
+
" # break"
|
295 |
+
]
|
296 |
+
},
|
297 |
+
{
|
298 |
+
"cell_type": "code",
|
299 |
+
"execution_count": null,
|
300 |
+
"metadata": {},
|
301 |
+
"outputs": [],
|
302 |
+
"source": [
|
303 |
+
"# Looking at this, probably the batch norm layer backward pass was the most complicated one\n",
|
304 |
+
"# Otherwise, the rest of them were pretty straight forward :)"
|
305 |
+
]
|
306 |
+
},
|
307 |
+
{
|
308 |
+
"cell_type": "code",
|
309 |
+
"execution_count": null,
|
310 |
+
"metadata": {},
|
311 |
+
"outputs": [],
|
312 |
+
"source": [
|
313 |
+
"# useful for checking your gradients\n",
|
314 |
+
"# for p,g in zip(parameters, grads):\n",
|
315 |
+
"# cmp(str(tuple(p.shape)), g, p)"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
{
|
319 |
+
"cell_type": "code",
|
320 |
+
"execution_count": 8,
|
321 |
+
"metadata": {},
|
322 |
+
"outputs": [],
|
323 |
+
"source": [
|
324 |
+
"# calibrate the batch norm at the end of training\n",
|
325 |
+
"\n",
|
326 |
+
"with torch.no_grad():\n",
|
327 |
+
" # pass the training set through\n",
|
328 |
+
" emb = C[Xtr]\n",
|
329 |
+
" embcat = emb.view(emb.shape[0], -1)\n",
|
330 |
+
" hpreact = embcat @ W1 + b1\n",
|
331 |
+
" # measure the mean/std over the entire training set\n",
|
332 |
+
" bnmean = hpreact.mean(0, keepdim=True)\n",
|
333 |
+
" bnvar = hpreact.var(0, keepdim=True, unbiased=True)"
|
334 |
+
]
|
335 |
+
},
|
336 |
+
{
|
337 |
+
"cell_type": "code",
|
338 |
+
"execution_count": 9,
|
339 |
+
"metadata": {},
|
340 |
+
"outputs": [
|
341 |
+
{
|
342 |
+
"name": "stdout",
|
343 |
+
"output_type": "stream",
|
344 |
+
"text": [
|
345 |
+
"train 2.0708959102630615\n",
|
346 |
+
"val 2.1080715656280518\n"
|
347 |
+
]
|
348 |
+
}
|
349 |
+
],
|
350 |
+
"source": [
|
351 |
+
"# evaluate train and val loss\n",
|
352 |
+
"\n",
|
353 |
+
"@torch.no_grad() # this decorator disables gradient tracking\n",
|
354 |
+
"def split_loss(split):\n",
|
355 |
+
" x,y = {\n",
|
356 |
+
" 'train': (Xtr, Ytr),\n",
|
357 |
+
" 'val': (Xdev, Ydev),\n",
|
358 |
+
" 'test': (Xte, Yte),\n",
|
359 |
+
" }[split]\n",
|
360 |
+
" emb = C[x] # (N, block_size, n_embd)\n",
|
361 |
+
" embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)\n",
|
362 |
+
" hpreact = embcat @ W1 + b1\n",
|
363 |
+
" hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias\n",
|
364 |
+
" h = torch.tanh(hpreact) # (N, n_hidden)\n",
|
365 |
+
" logits = h @ W2 + b2 # (N, vocab_size)\n",
|
366 |
+
" loss = F.cross_entropy(logits, y)\n",
|
367 |
+
" print(split, loss.item())\n",
|
368 |
+
"\n",
|
369 |
+
"split_loss('train')\n",
|
370 |
+
"split_loss('val')"
|
371 |
+
]
|
372 |
+
},
|
373 |
+
{
|
374 |
+
"cell_type": "code",
|
375 |
+
"execution_count": null,
|
376 |
+
"metadata": {},
|
377 |
+
"outputs": [],
|
378 |
+
"source": [
|
379 |
+
"# Okay probably relatively slightly lower but thats cool"
|
380 |
+
]
|
381 |
+
},
|
382 |
+
{
|
383 |
+
"cell_type": "code",
|
384 |
+
"execution_count": 10,
|
385 |
+
"metadata": {},
|
386 |
+
"outputs": [
|
387 |
+
{
|
388 |
+
"name": "stdout",
|
389 |
+
"output_type": "stream",
|
390 |
+
"text": [
|
391 |
+
"mora.\n",
|
392 |
+
"mayah.\n",
|
393 |
+
"see.\n",
|
394 |
+
"mad.\n",
|
395 |
+
"ryla.\n",
|
396 |
+
"reisha.\n",
|
397 |
+
"endraegan.\n",
|
398 |
+
"chedielin.\n",
|
399 |
+
"shi.\n",
|
400 |
+
"jen.\n",
|
401 |
+
"eden.\n",
|
402 |
+
"sana.\n",
|
403 |
+
"arleigh.\n",
|
404 |
+
"malaia.\n",
|
405 |
+
"noshubergshira.\n",
|
406 |
+
"sten.\n",
|
407 |
+
"joselle.\n",
|
408 |
+
"jose.\n",
|
409 |
+
"casubenteda.\n",
|
410 |
+
"jamell.\n"
|
411 |
+
]
|
412 |
+
}
|
413 |
+
],
|
414 |
+
"source": [
|
415 |
+
"# sample from the model\n",
|
416 |
+
"g = torch.Generator().manual_seed(2147483647 + 10)\n",
|
417 |
+
"\n",
|
418 |
+
"for _ in range(20):\n",
|
419 |
+
" \n",
|
420 |
+
" out = []\n",
|
421 |
+
" context = [0] * block_size # initialize with all ...\n",
|
422 |
+
" while True:\n",
|
423 |
+
" # ------------\n",
|
424 |
+
" # forward pass:\n",
|
425 |
+
" # Embedding\n",
|
426 |
+
" emb = C[torch.tensor([context])] # (1,block_size,d) \n",
|
427 |
+
" embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)\n",
|
428 |
+
" hpreact = embcat @ W1 + b1\n",
|
429 |
+
" hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias\n",
|
430 |
+
" h = torch.tanh(hpreact) # (N, n_hidden)\n",
|
431 |
+
" logits = h @ W2 + b2 # (N, vocab_size)\n",
|
432 |
+
" # ------------\n",
|
433 |
+
" # Sample\n",
|
434 |
+
" probs = F.softmax(logits, dim=1)\n",
|
435 |
+
" ix = torch.multinomial(probs, num_samples=1, generator=g).item()\n",
|
436 |
+
" context = context[1:] + [ix]\n",
|
437 |
+
" out.append(ix)\n",
|
438 |
+
" if ix == 0:\n",
|
439 |
+
" break\n",
|
440 |
+
" \n",
|
441 |
+
" print(''.join(itos[i] for i in out))"
|
442 |
+
]
|
443 |
+
},
|
444 |
+
{
|
445 |
+
"cell_type": "code",
|
446 |
+
"execution_count": null,
|
447 |
+
"metadata": {},
|
448 |
+
"outputs": [],
|
449 |
+
"source": [
|
450 |
+
"# I've definetly got some wayyy better names here through most are gibberish xD"
|
451 |
+
]
|
452 |
+
},
|
453 |
+
{
|
454 |
+
"cell_type": "markdown",
|
455 |
+
"metadata": {},
|
456 |
+
"source": [
|
457 |
+
"And that marks the end of exploring the basic understanding of the 'intuition' of training NN using (traditional) methods. We will be moving on to more complex ones from here on - like RNN etc. So looking forward to that :)"
|
458 |
+
]
|
459 |
+
}
|
460 |
+
],
|
461 |
+
"metadata": {
|
462 |
+
"kernelspec": {
|
463 |
+
"display_name": "venv",
|
464 |
+
"language": "python",
|
465 |
+
"name": "python3"
|
466 |
+
},
|
467 |
+
"language_info": {
|
468 |
+
"codemirror_mode": {
|
469 |
+
"name": "ipython",
|
470 |
+
"version": 3
|
471 |
+
},
|
472 |
+
"file_extension": ".py",
|
473 |
+
"mimetype": "text/x-python",
|
474 |
+
"name": "python",
|
475 |
+
"nbconvert_exporter": "python",
|
476 |
+
"pygments_lexer": "ipython3",
|
477 |
+
"version": "3.10.0"
|
478 |
+
}
|
479 |
+
},
|
480 |
+
"nbformat": 4,
|
481 |
+
"nbformat_minor": 2
|
482 |
+
}
|
names.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
starter-code.ipynb
ADDED
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "rToK0Tku8PPn"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"## makemore: becoming a backprop ninja"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"id": "8sFElPqq8PPp"
|
17 |
+
},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"# there no change change in the first several cells from last lecture"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": null,
|
26 |
+
"metadata": {
|
27 |
+
"id": "ChBbac4y8PPq"
|
28 |
+
},
|
29 |
+
"outputs": [],
|
30 |
+
"source": [
|
31 |
+
"import torch\n",
|
32 |
+
"import torch.nn.functional as F\n",
|
33 |
+
"import matplotlib.pyplot as plt # for making figures\n",
|
34 |
+
"%matplotlib inline"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"source": [
|
40 |
+
"# download the names.txt file from github\n",
|
41 |
+
"!wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt"
|
42 |
+
],
|
43 |
+
"metadata": {
|
44 |
+
"id": "x6GhEWW18aCS"
|
45 |
+
},
|
46 |
+
"execution_count": null,
|
47 |
+
"outputs": []
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": null,
|
52 |
+
"metadata": {
|
53 |
+
"id": "klmu3ZG08PPr"
|
54 |
+
},
|
55 |
+
"outputs": [],
|
56 |
+
"source": [
|
57 |
+
"# read in all the words\n",
|
58 |
+
"words = open('names.txt', 'r').read().splitlines()\n",
|
59 |
+
"print(len(words))\n",
|
60 |
+
"print(max(len(w) for w in words))\n",
|
61 |
+
"print(words[:8])"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": null,
|
67 |
+
"metadata": {
|
68 |
+
"id": "BCQomLE_8PPs"
|
69 |
+
},
|
70 |
+
"outputs": [],
|
71 |
+
"source": [
|
72 |
+
"# build the vocabulary of characters and mappings to/from integers\n",
|
73 |
+
"chars = sorted(list(set(''.join(words))))\n",
|
74 |
+
"stoi = {s:i+1 for i,s in enumerate(chars)}\n",
|
75 |
+
"stoi['.'] = 0\n",
|
76 |
+
"itos = {i:s for s,i in stoi.items()}\n",
|
77 |
+
"vocab_size = len(itos)\n",
|
78 |
+
"print(itos)\n",
|
79 |
+
"print(vocab_size)"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": null,
|
85 |
+
"metadata": {
|
86 |
+
"id": "V_zt2QHr8PPs"
|
87 |
+
},
|
88 |
+
"outputs": [],
|
89 |
+
"source": [
|
90 |
+
"# build the dataset\n",
|
91 |
+
"block_size = 3 # context length: how many characters do we take to predict the next one?\n",
|
92 |
+
"\n",
|
93 |
+
"def build_dataset(words):\n",
|
94 |
+
" X, Y = [], []\n",
|
95 |
+
"\n",
|
96 |
+
" for w in words:\n",
|
97 |
+
" context = [0] * block_size\n",
|
98 |
+
" for ch in w + '.':\n",
|
99 |
+
" ix = stoi[ch]\n",
|
100 |
+
" X.append(context)\n",
|
101 |
+
" Y.append(ix)\n",
|
102 |
+
" context = context[1:] + [ix] # crop and append\n",
|
103 |
+
"\n",
|
104 |
+
" X = torch.tensor(X)\n",
|
105 |
+
" Y = torch.tensor(Y)\n",
|
106 |
+
" print(X.shape, Y.shape)\n",
|
107 |
+
" return X, Y\n",
|
108 |
+
"\n",
|
109 |
+
"import random\n",
|
110 |
+
"random.seed(42)\n",
|
111 |
+
"random.shuffle(words)\n",
|
112 |
+
"n1 = int(0.8*len(words))\n",
|
113 |
+
"n2 = int(0.9*len(words))\n",
|
114 |
+
"\n",
|
115 |
+
"Xtr, Ytr = build_dataset(words[:n1]) # 80%\n",
|
116 |
+
"Xdev, Ydev = build_dataset(words[n1:n2]) # 10%\n",
|
117 |
+
"Xte, Yte = build_dataset(words[n2:]) # 10%"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": null,
|
123 |
+
"metadata": {
|
124 |
+
"id": "eg20-vsg8PPt"
|
125 |
+
},
|
126 |
+
"outputs": [],
|
127 |
+
"source": [
|
128 |
+
"# ok biolerplate done, now we get to the action:"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": null,
|
134 |
+
"metadata": {
|
135 |
+
"id": "MJPU8HT08PPu"
|
136 |
+
},
|
137 |
+
"outputs": [],
|
138 |
+
"source": [
|
139 |
+
"# utility function we will use later when comparing manual gradients to PyTorch gradients\n",
|
140 |
+
"def cmp(s, dt, t):\n",
|
141 |
+
" ex = torch.all(dt == t.grad).item()\n",
|
142 |
+
" app = torch.allclose(dt, t.grad)\n",
|
143 |
+
" maxdiff = (dt - t.grad).abs().max().item()\n",
|
144 |
+
" print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"cell_type": "code",
|
149 |
+
"execution_count": null,
|
150 |
+
"metadata": {
|
151 |
+
"id": "ZlFLjQyT8PPu"
|
152 |
+
},
|
153 |
+
"outputs": [],
|
154 |
+
"source": [
|
155 |
+
"n_embd = 10 # the dimensionality of the character embedding vectors\n",
|
156 |
+
"n_hidden = 64 # the number of neurons in the hidden layer of the MLP\n",
|
157 |
+
"\n",
|
158 |
+
"g = torch.Generator().manual_seed(2147483647) # for reproducibility\n",
|
159 |
+
"C = torch.randn((vocab_size, n_embd), generator=g)\n",
|
160 |
+
"# Layer 1\n",
|
161 |
+
"W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)\n",
|
162 |
+
"b1 = torch.randn(n_hidden, generator=g) * 0.1 # using b1 just for fun, it's useless because of BN\n",
|
163 |
+
"# Layer 2\n",
|
164 |
+
"W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1\n",
|
165 |
+
"b2 = torch.randn(vocab_size, generator=g) * 0.1\n",
|
166 |
+
"# BatchNorm parameters\n",
|
167 |
+
"bngain = torch.randn((1, n_hidden))*0.1 + 1.0\n",
|
168 |
+
"bnbias = torch.randn((1, n_hidden))*0.1\n",
|
169 |
+
"\n",
|
170 |
+
"# Note: I am initializating many of these parameters in non-standard ways\n",
|
171 |
+
"# because sometimes initializating with e.g. all zeros could mask an incorrect\n",
|
172 |
+
"# implementation of the backward pass.\n",
|
173 |
+
"\n",
|
174 |
+
"parameters = [C, W1, b1, W2, b2, bngain, bnbias]\n",
|
175 |
+
"print(sum(p.nelement() for p in parameters)) # number of parameters in total\n",
|
176 |
+
"for p in parameters:\n",
|
177 |
+
" p.requires_grad = True"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "code",
|
182 |
+
"execution_count": null,
|
183 |
+
"metadata": {
|
184 |
+
"id": "QY-y96Y48PPv"
|
185 |
+
},
|
186 |
+
"outputs": [],
|
187 |
+
"source": [
|
188 |
+
"batch_size = 32\n",
|
189 |
+
"n = batch_size # a shorter variable also, for convenience\n",
|
190 |
+
"# construct a minibatch\n",
|
191 |
+
"ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)\n",
|
192 |
+
"Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y"
|
193 |
+
]
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"cell_type": "code",
|
197 |
+
"execution_count": null,
|
198 |
+
"metadata": {
|
199 |
+
"id": "8ofj1s6d8PPv"
|
200 |
+
},
|
201 |
+
"outputs": [],
|
202 |
+
"source": [
|
203 |
+
"# forward pass, \"chunkated\" into smaller steps that are possible to backward one at a time\n",
|
204 |
+
"\n",
|
205 |
+
"emb = C[Xb] # embed the characters into vectors\n",
|
206 |
+
"embcat = emb.view(emb.shape[0], -1) # concatenate the vectors\n",
|
207 |
+
"# Linear layer 1\n",
|
208 |
+
"hprebn = embcat @ W1 + b1 # hidden layer pre-activation\n",
|
209 |
+
"# BatchNorm layer\n",
|
210 |
+
"bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n",
|
211 |
+
"bndiff = hprebn - bnmeani\n",
|
212 |
+
"bndiff2 = bndiff**2\n",
|
213 |
+
"bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)\n",
|
214 |
+
"bnvar_inv = (bnvar + 1e-5)**-0.5\n",
|
215 |
+
"bnraw = bndiff * bnvar_inv\n",
|
216 |
+
"hpreact = bngain * bnraw + bnbias\n",
|
217 |
+
"# Non-linearity\n",
|
218 |
+
"h = torch.tanh(hpreact) # hidden layer\n",
|
219 |
+
"# Linear layer 2\n",
|
220 |
+
"logits = h @ W2 + b2 # output layer\n",
|
221 |
+
"# cross entropy loss (same as F.cross_entropy(logits, Yb))\n",
|
222 |
+
"logit_maxes = logits.max(1, keepdim=True).values\n",
|
223 |
+
"norm_logits = logits - logit_maxes # subtract max for numerical stability\n",
|
224 |
+
"counts = norm_logits.exp()\n",
|
225 |
+
"counts_sum = counts.sum(1, keepdims=True)\n",
|
226 |
+
"counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...\n",
|
227 |
+
"probs = counts * counts_sum_inv\n",
|
228 |
+
"logprobs = probs.log()\n",
|
229 |
+
"loss = -logprobs[range(n), Yb].mean()\n",
|
230 |
+
"\n",
|
231 |
+
"# PyTorch backward pass\n",
|
232 |
+
"for p in parameters:\n",
|
233 |
+
" p.grad = None\n",
|
234 |
+
"for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way\n",
|
235 |
+
" norm_logits, logit_maxes, logits, h, hpreact, bnraw,\n",
|
236 |
+
" bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,\n",
|
237 |
+
" embcat, emb]:\n",
|
238 |
+
" t.retain_grad()\n",
|
239 |
+
"loss.backward()\n",
|
240 |
+
"loss"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"cell_type": "code",
|
245 |
+
"execution_count": null,
|
246 |
+
"metadata": {
|
247 |
+
"id": "mO-8aqxK8PPw"
|
248 |
+
},
|
249 |
+
"outputs": [],
|
250 |
+
"source": [
|
251 |
+
"# Exercise 1: backprop through the whole thing manually,\n",
|
252 |
+
"# backpropagating through exactly all of the variables\n",
|
253 |
+
"# as they are defined in the forward pass above, one by one\n",
|
254 |
+
"\n",
|
255 |
+
"# -----------------\n",
|
256 |
+
"# YOUR CODE HERE :)\n",
|
257 |
+
"# -----------------\n",
|
258 |
+
"\n",
|
259 |
+
"# cmp('logprobs', dlogprobs, logprobs)\n",
|
260 |
+
"# cmp('probs', dprobs, probs)\n",
|
261 |
+
"# cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)\n",
|
262 |
+
"# cmp('counts_sum', dcounts_sum, counts_sum)\n",
|
263 |
+
"# cmp('counts', dcounts, counts)\n",
|
264 |
+
"# cmp('norm_logits', dnorm_logits, norm_logits)\n",
|
265 |
+
"# cmp('logit_maxes', dlogit_maxes, logit_maxes)\n",
|
266 |
+
"# cmp('logits', dlogits, logits)\n",
|
267 |
+
"# cmp('h', dh, h)\n",
|
268 |
+
"# cmp('W2', dW2, W2)\n",
|
269 |
+
"# cmp('b2', db2, b2)\n",
|
270 |
+
"# cmp('hpreact', dhpreact, hpreact)\n",
|
271 |
+
"# cmp('bngain', dbngain, bngain)\n",
|
272 |
+
"# cmp('bnbias', dbnbias, bnbias)\n",
|
273 |
+
"# cmp('bnraw', dbnraw, bnraw)\n",
|
274 |
+
"# cmp('bnvar_inv', dbnvar_inv, bnvar_inv)\n",
|
275 |
+
"# cmp('bnvar', dbnvar, bnvar)\n",
|
276 |
+
"# cmp('bndiff2', dbndiff2, bndiff2)\n",
|
277 |
+
"# cmp('bndiff', dbndiff, bndiff)\n",
|
278 |
+
"# cmp('bnmeani', dbnmeani, bnmeani)\n",
|
279 |
+
"# cmp('hprebn', dhprebn, hprebn)\n",
|
280 |
+
"# cmp('embcat', dembcat, embcat)\n",
|
281 |
+
"# cmp('W1', dW1, W1)\n",
|
282 |
+
"# cmp('b1', db1, b1)\n",
|
283 |
+
"# cmp('emb', demb, emb)\n",
|
284 |
+
"# cmp('C', dC, C)"
|
285 |
+
]
|
286 |
+
},
|
287 |
+
{
|
288 |
+
"cell_type": "code",
|
289 |
+
"execution_count": null,
|
290 |
+
"metadata": {
|
291 |
+
"id": "ebLtYji_8PPw"
|
292 |
+
},
|
293 |
+
"outputs": [],
|
294 |
+
"source": [
|
295 |
+
"# Exercise 2: backprop through cross_entropy but all in one go\n",
|
296 |
+
"# to complete this challenge look at the mathematical expression of the loss,\n",
|
297 |
+
"# take the derivative, simplify the expression, and just write it out\n",
|
298 |
+
"\n",
|
299 |
+
"# forward pass\n",
|
300 |
+
"\n",
|
301 |
+
"# before:\n",
|
302 |
+
"# logit_maxes = logits.max(1, keepdim=True).values\n",
|
303 |
+
"# norm_logits = logits - logit_maxes # subtract max for numerical stability\n",
|
304 |
+
"# counts = norm_logits.exp()\n",
|
305 |
+
"# counts_sum = counts.sum(1, keepdims=True)\n",
|
306 |
+
"# counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...\n",
|
307 |
+
"# probs = counts * counts_sum_inv\n",
|
308 |
+
"# logprobs = probs.log()\n",
|
309 |
+
"# loss = -logprobs[range(n), Yb].mean()\n",
|
310 |
+
"\n",
|
311 |
+
"# now:\n",
|
312 |
+
"loss_fast = F.cross_entropy(logits, Yb)\n",
|
313 |
+
"print(loss_fast.item(), 'diff:', (loss_fast - loss).item())"
|
314 |
+
]
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"cell_type": "code",
|
318 |
+
"execution_count": null,
|
319 |
+
"metadata": {
|
320 |
+
"id": "-gCXbB4C8PPx"
|
321 |
+
},
|
322 |
+
"outputs": [],
|
323 |
+
"source": [
|
324 |
+
"# backward pass\n",
|
325 |
+
"\n",
|
326 |
+
"# -----------------\n",
|
327 |
+
"# YOUR CODE HERE :)\n",
|
328 |
+
"dlogits = None # TODO. my solution is 3 lines\n",
|
329 |
+
"# -----------------\n",
|
330 |
+
"\n",
|
331 |
+
"#cmp('logits', dlogits, logits) # I can only get approximate to be true, my maxdiff is 6e-9"
|
332 |
+
]
|
333 |
+
},
|
334 |
+
{
|
335 |
+
"cell_type": "code",
|
336 |
+
"execution_count": null,
|
337 |
+
"metadata": {
|
338 |
+
"id": "hd-MkhB68PPy"
|
339 |
+
},
|
340 |
+
"outputs": [],
|
341 |
+
"source": [
|
342 |
+
"# Exercise 3: backprop through batchnorm but all in one go\n",
|
343 |
+
"# to complete this challenge look at the mathematical expression of the output of batchnorm,\n",
|
344 |
+
"# take the derivative w.r.t. its input, simplify the expression, and just write it out\n",
|
345 |
+
"# BatchNorm paper: https://arxiv.org/abs/1502.03167\n",
|
346 |
+
"\n",
|
347 |
+
"# forward pass\n",
|
348 |
+
"\n",
|
349 |
+
"# before:\n",
|
350 |
+
"# bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n",
|
351 |
+
"# bndiff = hprebn - bnmeani\n",
|
352 |
+
"# bndiff2 = bndiff**2\n",
|
353 |
+
"# bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)\n",
|
354 |
+
"# bnvar_inv = (bnvar + 1e-5)**-0.5\n",
|
355 |
+
"# bnraw = bndiff * bnvar_inv\n",
|
356 |
+
"# hpreact = bngain * bnraw + bnbias\n",
|
357 |
+
"\n",
|
358 |
+
"# now:\n",
|
359 |
+
"hpreact_fast = bngain * (hprebn - hprebn.mean(0, keepdim=True)) / torch.sqrt(hprebn.var(0, keepdim=True, unbiased=True) + 1e-5) + bnbias\n",
|
360 |
+
"print('max diff:', (hpreact_fast - hpreact).abs().max())"
|
361 |
+
]
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"cell_type": "code",
|
365 |
+
"execution_count": null,
|
366 |
+
"metadata": {
|
367 |
+
"id": "POdeZSKT8PPy"
|
368 |
+
},
|
369 |
+
"outputs": [],
|
370 |
+
"source": [
|
371 |
+
"# backward pass\n",
|
372 |
+
"\n",
|
373 |
+
"# before we had:\n",
|
374 |
+
"# dbnraw = bngain * dhpreact\n",
|
375 |
+
"# dbndiff = bnvar_inv * dbnraw\n",
|
376 |
+
"# dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)\n",
|
377 |
+
"# dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv\n",
|
378 |
+
"# dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2) * dbnvar\n",
|
379 |
+
"# dbndiff += (2*bndiff) * dbndiff2\n",
|
380 |
+
"# dhprebn = dbndiff.clone()\n",
|
381 |
+
"# dbnmeani = (-dbndiff).sum(0)\n",
|
382 |
+
"# dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmeani)\n",
|
383 |
+
"\n",
|
384 |
+
"# calculate dhprebn given dhpreact (i.e. backprop through the batchnorm)\n",
|
385 |
+
"# (you'll also need to use some of the variables from the forward pass up above)\n",
|
386 |
+
"\n",
|
387 |
+
"# -----------------\n",
|
388 |
+
"# YOUR CODE HERE :)\n",
|
389 |
+
"dhprebn = None # TODO. my solution is 1 (long) line\n",
|
390 |
+
"# -----------------\n",
|
391 |
+
"\n",
|
392 |
+
"cmp('hprebn', dhprebn, hprebn) # I can only get approximate to be true, my maxdiff is 9e-10"
|
393 |
+
]
|
394 |
+
},
|
395 |
+
{
|
396 |
+
"cell_type": "code",
|
397 |
+
"execution_count": null,
|
398 |
+
"metadata": {
|
399 |
+
"id": "wPy8DhqB8PPz"
|
400 |
+
},
|
401 |
+
"outputs": [],
|
402 |
+
"source": [
|
403 |
+
"# Exercise 4: putting it all together!\n",
|
404 |
+
"# Train the MLP neural net with your own backward pass\n",
|
405 |
+
"\n",
|
406 |
+
"# init\n",
|
407 |
+
"n_embd = 10 # the dimensionality of the character embedding vectors\n",
|
408 |
+
"n_hidden = 200 # the number of neurons in the hidden layer of the MLP\n",
|
409 |
+
"\n",
|
410 |
+
"g = torch.Generator().manual_seed(2147483647) # for reproducibility\n",
|
411 |
+
"C = torch.randn((vocab_size, n_embd), generator=g)\n",
|
412 |
+
"# Layer 1\n",
|
413 |
+
"W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)\n",
|
414 |
+
"b1 = torch.randn(n_hidden, generator=g) * 0.1\n",
|
415 |
+
"# Layer 2\n",
|
416 |
+
"W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1\n",
|
417 |
+
"b2 = torch.randn(vocab_size, generator=g) * 0.1\n",
|
418 |
+
"# BatchNorm parameters\n",
|
419 |
+
"bngain = torch.randn((1, n_hidden))*0.1 + 1.0\n",
|
420 |
+
"bnbias = torch.randn((1, n_hidden))*0.1\n",
|
421 |
+
"\n",
|
422 |
+
"parameters = [C, W1, b1, W2, b2, bngain, bnbias]\n",
|
423 |
+
"print(sum(p.nelement() for p in parameters)) # number of parameters in total\n",
|
424 |
+
"for p in parameters:\n",
|
425 |
+
" p.requires_grad = True\n",
|
426 |
+
"\n",
|
427 |
+
"# same optimization as last time\n",
|
428 |
+
"max_steps = 200000\n",
|
429 |
+
"batch_size = 32\n",
|
430 |
+
"n = batch_size # convenience\n",
|
431 |
+
"lossi = []\n",
|
432 |
+
"\n",
|
433 |
+
"# use this context manager for efficiency once your backward pass is written (TODO)\n",
|
434 |
+
"#with torch.no_grad():\n",
|
435 |
+
"\n",
|
436 |
+
"# kick off optimization\n",
|
437 |
+
"for i in range(max_steps):\n",
|
438 |
+
"\n",
|
439 |
+
" # minibatch construct\n",
|
440 |
+
" ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)\n",
|
441 |
+
" Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y\n",
|
442 |
+
"\n",
|
443 |
+
" # forward pass\n",
|
444 |
+
" emb = C[Xb] # embed the characters into vectors\n",
|
445 |
+
" embcat = emb.view(emb.shape[0], -1) # concatenate the vectors\n",
|
446 |
+
" # Linear layer\n",
|
447 |
+
" hprebn = embcat @ W1 + b1 # hidden layer pre-activation\n",
|
448 |
+
" # BatchNorm layer\n",
|
449 |
+
" # -------------------------------------------------------------\n",
|
450 |
+
" bnmean = hprebn.mean(0, keepdim=True)\n",
|
451 |
+
" bnvar = hprebn.var(0, keepdim=True, unbiased=True)\n",
|
452 |
+
" bnvar_inv = (bnvar + 1e-5)**-0.5\n",
|
453 |
+
" bnraw = (hprebn - bnmean) * bnvar_inv\n",
|
454 |
+
" hpreact = bngain * bnraw + bnbias\n",
|
455 |
+
" # -------------------------------------------------------------\n",
|
456 |
+
" # Non-linearity\n",
|
457 |
+
" h = torch.tanh(hpreact) # hidden layer\n",
|
458 |
+
" logits = h @ W2 + b2 # output layer\n",
|
459 |
+
" loss = F.cross_entropy(logits, Yb) # loss function\n",
|
460 |
+
"\n",
|
461 |
+
" # backward pass\n",
|
462 |
+
" for p in parameters:\n",
|
463 |
+
" p.grad = None\n",
|
464 |
+
" loss.backward() # use this for correctness comparisons, delete it later!\n",
|
465 |
+
"\n",
|
466 |
+
" # manual backprop! #swole_doge_meme\n",
|
467 |
+
" # -----------------\n",
|
468 |
+
" # YOUR CODE HERE :)\n",
|
469 |
+
" dC, dW1, db1, dW2, db2, dbngain, dbnbias = None, None, None, None, None, None, None\n",
|
470 |
+
" grads = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]\n",
|
471 |
+
" # -----------------\n",
|
472 |
+
"\n",
|
473 |
+
" # update\n",
|
474 |
+
" lr = 0.1 if i < 100000 else 0.01 # step learning rate decay\n",
|
475 |
+
" for p, grad in zip(parameters, grads):\n",
|
476 |
+
" p.data += -lr * p.grad # old way of cheems doge (using PyTorch grad from .backward())\n",
|
477 |
+
" #p.data += -lr * grad # new way of swole doge TODO: enable\n",
|
478 |
+
"\n",
|
479 |
+
" # track stats\n",
|
480 |
+
" if i % 10000 == 0: # print every once in a while\n",
|
481 |
+
" print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')\n",
|
482 |
+
" lossi.append(loss.log10().item())\n",
|
483 |
+
"\n",
|
484 |
+
" if i >= 100: # TODO: delete early breaking when you're ready to train the full net\n",
|
485 |
+
" break"
|
486 |
+
]
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"cell_type": "code",
|
490 |
+
"execution_count": null,
|
491 |
+
"metadata": {
|
492 |
+
"id": "ZEpI0hMW8PPz"
|
493 |
+
},
|
494 |
+
"outputs": [],
|
495 |
+
"source": [
|
496 |
+
"# useful for checking your gradients\n",
|
497 |
+
"# for p,g in zip(parameters, grads):\n",
|
498 |
+
"# cmp(str(tuple(p.shape)), g, p)"
|
499 |
+
]
|
500 |
+
},
|
501 |
+
{
|
502 |
+
"cell_type": "code",
|
503 |
+
"execution_count": null,
|
504 |
+
"metadata": {
|
505 |
+
"id": "KImLWNoh8PP0"
|
506 |
+
},
|
507 |
+
"outputs": [],
|
508 |
+
"source": [
|
509 |
+
"# calibrate the batch norm at the end of training\n",
|
510 |
+
"\n",
|
511 |
+
"with torch.no_grad():\n",
|
512 |
+
" # pass the training set through\n",
|
513 |
+
" emb = C[Xtr]\n",
|
514 |
+
" embcat = emb.view(emb.shape[0], -1)\n",
|
515 |
+
" hpreact = embcat @ W1 + b1\n",
|
516 |
+
" # measure the mean/std over the entire training set\n",
|
517 |
+
" bnmean = hpreact.mean(0, keepdim=True)\n",
|
518 |
+
" bnvar = hpreact.var(0, keepdim=True, unbiased=True)\n"
|
519 |
+
]
|
520 |
+
},
|
521 |
+
{
|
522 |
+
"cell_type": "code",
|
523 |
+
"execution_count": null,
|
524 |
+
"metadata": {
|
525 |
+
"id": "6aFnP_Zc8PP0"
|
526 |
+
},
|
527 |
+
"outputs": [],
|
528 |
+
"source": [
|
529 |
+
"# evaluate train and val loss\n",
|
530 |
+
"\n",
|
531 |
+
"@torch.no_grad() # this decorator disables gradient tracking\n",
|
532 |
+
"def split_loss(split):\n",
|
533 |
+
" x,y = {\n",
|
534 |
+
" 'train': (Xtr, Ytr),\n",
|
535 |
+
" 'val': (Xdev, Ydev),\n",
|
536 |
+
" 'test': (Xte, Yte),\n",
|
537 |
+
" }[split]\n",
|
538 |
+
" emb = C[x] # (N, block_size, n_embd)\n",
|
539 |
+
" embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)\n",
|
540 |
+
" hpreact = embcat @ W1 + b1\n",
|
541 |
+
" hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias\n",
|
542 |
+
" h = torch.tanh(hpreact) # (N, n_hidden)\n",
|
543 |
+
" logits = h @ W2 + b2 # (N, vocab_size)\n",
|
544 |
+
" loss = F.cross_entropy(logits, y)\n",
|
545 |
+
" print(split, loss.item())\n",
|
546 |
+
"\n",
|
547 |
+
"split_loss('train')\n",
|
548 |
+
"split_loss('val')"
|
549 |
+
]
|
550 |
+
},
|
551 |
+
{
|
552 |
+
"cell_type": "code",
|
553 |
+
"execution_count": null,
|
554 |
+
"metadata": {
|
555 |
+
"id": "esWqmhyj8PP1"
|
556 |
+
},
|
557 |
+
"outputs": [],
|
558 |
+
"source": [
|
559 |
+
"# I achieved:\n",
|
560 |
+
"# train 2.0718822479248047\n",
|
561 |
+
"# val 2.1162495613098145"
|
562 |
+
]
|
563 |
+
},
|
564 |
+
{
|
565 |
+
"cell_type": "code",
|
566 |
+
"execution_count": null,
|
567 |
+
"metadata": {
|
568 |
+
"id": "xHeQNv3s8PP1"
|
569 |
+
},
|
570 |
+
"outputs": [],
|
571 |
+
"source": [
|
572 |
+
"# sample from the model\n",
|
573 |
+
"g = torch.Generator().manual_seed(2147483647 + 10)\n",
|
574 |
+
"\n",
|
575 |
+
"for _ in range(20):\n",
|
576 |
+
"\n",
|
577 |
+
" out = []\n",
|
578 |
+
" context = [0] * block_size # initialize with all ...\n",
|
579 |
+
" while True:\n",
|
580 |
+
" # forward pass\n",
|
581 |
+
" emb = C[torch.tensor([context])] # (1,block_size,d)\n",
|
582 |
+
" embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)\n",
|
583 |
+
" hpreact = embcat @ W1 + b1\n",
|
584 |
+
" hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias\n",
|
585 |
+
" h = torch.tanh(hpreact) # (N, n_hidden)\n",
|
586 |
+
" logits = h @ W2 + b2 # (N, vocab_size)\n",
|
587 |
+
" # sample\n",
|
588 |
+
" probs = F.softmax(logits, dim=1)\n",
|
589 |
+
" ix = torch.multinomial(probs, num_samples=1, generator=g).item()\n",
|
590 |
+
" context = context[1:] + [ix]\n",
|
591 |
+
" out.append(ix)\n",
|
592 |
+
" if ix == 0:\n",
|
593 |
+
" break\n",
|
594 |
+
"\n",
|
595 |
+
" print(''.join(itos[i] for i in out))"
|
596 |
+
]
|
597 |
+
}
|
598 |
+
],
|
599 |
+
"metadata": {
|
600 |
+
"kernelspec": {
|
601 |
+
"display_name": "Python 3",
|
602 |
+
"language": "python",
|
603 |
+
"name": "python3"
|
604 |
+
},
|
605 |
+
"language_info": {
|
606 |
+
"codemirror_mode": {
|
607 |
+
"name": "ipython",
|
608 |
+
"version": 3
|
609 |
+
},
|
610 |
+
"file_extension": ".py",
|
611 |
+
"mimetype": "text/x-python",
|
612 |
+
"name": "python",
|
613 |
+
"nbconvert_exporter": "python",
|
614 |
+
"pygments_lexer": "ipython3",
|
615 |
+
"version": "3.8.5"
|
616 |
+
},
|
617 |
+
"colab": {
|
618 |
+
"provenance": []
|
619 |
+
}
|
620 |
+
},
|
621 |
+
"nbformat": 4,
|
622 |
+
"nbformat_minor": 0
|
623 |
+
}
|