Token count. But tokens per batch is a bad metric. I learned this the hard way through experience.
It turns out that what matters is step count. Increasing batch size makes the model train faster, but increasing seq length from 1024 to 2048 doesn’t make it train twice as fast. So saying 104M tokens rather than batch size 50k x 6 is misleading to yourself. (One of the most surprising aspects of learning ML was how easy it is for me to trick myself in various silly ways like that.)
The mental model to make this easy to remember: progress happens in discrete quantities called steps. Training on 104M tokens per step means it’s embedding 104M tokens of knowledge every step. This isn’t the same thing as requiring an special case optimizer due to large batch sizes — sequence length adds "knowledge bandwidth", but otherwise doesn’t mess with the training dynamics.
As far as large batch optimizers, there’s LARS, which google used for their MLPerf results. I imagine they stuck with that. It creates a per-layer confidence metric, so that when the massive batch size makes a massive change, it dampens the change to smooth out the effect across the network. And since it’s a multiply, the shape of the gradient (by "shape" I mean in 3D space, where the Z axis is the intensity of the gradient) remains the same, so it doesn’t harm any knowledge transfer. It’s purely a stabilization aid.
Whether you count tokens or sequences, it’s about 25x the usual batch size. My guess is it makes for a fancy benchmark but isn’t actually useful. Would be interested in being proven otherwise.
This is a great series of questions and it isn't our goal to prove you otherwise!
We work with customers interested in training models who run their own ablations, including batch size and learning rates.
Based on that, we demonstrate workloads that we think will be interesting to potential customers! Absolutely agreed that this workload has a larger batch size than the public literature suggests.
Wait, this actually proves their point. It sounds like you didn’t train a model, but rather ran a bunch of ops. That’s fine, but the answer to their question would be "we didn’t actually measure loss or try to get a useful result, because the goal was to demonstrate raw throughput."
I agree that raw throughput is the metric to aim for, since figuring out how to put it to use is an exercise for the user. But it’s probably best to be straightforward about that. The reason MLPerf measures "time until loss reaches X for a resnet classifier on imagenet" is precisely because it gives information about performance at scale —- if you didn’t train anything, you haven’t actually achieved "fastest training run". You’ve achieved largest throughput, which is similar but not the same.
And I don’t think this is a pedantic distinction. Just throw LARS on it (the MLPerf code you used in 2019 is at https://github.com/shawwn/daxx-lightning fwiw, and it runs on pods last time I tried) and see how it performs in practice.
EDIT: reading over https://github.com/google/maxtext, it looks pretty delightful. I was in the TPU scene back in 2020, and there was no way to do ahead of time compilation. Restarting training runs was a major pain point once LLMs became the focus, and I kept pestering James Bradbury to please add it. Happy to see that it finally made its way in.
It sounds like MaxText is the right approach, but until you try to actually train a model —- to achieve a low loss on a specific dataset —- you can’t know whether the code works. This isn’t theoretical. I spent over a year debugging google’s public BigGAN code (compare_gan) and discovered why it never worked: the batch norm gamma parameter was initialized to zero instead of one, so everything was being multiplied by zero to start off, which severely crippled the model.
A bug like that could easily be lurking in MaxText. You can’t know until you try to train a useful LLM. Note that compare_gan seemed to work; the authors noted that they couldn’t replicate the performance of the official BigGAN paper, but the samples looked sort of reasonable. But the model was screwy, and no one knew why until the rigorous debugging process.
If you need help with this, let me know. There are challenges when training an actual LLM that aren’t present in theoretical runs like these. For example, you need a big dataset. The Pile is a good starting point for that, and it gives a nice comparison baseline, e.g. to GPT-J.
Alternatively, post a link to a tensorboard.dev showing the loss curves for your training runs. I suspect the reason you didn’t is because you didn’t have a real dataset. That’s ok, but it doesn’t prove that MaxText works until there’s empirical evidence.
In other words, DavidSJ was precisely right: it’s an impressive-looking benchmark, which doesn’t actually help your customers train LLMs in practice. They’ll need to solve this problem eventually, and the optimizer is certainly one aspect. The other is the quantized INT8 training. It may sound impressive to say it gives a 1.4x step count speedup, but that’s useless if it harms loss convergence. How do you know it doesn’t? This isn’t an easy question to answer unless you run MLPerf or some other known stable baseline, which I’m a little shocked no one has done yet.
Totally agreed with all your comments! MaxText mainline is right now a reference implementation for users who have their own scientific opinions on model architecture and convergence. We're additionally hoping to let MaxText run compatibly with some open-source models for customers who want to use known good configurations.
It turns out that what matters is step count. Increasing batch size makes the model train faster, but increasing seq length from 1024 to 2048 doesn’t make it train twice as fast. So saying 104M tokens rather than batch size 50k x 6 is misleading to yourself. (One of the most surprising aspects of learning ML was how easy it is for me to trick myself in various silly ways like that.)
The mental model to make this easy to remember: progress happens in discrete quantities called steps. Training on 104M tokens per step means it’s embedding 104M tokens of knowledge every step. This isn’t the same thing as requiring an special case optimizer due to large batch sizes — sequence length adds "knowledge bandwidth", but otherwise doesn’t mess with the training dynamics.
As far as large batch optimizers, there’s LARS, which google used for their MLPerf results. I imagine they stuck with that. It creates a per-layer confidence metric, so that when the massive batch size makes a massive change, it dampens the change to smooth out the effect across the network. And since it’s a multiply, the shape of the gradient (by "shape" I mean in 3D space, where the Z axis is the intensity of the gradient) remains the same, so it doesn’t harm any knowledge transfer. It’s purely a stabilization aid.
Kind of weird I remember that after three years.