Full title: "Google Cloud demonstrates the world’s largest distributed training job for large language models across 50000+ TPU v5e chips"
Summary from Bard: "This article is about training large language models (LLMs) on Google Cloud TPUs. It discusses the challenges of training LLMs at scale, and how Google Cloud TPU Multislice Training addresses these challenges. The article also details the results of a recent experiment in which Google trained a 128B parameter LLM on 50,944 TPU v5e chips. This experiment is the largest publicly disclosed LLM distributed training job to date."
Great questions! Slices are a set of TPU chips that share a fast, private inter-chip-interconnect. Unlike the current GPU generation in clouds, the TPUs on different machines can communicate through this private network. Multislice means that we're using a hierarchical network, where there is both inter-chip-interconnect and normal data-center netowrking.
Also, I should point out that a set of machines hosting TPUs is referred to as a "pod", which is not the same thing as a Kubernetes pod (also referenced in this doc).
Kubernetes chose "pod" to represent a set of co-scheduled containers, like a "pod of whales". Other systems like Mesos and Google's Borg https://storage.googleapis.com/pub-tools-public-publication-... use "task" to refer to a single container but didn't have a concept for heterogenous co-scheduled tasks at the time.
Somewhat ironically, it now means TPUs on GKE are confusing because we have TPUs hosts organized into "pods", and "pods" for the software using the TPUs.
A Kubernetes pod using a TPU lands on a host which is part of a slice of a TPU pod.
True. The pod’s monotonic and atomic lifecycle across containers is a significant difference, but you can broadly accomplish similar behaviors with an alloc for sharing resources.
Unlikely. One reason Google Cloud is so terrible is that nobody in Google actually uses Google Cloud. It used to be that every time I mentioned this, somebody would jump in and say, "Well actually, Google Domains runs on Google Cloud," and we'd discuss whether Google Domains was a business critical part of Google. https://support.google.com/domains/answer/13689670?hl=en
> Unlikely. One reason Google Cloud is so terrible is that nobody in Google actually uses Google Cloud.
Well, actually, Google Cloud is just an abstraction on top of internal Google infra, so this isn't the right question. So, it depends on what you want to infer/compare.
> Well, actually, Google Cloud is just an abstraction on top of internal Google infra
I didn't say otherwise. Of course Google Cloud runs on internal Google infrastructure. They wouldn't have an entirely different stack to build Google Cloud on. The problem is that Googlers don't use Google Cloud.
It is the right question. It's the right question because Google doesn't dogfood Google Cloud like they should/could. Dogfooding a bunch of stuff at a lower level of abstraction isn't the same thing.
Xoogler here. GCP was not an abstraction on Borg when I was there. GKE isn't either.
So up until late 2018 when I left, very little of Cloud ran on "proper" Google3 infra. This may have shifted slightly (cloud has been fishing for good Google infra to externalize a lot), but in general cloud!=google3 infra.
If anything important ran on Google Cloud, you can bet we'd see a blog post from Google Cloud marketing about that. Yes, many of the money losing side bets from the non-Google companies under the Alphabet umbrella use Google Cloud. That's only because they want the optionality to spin them off if by some miracle any of them are ever worth anything. If they were part of Google, they would use internal infrastructure. If they weren't under Alphabet, they would use AWS or Azure like everyone else.
Question for rwitten or anyone else involved in this project:
I see a per-device batch size of 6 for the 16B model. With 256x199 = 50944 TPUs and a sequence length of 2048, this works out to 104M tokens per batch. This is much larger than typical for training runs of dense LMs of this size, which are usually closer to ~4M tokens per batch.
Was your critical batch size really this large? In other words, did you really see a benefit as compared to a much smaller batch size (and probably many fewer TPUs)? Did you use some special learning rate schedule or optimizer to achieve this?
Token count. But tokens per batch is a bad metric. I learned this the hard way through experience.
It turns out that what matters is step count. Increasing batch size makes the model train faster, but increasing seq length from 1024 to 2048 doesn’t make it train twice as fast. So saying 104M tokens rather than batch size 50k x 6 is misleading to yourself. (One of the most surprising aspects of learning ML was how easy it is for me to trick myself in various silly ways like that.)
The mental model to make this easy to remember: progress happens in discrete quantities called steps. Training on 104M tokens per step means it’s embedding 104M tokens of knowledge every step. This isn’t the same thing as requiring an special case optimizer due to large batch sizes — sequence length adds "knowledge bandwidth", but otherwise doesn’t mess with the training dynamics.
As far as large batch optimizers, there’s LARS, which google used for their MLPerf results. I imagine they stuck with that. It creates a per-layer confidence metric, so that when the massive batch size makes a massive change, it dampens the change to smooth out the effect across the network. And since it’s a multiply, the shape of the gradient (by "shape" I mean in 3D space, where the Z axis is the intensity of the gradient) remains the same, so it doesn’t harm any knowledge transfer. It’s purely a stabilization aid.
Whether you count tokens or sequences, it’s about 25x the usual batch size. My guess is it makes for a fancy benchmark but isn’t actually useful. Would be interested in being proven otherwise.
This is a great series of questions and it isn't our goal to prove you otherwise!
We work with customers interested in training models who run their own ablations, including batch size and learning rates.
Based on that, we demonstrate workloads that we think will be interesting to potential customers! Absolutely agreed that this workload has a larger batch size than the public literature suggests.
Wait, this actually proves their point. It sounds like you didn’t train a model, but rather ran a bunch of ops. That’s fine, but the answer to their question would be "we didn’t actually measure loss or try to get a useful result, because the goal was to demonstrate raw throughput."
I agree that raw throughput is the metric to aim for, since figuring out how to put it to use is an exercise for the user. But it’s probably best to be straightforward about that. The reason MLPerf measures "time until loss reaches X for a resnet classifier on imagenet" is precisely because it gives information about performance at scale —- if you didn’t train anything, you haven’t actually achieved "fastest training run". You’ve achieved largest throughput, which is similar but not the same.
And I don’t think this is a pedantic distinction. Just throw LARS on it (the MLPerf code you used in 2019 is at https://github.com/shawwn/daxx-lightning fwiw, and it runs on pods last time I tried) and see how it performs in practice.
EDIT: reading over https://github.com/google/maxtext, it looks pretty delightful. I was in the TPU scene back in 2020, and there was no way to do ahead of time compilation. Restarting training runs was a major pain point once LLMs became the focus, and I kept pestering James Bradbury to please add it. Happy to see that it finally made its way in.
It sounds like MaxText is the right approach, but until you try to actually train a model —- to achieve a low loss on a specific dataset —- you can’t know whether the code works. This isn’t theoretical. I spent over a year debugging google’s public BigGAN code (compare_gan) and discovered why it never worked: the batch norm gamma parameter was initialized to zero instead of one, so everything was being multiplied by zero to start off, which severely crippled the model.
A bug like that could easily be lurking in MaxText. You can’t know until you try to train a useful LLM. Note that compare_gan seemed to work; the authors noted that they couldn’t replicate the performance of the official BigGAN paper, but the samples looked sort of reasonable. But the model was screwy, and no one knew why until the rigorous debugging process.
If you need help with this, let me know. There are challenges when training an actual LLM that aren’t present in theoretical runs like these. For example, you need a big dataset. The Pile is a good starting point for that, and it gives a nice comparison baseline, e.g. to GPT-J.
Alternatively, post a link to a tensorboard.dev showing the loss curves for your training runs. I suspect the reason you didn’t is because you didn’t have a real dataset. That’s ok, but it doesn’t prove that MaxText works until there’s empirical evidence.
In other words, DavidSJ was precisely right: it’s an impressive-looking benchmark, which doesn’t actually help your customers train LLMs in practice. They’ll need to solve this problem eventually, and the optimizer is certainly one aspect. The other is the quantized INT8 training. It may sound impressive to say it gives a 1.4x step count speedup, but that’s useless if it harms loss convergence. How do you know it doesn’t? This isn’t an easy question to answer unless you run MLPerf or some other known stable baseline, which I’m a little shocked no one has done yet.
Totally agreed with all your comments! MaxText mainline is right now a reference implementation for users who have their own scientific opinions on model architecture and convergence. We're additionally hoping to let MaxText run compatibly with some open-source models for customers who want to use known good configurations.
You can measure it either way, and you’ll see it both ways in the literature. In this case it doesn’t matter much how you measure since 2048 is a typical pretraining sequence length. 300k sequences per batch is huge compared to typical batch sizes in the literature, which are closer to 2048, for about 4M tokens total.
Ok so they claim in the article, 50000 TPU’s is equivalent to 10 exaflop floating point computations. That is equivalent to ~2,512 NVIDIA H100’s, which is like really small. Just shows the difference between TPU’s and GPU’s I guess. Inflection, a new LLM company created a 20,000 H100 cluster, I’m positive OpenAI, Tesla, Meta etc have orchestrated a job on more than 2500 H100 GPU’s.
Hey! I'm an contributor on this (Rafi Witten), all opinions my own.
You're asking the right question but I think the math is off by a bit. The equivalent number on the H100's is 989 TFLOP/s/chip so the equivalent job is ~10K H100's = (10 * 10^18) / (989 * 10^12). (Both chips also have 8-bit acceleration!)
I believe this is the largest ML job both by exaflops and number of chips every demonstrated. Other companies own more chips or exaflops than we show in this job but getting all the hardware working at once on a single job is a different matter! :-)
I think your math is also slightly off, in the Google article, it claims
“that is capable of achieving 10 exa-FLOPs (16-bit).” , so you should be comparing with 16 bit operations from a H100.
989 is TF32 core, for 16 bit it is 1979, so I guess around 5000 H100’s in a single training job would be equivalent to the training job mentioned in this article.
Either way I actually would not be surprised if OpenAI has launched a single job on more than 10k GPU’s, but I also am not very knowledgeable on practical scaling. Congrats on the feat!
No, it is not. That's the sparse fp8 flop number, but you need to ignore sparsity and compare bf16 flops not fp8 flops for the comparison the ancestor post is making.
It's worth noting that just because an H100 has a higher flops number doesn't mean your program is actually hitting that number of flops. Modern TPUs are surprisingly competitive with Nvidia on a perf/$ metric, if you're doing cloud ML they are absolutely worth a look. We have been keeping costs down by racking our own GPUs but TPUs are so cost effective that we need to do some thinking about changing our approach.
I'm not certain but I think part of this is that XLA (for example) is a mountain of chip-specific optimizations between your code and the actual operations. So comparing your throughput between GPU and TPU is not just flops-to-flops.
This is a blog post from Google Cloud marketing. It's saying that you, too, could train an LLM on Google Cloud if you hand them enough money. You can't do that on Inflection's or Tesla's clusters. Similar marketing blog post from last year: https://cloud.google.com/blog/products/compute/calculating-1...
The PaLM paper linked in the blog post is about how to get something actually useful out of that compute.
Something that doesn't seem worth bragging about is that the startup time increases linearly with the cluster size. Wouldn't you want it to be constant? What's the issue there?
Disclaimer: work associated with this team, didn't write or review the blog post
Article stated that it was throughput scheduling the pods on the clusters (from unrelated benchmarks that's usually ~300 pods/sec throughput for kube scheduler today) and then doing XLA compilation at pod launch, rather than amortizing once for all jobs.
Optimizing throughput of kube scheduler is a good general opportunity and something I believe we would like to see.
I believe AOT compilation just not a critical optimization for the test, we would recommend it when running large and long training jobs to AOT compile to keep pod start latency low for hardware failures and job restarts (from checkpoints).
> The start times we observed were impressive, but we believe we can improve these even further. We are working on areas such as optimizing scheduling in GKE to increase throughput and enabling ahead-of-time compilation in MaxText to avoid just-in-time compilations on the full cluster.
Thanks for the context. I remember recently reading a paper from I think Baidu where they claimed to have a container arrival rate in the millions per second, consequently it was practical to operate their whole site in the style of lambda/cloud functions.
Actually now that I am searching for that it seems Baidu has a number of papers on workload orchestration at scale specifically for learning.
I will note that a trend I have observed with recent ML - as we increasingly use accelerators and models correspondingly grow in size, we are returning to a "one machine, one workload" paradigm for the biggest training and inference jobs. You might have 8k accelerators, but only 1000 machines, and if you have one container per host 300 schedules / second is fast.
While at the same time as you note we have functional models for container execution that are approaching millions of dispatches for highly partitionable work, especially in data engineering and ETL.
(Contributor on the blog post, all opinions my own)
Agreed with you and we definitely weren't trying to brag! This is fast compared to people's expectations in the space but slow compared to what we should be able to accomplish and will accomplish in the future.
Summary from Bard: "This article is about training large language models (LLMs) on Google Cloud TPUs. It discusses the challenges of training LLMs at scale, and how Google Cloud TPU Multislice Training addresses these challenges. The article also details the results of a recent experiment in which Google trained a 128B parameter LLM on 50,944 TPU v5e chips. This experiment is the largest publicly disclosed LLM distributed training job to date."