1 00:00:00,000 --> 00:00:05,360 Hi everyone. So by now you have probably heard of ChatGPT. It has taken the world and the AI 2 00:00:05,360 --> 00:00:11,700 community by storm, and it is a system that allows you to interact with an AI and give it text-based 3 00:00:11,700 --> 00:00:17,080 tasks. So for example, we can ask ChatGPT to write us a small haiku about how important it is that 4 00:00:17,080 --> 00:00:20,720 people understand AI, and then they can use it to improve the world and make it more prosperous. 5 00:00:21,420 --> 00:00:27,080 So when we run this, AI knowledge brings prosperity for all to see, embrace its power. 6 00:00:27,080 --> 00:00:33,300 Okay, not bad. And so you could see that ChatGPT went from left to right and generated all these 7 00:00:33,300 --> 00:00:39,100 words sort of sequentially. Now, I asked it already the exact same prompt a little bit earlier, 8 00:00:39,300 --> 00:00:44,660 and it generated a slightly different outcome. AI's power to grow, ignorance holds us back, 9 00:00:44,920 --> 00:00:50,820 learn, prosperity waits. So pretty good in both cases and slightly different. So you can see that 10 00:00:50,820 --> 00:00:55,520 ChatGPT is a probabilistic system, and for any one prompt, it can give us multiple answers, 11 00:00:55,660 --> 00:00:57,020 sort of replying. 12 00:00:57,080 --> 00:01:02,060 Now, this is just one example of a prompt. People have come up with many, many examples, 13 00:01:02,520 --> 00:01:08,220 and there are entire websites that index interactions with ChatGPT. And so many of 14 00:01:08,220 --> 00:01:13,640 them are quite humorous. Explain HTML to me like I'm a dog, write release notes for chess too, 15 00:01:14,580 --> 00:01:19,900 write a note about Elon Musk buying a Twitter, and so on. So as an example, 16 00:01:20,460 --> 00:01:23,180 please write a breaking news article about a leaf falling from a tree, 17 00:01:23,180 --> 00:01:26,320 and a shocking turn of events. 18 00:01:27,080 --> 00:01:30,340 The leaf falling from a tree in the local park. Witnesses report that the leaf, which was 19 00:01:30,340 --> 00:01:35,040 previously attached to a branch of a tree, detached itself and fell to the ground. Very 20 00:01:35,040 --> 00:01:40,000 dramatic. So you can see that this is a pretty remarkable system, and it is what we call a 21 00:01:40,000 --> 00:01:47,560 language model, because it models the sequence of words or characters or tokens more generally, 22 00:01:47,980 --> 00:01:53,260 and it knows how certain words follow each other in English language. And so from its perspective, 23 00:01:53,260 --> 00:01:56,520 what it is doing is it is completing the sequence. 24 00:01:57,080 --> 00:02:02,040 So I give it the start of a sequence, and it completes the sequence with the outcome. 25 00:02:02,040 --> 00:02:07,400 And so it's a language model in that sense. Now, I would like to focus on the under the hood of 26 00:02:08,360 --> 00:02:12,920 under the hood components of what makes ChatGPT work. So what is the neural network under the 27 00:02:12,920 --> 00:02:19,080 hood that models the sequence of these words? And that comes from this paper called Attention 28 00:02:19,080 --> 00:02:26,200 is All You Need. In 2017, a landmark paper, a landmark paper in AI that produced and proposed 29 00:02:26,200 --> 00:02:26,900 the Transformer and the Translator. And this paper is called the Transformer and the Translator. And 30 00:02:26,900 --> 00:02:27,060 this paper is called the Transformer and the Translator. And this paper is called the Transformer 31 00:02:27,080 --> 00:02:34,280 Architecture. So GPT is short for generatively, generatively pre trained transformer. So 32 00:02:34,280 --> 00:02:38,360 transformer is the neural net that actually does all the heavy lifting under the hood. 33 00:02:38,360 --> 00:02:44,440 It comes from this paper in 2017. Now, if you read this paper, this reads like a pretty random 34 00:02:44,440 --> 00:02:48,120 machine translation paper. And that's because I think the authors didn't fully anticipate the 35 00:02:48,120 --> 00:02:52,680 impact that the transformer would have on the field. And this architecture that they produced 36 00:02:52,680 --> 00:02:56,900 in the context of machine translation, in their case, actually ended up taking over 37 00:02:56,900 --> 00:03:03,860 the rest of AI in the next five years after. And so this architecture with minor changes was copy 38 00:03:03,860 --> 00:03:11,060 pasted into a huge amount of applications in AI in more recent years. And that includes at the core 39 00:03:11,060 --> 00:03:16,340 of ChatGPT. Now, we are not going to, what I'd like to do now is I'd like to build out 40 00:03:16,340 --> 00:03:21,220 something like ChatGPT. But we're not going to be able to, of course, reproduce ChatGPT. 41 00:03:21,220 --> 00:03:24,580 This is a very serious production grade system. It is trained on 42 00:03:24,580 --> 00:03:31,540 a good chunk of internet. And then there's a lot of pre training and fine tuning stages to it. 43 00:03:31,540 --> 00:03:37,380 And so it's very complicated. What I'd like to focus on is just to train a transformer based 44 00:03:37,380 --> 00:03:42,820 language model. And in our case, it's going to be a character level language model. I still think 45 00:03:42,820 --> 00:03:47,620 that is a very educational with respect to how these systems work. So I don't want to train on 46 00:03:47,620 --> 00:03:52,740 the chunk of internet, we need a smaller data set. In this case, I propose that we work with 47 00:03:52,740 --> 00:03:54,560 my favorite toy data set. It's called ChatGPT. And I'm going to show you how that works. I'm going to 48 00:03:54,580 --> 00:03:57,120 show you what it looks like. So first, I'm going to create a little tiny Shakespeare. And what it 49 00:03:57,120 --> 00:04:01,920 is is basically it's a concatenation of all of the works of Shakespeare in my understanding. And so 50 00:04:01,920 --> 00:04:07,960 this is all of Shakespeare in a single file. This file is about one megabyte. And it's just all of 51 00:04:07,960 --> 00:04:13,920 Shakespeare. And what we are going to do now is we're going to basically model how these characters 52 00:04:13,920 --> 00:04:20,440 follow each other. So for example, given a chunk of these characters like this, given some context 53 00:04:20,440 --> 00:04:24,560 of characters in the past, the transformer neural network will look at the model of the character. 54 00:04:24,580 --> 00:04:28,900 characters that i've highlighted and it's going to predict that g is likely to come next in the 55 00:04:28,900 --> 00:04:34,020 sequence and it's going to do that because we're going to train that transformer on shakespeare 56 00:04:34,020 --> 00:04:40,180 and it's just going to try to produce character sequences that look like this and in that process 57 00:04:40,180 --> 00:04:45,220 is going to model all the patterns inside this data so once we've trained the system i just 58 00:04:45,220 --> 00:04:50,900 like to give you a preview we can generate infinite shakespeare and of course it's a fake 59 00:04:50,900 --> 00:04:59,300 thing that looks kind of like shakespeare um apologies for there's some jank that i'm not 60 00:04:59,300 --> 00:05:06,820 able to resolve in in here but um you can see how this is going character by character and it's kind 61 00:05:06,820 --> 00:05:14,020 of like predicting shakespeare-like language so verily my lord the sights have left the again 62 00:05:14,020 --> 00:05:20,500 the king coming with my curses with precious pale and then tronio says something else etc 63 00:05:20,900 --> 00:05:25,380 and this is just coming out of the transformer in a very similar manner as it would come out in 64 00:05:25,380 --> 00:05:32,180 chat gpt in our case character by character in chat gpt it's coming out on the token by token 65 00:05:32,180 --> 00:05:36,820 level and tokens are these sort of like little subword pieces so they're not word level they're 66 00:05:36,820 --> 00:05:44,820 kind of like word chunk level um and now i've already written this entire code uh to train 67 00:05:44,820 --> 00:05:50,580 these transformers um and it is in a github repository that you can find and it's called 68 00:05:50,900 --> 00:05:57,380 nano gpt so nano gpt is a repository that you can find on my github and it's a repository for 69 00:05:57,380 --> 00:06:02,900 training transformers um on any given text and what i think is interesting about it because 70 00:06:02,900 --> 00:06:07,620 there's many ways to train transformers but this is a very simple implementation so it's just two 71 00:06:07,620 --> 00:06:14,260 files of 300 lines of code each one file defines the gpt model the transformer and one file trains 72 00:06:14,260 --> 00:06:19,620 it on some given text dataset and here i'm showing that if you train it on a open web text dataset 73 00:06:20,900 --> 00:06:30,180 web pages then i reproduce the the performance of gpt2 so gpt2 is an early version of openai's gpt 74 00:06:31,140 --> 00:06:36,740 from 2017 if i recall correctly and i've only so far reproduced the the smallest 124 million 75 00:06:36,740 --> 00:06:41,540 parameter model but basically this is just proving that the code base is correctly arranged and i'm 76 00:06:41,540 --> 00:06:48,420 able to load the neural network weights that openai has released later so you can take a look 77 00:06:48,420 --> 00:06:50,420 at the finished code here in nano gpt 78 00:06:50,900 --> 00:06:56,180 what i would like to do in this lecture is i would like to basically write this repository 79 00:06:56,180 --> 00:07:01,780 from scratch so we're going to begin with an empty file and we're going to define a transformer piece 80 00:07:01,780 --> 00:07:07,620 by piece we're going to train it on the tiny shakespeare dataset and we'll see how we can then 81 00:07:08,260 --> 00:07:13,460 generate infinite shakespeare and of course this can copy paste to any arbitrary text dataset 82 00:07:13,460 --> 00:07:17,780 that you like but my goal really here is to just make you understand and appreciate 83 00:07:18,660 --> 00:07:20,500 how under the hood chat gpt works 84 00:07:21,220 --> 00:07:29,060 and really all that's required is a proficiency in python and some basic understanding of calculus 85 00:07:29,060 --> 00:07:34,660 and statistics and it would help if you also see my previous videos on the same youtube channel 86 00:07:34,660 --> 00:07:42,820 in particular my make more series where i define smaller and simpler neural network language models 87 00:07:42,820 --> 00:07:47,700 so multi-layered perceptrons and so on it really introduces the language modeling framework 88 00:07:47,700 --> 00:07:50,740 and then here in this video we're going to focus on the transformer 89 00:07:50,900 --> 00:07:53,940 so let's look at the general structure of the neural network itself 90 00:07:54,820 --> 00:08:00,340 okay so i created a new google collab jupiter notebook here and this will allow me to later 91 00:08:00,340 --> 00:08:04,900 easily share this code that we're going to develop together with you so you can follow along so this 92 00:08:04,900 --> 00:08:10,420 will be in a video description later now here i've just done some preliminaries i downloaded 93 00:08:10,420 --> 00:08:15,060 the dataset the tiny shakespeare dataset at this url and you can see that it's about a one megabyte 94 00:08:15,060 --> 00:08:19,860 file then here i open the input.txt file and just read in all the text of the string 95 00:08:19,860 --> 00:08:22,900 and you can see that we are working with one million characters roughly 96 00:08:23,860 --> 00:08:27,860 and the first 1000 characters if we just print them out are basically what you would expect 97 00:08:27,860 --> 00:08:32,980 this is the first 1000 characters of the tiny shakespeare dataset roughly up to here 98 00:08:34,100 --> 00:08:39,620 so so far so good next we're going to take this text and the text is a sequence of characters 99 00:08:39,620 --> 00:08:45,780 in python so when i call the set constructor on it i'm just going to get the set of all the 100 00:08:45,780 --> 00:08:47,780 characters that occur in this text 101 00:08:49,860 --> 00:08:53,940 and then i'm just going to set the set of all the characters that occur in this text 102 00:08:53,940 --> 00:08:58,260 and then i'm going to I'm going to sort that to create a list of those characters instead of just 103 00:08:58,260 --> 00:09:04,020 a set so that i have an ordering an arbitrary ordering and then i sort that so basically we 104 00:09:04,020 --> 00:09:08,740 get just all the characters that occur in the entire data set and they're sorted now the number 105 00:09:08,740 --> 00:09:13,940 of them is going to be our vocabulary size these are the possible elements of our sequences and 106 00:09:13,940 --> 00:09:19,620 we see that when i print here the characters there's 65 of them in total there's a space character and then all kinds of special characters 107 00:09:19,860 --> 00:09:25,460 lowercase letters so that's our vocabulary and that's the sort of like possible characters that 108 00:09:25,460 --> 00:09:32,500 the model can see or emit okay so next we would like to develop some strategy to tokenize the 109 00:09:32,500 --> 00:09:39,140 input text now when people say tokenize they mean convert the raw text as a string to some 110 00:09:39,140 --> 00:09:44,660 sequence of integers according to some notebook according to some vocabulary of possible elements 111 00:09:45,380 --> 00:09:49,380 so as an example here we are going to be building a character level language model 112 00:09:49,380 --> 00:09:52,660 so we're simply going to be translating individual characters into integers 113 00:09:53,380 --> 00:09:58,100 so let me show you a chunk of code that sort of does that for us so we're building both the 114 00:09:58,100 --> 00:10:04,100 encoder and the decoder and let me just talk through what's happening here when we encode 115 00:10:04,100 --> 00:10:10,660 an arbitrary text like hi there we're going to receive a list of integers that represents that 116 00:10:10,660 --> 00:10:18,740 string so for example 46 47 etc and then we also have the reverse mapping so we can take this list 117 00:10:18,740 --> 00:10:19,360 and decode it into a string so we can take this list and decode it into a string so we can take 118 00:10:19,360 --> 00:10:23,840 this list and decode it to get back the exact same string so it's really just like a translation 119 00:10:23,840 --> 00:10:28,800 to integers and back for arbitrary string and for us it is done on a character level 120 00:10:30,000 --> 00:10:33,840 now the way this was achieved is we just iterate over all the characters here 121 00:10:33,840 --> 00:10:38,000 and create a lookup table from the character to the integer and vice versa 122 00:10:38,000 --> 00:10:42,400 and then to encode some string we simply translate all the characters individually 123 00:10:42,400 --> 00:10:49,360 and to decode it back we use the reverse mapping concatenate all of it now this is only one of many 124 00:10:49,360 --> 00:10:54,640 possible encodings or many possible tokenizers and it's a very simple one but there's many 125 00:10:54,640 --> 00:10:59,600 other schemas that people have come up with in practice so for example Google uses SENTENCEPIECE 126 00:11:00,800 --> 00:11:06,720 so SENTENCEPIECE will also encode text into integers but in a different schema 127 00:11:06,720 --> 00:11:13,600 and using a different vocabulary and SENTENCEPIECE is a sub-word sort of tokenizer and what that 128 00:11:13,600 --> 00:11:19,120 means is that you're not encoding entire words but you're not also encoding individual characters 129 00:11:19,120 --> 00:11:24,960 it's a subword unit level and that's usually what's adopted in practice. For example also 130 00:11:24,960 --> 00:11:29,920 OpenAI has this library called tiktoken that uses a byte pair encoding tokenizer 131 00:11:30,960 --> 00:11:37,600 and that's what GPT uses and you can also just encode words into like hello world into lists 132 00:11:37,600 --> 00:11:43,040 of integers. So as an example I'm using the tiktoken library here I'm getting the encoding 133 00:11:43,040 --> 00:11:50,080 for GPT-2 or that was used for GPT-2. Instead of just having 65 possible characters or tokens 134 00:11:50,080 --> 00:11:57,040 they have 50 000 tokens and so when they encode the exact same string high there we only get a 135 00:11:57,040 --> 00:12:05,840 list of three integers but those integers are not between 0 and 64 they are between 0 and 50 256. 136 00:12:06,720 --> 00:12:12,800 So basically you can trade off the codebook size and the sequence lengths so you can have a very 137 00:12:12,800 --> 00:12:13,000 long string and you can have a very long string and you can have a very long string and you can 138 00:12:13,000 --> 00:12:13,020 have a very long string and you can have a very long string and you can have a very long string and 139 00:12:13,020 --> 00:12:13,260 you can have a very long string and you can have a very long string and you can have a very long 140 00:12:13,260 --> 00:12:17,340 sequences of integers with very small vocabularies or you can have short 141 00:12:19,180 --> 00:12:25,980 sequences of integers with very large vocabularies and so typically people use in practice these 142 00:12:25,980 --> 00:12:31,100 subword encodings but I'd like to keep our tokenizer very simple so we're using character 143 00:12:31,100 --> 00:12:36,460 level tokenizer and that means that we have very small codebooks we have very simple encode and 144 00:12:36,460 --> 00:12:42,980 decode functions but we do get very long sequences as a result but that's the level at which we're 145 00:12:42,980 --> 00:12:47,140 going to stick with this lecture because it's the simplest thing okay so now that we have an encoder 146 00:12:47,140 --> 00:12:52,420 and a decoder effectively a tokenizer we can tokenize the entire training set of Shakespeare 147 00:12:52,980 --> 00:12:57,220 so here's a chunk of code that does that and I'm going to start to use the pytorch library 148 00:12:57,220 --> 00:13:02,660 and specifically the torch.tensor from the pytorch library so we're going to take all of the text 149 00:13:02,660 --> 00:13:09,140 in tiny Shakespeare encode it and then wrap it into a torch.tensor to get the data tensor so 150 00:13:09,140 --> 00:13:12,900 here's what the data tensor looks like when I look at just the first one thousand character 151 00:13:12,980 --> 00:13:17,540 or the one thousand elements of it so we see that we have a massive sequence of integers 152 00:13:18,100 --> 00:13:23,540 and this sequence of integers here is basically an identical translation of the first 1000 characters 153 00:13:23,540 --> 00:13:30,580 here so I believe for example that zero is a new line character and maybe one is a space I'm not 154 00:13:30,580 --> 00:13:36,180 100 sure but from now on the entire data set of text is re-represented as just it's just stretched 155 00:13:36,180 --> 00:13:42,420 out as a single very large sequence of integers let me do one more thing before we move on here 156 00:13:42,980 --> 00:13:48,580 we're going to separate out our data set into a train and a validation split so in particular 157 00:13:48,580 --> 00:13:53,380 we're going to take the first 90 of the data set and consider that to be the training data 158 00:13:53,380 --> 00:13:58,740 for the transformer and we're going to withhold the last 10 at the end of it to be the validation 159 00:13:58,740 --> 00:14:03,380 data and this will help us understand to what extent our model is overfitting so we're going 160 00:14:03,380 --> 00:14:08,020 to basically hide and keep the validation data on the side because we don't want just a perfect 161 00:14:08,020 --> 00:14:12,980 memorization of this exact Shakespeare we want a neural network that sort of creates Shakespeare's 162 00:14:12,980 --> 00:14:20,260 like text and so it should be fairly likely for it to produce the actual like stowed away 163 00:14:21,380 --> 00:14:27,780 true Shakespeare text and so we're going to use this to get a sense of the overfitting okay so now 164 00:14:27,780 --> 00:14:32,100 we would like to start plugging these text sequences or integer sequences into the 165 00:14:32,100 --> 00:14:37,780 transformer so that it can train and learn those patterns now the important thing to realize is 166 00:14:37,780 --> 00:14:42,020 we're never going to actually feed entire text into transformer all at once that would be 167 00:14:42,980 --> 00:14:48,020 very expensive and prohibitive so when we actually train a transformer on a lot of these data sets 168 00:14:48,020 --> 00:14:53,060 we only work with chunks of the data set and when we train the transformer we basically sample random 169 00:14:53,060 --> 00:14:58,420 little chunks out of the training set and train them just chunks at a time and these chunks have 170 00:14:58,420 --> 00:15:05,460 basically some kind of a length and some maximum length now the maximum length typically at least 171 00:15:05,460 --> 00:15:11,220 in the code i usually write is called block size you can you can find it under different names 172 00:15:11,220 --> 00:15:12,820 like context length or something like that 173 00:15:13,300 --> 00:15:17,860 let's start with the block size of just eight and let me look at the first train data characters 174 00:15:18,500 --> 00:15:22,340 the first block size plus one characters i'll explain why plus one in a second 175 00:15:23,700 --> 00:15:29,620 so this is the first nine characters in the sequence in the training set now what i'd like 176 00:15:29,620 --> 00:15:34,740 to point out is that when you sample a chunk of data like this so say these nine characters out 177 00:15:34,740 --> 00:15:41,060 of the training set this actually has multiple examples packed into it and that's because all 178 00:15:41,060 --> 00:15:42,740 of these characters follow each other 179 00:15:42,980 --> 00:15:49,060 and so what this thing is going to say when we plug it into a transformer is we're going to 180 00:15:49,060 --> 00:15:52,900 actually simultaneously train it to make a prediction at every one of these positions 181 00:15:53,700 --> 00:16:00,020 now in the in a chunk of nine characters there's actually eight individual examples packed in there 182 00:16:00,500 --> 00:16:08,340 so there's the example that when 18 when in the context of 18 47 likely comes next in a context of 183 00:16:08,340 --> 00:16:12,260 18 and 47 56 comes next in the context of 1847 184 00:16:12,980 --> 00:16:16,880 47, 56, 57 can come next, and so on. 185 00:16:17,280 --> 00:16:19,180 So that's the eight individual examples. 186 00:16:19,640 --> 00:16:21,260 Let me actually spell it out with code. 187 00:16:22,480 --> 00:16:24,160 So here's a chunk of code to illustrate. 188 00:16:25,080 --> 00:16:26,920 X are the inputs to the transformer. 189 00:16:27,180 --> 00:16:29,240 It will just be the first block size characters. 190 00:16:30,100 --> 00:16:33,960 Y will be the next block size characters. 191 00:16:33,960 --> 00:16:35,340 So it's offset by one. 192 00:16:36,200 --> 00:16:40,460 And that's because Y are the targets for each position in the input. 193 00:16:40,460 --> 00:16:44,120 And then here I'm iterating over all the block size of eight. 194 00:16:44,920 --> 00:16:49,880 And the context is always all the characters in X up to T and including T. 195 00:16:50,580 --> 00:16:55,500 And the target is always the T character, but in the targets array Y. 196 00:16:56,180 --> 00:16:56,960 So let me just run this. 197 00:16:58,140 --> 00:17:00,600 And basically it spells out what I said in words. 198 00:17:01,200 --> 00:17:04,900 These are the eight examples hidden in a chunk of nine characters 199 00:17:04,900 --> 00:17:08,820 that we sampled from the training set. 200 00:17:09,680 --> 00:17:10,440 I want to make sure that I'm not missing anything. 201 00:17:10,440 --> 00:17:11,120 Let me just mention one more thing. 202 00:17:11,700 --> 00:17:16,340 We train on all the eight examples here with context between one 203 00:17:16,340 --> 00:17:18,220 all the way up to context of block size. 204 00:17:18,740 --> 00:17:21,160 And we train on that not just for computational reasons 205 00:17:21,160 --> 00:17:23,700 because we happen to have the sequence already or something like that. 206 00:17:23,740 --> 00:17:25,160 It's not just done for efficiency. 207 00:17:25,660 --> 00:17:31,500 It's also done to make the transformer network be used to seeing contexts 208 00:17:31,500 --> 00:17:35,140 all the way from as little as one all the way to block size. 209 00:17:35,700 --> 00:17:38,900 And we'd like the transformer to be used to seeing everything in between. 210 00:17:38,900 --> 00:17:40,320 And that's going to be useful. 211 00:17:40,440 --> 00:17:43,020 Later during inference, because while we're sampling, 212 00:17:43,380 --> 00:17:47,300 we can start to set a sampling generation with as little as one character of context. 213 00:17:47,680 --> 00:17:49,880 And the transformer knows how to predict the next character 214 00:17:50,160 --> 00:17:52,200 with all the way up to just context of one. 215 00:17:52,740 --> 00:17:55,100 And so then it can predict everything up to block size. 216 00:17:55,440 --> 00:17:59,500 And after block size, we have to start truncating because the transformer will never 217 00:18:00,200 --> 00:18:03,660 receive more than block size inputs when it's predicting the next character. 218 00:18:04,780 --> 00:18:09,300 Okay, so we've looked at the time dimension of the tensors that are going to be feeding into the transformer. 219 00:18:09,300 --> 00:18:12,120 There's one more dimension to care about, and that is the batch dimension. 220 00:18:12,820 --> 00:18:18,960 And so as we're sampling these chunks of text, we're going to be actually every time we're going to feed them into a transformer, 221 00:18:19,320 --> 00:18:24,060 we're going to have many batches of multiple chunks of text that are all like stacked up in a single tensor. 222 00:18:24,660 --> 00:18:32,500 And that's just done for efficiency just so that we can keep the GPUs busy because they are very good at parallel processing of data. 223 00:18:33,060 --> 00:18:36,600 And so we just want to process multiple chunks all at the same time. 224 00:18:36,900 --> 00:18:38,940 But those chunks are processed completely independently. 225 00:18:38,940 --> 00:18:39,280 They don't take up too much space. 226 00:18:39,280 --> 00:18:40,900 They don't talk to each other and so on. 227 00:18:41,560 --> 00:18:44,720 So let me basically just generalize this and introduce a batch dimension. 228 00:18:45,060 --> 00:18:45,960 Here's a chunk of code. 229 00:18:47,340 --> 00:18:49,560 Let me just run it and then I'm going to explain what it does. 230 00:18:51,800 --> 00:19:05,380 So here, because we're going to start sampling random locations in the data sets to pull chunks from, I am setting the seed so that in the random number generator, so that the numbers I see here are going to be the same numbers you see later if you try to reproduce this. 231 00:19:06,460 --> 00:19:08,860 Now, the batch size here is how many independent sequences we are producing. 232 00:19:09,280 --> 00:19:12,300 We're processing every forward backward pass of the transformer. 233 00:19:13,780 --> 00:19:17,800 The block size, as I explained, is the maximum context length to make those predictions. 234 00:19:18,440 --> 00:19:20,500 So let's say batch size four, block size eight. 235 00:19:20,920 --> 00:19:24,160 And then here's how we get batch for any arbitrary split. 236 00:19:24,820 --> 00:19:28,720 If the split is a training split, then we're going to look at train data, otherwise at val data. 237 00:19:30,040 --> 00:19:32,140 That gives us the data array. 238 00:19:32,860 --> 00:19:38,620 And then when I generate random positions to grab a chunk out of, I actually grab, I actually generate random data. 239 00:19:38,620 --> 00:19:39,140 I actually generate random positions to grab a chunk out of. 240 00:19:39,140 --> 00:19:39,240 I actually generate random positions to grab a chunk out of. 241 00:19:39,240 --> 00:19:39,260 I actually generate random positions to grab a chunk out of. 242 00:19:39,280 --> 00:19:42,960 assumeonline, will generate batch size number of random offsets. 243 00:19:43,200 --> 00:19:48,080 So because this is four, we are, i, x is going to be a four numbers that are randomly generated between 0 and len of data minus block size. 244 00:19:48,080 --> 00:19:52,060 are randomly generated between 0 and len of data minus block size. 245 00:19:53,460 --> 00:19:54,840 So it's just random offsets into the training set. 246 00:19:55,620 --> 00:20:03,520 And then x' as I explained are the first block size characters, starting at i. 247 00:20:03,920 --> 00:20:06,140 The y' are the offset by 1 of that. 248 00:20:07,280 --> 00:20:07,960 So just add plus 1. 249 00:20:08,040 --> 00:20:08,640 And then we're going to get roughly how many random fields are generated. 250 00:20:08,640 --> 00:20:09,220 So just add plus 1. 251 00:20:09,220 --> 00:20:16,340 get those chunks for every one of integers i in ix and use a torch.stack to take all those 252 00:20:17,540 --> 00:20:23,780 one-dimensional tensors as we saw here and we're going to stack them up as rows 253 00:20:24,900 --> 00:20:30,660 and so they all become a row in a four by eight tensor so here's where i'm printing them 254 00:20:30,660 --> 00:20:39,600 when i sample a batch xb and yb the inputs the transformer now are the input x is the four by 255 00:20:39,600 --> 00:20:47,920 eight tensor four rows of eight columns and each one of these is a chunk of the training set 256 00:20:47,920 --> 00:20:55,100 and then the targets here are in the associated array y and they will come in to the transformer 257 00:20:55,100 --> 00:21:00,660 all the way at the end to create the loss function so they will give us the 258 00:21:00,660 --> 00:21:07,220 correct answer for every single position inside x and then these are the four independent rows 259 00:21:08,900 --> 00:21:16,500 so spelled out as we did before this 4x8 array contains a total of 32 examples 260 00:21:17,060 --> 00:21:20,100 and they're completely independent as far as the transformer is concerned 261 00:21:22,020 --> 00:21:30,580 so when the input is 24 the target is 43 or rather 43 here in the y array when the input is 2443 262 00:21:30,660 --> 00:21:39,220 the target is 58. when the input is 2443 58 the target is 5 etc or like when it is a 52581 the 263 00:21:39,220 --> 00:21:46,580 target is 58 right so you can sort of see this spelled out these are the 32 independent examples 264 00:21:46,580 --> 00:21:52,740 packed in to a single batch of the input x and then the desired targets are in y 265 00:21:53,780 --> 00:22:00,580 and so now this integer tensor of x is going to feed into the transformer 266 00:22:01,300 --> 00:22:05,860 and that transformer is going to simultaneously process all these examples and then look up the 267 00:22:05,860 --> 00:22:12,260 correct integers to predict in every one of these positions in the tensor y okay so now 268 00:22:12,260 --> 00:22:16,660 that we have our batch of input that we'd like to feed into a transformer let's start basically 269 00:22:16,660 --> 00:22:21,060 feeding this into neural networks now we're going to start off with the simplest possible 270 00:22:21,060 --> 00:22:24,660 neural network which in the case of language modeling in my opinion is the bigram language 271 00:22:24,660 --> 00:22:30,100 model and we've covered the bigram language model in my make more series in a lot of depth and so 272 00:22:30,100 --> 00:22:34,980 here i'm going to sort of go faster and let's just implement the pytorch module directly that 273 00:22:34,980 --> 00:22:40,740 implements the bigram language model so i'm importing the pytorch nn module 274 00:22:42,180 --> 00:22:47,060 for reproducibility and then here i'm constructing a bigram language model which is a subclass of 275 00:22:47,060 --> 00:22:52,740 nn module and then i'm calling it and i'm passing in the inputs and the targets 276 00:22:53,700 --> 00:22:58,420 and i'm just printing now when the inputs and targets come here you see that i'm just taking the 277 00:22:58,420 --> 00:23:00,020 index the inputs and targets and then i'm just printing the inputs and targets and then i'm just 278 00:23:00,020 --> 00:23:04,420 printing the inputs x here which i rename to idx and i'm just passing them into this token 279 00:23:04,420 --> 00:23:10,580 embedding table so what's going on here is that here in the constructor we are creating a token 280 00:23:10,580 --> 00:23:17,940 embedding table and it is of size vocab size by vocab size and we're using an endot embedding 281 00:23:17,940 --> 00:23:23,300 which is a very thin wrapper around basically a tensor of shape vocab size by vocab size 282 00:23:24,020 --> 00:23:30,020 and what's happening here is that when we pass idx here every single integer in our input is going to 283 00:23:30,020 --> 00:23:35,380 refer to this embedding table and is going to pluck out a row of that embedding table corresponding 284 00:23:35,380 --> 00:23:42,740 to its index so 24 here will go to the embedding table and we'll pluck out the 24th row and then 43 285 00:23:42,740 --> 00:23:48,420 will go here and pluck out the 43rd row etc and then pytorch is going to arrange all of this into 286 00:23:48,420 --> 00:23:58,500 a batch by time by channel tensor in this case batch is 4 time is 8 and c which is the channels 287 00:24:00,740 --> 00:24:04,900 and so we're just going to pluck out all those rows arrange them in a b by t by c 288 00:24:05,700 --> 00:24:09,540 and now we're going to interpret this as the logits which are basically the scores 289 00:24:10,100 --> 00:24:14,260 for the next character in a sequence and so what's happening here is 290 00:24:14,260 --> 00:24:19,300 we are predicting what comes next based on just the individual identity of a single token 291 00:24:19,940 --> 00:24:24,740 and you can do that because um i mean currently the tokens are not talking to each other and 292 00:24:24,740 --> 00:24:29,940 they're not seeing any context except for they're just seeing themselves so i'm a i'm a token number 293 00:24:30,020 --> 00:24:35,220 five and then i can actually make pretty decent predictions about what comes next just by knowing 294 00:24:35,220 --> 00:24:42,580 that i'm token 5 because some characters know um follow other characters in typical scenarios 295 00:24:42,580 --> 00:24:47,300 so we saw a lot of this in a lot more depth in the make more series and here if i just run this 296 00:24:48,020 --> 00:24:54,420 then we currently get the predictions the scores the logits for every one of the four by eight 297 00:24:54,420 --> 00:24:58,740 positions now that we've made predictions about what comes next we'd like to evaluate the loss 298 00:24:58,740 --> 00:25:00,000 function and so in this case we want to make predictions about how would the loss function Te intentions are impact the sc noises you would get this this existence and then we oriented out we would get the loss function and so at 299 00:25:00,020 --> 00:25:05,680 make more series we saw that a good way to measure a loss or like a quality of the predictions is to 300 00:25:05,680 --> 00:25:10,420 use the negative log likelihood loss which is also implemented in PyTorch under the name cross 301 00:25:10,420 --> 00:25:17,800 entropy. So what we'd like to do here is loss is the cross entropy on the predictions and the 302 00:25:17,800 --> 00:25:23,260 targets and so this measures the quality of the logits with respect to the targets. In other words 303 00:25:23,260 --> 00:25:28,960 we have the identity of the next character so how well are we predicting the next character based 304 00:25:28,960 --> 00:25:36,880 on the logits and intuitively the correct dimension of logits depending on whatever 305 00:25:36,880 --> 00:25:41,040 the target is should have a very high number and all the other dimensions should be very low number 306 00:25:41,040 --> 00:25:46,920 right. Now the issue is that this won't actually this is what we want we want to basically output 307 00:25:46,920 --> 00:25:56,400 the logits and the loss this is what we want but unfortunately this won't actually run we get an 308 00:25:56,400 --> 00:25:58,540 error message but intuitively we want to 309 00:25:58,540 --> 00:25:58,940 you 310 00:25:58,940 --> 00:26:05,640 measure this. Now when we go to the PyTorch cross entropy documentation here 311 00:26:05,640 --> 00:26:11,760 we're trying to call the cross entropy in its functional form so that means we don't have to 312 00:26:11,760 --> 00:26:17,500 create like a module for it but here when we go to the documentation you have to look into the 313 00:26:17,500 --> 00:26:23,140 details of how PyTorch expects these inputs and basically the issue here is PyTorch expects 314 00:26:23,140 --> 00:26:27,820 if you have multi-dimensional input which we do because we have a b by t by c tensor 315 00:26:27,820 --> 00:26:28,520 then it actually expects a multi-dimensional input which we do because we have a b by t by c tensor 316 00:26:28,520 --> 00:26:35,220 then it actually really wants the channels to be the second dimension here so if you 317 00:26:36,620 --> 00:26:44,620 so basically it wants a b by c by t instead of a b by t by c and so just the details of how PyTorch 318 00:26:44,620 --> 00:26:51,380 treats these kinds of inputs and so we don't actually want to deal with that so what we're 319 00:26:51,380 --> 00:26:55,520 going to do instead is we need to basically reshape our logits. So here's what I like to do I 320 00:26:55,520 --> 00:26:58,420 like to take basically give names to the dimensions 321 00:26:58,420 --> 00:27:02,660 So logits.shape is B by T by C and unpack those numbers. 322 00:27:02,660 --> 00:27:07,240 And then let's say that logits equals logits.view. 323 00:27:07,240 --> 00:27:10,780 And we want it to be a B times C, B times T by C. 324 00:27:10,780 --> 00:27:14,000 So just a two-dimensional array, right? 325 00:27:14,000 --> 00:27:16,120 So we're going to take all the, 326 00:27:16,120 --> 00:27:19,840 we're going to take all of these positions here 327 00:27:19,840 --> 00:27:21,480 and we're going to stretch them out 328 00:27:21,480 --> 00:27:23,780 in a one-dimensional sequence 329 00:27:23,780 --> 00:27:26,840 and preserve the channel dimension as the second dimension. 330 00:27:26,840 --> 00:27:29,460 So we're just kind of like stretching out the array 331 00:27:29,460 --> 00:27:30,840 so it's two-dimensional. 332 00:27:30,840 --> 00:27:32,880 And in that case, it's going to better conform 333 00:27:32,880 --> 00:27:36,380 to what PyTorch sort of expects in its dimensions. 334 00:27:36,380 --> 00:27:38,600 Now we have to do the same to targets 335 00:27:38,600 --> 00:27:43,600 because currently targets are of shape B by T 336 00:27:44,720 --> 00:27:46,980 and we want it to be just B times T. 337 00:27:46,980 --> 00:27:48,620 So one-dimensional. 338 00:27:48,620 --> 00:27:51,860 Now, alternatively, you could always still just do minus one 339 00:27:51,860 --> 00:27:54,080 because PyTorch will guess what this should be 340 00:27:54,080 --> 00:27:55,380 if you want to lay it out. 341 00:27:55,380 --> 00:27:56,500 But let me just be explicit 342 00:27:56,500 --> 00:27:56,820 and say if you can see it. 343 00:27:56,820 --> 00:27:58,200 If you can see it, it's going to be B times T. 344 00:27:58,200 --> 00:27:59,720 Once we reshape this, 345 00:27:59,720 --> 00:28:02,900 it will match the cross-entropy case 346 00:28:02,900 --> 00:28:05,200 and then we should be able to evaluate our loss. 347 00:28:07,100 --> 00:28:11,000 Okay, so that right now, and we can do loss. 348 00:28:11,000 --> 00:28:14,620 And so currently we see that the loss is 4.87. 349 00:28:14,620 --> 00:28:19,100 Now, because we have 65 possible vocabulary elements, 350 00:28:19,100 --> 00:28:21,680 we can actually guess at what the loss should be. 351 00:28:21,680 --> 00:28:23,200 And in particular, 352 00:28:23,200 --> 00:28:25,920 we covered negative log likelihood in a lot of detail. 353 00:28:25,920 --> 00:28:26,480 We are expecting, 354 00:28:26,480 --> 00:28:31,480 we're expecting log or lon of one over 65 355 00:28:32,380 --> 00:28:33,940 and negative of that. 356 00:28:33,940 --> 00:28:37,520 So we're expecting the loss to be about 4.17, 357 00:28:37,520 --> 00:28:39,260 but we're getting 4.87. 358 00:28:39,260 --> 00:28:41,200 And so that's telling us that the initial predictions 359 00:28:41,200 --> 00:28:43,060 are not super diffuse. 360 00:28:43,060 --> 00:28:44,760 They've got a little bit of entropy. 361 00:28:44,760 --> 00:28:46,160 And so we're guessing wrong. 362 00:28:47,200 --> 00:28:52,200 So yes, but actually we are able to evaluate the loss. 363 00:28:52,800 --> 00:28:55,780 Okay, so now that we can evaluate the quality of the model 364 00:28:55,780 --> 00:28:57,180 on some data, 365 00:28:57,180 --> 00:28:59,300 we'd like to also be able to generate from the model. 366 00:28:59,300 --> 00:29:01,200 So let's do the generation. 367 00:29:01,200 --> 00:29:03,040 Now I'm going to go again a little bit faster here 368 00:29:03,040 --> 00:29:06,680 because I covered all this already in the previous videos. 369 00:29:06,680 --> 00:29:10,240 So here's a generate function for the model. 370 00:29:12,420 --> 00:29:13,800 So we take some, 371 00:29:13,800 --> 00:29:17,460 we take the same kind of input IDX here. 372 00:29:17,460 --> 00:29:22,460 And basically this is the current context of some characters 373 00:29:23,240 --> 00:29:25,180 in a batch, in some batch. 374 00:29:25,780 --> 00:29:27,480 So it's also B by T. 375 00:29:27,480 --> 00:29:30,820 And the job of generate is to basically take this B by T 376 00:29:30,820 --> 00:29:33,640 and extend it to be B by T plus one, plus two, plus three. 377 00:29:33,640 --> 00:29:34,700 And so it's just basically, 378 00:29:34,700 --> 00:29:37,540 it continues the generation in all the batch dimensions 379 00:29:37,540 --> 00:29:39,320 in the time dimension. 380 00:29:39,320 --> 00:29:40,500 So that's its job. 381 00:29:40,500 --> 00:29:43,080 And it will do that for max new tokens. 382 00:29:43,080 --> 00:29:44,500 So you can see here on the bottom, 383 00:29:44,500 --> 00:29:45,920 there's going to be some stuff here, 384 00:29:45,920 --> 00:29:46,820 but on the bottom, 385 00:29:46,820 --> 00:29:49,760 whatever is predicted is concatenated 386 00:29:49,760 --> 00:29:53,080 on top of the previous IDX along the first dimension, 387 00:29:53,080 --> 00:29:55,700 which is the time dimension to create a B by T plus one. 388 00:29:55,780 --> 00:29:58,280 So that becomes a new IDX. 389 00:29:58,280 --> 00:30:00,500 So the job of generate is to take a B by T 390 00:30:00,500 --> 00:30:03,720 and make it a B by T plus one, plus two, plus three, 391 00:30:03,720 --> 00:30:05,820 as many as we want max new tokens. 392 00:30:05,820 --> 00:30:08,320 So this is the generation from the model. 393 00:30:08,320 --> 00:30:10,900 Now inside the generation, what are we doing? 394 00:30:10,900 --> 00:30:12,740 We're taking the current indices. 395 00:30:12,740 --> 00:30:14,420 We're getting the predictions. 396 00:30:14,420 --> 00:30:18,080 So we get those are in the logits. 397 00:30:18,080 --> 00:30:20,020 And then the loss here is going to be ignored 398 00:30:20,020 --> 00:30:22,240 because we're not using that. 399 00:30:22,240 --> 00:30:25,460 And we have no targets that are sort of ground truth targets 400 00:30:25,460 --> 00:30:27,360 that we're going to be comparing with. 401 00:30:28,620 --> 00:30:30,020 Then once we get the logits, 402 00:30:30,020 --> 00:30:32,600 we are only focusing on the last step. 403 00:30:32,600 --> 00:30:34,940 So instead of a B by T by C, 404 00:30:34,940 --> 00:30:37,640 we're going to pluck out the negative one, 405 00:30:37,640 --> 00:30:40,160 the last element in the time dimension, 406 00:30:40,160 --> 00:30:42,680 because those are the predictions for what comes next. 407 00:30:42,680 --> 00:30:44,100 So that gives us the logits, 408 00:30:44,100 --> 00:30:47,600 which we then convert to probabilities via softmax. 409 00:30:47,600 --> 00:30:48,960 And then we use torch.multinomial 410 00:30:48,960 --> 00:30:50,760 to sample from those probabilities. 411 00:30:50,760 --> 00:30:53,900 And we ask PyTorch to give us one sample. 412 00:30:53,900 --> 00:30:55,000 And so IDX next, 413 00:30:55,000 --> 00:30:56,980 we'll become a B by one, 414 00:30:56,980 --> 00:31:00,100 because in each one of the batch dimensions, 415 00:31:00,100 --> 00:31:02,420 we're going to have a single prediction for what comes next. 416 00:31:02,420 --> 00:31:04,500 So this numSamples equals one, 417 00:31:04,500 --> 00:31:06,600 will make this be a one. 418 00:31:06,600 --> 00:31:09,100 And then we're going to take those integers 419 00:31:09,100 --> 00:31:10,820 that come from the sampling process 420 00:31:10,820 --> 00:31:13,340 according to the probability distribution given here. 421 00:31:13,340 --> 00:31:15,340 And those integers got just concatenated 422 00:31:15,340 --> 00:31:19,020 on top of the current sort of like running stream of integers. 423 00:31:19,020 --> 00:31:21,620 And this gives us a B by T plus one. 424 00:31:21,620 --> 00:31:23,160 And then we can return that. 425 00:31:23,160 --> 00:31:24,160 Now, one thing here is, 426 00:31:24,160 --> 00:31:30,320 here is you see how i'm calling self of idx which will end up going to the forward function 427 00:31:30,960 --> 00:31:36,560 i'm not providing any targets so currently this would give an error because targets is uh is uh 428 00:31:36,560 --> 00:31:42,240 sort of like not given so target has to be optional so targets is none by default and 429 00:31:42,240 --> 00:31:49,200 then if targets is none then there's no loss to create so it's just loss is none but else 430 00:31:49,200 --> 00:31:56,320 all of this happens and we can create a loss so this will make it so um if we have the targets 431 00:31:56,320 --> 00:32:02,160 we provide them and get a loss if we have no targets we'll just get the logits so this here 432 00:32:02,160 --> 00:32:11,760 will generate from the model and let's take that for a ride now oops so i have another code chunk 433 00:32:11,760 --> 00:32:16,640 here which will generate for the model from the model and okay this is kind of crazy so maybe let 434 00:32:16,640 --> 00:32:18,400 me let me break this down 435 00:32:19,200 --> 00:32:20,960 so these are the idx right 436 00:32:24,720 --> 00:32:28,720 i'm creating a batch will be just one time will be just one 437 00:32:29,600 --> 00:32:35,520 so i'm creating a little one by one tensor and it's holding a zero and the d type the data type 438 00:32:35,520 --> 00:32:42,720 is uh integer so zero is going to be how we kick off the generation and remember that zero is uh 439 00:32:42,720 --> 00:32:47,360 is the element standing for a new line character so it's kind of like a reasonable thing to 440 00:32:47,360 --> 00:32:48,960 to feed in as the very first character 441 00:32:49,200 --> 00:32:56,000 sequence to be the new line um so it's going to be idx which we're going to feed in here 442 00:32:56,000 --> 00:33:00,800 then we're going to ask for 100 tokens and then end that generate will continue that 443 00:33:01,680 --> 00:33:07,840 now because uh generate works on the level of batches we then have to index into the 444 00:33:07,840 --> 00:33:16,800 zero throw to basically unplug the um the single batch dimension that exists and then that gives us 445 00:33:16,800 --> 00:33:17,600 a um 446 00:33:17,600 --> 00:33:24,320 time steps is just a one-dimensional array of all the indices which we will convert to simple python 447 00:33:24,320 --> 00:33:32,720 list from pytorch tensor so that that can feed into our decode function and convert those integers 448 00:33:32,720 --> 00:33:40,080 into text so let me bring this back and we're generating 100 tokens let's run and uh here's 449 00:33:40,080 --> 00:33:44,560 the generation that we achieved so obviously it's garbage and the reason it's garbage is because 450 00:33:44,560 --> 00:33:47,120 this is a totally random model so next up we're going to want to do is we're going to want to do a 451 00:33:47,120 --> 00:33:50,880 we're going to want to train this model now one more thing i wanted to point out here is 452 00:33:52,080 --> 00:33:56,400 this function is written to be general but it's kind of like ridiculous right now because 453 00:33:57,920 --> 00:34:02,720 we're feeding in all this we're building out this context and we're concatenating it all 454 00:34:02,720 --> 00:34:08,480 and we're always feeding it all into the model but that's kind of ridiculous because this is 455 00:34:08,480 --> 00:34:14,240 just a simple bigram model so to make for example this prediction about k we only needed this w 456 00:34:14,240 --> 00:34:17,040 but actually what we fed into the model is we fed the entire sequence 457 00:34:17,520 --> 00:34:23,440 and then we only looked at the very last piece and predicted k so the only reason i'm writing 458 00:34:23,440 --> 00:34:28,320 it in this way is because right now this is a bigram model but i'd like to keep this function 459 00:34:28,320 --> 00:34:36,000 fixed and i'd like it to work later when our characters actually basically look further in 460 00:34:36,000 --> 00:34:41,040 the history and so right now the history is not used so this looks silly but eventually 461 00:34:41,040 --> 00:34:46,720 the history will be used and so that's why we want to do it this way so just a quick comment on that 462 00:34:47,520 --> 00:34:53,600 so now we see that this is random so let's train the model so it becomes a bit less random okay 463 00:34:53,600 --> 00:34:58,480 let's now train the model so first what i'm going to do is i'm going to create a pytorch optimization 464 00:34:58,480 --> 00:35:06,240 object so here we are using the optimizer adam w now in the make more series we've only ever used 465 00:35:06,240 --> 00:35:11,440 stochastic gradient descent the simplest possible optimizer which you can get using the sgd instead 466 00:35:11,440 --> 00:35:15,440 but i want to use adam which is a much more advanced and popular optimizer and it works 467 00:35:15,440 --> 00:35:16,000 extremely well for a lot of other optimizers but i want to use adam which is a much more advanced and popular optimizer and it works extremely well 468 00:35:17,680 --> 00:35:23,040 typical good setting for the learning rate is roughly 3e negative 4 but for very very small 469 00:35:23,040 --> 00:35:26,720 networks like it's the case here you can get away with much much higher learning rates 470 00:35:26,720 --> 00:35:33,280 1-3 or even higher probably but let me create the optimizer object which will basically take 471 00:35:33,280 --> 00:35:40,480 the gradients and update the parameters using the gradients and then here our batch size up above 472 00:35:40,480 --> 00:35:45,360 was only 4 so let me actually use something bigger let's say 32 and then for some number of steps 473 00:35:47,120 --> 00:35:53,600 we're sampling a new batch of data we're evaluating the loss we're zeroing out all the gradients from 474 00:35:53,600 --> 00:35:58,240 the previous step getting the gradients for all the parameters and then using those gradients to 475 00:35:58,240 --> 00:36:04,240 update our parameters so typical training loop as we saw in the make more series so let me now 476 00:36:04,240 --> 00:36:09,200 run this for say 100 iterations and let's see what kind of loss is we're going to get 477 00:36:11,440 --> 00:36:16,720 so we started around 4.7 and now we're getting down to like 4.6 478 00:36:17,040 --> 00:36:23,680 so the optimization is definitely happening but let's sort of try to increase the number 479 00:36:23,680 --> 00:36:28,320 of iterations and only print at the end because we probably will not train for longer 480 00:36:30,240 --> 00:36:32,240 okay so we're down to 3.6 roughly 481 00:36:35,520 --> 00:36:36,480 roughly down to three 482 00:36:41,440 --> 00:36:43,040 this is the most janky optimization 483 00:36:47,680 --> 00:36:52,720 if we do that and clean those up we get six hours of telly in in mobile 484 00:36:54,320 --> 00:36:56,080 okay it's working let's just do 10 000 485 00:36:57,520 --> 00:37:00,640 and then from here we want to copy this 486 00:37:01,760 --> 00:37:04,640 and hopefully we're going to get something reasonable and of course it's not going to 487 00:37:04,640 --> 00:37:08,320 be shakespeare from a bigger model but at least we see that the loss is improving 488 00:37:08,880 --> 00:37:11,680 and hopefully we're expecting something a bit more reasonable 489 00:37:12,960 --> 00:37:15,600 so we're down in about 2.5-ish let's see what we get 490 00:37:15,600 --> 00:37:16,240 okay 491 00:37:16,240 --> 00:37:18,220 Let me just increase the number of tokens. 492 00:37:19,160 --> 00:37:23,740 Okay, so we see that we're starting to get something at least like reasonable-ish. 493 00:37:26,540 --> 00:37:30,280 Certainly not Shakespeare, but the model is making progress. 494 00:37:30,640 --> 00:37:32,640 So that is the simplest possible model. 495 00:37:33,860 --> 00:37:41,360 So now what I'd like to do is, obviously, this is a very simple model because the tokens are not talking to each other. 496 00:37:41,360 --> 00:37:48,260 So given the previous context of whatever was generated, we're only looking at the very last character to make the predictions about what comes next. 497 00:37:48,880 --> 00:37:57,140 So now these tokens have to start talking to each other and figuring out what is in the context so that they can make better predictions for what comes next. 498 00:37:57,520 --> 00:37:59,840 And this is how we're going to kick off the transformer. 499 00:38:00,500 --> 00:38:04,880 Okay, so next, I took the code that we developed in this Jupyter notebook and I converted it to be a script. 500 00:38:05,340 --> 00:38:11,320 And I'm doing this because I just want to simplify our intermediate work, which is just the final product that we have. 501 00:38:11,360 --> 00:38:16,520 At this point, so in the top here, I put all the hyperparameters that we've defined. 502 00:38:16,760 --> 00:38:19,520 I introduced a few and I'm going to speak to that in a little bit. 503 00:38:20,120 --> 00:38:34,880 Otherwise, a lot of this should be recognizable, reproducibility, read data, get the encoder and decoder, create the train and test splits, use the kind of like data loader that gets a batch of the inputs and targets. 504 00:38:35,840 --> 00:38:37,960 This is new, and I'll talk about it in a second. 505 00:38:39,020 --> 00:38:41,000 Now, this is the bigram language model that we developed. 506 00:38:41,720 --> 00:38:44,900 And it can forward and give us a logits and loss and it can generate. 507 00:38:46,800 --> 00:38:49,980 And then here we are creating the optimizer and this is the training loop. 508 00:38:51,960 --> 00:38:54,040 So everything here should look pretty familiar. 509 00:38:54,160 --> 00:38:56,080 Now, some of the small things that I added. 510 00:38:56,200 --> 00:39:00,280 Number one, I added the ability to run on a GPU if you have it. 511 00:39:00,760 --> 00:39:06,600 So if you have a GPU, then you can, this will use CUDA instead of just CPU and everything will be a lot more faster. 512 00:39:07,220 --> 00:39:11,340 Now, when device becomes CUDA, then we need to make sure that when we load the data. 513 00:39:11,460 --> 00:39:12,960 We move it to device. 514 00:39:13,960 --> 00:39:18,460 When we create the model, we want to move the model parameters to device. 515 00:39:18,960 --> 00:39:26,960 So as an example, here we have the in an embedding table and it's got a dot weight inside it, which stores the sort of lookup table. 516 00:39:27,160 --> 00:39:33,160 So that would be moved to the GPU so that all the calculations here happen on the GPU and they can be a lot faster. 517 00:39:33,960 --> 00:39:39,460 And then finally here, when I'm creating the context that feeds into generate, I have to make sure that I create on the device. 518 00:39:40,360 --> 00:39:41,260 Number two, when I enter. 519 00:39:41,460 --> 00:39:45,960 Introduced is the fact that here in the training loop. 520 00:39:47,660 --> 00:39:52,960 Here, I was just printing the loss dot item inside the training loop. 521 00:39:53,160 --> 00:39:57,960 But this is a very noisy measurement of the current loss because every batch will be more or less lucky. 522 00:39:58,660 --> 00:40:11,160 And so what I want to do usually is I have an estimate loss function and the estimate loss basically then goes up here and it averages up. 523 00:40:11,160 --> 00:40:13,060 The loss over multiple batches. 524 00:40:13,560 --> 00:40:22,160 So in particular, we're going to iterate eval, either times and we're going to basically get our loss and then we're going to get the average loss for both splits. 525 00:40:22,560 --> 00:40:24,160 And so this will be a lot less noisy. 526 00:40:25,060 --> 00:40:30,760 So here when we call the estimate loss, we're going to report the pretty accurate train and validation loss. 527 00:40:31,960 --> 00:40:34,560 Now when we come back up, you'll notice a few things here. 528 00:40:34,760 --> 00:40:38,260 I'm setting the model to evaluation phase and down here. 529 00:40:38,260 --> 00:40:40,260 I'm resetting it back to training phase. 530 00:40:40,260 --> 00:40:56,060 Now right now for our model as is, this doesn't actually do anything because the only thing inside this model is this nn.embedding and this network would behave the same in both evaluation mode and training mode. 531 00:40:56,460 --> 00:40:57,660 We have no dropout layers. 532 00:40:57,660 --> 00:40:59,160 We have no batch drum layers, etc. 533 00:40:59,660 --> 00:41:08,460 But it is a good practice to think through what mode your neural network is in because some layers will have different behavior at inference time or training time. 534 00:41:08,460 --> 00:41:19,360 And there's also this context manager, torch.nograd, and this is just telling PyTorch that everything that happens inside this function, we will not call .backward on. 535 00:41:20,060 --> 00:41:28,160 And so PyTorch can be a lot more efficient with its memory use because it doesn't have to store all the intermediate variables because we're never going to call backward. 536 00:41:28,660 --> 00:41:31,560 And so it can be a lot more efficient in that way. 537 00:41:31,860 --> 00:41:36,560 So also a good practice to tell PyTorch when we don't intend to do backpropagation. 538 00:41:37,660 --> 00:41:38,160 So, 539 00:41:38,460 --> 00:41:38,960 right now, 540 00:41:38,960 --> 00:41:44,560 this script is about 120 lines of code of and that's kind of our starter code. 541 00:41:45,360 --> 00:41:48,360 I'm calling it bigram.py and I'm going to release it later. 542 00:41:48,960 --> 00:41:53,860 Now running this script gives us output in the terminal and it looks something like this. 543 00:41:54,860 --> 00:41:59,560 It basically, as I ran this code, it was giving me the train loss and val loss. 544 00:41:59,760 --> 00:42:03,960 And we see that we convert to somewhere around 2.5 with the bigram model. 545 00:42:04,460 --> 00:42:06,960 And then here's the sample that we produced at the end. 546 00:42:08,460 --> 00:42:13,060 And so we have everything packaged up in the script and we're in a good position now to iterate on this. 547 00:42:13,460 --> 00:42:13,660 Okay, 548 00:42:13,660 --> 00:42:20,460 so we are almost ready to start writing our very first self-attention block for processing these tokens. 549 00:42:21,160 --> 00:42:21,660 Now, 550 00:42:22,260 --> 00:42:23,360 before we actually get there, 551 00:42:23,560 --> 00:42:33,960 I want to get you used to a mathematical trick that is used in the self-attention inside a transformer and is really just like at the heart of an efficient implementation of self-attention. 552 00:42:34,660 --> 00:42:38,160 And so I want to work with this toy example to just get you used to this operation. 553 00:42:38,460 --> 00:42:44,160 And then it's going to make it much more clear once we actually get to it in the script again. 554 00:42:45,460 --> 00:42:50,460 So let's create a B by T by C where B, T and C are just 4, 8 and 2 in this toy example. 555 00:42:51,260 --> 00:42:59,660 And these are basically channels and we have batches and we have the time component and we have some information at each point in the sequence. 556 00:42:59,960 --> 00:43:00,560 So C. 557 00:43:02,060 --> 00:43:05,660 Now what we would like to do is we would like these tokens. 558 00:43:05,760 --> 00:43:08,360 So we have up to eight tokens here in a batch. 559 00:43:08,660 --> 00:43:12,660 And these eight tokens are currently not talking to each other and we would like them to talk to each other. 560 00:43:12,660 --> 00:43:13,660 We'd like to couple them. 561 00:43:14,860 --> 00:43:16,260 And in particular, 562 00:43:16,660 --> 00:43:19,460 we don't we want to couple them in this very specific way. 563 00:43:19,960 --> 00:43:20,960 So the token, 564 00:43:20,960 --> 00:43:21,360 for example, 565 00:43:21,360 --> 00:43:22,560 at the fifth location, 566 00:43:23,060 --> 00:43:30,460 it should not communicate with tokens in the sixth seventh and eighth location because those are future tokens in the sequence. 567 00:43:31,060 --> 00:43:35,560 The token on the fifth location should only talk to the one in the fourth third second and first. 568 00:43:36,060 --> 00:43:38,460 So it's only so information only flows. 569 00:43:38,560 --> 00:43:45,260 From previous context to the current time step and we cannot get any information from the future because we are about to try to predict the future. 570 00:43:46,460 --> 00:43:50,560 So what is the easiest way for tokens to communicate? 571 00:43:50,960 --> 00:43:53,860 Okay, the easiest way I would say is okay. 572 00:43:53,860 --> 00:44:05,860 If we are up to if we're a fifth token and I'd like to communicate with my past the simplest way we can do that is to just do a weight is to just do an average of all the of all the preceding elements. 573 00:44:06,160 --> 00:44:06,760 So for example, 574 00:44:06,760 --> 00:44:07,660 if I'm the fifth token, 575 00:44:07,760 --> 00:44:13,260 I would like to take the channels that make up that are information at my step, 576 00:44:13,660 --> 00:44:17,460 but then also the channels from the fourth step third step second step in the first step. 577 00:44:17,660 --> 00:44:24,960 I'd like to average those up and then that would become sort of like a feature vector that summarizes me in the context of my history. 578 00:44:25,660 --> 00:44:25,860 Now, 579 00:44:25,860 --> 00:44:26,160 of course, 580 00:44:26,160 --> 00:44:30,260 just doing a sum or like an average is an extremely weak form of interaction. 581 00:44:30,260 --> 00:44:32,460 Like this communication is extremely lossy. 582 00:44:32,660 --> 00:44:36,160 We've lost a ton of information about spatial arrangements of all those tokens, 583 00:44:36,960 --> 00:44:37,560 but that's okay. 584 00:44:37,660 --> 00:44:38,160 For now, 585 00:44:38,160 --> 00:44:40,260 we'll see how we can bring that information back later. 586 00:44:41,060 --> 00:44:41,360 For now, 587 00:44:41,360 --> 00:44:48,760 what we would like to do is for every single batch element independently for every teeth token in that sequence. 588 00:44:49,160 --> 00:44:56,560 We'd like to now calculate the average of all the vectors in all the previous tokens and also at this token. 589 00:44:57,460 --> 00:44:58,460 So let's write that out. 590 00:44:59,960 --> 00:45:02,960 I have a small snippet here and instead of just fumbling around, 591 00:45:03,560 --> 00:45:05,160 let me just copy paste it and talk to it. 592 00:45:06,560 --> 00:45:07,460 So in other words, 593 00:45:08,160 --> 00:45:18,960 we're going to create X and B O W is short for bag of words because bag of words is is kind of like a term that people use when you are just averaging up things. 594 00:45:18,960 --> 00:45:20,360 So this is just a bag of words. 595 00:45:20,660 --> 00:45:20,960 Basically, 596 00:45:20,960 --> 00:45:26,260 there's a word stored on every one of these eight locations and we're doing a bag of words for just averaging. 597 00:45:27,460 --> 00:45:28,260 So in the beginning, 598 00:45:28,260 --> 00:45:31,960 we're going to say that it's just initialized at zero and then I'm doing a for loop here. 599 00:45:31,960 --> 00:45:33,260 So we're not being efficient yet. 600 00:45:33,260 --> 00:45:33,960 That's coming. 601 00:45:34,560 --> 00:45:34,960 But for now, 602 00:45:34,960 --> 00:45:37,460 we're just iterating over all the batch dimensions independently. 603 00:45:38,060 --> 00:45:39,460 Iterating over time 604 00:45:40,160 --> 00:45:45,860 and then the previous tokens are at this batch dimension 605 00:45:46,360 --> 00:45:49,360 and then everything up to and including the teeth token. 606 00:45:49,860 --> 00:45:50,260 Okay. 607 00:45:50,960 --> 00:45:53,060 So when we slice out X in this way, 608 00:45:53,560 --> 00:45:55,660 Xprev becomes of shape, 609 00:45:56,860 --> 00:46:00,960 how many T elements there were in the past and then of course C. 610 00:46:00,960 --> 00:46:04,260 So all the two-dimensional information from these little tokens. 611 00:46:05,260 --> 00:46:07,460 So that's the previous sort of chunk of 612 00:46:07,560 --> 00:46:11,260 tokens from my current sequence. 613 00:46:11,960 --> 00:46:15,560 And then I'm just doing the average or the mean over the zero dimension. 614 00:46:15,560 --> 00:46:21,660 So I'm averaging out the time here and I'm just going to get a little C one-dimensional vector, 615 00:46:21,660 --> 00:46:24,360 which I'm going to store in X bag of words. 616 00:46:25,160 --> 00:46:31,560 So I can run this and this is not going to be very informative because let's see. 617 00:46:31,560 --> 00:46:32,560 So this is X of zero. 618 00:46:32,560 --> 00:46:36,660 So this is the zeroth batch element and then expo at zero. 619 00:46:37,160 --> 00:46:37,660 Now, 620 00:46:38,560 --> 00:46:39,460 you see how the 621 00:46:39,960 --> 00:46:41,260 at the first location here, 622 00:46:41,660 --> 00:46:43,360 you see that the two are equal 623 00:46:43,860 --> 00:46:46,860 and that's because it's we're just doing an average of this one token. 624 00:46:47,660 --> 00:46:50,660 But here this one is now an average of these two. 625 00:46:51,860 --> 00:46:54,860 And now this one is an average of these three. 626 00:46:55,960 --> 00:46:56,560 And so on. 627 00:46:57,860 --> 00:47:02,460 So and this last one is the average of all of these elements. 628 00:47:02,460 --> 00:47:06,360 So vertical average just averaging up all the tokens now gives this outcome. 629 00:47:06,660 --> 00:47:07,160 Here. 630 00:47:08,360 --> 00:47:09,760 So this is all well and good, 631 00:47:10,160 --> 00:47:11,560 but this is very inefficient. 632 00:47:11,960 --> 00:47:12,160 Now. 633 00:47:12,160 --> 00:47:16,660 The trick is that we can be very very efficient about doing this using matrix multiplication. 634 00:47:17,360 --> 00:47:19,060 So that's the mathematical trick. 635 00:47:19,060 --> 00:47:20,160 And let me show you what I mean. 636 00:47:20,560 --> 00:47:22,160 Let's work with the toy example here. 637 00:47:23,060 --> 00:47:24,260 You run it and I'll explain. 638 00:47:25,460 --> 00:47:27,560 I have a simple matrix here. 639 00:47:27,560 --> 00:47:34,160 That is three by three of all ones a matrix B of just random numbers and it's a three by two and a matrix C, 640 00:47:34,160 --> 00:47:36,560 which will be three by three multiply three by two. 641 00:47:36,860 --> 00:47:38,560 Which will give out a three by two. 642 00:47:39,460 --> 00:47:40,560 So here we're just using 643 00:47:41,860 --> 00:47:42,860 matrix multiplication. 644 00:47:43,460 --> 00:47:45,260 So a multiply B gives us C. 645 00:47:47,060 --> 00:47:51,860 Okay, so how are these numbers in C achieved? 646 00:47:51,860 --> 00:47:52,160 Right? 647 00:47:52,160 --> 00:47:59,460 So this number in the top left is the first row of a dot product with the first column of B. 648 00:48:00,160 --> 00:48:06,460 And since all the row of a right now is all just once then the dot product here with with this column of 649 00:48:06,460 --> 00:48:06,560 B. 650 00:48:06,860 --> 00:48:09,860 Is just going to do a sum of these of this column. 651 00:48:10,060 --> 00:48:12,460 So two plus six plus six is 14. 652 00:48:13,460 --> 00:48:17,060 The element here in the output of C is also the first column here. 653 00:48:17,060 --> 00:48:21,060 The first row of a multiplied now with the second column of B. 654 00:48:21,460 --> 00:48:23,860 So seven plus four plus plus five is 16. 655 00:48:24,760 --> 00:48:26,260 Now you see that there's repeating elements here. 656 00:48:26,260 --> 00:48:31,460 So this 14 again is because this row is again all once and it's multiplying the first column of B. 657 00:48:31,460 --> 00:48:34,760 So we get 14 and this one is and so on. 658 00:48:34,760 --> 00:48:36,560 So this last number here is the. 659 00:48:36,660 --> 00:48:39,360 The last row dot product last column. 660 00:48:40,460 --> 00:48:42,660 Now the trick here is the following. 661 00:48:43,360 --> 00:48:53,860 This is just a boring number of is just a boring array of all once but torch has this function called trill which is short for a triangular. 662 00:48:55,360 --> 00:49:01,860 Something like that and you can wrap it in torch that once and it will just return the lower triangular portion of this. 663 00:49:02,660 --> 00:49:03,060 Okay. 664 00:49:04,760 --> 00:49:06,460 So now it will basically zero out. 665 00:49:06,660 --> 00:49:07,460 Of these guys here. 666 00:49:07,560 --> 00:49:09,460 So we just get the lower triangular part. 667 00:49:09,760 --> 00:49:11,960 Well, what happens if we do that? 668 00:49:15,160 --> 00:49:18,160 So now we'll have a like this and be like this. 669 00:49:18,160 --> 00:49:19,660 And now what are we getting here and see? 670 00:49:20,360 --> 00:49:21,560 Well, what is this number? 671 00:49:21,860 --> 00:49:26,860 Well, this is the first row times the first column and because this is zeros. 672 00:49:28,960 --> 00:49:30,660 These elements here are now ignored. 673 00:49:30,660 --> 00:49:36,160 So we just get a two and then this number here is the first row times the second column. 674 00:49:36,660 --> 00:49:39,660 And because these are zeros they get ignored and it's just seven. 675 00:49:40,160 --> 00:49:41,460 The seven multiplies this one. 676 00:49:42,460 --> 00:49:45,260 But look what happened here because this is one and then zeros. 677 00:49:45,560 --> 00:49:50,860 We what ended up happening is we're just plucking out the row of this row of B and that's what we got. 678 00:49:52,160 --> 00:49:54,760 Now here we have one one zero. 679 00:49:55,360 --> 00:50:02,660 So here one one zero dot product with these two columns will now give us two plus six which is eight and seven plus four which is 11. 680 00:50:03,360 --> 00:50:06,160 And because this is one one one we ended up with. 681 00:50:06,660 --> 00:50:07,860 The addition of all of them. 682 00:50:08,860 --> 00:50:11,860 And so basically depending on how many ones and zeros we have here. 683 00:50:12,260 --> 00:50:20,260 We are basically doing a sum currently of the variable number of these rows and that gets deposited into C. 684 00:50:21,760 --> 00:50:32,760 So currently we're doing sums because these are ones but we can also do average right and you can start to see how we could do average of the rows of B sort of an incremental fashion. 685 00:50:33,560 --> 00:50:36,360 Because we don't have to we can basically normalize. 686 00:50:36,360 --> 00:50:39,660 These rows so that they sum to one and then we're going to get an average. 687 00:50:40,360 --> 00:50:47,060 So if we took a and then we did a equals a divide a torch dot sum in the. 688 00:50:48,660 --> 00:50:50,460 Of a in the. 689 00:50:51,360 --> 00:50:51,760 One. 690 00:50:52,860 --> 00:50:55,860 Dimension and then let's keep them is true. 691 00:50:56,460 --> 00:50:58,260 So therefore the broadcasting will work out. 692 00:50:58,860 --> 00:51:03,460 So if I rerun this you see now that these rows now sum to one. 693 00:51:03,760 --> 00:51:06,160 So this row is one this row is point five point five zero. 694 00:51:06,860 --> 00:51:08,360 And here we get one thirds. 695 00:51:08,960 --> 00:51:11,760 And now when we do a multiply be what are we getting. 696 00:51:12,560 --> 00:51:14,860 Here we are just getting the first row first row. 697 00:51:15,760 --> 00:51:19,160 Here now we are getting the average of the first two rows. 698 00:51:21,060 --> 00:51:25,160 Okay so two and six average is four and four and seven averages five point five. 699 00:51:26,060 --> 00:51:30,760 And on the bottom here we are now getting the average of these three rows. 700 00:51:31,560 --> 00:51:35,560 So the average of all of elements of B are now deposited here. 701 00:51:36,460 --> 00:51:45,060 And so you can see that by manipulating these elements of this multiplying matrix and then multiplying it with any given matrix. 702 00:51:45,360 --> 00:51:49,860 We can do these averages in this incremental fashion because we just get. 703 00:51:51,460 --> 00:51:54,060 And we can manipulate that based on the elements of a. 704 00:51:54,660 --> 00:52:01,460 Okay so that's very convenient so let's swing back up here and see how we can vectorize this and make it much more efficient using what we've learned. 705 00:52:02,360 --> 00:52:03,060 So in particular. 706 00:52:04,660 --> 00:52:06,160 We are going to produce an array. 707 00:52:06,860 --> 00:52:09,360 But here I'm going to call it way short for weights. 708 00:52:10,160 --> 00:52:11,060 But this is our a. 709 00:52:12,660 --> 00:52:19,460 And this is how much of every row we want to average up and it's going to be an average because you can see that these rows sum to one. 710 00:52:21,060 --> 00:52:25,360 So this is our a and then our B in this example of course is. 711 00:52:26,160 --> 00:52:26,560 X. 712 00:52:27,860 --> 00:52:30,960 So it's going to happen here now is that we are going to have an expo to. 713 00:52:32,760 --> 00:52:35,660 And this expo to is going to be way. 714 00:52:36,460 --> 00:52:37,160 Multiplying. 715 00:52:38,060 --> 00:52:38,560 Rx. 716 00:52:39,860 --> 00:52:47,160 So let's think this through way is T by T and this is matrix multiplying in PyTorch a B by T by C. 717 00:52:48,660 --> 00:52:49,460 And it's giving us. 718 00:52:51,060 --> 00:52:51,560 What shape. 719 00:52:52,160 --> 00:53:00,260 So PyTorch will come here and it will see that these shapes are not the same so it will create a bash dimension here and this is a batch matrix multiply. 720 00:53:01,460 --> 00:53:04,860 And so it will apply this matrix multiplication in all the batch elements. 721 00:53:05,460 --> 00:53:05,960 In parallel. 722 00:53:06,460 --> 00:53:14,160 And individually and then for each batch element there will be a T by T multiplying T by C exactly as we had below. 723 00:53:16,660 --> 00:53:17,860 So this will now create. 724 00:53:18,760 --> 00:53:20,060 B by T by C. 725 00:53:21,360 --> 00:53:25,160 And expo to will now become identical to expo. 726 00:53:26,360 --> 00:53:26,960 So. 727 00:53:28,960 --> 00:53:30,960 We can see that torched out all close. 728 00:53:31,860 --> 00:53:35,560 Of expo and expo to should be true now. 729 00:53:37,360 --> 00:53:41,960 So this kind of like misses us that these are in fact the same. 730 00:53:42,760 --> 00:53:45,960 So expo and expo to if I just print them. 731 00:53:48,260 --> 00:53:49,960 Okay, we're not going to be able to. 732 00:53:50,460 --> 00:53:52,960 Okay, we're not going to be able to just stare it down but. 733 00:53:55,160 --> 00:53:58,960 Well, let me try expo basically just at the 0th element and expo to at the 0th element. 734 00:53:58,960 --> 00:54:04,160 So just the first batch and we should see that this and that should be identical which they are. 735 00:54:05,360 --> 00:54:05,760 Right. 736 00:54:05,860 --> 00:54:06,860 So what happened here. 737 00:54:07,260 --> 00:54:20,560 The trick is we were able to use batch matrix multiply to do this aggregation really and it's a weighted aggregation and the weights are specified in this T by T array. 738 00:54:21,460 --> 00:54:31,360 And we're basically doing weighted sums and these weighted sums are according to the weights inside here that take on sort of this triangular form. 739 00:54:32,160 --> 00:54:35,660 And so that means that a token at the teeth dimension will only get. 740 00:54:35,860 --> 00:54:40,360 Sort of information from the tokens preceding it. 741 00:54:40,660 --> 00:54:41,860 So that's exactly what we want. 742 00:54:42,260 --> 00:54:44,760 And finally, I would like to rewrite it in one more way. 743 00:54:45,460 --> 00:54:47,160 And we're going to see why that's useful. 744 00:54:48,060 --> 00:54:53,760 So this is the third version and it's also identical to the first and second, but let me talk through it. 745 00:54:53,760 --> 00:54:54,860 It uses softmax. 746 00:54:55,660 --> 00:55:04,860 So trill here is this Matrix lower triangular once way begins as all zero. 747 00:55:05,860 --> 00:55:13,560 Okay, so if I just print way in the beginning, it's all zero then I used masked fill. 748 00:55:14,260 --> 00:55:18,260 So what this is doing is wait that masked fill it's all zeros. 749 00:55:18,260 --> 00:55:24,760 And I'm saying for all the elements where trill is equals equals zero make them be negative Infinity. 750 00:55:25,460 --> 00:55:29,160 So all the elements where trill is zero will become negative Infinity now. 751 00:55:30,260 --> 00:55:31,160 So this is what we get. 752 00:55:32,260 --> 00:55:34,960 And then the final line here is softmax. 753 00:55:36,760 --> 00:55:44,860 So if I take a softmax along every single so dim is negative one so long every single row if I do a softmax, what is that going to do? 754 00:55:47,060 --> 00:55:53,160 Well softmax is is also like a normalization operation, right? 755 00:55:54,160 --> 00:55:56,860 And so spoiler alert you get the exact same Matrix. 756 00:55:58,560 --> 00:56:02,160 Let me bring back the softmax and recall that in softmax. 757 00:56:02,160 --> 00:56:04,560 We're going to exponentiate every single one of these. 758 00:56:05,760 --> 00:56:07,460 And then we're going to divide by the sum. 759 00:56:08,260 --> 00:56:16,460 And so if we exponentiate every single element here, we're going to get a one and here we're going to get basically zero zero zero zero everywhere else. 760 00:56:17,060 --> 00:56:20,260 And then when we normalize we just get one here. 761 00:56:20,260 --> 00:56:27,460 We're going to get one one and then zeros and then softmax will again divide and this will give us 0.5 0.5 and so on. 762 00:56:28,160 --> 00:56:32,460 And so this is also the same way to produce this mask. 763 00:56:33,360 --> 00:56:35,260 Now the reason that this is a bit more interesting. 764 00:56:35,360 --> 00:56:47,660 And the reason we're going to end up using it in self-attention is that these weights here begin with zero and you can think of this as like an interaction strength or like an affinity. 765 00:56:48,260 --> 00:57:03,760 So basically it's telling us how much of each token from the past do we want to aggregate and average up and then this line is saying tokens from the past cannot communicate by setting them to negative Infinity. 766 00:57:03,860 --> 00:57:05,060 We're saying that we will not. 767 00:57:05,260 --> 00:57:07,060 Aggregate anything from those tokens. 768 00:57:08,360 --> 00:57:13,760 And so basically this then goes through softmax and through the weighted and this is the aggregation through matrix multiplication. 769 00:57:15,060 --> 00:57:18,060 And so what this is now is you can think of these as 770 00:57:18,860 --> 00:57:28,960 the zeros are currently just set by us to be zero but quick preview is that these affinities between the tokens are not going to be just constant at zero. 771 00:57:29,260 --> 00:57:30,960 They're going to be data dependent. 772 00:57:31,260 --> 00:57:35,160 These tokens are going to start looking at each other and some tokens will find other tokens. 773 00:57:35,360 --> 00:57:44,160 More or less interesting and depending on what their values are, they're going to find each other interesting to different amounts and I'm going to call those affinities. 774 00:57:44,160 --> 00:57:48,460 I think and then here we are saying the future cannot communicate with the past. 775 00:57:48,960 --> 00:57:50,060 We're going to clamp them. 776 00:57:51,160 --> 00:57:57,860 And then when we normalize and some we're going to aggregate sort of their values depending on how interesting they find each other. 777 00:57:58,460 --> 00:58:04,760 And so that's the preview for self-attention and basically long story short from this entire section is that. 778 00:58:05,360 --> 00:58:15,060 You can do weighted aggregations of your past elements by having by using matrix multiplication of a lower triangular fashion. 779 00:58:15,760 --> 00:58:22,960 And then the elements here in the lower triangular part are telling you how much of each element fuses into this position. 780 00:58:23,660 --> 00:58:26,760 So we're going to use this trick now to develop the self-attention block. 781 00:58:27,260 --> 00:58:29,460 So first let's get some quick preliminaries out of the way. 782 00:58:30,760 --> 00:58:34,960 First the thing I'm kind of bothered by is that you see how we're passing in vocab size into the constructor. 783 00:58:35,360 --> 00:58:39,560 You don't need to do that because vocab size is already defined up top as a global variable. 784 00:58:39,560 --> 00:58:41,360 So there's no need to pass this stuff around. 785 00:58:42,860 --> 00:58:46,060 Next what I want to do is I don't want to actually create. 786 00:58:46,360 --> 00:58:52,260 I want to create like a level of indirection here where we don't directly go to the embedding for the logits. 787 00:58:52,460 --> 00:58:56,960 But instead we go through this intermediate phase because we're going to start making that bigger. 788 00:58:57,560 --> 00:59:03,360 So let me introduce a new variable and embed it short for number of embedding dimensions. 789 00:59:03,960 --> 00:59:04,760 So an embed. 790 00:59:05,360 --> 00:59:08,360 Here will be say 32. 791 00:59:09,060 --> 00:59:11,260 That was a suggestion from GitHub copilot by the way. 792 00:59:11,960 --> 00:59:14,160 It also suggested 32 which is a good number. 793 00:59:15,460 --> 00:59:18,860 So this is an embedding table and only 32 dimensional embeddings. 794 00:59:19,960 --> 00:59:22,660 So then here this is not going to give us logits directly. 795 00:59:23,160 --> 00:59:25,360 Instead this is going to give us token embeddings. 796 00:59:25,960 --> 00:59:26,860 That's what I'm going to call it. 797 00:59:27,260 --> 00:59:31,060 And then to go from the token embeddings to the logits we're going to need a linear layer. 798 00:59:31,560 --> 00:59:35,160 So self.lmhead let's call it short for language modeling head. 799 00:59:35,960 --> 00:59:38,760 Is an in linear from an embed up to vocab size. 800 00:59:39,660 --> 00:59:40,860 And then we swing over here. 801 00:59:40,960 --> 00:59:44,660 We're actually going to get the logits by exactly what the copilot says. 802 00:59:45,760 --> 00:59:50,260 Now we have to be careful here because this C and this C are not equal. 803 00:59:51,360 --> 00:59:53,760 This is an embed C and this is vocab size. 804 00:59:54,760 --> 00:59:57,260 So let's just say that an embed is equal to C. 805 00:59:58,660 --> 01:00:02,860 And then this just creates one spurious layer of indirection through a linear layer. 806 01:00:03,360 --> 01:00:05,060 But this should basically run. 807 01:00:05,260 --> 01:00:16,360 So we see that this runs and this currently looks kind of spurious. 808 01:00:16,360 --> 01:00:18,160 But we're going to build on top of this. 809 01:00:18,860 --> 01:00:20,260 Now next up so far. 810 01:00:20,260 --> 01:00:27,460 We've taken these indices and we've encoded them based on the identity of the tokens inside IDX. 811 01:00:28,160 --> 01:00:34,260 The next thing that people very often do is that we're not just encoding the identity of these tokens, but also their position. 812 01:00:34,760 --> 01:00:37,860 So we're going to have a second position embedding table here. 813 01:00:38,260 --> 01:00:43,460 So solve that position embedding table is an embedding of block size by an embed. 814 01:00:43,960 --> 01:00:48,360 And so each position from zero to block size minus one will also get its own embedding vector. 815 01:00:49,460 --> 01:00:53,760 And then here first, let me decode B by T from IDX dot shape. 816 01:00:55,360 --> 01:00:59,160 And then here we're also going to have a plus embedding, which is the positional embedding. 817 01:00:59,260 --> 01:01:01,060 And these are this is torr dash arrange. 818 01:01:01,560 --> 01:01:04,060 So this will be basically just integers from zero to zero. 819 01:01:04,260 --> 01:01:05,160 To T minus one. 820 01:01:06,160 --> 01:01:11,060 And all of those integers from zero to T minus one get embedded through the table to create a T by C. 821 01:01:12,360 --> 01:01:20,760 And then here this gets renamed to just say X and X will be the addition of the token embeddings with the positional embeddings. 822 01:01:21,960 --> 01:01:23,860 And here the broadcasting note will work out. 823 01:01:23,860 --> 01:01:25,960 So B by T by C plus T by C. 824 01:01:26,460 --> 01:01:31,560 This gets right aligned and new dimension of one gets added and it gets broadcasted across batch. 825 01:01:32,760 --> 01:01:34,160 So at this point X. 826 01:01:34,260 --> 01:01:38,860 Holds not just the token identities, but the positions at which these tokens occur. 827 01:01:39,660 --> 01:01:43,360 And this is currently not that useful because of course, we just have a simple migraine model. 828 01:01:43,460 --> 01:01:49,160 So it doesn't matter if you're in the fifth position, the second position or wherever it's all translation invariant at this stage. 829 01:01:49,660 --> 01:01:51,460 So this information currently wouldn't help. 830 01:01:52,060 --> 01:01:55,660 But as we work on the self-attention block, we'll see that this starts to matter. 831 01:01:59,860 --> 01:02:01,960 Okay, so now we get the crux of self-attention. 832 01:02:02,260 --> 01:02:04,160 So this is probably the most important part of this video. 833 01:02:04,360 --> 01:02:05,360 To understand. 834 01:02:06,460 --> 01:02:10,560 We're going to implement a small self-attention for a single individual head as they're called. 835 01:02:11,160 --> 01:02:12,960 So we start off with where we were. 836 01:02:13,160 --> 01:02:14,560 So all of this code is familiar. 837 01:02:15,360 --> 01:02:20,060 So right now I'm working with an example where I change the number of channels from 2 to 32. 838 01:02:20,060 --> 01:02:28,160 So we have a 4 by 8 arrangement of tokens and each token and the information at each token is currently 32 dimensional. 839 01:02:28,260 --> 01:02:30,060 But we just are working with random numbers. 840 01:02:31,360 --> 01:02:32,960 Now we saw here that 841 01:02:32,960 --> 01:02:42,260 The code as we had it before does a simple weight simple average of all the past tokens and the current token. 842 01:02:42,460 --> 01:02:46,660 So it's just the previous information and current information is just being mixed together in an average. 843 01:02:47,360 --> 01:02:49,260 And that's what this code currently achieves. 844 01:02:49,560 --> 01:02:57,260 And it does so by creating this lower triangular structure, which allows us to mask out this weight matrix that we create. 845 01:02:57,760 --> 01:03:01,660 So we mask it out and then we normalize it and currently 846 01:03:01,660 --> 01:03:09,760 When we initialize the affinities between all the different sort of tokens or nodes, I'm going to use those terms interchangeably. 847 01:03:10,660 --> 01:03:19,660 So when we initialize the affinities between all the different tokens to be zero, then we see that way gives us this structure where every single row has these 848 01:03:21,160 --> 01:03:22,060 Uniform numbers. 849 01:03:22,460 --> 01:03:28,660 And so that's what that's what then in this matrix multiply makes it so that we're doing a simple average. 850 01:03:29,660 --> 01:03:30,160 Now, 851 01:03:30,860 --> 01:03:31,560 We don't actually want. 852 01:03:31,860 --> 01:03:32,660 This to be 853 01:03:33,360 --> 01:03:34,260 All uniform 854 01:03:34,660 --> 01:03:41,260 Because different tokens will find different other tokens more or less interesting and we want that to be data dependent. 855 01:03:41,460 --> 01:03:50,060 So for example, if I'm a vowel then maybe I'm looking for consonants in my past and maybe I want to know what those consonants are and I want that information to flow to me. 856 01:03:51,160 --> 01:03:55,960 And so I want to now gather information from the past, but I want to do it in a data dependent way. 857 01:03:56,260 --> 01:03:58,160 And this is the problem that self-attention solves. 858 01:03:58,960 --> 01:04:00,860 Now the way self-attention solves this 859 01:04:01,060 --> 01:04:01,560 Is the following. 860 01:04:02,260 --> 01:04:07,760 Every single node or every single token at each position will emit two vectors. 861 01:04:08,460 --> 01:04:11,860 It will emit a query and it will emit a key. 862 01:04:13,360 --> 01:04:17,360 Now the query vector roughly speaking is what am I looking for? 863 01:04:18,260 --> 01:04:21,160 And the key vector roughly speaking is what do I contain? 864 01:04:22,560 --> 01:04:31,460 And then the way we get affinities between these tokens now in a sequence is we basically just do a dot product between the keys and the query. 865 01:04:32,260 --> 01:04:41,260 So my query dot products with all the keys of all the other tokens and that dot product now becomes way. 866 01:04:42,260 --> 01:04:56,260 And so if the key and the query are sort of aligned, they will interact to a very high amount and then I will get to learn more about that specific token as opposed to any other token in the sequence. 867 01:04:56,460 --> 01:04:57,460 So let's implement this now. 868 01:05:01,660 --> 01:05:03,060 We're going to implement a single 869 01:05:04,660 --> 01:05:06,760 what's called head of self-attention. 870 01:05:08,060 --> 01:05:09,260 So this is just one head. 871 01:05:09,660 --> 01:05:12,760 There's a hyper parameter involved with these heads, which is the head size. 872 01:05:13,460 --> 01:05:18,060 And then here I'm initializing linear modules and I'm using bias equals false. 873 01:05:18,060 --> 01:05:21,560 So these are just going to apply a matrix multiply with some fixed weights. 874 01:05:22,760 --> 01:05:29,860 And now let me produce a key and Q K and Q by forwarding these modules on X. 875 01:05:30,960 --> 01:05:31,560 So the size of this. 876 01:05:31,760 --> 01:05:32,860 This will not become 877 01:05:33,760 --> 01:05:40,160 B by T by 16 because that is the head size and the same here B by T by 16. 878 01:05:45,860 --> 01:05:47,060 So this being the head size. 879 01:05:47,660 --> 01:05:59,660 So you see here that when I forward this linear on top of my X all the tokens in all the positions in the B by T arrangement all of them in parallel and independently produce a key and a query. 880 01:05:59,660 --> 01:06:01,260 So no communication has happened yet. 881 01:06:02,660 --> 01:06:03,960 But the communication comes now. 882 01:06:04,060 --> 01:06:07,260 All the queries will dot product with all the keys. 883 01:06:08,560 --> 01:06:15,860 So basically what we want is we want way now or the affinities between these to be query multiplying key, 884 01:06:16,560 --> 01:06:19,160 but we have to be careful with we can't matrix multiply this. 885 01:06:19,160 --> 01:06:26,460 We actually need to transpose K but we have to be also careful because these are when you have the batch dimension. 886 01:06:26,860 --> 01:06:31,160 So in particular we want to transpose the last two dimensions. 887 01:06:31,260 --> 01:06:31,560 Dimension. 888 01:06:31,660 --> 01:06:33,360 Negative one and dimension negative two. 889 01:06:34,060 --> 01:06:36,260 So negative two negative one. 890 01:06:37,560 --> 01:06:43,060 And so this matrix multiply now will basically do the following B by T by 16. 891 01:06:45,260 --> 01:06:52,360 Matrix multiplies B by 16 by T to give us B by T by T. 892 01:06:54,460 --> 01:06:54,860 Right? 893 01:06:56,060 --> 01:06:57,860 So for every row of B, 894 01:06:58,160 --> 01:07:01,460 we're not going to have a T square matrix giving us the affinity. 895 01:07:01,560 --> 01:07:07,360 We're going to have a T square matrix giving us the affinities and these are now the way so they're not zeros. 896 01:07:07,560 --> 01:07:11,360 They are now coming from this dot product between the keys in the queries. 897 01:07:11,560 --> 01:07:21,360 So this can now run I can I can run this and the weighted aggregation now is a function in a data abandoned manner between the keys and queries of these notes. 898 01:07:21,560 --> 01:07:27,160 So just inspecting what happened here the way takes on this form. 899 01:07:27,360 --> 01:07:31,160 And you see that before way was just a constant. 900 01:07:31,160 --> 01:07:33,060 There's no way to all the batch elements. 901 01:07:33,260 --> 01:07:41,660 But now every single batch elements will have different sort of way because every single batch element contains different tokens at different positions. 902 01:07:41,860 --> 01:07:43,360 And so this is not data dependent. 903 01:07:44,160 --> 01:07:47,060 So when we look at just the zero row, 904 01:07:47,260 --> 01:07:48,460 for example in the input, 905 01:07:49,060 --> 01:07:50,860 these are the weights that came out. 906 01:07:51,260 --> 01:07:53,460 And so you can see now that they're not just exactly uniform. 907 01:07:55,260 --> 01:07:57,460 And in particular as an example here for the last row, 908 01:07:57,860 --> 01:08:01,060 this was the eighth token and the eighth token knows what content. 909 01:08:01,160 --> 01:08:03,460 It has and it knows at what position it's in. 910 01:08:04,160 --> 01:08:08,260 And now the eight token based on that creates a query. 911 01:08:08,560 --> 01:08:09,960 Hey, I'm looking for this kind of stuff. 912 01:08:11,060 --> 01:08:11,660 I'm a vowel. 913 01:08:11,860 --> 01:08:12,660 I'm on the eighth position. 914 01:08:12,860 --> 01:08:15,360 I'm looking for any consonants at positions up to four. 915 01:08:16,460 --> 01:08:24,460 And then all the nodes get to emit keys and maybe one of the channels could be I am a I am a consonant and I am in a position up to four. 916 01:08:25,160 --> 01:08:28,860 And that key would have a high number in that specific channel. 917 01:08:29,360 --> 01:08:30,960 And that's how the query and the key when they 918 01:08:31,060 --> 01:08:31,660 dark product, 919 01:08:31,660 --> 01:08:33,760 they can find each other and create a high affinity. 920 01:08:34,760 --> 01:08:35,860 And when they have a high affinity, 921 01:08:35,860 --> 01:08:41,060 like say this token was pretty interesting to to this eighth token. 922 01:08:42,360 --> 01:08:43,560 When they have a high affinity, 923 01:08:43,860 --> 01:08:45,060 then through the softmax, 924 01:08:45,260 --> 01:08:48,660 I will end up aggregating a lot of its information into my position. 925 01:08:49,260 --> 01:08:51,160 And so I'll get to learn a lot about it. 926 01:08:52,660 --> 01:08:57,460 Now just this was looking at way after this has already happened. 927 01:08:59,360 --> 01:09:00,860 Let me erase this operation as well. 928 01:09:00,960 --> 01:09:05,760 So let me erase the masking and the softmax just to show you the under the hood internals and how that works. 929 01:09:06,560 --> 01:09:10,560 So without the masking and the softmax way comes out like this, 930 01:09:10,560 --> 01:09:10,860 right? 931 01:09:10,860 --> 01:09:12,560 This is the outputs of the dark products. 932 01:09:13,760 --> 01:09:16,360 And these are the raw outputs and they take on values from negative, 933 01:09:16,560 --> 01:09:16,960 you know, 934 01:09:17,060 --> 01:09:18,460 two to positive two Etc. 935 01:09:19,760 --> 01:09:23,660 So that's the raw interactions and raw Affinities between all the nodes. 936 01:09:24,360 --> 01:09:26,560 But now if I'm a if I'm a fifth node, 937 01:09:26,660 --> 01:09:30,860 I will not want to aggregate anything from the sixth node seventh node and the eighth node. 938 01:09:31,360 --> 01:09:34,460 So actually we use the upper triangular masking. 939 01:09:34,960 --> 01:09:36,560 So those are not allowed to communicate. 940 01:09:38,360 --> 01:09:41,860 And now we actually want to have a nice distribution. 941 01:09:42,460 --> 01:09:45,660 So we don't want to aggregate negative point one one of this note. 942 01:09:45,660 --> 01:09:46,260 That's crazy. 943 01:09:46,660 --> 01:09:48,460 So instead we exponentiate and normalize. 944 01:09:49,060 --> 01:09:50,860 And now we get a nice distribution that sums to one. 945 01:09:51,660 --> 01:09:53,660 And this is telling us now in the data dependent manner, 946 01:09:53,660 --> 01:09:57,560 how much of information to aggregate from any of these tokens in the past. 947 01:09:59,560 --> 01:10:00,860 So that's way. 948 01:10:01,260 --> 01:10:02,260 And it's not zeros anymore, 949 01:10:02,260 --> 01:10:04,460 but but it's calculated in this way. 950 01:10:05,060 --> 01:10:05,160 Now, 951 01:10:05,160 --> 01:10:09,760 there's one more part to a single self-attention head. 952 01:10:10,260 --> 01:10:11,960 And that is that when we do the aggregation, 953 01:10:11,960 --> 01:10:13,860 we don't actually aggregate the tokens. 954 01:10:13,860 --> 01:10:14,360 Exactly. 955 01:10:14,860 --> 01:10:15,760 We aggregate, 956 01:10:15,760 --> 01:10:19,160 we produce one more value here and we call that the value. 957 01:10:21,160 --> 01:10:23,060 So in the same way that we produced key and query, 958 01:10:23,060 --> 01:10:24,860 we're also going to create a value. 959 01:10:26,060 --> 01:10:30,160 And then here we don't aggregate. 960 01:10:31,260 --> 01:10:33,860 X we calculate a V, 961 01:10:33,860 --> 01:10:38,560 which is just achieved by propagating this linear on top of X again. 962 01:10:39,060 --> 01:10:42,760 And then we output way multiplied by V. 963 01:10:43,160 --> 01:10:48,660 So V is the elements that we aggregate or the vector that we aggregate instead of the raw X. 964 01:10:49,860 --> 01:10:50,960 And now of course, 965 01:10:50,960 --> 01:10:56,860 this will make it so that the output here of the single head will be 16 dimensional because that is the head size. 966 01:10:58,360 --> 01:11:00,360 So you can think of X as kind of like private information. 967 01:11:00,360 --> 01:11:00,860 So you can think of X as kind of like private information. 968 01:11:00,860 --> 01:11:01,660 To this token, 969 01:11:01,660 --> 01:11:03,460 if you if you think about it that way. 970 01:11:03,660 --> 01:11:05,660 So X is kind of private to this token. 971 01:11:05,860 --> 01:11:12,560 So I'm a fifth token at some and I have some identity and my information is kept in vector X. 972 01:11:13,160 --> 01:11:15,360 And now for the purposes of the single head, 973 01:11:15,560 --> 01:11:17,060 here's what I'm interested in. 974 01:11:17,460 --> 01:11:18,960 Here's what I have. 975 01:11:19,560 --> 01:11:21,260 And if you find me interesting, 976 01:11:21,260 --> 01:11:22,760 here's what I will communicate to you. 977 01:11:23,160 --> 01:11:24,460 And that's stored in V. 978 01:11:25,060 --> 01:11:30,660 And so V is the thing that gets aggregated for the purposes of this single head between the different nodes. 979 01:11:31,660 --> 01:11:35,260 And that's basically the self-attention mechanism. 980 01:11:35,260 --> 01:11:37,360 This is this is what it does. 981 01:11:38,160 --> 01:11:41,260 There are a few notes that I would make like to make about attention. 982 01:11:41,760 --> 01:11:44,860 Number one attention is a communication mechanism. 983 01:11:45,260 --> 01:11:47,660 You can really think about it as a communication mechanism 984 01:11:47,960 --> 01:11:50,560 where you have a number of nodes in a directed graph 985 01:11:50,960 --> 01:11:53,560 where basically you have edges pointing between nodes like this. 986 01:11:54,760 --> 01:11:58,060 And what happens is every node has some vector of information 987 01:11:58,460 --> 01:12:00,260 and it gets to aggregate information 988 01:12:00,260 --> 01:12:03,760 via a weighted sum from all of the nodes that point to it. 989 01:12:04,760 --> 01:12:06,560 And this is done in a data dependent manner. 990 01:12:06,560 --> 01:12:10,260 So depending on whatever data is actually stored at each node at any point in time. 991 01:12:11,160 --> 01:12:11,760 Now, 992 01:12:12,660 --> 01:12:13,960 our graph doesn't look like this. 993 01:12:13,960 --> 01:12:15,460 Our graph has a different structure. 994 01:12:15,760 --> 01:12:20,260 We have eight nodes because the block size is eight and there's always eight tokens. 995 01:12:21,160 --> 01:12:24,260 And the first node is only pointed to by itself. 996 01:12:24,660 --> 01:12:27,560 The second node is pointed to by the first node and itself 997 01:12:27,860 --> 01:12:29,560 all the way up to the eighth node, 998 01:12:29,660 --> 01:12:30,060 which is pointed to by itself. 999 01:12:30,060 --> 01:12:32,860 Pointed to by all the previous nodes and itself. 1000 01:12:33,860 --> 01:12:37,860 And so that's the structure that are directed graph has or happens happens to have 1001 01:12:37,860 --> 01:12:40,560 an autoregressive sort of scenario like language modeling. 1002 01:12:41,260 --> 01:12:44,360 But in principle attention can be applied to any arbitrary directed graph 1003 01:12:44,360 --> 01:12:46,560 and it's just a communication mechanism between the nodes. 1004 01:12:47,260 --> 01:12:50,460 The second note is that notice that there is no notion of space. 1005 01:12:50,760 --> 01:12:55,060 So attention simply acts over like a set of vectors in this graph. 1006 01:12:55,460 --> 01:12:59,160 And so by default these nodes have no idea where they are positioned in the space. 1007 01:12:59,360 --> 01:12:59,960 And that's why we need to 1008 01:13:00,160 --> 01:13:04,860 encode them positionally and sort of give them some information that is anchors to a specific 1009 01:13:04,860 --> 01:13:08,260 position so that they sort of know where they are. 1010 01:13:08,660 --> 01:13:12,160 And this is different than for example from convolution because if you run for example, 1011 01:13:12,160 --> 01:13:17,460 a convolution operation over some input there is a very specific sort of layout of the information 1012 01:13:17,460 --> 01:13:21,160 in space and the convolutional filters sort of act in space. 1013 01:13:21,460 --> 01:13:27,360 And so it's it's not like an attention and attention is just a set of vectors out there in space. 1014 01:13:27,660 --> 01:13:29,860 They communicate and if you want them to have a 1015 01:13:29,860 --> 01:13:34,660 notion of space you need to specifically add it which is what we've done when we calculated the 1016 01:13:35,460 --> 01:13:40,160 relative the positional encode encodings and added that information to the vectors. 1017 01:13:40,360 --> 01:13:44,860 The next thing that I hope is very clear is that the elements across the batch dimension which are 1018 01:13:44,860 --> 01:13:46,860 independent examples never talk to each other. 1019 01:13:46,860 --> 01:13:51,460 They're always processed independently and this is a batch matrix multiply that applies basically a 1020 01:13:51,460 --> 01:13:54,860 matrix multiplication kind of in parallel across the batch dimension. 1021 01:13:55,360 --> 01:13:59,260 So maybe it would be more accurate to say that in this analogy of a directed graph. 1022 01:13:59,860 --> 01:14:05,060 We really have because the batch size is for we really have four separate pools of eight 1023 01:14:05,060 --> 01:14:09,260 nodes and those eight nodes only talk to each other but in total there's like 32 nodes that are being 1024 01:14:09,260 --> 01:14:14,660 processed but there's sort of four separate pools of eight you can look at it that way. 1025 01:14:15,460 --> 01:14:21,860 The next note is that here in the case of language modeling we have this specific structure of 1026 01:14:21,860 --> 01:14:27,860 directed graph where the future tokens will not communicate to the past tokens but this doesn't 1027 01:14:27,860 --> 01:14:29,760 necessarily have to be the constraint in the general case. 1028 01:14:30,460 --> 01:14:35,960 And in fact in many cases you may want to have all of the notes talk to each other fully. 1029 01:14:36,560 --> 01:14:40,660 So as an example if you're doing sentiment analysis or something like that with a transformer you might 1030 01:14:40,660 --> 01:14:46,060 have a number of tokens and you may want to have them all talk to each other fully because later you 1031 01:14:46,060 --> 01:14:50,860 are predicting for example the sentiment of the sentence and so it's okay for these notes to talk to 1032 01:14:50,860 --> 01:14:58,060 each other and so in those cases you will use an encoder block of self-attention and all it means 1033 01:14:58,060 --> 01:14:59,260 that it's an encoder block. 1034 01:14:59,360 --> 01:15:03,960 Is that you will delete this line of code allowing all the notes to completely talk to each other. 1035 01:15:04,560 --> 01:15:10,660 What we're implementing here is sometimes called a decoder block and it's called a decoder because 1036 01:15:10,660 --> 01:15:17,460 it is sort of like decoding language and it's got this autoregressive format where you have to mask 1037 01:15:17,460 --> 01:15:23,760 with the triangular matrix so that notes from the future never talk to the past because they would 1038 01:15:23,760 --> 01:15:24,560 give away the answer. 1039 01:15:25,360 --> 01:15:29,160 And so basically in encoder blocks you would delete this allow all the notes to talk to each other. 1040 01:15:29,860 --> 01:15:35,360 In decoder blocks this will always be present so that you have this triangular structure but both are 1041 01:15:35,360 --> 01:15:36,860 allowed and attention doesn't care. 1042 01:15:36,860 --> 01:15:39,260 Attention supports arbitrary connectivity between notes. 1043 01:15:39,860 --> 01:15:44,460 The next thing I wanted to comment on is you keep me you keep hearing me say attention self-attention 1044 01:15:44,460 --> 01:15:45,060 etc. 1045 01:15:45,060 --> 01:15:46,960 There's actually also something called cross attention. 1046 01:15:46,960 --> 01:15:47,760 What is the difference? 1047 01:15:48,760 --> 01:15:58,160 So basically the reason this attention is self-attention is because the keys queries and the values are all coming 1048 01:15:58,160 --> 01:15:59,160 from the same source. 1049 01:15:59,260 --> 01:16:03,860 From X so the same source X produces keys queries and values. 1050 01:16:03,860 --> 01:16:09,560 So these nodes are self-attending but in principle attention is much more general than that. 1051 01:16:09,560 --> 01:16:15,660 So for example in encoder decoder transformers you can have a case where the queries are produced from 1052 01:16:15,660 --> 01:16:22,260 X but the keys and the values come from a whole separate external source and sometimes from encoder blocks 1053 01:16:22,260 --> 01:16:27,160 that encode some context that we'd like to condition on and so the keys and the values will actually come 1054 01:16:27,160 --> 01:16:28,660 from a whole separate source. 1055 01:16:28,660 --> 01:16:30,960 Those are nodes on the side and here. 1056 01:16:30,960 --> 01:16:34,260 We're just producing queries and we're reading off information from the side. 1057 01:16:34,960 --> 01:16:39,960 So cross attention is used when there's a separate source of nodes. 1058 01:16:40,160 --> 01:16:44,560 We'd like to pull information from into our notes and it's self-attention. 1059 01:16:44,560 --> 01:16:47,260 If we just have nodes that would like to look at each other and talk to each other. 1060 01:16:48,060 --> 01:16:50,960 So this attention here happens to be self-attention. 1061 01:16:52,560 --> 01:16:56,160 But in principle attention is a lot more general. 1062 01:16:56,660 --> 01:16:58,460 Okay in the last note at this stage is 1063 01:16:58,660 --> 01:17:00,960 if we come to the attention is all you need paper here. 1064 01:17:00,960 --> 01:17:02,760 We've already implemented attention. 1065 01:17:02,760 --> 01:17:07,160 So given query key and value we've multiplied the query on the key. 1066 01:17:07,160 --> 01:17:10,860 We've soft maxed it and then we are aggregating the values. 1067 01:17:10,860 --> 01:17:16,260 There's one more thing that we're missing here, which is the dividing by 1 over square root of the head size. 1068 01:17:16,260 --> 01:17:17,860 The DK here is the head size. 1069 01:17:17,860 --> 01:17:19,260 Why are they doing this? 1070 01:17:19,260 --> 01:17:20,060 Why is this important? 1071 01:17:20,060 --> 01:17:26,560 So they call it a scaled attention and it's kind of like an important normalization to basically have. 1072 01:17:26,560 --> 01:17:28,560 The problem is. 1073 01:17:28,660 --> 01:17:33,760 If you have unit Gaussian inputs, so 0 mean unit variance, K and Q are unit Gaussian. 1074 01:17:33,760 --> 01:17:40,160 And if you just do way naively, then you see that your way actually will be the variance will be on the order of head size, 1075 01:17:40,160 --> 01:17:41,360 which in our case is 16. 1076 01:17:42,460 --> 01:17:45,460 But if you multiply by 1 over head size square root, 1077 01:17:45,460 --> 01:17:50,660 so this is square root and this is 1 over then the variance of way will be 1. 1078 01:17:50,660 --> 01:17:51,660 So it will be preserved. 1079 01:17:52,960 --> 01:17:54,160 Now, why is this important? 1080 01:17:54,560 --> 01:17:58,660 You'll notice that way here will feed into softmax. 1081 01:17:59,660 --> 01:18:04,660 And so it's really important, especially at initialization that way be fairly diffuse. 1082 01:18:04,660 --> 01:18:11,660 So in our case here, we sort of lucked out here and way had a fairly diffuse numbers here. 1083 01:18:11,660 --> 01:18:14,460 So like this. 1084 01:18:14,460 --> 01:18:16,460 Now, the problem is that because of softmax, 1085 01:18:16,460 --> 01:18:20,160 if weight takes on very positive and very negative numbers inside it, 1086 01:18:20,160 --> 01:18:24,360 softmax will actually converge towards one hot vectors. 1087 01:18:24,360 --> 01:18:25,760 And so I can illustrate that here. 1088 01:18:28,060 --> 01:18:28,360 Say, 1089 01:18:28,360 --> 01:18:32,060 we are applying softmax to a tensor of values that are very close to zero. 1090 01:18:32,360 --> 01:18:34,660 Then we're going to get a diffuse thing out of softmax. 1091 01:18:35,560 --> 01:18:38,360 But the moment I take the exact same thing and I start sharpening it, 1092 01:18:38,360 --> 01:18:40,760 making it bigger by multiplying these numbers by 8, 1093 01:18:40,760 --> 01:18:41,460 for example, 1094 01:18:41,860 --> 01:18:43,860 you'll see that the softmax will start to sharpen. 1095 01:18:44,260 --> 01:18:46,660 And in fact, it will sharpen towards the max. 1096 01:18:46,660 --> 01:18:49,560 So it will sharpen towards whatever number here is the highest. 1097 01:18:50,160 --> 01:18:53,360 And so basically we don't want these values to be too extreme, 1098 01:18:53,360 --> 01:18:54,560 especially the initialization. 1099 01:18:54,560 --> 01:18:57,360 Otherwise softmax will be way too peaky and 1100 01:18:57,360 --> 01:18:57,760 um, 1101 01:18:57,760 --> 01:19:01,760 you're basically aggregating information from like a single node. 1102 01:19:01,760 --> 01:19:04,460 Every node just aggregates information from a single other node. 1103 01:19:04,460 --> 01:19:05,360 That's not what we want, 1104 01:19:05,360 --> 01:19:06,760 especially at initialization. 1105 01:19:06,760 --> 01:19:11,260 And so the scaling is used just to control the variance at initialization. 1106 01:19:11,260 --> 01:19:11,960 Okay. 1107 01:19:11,960 --> 01:19:13,060 So having said all that, 1108 01:19:13,060 --> 01:19:16,660 let's now take our self-attention knowledge and let's take it for a spin. 1109 01:19:16,660 --> 01:19:18,760 So here in the code, 1110 01:19:18,760 --> 01:19:23,060 I've created this head module and implements a single head of self-attention. 1111 01:19:23,060 --> 01:19:27,160 So you give it a head size and then here it creates the key query and evaluate. 1112 01:19:27,160 --> 01:19:28,960 Linear layers. 1113 01:19:29,360 --> 01:19:31,260 Typically people don't use biases in these. 1114 01:19:32,360 --> 01:19:35,560 So those are the linear projections that we're going to apply to all of our nodes. 1115 01:19:36,360 --> 01:19:37,060 Now here, 1116 01:19:37,060 --> 01:19:38,960 I'm creating this trill variable. 1117 01:19:39,260 --> 01:19:41,160 Trill is not a parameter of the module. 1118 01:19:41,160 --> 01:19:43,160 So in sort of pytorch naming conventions, 1119 01:19:43,560 --> 01:19:44,660 this is called a buffer. 1120 01:19:44,860 --> 01:19:46,960 It's not a parameter and you have to call it. 1121 01:19:46,960 --> 01:19:49,260 You have to assign it to the module using a register buffer. 1122 01:19:49,660 --> 01:19:50,760 So that creates the trill, 1123 01:19:51,660 --> 01:19:53,660 the lower triangular matrix. 1124 01:19:54,460 --> 01:19:55,760 And when we're given the input X, 1125 01:19:55,760 --> 01:19:57,160 this should look very familiar now. 1126 01:19:57,360 --> 01:19:58,460 We calculate the keys, 1127 01:19:58,460 --> 01:19:59,260 the queries, 1128 01:19:59,460 --> 01:20:02,160 we calculate the attention scores inside way. 1129 01:20:02,960 --> 01:20:03,760 We normalize it. 1130 01:20:03,760 --> 01:20:05,560 So we're using scaled attention here. 1131 01:20:06,260 --> 01:20:09,360 Then we make sure that sure doesn't communicate with the past. 1132 01:20:09,560 --> 01:20:11,260 So this makes it a decoder block 1133 01:20:12,060 --> 01:20:15,060 and then softmax and then aggregate the value and output. 1134 01:20:16,560 --> 01:20:17,560 Then here in the language model, 1135 01:20:17,560 --> 01:20:22,060 I'm creating a head in the constructor and I'm calling it self-attention head 1136 01:20:22,460 --> 01:20:23,760 and the head size. 1137 01:20:23,760 --> 01:20:26,960 I'm going to keep as the same and embed just for now. 1138 01:20:27,160 --> 01:20:32,960 And then here once we've encoded the information with the token embeddings 1139 01:20:32,960 --> 01:20:34,060 and the position embeddings, 1140 01:20:34,460 --> 01:20:36,760 we're simply going to feed it into the self-attention head 1141 01:20:37,160 --> 01:20:42,560 and then the output of that is going to go into the decoder language modeling 1142 01:20:42,560 --> 01:20:44,160 head and create the logits. 1143 01:20:44,560 --> 01:20:48,260 So this is sort of the simplest way to plug in a self-attention component 1144 01:20:49,060 --> 01:20:50,360 into our network right now. 1145 01:20:51,160 --> 01:20:52,460 I had to make one more change, 1146 01:20:52,860 --> 01:20:55,960 which is that here in the generate, 1147 01:20:55,960 --> 01:21:00,360 we have to make sure that our IDX that we feed into the model 1148 01:21:00,960 --> 01:21:02,660 because now we're using positional embeddings, 1149 01:21:02,960 --> 01:21:05,760 we can never have more than block size coming in 1150 01:21:06,160 --> 01:21:08,660 because if IDX is more than block size, 1151 01:21:08,960 --> 01:21:11,460 then our position embedding table is going to run out of scope 1152 01:21:11,460 --> 01:21:13,560 because it only has embeddings for up to block size. 1153 01:21:14,460 --> 01:21:17,660 And so therefore I added some code here to crop the context 1154 01:21:18,260 --> 01:21:19,960 that we're going to feed into self 1155 01:21:21,660 --> 01:21:24,660 so that we never pass in more than block size elements. 1156 01:21:24,660 --> 01:21:27,760 So those are the changes and let's now train the network. 1157 01:21:28,060 --> 01:21:31,560 Okay, so I also came up to the script here and I decreased the learning rate 1158 01:21:31,560 --> 01:21:34,960 because the self-attention can't tolerate very very high learning rates. 1159 01:21:35,560 --> 01:21:37,660 And then I also increased the number of iterations 1160 01:21:37,660 --> 01:21:40,160 because the learning rate is lower and then I trained it 1161 01:21:40,160 --> 01:21:42,960 and previously we were only able to get to up to 2.5 1162 01:21:43,260 --> 01:21:44,860 and now we are down to 2.4. 1163 01:21:45,260 --> 01:21:49,260 So we definitely see a little bit of improvement from 2.5 to 2.4 roughly, 1164 01:21:49,860 --> 01:21:51,460 but the text is still not amazing. 1165 01:21:51,960 --> 01:21:54,460 So clearly the self-attention head is doing 1166 01:21:54,660 --> 01:21:55,960 some useful communication, 1167 01:21:56,460 --> 01:21:59,060 but we still have a long way to go. 1168 01:21:59,360 --> 01:21:59,560 Okay. 1169 01:21:59,560 --> 01:22:01,860 So now we've implemented the scale dot product attention. 1170 01:22:02,160 --> 01:22:04,560 Now next up in the attention is all you need paper. 1171 01:22:05,060 --> 01:22:06,660 There's something called multi-head attention. 1172 01:22:07,060 --> 01:22:08,460 And what is multi-head attention? 1173 01:22:08,860 --> 01:22:11,760 It's just applying multiple attentions in parallel 1174 01:22:11,960 --> 01:22:13,460 and concatenating the results. 1175 01:22:14,060 --> 01:22:15,960 So they have a little bit of diagram here. 1176 01:22:16,260 --> 01:22:17,760 I don't know if this is super clear. 1177 01:22:18,360 --> 01:22:21,060 It's really just multiple attentions in parallel. 1178 01:22:21,760 --> 01:22:24,260 So let's implement that fairly straightforward. 1179 01:22:25,360 --> 01:22:26,960 If we want a multi-head attention, 1180 01:22:27,260 --> 01:22:29,960 then we want multiple heads of self-attention running in parallel. 1181 01:22:30,860 --> 01:22:34,960 So in PyTorch we can do this by simply creating multiple heads. 1182 01:22:36,160 --> 01:22:38,960 So however many heads you want 1183 01:22:39,160 --> 01:22:40,760 and then what is the head size of each 1184 01:22:41,560 --> 01:22:45,260 and then we run all of them in parallel into a list 1185 01:22:45,560 --> 01:22:47,760 and simply concatenate all of the outputs 1186 01:22:48,360 --> 01:22:50,560 and we're concatenating over the channel dimension. 1187 01:22:51,660 --> 01:22:54,460 So the way this looks now is we don't have just a single attention 1188 01:22:54,960 --> 01:23:00,560 that has a head size of 32 because remember an embed is 32. 1189 01:23:01,660 --> 01:23:04,160 Instead of having one communication channel, 1190 01:23:04,460 --> 01:23:07,860 we now have four communication channels in parallel 1191 01:23:08,160 --> 01:23:10,460 and each one of these communication channels typically 1192 01:23:10,960 --> 01:23:14,160 will be smaller correspondingly. 1193 01:23:14,560 --> 01:23:16,460 So because we have four communication channels, 1194 01:23:16,760 --> 01:23:18,560 we want eight-dimensional self-attention. 1195 01:23:19,160 --> 01:23:20,860 And so from each communication channel, 1196 01:23:20,860 --> 01:23:23,060 we're getting together eight-dimensional vectors 1197 01:23:23,360 --> 01:23:24,560 and then we have four of them. 1198 01:23:24,660 --> 01:23:26,760 And that concatenates to give us 32, 1199 01:23:26,860 --> 01:23:28,260 which is the original and embed. 1200 01:23:29,160 --> 01:23:32,160 And so this is kind of similar to if you're familiar with convolutions, 1201 01:23:32,160 --> 01:23:33,860 this is kind of like a group convolution 1202 01:23:34,460 --> 01:23:37,160 because basically instead of having one large convolution, 1203 01:23:37,260 --> 01:23:41,560 we do convolution in groups and that's multi-headed self-attention. 1204 01:23:42,560 --> 01:23:45,360 And so then here we just use essay heads, 1205 01:23:45,460 --> 01:23:46,860 self-attention heads instead. 1206 01:23:47,760 --> 01:23:51,060 Now, I actually ran it and scrolling down, 1207 01:23:52,260 --> 01:23:54,160 I ran the same thing and then we now get down 1208 01:23:54,660 --> 01:23:58,360 to 2.28 roughly and the output is still, 1209 01:23:58,360 --> 01:23:59,760 the generation is still not amazing, 1210 01:23:59,960 --> 01:24:01,760 but clearly the validation loss is improving 1211 01:24:01,760 --> 01:24:03,960 because we were at 2.4 just now. 1212 01:24:04,860 --> 01:24:07,160 And so it helps to have multiple communication channels 1213 01:24:07,160 --> 01:24:10,260 because obviously these tokens have a lot to talk about. 1214 01:24:10,760 --> 01:24:12,660 They want to find the consonants, the vowels, 1215 01:24:12,660 --> 01:24:14,960 they want to find the vowels just from certain positions, 1216 01:24:15,460 --> 01:24:18,060 they want to find any kinds of different things. 1217 01:24:18,460 --> 01:24:21,460 And so it helps to create multiple independent channels of communication, 1218 01:24:21,560 --> 01:24:23,360 gather lots of different types of data 1219 01:24:23,760 --> 01:24:24,360 and then 1220 01:24:24,660 --> 01:24:25,760 decode the output. 1221 01:24:26,160 --> 01:24:27,660 Now going back to the paper for a second, 1222 01:24:27,860 --> 01:24:28,260 of course, 1223 01:24:28,260 --> 01:24:29,960 I didn't explain this figure in full detail, 1224 01:24:29,960 --> 01:24:33,160 but we are starting to see some components of what we've already implemented. 1225 01:24:33,360 --> 01:24:34,660 We have the positional encodings, 1226 01:24:34,660 --> 01:24:36,160 the token encodings that add, 1227 01:24:36,560 --> 01:24:39,260 we have the masked multi-headed attention implemented. 1228 01:24:39,960 --> 01:24:42,060 Now, here's another multi-headed attention, 1229 01:24:42,060 --> 01:24:44,360 which is a cross attention to an encoder, 1230 01:24:44,360 --> 01:24:45,260 which we haven't, 1231 01:24:45,260 --> 01:24:46,860 we're not going to implement in this case. 1232 01:24:47,160 --> 01:24:48,460 I'm going to come back to that later. 1233 01:24:49,460 --> 01:24:51,960 But I want you to notice that there's a feed forward part here 1234 01:24:52,160 --> 01:24:54,560 and then this is grouped into a block that gets repeated. 1235 01:24:54,860 --> 01:24:55,360 And again, 1236 01:24:55,960 --> 01:24:59,360 now the feed forward part here is just a simple multi-layer perceptron. 1237 01:25:01,760 --> 01:25:02,660 So the multi-headed, 1238 01:25:03,060 --> 01:25:07,460 so here position wise feed forward networks is just a simple little MLP. 1239 01:25:08,160 --> 01:25:10,360 So I want to start basically in a similar fashion. 1240 01:25:10,360 --> 01:25:12,760 Also adding computation into the network 1241 01:25:13,360 --> 01:25:15,560 and this computation is on the per node level. 1242 01:25:16,060 --> 01:25:16,560 So 1243 01:25:17,460 --> 01:25:20,860 I've already implemented it and you can see the diff highlighted on the left here 1244 01:25:20,860 --> 01:25:22,260 when I've added or changed things. 1245 01:25:22,960 --> 01:25:24,160 Now before we had the 1246 01:25:24,660 --> 01:25:26,960 multi-headed self-attention that did the communication, 1247 01:25:27,360 --> 01:25:30,260 but we went way too fast to calculate the logits. 1248 01:25:30,660 --> 01:25:32,460 So the tokens looked at each other, 1249 01:25:32,460 --> 01:25:36,760 but didn't really have a lot of time to think on what they found from the other tokens. 1250 01:25:37,560 --> 01:25:38,160 And so 1251 01:25:38,760 --> 01:25:41,860 what I've implemented here is a little feed forward single layer 1252 01:25:42,360 --> 01:25:45,960 and this little layer is just a linear followed by a relu non-linearity 1253 01:25:46,160 --> 01:25:46,960 and that's it. 1254 01:25:47,960 --> 01:25:49,560 So it's just a little layer 1255 01:25:50,160 --> 01:25:51,960 and then I call it feed forward 1256 01:25:53,560 --> 01:25:54,160 and embed. 1257 01:25:54,660 --> 01:25:58,560 And then this feed forward is just called sequentially right after the self-attention. 1258 01:25:58,960 --> 01:26:01,460 So we self-attend then we feed forward 1259 01:26:01,960 --> 01:26:04,660 and you'll notice that the feed forward here when it's applying linear. 1260 01:26:04,860 --> 01:26:06,360 This is on a per token level. 1261 01:26:06,460 --> 01:26:08,260 All the tokens do this independently. 1262 01:26:08,660 --> 01:26:11,060 So the self-attention is the communication 1263 01:26:11,460 --> 01:26:15,260 and then once they've gathered all the data now they need to think on that data individually. 1264 01:26:16,160 --> 01:26:17,660 And so that's what feed forward is doing 1265 01:26:18,060 --> 01:26:19,560 and that's why I've added it here. 1266 01:26:20,260 --> 01:26:24,360 Now when I train this the validation laws actually continues to go down now to 2.24. 1267 01:26:25,260 --> 01:26:26,960 Which is down from 2.28. 1268 01:26:27,660 --> 01:26:29,360 The output still look kind of terrible, 1269 01:26:29,660 --> 01:26:31,360 but at least we've improved the situation. 1270 01:26:32,060 --> 01:26:33,160 And so as a preview 1271 01:26:34,160 --> 01:26:36,060 we're going to now start to intersperse 1272 01:26:36,560 --> 01:26:39,360 the communication with the computation 1273 01:26:39,660 --> 01:26:41,760 and that's also what the transformer does 1274 01:26:42,160 --> 01:26:45,060 when it has blocks that communicate and then compute 1275 01:26:45,360 --> 01:26:47,560 and it groups them and replicates them. 1276 01:26:48,760 --> 01:26:50,760 Okay, so let me show you what we'd like to do. 1277 01:26:51,360 --> 01:26:52,460 We'd like to do something like this. 1278 01:26:52,460 --> 01:26:53,260 We have a block 1279 01:26:53,660 --> 01:26:54,560 and this block is basically 1280 01:26:54,760 --> 01:26:55,560 this part here 1281 01:26:56,160 --> 01:26:57,460 except for the cross attention. 1282 01:26:58,660 --> 01:27:02,160 Now the block basically intersperses communication and then computation. 1283 01:27:02,660 --> 01:27:05,960 The computation is done using multi-headed self-attention 1284 01:27:06,560 --> 01:27:09,260 and then the computation is done using a feed forward network 1285 01:27:09,660 --> 01:27:10,960 on all the tokens independently. 1286 01:27:12,560 --> 01:27:15,560 Now what I've added here also is you'll notice 1287 01:27:17,260 --> 01:27:19,560 this takes the number of embeddings in the embedding dimension 1288 01:27:19,560 --> 01:27:21,060 and number of heads that we would like 1289 01:27:21,060 --> 01:27:23,560 which is kind of like group size in group convolution. 1290 01:27:24,060 --> 01:27:26,460 And I'm saying that number of heads we'd like is four 1291 01:27:26,860 --> 01:27:28,560 and so because this is 32 1292 01:27:28,960 --> 01:27:30,660 we calculate that because this is 32 1293 01:27:30,860 --> 01:27:32,260 the number of heads should be four 1294 01:27:34,060 --> 01:27:35,460 the head size should be eight 1295 01:27:35,560 --> 01:27:37,760 so that everything sort of works out channel wise. 1296 01:27:38,960 --> 01:27:40,860 So this is how the transformer structures 1297 01:27:41,160 --> 01:27:43,560 sort of the sizes typically. 1298 01:27:44,360 --> 01:27:45,560 So the head size will become eight 1299 01:27:45,660 --> 01:27:47,360 and then this is how we want to intersperse them. 1300 01:27:47,860 --> 01:27:50,060 And then here I'm trying to create blocks 1301 01:27:50,160 --> 01:27:53,360 which is just a sequential application of block block block. 1302 01:27:53,660 --> 01:27:56,860 So that we're interspersing communication feed forward many many times 1303 01:27:57,060 --> 01:27:58,860 and then finally we decode. 1304 01:27:59,460 --> 01:28:01,460 Now actually try to run this 1305 01:28:01,760 --> 01:28:04,760 and the problem is this doesn't actually give a very good answer 1306 01:28:05,360 --> 01:28:06,660 and very good result. 1307 01:28:06,860 --> 01:28:10,660 And the reason for that is we're starting to actually get like a pretty deep neural net 1308 01:28:11,060 --> 01:28:13,760 and deep neural nets suffer from optimization issues. 1309 01:28:13,760 --> 01:28:16,460 And I think that's what we're kind of like slightly starting to run into. 1310 01:28:16,760 --> 01:28:19,560 So we need one more idea that we can borrow from the 1311 01:28:20,360 --> 01:28:22,560 transformer paper to resolve those difficulties. 1312 01:28:22,560 --> 01:28:25,660 Now there are two optimizations that dramatically help 1313 01:28:25,760 --> 01:28:27,060 with the depth of these networks 1314 01:28:27,360 --> 01:28:29,960 and make sure that the networks remain optimizable. 1315 01:28:30,260 --> 01:28:31,260 Let's talk about the first one. 1316 01:28:31,960 --> 01:28:34,560 The first one in this diagram is you see this arrow here 1317 01:28:35,360 --> 01:28:37,360 and then this arrow and this arrow. 1318 01:28:37,760 --> 01:28:40,760 Those are skip connections or sometimes called residual connections. 1319 01:28:41,560 --> 01:28:42,560 They come from this paper 1320 01:28:43,460 --> 01:28:46,860 the procedural learning for image recognition from about 2015 1321 01:28:47,760 --> 01:28:49,160 that introduced the concept. 1322 01:28:49,960 --> 01:28:52,360 Now these are basically what it means 1323 01:28:52,560 --> 01:28:54,260 is you transform the data, 1324 01:28:54,460 --> 01:28:56,860 but then you have a skip connection with addition 1325 01:28:57,460 --> 01:28:58,760 from the previous features. 1326 01:28:59,360 --> 01:29:00,960 Now the way I like to visualize it 1327 01:29:01,660 --> 01:29:02,460 that I prefer 1328 01:29:02,960 --> 01:29:03,760 is the following. 1329 01:29:04,260 --> 01:29:06,960 Here the computation happens from the top to bottom 1330 01:29:07,560 --> 01:29:10,460 and basically you have this residual pathway 1331 01:29:11,060 --> 01:29:13,460 and you are free to fork off from the residual pathway, 1332 01:29:13,460 --> 01:29:14,660 perform some computation 1333 01:29:14,960 --> 01:29:17,760 and then project back to the residual pathway via addition. 1334 01:29:18,560 --> 01:29:19,760 And so you go from the 1335 01:29:20,560 --> 01:29:22,460 the inputs to the targets 1336 01:29:22,660 --> 01:29:24,560 only via plus and plus and plus. 1337 01:29:25,460 --> 01:29:27,760 And the reason this is useful is because during dot propagation 1338 01:29:27,760 --> 01:29:30,660 remember from our micrograd video earlier 1339 01:29:31,060 --> 01:29:34,660 addition distributes gradients equally to both of its branches 1340 01:29:35,260 --> 01:29:36,560 that fed as the input. 1341 01:29:37,060 --> 01:29:40,960 And so the supervision or the gradients from the loss 1342 01:29:41,360 --> 01:29:44,460 basically hop through every addition node 1343 01:29:44,760 --> 01:29:46,260 all the way to the input 1344 01:29:46,760 --> 01:29:49,960 and then also fork off into the residual blocks. 1345 01:29:51,260 --> 01:29:52,360 But basically you have this 1346 01:29:52,360 --> 01:29:55,560 gradient superhighway that goes directly from the supervision 1347 01:29:55,760 --> 01:29:57,660 all the way to the input unimpeded. 1348 01:29:58,360 --> 01:30:01,060 And then these residual blocks are usually initialized in the beginning. 1349 01:30:01,360 --> 01:30:04,360 So they contribute very very little if anything to the residual pathway. 1350 01:30:04,760 --> 01:30:06,360 They are initialized that way. 1351 01:30:06,760 --> 01:30:09,760 So in the beginning they are sort of almost kind of like not there. 1352 01:30:10,160 --> 01:30:13,560 But then during the optimization they come online over time 1353 01:30:14,160 --> 01:30:15,760 and they start to contribute 1354 01:30:16,360 --> 01:30:18,060 but at least at the initialization 1355 01:30:18,260 --> 01:30:20,460 you can go from directly supervision to the input 1356 01:30:20,960 --> 01:30:22,260 gradient is unimpeded and just flows. 1357 01:30:22,860 --> 01:30:25,160 And then the blocks over time kick in. 1358 01:30:25,760 --> 01:30:28,360 And so that dramatically helps with the optimization. 1359 01:30:28,660 --> 01:30:29,560 So let's implement this. 1360 01:30:29,860 --> 01:30:31,160 So coming back to our block here. 1361 01:30:31,560 --> 01:30:32,960 Basically what we want to do is 1362 01:30:33,560 --> 01:30:35,560 we want to do x equals x plus 1363 01:30:36,560 --> 01:30:39,560 self-attention and x equals x plus self.feedforward. 1364 01:30:40,760 --> 01:30:45,460 So this is x and then we fork off and do some communication and come back 1365 01:30:45,760 --> 01:30:48,160 and we fork off and we do some computation and come back. 1366 01:30:48,960 --> 01:30:50,360 So those are residual connections 1367 01:30:50,860 --> 01:30:52,260 and then swinging back up here. 1368 01:30:52,460 --> 01:30:55,060 We also have to introduce this projection. 1369 01:30:55,960 --> 01:30:57,060 So nn.linear 1370 01:30:58,560 --> 01:31:00,860 and this is going to be from 1371 01:31:01,860 --> 01:31:03,060 after we concatenate this. 1372 01:31:03,060 --> 01:31:04,460 This is the size and embed. 1373 01:31:04,960 --> 01:31:07,260 So this is the output of the self-attention itself. 1374 01:31:07,960 --> 01:31:11,260 But then we actually want the to apply the projection 1375 01:31:12,260 --> 01:31:13,160 and that's the result. 1376 01:31:14,360 --> 01:31:17,260 So the projection is just a linear transformation of the outcome of this layer. 1377 01:31:18,860 --> 01:31:21,160 So that's the projection back into the residual pathway. 1378 01:31:21,860 --> 01:31:23,160 And then here in a feedforward, 1379 01:31:23,260 --> 01:31:24,460 it's going to be the same thing. 1380 01:31:24,960 --> 01:31:27,260 I could have a self.projection here as well. 1381 01:31:27,560 --> 01:31:28,960 But let me just simplify it 1382 01:31:29,660 --> 01:31:30,660 and let me 1383 01:31:32,060 --> 01:31:34,060 couple it inside the same sequential container. 1384 01:31:34,760 --> 01:31:37,860 And so this is the projection layer going back into the residual pathway. 1385 01:31:39,160 --> 01:31:40,160 And so 1386 01:31:40,960 --> 01:31:42,360 that's well, that's it. 1387 01:31:42,660 --> 01:31:43,560 So now we can train this. 1388 01:31:43,760 --> 01:31:45,360 So I implemented one more small change. 1389 01:31:45,960 --> 01:31:48,160 When you look into the paper again, 1390 01:31:48,360 --> 01:31:50,960 you see that the dimensionality of input and output 1391 01:31:51,160 --> 01:31:52,460 is 512 for them. 1392 01:31:52,760 --> 01:31:55,160 And they're saying that the inner layer here in the feedforward 1393 01:31:55,160 --> 01:31:56,860 has dimensionality of 2048. 1394 01:31:57,160 --> 01:31:58,660 So there's a multiplier of 4. 1395 01:31:59,360 --> 01:32:02,060 And so the inner layer of the feedforward network 1396 01:32:02,760 --> 01:32:04,960 should be multiplied by 4 in terms of channel sizes. 1397 01:32:05,160 --> 01:32:07,560 So I came here and I multiplied 4 times embed 1398 01:32:07,860 --> 01:32:09,260 here for the feedforward 1399 01:32:09,660 --> 01:32:12,660 and then from 4 times an embed coming back down to an embed 1400 01:32:12,860 --> 01:32:15,060 when we go back to the projection. 1401 01:32:15,360 --> 01:32:18,460 So adding a bit of computation here and growing that layer 1402 01:32:18,660 --> 01:32:20,660 that is in the residual block on the side 1403 01:32:20,660 --> 01:32:22,060 of the residual pathway. 1404 01:32:23,160 --> 01:32:25,960 And then I train this and we actually get down all the way to 1405 01:32:26,260 --> 01:32:28,060 2.08 validation loss. 1406 01:32:28,260 --> 01:32:30,460 And we also see that network is starting to get big enough 1407 01:32:30,760 --> 01:32:33,060 that our train loss is getting ahead of validation loss. 1408 01:32:33,060 --> 01:32:35,160 So we started to see like a little bit of overfitting 1409 01:32:36,160 --> 01:32:36,660 and 1410 01:32:37,060 --> 01:32:37,360 our 1411 01:32:37,660 --> 01:32:38,060 our 1412 01:32:40,060 --> 01:32:41,660 generations here are still not amazing. 1413 01:32:41,660 --> 01:32:45,860 But at least you see that we can see like is here this now grief sync 1414 01:32:46,560 --> 01:32:48,660 like this starts to almost look like English. 1415 01:32:48,960 --> 01:32:49,460 So 1416 01:32:50,060 --> 01:32:50,160 yeah, 1417 01:32:50,160 --> 01:32:51,160 we're starting to really get there. 1418 01:32:51,660 --> 01:32:51,860 Okay. 1419 01:32:51,860 --> 01:32:54,860 And the second innovation that is very helpful for optimizing very deep 1420 01:32:54,860 --> 01:32:56,360 neural networks is right here. 1421 01:32:56,860 --> 01:32:59,060 So we have this addition now that's the residual part. 1422 01:32:59,260 --> 01:33:01,760 But this norm is referring to something called layer norm. 1423 01:33:02,460 --> 01:33:04,360 So layer norm is implemented in pytorch. 1424 01:33:04,360 --> 01:33:07,360 It's a paper that came out a while back here. 1425 01:33:10,160 --> 01:33:12,260 And layer norm is very very similar to bash norm. 1426 01:33:12,660 --> 01:33:15,860 So remember back to our make more series part three. 1427 01:33:16,260 --> 01:33:20,060 We implemented bash normalization and bash normalization basically just 1428 01:33:20,060 --> 01:33:23,960 made sure that across the batch dimension. 1429 01:33:24,160 --> 01:33:29,960 Any individual neuron had unit Gaussian distribution. 1430 01:33:30,260 --> 01:33:34,260 So it was zero mean and unit standard deviation one standard deviation 1431 01:33:34,460 --> 01:33:34,960 output. 1432 01:33:35,860 --> 01:33:39,260 So what I did here is I'm copy pasting the bathroom 1D that we developed 1433 01:33:39,260 --> 01:33:43,960 in our make more series and see here we can initialize for example this 1434 01:33:43,960 --> 01:33:48,660 module and we can have a batch of 32 100 dimensional vectors feeding through 1435 01:33:48,660 --> 01:33:49,460 the bathroom layer. 1436 01:33:50,160 --> 01:33:55,760 So what this does is it guarantees that when we look at just the 0th column, 1437 01:33:56,360 --> 01:33:59,260 it's a zero mean one standard deviation. 1438 01:33:59,760 --> 01:34:02,960 So it's normalizing every single column of this input. 1439 01:34:03,860 --> 01:34:08,060 Now the rows are not going to be normalized by default because we're just 1440 01:34:08,060 --> 01:34:09,060 normalizing columns. 1441 01:34:09,660 --> 01:34:11,060 So let's not implement layer norm. 1442 01:34:11,960 --> 01:34:13,060 It's very complicated. 1443 01:34:13,160 --> 01:34:14,860 Look we come here. 1444 01:34:15,060 --> 01:34:19,560 We change this from 0 to 1 so we don't normalize the columns. 1445 01:34:19,560 --> 01:34:19,960 We normalize. 1446 01:34:20,160 --> 01:34:23,660 The rows and now we've implemented layer norm. 1447 01:34:25,060 --> 01:34:28,260 So now the columns are not going to be normalized. 1448 01:34:29,960 --> 01:34:33,760 But the rows are going to be normalized for every individual example. 1449 01:34:33,760 --> 01:34:38,460 It's 100 dimensional vector is normalized in this way and because our 1450 01:34:38,460 --> 01:34:43,360 computation now does not span across examples, we can delete all of this 1451 01:34:43,360 --> 01:34:48,960 buffers stuff because we can always apply this operation and don't need to 1452 01:34:48,960 --> 01:34:49,960 maintain any running buffers. 1453 01:34:50,660 --> 01:34:52,360 So we don't need the buffers. 1454 01:34:53,360 --> 01:34:57,460 We don't there's no distinction between training and test time. 1455 01:34:59,460 --> 01:35:01,460 And we don't need these running buffers. 1456 01:35:01,760 --> 01:35:03,360 We do keep gamma and beta. 1457 01:35:03,660 --> 01:35:04,860 We don't need the momentum. 1458 01:35:04,860 --> 01:35:06,560 We don't care if it's training or not. 1459 01:35:07,360 --> 01:35:13,160 And this is now a layer norm and it normalizes the ropes instead of the 1460 01:35:13,160 --> 01:35:18,460 columns and this here is identical to basically this here. 1461 01:35:19,460 --> 01:35:19,960 So let's. 1462 01:35:19,960 --> 01:35:23,360 Now implement layer norm in our transformer before I incorporate the 1463 01:35:23,360 --> 01:35:23,760 layer norm. 1464 01:35:23,760 --> 01:35:27,660 I just wanted to note that as I said very few details about the transformer 1465 01:35:27,660 --> 01:35:30,460 have changed in the last five years, but this is actually something that's 1466 01:35:30,460 --> 01:35:32,060 likely departs from the original paper. 1467 01:35:32,560 --> 01:35:36,360 You see that the ad and norm is applied after the transformation. 1468 01:35:37,360 --> 01:35:43,560 But now it is a bit more basically common to apply the layer norm before 1469 01:35:43,560 --> 01:35:44,360 the transformation. 1470 01:35:44,360 --> 01:35:46,160 So there's a reshuffling of the layer norms. 1471 01:35:46,960 --> 01:35:49,560 So this is called the pre norm formulation and that the one that we're 1472 01:35:49,560 --> 01:35:50,660 going to implement as well. 1473 01:35:50,660 --> 01:35:52,460 So slight deviation from the original paper. 1474 01:35:53,260 --> 01:35:55,360 Basically, we need to layer norms layer norm. 1475 01:35:55,360 --> 01:36:01,360 One is an end dot layer norm and we tell it how many was the embedding 1476 01:36:01,360 --> 01:36:04,360 dimension and we need the second layer norm. 1477 01:36:05,260 --> 01:36:08,660 And then here the layer norms are applied immediately on X. 1478 01:36:09,360 --> 01:36:13,760 So self-taught layer norm one in applied on X and self-taught layer norm 1479 01:36:13,760 --> 01:36:19,460 two applied on X before it goes into self-attention and feed forward and 1480 01:36:19,560 --> 01:36:22,460 the size of the layer norm here is an embed so 32. 1481 01:36:23,060 --> 01:36:27,960 So when the layer norm is normalizing our features it is the normalization 1482 01:36:27,960 --> 01:36:33,860 here happens the mean and the variance are taken over 32 numbers. 1483 01:36:34,160 --> 01:36:37,760 So the batch and the time act as batch dimensions both of them. 1484 01:36:38,360 --> 01:36:42,860 So this is kind of like a per token transformation that just normalizes 1485 01:36:42,860 --> 01:36:48,560 the features and makes them a unit mean unit Gaussian at initialization. 1486 01:36:48,560 --> 01:36:53,360 But of course because these layer norms inside it have these gamma and beta 1487 01:36:53,360 --> 01:36:59,360 trainable parameters the layer normal eventually create outputs that might 1488 01:36:59,360 --> 01:37:03,860 not be unit Gaussian but the optimization will determine that so for 1489 01:37:03,860 --> 01:37:07,660 now, this is the this is incorporating the layer norms and let's train them 1490 01:37:07,660 --> 01:37:07,960 up. 1491 01:37:08,560 --> 01:37:12,660 Okay, so I let it run and we see that we get down to 2.06 which is better 1492 01:37:12,660 --> 01:37:14,060 than the previous 2.08. 1493 01:37:14,360 --> 01:37:17,760 So a slight improvement by adding the layer norms and I'd expect that they 1494 01:37:17,760 --> 01:37:18,260 help. 1495 01:37:18,260 --> 01:37:20,460 Even more if we have bigger and deeper network. 1496 01:37:21,060 --> 01:37:21,560 One more thing. 1497 01:37:21,560 --> 01:37:24,460 I forgot to add is that there should be a layer norm here. 1498 01:37:24,460 --> 01:37:29,360 Also typically as at the end of the transformer and right before the final 1499 01:37:29,660 --> 01:37:32,260 linear layer that decodes into vocabulary. 1500 01:37:32,760 --> 01:37:34,060 So I added that as well. 1501 01:37:34,760 --> 01:37:37,960 So at this stage, we actually have a pretty complete transformer coming to 1502 01:37:37,960 --> 01:37:40,860 the original paper and it's a decoder only transformer. 1503 01:37:40,960 --> 01:37:45,260 I'll I'll talk about that in a second but at this stage the major pieces 1504 01:37:45,260 --> 01:37:48,160 are in place so we can try to scale this up and see how well we can push 1505 01:37:48,160 --> 01:37:50,760 this number now in order to scale up the model. 1506 01:37:50,760 --> 01:37:54,360 I had to perform some cosmetic changes here to make it nicer. 1507 01:37:54,660 --> 01:37:57,660 So I introduced this variable called in layer which just specifies how 1508 01:37:57,660 --> 01:37:59,860 many layers of the blocks. 1509 01:37:59,860 --> 01:38:03,260 We're going to have I create a bunch of blocks and we have a new variable 1510 01:38:03,260 --> 01:38:04,460 number of heads as well. 1511 01:38:05,460 --> 01:38:06,960 I pulled out the layer norm here. 1512 01:38:07,160 --> 01:38:08,560 And so this is identical. 1513 01:38:09,160 --> 01:38:12,660 Now one thing that I did briefly change is I added dropout. 1514 01:38:13,160 --> 01:38:17,760 So dropout is something that you can add right before the residual connection. 1515 01:38:17,760 --> 01:38:21,060 Back right before the connection back into the residual pathway. 1516 01:38:21,660 --> 01:38:24,160 So we can drop out that as the last layer here. 1517 01:38:24,760 --> 01:38:28,760 We can drop out here at the end of the multi-headed extension as well. 1518 01:38:29,460 --> 01:38:35,360 And we can also drop out here when we calculate the basically affinities 1519 01:38:35,360 --> 01:38:39,260 and after the softmax we can drop out some of those so we can randomly 1520 01:38:39,260 --> 01:38:41,260 prevent some of the notes from communicating. 1521 01:38:42,060 --> 01:38:46,860 And so dropout comes from this paper from 2014 or so. 1522 01:38:46,860 --> 01:38:53,760 And basically it takes your neural net and it randomly every forward backward 1523 01:38:53,760 --> 01:39:00,860 pass shuts off some subset of neurons so randomly drops them to zero and 1524 01:39:00,860 --> 01:39:05,460 trains without them and what this does effectively is because the mask of 1525 01:39:05,460 --> 01:39:08,760 what being dropped out has changed every single forward backward pass it 1526 01:39:08,760 --> 01:39:14,060 ends up kind of training an ensemble of sub networks and then at test time 1527 01:39:14,060 --> 01:39:16,760 everything is fully enabled and kind of all those sub networks. 1528 01:39:16,760 --> 01:39:18,560 Are merged into a single ensemble. 1529 01:39:18,560 --> 01:39:20,260 If you can if you want to think about it that way. 1530 01:39:20,960 --> 01:39:23,960 So I would read the paper to get the full detail for now. 1531 01:39:23,960 --> 01:39:27,360 We're just going to stay on the level of this is a regularization technique 1532 01:39:27,660 --> 01:39:30,860 and I added it because I'm about to scale up the model quite a bit and I 1533 01:39:30,860 --> 01:39:32,060 was concerned about overfitting. 1534 01:39:33,360 --> 01:39:37,460 So now when we scroll up to the top we'll see that I changed a number of 1535 01:39:37,460 --> 01:39:39,360 hyper parameters here about our neural net. 1536 01:39:39,760 --> 01:39:42,460 So I made the batch size be much larger now 64. 1537 01:39:43,160 --> 01:39:45,260 I changed the block size to be 256. 1538 01:39:45,460 --> 01:39:46,660 So previously was just eight. 1539 01:39:46,960 --> 01:39:48,160 Eight characters of context. 1540 01:39:48,260 --> 01:39:52,760 Now it is 256 characters of context to predict the 257th. 1541 01:39:54,360 --> 01:39:57,160 I brought down the learning rate a little bit because the neural net is 1542 01:39:57,160 --> 01:39:57,960 now much bigger. 1543 01:39:57,960 --> 01:39:59,360 So I brought down the learning rate. 1544 01:40:00,260 --> 01:40:03,460 The embedding dimension is not 384 and there are six heads. 1545 01:40:03,960 --> 01:40:10,260 So 384 divide 6 means that every head is 64 dimensional as it as a standard 1546 01:40:11,060 --> 01:40:14,860 and then there was going to be six layers of that and the dropout will be 1547 01:40:14,860 --> 01:40:16,660 a point to so every forward backward pass. 1548 01:40:16,960 --> 01:40:22,760 20% of all these intermediate calculations are disabled and dropped to 1549 01:40:22,760 --> 01:40:26,160 zero and then I already trained this and I ran it. 1550 01:40:26,160 --> 01:40:28,960 So drumroll how does it perform? 1551 01:40:29,760 --> 01:40:30,860 So let me just scroll up here. 1552 01:40:32,760 --> 01:40:37,260 We get a validation loss of 1.48 which is actually quite a bit of an 1553 01:40:37,260 --> 01:40:40,160 improvement on what we had before which I think was 2.07. 1554 01:40:40,660 --> 01:40:44,160 So we went from 2.07 all the way down to 1.48 just by scaling up this 1555 01:40:44,160 --> 01:40:45,860 neural net with the code that we have. 1556 01:40:46,360 --> 01:40:48,060 And this of course ran for a lot longer. 1557 01:40:48,060 --> 01:40:52,860 This may be trained for I want to say about 15 minutes on my A100 GPU. 1558 01:40:52,860 --> 01:40:55,960 So that's a pretty good GPU and if you don't have a GPU you're not going to 1559 01:40:55,960 --> 01:40:58,460 be able to reproduce this on a CPU. 1560 01:40:58,460 --> 01:41:02,360 This would be I would not run this on the CPU or MacBook or something like 1561 01:41:02,360 --> 01:41:02,860 that. 1562 01:41:02,860 --> 01:41:06,660 You'll have to break down the number of layers and the embedding dimension 1563 01:41:06,660 --> 01:41:07,260 and so on. 1564 01:41:08,460 --> 01:41:13,360 But in about 15 minutes we can get this kind of a result and I'm printing 1565 01:41:14,060 --> 01:41:15,060 some of the Shakespeare here. 1566 01:41:15,060 --> 01:41:17,860 But what I did also is I printed 10,000 characters. 1567 01:41:17,860 --> 01:41:19,760 So a lot more and I wrote them to a file. 1568 01:41:20,560 --> 01:41:21,960 And so here we see some of the outputs. 1569 01:41:24,260 --> 01:41:27,860 So it's a lot more recognizable as the input text file. 1570 01:41:28,260 --> 01:41:30,860 So the input text file just for reference look like this. 1571 01:41:31,760 --> 01:41:37,260 So there's always like someone speaking in this matter and our predictions 1572 01:41:37,260 --> 01:41:41,560 now take on that form except of course they're nonsensical when you 1573 01:41:41,560 --> 01:41:42,260 actually read them. 1574 01:41:42,860 --> 01:41:43,360 So 1575 01:41:43,360 --> 01:41:46,960 it is every crimp to be a house. 1576 01:41:46,960 --> 01:41:50,560 Oh those probation we give heed. 1577 01:41:52,560 --> 01:41:53,160 You know. 1578 01:41:55,960 --> 01:41:58,060 Oh ho sent me you mighty Lord. 1579 01:42:00,560 --> 01:42:02,160 Anyway, so you can read through this. 1580 01:42:02,160 --> 01:42:06,460 It's nonsensical of course, but this is just a transformer trained on the 1581 01:42:06,460 --> 01:42:10,060 character level for 1 million characters that come from Shakespeare. 1582 01:42:10,060 --> 01:42:13,160 So there's sort of like blabbers on in Shakespeare like math. 1583 01:42:13,360 --> 01:42:16,160 Banner, but it doesn't of course make sense at this scale. 1584 01:42:17,060 --> 01:42:20,560 But I think I think still a pretty good demonstration of what's possible. 1585 01:42:21,760 --> 01:42:23,160 So now 1586 01:42:24,560 --> 01:42:28,360 I think that kind of like concludes the programming section of this video. 1587 01:42:28,560 --> 01:42:32,860 We basically kind of did a pretty good job and of implementing this 1588 01:42:32,860 --> 01:42:37,160 transformer, but the picture doesn't exactly match up to what we've done. 1589 01:42:37,360 --> 01:42:39,560 So what's going on with all these additional parts here? 1590 01:42:40,060 --> 01:42:43,260 So let me finish explaining this architecture and why it looks so funky. 1591 01:42:44,060 --> 01:42:47,960 Basically, what's happening here is what we implemented here is a decoder 1592 01:42:47,960 --> 01:42:48,860 only transformer. 1593 01:42:49,360 --> 01:42:51,160 So there's no component here. 1594 01:42:51,160 --> 01:42:55,260 This part is called the encoder and there's no cross attention block here. 1595 01:42:55,760 --> 01:42:58,960 Our block only has a self attention and the feed forward. 1596 01:42:58,960 --> 01:43:02,560 So it is missing this third in between piece here. 1597 01:43:02,960 --> 01:43:04,260 This piece does cross attention. 1598 01:43:04,560 --> 01:43:06,660 So we don't have it and we don't have the encoder. 1599 01:43:06,760 --> 01:43:11,860 We just have the decoder and the reason we have a decoder only is because we are 1600 01:43:11,860 --> 01:43:13,060 just generating text. 1601 01:43:13,060 --> 01:43:16,660 And it's unconditioned on anything or just we're just blabbering on according 1602 01:43:16,660 --> 01:43:17,560 to a given data set. 1603 01:43:18,360 --> 01:43:22,660 What makes it a decoder is that we are using the triangular mask in our 1604 01:43:22,860 --> 01:43:23,460 transformer. 1605 01:43:23,760 --> 01:43:27,660 So it has this autoregressive property where we can just go and sample from it. 1606 01:43:28,560 --> 01:43:32,460 So the fact that it's using the triangulate triangular mask to mask out the 1607 01:43:32,460 --> 01:43:35,960 attention makes it a decoder and it can be used for language modeling. 1608 01:43:36,660 --> 01:43:41,060 Now, the reason that the original paper had an encoder decoder architecture is 1609 01:43:41,060 --> 01:43:42,760 because it is a machine translation paper. 1610 01:43:43,160 --> 01:43:46,260 So it is concerned with a different setting in particular. 1611 01:43:46,860 --> 01:43:53,260 It expects some tokens that encode say for example French and then it is expected to 1612 01:43:53,260 --> 01:43:55,260 decode the translation in English. 1613 01:43:55,860 --> 01:43:58,960 So so you typically these here are special tokens. 1614 01:43:59,360 --> 01:44:04,360 So you are expected to read in this and condition on it and then you start off the 1615 01:44:04,360 --> 01:44:06,460 generation with a special token called start. 1616 01:44:06,760 --> 01:44:12,860 So this is a special new token that you introduce and always place in the beginning and then 1617 01:44:12,860 --> 01:44:12,960 the. 1618 01:44:13,160 --> 01:44:18,760 Network is expected to output neural networks are awesome and then a special end token to 1619 01:44:18,760 --> 01:44:19,660 finish the generation. 1620 01:44:21,060 --> 01:44:26,260 So this part here will be decoded exactly as we have we've done it neural networks are 1621 01:44:26,260 --> 01:44:32,360 awesome will be identical to what we did but unlike what we did they want to condition the 1622 01:44:32,360 --> 01:44:35,560 generation on some additional information. 1623 01:44:35,660 --> 01:44:39,160 And in that case this additional information is the French sentence that they should be 1624 01:44:39,160 --> 01:44:39,660 translating. 1625 01:44:40,660 --> 01:44:42,860 So what they do now is they. 1626 01:44:42,960 --> 01:44:47,260 Bring the encoder now the encoder reads this part here. 1627 01:44:47,760 --> 01:44:53,260 So we're all going to take the part of French and we're going to create tokens from it exactly as 1628 01:44:53,260 --> 01:44:58,060 we've seen in our video and we're going to put a transformer on it, but there's going to be no 1629 01:44:58,060 --> 01:44:59,460 triangular mask. 1630 01:44:59,560 --> 01:45:03,360 And so all the tokens are allowed to talk to each other as much as they want and they're just 1631 01:45:03,360 --> 01:45:11,160 encoding whatever the content of this French sentence once they've encoded it they've they 1632 01:45:11,160 --> 01:45:12,760 basically come out in the top here. 1633 01:45:13,360 --> 01:45:18,260 And then what happens here is in our decoder which does the language modeling. 1634 01:45:18,660 --> 01:45:24,560 There's an additional connection here to the outputs of the encoder and that is brought in 1635 01:45:24,560 --> 01:45:26,460 through a cross attention. 1636 01:45:27,060 --> 01:45:31,660 So the queries are still generated from X but now the keys and the values are coming from the 1637 01:45:31,660 --> 01:45:37,260 side the keys and the values are coming from the top generated by the nodes that came outside 1638 01:45:37,260 --> 01:45:42,760 of the decode the encoder and those tops the keys and the values there the top of it. 1639 01:45:43,360 --> 01:45:48,660 Feed in on the side into every single block of the decoder and so that's why there's an additional 1640 01:45:48,660 --> 01:45:54,860 cross attention and really what is doing is it's conditioning the decoding not just on the past of 1641 01:45:54,860 --> 01:46:05,060 this current decoding but also on having seen the full fully encoded French prompt sort of and so 1642 01:46:05,060 --> 01:46:08,960 it's an encoder decoder model, which is why we have those two transformers and additional block 1643 01:46:09,260 --> 01:46:09,860 and so on. 1644 01:46:10,160 --> 01:46:12,660 So we did not do this because we have no we have nothing to do. 1645 01:46:12,660 --> 01:46:13,460 Nothing to encode. 1646 01:46:13,460 --> 01:46:14,460 There's no conditioning. 1647 01:46:14,460 --> 01:46:18,960 We just have a text file and we just want to imitate it and that's why we are using a decoder only 1648 01:46:18,960 --> 01:46:21,660 transformer exactly as done in GPT. 1649 01:46:22,860 --> 01:46:23,060 Okay. 1650 01:46:23,060 --> 01:46:28,960 So now I wanted to do a very brief walkthrough of nano GPT, which you can find in my GitHub and nano 1651 01:46:28,960 --> 01:46:31,060 GPT is basically two files of interest. 1652 01:46:31,260 --> 01:46:36,860 There's train.pi and model.pi train.pi is all the boilerplate code for training the network. 1653 01:46:37,060 --> 01:46:41,060 It is basically all the stuff that we had here is the training loop. 1654 01:46:41,960 --> 01:46:42,460 It's just that. 1655 01:46:42,460 --> 01:46:46,660 It's a lot more complicated because we're saving and loading checkpoints and pre-trained weights 1656 01:46:46,660 --> 01:46:51,660 and we are decaying the learning rate and compiling the model and using distributed training across 1657 01:46:51,660 --> 01:46:53,260 multiple nodes or GPUs. 1658 01:46:53,760 --> 01:46:57,060 So the training that Pi gets a little bit more hairy, complicated. 1659 01:46:57,460 --> 01:47:03,660 There's more options Etc, but the model that I should look very very similar to what we've done 1660 01:47:03,660 --> 01:47:04,060 here. 1661 01:47:04,260 --> 01:47:06,460 In fact, the model is almost identical. 1662 01:47:07,260 --> 01:47:12,360 So first here we have the causal self-attention block and all of this should look very very 1663 01:47:12,360 --> 01:47:12,460 very similar. 1664 01:47:12,460 --> 01:47:17,980 recognizable to you we're producing queries keys values we're doing dot products we're masking 1665 01:47:17,980 --> 01:47:25,260 applying softmax optionally dropping out and here we are pooling the values what is different here 1666 01:47:25,260 --> 01:47:32,100 is that in our code i have separated out the multi-headed attention into just a single 1667 01:47:32,100 --> 01:47:38,100 individual head and then here i have multiple heads and i explicitly concatenate them whereas 1668 01:47:38,100 --> 01:47:43,180 here all of it is implemented in a batched manner inside a single causal self-attention 1669 01:47:43,180 --> 01:47:48,120 and so we don't just have a b and a t and a c dimension we also end up with a fourth dimension 1670 01:47:48,120 --> 01:47:53,320 which is the heads and so it just gets a lot more sort of hairy because we have four-dimensional 1671 01:47:53,320 --> 01:47:59,360 array tensors now but it is equivalent mathematically so the exact same thing is 1672 01:47:59,360 --> 01:48:03,740 happening as what we have it's just it's a bit more efficient because all the heads are now 1673 01:48:03,740 --> 01:48:08,000 treated as a batch dimension as well then we have the multi-layered perceptron 1674 01:48:08,000 --> 01:48:08,080 and we have the multi-layered perceptron and we have the multi-layered perceptron 1675 01:48:08,100 --> 01:48:13,560 it's using the gelu non-linearity which is defined here except instead of relu and this 1676 01:48:13,560 --> 01:48:16,520 is done just because openly i used it and i want to be able to load their checkpoints 1677 01:48:16,520 --> 01:48:22,560 the blocks of the transformer are identical the communicate and the compute phase as we saw 1678 01:48:22,560 --> 01:48:27,180 and then the gpt will be identical we have the position encodings token encodings 1679 01:48:27,180 --> 01:48:33,600 the blocks the layer norm at the end the final linear layer and this should look all very 1680 01:48:33,600 --> 01:48:38,000 recognizable and there's a bit more here because i'm loading checkpoints and stuff like that 1681 01:48:38,000 --> 01:48:42,160 i'm separating out the parameters into those that should be weight decayed and those that 1682 01:48:42,160 --> 01:48:48,240 shouldn't but the generate function should also be very very similar so a few details are different 1683 01:48:48,240 --> 01:48:53,160 but you should definitely be able to look at this file and be able to understand a lot of the pieces 1684 01:48:53,160 --> 01:48:58,440 now so let's now bring things back to chat gpt what would it look like if we wanted to train 1685 01:48:58,440 --> 01:49:03,680 chat gpt ourselves and how does it relate to what we learned today well to train the chat gpt there 1686 01:49:03,680 --> 01:49:07,980 are roughly two stages first is the pre-training stage and then the fine-tuning stage and then the 1687 01:49:08,000 --> 01:49:13,920 pre-training stage in the pre-training stage we are training on a large chunk of internet and just 1688 01:49:13,920 --> 01:49:20,400 trying to get a first decoder only transformer to babble text so it's very very similar to what 1689 01:49:20,400 --> 01:49:27,920 we've done ourselves except we've done like a tiny little baby pre-training step and so in our case 1690 01:49:29,360 --> 01:49:33,440 this is how you print a number of parameters i printed it and it's about 10 million 1691 01:49:33,440 --> 01:49:37,120 so this transformer that i created here to create a little shakespeare 1692 01:49:38,000 --> 01:49:45,140 transformer was about 10 million parameters our data set is roughly 1 million characters so 1693 01:49:45,140 --> 01:49:49,900 roughly 1 million tokens but you have to remember that openai uses different vocabulary they're not 1694 01:49:49,900 --> 01:49:55,480 on the character level they use these subword chunks of words and so they have a vocabulary 1695 01:49:55,480 --> 01:50:02,640 of 50 000 roughly elements and so their sequences are a bit more condensed so our data set the 1696 01:50:02,640 --> 01:50:07,840 shakespeare data set would be probably around 300 000 tokens in the openai vocabulary roughly 1697 01:50:08,480 --> 01:50:15,520 so we trained about 10 million parameter model on roughly 300 000 tokens now when you go to the gpt3 paper 1698 01:50:16,960 --> 01:50:22,400 and you look at the transformers that they trained they trained a number of transformers 1699 01:50:22,400 --> 01:50:28,640 of different sizes but the biggest transformer here has 175 billion parameters uh so ours is 1700 01:50:28,640 --> 01:50:35,040 again 10 million they used this number of layers in the transformer this is the n embed this is 1701 01:50:35,040 --> 01:50:37,620 the number of heads and this is the head size 1702 01:50:38,000 --> 01:50:45,440 and then this is the batch size so ours was 65 and the learning rate is similar 1703 01:50:45,440 --> 01:50:50,600 now when they train this transformer they trained on 300 billion tokens so 1704 01:50:50,600 --> 01:50:56,320 again remember ours is about 300,000 so this is about a million fold increase 1705 01:50:56,320 --> 01:50:59,580 and this number would not be even that large by today's standards you'd be 1706 01:50:59,580 --> 01:51:05,660 going up 1 trillion and above so they are training a significantly larger 1707 01:51:05,660 --> 01:51:12,020 model on a good chunk of the internet and that is the pre-training stage but 1708 01:51:12,020 --> 01:51:15,380 otherwise these hyperparameters should be fairly recognizable to you and the 1709 01:51:15,380 --> 01:51:18,420 architecture is actually like nearly identical to what we implemented 1710 01:51:18,420 --> 01:51:22,460 ourselves but of course it's a massive infrastructure challenge to train this 1711 01:51:22,460 --> 01:51:27,140 you're talking about typically thousands of GPUs having to you know talk to each 1712 01:51:27,140 --> 01:51:31,880 other to train models of this size so that's just a pre-training stage now 1713 01:51:31,880 --> 01:51:35,540 after you complete the pre-training stage you don't get something that 1714 01:51:35,540 --> 01:51:35,640 you don't get something that you don't get something that you don't get 1715 01:51:35,640 --> 01:51:40,240 a response to your questions with answers and it's not helpful and etc you 1716 01:51:40,240 --> 01:51:45,640 get a document completer right so it babbles but it doesn't babble Shakespeare 1717 01:51:45,640 --> 01:51:49,560 it babbles internet it will create arbitrary news articles and documents 1718 01:51:49,560 --> 01:51:52,080 and it will try to complete documents because that's what it's trained for 1719 01:51:52,080 --> 01:51:55,740 it's trying to complete the sequence so when you give it a question it would 1720 01:51:55,740 --> 01:51:59,520 just potentially just give you more questions it would follow with more 1721 01:51:59,520 --> 01:52:04,240 questions it will do whatever it looks like the some closed document would do 1722 01:52:04,240 --> 01:52:05,520 in the training data 1723 01:52:05,520 --> 01:52:08,840 on the internet and so who knows you're getting kind of like undefined behavior 1724 01:52:08,840 --> 01:52:13,320 it might basically answer with two questions with other questions it might 1725 01:52:13,320 --> 01:52:16,980 ignore your question it might just try to complete some news article it's 1726 01:52:16,980 --> 01:52:22,020 totally unaligned as we say so the second fine-tuning stage is to actually 1727 01:52:22,020 --> 01:52:28,280 align it to be an assistant and this is the second stage and so this chat GPT 1728 01:52:28,280 --> 01:52:32,480 blog post from opening I talks a little bit about how this stage is achieved we 1729 01:52:32,480 --> 01:52:34,920 basically 1730 01:52:35,520 --> 01:52:39,540 roughly three steps to it to this stage so what they do here is they start to 1731 01:52:39,540 --> 01:52:43,260 collect training data that looks specifically like what an assistant 1732 01:52:43,260 --> 01:52:46,560 would do so there are documents that have the format where the question is on 1733 01:52:46,560 --> 01:52:50,700 top and then an answer is below and they have a large number of these but 1734 01:52:50,700 --> 01:52:53,760 probably not on the order of the internet this is probably on the order 1735 01:52:53,760 --> 01:53:00,600 of maybe thousands of examples and so they they then fine-tune the model to 1736 01:53:00,600 --> 01:53:05,220 basically only focus on documents that look like that and so you're starting to 1737 01:53:05,220 --> 01:53:08,720 slowly align it so it's going to expect a question at the top and it's going to 1738 01:53:08,720 --> 01:53:13,980 expect to complete the answer and these very very large models are very sample 1739 01:53:13,980 --> 01:53:18,120 efficient during their fine-tuning so this actually somehow works but that's 1740 01:53:18,120 --> 01:53:21,780 just step one that's just fine-tuning so then they actually have more steps where 1741 01:53:21,780 --> 01:53:26,340 okay the second step is you let the model respond and then different raters 1742 01:53:26,340 --> 01:53:30,540 look at the different responses and rank them for their preferences to which one 1743 01:53:30,540 --> 01:53:35,000 is better than the other they use that to train a reward model so they can predict a 1744 01:53:35,000 --> 01:53:42,040 basically using a different network, how much of any candidate response would be desirable. 1745 01:53:42,720 --> 01:53:47,900 And then once they have a reward model, they run PPO, which is a form of policy gradient 1746 01:53:47,900 --> 01:53:55,640 reinforcement learning optimizer, to fine-tune this sampling policy so that the answers that 1747 01:53:55,640 --> 01:54:02,700 the chat GPT now generates are expected to score a high reward according to the reward model. 1748 01:54:02,700 --> 01:54:08,160 And so basically there's a whole aligning stage here, or fine-tuning stage. It's got multiple 1749 01:54:08,160 --> 01:54:13,900 steps in between there as well, and it takes the model from being a document completer to a 1750 01:54:13,900 --> 01:54:19,520 question answerer, and that's like a whole separate stage. A lot of this data is not 1751 01:54:19,520 --> 01:54:25,120 available publicly. It is internal to OpenAI, and it's much harder to replicate this stage. 1752 01:54:26,320 --> 01:54:32,180 And so that's roughly what would give you a chat GPT. And NanoGPT focuses on the pre-training stage. 1753 01:54:32,180 --> 01:54:32,580 Okay. 1754 01:54:32,700 --> 01:54:38,920 And that's everything that I wanted to cover today. So we trained, to summarize, a decoder-only 1755 01:54:38,920 --> 01:54:45,460 transformer following this famous paper, Attention is All You Need, from 2017. And so that's 1756 01:54:45,460 --> 01:54:53,420 basically a GPT. We trained it on tiny Shakespeare and got sensible results. All of the training 1757 01:54:53,420 --> 01:55:02,140 code is roughly 200 lines of code. I will be releasing this code base. So also it comes with 1758 01:55:02,140 --> 01:55:02,420 all the... 1759 01:55:02,700 --> 01:55:08,780 Git log commits along the way, as we built it up. In addition to this code, I'm going to release 1760 01:55:08,780 --> 01:55:14,540 the notebook, of course, the Google Colab. And I hope that gave you a sense for how you can train 1761 01:55:16,140 --> 01:55:21,100 these models, like, say, GPT-3, that will be architecturally basically identical to what we 1762 01:55:21,100 --> 01:55:25,580 have, but they are somewhere between 10,000 and 1 million times bigger, depending on how you count. 1763 01:55:26,620 --> 01:55:32,540 And so that's all I have for now. We did not talk about any of the fine-tuning stages. That would, 1764 01:55:32,540 --> 01:55:36,460 typically, go on top of this. So if you're interested in something that's not just language 1765 01:55:36,460 --> 01:55:41,580 modeling, but you actually want to, you know, say, perform tasks, or you want them to be aligned in 1766 01:55:41,580 --> 01:55:47,180 a specific way, or you want to detect sentiment or anything like that, basically, any time you 1767 01:55:47,180 --> 01:55:51,340 don't want something that's just a document completer, you have to complete further stages 1768 01:55:51,340 --> 01:55:56,380 of fine-tuning, which we did not cover. And that could be simple, supervised fine-tuning, 1769 01:55:56,380 --> 01:56:00,780 or it can be something more fancy, like we see in ChatGPT, where we actually train a reward model, 1770 01:56:00,780 --> 01:56:02,540 and then do rounds of PPO to... 1771 01:56:02,540 --> 01:56:07,340 align it with respect to the reward model. So there's a lot more that can be done on top of it. 1772 01:56:07,340 --> 01:56:12,620 I think for now, we're starting to get to about two hours, Mark. So I'm going to kind of finish 1773 01:56:12,620 --> 01:56:19,660 here. I hope you enjoyed the lecture. And yeah, go forth and transform. See you later.