Writing an LLM from scratch, part 21 -- perplexed by perplexity

Posted on 7 October 2025 in AI, LLM from scratch, TIL deep dives

I'm continuing through chapter 5 of Sebastian Raschka's book "Build a Large Language Model (from Scratch)", which covers training the LLM. Last time I wrote about cross entropy loss. Before moving on to the next section, I wanted to post about something that the book only covers briefly in a sidebar: perplexity.

Back in May, I thought I had understood it:

Just as I was finishing this off, I found myself thinking that logits were interesting because you could take some measure of how certain the LLM was about the next token from them. For example, if all of the logits were the same number, it would mean that the LLM has absolutely no idea what token might come back -- it's giving an equal chance to all of them. If all of them were zero apart from one, which was a positive number, then it would be 100% sure about what the next one was going to be. If you could represent that in a single number -- let's say, 0 means that it has only one candidate and 1 means that it hasn't even the slightest idea what is most likely -- then it would be an interesting measure of how certain the LLM was about its choice.

Turns out (unsurprisingly) that I'd re-invented something that's been around for a long time. That number is called perplexity, and I imagine that's why the largest AI-enabled web search engine borrowed that name.

I'd misunderstood. From the post on cross entropy, you can see that the measure that I was talking about in May was something more like the simple Shannon entropy of the LLM's output probabilities. That's a useful number, but perplexity is something different.

Its actual calculation is really simple -- you just raise the base of the logarithms you were using in your cross entropy loss to the power of that loss. So if you were using the natural logarithm to work out your loss L, perplexity would be eL, if you were using the base-2 logarithm log2 then it would be 2L, and so on. PyTorch uses the natural logarithm, so you'd use the matching torch.exp function.

Raschka says that perplexity "measures how well the probability distribution predicted by the model matches the actual distribution of the words in the dataset", and that it "is often considered more interpretable than the raw [cross entropy] loss value because it signifies the effective vocabulary size about which the model is uncertain at each step."

This felt like something I would like to dig into a bit.

When we're training an LLM, the cross entropy loss we get after one step is based on a whole bunch of different sequences and their associated targets. A sequence like "The fat cat sat on the", with its target " fat cat sat on the mat", will have one prefix sequence for each token in the input sequence, and each one will map to one target token:

...and so on. And every other input sequence/target sequence in the batch will have the same.

Our cross entropy loss is the arithmetic mean of all of those per-sequence/target losses.

To unpick perplexity, I found it easiest to start thinking about what it meant on a per-sequence/target level -- as always, I'll use

"The fat cat sat on the" -> " mat"

Now, in the last post, we found that the loss for this specific pair was this:

L=logpcorrect

...where pcorrect was the probability that the model assigned to " mat" -- that is, the correct answer -- in the distribution we got by softmaxing the logits. This is what the full cross entropy formula reduces to if the real distribution that we're comparing our model to is just a one-hot distribution with a probability of one for the target token, and zero for everything else.

Let's work out the perplexity for that; if we use the natural logarithm ln, then we need to raise it to the power of e, so we get this:

Perplexity=elnpcorrect

I had to refresh my knowledge of logarithms a bit for this next bit, but we can simplify that. (Apologies for readers who are more current in this for the step-by-step derivation.)

To start: log1 in any base is 0, so we can rewrite the perplexity equation in a slightly more complex form as:

Perplexity=eln1lnpcorrect

There is an identity for all logarithms:

logxy=logxlogy

So we can apply that rule and get:

Perplexity=eln1pcorrect

By the definition of logarithms,

blogbx=x

...so we can finally simplify to:

Perplexity=1pcorrect

Phew. That exercised a few brain cells that had been sleeping since the early 90s. Let's think about what it means in some specific cases.

Firstly, imagine that the LLM is 100% certain that the next token is " mat". That means that the probability vector it has output (logits post-softmax) has a one in the position for " mat" (and therefore zeros everywhere else). That means that pcorrect=1, so perplexity is 1/1=1.

Now imagine that the LLM had no idea which token might come next. We'll assume that its output is a probability distribution where every number is 1/N, where N is the number of tokens in the vocabulary. That means that our perplexity is N.

This gives a hint as to what Raschka meant by "signifies the effective vocabulary size about which the model is uncertain at each step". In the first case, it was certain -- that is, the subset of the vocabulary that it was considering as options had a size of 1, and that's the number we have for perplexity. In the second case, it was considering all N tokens in the vocab as possibilities, so that's the number we have for perplexity.

We can extend that -- let's say that it was dithering between four out of the N tokens in the vocabulary of N, thinking they were all equally likely (and the others had zero probability). Each of those would have a probability of 1/4, and so perplexity would be 4.

