Writing an LLM from scratch, part 32k -- Interventions: training a better model locally with gradient accumulation
I've been working on a GPT-2-small-style LLM based on Sebastian Raschka's book "Build a Large Language Model (from Scratch)". I've trained various versions of it in the cloud to work out which interventions to the model and training code had the best effects on the loss it gets on a specific test dataset, and now I wanted to do a training run locally to match the best of those. For that, I wanted to match the batch size I was using for the cloud training runs.
When I first started learning this stuff, batching seemed like a performance thing -- with highly parallel systems like GPUs, it generally turned out that you could run a batch of (say) two inputs through a model in less than twice the time you could run one, so it made sense to batch them up.
For inference, that is exactly the advantage you get, but when training, it's become increasingly clear to me that you can also get an improvement in the quality of the model from batching. The best intuitive model I have is that if you run inputs through one-by-one, adjusting parameters after each, then it's easy for the model to "overcorrect" each time. With batches, you get an average set of gradients across all of the items -- which smooths things out and stabilises the training.
Of course, it's possible to overdo it. As an extreme example, imagine that you were somehow able to fit your whole training set into one batch -- then you could train by running that single batch through, doing a single backward pass, and then adjusting the parameters once. It's pretty clear that that would not work very well -- just one single update of the initially-random parameters.
When training on my local machine, I could fit a batch of six sequences into my RTX 3090. I'd found that when I moved to cloud machines, it had a very positive effect on the loss I got out of the models when I tested them. From a quick-and-dirty bit of curve-fitting, I estimated that the optimal batch size for this model, with that training run, was somewhere around 97. Conveniently, that was close to the maximum I could fit onto an 8x A100 40 GiB/GPU machine, so I used a batch size of 96 to test the different interventions I was trying.
And when I finally put all of the interventions that helped with training together, I found (somewhat to my surprise) that their combined effect -- an improvement in loss of 0.113765 -- was less than half of the loss improvement of 0.252474 that I had got from increasing the batch size.
What that all made clear was that if I wanted to do a local training run that matched the quality of the cloud-trained model, I'd need to not only add on the interventions that I'd been testing in detail, but I'd need to match the cloud batch size. And for that, I needed to learn about gradient accumulation.
Writing an LLM from scratch, part 32j -- Interventions: trying to train a better model in the cloud
Since early February, I've been trying various interventions on a 163M-parameter GPT-2-style model that I trained from scratch on my local RTX 3090, using code based on Sebastian Raschka's book "Build a Large Language Model (from Scratch)".
My original model got a loss of 3.944 on my test set, while the original GPT-2 weights got 3.500 on the same dataset. I wanted to see if I could close that gap, and had a list of potential changes to the training setup, and to the model itself. Which of them would help?
I found a list of solid-looking interventions, and in my last post I came to the conclusion that the improvements in loss I had seen with all of them -- with two possible exceptions -- seemed unlikely to be in the noise. What would happen if I tried to put them into a new model?
Writing an LLM from scratch, part 32i -- Interventions: what is in the noise?
Towards the end of last year, I trained a 163M-parameter GPT-2-style model from scratch on my local RTX 3090, using code based on Sebastian Raschka's book "Build a Large Language Model (from Scratch)".
The result was a pretty decent little model, but it wasn't as good as the original GPT-2-small, despite having more parameters (because it wasn't using weight-tying). Specifically: on a particular test set, my model gave a loss of 3.944 -- quite a lot more than the original GPT-2's 3.500 on the same dataset.
I wanted to see whether I could train a model on my own hardware (or on something that didn't cost too much to rent in the cloud) that got closer to the original model's performance. So over the last few months, I've done a bunch of further training runs, each one testing a specific intervention -- a stand-alone change that I expected to change the loss, either for better or for worse. Specifically:
- I trained a baseline model on an 8x A100 40 GiB per GPU machine on Lambda (which was better than my original locally-trained model, I believe due to the larger batch size that the larger machine made possible).
- I tried adding gradient clipping to see if that would help by limiting the effects of loss spikes.
- I tried removing dropout, given that these days people tend not to use it (because we're doing single-epoch training runs).
- I tried adding bias to the attention weight matrices -- something that was popular back in the GPT-2 era, and was used by the original weights, but which my code did not use.
- Instead of just using the learning rate of 0.0004 that was used in the code from the book, I looked into what values people use these days, and learned how to schedule it over the course of the training run.
- Similarly, I learned more about weight decay and tried some alternative values.
- Then I tried making my model more like the original GPT-2 one by introducing weight tying to see if that would help.
- Finally, I decided to try training in "full-fat" float32 instead of using PyTorch's AMP and TF32 matrix multiplication performance enhancements.
At the end of all of that, I had this table showing the effect of each intervention in terms of loss on the test set. They're sorted from least-effective to most-effective, and you can see the baseline in there too:
| Test set loss | Improvement vs baseline | |
|---|---|---|
| 8xa100m40-weight-tying | 3.874 | -0.182 |
| 8xa100m40-weight-decay-cerebras | 3.867 | -0.175 |
| 8xa100m40-baseline | 3.692 | - |
| 8xa100m80-no-amp | 3.679 | 0.013 |
| 8xa100m40-gradient-clipping | 3.678 | 0.014 |
| 8xa100m40-qkv-bias | 3.669 | 0.023 |
| 8xa100m40-weight-decay-gpt2 | 3.643 | 0.049 |
| 8xa100m40-remove-dropout | 3.641 | 0.051 |
| 8xa100m40-schedule-learning-rate | 3.602 | 0.09 |
Winners and losers are reasonably clear:
- Weight tying and the number for weight decay I derived from a paper by Cerebras Research (probably without understanding it properly) were negatives.
- Full-fat float32, gradient clipping, attention biases, the GPT-2 weight decay parameter, removing dropout, and scheduling (and updating) the learning rate were positives.
So, for an optimal train, we'd just use the effective interventions, right? Well, not quite.
Full-fat float32 I decided wasn't worth the effort, as it meant that the train took more than twice as long, and (because it required a larger machine), cost more than three times as much.
The others did look like solid changes, but there was one concern. The effect of each intervention is actually pretty small. For example, gradient clipping reduced the loss by 0.014, from 3.692 to 3.678. That's a 0.3% improvement. Even the best intervention, scheduling the learning rate, only improved things by 2%.
Could it be that some or all of these improvements were not real, but just a result of the random nature of training deep neural networks? Could the differences just be in the noise? They seemed small enough for that to be possible.
I've trained seven more models over the last few days to try to get a feel as to how big an effect noise has for this kind of training run. The results appear to show that variations in the initial weights matter quite a lot, but randomness in the training loop (given the same initial weights) actually has a fairly minimal impact. That surprised me a bit!
Let's go through the details.
Writing an LLM from scratch, part 32h -- Interventions: full fat float32
This is the last of the interventions I'm trying out to see if I can improve the test loss for a from-scratch GPT-2 small base model, trained on code based on Sebastian Raschka's book "Build a Large Language Model (from Scratch)".
Back when I did my first training run for a base model, on my local RTX 3090, I used two optimisations:
- Setting the 32-bit floating point matrix multiplication precision to "high" rather than to "highest", which means that it uses lower-precision (but still technically 32-bit) TF32 for those operations rather than normal float32.
- Using PyTorch's Automated Mixed Precision (AMP), which allows it to use 16-bit calculations rather than 32-bit in places where it makes sense to do so.
The first of those boosted training speed from 12,599 tokens per second to 15,402 in my test harness, while AMP on its own boosted it to 19,921 tps (and also allowed me to increase the batch size from 5 to 6). Doing both appeared to hit some kind of diminishing returns -- it maxed out at 19,997 tps, only a little better than AMP on its own.
But intuitively, you'd expect that might come at a cost. While I'm sure the PyTorch developers have solid understanding of where switching to 16-bit will have a minimal impact on training quality, it seems too good to be true that it would have no impact at all.
Let's see what happens if we switch both of these optimisations off!
Writing an LLM from scratch, part 32g -- Interventions: weight tying
In Sebastian Raschka's book "Build a Large Language Model (from Scratch)", he writes that weight tying, while it reduces the parameter count of a model, in his experience makes it worse. As such, apparently people don't use it in modern LLMs. Intuitively, that makes sense -- I'll explain why in this post.
But as I'm trying various interventions to see if I can get my model -- based on Raschka's code, but trained for a fraction of the time that the original GPT-2 model was -- to perform as well as the original in terms of the loss it gets on a test set, I thought it would be worth seeing if it really is a negative for this particular tiny model of 163M parameters.
After all, the original weights use weight tying, and I did find that QKV bias appeared to help -- and that's another old-school technique that they used, which has since dropped out of fashion. Might this one help too?
Worth a try! Let's give it a go.
Writing an LLM from scratch, part 32f -- Interventions: weight decay
I'm still working on improving the test loss for a from-scratch GPT-2 small base model, trained on code based on Sebastian Raschka's book "Build a Large Language Model (from Scratch)".
In my training code, I have this code to create the optimiser:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=0.0004, weight_decay=0.1
)
In my last post I looked
into the learning rate, the lr parameter in that code, and found a value for that,
plus some extra code to schedule it -- that is, to vary it over time -- which gave better training results.
This time I want to go into the weight decay. What is it, what is it for, and is 0.1 really the best value?
Writing an LLM from scratch, part 32e -- Interventions: the learning rate
I'm still working on improving the test loss for a from-scratch GPT-2 small base model, trained on code based on Sebastian Raschka's book "Build a Large Language Model (from Scratch)".
In my training code, I have this code to create the optimiser:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=0.0004, weight_decay=0.1
)
The values in there -- 0.0004 for the learning rate, and 0.1 for the weight
decay -- were just copied from the tiny training run that we do in section 5.2 of
the book.
What do those values actually mean, and are those really the right values for them?
I felt I had a good handle on the learning rate, at least -- it's one of the first things you learn when you start looking at machine learning of any kind -- but how would you go about working out what the correct value for it was? On top of that, when I was reading the Chinchilla paper a while back, I noticed they repeatedly referred to a "cosine cycle" for the learning rate, which didn't fit into anything I'd learned about before.
The weight decay was pretty much an unknown for me -- I know it is a parameter controlling the behaviour of the optimiser, but I don't know how it does that.
In this post I want to look into the learning rate, and these mysterious cosines; I'll write a follow-up about the weight decay later.
Writing an LLM from scratch, part 32d -- Interventions: adding attention bias
I'm still seeing what I can do to improve the test loss for a from-scratch GPT-2 small base model, trained on code based on Sebastian Raschka's book "Build a Large Language Model (from Scratch)". This is the third intervention I'm trying: adding bias to the attention weight matrices.
In the code from the book, we have this:
class MultiHeadAttention(nn.Module):
def __init__(
self,
d_in, d_out,
context_length,
dropout,
num_heads,
qkv_bias=False
):
...
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
...
def forward(self, x):
...
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
So: we initialise the weights , and as linear layers rather than
simple matrices of weights, and have a parameter qkv_bias to say whether or not we should
add bias to those. In all of our trains so far we've set that to False.
Why do we have this parameter, and where did it come from?
Writing an LLM from scratch, part 32c -- Interventions: removing dropout
This is the second in my series of attempts to improve the loss on my test dataset -- interventions, as I'm calling them -- for a from-scratch GPT-2 small base model, trained on code based on Sebastian Raschka's book "Build a Large Language Model (from Scratch)".
Last time around I saw what gradient clipping can do -- it improved loss over the baseline by 0.014, bringing it down from 3.692 to 3.678. Not much, but it's something!
This time, I wanted to see what happened if we trained without dropout. Would removing it make the test loss worse, or better?
Writing an LLM from scratch, part 32b -- Interventions: gradient clipping
I'm still working on training the best GPT-2 small sized base model that I can with a number of FLOPs roughly equal to two days on my own machine -- my "extra credit" exercise after having worked through Sebastian Raschka's book "Build a Large Language Model (from Scratch)".
In the last post I trained a baseline model -- one with the same architecture and almost the same training code as in the minimal training run in the book, just modified to run using DDP on an 8x A100 40 GiB/GPU machine in the cloud. There are a bunch of "interventions" I want to try to see if they'll make it better, as measured by the loss they get on a test set. I'll do a post for each intervention, and this is the first: gradient clipping.