Writing an LLM from scratch, part 12 -- multi-head attention

Posted on 21 April 2025 in AI, Python, LLM from scratch, TIL deep dives

In this post, I'm wrapping up chapter 3 of Sebastian Raschka's "Build a Large Language Model (from Scratch)". Last time I covered batches, which -- somewhat to my disappointment -- didn't involve completely new (to me) high-order tensor multiplication, but instead relied on batched and broadcast matrix multiplication. That was still interesting on its own, however, and at least was easy enough to grasp that I didn't disappear down a mathematical rabbit hole.

The last section of chapter 3 is about multi-head attention, and while it wasn't too hard to understand, there were a couple of oddities that I want to write down -- as always, primarily to get it all straight in my own head, but also just in case it's useful for anyone else.

So, the first question is, what is multi-head attention?

What multi-head attention means

What we've gone through so far in the book -- a full attention mechanism that generates context vectors from a batch of input sequences (each of which is a list of input embeddings), by using the basic attention mechanism calculations plus dropout and a causal mask -- is a single attention head. Over its training, it will learn certain ways of working out those context vectors that lead to "good" results. In a way, it has "opinions" about how to pay attention to other tokens when looking at a specific token. It will also learn values for the Wv weights that lead to "good" output representations.

If having one of these is good, then running a bunch in parallel clearly has benefits. Each one can learn different ways to pay attention to the inputs, and then their resulting context vectors can be combined to get a richer high-level result.

For example, in

The fat cat sat on the mat

One attention head might learn how articles like "the" relate to the nouns they are near to, another might learn how the verb "sat" interacts with the subject "cat" and the indirect object "mat", another might learn how "fat" modifies "cat", and so on. Their combined results would be a richer representation of what the different parts of the sentence mean in context.

So, what multi-head attention means is that we are running our input sequence through a bunch of n separate heads, each with its own trainable attention weights -- Wk, Wq and Wv -- then we're somehow munging the resulting n sets of context vectors together to get a combined result, and that's what is fed forward to the next part of the LLM.

Now, it's pretty clear that something involving batched matrix multiplication is likely to be the best way to do this kind of thing, to run all of the heads in parallel, but Raschka, sensibly enough, starts off with an easier-to-understand example.

The simple first example

Just like he did with the original attention mechanism calculations, he starts off by doing the multi-head attention "by hand" in Python code, without trying to use complex maths. In a way I think he's aiming to kick-start an intuition about how it all fits together.

He writes a PyTorch nn.Module that has a list of CausalAttention objects -- each a single attention head -- as a field, and then in its forward method just runs through them one after another to get a set of context vectors.

He then uses torch.cat to concatenate them on the last dimension -- that is, the columns, which makes sense. The output of a single attention head is b×ntokens×dout, and that last dimension is the one we're concatenating them on.

That means that the result is b×ntokens×(dout·nheads), and the aggregate context vector for a particular token is just the vector produced by the first head in the first dout columns, then the vector from the second head in the second dout columns and so on. Not a particularly sophisticated way to combine them, but as a starting point it makes sense.

So the next step is to start leaning on PyTorch's optimised batched matrix multiplication so that we can run all of the heads in parallel.

Multi-head attention with batched matrix maths -- but with a twist!

Before I looked at the code, I had a pretty clear idea in my head what it would look like. Unfortunately it wasn't quite right, but I think it's actually a pretty good starting point for understanding what the actual code does.

Our single-head attention class looked like this:

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5,
            dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

Now, all we needed for multi-head attention -- or at least, so it seemed to me -- was to add an extra dimension to W_query, W_key, and W_value. They're matrices, so this would just make them order-3 tensors; a stack of matrices, with one matrix in the stack for each head.

The rest of the code should just work as-is, and by running it we'd wind up with a four-dimensional tensor of results -- one item per head in one dimension, one per batch item in another, and then the remaining two making up the matrices.

Or in other words, if we have a batch size b, a number of attention heads nheads, ntokens input tokens in each batch, and a desired output context vector length of dout, then we'd have an order-4 tensor like this:

b×nheads×ntokens×dout

We'd then do something with that to collapse the per-head dimension -- perhaps just a concatenation like in the previous example -- and we'd be done.

The code changes would be pretty trivial, just pass in a num_heads to the __init__, then initialise the weights something like this:

self.W_query = nn.Linear(n_heads, d_in, d_out, bias=qkv_bias)

...and do whatever we wanted for the collapse at the end of the forward method.