Or would it? There is a wrinkle.

Remember, we're only considering the probability it assigns to the correct token -- our target for the training. Let's say that we have a four-token vocab, and the target in the training item we're considering is the second token (index 1 in the lists that follow).

We might get this probability distribution:

[0.25, 0.25, 0.25, 0.25]

But we might also get this one:

[0.0, 0.25, 0.0, 0.75]

In both of those cases, pcorrect is 0.25, so our perplexity will be 4. But while in the first case we could reasonably say that it was uncertain about which option to pick from a vocab size of four, in the second case it's really only looking at a vocab size of two at the most -- indeed, a number between one and two would seem reasonable, probably closer to one.

The important thing, though, is that it was wrong about the choice it ascribed the high probability to in that second case. So I think we can refine the description a little. Perplexity doesn't measure how many vocab items the model was choosing between (which, again, would be something like the Shannon entropy of the probability distribution) -- that's an oversimplification of what Raschka was saying. It has to include some kind of reference to what the right answer is.

Let's go back to something from the last post, the section on "Certainty". Our training for LLMs is slightly artificial, at least in the case of a single sequence-target pair. In our example where we're feeding in "The fat cat sat on the" and telling it that the target is " mat", we're using a one-hot vector for the target distribution -- that is, we're saying that the sequence must continue with " mat". That's obviously incorrect -- it could be any of a number of possible tokens, like " lap" or " dog" or " desk".

We don't do that kind of "label smoothing" training with LLMs, because it would be hard to set up and we're training on so much data that the different gradient updates average out to the real distribution over time. But imagine if we did. We have a probability distribution p which is the real distribution of the next tokens, and our LLM's prediction q, so cross-entropy loss (using natural logarithms) is:

H(p,q)=xp(x)·lnq(x)

So our perplexity is:

Perplexity=exp(xp(x)·lnq(x))

Pulling the negation back inside the sum we get:

Perplexity=exp(xp(x)·(lnq(x)))

Now we can apply that lna=ln1a rule we went through step-by-step above:

Perplexity=exp(xp(x)·ln1q(x))

Now another high-school rule; if you take something -- say, z -- to the power of a+b+c -- that is, za+b+c -- it's the same as multiplying za, zb and zc. For example:

101+2+3=101×102×103=10×100×1000=1000000=106

I don't remember seeing it at school, but there is a big-pi operator for doing the product of a series to match the familiar big-sigma for sum, so we can express that (for the exp operator that we're using to mean e to the power of something) as:

exp(ixi)=iexp(xi)

So we can apply that to our perplexity equation to get this:

Perplexity=xexp(p(x)·ln1q(x))

Yet one more high-school rule:

za·b=(za)b

So let's swap around our terms in the exp function:

Perplexity=xexp(ln1q(x)·p(x))

...then apply that:

Perplexity=x(exp(ln1q(x)))p(x)

Now, clearly exp(lnx)=x, so we can simplify:

Perplexity=x(1q(x))p(x)

I hope you found that as fun as I did :-)

Now, the 1/q(x) that we have there is the equivalent of the perplexity measure that we were using for our original one-hot calculation -- the inverse of the probability that the model assigned to a given token. What we're doing is iterating over all of the predictions, working out that number, and taking it to the power of the real-world probability of it happening. Those results are then multiplied together to get our overall perplexity.

Let's think of how that works with real numbers.

Imagine that a token doesn't exist at all as an option in the real-world probability distribution p. That means that p(x) will be zero, so the contribution it makes to the overall perplexity will be its per-token perplexity to the power of zero, which is one -- as we're multiplying numbers together, that means that it will have no effect.

If it is 0.5 in the real-world data, then the contribution of the perplexity will be scaled down by being square-rooted (that is, raised to the power of 1/2).

And if it is 1 in the real-world data, it will be raised to the power of 1 -- that is, it will be fed through unchanged.

So we've got a setup where we're taking the per-token perplexity and we're using the real-world probabilities to scale how much it contributes -- the more likely a token is in the real world, the higher the power we're raising it to (though these powers are all less than one, being probabilities, so each item will contribute less than its per-token perplexity).

And, of course, if we feed in a one-hot vector for p then all of the terms but one will have p(x) of zero, so they’ll be raised to the power 0 (i.e. 1). Then for that case where p(x) is one, it will be passed through unchanged -- that is, it will collapse to the original equation:

Perplexity=1qcorrect

