Writing an LLM from scratch, part 32f -- Interventions: weight decay

Posted on 23 March 2026 in AI, LLM from scratch, TIL deep dives, Python

I'm still working on improving the test loss for a from-scratch GPT-2 small base model, trained on code based on Sebastian Raschka's book "Build a Large Language Model (from Scratch)".

In my training code, I have this code to create the optimiser:

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=0.0004, weight_decay=0.1
    )

In my last post I looked into the learning rate, the lr parameter in that code, and found a value for that, plus some extra code to schedule it -- that is, to vary it over time -- which gave better training results.

This time I want to go into the weight decay. What is it, what is it for, and is 0.1 really the best value?

I was a little concerned going into this that in order to understand this hyperparameter, I'd need to have a good understanding of how the optimiser works; I've been building what I think is a solid mental model of optimisers, but I don't think I understand them well enough to explain them yet, and I've been hoping to delay posting about them to a separate blog post series after this one.

The good news is that while weight decay is an important aspect of how optimisers work -- the "W" in AdamW, the thing that makes it different to the older Adam optimiser, is a nod to its different treatment of weight decay -- you don't need to know how the optimiser itself works to understand what weight decay is.

Instead, you just need to consider an older and more fundamental aspect of building ML systems -- regularisation. In order to dig into that, let's start with overfitting.

Overfitting

Let's imagine a simple classification task: we want to build a model that can -- for any point on this chart -- predict whether a cross or a circle should go there, training it using the sample data points that we already have:

Overfitting example: the training data

Let's say that we train a powerful model on this dataset, and it comes up with this:

Overfitting example: the powerful model's solution

Now, ab initio we don't know whether that's a good result or not; we need to use our validation set to evaluate it. Let's say that the validation points are these blue ones:

Overfitting example: validation samples show that it's overfit

We can see that it looks like our powerful model has overfit. The training set is all nicely split by the boundary, but the validation points are not.

A common solution to how to handle that kind of issue that you might see in introductory ML courses is to try using a less powerful model. A less powerful model in this case might come up with a less "wiggly" line to separate the two categories, perhaps because it didn't have enough parameters to make it wiggle so much, so you might find that it came up with a classifier that looked more like this:

Overfitting example: a simpler model fits perfectly

So: we use our validation set to detect overfitting, and we can adjust the complexity of our model to try to avoid it.

Now, this is all very well, but it does require manual intervention. We had to do a training run, identify that we were overfitting, and then decide on parameters for the new simpler model (how many parameters should it have?). We could, perhaps have gone too far and wound up with something like this:

Overfitting example: a too-simple model underfits

...and underfit.

There's no way when we start out knowing what the right number of parameters is, so we need to try various values and then try to work out the optimum balance.

Regularisation techniques are designed to try to automate this -- to prevent overfitting without all that tedious mucking about with the model.

Regularisation

We've already looked at Dropout, which is one of the standard ways to do that. Although my own mental model of what it does goes some way beyond just helping to prevent overfitting, I may well be wrong -- and given that our LLM train is never seeing the same training data twice, being a single-epoch run, removing it turned out to improve our model.

Another technique is just stopping the training run when you start seeing the validation loss rise, also known as "early stopping". That's such an obvious thing to do that I came up with it independently back when I was doing my early experiments with fine-tuning. Now, we don't have a separate validation set for these training runs, but because we're doing a single epoch, the training data it sees is just as "new to it" as a held-back validation set would be, so we could use a similar trick and treat "train loss starts rising" instead of validation loss rising as a reason to stop the train early. It's not exactly the same thing, but perhaps it would be close enough.

But in all of the trains in this series, that's never happened -- while sometimes the train loss blips up for a bit, in the longer term it keeps going down.

But there are other techniques that rely on a neat trick. Let's think back to the manual, boring way of trying to find how many parameters are appropriate for a modelling task. We tried one number, found that it overfit, then we might try a lower one, find that it underfit, then try something in the middle and find that it's better but still not perfect one way or the other, and rinse and repeat until we find something we're happy with. This kind of searching through a solution space to find an optimum is exactly what we're doing when training a model. It would be really nice to automate it in the same way.