Now, if we were still using nn.Parameter for our weights, that would have worked! The combining of the context vectors at the end might have been a bit fiddlier in reality than it was in my head, but I think it would have been reasonably easy to understand as an extension of the last, single-head version.

However, we're not using nn.Parameter: we're using nn.Linear, because -- Raschka explained earlier -- the weight initialisation is better for our purposes.

This causes a problem, because nn.Linear is designed to be a layer in a regular neural network. Unlike nn.Parameter, which contains a data Tensor of any dimensionality nn.Linear objects can only have two dimensions -- the number of input features and the number of output features. So that line above in my imaginary code where we created an nn.Linear with three dimensions would be an error.

So while I still think the idea was quite elegant, it doesn’t work in practice because of our use of nn.Linear.

The way that the solution to this works is quite clever, and (for all I know) is actually more efficient than the solution I came up with -- I think it would be a mistake to see it as just a workaround for the choice of nn.Linear, though I'll need to dig in a bit further to see if that's true.

Let's work through it step by step.

We do indeed have a new num_heads parameter:

def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):

Now, the d_out is the dimensionality we want for the context vectors this multi-head attention class will spit out at the end. So if we were to just concatenate the context vectors from each head like the original simple implementation did, we would need each one to provide a context vector of length d_out / num_heads, and while we're not just doing concatenation (I'll come back to that later), we do size them appropriately for that:

self.head_dim = d_out // num_heads

Note the explicit integer division, which makes sense.1

But that's not used when constructing the weights; instead, we just do this:

self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)

...and likewise for W_key and W_value.

For me, the easiest way to visualise these is that for each of them, we have the weights for all of our heads side-by-side, in adjacent groups of columns. The first d_head columns are for the first head, the next d_head for the second, and so on.

Moving on to the forward method, we just project our inputs into key, query and value space exactly like we did before, for example:

keys = self.W_key(x)

That makes sense -- because of the way that matrix multiplication works, the columns in the second argument map directly to the columns in the result. (Remember that that Python code is equivalent to XWkey because nn.Linear swaps round the order, so the second argument to the underlying matrix multiplication is Wkey.)

So because we can see W_query as being the weights for all of the attention heads in a single matrix side-by-side, the result of that multiplication is the projection matrices of the inputs into key space for each head, also in a single matrix side-by-side.

This single matrix multiplication feels to me like it might possibly be more efficient than a batched one would be -- the one I was imagining in my original model -- though I'm not 100% sure.

The next bit, however, is a move in the same direction as the idea that I had before I looked at the code. Let's keep using the keys as the example. Firstly, we use PyTorch's Tensor.view method to reshape it:

keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)

What that's doing is keeping the same underlying data, but creating a representation of it with different dimensionality -- presumably the name view was chosen because it's very similar to the way you can create a view in SQL to provide a different... well, view, of your data without having to copy stuff into a new table.

So we're starting with keys as the result of the matrix multiplication, so its shape is b×ntokens×dout. We're adding in an extra dimension; if my reading of the docs is correct, this is done by splitting an existing dimension up into slices, and which dimension to split is inferred from the parameters. In this case, the output we want has the same first two dimensions as the input, so it's the dout one that is split up into two: one for the number of heads nheads, and one for the size of each head's output, dhead.

So we wind up with (as the parameters say) with something shaped like this:

b×ntokens×nheads×dhead

Next, we do a transpose of dimensions 1 and 2:

keys = keys.transpose(1, 2)

-- remember that dimensions are zero-indexed, so we get something shaped like this:

b×nheads×ntokens×dhead

That's exactly what we would have got if we had been using stacks of per-head weights for the Wkey parameters in the first place! So the next few steps of the code are almost exactly as they were with single-head attention.

Firstly we work out the dot product for all of the heads -- the same code as before, except that now that we have an extra dimension at the start for nheads we need to transpose dimensions 2 and 3 rather than 1 and 2:

attn_scores = queries @ keys.transpose(2, 3)

I think it's worth keeping track of the shape of what we're working with here.

We just did a batched matrix multiplication of

So what we have is going to be b×nheads×ntokens×ntokens.

That makes sense -- it's a batch of results, each of which has a bunch of results for the heads, each of which is a ntokens×ntokens matrix saying how much attention we need to pay -- when looking at a particular token -- to each of the other tokens from the perspective of that particular attention head.

Next, we do our mask for causal attention -- the only change here is that we've broken the mask out into a variable:

mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
attn_scores.masked_fill_(mask_bool, -torch.inf)

That doesn't change the shape of attn_scores, of course.

Finally, we do our softmax and dropout exactly as before:

attn_weights = torch.softmax(
    attn_scores / keys.shape[-1] ** 0.5,
    dim=-1
)
attn_weights = self.dropout(attn_weights)

Again the shape of attn_scores is unchanged.

Now it gets a little more involved. Previously we calculated our single-head context vector like this:

context_vec = attn_weights @ values

We do the same thing here, but need to do a transpose afterwards:

context_vec = (attn_weights @ values).transpose(1, 2)

So let's think about that batched matrix multiplication.

We have attn_weights, which has the same shape as attn_scores: b×nheads×ntokens×ntokens.

We're multiplying that by values, which is b×nheads×ntokens×dhead.

So our result is going to be b×nheads×ntokens×dhead.

What we're doing is shuffling it around so that it's b×ntokens×nheads×dhead.

Why? It's because we're about to reshape the tensor to combine the outputs from the heads. Earlier on, we had a keys matrix that we wanted to split up, and we did that with a view. The dimension that we split was turned into two new dimensions that were next to each other in the shape, and because we wanted to have them separated a bit, we had to do a transpose.

So we can see this as the opposite of the transpose we did back there, getting stuff back into a shape where we can combine things again; we're going to want to merge the last two dimensions together to get a single one of dout=nheads·dhead.

That's what we do with our next step:

context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)

The call to contiguous isn't explained, but the docs say that it's to produce a "contiguous in memory tensor" with the same data. It feels like a carefully-positioned thing to copy data around so that it's in an efficient shape in memory for future use -- kind of tidying up the mess that might have been left around the underlying in-memory representation by all of this splitting and transposing and so on.

But once that's done, that call to view just reshapes it to b×ntokens×dout, essentially merging those last two dimensions -- which is basically the same as the concatenation that we did in the original, simple code.

A linear layer after the combination

There's one extra thing that was added to this beyond the original, simple implementation of multi-head attention -- after we've essentially created context vectors that are the different context vectors from each head concatenated, we run them through a single linear layer. It's defined in __init__:

self.out_proj = nn.Linear(d_out, d_out)

So, a single neural network layer with bias (the bias kwarg is True by default), with both input and output dimensions the same -- our desired size for the output context vectors.

Then we run our concatenated context vectors through it at the end:

context_vec = self.out_proj(context_vec)

Raschka doesn't go into details as to why we do this, but TBF just whacking the context vectors together into a larger one felt a bit kludgy, and adding an extra layer where the model can learn better ways of doing it that make sense in the context of what it's doing seems like a sensible idea.

Phew!

That was quite a lot of work to go through, but I don't think there's anything hugely tricky in there. I do think that it would have been easier to get a handle on it if we'd been using order-3 weights rather than the single large matrices that needed to be split apart -- but on the other hand, that didn't require much work to understand, and we’d still have needed the reshaping at the end to combine the context vectors.

So, it's a small price to pay, and we do get better weight initialisation. Perhaps also the single (large) matrix multiplication it requires is more efficient than a batched one over per-head weights -- though there is also the cost of those reshapes and transpositions to balance off against that. As I understand it, PyTorch does a lot of that kind of thing just by changing metadata about the tensor rather than by shuffling bytes around in memory, but the call to contiguous near the end suggests that there's a limit to how much can be achieved with that.

There's also that qkv_bias kwarg in the __init__ for the module, which if set to True would add bias terms to the three weight matrices. I'm sure it will be used at some point later in the book, and it would be a bit messy (though not hugely tricky) to add it to a setup where we were using nn.Parameter.

But anyway: having reached this point, we have a fully-functioning multi-head attention module, with configurable input and output dimension counts, context length, dropout, and number of heads. After four months (!), Chapter 3 is done.

Raschka tweeted that Chapter 3 "might be the most technical one (like building the engine of a car) but it gets easier from here!" And that's good to know. Because now we have this powerful thing that can turn input vectors into context vectors and can learn the best way to do that.

But how do we get from that to something that achieves the underlying goal of our LLM: to predict the next token in a sequence given the tokens so far?

Here's a link to the next post in this series.


  1. I also noticed that the previous simple example didn't do the same thing -- instead, each attention head had an output with context vectors that were d_out in length, so the combined concatenated ones were d_out * num_heads. Still, it was just illustrative code to kickstart intuition into multi-head attention, so that doesn't really matter.