The fixed length bottleneck and the feed forward network
This post is a kind of note-to-self of a hitch I'm having in my understanding of the mechanics of LLMs at this point in my journey. Please treat it as the musings of a learner, and if you have suggestions on ways around this minor roadblock, comments below would be very welcome!
Having read about and come to the seeds of a working understanding of the role of the feed-forward network in a GPT-style LLM, something has come to mind that I'm still working my way through. It's likely due to a bug in at least one of the mental models I've constructed so far, so what I'd like to do in this post is express the issue as clearly as I can. Hopefully having done that I'll be able to work my way through it in the future, and will be able to post about the solution.
The core of the issue is that the feed-forward network operates on a per-context-vector basis -- that is, the context vectors for each and every token are processed by the same one-hidden-layer neural network in parallel, with no crosstalk between them -- the inter-token communication is all happening in the attention mechanism.
But this means that the amount of data that the FFN is handling is fixed -- it's a vector of numbers, with a dimensionality determined by the LLM's architecture -- 768 for the 124M parameter GPT-2 model I'm studying.
Here's the issue: in my mental model of the LLM, the attention mechanism is working out what to think about, but the FFN is what's doing the thinking (for hand-wavy values of "thinking"). So, given that it's thinking about one context vector at a time, there's a limit to how much it can think about -- just whatever can be represented in those 768 dimensions for this size of GPT-2.
This reminds me very much of the fixed-length bottleneck that plagued early encoder-decoder translation systems. There's a limit to how much data you can jam into a single vector.
Now, this is an error of some kind on my side -- I'm far from being knowledgable enough about LLMs or AI in general to be able to spot problems like this. And I'm pretty sure that the answer lies in one of my mental models being erroneous.
It seems likely that it's related to the interplay between the attention mechanism and the FFNs; that's certainly what's come through in my discussions with various AIs about it. But none of the explanations I've read has been quite enough to gel for me, so in this post I'll detail the issue as well as I can, so that later on I can explain the error in my ways :-)
The fixed-length bottleneck
A good starting point is to go through what this fixed-length bottleneck is, or at least was. In early machine translation systems -- way back when in the mid-2010s -- a popular architecture was the encoder/decoder pair. Each of the two parts was a recurrent neural network, which is like a normal NN but has an internal hidden state. You feed in inputs one at a time, and it both produces an output (like a normal NN) and updates its hidden state.
So, you would feed the text you wanted to translate into the encoder, token by token, and ignore its outputs, but just let it update its hidden state to something.
You'd then take that hidden state (perhaps massaging it a bit first) and dump it into the hidden state of the decoder. You'd then feed null inputs into the decoder, one by one, and expect it to output the translated sentence.
In order to get that all working, of course, you'd need to train the encoder and decoder as a pair on a vast database of translated sentences in your source and target languages. But it worked pretty well! For short inputs.
The reason it didn't work for longer sequences was that the hidden state that was transferred from encoder to decoder had a fixed size. Essentially, it had to store a complete representation of the input text -- kind of like a sentence embedding. The encoder would create this, and the decoder would use it to know what the meaning of the output was meant to be.
The hidden state had a fixed size -- it was a vector (or more specifically a set of vectors) of fixed length. And obviously there's a limit, just due to simple information theory, to how much information you can fit into a fixed set of numbers. "The fat cat sat on the mat" -- no problem at all. The Wikipedia page on the Assyrian Empire -- not so much.
So -- the longer the input text, the lossier the "compression" into the hidden state vectors, and the less likely the output would be to accurately represent the input.1
Attention to the rescue!
As I understand it, attention was originally invented as a way to solve this. Think of how a human translator would translate something. They might read it, remember what it means -- but while they were writing their translation, they might refer back to the original text to make sure that it was an accurate representation, and also refer to what they had written so far to make sure that it was coherent.
Likewise, with these attention mechanisms, the decoder would have cross-attention to the hidden states that the encoder had for each token (kind of like the human looking back at the original text) and self-attention to what it had emitted so far.2
The metaphor of the human translator is useful, but even at a kind of mechanical level, you can see how this works around the fixed-length bottleneck -- the decoder's ability to look back at the encoder's hidden states means that it's not limited to the fixed length hidden-state at the end -- instead, it has one of those for every input token -- that is, the length is no longer fixed, it scales linearly with the length of the input sequence. Problem solved! Maybe.
The context vectors as a variable-length hidden state
Next, along came "Attention is all you need" paper, and then GPT and so on, and people eventually threw away the encoder and the hidden state entirely, and left just a decoder with self-attention. You could give that decoder a text (if you see it as a decoder that was hijacked from a translation system, perhaps you can imagine that you're lying to it and saying "this is what you've written so far") and it would use its attention mechanism to work out what to say next -- the next-token prediction that we use LLMs for.
The expanding of the "hidden state" to scale with the number of tokens in the input sequence still holds. Each attention layer in the LLM takes in a set of input vectors, does its magic, and spits out context vectors, which in turn become the input vectors for the next layer. Taken together these context vectors, at each layer, act as a hidden state, and there's one context vector for every token in the input sequence.
That was where my mental model stood until I learned about the mechanics of the feed-forward network, and then everything came to a screeching halt.
Per-token processing in the feed-forward network
Like I said earlier, the model I came up with to understand the role of the feed-forward networks in each attention layer was that attention worked out what to think about, but the feed-forward network does the actual thinking.
And this is where the doubts crept in. The feed-forward network operates on each context vector independently -- and, of course, each context vector has a fixed length; the 768 dimensions for 124M GPT-2.3
Now, of course, the attention mechanism's job is to shuffle information from context vector to context vector, so that each one can gain information from the relevant tokens to its left. And that's all very well -- but the amount of information that can be contained in the context vector remains constant. If you try to jam in too much, it will be lossy, just like if you tried to translate a long text with one of those encoder-decoder systems without attention.
So: if attention works out what to think about, and the FFN does the thinking, the amount of "stuff" that can be thought about is actually limited. It's a different fixed-length bottleneck to the encoder-decoder one, of course -- the system as a whole has a variable length hidden state made up of all of those context vectors taken together -- but each "unit" of thinking has a very limited amount of information it can think about.
To put it another way -- remember that information only flows from left to right across the sequence during causal attention. The "final decision" for the last context vector -- and thus the resulting logits -- for each token comes from the FFN for the last layer, which has just whatever happens to have been jammed into the incoming context vector plus whatever that layer's attention mechanism might have done (and of course, whatever happened in the previous layers). Importantly, the most important prediction during inference, the predicted next token for the last one in the sequence, comes from the FFN in the last layer processing the context vector for that last token -- and it's limited by the size of that context vector. How can it make sure an important choice -- especially if it's the last token in an input sequence long enough that only a tiny portion of its meaning could have been represented in those 768 numbers?
Attempt 1: a bureaucracy
A while back I was thinking about using metaphors from human organisations to try to unlock intuition about what's going on inside LLMs. Now, organisations make decisions that are too complex for individual humans all of the time. The core idea is that different decisions are made by different people, with the right information being provided to them for their specific part of the whole.
So, let's try that one here: we can see each of the FFN blocks as a worker in a giant bureaucracy. They receive memos of a fixed length -- the context vectors coming in -- and have to make decisions based on them.
Perhaps they don't have enough information to make their decision, and they know it -- that's fine, they pass on a summary of what they received with a note saying "more information of this sort is needed to make a final decision here" -- in mechanistic terms inside an LLM, they'll emit a context vector that at the next round of attention will produce a query vector that will match up with the kind of other information that is needed.
With this metaphor, the answer to the question of how the "final decision maker" can get things right boils down nicely: the previous workers who've been handling this case (the FFNs in the previous layers) have summarised things, and the final decision maker doesn't need to know all of the details.
That sorta-kinda works, but I'm working through getting planning permission for some building works on my home right now and not feeling that well-disposed towards bureaucracies.4 Is there something more market-oriented?
Attempt 2: maybe Hayek? (Not the actor)
Market enthusiasts don't tend to think that business people are geniuses (if we ignore Ayn Rand) -- they're normal humans reacting to price signals that the market generates. They receive a limited amount of information about the world, see an opportunity to make money, and with their normal human intelligence manage to succeed in that -- or to fail in it, thus providing more information for the next entrepreneur to come along.
So in this model, the decision to buy "mat" as the next token made by the FFN that is processing the last token in "the fat cat sat on the" is made based on the prices of the different possible tokens, which is a fixed amount of information coming from the other FFNs manufacturing the inputs to make those tokens.
Hmm. In my mind, I rather like that metaphor, but written out like that it doesn't seem quite so clear. Maybe I'm just not expressing it well, or maybe I just like it because of my politics ;-)
Still not quite there
Either way, neither of those metaphors fits well for me right now. Perhaps I've not found the right metaphor -- or perhaps the problem is with the models of how LLMs work that I'm building on.
Hopefully as I learn more about the field, and get to a solid soup-to-nuts understanding of how LLMs work -- or perhaps later -- I'll be able to work out which, and revisit it.
But for now, that's where I am -- it feels like there's a problem with LLMs where the information going in to work out the next token is still constrained by a fixed-length bottleneck. I'm sure that no such problem exists, so that points to a gap in my understanding I need to address.
If and when I come to a solution, I'll link to it from here.
-
It's actually even worse than that, because after emitting some output words, the decoder might have to somehow represent the original concept plus the fact that part of it had already been "said". See my earlier post for an example. ↩
-
The encoder had its own self-attention, and one day I'm sure I'll get around to finding out why... ↩
-
Of course it's expanded to 4x that, to 3,072 dimensions, inside the FFN, but that doesn't change the amount of information in it, even if it makes that information easier to work with. ↩
-
Dreams of RLHF-ing bureaucrats ↩