One trick is: if we want to minimise the complexity of our model so that it doesn't overfit, we can try adding a measure of the model's complexity to the loss function -- and then our normal process of gradient descent will try to minimise that, just like it will try to minimise the loss from the training results themselves. And that brings us on to weight decay.

Weight decay (finally!)

Regularisation by weight decay starts off with the hypothesis that the "size" of all of the model's weights, taken together, is a measure of the model's complexity. If the model's weights are small, then it's a simpler model than if they're large. 1

The "size" in this sense is the square of the L2 norm -- that's something we came across in gradient clipping. The L2 norm is basically all of the weights squared, added together and then the resulting sum square-rooted. You can think of it as the length of the vector that the weights represent -- that is, for our 163M-parameter model, it would be the length of the model's weights considered as a vector in 163-million dimensional space. 2 And by using its square, we get something that penalises larger values more (and we also save the time in calculating a square root).

To me, it's not intuitively obvious that that measure really does express the complexity of the model in any clear sense. After all, you'd think that doubling all parameters would leave it no more complex than it was before, but it would double the L2 norm. 3 But I imagine there is solid maths behind it to say that it does work in a more general way, so in the interests of not disappearing down a mathematical rabbit hole at this stage, I'll take it as given.

So: we're using the squared L2 norm as a measure of model complexity, and we're going to add that on to the training loss as a way to try to minimise both. The next question is, how do we balance between the two -- the training loss and the model complexity penalty?

This is, in a somewhat hand-wavy way, similar to the decision of how much of the current loss function's gradient to use when adjusting the weights. For that, we use η, the learning rate to scale the gradients before applying them:

wnew=wη×gradient

And the balance between the "real" loss and the model complexity penalty is done in a similar way -- we have a number, the weight decay, normally represented by a lower-case lambda, λ, and we multiply the squared L2 norm by that, something like this:

=+λ·N2

...where I'm using for the normal loss on the training inputs vs the targets, N2 for the squared L2 norm of the weights, and for the combined loss. And is what we -- in theory -- actually try to minimise using our optimiser.

But there's actually a neat simplification that we can apply to make this even easier.

Firstly, let's make one small change to the equation above: we'll halve the squared L2 norm before multiplying it by λ. That obviously doesn't change the underlying maths, it just means that we'd need to use larger values for λ to get the same effect. You'll see why that's useful in a bit.

=+λ·N22

Now let's think about normal gradient descent. Again, we work out the gradient of the loss function for each weight, and subtract that times the learning rate η from the weight's value to update it:

wnew=wη×gradient

Let's reformulate that a bit. The gradient of the loss function for the weight is its partial derivative against that weight, so we can write the above like this for the version of the loss function including weight decay, :

wnew=wηw

Now, we defined above as +λ·N22, so we can substitute that in there:

wnew=wηw(+λ·N22)

Now, let's think about that L2 norm, N. It's the square root of the sum of all of the weights squared, or equivalently we can square it (like we do in the formula above) and say:

N2=w02+w12++wn2

Let's drop that in:

wnew=wηw(+λ·w02+w12++wn22)

Now, the derivative of a bunch of things added together is just each of them differentiated separately and then added together. Let's apply that to the two terms in the brackets:

wnew=wη(w+w(λ·w02+w12++wn22))

...and now pull the constant λ and the 2 out of the second partial derivative:

wnew=wη(w+λ2w(w02+w12++wn2))

Then we apply the rule for the derivative of a bunch of things added together again:

wnew=wη(w+λ2(ww02+ww12++wwn2))

Now, we're doing a partial derivative versus one specific weight, w, which is one of the w0, w1, and so on in there. From that perspective, all of the other weights are constant -- which means that their derivative with respect to w is zero. So we can just get rid of all of them apart from the one that actually is w, and we wind up with this:

wnew=wη(w+λ2(ww2))

The derivative of w2 with respect to w is just 2w. Thanks to that crafty halving of the N2 earlier, that means that we can go to this:

wnew=wη(w+λ·w)

Multiplying that η across the bracketed terms, we get:

wnew=wηwηλw

That's exactly the same as the normal gradient descent update, using the unmodified loss function without weight decay -- except that we're additionally subtracting the weight's original value scaled down by both the learning rate η and the weight decay value λ.

Much simpler :-)

