Writing an LLM from scratch, part 14 -- the complexity of self-attention at scale
Between reading chapters 3 and 4 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)", I'm taking a break to solidify a few things that have been buzzing through my head as I've worked through it. Last time I posted about how I currently understand the "why" of the calculations we do for self-attention. This time, I want to start working through my budding intuition on how this algorithm behaves as we scale up context length.. As always, this is to try to get my own thoughts clear in my head, with the potential benefit of helping out anyone else at the same stage as me -- if you want expert explanations, I'm afraid you'll need to look elsewhere :-)
The particular itch I want to scratch is around the incredible increases in context lengths over the last few years. When ChatGPT first came out in late 2022, it was pretty clear that it had a context length of a couple of thousand tokens; conversations longer than that became increasingly surreal. But now it's much better -- OpenAI's GPT-4.1 model has a context window of 1,047,576 tokens, and Google's Gemini 1.5 Pro is double that. Long conversations just work -- and the only downside is that you hit rate limits faster if they get too long.
It's pretty clear that there's been some impressive engineering going into achieving that. And while understanding those enhancements to the basic LLM recipe is one of the side quests I'm trying to avoid while reading this book, I think it's important to make sure I'm clear in my head what the problems are, even if I don't look into the solutions.
So: why is context length a problem?
This post makes use of big O notation quite a lot -- that is, measures of the complexity of algorithms expressed like this: . If you're not familiar with that, check out this post on freeCodeCamp. Whatever you do, don't go to the Wikipedia page, which is -- frankly -- terrifying.
Complexity in terms of space
You don't need to spend long reading about LLMs before you hear that attention is with respect to the context length. And, looking at the attention score matrices we had earlier, it's pretty clear why, at least in terms of space:
Token | ω("The") | ω("fat") | ω("cat") | ω("sat") | ω("on") | ω("the") | ω("mat") |
---|---|---|---|---|---|---|---|
The | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
fat | 0.2 | 1 | 0 | 0 | 0 | 0 | 0 |
cat | 0.6 | 0.8 | 1 | 0 | 0 | 0 | 0 |
sat | 0.1 | 0 | 0.85 | 1 | 0 | 0 | 0 |
on | 0 | 0.1 | 0.4 | 0.6 | 1 | 0 | 0 |
the | 0 | 0 | 0 | 0 | 0.1 | 1 | 0 |
mat | 0 | 0 | 0.2 | 0.8 | 0.7 | 0.6 | 1 |
For a sequence of length , we have attention scores. With the causal mask, just less than half of them are ignored, but even if we have some kind of clever data structure that stores those "triangular" matrices efficiently, the algorithm is , which is (by the definition of big O notation) the same as .
Eagle-eyed readers will note that I'm ignoring batches in my notes here. I believe that all of the calculations are linear in batch size, so we can safely ignore them.
Let's take some hard numbers. With a context length of 1,024, and assuming 2 bytes per attention score, we get bytes for our attention scores. 2 MiB -- not too bad! Now, of course, we need one of those per attention head (meaning that attention is in terms of heads), but with even consumer GPUs having gigabytes of VRAM, that's no big deal.
But now let's see what would happen if we naively scaled the same algorithm to that binary million context length that the new models have. bytes. That's 2 TiB of attention scores. I'm sure Nvidia would be delighted if you needed 25 H100 cards with 80 GiB each just to hold the attention scores for one head, but realistically that's not what is happening.
Complexity in terms of time
So that's complexity in terms of space; how about time? Let's look at the calculations that we do to get to that attention score matrix. Firstly, we project our inputs into query space:
So, with a sequence length of , an input embedding length of , and a key/query dimensionality of , we're multiplying an matrix by a one to get a new one of . So we're using memory for that result matrix -- not too bad. (In big O notation, given that we're working relative to , we can treat that as , of course, but bear with me).
Let's think about the calculations. For each of those numbers, we need to take the dot product of a column from the first matrix -- with items -- and a row from the second -- also , of course. The dot product is calculated by multiplying the vectors element-wise and summing the results, so that's multiplications followed by additions, making it .
So: that matrix multiplication, an matrix by a one, is an calculation, the dot product, performed times. That makes it -- and that's a general formula for a matrix multiplication, which we can re-use later. It's the rows from the first matrix times the columns from the first (which is the same as the rows in the second) times the columns in the second -- or in other words, the three dimensions (treating the "shared" one as just one) multiplied together.
But in this case, for , as we're working out complexity with respect to , we can just treat as .
Let's move on to the next steps in our calculations, the projections into key and value space:
Both of these are the same complexity as the calculation of : -- the dimensionality of value space might be different, but again, we can ignore that because we're working out complexity with respect to a sequence length of .
The next step in the calculation is to work out our attention scores:
So we're taking a matrix and multiplying it by a one. I'm going to ignore the transpose, as from what I've read, libraries like PyTorch just do that by modifying metadata associated with the tensor, so it's done in a small constant time.
Taking that general equation for the big O of a matrix multiplication above, that comes to , which, as is a constant with respect to context length, is .
Next we convert the attention scores to attention weights with our softmax of the scaled scores:
We're dividing every number in an matrix by , so that's going to be , then working out softmaxes, each of which appears to be , so we've got another there too.
Finally, we work out the context vectors by applying our attention weights:
We're multiplying a matrix by a one, so that is (again, using the matrix multiplication complexity from above) , so yet another .
So: having ground through the maths, we've found that the hardest bits were , and of course that means that the algorithm as a whole is .
And that means that the attention calculations we have are in terms of both space and time.
How does that make it scale?
So, we previously worked out that a 1M token context length made our attention score matrix balloon from 2 MiB to 2 TiB. What does the mean for the time?
Well, let's say that our 1k token model takes 0.1s to work out the next token -- 10 tokens/second doesn't sound unreasonable.
Ignoring the fact that communicating between the 25 GPUs we need to store this monster matrix will take up quite a lot of time -- imagine we somehow have a single GPU with 2 TiB -- then our inference will take a million times longer, because our sequence is a thousand times longer and .
So that's 100,000 seconds, or about 28 hours per token -- not ideal for a chatbot.
Sidebar: fine-tuning
If you've been reading this blog for a while, you might remember that when I was messing around with fine-tuning LLMs, it appeared to be in both space and time versus sequence length. My guess is that this is because I was only using context lengths up to about 2,000, and the various other parts of the training process beyond the attention scores -- the gradients for the billions of parameters, the optimiser states, and so on -- were so large that they drowned out the signal from the parts of the LLMs. Those larger uses of memory and CPU were growing linearly.
Can we do better with something simple?
Is there any low-hanging fruit that we can pick here to optimise this algorithm? Realistically, no -- people wouldn't have used scaled dot attention for research systems (and indeed, early versions of the GPT models) if there were any simple improvements.
It does look like we could make it linear in terms of space; there's no real need to calculate all of the tokens' attention scores in parallel; you could work out the scores for the first token, generate its context vector, then go on to the next one, and so on. This would be in space, so no problem at all!
But it doesn't help in terms of time. Working out each token's context vector would be in time, but we'd need to do that times, keeping our overall time complexity at . And, of course, doing it all serially like that would mean that we'd not be able to take advantage of parallelism on the GPUs, so we'd take an additional performance hit there (which is likely constant and wouldn't show up in the big O notation, but it's real).
So: not particularly helpful. 1
Long-distance dependencies considered harmful
There's something else that seems important here, too. Let's consider that million-token context length again. What on earth is a token at position 983,232 doing looking at token 15? If you're reading War and Peace, you don't need to look up specific words on the first page when you're reading the last chapter 2.
At least in the first layer of your LLM, you'd expect tokens to attend quite closely to ones next to them, but not to anything really far away. Ideally, the kind of global connections from stuff at the start to stuff at the end would be built up from layer to layer, so that by the end, that token from near the end would have some kind of summary of what happened earlier to reference -- the same way as we would when reading a book.
So: at least as I understand it, it's both impractical and unnecessary for a million-token context length LLM to have every token attending to every other token. In modern AIs where the context length is longer, something more is going on -- perhaps something hierarchical, even if it's also something that is learned. What might that be?
A wildly speculative conclusion
Here's where I should probably stop, because I'm heading off on a side quest. From what I've read around the subject, sparse attention and local attention are techniques that would fix this kind of thing -- certainly their names are evocative of the kind of thing that might work.
And of course there are other models outside a pure LLM that can be used to extend the effective context length. Even the first version of ChatGPT could have conversations longer than its context length. As far as I understand it, when the conversation got too long, it would create a "story so far" summary for all but the last couple of messages, and then replace everything but those last few with the summary. The reason longer conversations got weird was that its summaries (and summaries of summaries, and summaries of summaries of summaries) would gradually lose their connection with reality, and it would be responding to a conversation that wasn't actually the one you were having.
So you can imagine an AI that, while it was "reading" War and Peace, would write its own summary of chapter 1, stash that away somewhere -- perhaps indexed by an embedding for easier later retrieval using a technique like Retrieval Augmented Generation. Then later on, when it was reading chapter 35,326, the bit where Prince Vladimir proposes to Countess Olga, it would be able to find appropriate bits of context to pull into the "real" context window and "remember" that in chapter 27 it was mentioned that they were cousins. 3
(I suspect that ChatGPT's "memory" feature is using something along these lines.)
That example is outside the LLM itself, of course, but maybe something along those lines could be built into the model directly without the external systems that an RAG system would require?
Anyway, I'm speculating wildly now. What's certain is that the LLM we're building will, like early versions of ChatGPT, be in both space and time complexity with respect to context length. And finding out about the ways around that looks like a fantastic thing to add on to my list of next steps once I've finished the book.
So it's time to get back to it and get some more foundational knowledge in place. Up next: chapter 4, where we start putting the pieces together to actually build an LLM from the self-attention system we've built. Looking forward to it :-)
Here's a link to the next post in this series.
-
That said, from what little I understand, FlashAttention does something not entirely dissimilar to that at its core, but (of course) with lots of improvements. It's still , but there's a substantial linear speedup. ↩
-
Maybe War and Peace isn't the world's best example -- from what I remember, by the end you might forget who Prince Somethingorother is and have to check the family tree that the translator helpfully put in on page viii. But you get the point I'm making, I hope. ↩
-
It's been a while since I read it, sorry. ↩