Writing an LLM from scratch, part 15 -- from context vectors to logits; or, can it really be that simple?!
Having worked through chapter 3 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)", and spent some time digesting the concepts it introduced (most recently in my post on the complexity of self-attention at scale), it's time for chapter 4.
I've read it through in its entirety, and rather than working through it section-by-section in order, like I did with the last one, I think I'm going to jump around a bit, covering each new concept and how I wrapped my head around it separately. This chapter is a lot easier conceptually than the last, but there were still some "yes, but why do we do that?" moments.
The first of those is the answer to a question I'd been wondering about since at least part 6 in this series, and probably before. The attention mechanism is working through the (tokenised, embedded) input sequence and generating these rich context vectors, each of which expresses the "meaning" of its respective token in the context of the words that came before it. How do we go from there to predicting the next word in the sequence?
The answer, at least in the form of code showing how it happens, leaped out at
me the first time I looked at the first listing in this chapter, for the initial
DummyGPTModel that will be filled in as we go through it.
In its __init__, we create our token and position embedding mappings,
and an object to handle dropout, then
the multiple layers of attention heads (which
are a bit more complex than the heads we've been working with so far, but more on that later),
then some kind of normalisation layer, then:
self.out_head = nn.Linear(
cfg["emb_dim"], cfg["vocab_size"], bias=False
)
...and then in the forward method, we run our tokens through all of that and then:
logits = self.out_head(x)
return logits
The x in that second bit of code is our context vectors from all of that hard
work the attention layers did -- folded, spindled and mutilated a little by things
like layer normalisation and being run through feed-forward networks with GELU (about
both of which I'll go into in future posts) -- but ultimately just the context vectors.
And all we do to convert it into these logits, the output of the LLM, is run it through a single neural network layer. There's not even a bias, or an activation function -- it's basically just a single matrix multiplication!
My initial response was, essentially, WTF. Possibly WTFF. Gradient descent over neural networks is amazingly capable at learning things, but this seemed quite a heavy lift. Why would something so simple work? (And also, what are "logits"?)
Unpicking that took a bit of thought, and that's what I'll cover in this post.
Writing an LLM from scratch, part 14 -- the complexity of self-attention at scale
Between reading chapters 3 and 4 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)", I'm taking a break to solidify a few things that have been buzzing through my head as I've worked through it. Last time I posted about how I currently understand the "why" of the calculations we do for self-attention. This time, I want to start working through my budding intuition on how this algorithm behaves as we scale up context length.. As always, this is to try to get my own thoughts clear in my head, with the potential benefit of helping out anyone else at the same stage as me -- if you want expert explanations, I'm afraid you'll need to look elsewhere :-)
The particular itch I want to scratch is around the incredible increases in context lengths over the last few years. When ChatGPT first came out in late 2022, it was pretty clear that it had a context length of a couple of thousand tokens; conversations longer than that became increasingly surreal. But now it's much better -- OpenAI's GPT-4.1 model has a context window of 1,047,576 tokens, and Google's Gemini 1.5 Pro is double that. Long conversations just work -- and the only downside is that you hit rate limits faster if they get too long.
It's pretty clear that there's been some impressive engineering going into achieving that. And while understanding those enhancements to the basic LLM recipe is one of the side quests I'm trying to avoid while reading this book, I think it's important to make sure I'm clear in my head what the problems are, even if I don't look into the solutions.
So: why is context length a problem?
Writing an LLM from scratch, part 13 -- the 'why' of attention, or: attention heads are dumb
Now that I've finished chapter 3 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)" -- having worked my way through multi-head attention in the last post -- I thought it would be worth pausing to take stock before moving on to Chapter 4.
There are two things I want to cover, the "why" of self-attention, and some thoughts on context lengths. This post is on the "why" -- that is, why do the particular set of matrix multiplications described in the book do what we want them to do?
As always, this is something I'm doing primarily to get things clear in my own head -- with the possible extra benefit of it being of use to other people out there. I will, of course, run it past multiple LLMs to make sure I'm not posting total nonsense, but caveat lector!
Let's get into it. As I wrote in part 8 of this series:
I think it's also worth noting that [what's in the book is] very much a "mechanistic" explanation -- it says how we do these calculations without saying why. I think that the "why" is actually out of scope for this book, but it's something that fascinates me, and I'll blog about it soon.
That "soon" is now :-)