(As an aside: the description above is correct for "traditional" simple gradient descent and -- loosely -- for Adam, but AdamW's trick is to do things somewhat differently. That's something I'll go into in more detail when I get round to writing my post on optimisers.)

So: weight decay is a regularisation technique that tries to prevent our model from getting any more complex than it needs to be. We have one number, λ, which determines how much to weight complexity against the normal training loss. And, as we can see from the code:

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=0.0004, weight_decay=0.1
    )

...right now we're setting λ to 0.1. Is that the right value?

The literature

Best guesses for GPT-2

As usual, the GPT-2 paper is light on the details of the hyperparameters they used, but nostalgebraist wrote a really nice post on Tumblr where they dug into what the number might have been. As they say:

It does say it follows the first GPT paper in most respects, and that paper used weight decay of 0.01.

Their link for the paper appears to be mistaken, as it's a different (albeit very interesting) paper from 2020, a year after the GPT-2 one, but I believe this is the paper normally called the GPT-1 one. They do indeed use 0.01 there:

We also employed a modified version of L2 regularization proposed in [37], with w = 0.01 on all non bias or gain weights.

The link to the GPT-3 paper looks right, though, and as they say, it uses a weight decay of 0.1:

All models use weight decay of 0.1 to provide a small amount of regularization

They then do a bit of maths to work out whether the GPT-2 weights are likely to have been regularised by something like weight decay, and come to the conclusion that they probably used 0.01, just like the GPT-1 paper. It seems plausible, but of course not certain.

But: tentatively, GPT-2 used 0.01, while we're using 0.1, perhaps because the GPT-3 paper does. What other data points do we have?

The Hugging Face "Smol training playbook" has some interesting stuff (including not using weight decay on embeddings, which they say they found helped), but the value that they use is 0.1, which they call "a very vanilla setting". And:

Interestingly, over the last few years the AdamW hyperparameters have barely moved:

The same triplet is reused in Llama 1, 2, and 3 and DeepSeek-V1, V2, and V3-671B, with no changes.

Anyway, assuming they're right about weight decay value for the models they mention (and I assume they've done the research -- I had the link to the DeepSeek paper to hand, and that one certainly says 0.1), it looks like 0.1 is pretty much standard these days. And a quick double-check of what a typical value would be -- asking ChatGPT, Claude, Gemini and Grok -- they all recommend 0.1 as a solid sensible default with AdamW (though they all also say that values between 0.01 and 0.1 are reasonable).

So on that basis, I think we can say that 0.1 is a reasonable default, and has pretty much become the standard, but it might be worth trying 0.01 just to see if it does help with tiny models like ours.

Are there any dissenting voices to the 0.1 orthodoxy?

The Cerebras paper

I came across a paper from a team at Cerebras Systems, "Power Lines: Scaling Laws for Weight Decay and Batch Size in LLM Pre-training".

It's essentially a Chinchilla-like attempt to get scaling laws, but rather than looking just at optimal tokens per parameter in order to work out what you should scale up when adding on more compute, they're trying to find optimal batch sizes and values for weight decay. That's certainly relevant to our interests :-)

However, it is very dense and in-depth, and fully understanding it at this stage would need quite a lot of work -- very much a side quest. Definitely something to come back to later, but for now, I'll just try to extract the stuff we need.

Let's start off with the optimal batch size, as they have it right there on the first page. We're not going to use it, but it will be interesting to compare with what we're using, and what the DeepSeek paper that I looked at in the last post suggested.

They fit this formula:

Bopt=0.0306·D0.383

...where D is the total number of tokens that you're training on. That's quite different to the formula in the DeepSeek paper, which was:

Bopt=0.2920·C0.3271

...where C is the number of FLOPs 4. C scales up linearly with the number of tokens D, but also with the number of parameters in the model N, so you can see the DeepSeek formula as a function of N and D -- as your model gets bigger, so does Bopt -- whereas this Cerebras paper is saying that it's just a function of D, unaffected by model size. They did train over a number of different sizes (from 111M parameters up to 1.7B) and their formula seems to hold, so it's not just that they didn't treat model size as relevant.

Well, let's see what their formula comes up with. We have 3,260,252,160 tokens in our train, so their formula for Bopt comes out as:

