Writing an LLM from scratch, part 16 -- layer normalisation
I'm now working through chapter 4 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)". It's the chapter where we put together the pieces that we've covered so far and wind up with a fully-trainable LLM, so there's a lot in there -- nothing quite as daunting as fully trainable self-attention was, but still stuff that needs thought.
Last time around I covered something that for me seemed absurdly simplistic for what it did -- the step that converts the context vectors our attention layers come up with into output logits. Grasping how a single matrix multiplication could possibly do that took a bit of thought, but it was all clear in the end.
This time, layer normalisation.
Writing an LLM from scratch, part 15 -- from context vectors to logits; or, can it really be that simple?!
Having worked through chapter 3 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)", and spent some time digesting the concepts it introduced (most recently in my post on the complexity of self-attention at scale), it's time for chapter 4.
I've read it through in its entirety, and rather than working through it section-by-section in order, like I did with the last one, I think I'm going to jump around a bit, covering each new concept and how I wrapped my head around it separately. This chapter is a lot easier conceptually than the last, but there were still some "yes, but why do we do that?" moments.
The first of those is the answer to a question I'd been wondering about since at least part 6 in this series, and probably before. The attention mechanism is working through the (tokenised, embedded) input sequence and generating these rich context vectors, each of which expresses the "meaning" of its respective token in the context of the words that came before it. How do we go from there to predicting the next word in the sequence?
The answer, at least in the form of code showing how it happens, leaped out at
me the first time I looked at the first listing in this chapter, for the initial
DummyGPTModel
that will be filled in as we go through it.
In its __init__
, we create our token and position embedding mappings,
and an object to handle dropout, then
the multiple layers of attention heads (which
are a bit more complex than the heads we've been working with so far, but more on that later),
then some kind of normalisation layer, then:
self.out_head = nn.Linear(
cfg["emb_dim"], cfg["vocab_size"], bias=False
)
...and then in the forward
method, we run our tokens through all of that and then:
logits = self.out_head(x)
return logits
The x
in that second bit of code is our context vectors from all of that hard
work the attention layers did -- folded, spindled and mutilated a little by things
like layer normalisation and being run through feed-forward networks with GELU (about
both of which I'll go into in future posts) -- but ultimately just the context vectors.
And all we do to convert it into these logits, the output of the LLM, is run it through a single neural network layer. There's not even a bias, or an activation function -- it's basically just a single matrix multiplication!
My initial response was, essentially, WTF. Possibly WTFF. Gradient descent over neural networks is amazingly capable at learning things, but this seemed quite a heavy lift. Why would something so simple work? (And also, what are "logits"?)
Unpicking that took a bit of thought, and that's what I'll cover in this post.
Writing an LLM from scratch, part 14 -- the complexity of self-attention at scale
Between reading chapters 3 and 4 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)", I'm taking a break to solidify a few things that have been buzzing through my head as I've worked through it. Last time I posted about how I currently understand the "why" of the calculations we do for self-attention. This time, I want to start working through my budding intuition on how this algorithm behaves as we scale up context length.. As always, this is to try to get my own thoughts clear in my head, with the potential benefit of helping out anyone else at the same stage as me -- if you want expert explanations, I'm afraid you'll need to look elsewhere :-)
The particular itch I want to scratch is around the incredible increases in context lengths over the last few years. When ChatGPT first came out in late 2022, it was pretty clear that it had a context length of a couple of thousand tokens; conversations longer than that became increasingly surreal. But now it's much better -- OpenAI's GPT-4.1 model has a context window of 1,047,576 tokens, and Google's Gemini 1.5 Pro is double that. Long conversations just work -- and the only downside is that you hit rate limits faster if they get too long.
It's pretty clear that there's been some impressive engineering going into achieving that. And while understanding those enhancements to the basic LLM recipe is one of the side quests I'm trying to avoid while reading this book, I think it's important to make sure I'm clear in my head what the problems are, even if I don't look into the solutions.
So: why is context length a problem?
Writing an LLM from scratch, part 13 -- the 'why' of attention, or: attention heads are dumb
Now that I've finished chapter 3 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)" -- having worked my way through multi-head attention in the last post -- I thought it would be worth pausing to take stock before moving on to Chapter 4.
There are two things I want to cover, the "why" of self-attention, and some thoughts on context lengths. This post is on the "why" -- that is, why do the particular set of matrix multiplications described in the book do what we want them to do?
As always, this is something I'm doing primarily to get things clear in my own head -- with the possible extra benefit of it being of use to other people out there. I will, of course, run it past multiple LLMs to make sure I'm not posting total nonsense, but caveat lector!
Let's get into it. As I wrote in part 8 of this series:
I think it's also worth noting that [what's in the book is] very much a "mechanistic" explanation -- it says how we do these calculations without saying why. I think that the "why" is actually out of scope for this book, but it's something that fascinates me, and I'll blog about it soon.
That "soon" is now :-)
Writing an LLM from scratch, part 12 -- multi-head attention
In this post, I'm wrapping up chapter 3 of Sebastian Raschka's "Build a Large Language Model (from Scratch)". Last time I covered batches, which -- somewhat to my disappointment -- didn't involve completely new (to me) high-order tensor multiplication, but instead relied on batched and broadcast matrix multiplication. That was still interesting on its own, however, and at least was easy enough to grasp that I didn't disappear down a mathematical rabbit hole.
The last section of chapter 3 is about multi-head attention, and while it wasn't too hard to understand, there were a couple of oddities that I want to write down -- as always, primarily to get it all straight in my own head, but also just in case it's useful for anyone else.
So, the first question is, what is multi-head attention?
Writing an LLM from scratch, part 11 -- batches
I'm still working through chapter 3 of Sebastian Raschka's "Build a Large Language Model (from Scratch)". Last time I covered dropout, which was nice and easy.
This time I'm moving on to batches. Batches allow you to run a bunch of different input sequences through an LLM at the same time, generating outputs for each in parallel, which can make training and inference more efficient -- if you've read my series on fine-tuning LLMs you'll probably remember I spent a lot of time trying to find exactly the right batch sizes for speed and for the memory I had available.
This was something I was originally planning to go into in some depth, because there's some fundamental maths there that I really wanted to understand better. But the more time I spent reading into it, the more of a rabbit hole it became -- and I had decided on a strict "no side quests" rule when working through this book.
So in this post I'll just present the basic stuff, the stuff that was necessary for me to feel comfortable with the code and the operations described in the book. A full treatment of linear algebra and higher-order tensor operations will, sadly, have to wait for another day...
Let's start off with the fundamental problem of why batches are a bit tricky in an LLM.
Writing an LLM from scratch, part 10 -- dropout
I'm still chugging through chapter 3 of Sebastian Raschka's "Build a Large Language Model (from Scratch)". Last time I covered causal attention, which was pretty simple when it came down to it. Today it's another quick and easy one -- dropout.
The concept is pretty simple: you want knowledge to be spread broadly across your model, not concentrated in a few places. Doing that means that all of your parameters are pulling their weight, and you don't have a bunch of them sitting there doing nothing.
So, while you're training (but, importantly, not during inference) you randomly ignore certain parts -- neurons, weights, whatever -- each time around, so that their "knowledge" gets spread over to other bits.
Simple enough! But the implementation is a little more fun, and there were a couple of oddities that I needed to think through.
Writing an LLM from scratch, part 9 -- causal attention
My trek through Sebastian Raschka's "Build a Large Language Model (from Scratch)" continues... Self-attention was a hard nut to crack, but now things feel a bit smoother -- fingers crossed that lasts! So, for today: a quick dive into causal attention, the next part of chapter 3.
Causal attention sounds complicated, but it just means that when we're looking at a word, we don't pay any attention to the words that come later. That's pretty natural -- after all, when we're reading something, we don't need to look at words later on to understand what the word we're reading now means (unless the text is really badly written).
It's
"causal" in the sense of causality -- something can't have an effect on something that came before
it, just like causes can't some after effects in reality. One big plus about
getting that is that I finally now understand why I was using a class called
AutoModelForCausalLM
for my earlier experiments in fine-tuning LLMs.
Let's take a look at how it's done.
Writing an LLM from scratch, part 8 -- trainable self-attention
This is the eighth post in my trek through Sebastian Raschka's book "Build a Large Language Model (from Scratch)". I'm blogging about bits that grab my interest, and things I had to rack my brains over, as a way to get things straight in my own head -- and perhaps to help anyone else that is working through it too. It's been almost a month since my last update -- and if you were suspecting that I was blogging about blogging and spending time getting LaTeX working on this site as procrastination because this next section was always going to be a hard one, then you were 100% right! The good news is that -- as so often happens with these things -- it turned out to not be all that tough when I really got down to it. Momentum regained.
If you found this blog through the blogging-about-blogging, welcome! Those posts were not all that typical, though, and I hope you'll enjoy this return to my normal form.
This time I'm covering section 3.4, "Implementing self-attention with trainable weights". How do we create a system that can learn how to interpret how much attention to pay to words in a sentence, when looking at other words -- for example, that learns that in "the fat cat sat on the mat", when you're looking at "cat", the word "fat" is important, but when you're looking at "mat", "fat" doesn't matter as much?
Writing an LLM from scratch, part 7 -- wrapping up non-trainable self-attention
This is the seventh post in my series of notes on Sebastian Raschka's book "Build a Large Language Model (from Scratch)". Each time I read part of it, I'm posting about what I found interesting or needed to think hard about, as a way to help get things straight in my own head -- and perhaps to help anyone else that is working through it too.
This post is a quick one, covering just section 3.3.2, "Computing attention weights for all input tokens". I'm covering it in a post on its own because it gets things in place for what feels like the hardest part to grasp at an intuitive level -- how we actually design a system that can learn how to generate attention weights, which is the subject of the next section, 3.4. My linear algebra is super-rusty, and while going through this one, I needed to relearn some stuff that I think I must have forgotten sometime late last century...