Writing an LLM from scratch, part 11 -- batches
I'm still working through chapter 3 of Sebastian Raschka's "Build a Large Language Model (from Scratch)". Last time I covered dropout, which was nice and easy.
This time I'm moving on to batches. Batches allow you to run a bunch of different input sequences through an LLM at the same time, generating outputs for each in parallel, which can make training and inference more efficient -- if you've read my series on fine-tuning LLMs you'll probably remember I spent a lot of time trying to find exactly the right batch sizes for speed and for the memory I had available.
This was something I was originally planning to go into in some depth, because there's some fundamental maths there that I really wanted to understand better. But the more time I spent reading into it, the more of a rabbit hole it became -- and I had decided on a strict "no side quests" rule when working through this book.
So in this post I'll just present the basic stuff, the stuff that was necessary for me to feel comfortable with the code and the operations described in the book. A full treatment of linear algebra and higher-order tensor operations will, sadly, have to wait for another day...
Let's start off with the fundamental problem of why batches are a bit tricky in an LLM.
Batches in normal neural networks
If you're not familiar with batches in normal neural networks, check out my blog post on basic matrix maths for neural networks. This bit is just a summary of part of that post.
With a traditional neural network, we have a bunch of inputs, we feed them through a number of layers, and get a number of outputs. That means that we're starting with an input vector, the first layer takes that vector and transforms it into a vector of activations, the second takes that activation vector and transforms it into a different activation vector, and so on until we have the last activation vector, which is the outputs.
The fact that a single input "run" of the network operates only on vectors means that we can "stack" those vectors into a matrix, and run the network on a bunch of different sets of inputs in parallel. That's batching.
Batches in LLMs
Now, with LLMs we don't have that shortcut, or at least so it looked to me initially. Remember, we take our input tokens, and convert each one into an input embedding, which is a vector. The data -- for one single input sequence, no batches -- that is fed into the attention mechanism is a matrix already. It has one row per token, and the columns in each row make up the token's input embedding.
We then perform matrix maths to work out our attention weights -- another matrix with one row per input token, with each column saying how much attention to pay to each other token -- and then use that to generate our context vectors, which are yet another matrix with one row per token, each one somehow encoding the meaning of that token in the current context.
So if we're already using matrices for a single input sequence, we clearly need some kind of higher-order thing to handle batches, where we're handling multiple input sequences -- multiple matrices -- at the same time.
Tensors to the rescue
Tensors are the answer1. They're not a big stretch conceptually, especially if you're a coder, because they're pretty much like arrays. Mathematically, a scalar -- ie. a number -- is an order-zero tensor, a vector is an order-one tensor (a list, or a one-dimensional array), a matrix is an order-two tensor (a list of lists, each of which is the same length -- or in other words, a two-dimensional array) -- and it's not hard to see that you could also have order-three tensors, order-four tensors, and so on, just like you could keep on adding dimensions to an array.
I've seen code -- rarely good code, but sometimes -- with lines like this:
this_item = records[i][j][k][l][m];
...and there's no deep conceptual problem with that, even if the actual meaning in context might be a tad tricky. It's just indexing into a multi-dimensional array.
The blocker for me was what operations like multiplication mean on higher-order tensors. There is a clear definition of matrix multiplication:
If you have a matrix which has rows and columns -- that is, it's an matrix -- you can multiply it by any other matrix , so long as has rows -- that is, the number of columns in the first matrix in a matrix multiplication has to equal the number of rows in the second one. If we say that is an matrix, the result of the matrix multiplication is a new matrix, with rows and columns -- that is, it has the same number of rows as the first one, and the same number of columns as the second one.
To work out the values in the resulting matrix , we say that -- that is, the element at row , column -- is the dot product of row in the first matrix, taken as a vector, with column in the second matrix, also considered as a vector.
But what happens if you want to multiply a (say) order three tensor with an one? What rules apply (like the columns-in-the-first, rows-in-the-second one for matrix multiplication) to say which dimensions have to match?
This was where I fell into a rabbit hole; Claude was an appalling enabler for this, giving me loads of fascinating information. However, after a while I realised that googling about Einstein summation notation, while great fun, was exactly the kind of side quest I'd sworn not to fall into with this book.
So it was time to focus on what the book says rather than reading around.
Follow the code
I think that the best way to approach it is to start off by showing Raschka's code for self attention without batching, and then to describe just the differences.
Unfortunately the last code listing in the book was before the sections about causal attention and dropout, so it's pretty simple:
class SelfAttention(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(
attn_scores / keys.shape[-1] ** 0.5,
dim=-1
)
context_vec = attn_weights @ values
return context_vec
It's pretty clearly an expression in PyTorch of the maths that came out of the
section on self-attention.
The parameter d_in
is the length of the input vectors, and d_out
is the length
of the context vectors we want it to produce.
There are two oddities; both are due to the use of the nn.Linear
class, which
I kind of glossed over in my earlier post, and probably should not have -- I suspect I
was just a bit overwhelmed by finally getting my head around self-attention!
Raschka's initial implementation used the nn.Parameter
class for the three
weight matrices. He then explained that nn.Linear
has a better
weight-initialisation algorithm for our purposes, and switched to using it, so
I think it's worth looking at the changes that required.
Firstly, there's that qkv_bias
kwarg that has crept in; the reason for that is
that the nn.Linear
class is designed for use in "classical" neural networks.
In that use case it would
control whether or not we have a bias for each neuron (which we normally would, but
might not want to). Here we're just trying to mirror the nn.Parameter
that Raschka used
previously, so we're setting it to False
by default. It's kind of interesting
that he's keeping the option open to set it to True
later by putting it into a
kwarg for the __init__
, but if that's relevant I'm sure it will pop up later.
The other oddity is that we're doing this:
keys = self.W_key(x)
...and likewise for the queries and values. But in the mathematical treatment we had
-- note that they're in the opposite order. That's just a quirk of PyTorch, the
actual calculation being performed is exactly what we have in the maths. Indeed,
in the previous code listing where we used nn.Parameter
, we did indeed have the code
keys = x @ self.W_key
The important thing is that the PyTorch syntax for nn.Linear
reverses the order, but it's the same
matrix multiplication happening2.
So, that's the class with basic self-attention. Now let's look at the equivalent class with causal attention, dropout and batching:
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
self.register_buffer(
'mask',
torch.triu(torch.ones(context_length, context_length), diagonal=1)
)
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.transpose(1, 2)
attn_scores.masked_fill_(
self.mask.bool()[:num_tokens, :num_tokens],
-torch.inf
)
attn_weights = torch.softmax(
attn_scores / keys.shape[-1] ** 0.5,
dim=-1
)
attn_weights = self.dropout(attn_weights)
context_vec = attn_weights @ values
return context_vec
Although Raschka introduces the concepts of causal attention, dropout and batching in that order3, when breaking down the differences between the old code and the new, I think it's worth tweaking that. Let's look at dropout first.
Dropout
This one is really easy -- we construct an nn.Dropout
object, passing in the
dropout
level (from - ) that we got as a parameter when we create the
CausalAttention
object:
self.dropout = nn.Dropout(dropout)
...and then we apply it to the attention weights once we've calculated them:
attn_weights = self.dropout(attn_weights)
Simple enough -- it's really nifty that PyTorch makes that so easy (though, as I noted in my post on dropout, applying it to the weights is unusual and applying it to the attention scores pre-softmax is apparently more common).
Batches
Now we're into the core of it. In our unbatched example, the first bit of code,
the input to the forward
function, x
, was a matrix; its size was num_tokens
rows by d_in
(the length of the input vectors) columns, eg. .
Now we have batches.
If we look at the start of the forward
function, we're extracting
x
's dimensions:
b, num_tokens, d_in = x.shape
So that makes it clear that it's a order-three tensor (where is our batch size).
My mental model of this (which will of course become useless as soon as we need to deal with order-four tensors) is that it's a stack of matrices, one on top of the other.
So what other changes do we need to handle this higher-order tensor? Well, as far as I can understand it, almost nothing! The multiplications to project the inputs into key, query and value space are exactly the same as they were in the unbatched version, eg.
keys = self.W_key(x)
Then there's the transpose, where there is a small change -- instead of having
attn_scores = queries @ keys.T
...we have
attn_scores = queries @ keys.transpose(1, 2)
...but the only obvious change there is that we're saying "transpose the tensor in the 1st and 2nd dimensions, leaving the 0th untouched" -- which makes sense, we're just transposing every matrix in the "stack" and leaving things unchanged between stacked matrices.
So what's going on in the multiplications? Let's start with the easy ones:
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
Taking the first one, we're doing
-- remember that the ordering in the PyTorch code is reversed
compared to mathematical notation because we're using nn.Linear
, so self.W_key(x)
is .
So is a third-order tensor, and is a matrix.
What does it mean to multiply two things of such different shapes -- a third-order and a second-order tensor? I was puzzling about this for a bit until I realised that it's just a broadcast.
Just like we can use broadcasting to add a vector of bias terms to our matrix of batched activations in a traditional neural network (see my post on matrix maths for neural networks in practice for the details):
...if we multiply a 3-tensor by a 2-tensor, we're broadcasting the matrix multiplication across the stack of matrices that makes up the 3-tensor. The last dimension of the 3-tensor (the number of columns in each matrix in the "stack") is , which is the same as the number of rows in the 2-tensor, so it's compatible.
So, keys
, queries
and values
are all 3-tensors, with
dimensions .
Now how about the attention scores? Now we are working with 3-tensors consistently:
attn_scores = queries @ keys.transpose(1, 2)
Let's look at the dimensionality.
queries
is, as we said, .
keys.transpose(1, 2)
means that we've transposed dimensions 1 and 2 of keys
,
and dimensions are zero-indexed. So that means that we're changing it
from to .
It was while I was trying to get my head around exactly what a multiplication of two 3-tensors like that would mean that I fell down my rabbit hole and started reading about Einstein summation notation -- but the answer to what is going on turns out to be much simpler than all of that.
Let's think about what that @
means in a Pythonic sense. Python
allows you to overload mathematical operators; for example, if you define __add__
on a class A
, create an instance of it a
, and then call
a + b
...that will be interpreted as a.__add__(b)
.
Python 3.5 made it possible to overload the @
operator too, and defined it as
meaning matrix multiplication. So, if you have a class that defines __matmul__
,
then it will be called if you do
a @ b
(For completeness -- if the class of a
doesn't define __matmul__
, it will check
whether b
's class defines __rmatmul__
and call b.__rmatmul__(a)
if it does.)
So, what is going on here? When we do
queries @ keys.transpose(1, 2)
...what we're really saying is:
queries.__matmul__(keys.transpose(1, 2))
Although I couldn't find it expressed explicitly in the PyTorch docs for torch.Tensor
,
I'm pretty sure this winds up being the same as
queries.matmul(keys.transpose(1, 2))
-- that is, the dunder (double-underscore) method is an alias for the non-dunder one.
This is in turn an alias for torch.matmul
, so we're actually doing this:
torch.matmul(queries, keys.transpose(1, 2))
The docs for that method show that it's essentially a whole bunch of special cases to make it generally useful for the kind of stuff that people do in ML models (and let mathematical purity be damned ;-). The important bit is this, because it covers not only this case but also the broadcast case above:
If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a batched matrix multiply is returned.
So, no need for any complicated tensor maths. PyTorch knows that people want to
do batched matrix multiplications for its normal use cases, and the @
operator
is designed for exactly that. We just make sure that our batch dimension (or
indeed, per those docs, our batch dimensions) are at the start and the matrix
dimensions are the last two, and it will happily broadcast the whole thing for us.
In a way I was a little disappointed when I worked that out -- it's just a special-case utility function rather than some kind of deep new maths. But never mind; deep knowledge of tensors can wait for another day.
Let's get on to the last part of the differences between the old and the new code: the causal attention.
Causal attention
The two remaining differences between Raschka's first code example and the second
are both to do with the masked self-attention. The first is this bit from the end of
the __init__
method:
self.register_buffer(
'mask',
torch.triu(torch.ones(context_length, context_length), diagonal=1)
)
That torch.triu
call is exactly what we had in the code in the
section on causal attention -- it's
a mask that will be applied to the attention weights to clear the ones that would
imply that a token was paying attention to tokens that came after it. Raschka
somewhat glosses over what the whole register_buffer
thing is all about,
though he does mention that it means that the mask
will automatically be moved to the same device as our model so that if (say)
the model is on the GPU, so will the mask be, so we won't get errors when trying
to use it.
The docs
didn't make things a whole lot clearer for me, but the impression I get is that
the CausalAttention
class that we're creating here is regarded as a single
"thing" by PyTorch, an nn.Module
, and will be moved between CPU and GPU as a unit. A
variable that was a field of the module -- like, say, W_query
-- would be treated
as part of it, but would also be treated as a trainable parameter by default, and
we certainly don't want this mask to be adjusted as part of training -- it should
always be a triangular mask with negative infinities in the top right. So we
register it as a buffer to make sure that it is part of the module, but isn't
treated as a set of trainable parameters.
I'm sure there's more depth there, but I think that's enough for now.
Let's move on to the second bit of causal attention code, which is this from the forward
method:
attn_scores.masked_fill_(
self.mask.bool()[:num_tokens, :num_tokens],
-torch.inf
)
In the original causal attention code we had this:
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
Raschka notes that the trailing underscore on masked_fill_
in the new code is
just PyTorch shorthand to say "do this operation in place", which explains why
we don't need to bind it to a new variable -- attn_scores
is being changed
directly.
The next bit is this [:num_tokens, :num_tokens]
at the end of the mask;
this is just because the original code was assuming that the mask was the right
size for the attn_scores
matrix that we were working with -- that is, all of
our inputs were the full context_length
in length, whereas here we're allowing
for variable lengths -- so we just take the sub-matrix of the mask that is required
for this scores matrix.
The elephant in the room is, of course: what about the batches? The docs indicate what's going on -- it's just another broadcast operation, so it will do exactly what we expect.
And that's it!
Wrapping it up
So, as it turned out, this was actually all pretty simple. There's a certain amount of depth there -- I still need to learn a bit more about PyTorch buffers -- but instead of the deep linear algebra that I thought I was going to have to learn to work out what multiplying order >2 tensors together actually means, we just have what amount to utility functions that know that dealing with batches of matrices is a common pattern, and allow us to treat them essentially the same as matrices, doing the sensible thing with them on that basis.
This is a good thing! You want the tools you're using to be optimised for the job that you're doing.
I think that for me the biggest lesson from working through this section is actually to look out for AIs as an enabler for excessive curiosity. The reason I took so long to work through something so relatively simple is that I could see that there were operations with higher-order tensors going on, and started chatting about what that means, and this stuff is fascinating (for very geeky values of fascinating). There's a ton to learn there, and I suspect that much of it is useful as well as interesting.
But as it turned out, just RTFMing was the best way forward. What we're looking at here isn't really advanced maths, it's just engineering. The docs are your friends, and are probably the best first port of call, especially if you want to keep yourself grounded and avoid side quests.
Conclusion
So, that was batching. Not as simple as I thought it might be, but -- having avoided a full-on diversion into deep linear algebra -- not a huge lift either. Next step: multi-head attention. My suspicion is that even more tensors and batched matrix multiplication helper functions will be involved somewhere...
Here's a link to the next post in this series.
-
In the interest of being maximally annoying, here's a niche joke that will infuriate those that get it because they'll be stuck with the world's worst earworm.
Eight, sir; seven, sir; Six, sir; five, sir; Four, sir; three, sir; Two, sir, one! Tenser, said the Tensor. Tenser, said the Tensor. Tension, apprehension, And dissension have begun. <RIFF> Tension, apprehension, And dissension have begun <RIFF> Tension...
-
To be more precise, because the
nn.Linear
is actually designed to do a forward pass of a normal NN, so what's going on internally is something more like , where is all zeros (because we setbias
toFalse
) and PyTorch is internally transposing as part of its initialisation. But that's kind of a nit at this point, I think. ↩ -
Correctly, IMO ↩