OK, so at this point we've shown that perplexity, if used against a cross entropy loss that compares how the LLM is doing with respect to the real world probabilities, will give us a number that combines how "confused" the model was about each possible output token, scaled by the probability of that token in the real world. It's not actually all that different from cross entropy loss (which is unsurprising, given that it's just a number raised to the power of that loss).

How does that help in our world where we're using one-hot vectors?

Let's start by bringing back in something we put aside at the start of this post. For a given training run, we have a batch of sequences, with targets for each prefix sequence. So for that run we have b·n sequence/target pairs for a batch size of b and a sequence length of n. Let's call that number T.

The cross entropy loss we have for the ith of those pairs, using our one-hot equation, is:

Li=lnq(tcorrect)

Now, we're just taking the arithmetic mean of those different per-pair losses to work out the cross entropy of the whole training batch. Let's write that as:

L=1Tt=1Tlnq(tcorrect)

Now we work out the perplexity:

Perplexity=exp(1Tt=1Tlnq(tcorrect))

Using loga=log1a again, we get:

Perplexity=exp(1Tt=1Tln1q(tcorrect))

Moving the division inside the sum:

Perplexity=exp(t=1T1Tln1q(tcorrect))

Back to high school maths:

a·logx=logxa

So we can go to:

Perplexity=exp(t=1Tln(1q(tcorrect))1T)

Using the rule we used above for converting the exponential of a sum to the product of exponentials:

Perplexity=t=1Texp(ln(1q(tcorrect))1T)

...and then using our blogbx=x rule, we get:

Perplexity=t=1T(1q(tcorrect))1T

Now, we could refactor that one step further, like this:

Perplexity=(t=1T1q(tcorrect))1T

...which shows that the combined perplexity over all sequence/target pairs is the geometric average over all of the per-token perplexities (just as with an arithmetic average of n numbers, you add them all and then divide by n, for the geometric average you multiply all of your numbers and then take the nth root). And that's interesting, but let's take another look at the first version:

Perplexity=t=1T(1q(tcorrect))1T

...and then compare it to the formula we had for perplexity where we're using the full cross entropy, comparing our LLM's output to a probability distribution p that captured the real distribution of next tokens:

Perplexity=x(1q(x))p(x)

They're tantalisingly close! But those exponents are the difference -- in the first one, we have a constant 1/T exponent, but in the second we have p(x)

Now let's imagine that we have done our training run on a really large and diverse dataset, so large and diverse that the distribution of targets matches what language as a whole has, and we want to work out the perplexity over that big dataset.

As a toy example with essentially the same properties, let's say that all of our training sequences are "the fat cat sat on the", and that in reality that sequence is completed by " mat" 60% of the time, " lap" 30% of the time, and " dog" 10% of the time. Conveniently enough, we have ten training sequences, six of which have " mat" as the target, three have " lap", and one has " dog".

Once again, our perplexity formula is this:

Perplexity=t=1T(1q(tcorrect))1T

...and what that means is that we have this bit:

(1q(tcorrect))1T

...multiplied ten times, one for each of our sequence/target pairs. Six of those times will be for the " mat" case, so we can just multiply those together (and substitute 10 in for T, as we know that that is what it is):

(1q(mat))610

We can do the same for the " lap" and " dog" cases too, so the whole formula reduces to:

(1q(mat))610×(1q(lap))310×(1q(dog))110

Those exponents look familiar! We said up-front that the real world distribution (which our dataset matched) was " mat" 60% of the time, " lap" 30% of the time, and " dog" 10% of the time.

The equation is the product of the per-token perplexities, taken to the power of their respective probabilities, and that is:

Perplexity=x(1q(x))p(x)

What we've shown is that using the one-hot probability distribution for the targets doesn't actually break anything; perplexity is always calculated relative to the actual diversity of the possible next tokens in the dataset that we've provided.

So when we average perplexity across a large enough dataset, we're effectively estimating the same number we'd get if we compared the model directly against the true probabilities of the language itself.

Conclusion

I hope that was an interesting read -- for me, at least, it was a valuable reminder of some basic maths from long ago.

And if there's one thing to take away from this, I think it's this:

Perplexity is, like loss, a metric you use on a model against a particular dataset (aka corpus). It measures how many correct tokens the model was choosing between for each prediction it made, weighted by how likely each one is. In practice, you often see benchmarks saying things like:

Our model achieved a perplexity of 15.2 on our validation set, and 18.5 on WikiText-103.

What that means is that it was effectively choosing between about 15.2 plausible next tokens on the validation set, and 18.5 on WikiText -- presumably the latter differed from the training data more than the former.

And finally, while you might intuitively think that our use of one-hot probability distributions would cause problems with the maths, for a well-balanced (which in practice just means "large and diverse") corpus, it all balances out because the more common a particular next token choice is in the dataset, the more times it will contribute perplexity to the total, and it will all balance out.

That's all for now! For the next post, let's see if I can wrap up training :-)

Here's a link to the next post in this series.