+1, we still have a lot of performance we can extract! JIT-compiled train steps, more optimized data loading and sharding, gradient accumulation, and activation checkpointing. We will continue building and will do another blog soon after implementing all the improvements!
We initially started with the goal of fine-tuning LLaMA 3 on TPUs, but PyTorch XLA was clunky, so we decided to rewrite the model in JAX. That said, as mentioned earlier in the thread, we also believe JAX is a better platform for non-NVIDIA GPUs and want to build on JAX+openXLA for building infra for non-NVIDIA GPUs.
Note: we couldn't run the JIT-compiled version of the 405B model due to our code/VRAM constraints (we need to investigate this further). The entire training run was executed in JAX eager mode, so there is significant potential for performance improvements.
GPU utilization across the board was still ~30-40% even with eager mode, which is quite good! With JIT, I think the GPU util can easily shoot up to ~50-60%.
I’d be interested to hear how long it takes, wall clock, to train the same model to the same loss with same number of batches. I don’t trust utilization to say anything useful.
Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD MI300x GPUs using JAX instead of PyTorch. JAX's advanced sharding APIs allowed us to achieve great performance. Check out our blog post to learn about the cool sharding tricks we used. We've also open-sourced the code: https://github.com/felafax/felafax
We're a small startup building AI infra for fine-tuning and serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).
Problem: Many companies are trying to get PyTorch working on AMD GPUs, but we believe this is a treacherous path. PyTorch is deeply intertwined with the NVIDIA ecosystem in a lot of ways (e.g., `torch.cuda` or scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function). So, to get PyTorch code running on non-NVIDIA hardware, there's a lot of "de-NVIDIAfying" that needs to be done.
Solution: We believe JAX is a better fit for non-NVIDIA hardware. In JAX, ML model code compiles to hardware-independent HLO graphs, which are then optimized by the XLA compiler before hardware-specific optimization. This clean separation allowed us to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no changes.
Our strategy as a company is to invest upfront in porting models to JAX, then leverage its framework and XLA kernels to extract maximum performance from non-NVIDIA backends. This is why we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX model works great on TPUs and runs perfectly on AMD GPUs.
We'd love to hear your thoughts on our vision and repo!
While your project is neat and I'd like to see how the performance compares, for LLM training, PyTorch, including torch.compile works completely OOTB on AMD.
All you have to do is pip install the ROCm version of PyTorch (or run the docker image) and it's seamless (the ROCm version just treats torch.cuda as calling ROCm).
I've used axolotl (trl/accelerate based), torchtune, and LLaMA-Factory, which are all PyTorch-based without any issues for training.
Yeah I would suggest taking a look at PyTorch on AMD before saying stuff like "scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function", because that is demonstrably false.
Also, FWIW, I would suggest getting a small Llama 3.1 model training fast before trying to do a big 405B model -- faster to iterate and almost everything you'll learn on the small models will scale to the 405B.
Thanks for the feedback! I appreciate you pointing that out. My understanding was based on the PyTorch documentation for scaled_dot_product_attention (https://pytorch.org/docs/stable/generated/torch.nn.functiona...).
- "The function may call optimized kernels for improved performance when using the CUDA backend. For all other backends, the PyTorch implementation will be used."
And was trying to make a broader point about the lack of transparency (in performance, lower-level impl) in PyTorch when running on NVIDIA vs. non-NVIDIA hardware.
> And was trying to make a broader point about the lack of transparency (in performance, lower-level impl) in PyTorch when running on NVIDIA vs. non-NVIDIA hardware.
I don't quite understand this argument. Lack of transparency from running PyTorch so instead we're gonna leave it all to XLA? How does this solve the "transparency" issue?
Having a common library function that is either lighting fast or dog slow depending on the hardware, is not a great position to be in.
Moreover, this will get worse as more CUDA specific features are added to PyTorch with ad-hoc fallback functions.
I guess OP is saying that XLA is more transparent in this regard, because it wouldn’t use functions like these and the generated comparable code would be on-pare performance wise?
> it wouldn’t use functions like these and the generated comparable code would be on-pare performance wise
Perhaps if XLA generated all functions from scratch, this would be more compelling. But XLA relies very heavily on pattern-matching to common library functions (e.g. CuDNN), and these patterns will certainly work better on Nvidia GPUs than AMD GPUs.
In this way, I actually think explicitly calling the common library functions is actually much more transparent.
How are you verifying accuracy for your JAX port of Llama 3.1?
IMHO, the main reason to use pytorch is actually that the original model used pytorch. What can seem to be identical logic between different model versions may actually cause model drift when infinitesimal floating point errors accumulate due to the huge scale of the data. My experience is that debugging an accuracy mismatches like this in a big model is a torturous ordeal beyond the 10th circle of hell.
Good question. We used a new AI+math-based testing tool (benchify.com) to run comparison tests, but we are working on building more robust infrastructure for this. Translating models from PyTorch to JAX is core to our strategy.
That said, this path is not uncommon (translating from one framework to another). HuggingFace translates Google's Gemma family models from JAX to PyTorch, and a ton of people use it.
When you say "model versions", do you mean different quantizations of the model? Then it's not floating point errors that accumulate. Different quantizations of the model are different models. People will call such a model something like Meta-Llama-3.1-8B-Instruct--q4_0, claiming that it's just a "version" of the Meta-Llama-3.1-8B-Instruct. But it's just a lie. It's not the same model, and you should not expect the same results. There is no reason to debug the differences, what exactly would you expect to find, and what action would you envision to take once you find what you are looking for? However, is the quantized version still a useful LLM? Absolutely. Most people don't have an A100 to run the original model, so a quantized version is better than nothing.
Very fascinating, can you explain more about a time when this happened?
Like what area was affected by fp errors, why were they introduced (was it like refactoring of pytorch code?), how was this determined to be the cause?
Does JAX have its own implementations of matmul, flash attention etc? Or does it use the ROCm implementations like PyTorch does? (e.g,. hipblaslt, Composable Kernel FA etc)
Not too familiar with JAX, but the abysmal PyTorch training perf on MI300x is in large part attributable to the slow perf of the ROCm libraries it is using under the hood.
JAX has a sub-system called Pallas[1] with a Triton-like programming model and an example implementation of Flash Attention [2]. It is quite fast. On TPUs I've heard that the XLA compiler already emits a flash-attention-like computation graph for a regular JAX implementation of attention so there's no need to have some specialized kernel in that case.
A couple months ago I did some testing on some consumer cards. [1] I think you should be able to use torchtune or axolotl without anything besides installing the ROCm version of PyTorch.
Yeah, those numbers are correct as of their testing (in June) although people who are really interested should check out the linked repo and do their own runs as software/optimizations have continued to change a lot and the RDNA3 side has a lot of untapped potential. Eg, the 7900 XTX has a huge theoretical FLOPS advantage over the 3090 but the results totally don't reflect that. One example of this hobbling is that RDNA3 only recently got backpass FA via a still under-optimized aotriton implementation: https://github.com/ROCm/aotriton/pull/39
There are also still ongoing optimizations on the Nvidia side as well. In the beginning of the year the 7900 XTX and 3090 were pretty close on llama.cpp inference performance, but a few months ago llama.cpp got CUDA graph and FA support implemented that boosted perf significantly for both my 3090 and 4090.
(For AI/ML, a used 3090 remains I think the best bang/buck for both inference and small training runs. You can pay twice as much for the twice as fast 4090, but at the end of the day you'll still wish you had more VRAM, so it's hard to really recommend unless you're going to use mixed precision. The RDNA3 cards are not as bad to work with as the Internet would have you believe, but they'd have to be a lot cheaper if your main use case was AI/ML for both the PITA factor and just from pure real-world performance.)
i've been running inference on the 7900xtx using pytorch and rocm (installed directly from package managers, no manual fiddling) with great performance. no problem running the full flux1.dev model, for instance. haven't looked at training or fine-tuning yet.
Given it is a migration, is there actual comparison of the same model on PyTorch vs your version. The comparison table there seems to be on technical side.
We have a few technical issues that we still need to address:
1) This entire fine-tuning run was done in JAX eager mode. I kept running out of memory (OOM) when trying to `jax.jit` the entire training step. Even gradual `jax.jit` didn't work.
2) The current version doesn't have gradient accumulation, and with a batch size of just 16, that’s not ideal. I'm working on implementing gradient accumulation next.
3) We still haven't found a good way to load large sequence-length data (like 32k sequence length). Currently, before sharding the training batch across GPUs, it ends up loading the entire batch onto a single GPU’s VRAM and causes OOM issues.
I'm glad to see a full implementation on AMD hardware.
I'm not familiar with JAX, but the idea of providing an abstraction layer to more easily get to work on what hardware is available seems really valuable. Bringing back some competitiveness to the ecosystem will be a big win for workload mobility.
I suspect that price/performance across implementations will be highly dependent on contract details, but do you intend to publish some comparisons in the future?
Any direct comparisons to 8xH100? 2 toks/sec seems very slow!
I haven't done any LoRA training on MI300x myself, but I have done LLama 3.1 full training on 8xMI300x and got pretty close to 8xH100 performance with my own kernels (ROCm is just too slow).
Oops, my calculation was wrong. Let me add an edit to the blog, thanks for pointing it out!
My train step was taking 30s.
And I was using a batch size of 16 and seq length of 64, making the training speed as (16*64/30) tokens per sec == 35 tokens per second (for fine-tuning in JAX eager mode).
Note that I’m ignoring the attention flops in this simplified calculation, but they would be a second order effect at this sequence length
Also note that I’m assuming full weight training, not LoRA . The result would be lower MFU if using LoRA
These MI300X results are promising functionally (it's tough to get any model this big running) but they have a long way to go on perf. It's also single node. The biggest issues I've seen on MI300X are related to scaling to multiple nodes.
EDIT: The blog seems to indicate it is using LoRA. So we should remove the backward param pass from the equation above. Backward param only applies to adaptor weights, which are much more than 10x smaller, so we set it to 0 in the approximation. So we get
We've significantly optimized multinode on AMD MI300X for both stability and performance at TensorWave. There were certainly a lot of challenges, but we've become experts at multinode on AMD. We'd be happy to show you the latest results! They are quite compelling
Let's break down the results described in the post.
Context: The post is discussing the performance of a large language model on a MI300X GPU, which is a high-performance computing (HPC) system. The model has approximately 405 billion parameters and is trained using a batch size of 16 and sequence length of 64.
Key metrics:
MFU (Million Floating-Point Operations per Second): This is a measure of the model's performance, specifically the number of floating-point operations (FLOPs) it can perform per second.
FLOPs: The number of floating-point operations required to perform a matrix multiplication, which is a fundamental operation in deep learning.
GPU performance: The MI300X GPU is capable of 1.3 petaflops (1.3 x 10^15 FLOPs) per second in bfloat16 (a 16-bit floating-point format).
Calculations:
The author provides two calculations to estimate the MFU of the model:
Initial calculation: Assuming full weight training (not LoRA), the author estimates the MFU as:
405 billion parameters
2 FLOPs per matrix multiply per parameter
3 matrix multiplies (forward, backward parameter, and backward activation)
Batch size 16
Sequence length 64
30 seconds to complete the calculation
1.3 petaflops per second per GPU
8 GPUs
The calculation yields an MFU of approximately 0.8%.
Revised calculation: After correcting the assumption to use LoRA (a technique that reduces the number of FLOPs), the author revises the calculation by removing the backward parameter pass, which is only applied to adaptor weights (much smaller than the main weights). This yields an MFU of approximately 0.53%.
Interpretation:
The results indicate that the MI300X GPU is not yet optimized for this large language model, with an MFU of only 0.53% (or 0.8% in the initial calculation). This is a relatively low performance compared to the theoretical maximum of 1.3 petaflops per second per GPU. The author notes that the biggest issues are related to scaling to multiple nodes, suggesting that the performance may improve when running on a larger cluster.
The revised calculation using LoRA reduces the MFU by about 33%, indicating that using this technique can lead to a more efficient use of the GPU resources.
MFU means model flops utilitization. it is a measure of efficiency from 0% to 100% . 100% means that the model is running at maximum possible efficiency, i.e. 1.3 petaflops per gpu.
In that case, the results indicate that the MI300X GPU is running the large language model at a relatively low efficiency, with an MFU of 0.53% (or 0.8% in the initial calculation).
This means that the model is only utilizing a tiny fraction of the GPU's maximum theoretical performance of 1.3 petaflops per second. In other words, the model is not fully utilizing the GPU's capabilities, and there is a significant amount of headroom for optimization.
To put this into perspective, an MFU of 100% would mean that the model is running at the maximum possible efficiency, using 1.3 petaflops per second per GPU. An MFU of 0.53% or 0.8% is extremely low, indicating that the model is running at a tiny fraction of its potential performance.
The author's comment that the MI300X results are "promising functionally" suggests that the model is able to run, but the low MFU indicates that there are significant opportunities for optimization and performance improvement.
> The blog seems to indicate it is using LoRA. So we should remove the backward param pass from the equation above. Backward param only applies to adaptor weights
Backward pass still runs on the non adapter weights. But yeah 10 TFlops/GPU specially on tiny sequence size is very bad compared to what you can get on Nvidia. And I believe the difference would be even higher with large sequence length.
Did you consider using https://github.com/AI-Hypercomputer/maxtext ? It has a Jax llama implementation, and gets decent MFU on TPU and GPU (I've only tried it on NVidia GPU, not AMD).
Thank you for the detailed analysis. We need to spend some time thinking and coming up with a price comparison like this. We’ll use this as inspiration!
Unsloth is great! They focus on single-GPU and LoRA fine-tuning on NVIDIA GPUs. We are initially trying to target multi-node, multi-TPU, full-precision training use cases.
That said, in terms of single-GPU speed, we believe we would be behind but not too far off, thanks to JAX+TPU's more performant stack. Additionally, we can do larger-scale multi-node training on TPUs.
There are still more optimizations we need to do for Llama 3.1, such as adding Pallas memory attention kernels, etc