Writing an LLM from scratch, part 17 -- the feed-forward network
I'm still working through chapter 4 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)". This chapter not only puts together the pieces that the previous ones covered, but adds on a few extra steps. I'd previously been thinking of these steps as just useful engineering techniques ("folding, spindling and mutilating" the context vectors) to take a model that would work in theory, but not in practice, and make it something trainable and usable -- but in this post I'll explain why that was wrong.
Last time I covered layer normalisation, which I managed to get a satisfactory (but not great) handle on -- it is a trick to constrain the outputs of a layer in the LLM so that one token doesn't "drown out" the signals from the others and cause problems like exploding or vanishing gradients during training (and, I would imagine, to some degree during inference). That definitely is one of those engineering techniques to ensure trainability.
This time I want to go through the feed-forward layer, which is a different kind of beast. It is covered in just four and a half pages in the book, and the implementation is really simple -- indeed, two of the pages are an in-depth look at GELU, the specific activation function that the GPT-2 style model that we're building uses, and the rest is just making concrete exactly how to write a simple neural network with one hidden layer that uses GELU.
So, the "how" was simple enough. It was the "why" that surprised me. The more I thought about it, the more I realised that this part of the LLM is just as important as the attention mechanism itself. In my current working model, at least, attention tells the LLM what to think about -- gathering the "meaning" of an input vector in the context of those to its left -- but it's the linear layers that actually do that thinking and allow the LLM as a whole to make next-token predictions. Indeed, there are more parameters in a normal LLM for these networks than there are for the attention mechanism itself! They're clearly super-important.
Let's dig in.
What is happening
The steps that Raschka shows are simple enough:
- We run the context vectors that we got from the normalisation layer that came after the attention mechanism through a simple biased linear layer, expanding them from 768 dimensions to four times that -- 3072.
- We run those through the GELU activation function -- I won't go into that in detail because the book covers it perfectly well.1 It doesn't seem too complicated in principle and its advantages (no sharp corner like ReLU has at zero, and only one point where the gradient is zero, whereas ReLU has a gradient of zero for all negative numbers) make intuitive sense.
- After the activation function, we have another biased linear layer, bringing the dimensionality back down to 768 so that it's compatible with the input "format" expected by the next layer.
Note that this is done on a per-context-vector basis -- that is, each vector is run through the linear layer independently. All of the "cross-talk" between the tokens has happened during the attention mechanism. So the linear layer sees each context vector individually -- though, of course, thanks to attention, the context vectors have not only information from their respective tokens, but from other tokens that came before them in the sequence.
So, the "what" is pretty simple. It's a simple neural network with a single hidden layer, processing the vectors in the sequence in parallel like you would normally feed a batch through any other neural network.
Why do we do it?
Why inject non‑linearity?
When I first thought about it, it seemed to just be a textbook neural network, rather arbitrarily dumped into the middle of this system. I won't go into the details of how those work -- there are great explanations out there (I love Anil Ananthaswamy's "Why Machines Learn").
But at a high level, my intuition is that neural networks are pattern-matchers. If we think of a network as dividing some high-dimensional space into sections, then we can see that projecting it into a higher-dimensional space -- where that projection is learned during training -- can make it easier to divide different regions, and that having non-linearity in the dividers that we use to divide those regions makes it easier to make complex shapes with them -- easier than it would be with straight lines, planes or hyperplanes.2
So, we have a simple neural network with one hidden layer tacked on after attention, doing some kind of pattern-matching.
But isn't attention "all you need"?
But my understanding initially was that we already had all of this clever attention stuff that was doing the work necessary to make the context vector for each token contain all of the relevant information from all tokens to its left. We were then building this up over multiple layers, so that it would eventually point in the direction of the predicted next token given that information.
Why would we have a random classifier thrown in there?
Like most people confronted with a problem understanding something these days, I started chats with a bunch of different LLMs to talk around the problem and get their input. I think it was ChatGPT's o3 that made it click for me.
Attention itself is just combining information from input vectors. The first layer, for example, given "the fat cat sat on the" might blend some information from every token in the input into that last "the". Later layers would blend increasing amounts of information from the previous tokens into it.
But combining information isn't really "thinking". LLMs can produce output that shows some kind of reasoning. We can, of course, have long philosophical debates about whether it's "real" thinking, and throw around words like "stochastic parrot". But something is going on in there that is not just simple statistical pattern matching. I think it was Dijkstra who said something like "the question of whether a computer can think is no more interesting than the question of whether a submarine can swim" -- it doesn't really matter what you call it, something is happening in there.
Back when ChatGPT 4 came out, I was delighted to see that one of my trick prompts no longer worked. When I asked version 3.5 this question:
Imagine that I have a cup, and inside the cup is a diamond. The cup is on a table. I pick the cup up, and take it to a bed, where I put it down. Then I turn the cup upside down. I wait for five seconds. Then I turn the cup back the right way up. Then I pick up the cup again, and take it back to the table. Where is the diamond, and why?
...it told me that the diamond was still in the cup, its reasoning being something like "you never took it out of there".
But ChatGPT 4 gave the correct answer, that the diamond would have fallen onto the bed when I inverted the cup. It "knew" that cups are open at the top.
Intuitively, it's hard if not impossible to imagine how attention alone would have a knowledge base like that. All it is doing is agglomerating information about tokens. In my new intuition, it seems to me that attention is building up something like the hidden state from one of those encoder RNNs that were used with decoders for translation -- it's working out, for each token, what the input sequence up to that point means.
In order to predict a next token, the LLM needs to think about that meaning -- and that is what these linear layers are doing. It's inside its linear layers (maybe in one place, but probably spread across a bunch of places) that ChatGPT 4 "knew" that cups are open at the top, and what that implies.
I'm sure this is a gross simplification, but I think that for now it's a useful working model. Attention is how the LLM works out what to think about, and the feed-forward layers are where it does its thinking.
Parameter counts
Another indication about how important this all is showed up when I started thinking about parameter counts.
Let's think about how many parameters the attention weights have. We have three attention weight matrices per multi-head attention layer -- they have one row per dimension of the input embeddings -- for our 124M GPT-2 model -- and then one column per dimension of the output embeddings -- , again.3 So, there are parameters for each of those three matrices. We also have a linear layer with inputs and outputs to combine the different heads' results into a single coherent context vector.
So that's four weight matrices, each , for each layer's attention mechanism, giving us parameters for that part.
Now let's think about this feed-forward network. We have two layers:
- The first one projects from dimensions to dimensions, so that one has weights.
- The second one projects from dimensions back down to , so it has weights.
That's
So that means that we have double the number of parameters for the feed-forward network than we do for the attention mechanism itself. 4
It must be important! Otherwise why would we use 66% of our parameters for it?
Why just two layers?
My next thought was, why do we only have two layers? If they're that important, maybe we can have more?
Now, as I understand it, a sufficiently-large neural network with one hidden layer -- which is what we have here -- is a universal approximator. My maths on this is more than a little shaky, but I believe that means that any network with more than one hidden layer can, in theory, be approximated essentially perfectly by an equivalent network with just one -- though that layer might be pretty huge.
But people do, of course, build much deeper networks -- because adding on layers can make the network smaller and easier to train -- though the depth does of course come with its own problems, like vanishing/exploding gradients and so on.
From a bit of reading around, it looks like people do use deeper networks here -- indeed, I was a bit annoyed when I realised that my "isn't attention all you need?" gag from the heading above (something that came to my mind quite early in this writeup) had already been made by an interesting-looking paper that studies exactly that.
At some point when I've finished this book, I'm considering doing posts where I work through papers, and that one is definitely going on the list.
So that's the "why"
When I was writing about layer normalisation, I said, about a multi-layered set of simple multi-head attention mechanisms, without what I saw as "folding, spindling and mutilating" of the context vectors that came out of them:
That actually feels like a pretty complete system, and I rather suspect that in principle, an LLM might work just with those calculations.
I think I was wrong about that. Skipping layer normalisation sounds like it could in theory work (though for the reasons I covered then, it probably wouldn't be trainable in practice). But missing out the feed-forward layer sounds like it would not. It would lead to an LLM that might in some sense "understand" or at least gather knowledge about the meaning and structure of an input sequence, by building up context vectors over multiple layers (like I keep saying, similarly to how a CNN builds up increasingly complex representations of what is in an image over its multiple layers) -- but it wouldn't be able to do anything with that "understanding". It would not be able to predict the next token.
And that wouldn't be very helpful.
I think that pretty much wraps things up for this post. I hope it was as useful to read as it was for me to write!
The next one also promises to be interesting -- what Raschka calls shortcut connections. They're another thing that Raschka explains quickly and simply, because the implementation really is simple -- but they're also something where the why is much deeper than it looks. It will use the Talmud as a metaphor... So see you then!
Here's a link to the next post in this series.
-
While reading around this, I noticed that "Attention Is All You Need" used the normal, much simpler ReLU as its activation function (section 3.3). I believe that GELU was brought in by the original paper introducing the GPT architecture. ↩
-
Perhaps it would be worth continuing the series I started back in February with "Basic matrix maths for neural networks: the theory" and "Basic matrix maths for neural networks: in practice", expanding it into a more general "basic neural networks for techies" series -- let me know in the comments if you'd be interested. ↩
-
Remember that the columns are kind of split into per-head stripes, so with 12 heads, the first 64 columns would "belong to" the first head, the next 64 to the second head, and so on. ↩
-
Eagle-eyed readers will note that I'm ignoring biases in the various weight matrices here. They're tiny by comparison to the other numbers, and don't affect the result meaningfully. ↩