0.0306·3,260,252,1600.383 0.0306·4401.24 134.68

That's much closer to the 97-or-so sequences that appeared to be optimal when I did some rough-and-ready curve-fitting than the 373 that the DeepSeek formula gave for our setup :-)

OK, so what about the weight decay? They don't give a direct formula for that, but they do give a formula for the optimal τ, the AdamW timescale. Without going into exactly what that means right now (that's one for my optimisers post later), they relate it to other numbers that we do know with this formula:

τ=BηλD

...where B is the batch size, D is the amount of data, and of course λ and η are weight decay and learning rate respectively. So if we know the optimal τ we can work out the optimal λ for our training run; solving for λ, we get:

λopt=BητoptD

So let's work out the τopt. Their fitted formula is this:

τopt=1.084·(TPP)0.527

...where TPP is tokens-per-parameter. For us, with our Chinchilla-optimal TPP of 20, we get:

=1.084·200.527 1.084*0.2062 0.22352

Now, we're using a batch size B of 96, and (as before) D is 3,260,252,160. Our learning rate η is 0.0004 for this train -- remember, although in the last post we found that a scheduled learning rate with a peak at 0.0014 was better, in this post we're testing changing weight decay in isolation. 5

So, we just need to plug our τopt into this:

λopt=BητoptD

...right?

Before we do: having a batch size and a number of tokens in the same formula feels like a unit mismatch.

In particular, as part of the explanation of that formula, they tie it back to a value S, the total number of optimisation steps, which they define as D/B. For that to work, either both need to be in terms of tokens, or both need to be in terms of sequences

They clearly say that "B is reported in units of sequences". I'm not sure how to explain this, except by saying that perhaps the D is also meant to be in terms of sequences too, even though I'm pretty sure that it's meant to be in terms of tokens in the equation for the batch size. 6

Well, let's assume that is the case, and plug in numbers for sequences. We have 3,260,252,160 training tokens split into 1,024-token sequences, which is 3,183,840 sequences, so that comes out as:

λopt=960.0004·0.22352·3,183,840

