Writing an LLM from scratch, part 12 -- multi-head attention
In this post, I'm wrapping up chapter 3 of Sebastian Raschka's "Build a Large Language Model (from Scratch)". Last time I covered batches, which -- somewhat to my disappointment -- didn't involve completely new (to me) high-order tensor multiplication, but instead relied on batched and broadcast matrix multiplication. That was still interesting on its own, however, and at least was easy enough to grasp that I didn't disappear down a mathematical rabbit hole.
The last section of chapter 3 is about multi-head attention, and while it wasn't too hard to understand, there were a couple of oddities that I want to write down -- as always, primarily to get it all straight in my own head, but also just in case it's useful for anyone else.
So, the first question is, what is multi-head attention?
What multi-head attention means
What we've gone through so far in the book -- a full attention mechanism that generates context vectors from a batch of input sequences (each of which is a list of input embeddings), by using the basic attention mechanism calculations plus dropout and a causal mask -- is a single attention head. Over its training, it will learn certain ways of working out those context vectors that lead to "good" results. In a way, it has "opinions" about how to pay attention to other tokens when looking at a specific token. It will also learn values for the weights that lead to "good" output representations.
If having one of these is good, then running a bunch in parallel clearly has benefits. Each one can learn different ways to pay attention to the inputs, and then their resulting context vectors can be combined to get a richer high-level result.
For example, in
The fat cat sat on the mat
One attention head might learn how articles like "the" relate to the nouns they are near to, another might learn how the verb "sat" interacts with the subject "cat" and the indirect object "mat", another might learn how "fat" modifies "cat", and so on. Their combined results would be a richer representation of what the different parts of the sentence mean in context.
So, what multi-head attention means is that we are running our input sequence through a bunch of separate heads, each with its own trainable attention weights -- , and -- then we're somehow munging the resulting sets of context vectors together to get a combined result, and that's what is fed forward to the next part of the LLM.
Now, it's pretty clear that something involving batched matrix multiplication is likely to be the best way to do this kind of thing, to run all of the heads in parallel, but Raschka, sensibly enough, starts off with an easier-to-understand example.
The simple first example
Just like he did with the original attention mechanism calculations, he starts off by doing the multi-head attention "by hand" in Python code, without trying to use complex maths. In a way I think he's aiming to kick-start an intuition about how it all fits together.
He writes
a PyTorch nn.Module that has a list of CausalAttention objects -- each a single attention head -- as a field,
and then in its forward method just runs through them one after another to get
a set of context vectors.
He then uses torch.cat to concatenate them on the
last dimension -- that is, the columns, which makes sense. The output of a single
attention head is , and that last dimension is
the one we're concatenating them on.
That means that the result is , and the aggregate context vector for a particular token is just the vector produced by the first head in the first columns, then the vector from the second head in the second columns and so on. Not a particularly sophisticated way to combine them, but as a starting point it makes sense.
So the next step is to start leaning on PyTorch's optimised batched matrix multiplication so that we can run all of the heads in parallel.
Multi-head attention with batched matrix maths -- but with a twist!
Before I looked at the code, I had a pretty clear idea in my head what it would look like. Unfortunately it wasn't quite right, but I think it's actually a pretty good starting point for understanding what the actual code does.
Our single-head attention class looked like this:
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
self.register_buffer(
'mask',
torch.triu(torch.ones(context_length, context_length), diagonal=1)
)
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.transpose(1, 2)
attn_scores.masked_fill_(
self.mask.bool()[:num_tokens, :num_tokens],
-torch.inf
)
attn_weights = torch.softmax(
attn_scores / keys.shape[-1] ** 0.5,
dim=-1
)
attn_weights = self.dropout(attn_weights)
context_vec = attn_weights @ values
return context_vec
Now, all we needed for multi-head attention -- or at least, so it seemed to me
-- was to add an extra dimension to W_query, W_key, and W_value. They're
matrices, so this would just make them order-3 tensors; a stack of matrices, with
one matrix in the stack for each head.
The rest of the code should just work as-is, and by running it we'd wind up with a four-dimensional tensor of results -- one item per head in one dimension, one per batch item in another, and then the remaining two making up the matrices.
Or in other words, if we have a batch size , a number of attention heads , input tokens in each batch, and a desired output context vector length of , then we'd have an order-4 tensor like this:
We'd then do something with that to collapse the per-head dimension -- perhaps just a concatenation like in the previous example -- and we'd be done.
The code changes would be pretty trivial, just pass in a num_heads to the __init__, then
initialise the weights something like this:
self.W_query = nn.Linear(n_heads, d_in, d_out, bias=qkv_bias)
...and do whatever we wanted for the collapse at the end of the forward method.
Now, if we were still using nn.Parameter for our weights, that would have worked!
The combining of the context vectors at the end might have been a bit fiddlier in
reality than it was in my head, but I think it would have been reasonably easy to
understand as an extension of the last, single-head version.
However, we're not using nn.Parameter: we're using nn.Linear, because -- Raschka explained earlier -- the
weight initialisation is better for our purposes.
This causes a problem, because nn.Linear is designed to be a layer in a regular
neural network. Unlike nn.Parameter, which contains a data Tensor of any dimensionality
nn.Linear objects can only have two dimensions -- the
number of input features and the number of output features. So that line above
in my imaginary code where we created an nn.Linear with three dimensions would
be an error.
So while I still think the idea was quite elegant, it doesn’t work in practice
because of our use of nn.Linear.
The way that the solution to this works is quite clever, and (for all I know) is
actually more efficient than the solution I came up with -- I think it would be
a mistake to see it as just a workaround for the choice of nn.Linear, though
I'll need to dig in a bit further to see if that's true.
Let's work through it step by step.
We do indeed have a new num_heads parameter:
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
Now, the d_out is the dimensionality we want for the context vectors this multi-head
attention class will spit out at the end. So if we were to just concatenate the context vectors
from each head like the original simple implementation did, we would need each one to provide a
context vector of length d_out / num_heads, and while we're not just doing concatenation
(I'll come back to that later), we do size them appropriately for that:
self.head_dim = d_out // num_heads
Note the explicit integer division, which makes sense.1
But that's not used when constructing the weights; instead, we just do this:
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
...and likewise for W_key and W_value.
For me, the easiest way to visualise these
is that for each of them, we have the weights for all of our heads side-by-side,
in adjacent groups of columns. The first d_head columns are for the first head,
the next d_head for the second, and so on.
Moving on to the forward method, we just project our inputs into key, query and
value space exactly like we did before, for example:
keys = self.W_key(x)
That makes sense -- because of the way that matrix multiplication works, the columns
in the second argument map directly to the columns in the result. (Remember that that Python code is equivalent to because
nn.Linear swaps round the order, so the second argument to the underlying matrix
multiplication is .)
So because we can see W_query as being the weights for all of the attention heads
in a single matrix side-by-side, the result of that multiplication is the projection
matrices of the inputs into key space for each head, also in a single matrix side-by-side.
This single matrix multiplication feels to me like it might possibly be more efficient than a batched one would be -- the one I was imagining in my original model -- though I'm not 100% sure.
The next bit, however, is a move in the same direction as the idea that I had before
I looked at the code. Let's keep using the keys as the example. Firstly, we use
PyTorch's Tensor.view
method to reshape it:
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
What that's doing is keeping the same underlying data, but creating a representation
of it with different dimensionality -- presumably the name view was chosen because
it's very similar to the way you can create a view in SQL to provide a different...
well, view, of your data without having to copy stuff into a new table.
So we're starting with keys as the result of the matrix multiplication, so
its shape is . We're adding in an extra dimension;
if my reading of the docs is correct, this is done by splitting an existing dimension
up into slices, and which dimension to split is inferred from the parameters. In
this case, the output we want has the same first two dimensions as the input, so
it's the one that is split up into two: one for the number of heads , and
one for the size of each head's output, .
So we wind up with (as the parameters say) with something shaped like this:
Next, we do a transpose of dimensions 1 and 2:
keys = keys.transpose(1, 2)
-- remember that dimensions are zero-indexed, so we get something shaped like this:
That's exactly what we would have got if we had been using stacks of per-head weights for the parameters in the first place! So the next few steps of the code are almost exactly as they were with single-head attention.
Firstly we work out the dot product for all of the heads -- the same code as before, except that now that we have an extra dimension at the start for we need to transpose dimensions 2 and 3 rather than 1 and 2:
attn_scores = queries @ keys.transpose(2, 3)
I think it's worth keeping track of the shape of what we're working with here.
We just did a batched matrix multiplication of
queries, which is , and- the transposed
keys, which is .
So what we have is going to be .
That makes sense -- it's a batch of results, each of which has a bunch of results for the heads, each of which is a matrix saying how much attention we need to pay -- when looking at a particular token -- to each of the other tokens from the perspective of that particular attention head.
Next, we do our mask for causal attention -- the only change here is that we've broken the mask out into a variable:
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
attn_scores.masked_fill_(mask_bool, -torch.inf)
That doesn't change the shape of attn_scores, of course.
Finally, we do our softmax and dropout exactly as before:
attn_weights = torch.softmax(
attn_scores / keys.shape[-1] ** 0.5,
dim=-1
)
attn_weights = self.dropout(attn_weights)
Again the shape of attn_scores is unchanged.
Now it gets a little more involved. Previously we calculated our single-head context vector like this:
context_vec = attn_weights @ values
We do the same thing here, but need to do a transpose afterwards:
context_vec = (attn_weights @ values).transpose(1, 2)
So let's think about that batched matrix multiplication.
We have attn_weights,
which has the same shape as attn_scores: .
We're multiplying that by values, which is .
So our result is going to be .
What we're doing is shuffling it around so that it's .
Why? It's because we're about to reshape the tensor to combine the outputs from
the heads.
Earlier on, we had a keys matrix that we wanted to split
up, and we did that with a view. The dimension that we split was turned into
two new dimensions that were next to each other in the shape, and because we wanted
to have them separated a bit, we had to do a transpose.
So we can see this as the opposite of the transpose we did back there, getting stuff back into a shape where we can combine things again; we're going to want to merge the last two dimensions together to get a single one of .
That's what we do with our next step:
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
The call to contiguous
isn't explained, but the docs say that it's to produce a "contiguous in memory tensor"
with the same data. It feels like a carefully-positioned thing to copy data
around so that it's in an efficient shape in memory for future use -- kind of tidying
up the mess that might have been left around the underlying in-memory representation
by all of this splitting and transposing and so on.
But once that's done, that call to view just reshapes it to ,
essentially merging those last two dimensions -- which is basically the same as the
concatenation that we did in the original, simple code.
A linear layer after the combination
There's one extra thing that was added to this beyond the original, simple implementation
of multi-head attention -- after we've essentially created context vectors that are
the different context vectors from each head concatenated, we run them through
a single linear layer. It's defined in __init__:
self.out_proj = nn.Linear(d_out, d_out)
So, a single neural network layer with bias (the bias kwarg is True by default),
with both input and output dimensions the same -- our desired size for the output
context vectors.
Then we run our concatenated context vectors through it at the end:
context_vec = self.out_proj(context_vec)
Raschka doesn't go into details as to why we do this, but TBF just whacking the context vectors together into a larger one felt a bit kludgy, and adding an extra layer where the model can learn better ways of doing it that make sense in the context of what it's doing seems like a sensible idea.
Phew!
That was quite a lot of work to go through, but I don't think there's anything hugely tricky in there. I do think that it would have been easier to get a handle on it if we'd been using order-3 weights rather than the single large matrices that needed to be split apart -- but on the other hand, that didn't require much work to understand, and we’d still have needed the reshaping at the end to combine the context vectors.
So, it's a small price to pay, and we do get better weight initialisation.
Perhaps also the single (large) matrix multiplication it requires is more efficient
than a batched one over per-head weights -- though there is also the cost of those
reshapes and transpositions to balance off against that. As I understand it,
PyTorch does a lot of that kind of thing just by changing metadata about the tensor
rather than by shuffling bytes around in memory, but the call to contiguous near
the end suggests that there's a limit to how much can be achieved with that.
There's also that qkv_bias kwarg in the __init__ for the module, which if set
to True would add bias terms to the three weight matrices. I'm sure it will be used
at some point later in the book, and it would be a bit messy (though not hugely
tricky) to add it to a setup where we were using nn.Parameter.
But anyway: having reached this point, we have a fully-functioning multi-head attention module, with configurable input and output dimension counts, context length, dropout, and number of heads. After four months (!), Chapter 3 is done.
Raschka tweeted that Chapter 3 "might be the most technical one (like building the engine of a car) but it gets easier from here!" And that's good to know. Because now we have this powerful thing that can turn input vectors into context vectors and can learn the best way to do that.
But how do we get from that to something that achieves the underlying goal of our LLM: to predict the next token in a sequence given the tokens so far?
Here's a link to the next post in this series.
-
I also noticed that the previous simple example didn't do the same thing -- instead, each attention head had an output with context vectors that were
d_outin length, so the combined concatenated ones wered_out * num_heads. Still, it was just illustrative code to kickstart intuition into multi-head attention, so that doesn't really matter. ↩