JAX backends and devices
There's nothing like writing your own code with a framework to clarify how things
fit together! Continuing with my port of my PyTorch LLM code to
JAX, I wanted to load up a large dataset:
the 10,248,871,837 16-bit unsigned integers in the train split of
gpjt/fineweb-gpt2-tokens.
That's just over 19GiB of data.
from safetensors.flax import load_file
...
full_dataset = load_file(dataset_dir / f"train.safetensors")["tokens"]
When I ran that, I got a CUDA out-of-memory error:
jax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 19.09GiB.
That makes sense! The allocation it was trying to do is exactly the size of the data I was trying to load. I have an RTX 3090 with 24 GiB, but some is already used up by the OS, various apps, and a model that the code creates earlier on.
But in PyTorch land, I was used to things being loaded into RAM by default, and only moved over to the GPU when I asked it to do that. JAX was clearly loading to the GPU by default. How could I stop it from doing that for this case? The load into the GPU was happening inside Safetensors, in code I couldn't directly control.
Understanding how to do it helped me understand a little bit more about JAX.
Using Safetensors with Flax
I'm porting my PyTorch LLM code to JAX, using Flax as the neural network layer. For various reasons I wanted to use Safetensors to store checkpoints of the model. It took a little while to get it working; here's the trick I learned.
On first looking into JAX
Much have I travell'd in the realms of gold,
And many goodly states and kingdoms seen;
Round many western islands have I been
Which bards in fealty to Apollo hold.
Oft of one wide expanse had I been told
That deep-brow'd Homer ruled as his demesne;
Yet did I never breathe its pure serene
Till I heard Chapman speak out loud and bold:
Then felt I like some watcher of the skies
When a new planet swims into his ken;
Or like stout Cortez when with eagle eyes
He star'd at the Pacific -- and all his men
Look'd at each other with a wild surmise --
Silent, upon a peak in Darien.
John Keats, On First Looking into Chapman's Homer
I've been working with PyTorch quite a lot for the last couple of years, and feel like I've come to a reasonably solid understanding of how it all fits together. Working through Sebastian Raschka's book "Build a Large Language Model (from Scratch)", training my own LLMs locally and in the cloud, rebuilding Andrej Karpathy's 2015-vintage RNNs -- over time, it all adds up!
But, of course, there are other frameworks, and one I kept hearing about was JAX. While it's less dominant than PyTorch, it has a reputation for a certain cleanliness, a certain purity. And having spent time over the last couple of weeks working through the tutorials, and translating small PyTorch examples into it, I've been really impressed.
In this post I want to give an overview -- to report back to beginners like me, still living in PyTorch-land, on my new discovery. Less like Herschel discovering Uranus, and more like a 16th-century European coming back after having discovered something that the people who lived there were perfectly well aware of. What is this JAX thing, and how does it differ from PyTorch?
10Gb/s Ethernet: using mini-heatsinks with a 10GBASE-T SFP+ module
In my last post I showed the somewhat-scary
temperatures I was getting on the MikroTik 10GBASE-T SFP+ module I have plugged
into nigel, the 10Gb/s switch I have in my study.
As I mentioned then, the plan was to try using some of the mini-heatsinks that
people use on Raspberry Pis, to see if that would help.
Here's how it went.
10Gb/s Ethernet: what I actually did to get it working in my home
Having learned enough about 10Gb/s Ethernet to be comfortable about setting it up in my house, it was time to bite the bullet: order it from the ISP, buy some kit, and get started.
I already had 2.5Gb/s working. The apartment has structured cabling -- each room has one or more RJ45 sockets in the wall, and there's a patch panel downstairs by our front door that has a matching patch socket for each wall socket. So when we moved in, I simply set things up so that there was a 2.5Gb/s switch down by the patch panel, and wired everything together there. Most of our stuff works over WiFi, of course, but I needed a wired backbone to connect the excessive number of computers in my study both to each other, and to the outside world.
What did I need to do?
10Gb/s Ethernet: what I had to (re)learn
My ISP recently started offering a 10Gb option, and my "shiny new thing!" Pavlovian response immediately kicked in. So of course, I had to upgrade the wired networking in my home -- which meant I had to learn a few things to get it all working, and relearn a bunch of stuff I'd forgotten over the years.
Wired networking for home and small offices hasn't really moved forward that much in the last 20-odd years. Back in 2006, gigabit Ethernet was standard for businesses, and most home users moved to it not long after. Perhaps due to the rise of WiFi for most "last few metres" connections, it's pretty much stagnated there, perhaps with a bit of a push towards 2.5Gb/s more recently.
But with faster ISP connections arriving, I think things are starting to become a bit more interesting. Even the fastest WiFi 7 connections are only able to get up to around 6Gb/s to a single device -- and that's in an ideal "super-fast machine sitting right next to the AP in a shielded lab" setup.
Here's what I had to drag up from my memory, and the new stuff I had to learn, in order to get this all working. I'll write about the background in this post, and then tomorrow I'll post about what I actually put in place.
Writing an LLM from scratch, part 33 -- what I learned from finally getting round to the appendices
After finishing the main body of "Build a Large Language Model (from Scratch)", I set myself three follow-on goals.
The first was training a full GPT-2-small-style base model myself. That was reasonably easy to do but unlocked a bunch of irresistible side quests; having finally got to the end of those, it's time to move on to the others: reading through the book's appendices, and building my own GPT-2 style model in JAX.
This post is about the appendices. The TL;DR: there was stuff in there that could have saved me time in my side-questing, but I think that having to work those things out from scratch probably helped me learn them better.
Writing an LLM from scratch, part 32m -- Interventions: conclusion
Last November, when I finished the main body of "Build a Large Language Model (from Scratch)", I set myself a number of follow-on goals. One was "training the full GPT-2 base model myself".
I've reached the end of that journey, with a model that is almost -- if not quite -- as good as GPT-2 small, trained in 44 hours on my own machine, so I thought it would be worth summarising how it went.
Writing an LLM from scratch, part 32l -- Interventions: updated instruction fine-tuning results
I've been working on a GPT-2-small-style LLM based on Sebastian Raschka's book "Build a Large Language Model (from Scratch)", and have tried a bunch of different things to see if I could get it to approach the quality of the original OpenAI GPT-2-small, measured in terms of loss on a held-back test dataset. After working through them, in my last post, I managed to train one that was almost (if not quite) there.
Now, back before I started digging into these interventions, I was doing three evals for each model I built; a smoke test (to see if it could give a coherent completion to "Every effort moves you"), a test for that test set loss, and an instruction-following test that fine-tuned the model on the Alpaca dataset, got it to generate results for a test set of instructions, and then used an LLM as a judge to score them.
The idea behind this was that the loss on the test set was an interesting technical measure of the quality of a model, but it didn't really tell us much about how useful it might be in reality.
Unfortunately, in January, I realised that my methodology was bad; because I was asking the LLM to score a model in isolation, the LLM's natural randomness would mean that results were not really comparable, at least for models that were reasonably close in quality.
For example, if two models both replied to
Name the author of 'Pride and Prejudice'.
with:
The author of 'Pride and Prejudice' is Sarah Palin.
...then one run of the instruction-following test might "find the judge LLM in a good mood" and get, say, 5% -- after all, the model tried to answer, and actually used a real person's name, even if the answer was totally wrong. But in another run, the judge might be in a "worse mood" and score it at 0%.
My fix was to have two scripts:
- One that fine-tuned the model then got it to generate responses, then saved those responses in a file.
- One that took a bunch of files generated by the above, one for each of a set of different models, and presented them to the LLM together, so that it would (hopefully) be consistent in how it rated them relative to each other.
The details are here.
Because doing it that way was significantly more work, I've not been doing these tests as part of the interventions mini-series. I felt it would make more sense to wait until I'd tried a bunch of interventions and got a number of models to try.
Now I have those, so let's give it a go!
How an LLM becomes more coherent as we train it
I remember finding it interesting when, back in 2015, Andrej Karpathy posted about RNNs and gave an example of how their output improves over the course of a training run. What might that look like for a (relatively) modern transformers-based LLM?
I recently trained a GPT-2-small-style LLM, with 163 million parameters, on about 3.2 billion tokens (that's about 12.8 GiB of text) from the Hugging Face FineWeb dataset, and over the course of that training run, I saved the current model periodically -- 57 checkpoints over two days.
Here's what it looked like -- the start, the end, and some interesting waypoints in between.