(Note that we'd get the same numbers if we plugged in numbers for tokens in both cases, as it would just multiply the top and the bottom by 1,024.)

That comes out as 0.33724. Wow! That's even higher than the "traditional" 0.1, never mind the 0.01 that is the best guess we have for GPT-2. Even if I'm missing something here (I certainly can't say I've read the paper in as much detail as it deserves), that actually gives us a nice number to try out as an experiment.

We already have a loss on our test set for a model trained with a weight decay of 0.1, as that was what we used in our baseline train. It looks like it might be worth doing two more, one with the GPT-2 estimate of 0.01, and one with this Cerebras-inspired 0.33724, neatly bracketing it. Let's give them a go!

Training with the estimated GPT-2 weight decay, 0.01

Firstly, the training run with λ=0.01:

Loss for training run with estimated GPT-2 weight decay, 0.01

Looks like a nice smooth train -- one small loss spike near the start but it quickly recovered. The output was:

Training complete in 12,244.461 seconds
Tokens seen: 3,260,252,160
Throughput: 266,263 tokens/second
Final train loss: 3.697

That's not a bad final train loss (which does tend to indicate a good model).

Let's look at the evals; firstly, the smoke test -- how would it complete "Every effort moves you"?

Every effort moves you and any of the other staff at a time by using different techniques from the other members or staff who

Passably coherent. Let's take a look at the loss it gets on our test set:

Loss against our test dataset: 3.643

Not bad at all! Time to upload it to Hugging Face and to add it to the table so that we can compare it to the other interventions we've tried so far.

Test set loss Improvement vs baseline
8xa100m40-baseline 3.692 -
8xa100m40-gradient-clipping 3.678 0.014
8xa100m40-qkv-bias 3.669 0.023
8xa100m40-weight-decay-gpt2 3.643 0.049
8xa100m40-remove-dropout 3.641 0.051
8xa100m40-schedule-learning-rate 3.602 0.09

So, it's better than gradient clipping and the QKV bias, but slightly worse than removing dropout and much worse than scheduling (and increasing) the learning rate.

Now, that suggests to me that the much-higher Cerebras-inspired weight decay will be worse. My logic is this: if both decreasing it and increasing it improved loss, that would suggest that we have an inverted-U loss curve for weight decay like this:

Intuition on weight decay changes, part 1: an inverted U-shaped curve

Now, it seems vanishingly unlikely that those downward trends on either side would continue so that you could get arbitrarily low loss by increasing or decreasing weight decay even more. So the curve would perhaps look a bit more like this W-shaped one:

Intuition on weight decay changes, part 2: a W-shaped curve

My intuition is that having multiple minima -- especially ones that just happen to be on either side of the "standard" value for weight decay -- seems less likely than the alternative -- that the higher number will be worse because we're actually on a U-shaped curve more like this:

Intuition on weight decay changes, part 3: a U-shaped curve

Of course, my intuition could be completely off on this, and it's definitely still worth doing the test!

Training with the Cerebras-inspired weight decay, 0.33724

Here's the loss chart with that:

Loss for training run with Cerebras-inspired weight decay, 0.33724

You can see right away that it was a much choppier train, with quite a few loss spikes, some quite late on. The output at the end reflected this:

Training complete in 12,254.884 seconds
Tokens seen: 3,260,252,160
Throughput: 266,037 tokens/second
Final train loss: 3.867

...a significantly worse loss at the end. Still, we should do the evals. Firstly the smoke test:

Every effort moves you so far away (but to no surprise, a real change was the only goal I did in 2009

Not too bad, but the loss test is the important one:

Loss against our test dataset: 3.814

That's terrible! Our first result for loss on the test set for an intervention that is actually worse than the baseline. Much worse:

Test set loss Improvement vs baseline
8xa100m40-weight-decay-cerebras 3.814 -0.122
8xa100m40-baseline 3.692 -
8xa100m40-gradient-clipping 3.678 0.014
8xa100m40-qkv-bias 3.669 0.023
8xa100m40-weight-decay-gpt2 3.643 0.049
8xa100m40-remove-dropout 3.641 0.051
8xa100m40-schedule-learning-rate 3.602 0.09

However, at this point I started wondering. When I was looking at the learning rate, the number I selected based on the DeepSeek paper worked well with learning rate scheduling, but failed to converge without. The weight decay number is multiplied by the current learning rate before it's used to reduce weights' values, so will be affected by both scheduling and η.

It seemed likely that Cerebras used a learning rate schedule, and double-checking the paper:

We present results with a single (standard) learning rate schedule ... For a given TPP, all models have the exact same warmup phase: a linear warmup of the learning rate from 0 to the maximum value. ... We use the µP-tuned and adjusted peak η, for 111M models. The learning rate increases linearly to the peak for the first 10% of steps, then decreases from the peak to 0 for the remainder of steps.

Seems pretty certain.

Now, I've been following a fairly strict rule of testing interventions in isolation; however, the learning rate and the weight decay parameters are so intertwined that perhaps that's just not reasonable here. I decided to do two more trains, both with learning rate scheduling. I'd use the same schedule as in the last blog post -- a warmup from pretty-much zero to the peak over 10% of the run, followed by a cosine decay to 10% of the peak. In the first, I'd use the same learning rate as our baseline model, 0.0004. In the second, I'd use the one we got from the DeepSeek paper, which did really well when scheduled: 0.0014.

Training with the Cerebras-inspired weight decay, scheduling learning rate with a peak of 0.0004

Loss for training run with Cerebras-inspired weight decay, 0.33724, scheduled learning rate peaking at 0.0004

Well, that's less choppy, at least -- the scheduling calmed down the later parts of the run, as you'd expect given that the learning rate was dropping. The output:

Training complete in 12,275.101 seconds
Tokens seen: 3,260,252,160
Throughput: 265,599 tokens/second
Final train loss: 3.801

Still a kind of high training loss at the end, though. The smoke test:

Every effort moves you toward the most beneficial method for the best outcomes is for your child, whether you are an expert,

Not too bad, and the test set loss:

Loss against our test dataset: 3.739

Unfortunately still worse than the baseline of 3.692, albeit better than the one without learning rate scheduling. I'm not going to add it to the table, as this was more in the way of an exploratory training run.

Let's see how we do with the larger DeepSeek-suggested learning rate.

Training with the Cerebras-inspired weight decay, scheduling learning rate with a peak of 0.0014

For this one, I kept the weight decay at 0.33724. (This was an error, as I realised later -- more on that shortly)

Loss for training run with Cerebras-inspired weight decay, 0.33724, scheduled learning rate peaking at 0.0014

Ouch, super-choppy loss -- and the loss at the end of the train isn't promising either

Training complete in 12,258.767 seconds
Tokens seen: 3,260,252,160
Throughput: 265,953 tokens/second
Final train loss: 3.911

Terrible loss at the end. The smoke test gives this:

Every effort moves you forward, it doesn´t take time for the next chapter in our history as much time or a

...which is not too bad, but the test set loss:

Loss against our test dataset: 3.847

...is still pretty terrible (though still a tad better than the one without the learning rate scheduling).

Another one to throw away, I think.

But then something occurred to me: the formula to go from the optimal AdamW time horizon τopt to the optimal weight decay λopt is this:

λopt=BητoptD

It has the learning rate η in it -- I even made a footnote saying that I was going to have to remember to recalculate the weight decay value when that changed :-S

Luckily, though, running the real numbers through that:

λopt=960.0014×0.22352×3,183,840 0.096355

...which is almost exactly the same as the 0.1 that we've been using for all of our other experiments.

So that actually suggests that the Cerebras equations come up with a reasonably usable number for weight decay if you use the DeepSeek-optimal level for the learning rate, and schedule it in a normal warmup-cosine decay manner. But it's still not as good -- for this model -- as using the GPT-2 number. 7

With that, I think it's time to wrap this intervention up!

Conclusion

Let's look at our results table again:

Test set loss Improvement vs baseline
8xa100m40-weight-decay-cerebras 3.814 -0.122
8xa100m40-baseline 3.692 -
8xa100m40-gradient-clipping 3.678 0.014
8xa100m40-qkv-bias 3.669 0.023
8xa100m40-weight-decay-gpt2 3.643 0.049
8xa100m40-remove-dropout 3.641 0.051
8xa100m40-schedule-learning-rate 3.602 0.09

We've found that reducing the weight decay from the now-standard 0.1 to a GPT-2-inspired 0.01 improves the loss our model gets on the test set; it's the third-best intervention so far, after getting rid of dropout and updating our learning rate -- and the difference between it and the dropout intervention is pretty small.

It did surprise me that the Cerebras-inspired number did so badly, though. To recap:

I think that for now, I should not head any further down this rabbit hole and just take the win -- we have a weight decay parameter that works better than the one we had, and so that's something that can go into our set of working interventions. I can revisit the Cerebras paper later when I've spent more time studying optimisers.

As to why this old-fashioned GPT-2 value might work better than the current default of 0.1: I think that could plausibly be due to scale. The 0.1 value appears to come from the GPT-3 paper, which essentially was an experiment in scaling up GPT-2. Perhaps larger models need larger weight decays? And the model we're working with here is really small, at 163M parameters.

So, that's weight decay done! Of the list of planned interventions I wanted to try, only training in full-fat 32 bits (rather than AMP), and weight-tying remain. I think I'll look into the second of those next. Stay tuned!

Here's a link to the next post in this series.


  1. More precisely, from Deep Learning:

    Minimizing J(w) results in a choice of weights that make a tradeoff between fitting the training data and being small. This gives us solutions that have a smaller slope, or that put weight on fewer of the features.

    ...where J(w) is the loss function we're trying to minimise in our training run, combining the "real" loss and a measure of the model's size. 

  2. I can't decide whether that makes it easier or harder to understand ;-) 

  3. Wild speculation: how about something using the Shannon entropy of the weights...? 

  4. Specifically the non-embedding training FLOPs. 

  5. Note to self: don't forget to adjust it if we do decide to combine this with the learning rate update. Also: I'm pretty sure from reading the paper that the η that they're using in these formulae is the peak -- they certainly are using learning rate scheduling, albeit with a decay-to-zero rather than the decay-to-10% we used. 

  6. Plugging in the number of sequences into the batch size formula gives us an optimal value of 9.47, which definitely doesn't look right based on the trains I've done. 

  7. Assuming that the GPT-2 value for weight decay "stacks up" well with the learning rate update and the scheduling from the last post. There may be some useful tests to do when we try to put this all together.