Hacker News new | past | comments | ask | show | jobs | submit login
We fine-tuned Llama 405B on AMD GPUs (publish.obsidian.md)
495 points by felarof 66 days ago | hide | past | favorite | 101 comments



Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD MI300x GPUs using JAX instead of PyTorch. JAX's advanced sharding APIs allowed us to achieve great performance. Check out our blog post to learn about the cool sharding tricks we used. We've also open-sourced the code: https://github.com/felafax/felafax

We're a small startup building AI infra for fine-tuning and serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).

Problem: Many companies are trying to get PyTorch working on AMD GPUs, but we believe this is a treacherous path. PyTorch is deeply intertwined with the NVIDIA ecosystem in a lot of ways (e.g., `torch.cuda` or scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function). So, to get PyTorch code running on non-NVIDIA hardware, there's a lot of "de-NVIDIAfying" that needs to be done.

Solution: We believe JAX is a better fit for non-NVIDIA hardware. In JAX, ML model code compiles to hardware-independent HLO graphs, which are then optimized by the XLA compiler before hardware-specific optimization. This clean separation allowed us to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no changes.

Our strategy as a company is to invest upfront in porting models to JAX, then leverage its framework and XLA kernels to extract maximum performance from non-NVIDIA backends. This is why we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX model works great on TPUs and runs perfectly on AMD GPUs.

We'd love to hear your thoughts on our vision and repo!


I, and several others, had no problem running on PyTorch on AMD GPUs, with no code changes from CUDA. Check out MosaicML's blog posts: https://www.databricks.com/blog/training-llms-scale-amd-mi25...


Again, the problem is custom kernels in CUDA. It’s not straightforward for many applications (LLMs are probably the most straightforward).


Ahh, interesting, will take a look!

Curious what are the steps to run PyTorch on AMD (does it work out-of-box with PyTorch+rocm docker image)? Does torch.compile work smoothly?


While your project is neat and I'd like to see how the performance compares, for LLM training, PyTorch, including torch.compile works completely OOTB on AMD.

All you have to do is pip install the ROCm version of PyTorch (or run the docker image) and it's seamless (the ROCm version just treats torch.cuda as calling ROCm).

I've used axolotl (trl/accelerate based), torchtune, and LLaMA-Factory, which are all PyTorch-based without any issues for training.


Yeah I would suggest taking a look at PyTorch on AMD before saying stuff like "scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function", because that is demonstrably false.

Also, FWIW, I would suggest getting a small Llama 3.1 model training fast before trying to do a big 405B model -- faster to iterate and almost everything you'll learn on the small models will scale to the 405B.


Thanks for the feedback! I appreciate you pointing that out. My understanding was based on the PyTorch documentation for scaled_dot_product_attention (https://pytorch.org/docs/stable/generated/torch.nn.functiona...). - "The function may call optimized kernels for improved performance when using the CUDA backend. For all other backends, the PyTorch implementation will be used."

And was trying to make a broader point about the lack of transparency (in performance, lower-level impl) in PyTorch when running on NVIDIA vs. non-NVIDIA hardware.


> And was trying to make a broader point about the lack of transparency (in performance, lower-level impl) in PyTorch when running on NVIDIA vs. non-NVIDIA hardware.

I don't quite understand this argument. Lack of transparency from running PyTorch so instead we're gonna leave it all to XLA? How does this solve the "transparency" issue?


Having a common library function that is either lighting fast or dog slow depending on the hardware, is not a great position to be in.

Moreover, this will get worse as more CUDA specific features are added to PyTorch with ad-hoc fallback functions.

I guess OP is saying that XLA is more transparent in this regard, because it wouldn’t use functions like these and the generated comparable code would be on-pare performance wise?


> it wouldn’t use functions like these and the generated comparable code would be on-pare performance wise

Perhaps if XLA generated all functions from scratch, this would be more compelling. But XLA relies very heavily on pattern-matching to common library functions (e.g. CuDNN), and these patterns will certainly work better on Nvidia GPUs than AMD GPUs.

In this way, I actually think explicitly calling the common library functions is actually much more transparent.


[flagged]


are you at all confident that this isn't hallucinated? I'd never trust an answer like this from an LLM


Did you verify everything else it said is true?


How are you verifying accuracy for your JAX port of Llama 3.1?

IMHO, the main reason to use pytorch is actually that the original model used pytorch. What can seem to be identical logic between different model versions may actually cause model drift when infinitesimal floating point errors accumulate due to the huge scale of the data. My experience is that debugging an accuracy mismatches like this in a big model is a torturous ordeal beyond the 10th circle of hell.


Good question. We used a new AI+math-based testing tool (benchify.com) to run comparison tests, but we are working on building more robust infrastructure for this. Translating models from PyTorch to JAX is core to our strategy.

That said, this path is not uncommon (translating from one framework to another). HuggingFace translates Google's Gemma family models from JAX to PyTorch, and a ton of people use it.


When you say "model versions", do you mean different quantizations of the model? Then it's not floating point errors that accumulate. Different quantizations of the model are different models. People will call such a model something like Meta-Llama-3.1-8B-Instruct--q4_0, claiming that it's just a "version" of the Meta-Llama-3.1-8B-Instruct. But it's just a lie. It's not the same model, and you should not expect the same results. There is no reason to debug the differences, what exactly would you expect to find, and what action would you envision to take once you find what you are looking for? However, is the quantized version still a useful LLM? Absolutely. Most people don't have an A100 to run the original model, so a quantized version is better than nothing.


Very fascinating, can you explain more about a time when this happened?

Like what area was affected by fp errors, why were they introduced (was it like refactoring of pytorch code?), how was this determined to be the cause?


Does JAX have its own implementations of matmul, flash attention etc? Or does it use the ROCm implementations like PyTorch does? (e.g,. hipblaslt, Composable Kernel FA etc)

Not too familiar with JAX, but the abysmal PyTorch training perf on MI300x is in large part attributable to the slow perf of the ROCm libraries it is using under the hood.


JAX has a sub-system called Pallas[1] with a Triton-like programming model and an example implementation of Flash Attention [2]. It is quite fast. On TPUs I've heard that the XLA compiler already emits a flash-attention-like computation graph for a regular JAX implementation of attention so there's no need to have some specialized kernel in that case.

1. https://jax.readthedocs.io/en/latest/pallas/index.html

2. https://github.com/jax-ml/jax/blob/main/jax/experimental/pal...


Does this work on the consumer grade cards like the 7090 XTX?

And by work I don't mean: spend two weeks trying to get the drivers set up and never update the server again.


A couple months ago I did some testing on some consumer cards. [1] I think you should be able to use torchtune or axolotl without anything besides installing the ROCm version of PyTorch.

[1] https://wandb.ai/augmxnt/train-bench/reports/Trainer-perform...


Am I reading that right that the 7900 XTX is on a par with 3090, and 4090 is twice as fast?


Yeah, those numbers are correct as of their testing (in June) although people who are really interested should check out the linked repo and do their own runs as software/optimizations have continued to change a lot and the RDNA3 side has a lot of untapped potential. Eg, the 7900 XTX has a huge theoretical FLOPS advantage over the 3090 but the results totally don't reflect that. One example of this hobbling is that RDNA3 only recently got backpass FA via a still under-optimized aotriton implementation: https://github.com/ROCm/aotriton/pull/39

There are also still ongoing optimizations on the Nvidia side as well. In the beginning of the year the 7900 XTX and 3090 were pretty close on llama.cpp inference performance, but a few months ago llama.cpp got CUDA graph and FA support implemented that boosted perf significantly for both my 3090 and 4090.

(For AI/ML, a used 3090 remains I think the best bang/buck for both inference and small training runs. You can pay twice as much for the twice as fast 4090, but at the end of the day you'll still wish you had more VRAM, so it's hard to really recommend unless you're going to use mixed precision. The RDNA3 cards are not as bad to work with as the Internet would have you believe, but they'd have to be a lot cheaper if your main use case was AI/ML for both the PITA factor and just from pure real-world performance.)


Damn, ML/AI performance is that different? In games, 4090 -> 7900 XTX is more like -20%


i've been running inference on the 7900xtx using pytorch and rocm (installed directly from package managers, no manual fiddling) with great performance. no problem running the full flux1.dev model, for instance. haven't looked at training or fine-tuning yet.


Given it is a migration, is there actual comparison of the same model on PyTorch vs your version. The comparison table there seems to be on technical side.

Also any technical issues encountered?


We have a few technical issues that we still need to address:

1) This entire fine-tuning run was done in JAX eager mode. I kept running out of memory (OOM) when trying to `jax.jit` the entire training step. Even gradual `jax.jit` didn't work.

2) The current version doesn't have gradient accumulation, and with a batch size of just 16, that’s not ideal. I'm working on implementing gradient accumulation next.

3) We still haven't found a good way to load large sequence-length data (like 32k sequence length). Currently, before sharding the training batch across GPUs, it ends up loading the entire batch onto a single GPU’s VRAM and causes OOM issues.


> I kept running out of memory (OOM) when trying to `jax.jit` the entire training step. Even gradual `jax.jit` didn't work.

Were you using activation checkpointing? https://jax.readthedocs.io/en/latest/_autosummary/jax.checkp... is very important for keeping memory usage reasonable when training large models.


I'm glad to see a full implementation on AMD hardware.

I'm not familiar with JAX, but the idea of providing an abstraction layer to more easily get to work on what hardware is available seems really valuable. Bringing back some competitiveness to the ecosystem will be a big win for workload mobility.

I suspect that price/performance across implementations will be highly dependent on contract details, but do you intend to publish some comparisons in the future?


Any direct comparisons to 8xH100? 2 toks/sec seems very slow!

I haven't done any LoRA training on MI300x myself, but I have done LLama 3.1 full training on 8xMI300x and got pretty close to 8xH100 performance with my own kernels (ROCm is just too slow).


Oops, my calculation was wrong. Let me add an edit to the blog, thanks for pointing it out!

My train step was taking 30s.

And I was using a batch size of 16 and seq length of 64, making the training speed as (16*64/30) tokens per sec == 35 tokens per second (for fine-tuning in JAX eager mode).

(I haven't done comparison with 8XH100)


That’s approximately 0.8% MFU - h100 would get more like 30% or 40% MFU if well tuned

405e9 parameters

2 flops per matrix multiply per parameter

3 matrix multiplies for (forward, backward param, and backward activation) passes

batch size 16

seq length 64

1.3 petaflops per second per GPU in bfloat16

8 GPUs

30 seconds

So that’s 0.8% = (405e9 * 2 * 3 * 16 * 64 / 30) / (1.3e15 * 8)

Note that I’m ignoring the attention flops in this simplified calculation, but they would be a second order effect at this sequence length

Also note that I’m assuming full weight training, not LoRA . The result would be lower MFU if using LoRA

These MI300X results are promising functionally (it's tough to get any model this big running) but they have a long way to go on perf. It's also single node. The biggest issues I've seen on MI300X are related to scaling to multiple nodes.

EDIT: The blog seems to indicate it is using LoRA. So we should remove the backward param pass from the equation above. Backward param only applies to adaptor weights, which are much more than 10x smaller, so we set it to 0 in the approximation. So we get

0.53% = (405e9 * 2 * 2 * 16 * 64 / 30) / (1.3e15 * 8)


We've significantly optimized multinode on AMD MI300X for both stability and performance at TensorWave. There were certainly a lot of challenges, but we've become experts at multinode on AMD. We'd be happy to show you the latest results! They are quite compelling


Promises... promises... https://x.com/DarrickHorton/status/1828902979875017003

Maybe challenging for Darrick, but in reality, you just need to follow the directions and configure things correctly...

https://techdocs.broadcom.com/us/en/storage-and-ethernet-con...


If I feed your reply to an llm will it be able to tell me what you’re talking about?

/edit I think it did alright…


I tried with Llama 3.1 and it hallucinated MFU. It also completely misinterpreted multi-node scaling.

==========================================================

Let's break down the results described in the post.

Context: The post is discussing the performance of a large language model on a MI300X GPU, which is a high-performance computing (HPC) system. The model has approximately 405 billion parameters and is trained using a batch size of 16 and sequence length of 64.

Key metrics:

MFU (Million Floating-Point Operations per Second): This is a measure of the model's performance, specifically the number of floating-point operations (FLOPs) it can perform per second.

FLOPs: The number of floating-point operations required to perform a matrix multiplication, which is a fundamental operation in deep learning.

GPU performance: The MI300X GPU is capable of 1.3 petaflops (1.3 x 10^15 FLOPs) per second in bfloat16 (a 16-bit floating-point format).

Calculations:

The author provides two calculations to estimate the MFU of the model:

Initial calculation: Assuming full weight training (not LoRA), the author estimates the MFU as:

405 billion parameters

2 FLOPs per matrix multiply per parameter

3 matrix multiplies (forward, backward parameter, and backward activation)

Batch size 16

Sequence length 64

30 seconds to complete the calculation

1.3 petaflops per second per GPU

8 GPUs

The calculation yields an MFU of approximately 0.8%.

Revised calculation: After correcting the assumption to use LoRA (a technique that reduces the number of FLOPs), the author revises the calculation by removing the backward parameter pass, which is only applied to adaptor weights (much smaller than the main weights). This yields an MFU of approximately 0.53%.

Interpretation:

The results indicate that the MI300X GPU is not yet optimized for this large language model, with an MFU of only 0.53% (or 0.8% in the initial calculation). This is a relatively low performance compared to the theoretical maximum of 1.3 petaflops per second per GPU. The author notes that the biggest issues are related to scaling to multiple nodes, suggesting that the performance may improve when running on a larger cluster.

The revised calculation using LoRA reduces the MFU by about 33%, indicating that using this technique can lead to a more efficient use of the GPU resources.

==========================================================

MFU means model flops utilitization. it is a measure of efficiency from 0% to 100% . 100% means that the model is running at maximum possible efficiency, i.e. 1.3 petaflops per gpu.

==========================================================

In that case, the results indicate that the MI300X GPU is running the large language model at a relatively low efficiency, with an MFU of 0.53% (or 0.8% in the initial calculation).

This means that the model is only utilizing a tiny fraction of the GPU's maximum theoretical performance of 1.3 petaflops per second. In other words, the model is not fully utilizing the GPU's capabilities, and there is a significant amount of headroom for optimization.

To put this into perspective, an MFU of 100% would mean that the model is running at the maximum possible efficiency, using 1.3 petaflops per second per GPU. An MFU of 0.53% or 0.8% is extremely low, indicating that the model is running at a tiny fraction of its potential performance.

The author's comment that the MI300X results are "promising functionally" suggests that the model is able to run, but the low MFU indicates that there are significant opportunities for optimization and performance improvement.


MFU being maximal floating point utilization right? I used chatgpt 4o and it otherwise looks very similar to your llama result


> The blog seems to indicate it is using LoRA. So we should remove the backward param pass from the equation above. Backward param only applies to adaptor weights

Backward pass still runs on the non adapter weights. But yeah 10 TFlops/GPU specially on tiny sequence size is very bad compared to what you can get on Nvidia. And I believe the difference would be even higher with large sequence length.


backward activations does but typically not backwards weight gradients.

Why compute gradients with regards to weights that aren't going to be updated?


Is there some cost rule of thumb to compare Nvidia, AMD, and Google TPU?


Good question. No good metric give performance depends on software stack (JAX vs PyTorch) + optimizations.

But my take performance per dollar of TPU > AMD > NVIDIA.


TPUs are slow but efficient and AMD has bugs but for some things works quite well. Nvidia is obviously the gold standard.


Haha, TPUs are not slow :) All of Google's training (including Gemini models) is done on TPUs.

There are good 1p [a] and 3p [b] benchmarks comparing TPUs vs NVIDIA GPUs.

[a] - https://github.com/GoogleCloudPlatform/vertex-ai-samples/blo...

[b] - https://arxiv.org/pdf/2309.07181


Did you consider using https://github.com/AI-Hypercomputer/maxtext ? It has a Jax llama implementation, and gets decent MFU on TPU and GPU (I've only tried it on NVidia GPU, not AMD).


Could you share performance so we could compare?


Do you see tinygrad as a useful lower level abstraction or is JAX sufficient to get pref out of AMD GPUs?


Tinygrad is great, but still in early stages I believe.

JAX has matured a lot over last 6 years and XLA has been around for lot longer. We believe we can extract good perf from AMD with JAX + XLA kernels.


scaled_dot_product_attention isn’t CUDA specific, it even works on TPUs.


To be clear, this performance is quite bad (presumably because you didn't manage to get compilation working).

You're getting 35 tokens/s for a 405B model, which comes out to about 85 Teraflops. 8 MI300x GPUs comes out to 10.4 Petaflops, so you're getting about 0.8% MFU (which is about 40-50x worse than decent training performance of 30-40% MFU).

For AMD's sake, I hope that it's your software stack that's limiting perf.


That's exactly what I wanted to ask:

Their github page claims that it is possible to "tune LLaMa3.1 on Google Cloud TPUs for 30% lower cost", but they don't mention performance.


Firstly great work! I dabbled with AMD GPUs and ROCm support a year ago, and it was obvious AMD still a long way from catch ling up with Nvidia. While opting for JAX is in an interesting approach, what were the challenges for you deviating from pytorch (being the standard library for ML)?


A few weeks ago, I did a Show HN explaining our journey: https://news.ycombinator.com/item?id=41512142.

We initially started with the goal of fine-tuning LLaMA 3 on TPUs, but PyTorch XLA was clunky, so we decided to rewrite the model in JAX. That said, as mentioned earlier in the thread, we also believe JAX is a better platform for non-NVIDIA GPUs and want to build on JAX+openXLA for building infra for non-NVIDIA GPUs.


I cannot get AMD ROCm running on my debian 12 system which is what I think is causing Ollama to use CPU instead of GPU. So I guess there is still a long way to go.


At the risk of pissing people off, I think you may be better served by a distribution that provides a more up-to-date kernel. Debian 12 will give you Linux 6.1 LTS, which is probably OK if you're using an older Radeon card, but I've heard support for the 7900 XT/X series is a bit dicey and beyond that (e.g. Radeon 890M) non-existent.

If there were improvements on the AMDGPU DRM driver side, you would not see them in Debian any time soon, as the 6.1 LTS kernel will be stuck with roughly whatever shipped January of last year. This is just a shortcoming in the Linux kernel, due to its lack of any kind of stable ABI for drivers.

Of course it is possible this would help nothing or even hurt. My experience running stable (or even newer) kernels has been quite good, though. I run stable or newer across a few devices and run into hiccups not more than once every few years, which is definitely worth it to be able to get new driver improvements years in advance.

(FWIW Debian is not even supported by ROCm[1]... although distros with even older kernels are. But, even if ROCm works, I can't imagine you will get ideal hardware support when running older kernels. I am not sure if ROCm has some workaround for enterprise Linux distributions specifically, but it feels like they must, given how many of their customers in the datacenter are likely to want to use them.)

[1]: https://rocm.docs.amd.com/en/latest/compatibility/compatibil...


> I've heard support for the 7900 XT/X series is a bit dicey

The firmware-amd-graphics package in stable is too old to properly support RDNA 3. It kind of works, but it is quite buggy. All RDNA 3 users on Debian 12 should be sure to install the kernel and firmware from bookworm-backports.

There is full support for RDNA 3 hardware enabled on Debian Testing (both in the drivers and runtime libraries). The Debian ROCm Team intended to backport all the ROCm packages from Testing into Bookworm, but have been held up as LLVM 17 is not available in bookworm-backports (yet?).

> FWIW Debian is not even supported by ROCm

ROCm does not support Debian, but Debian supports ROCm. Most of the libraries that comprise ROCm have been directly packaged by the distribution.


You'd probably have a lot better luck using Vulkan acceleration (not ROCm) of llama.cpp as backend to ollama. It is incomparibly easier to set up and maintain compared to ROCm. You can actually do it on your computer's normal OS instead of inside a bunch of container/vms where the system libs are entirely customized to running just that one application.

AMD's support of consumer cards is very, very short. By the time it's stable enough for a new card to run the card is no longer supported. In 2021 I bought an AMD GPU that came out 3 years before and 1 year after I bought it (4 years since release) they dropped ROCm support.


ROCm is not even worth the effort for inference workloads. Vulkan is much more convenient and performs fine.

llama.cpp and stable-diffusion.cpp offer Vulkan backends but generally you can run most models on Vulkan if you use IREE[1].

[1] <https://iree.dev/guides/ml-frameworks/>


While Vulkan can be a good fallback, for LLM inference at least, the performance difference is not as insignificant as you believe. I just ran a test on the latest pull just to make sure this is still the case on llama.cpp HEAD, but text generation is +44% faster and prompt processing is +202% (~3X) faster with ROCm vs Vulkan.

Note: if you're building llama.cpp, all you have to do is swap GGML_HIPBLAS=1 and GGML_VULKAN=1 so the extra effort is just installing ROCm? (vs the Vulkan devtools)

ROCm:

  CUDA_VISIBLE_DEVICES=1 ./llama-bench -m /models/gguf/llama-2-7b.Q4_0.gguf
  ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
  ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
  ggml_cuda_init: found 1 ROCm devices:
  Device 0: Radeon RX 7900 XTX, compute capability 11.0, VMM: no
  | model                          |       size |     params | backend    | ngl |          test |                  t/s |
  | ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
  | llama 7B Q4_0                  |   3.56 GiB |     6.74 B | ROCm       |  99 |         pp512 |      3258.67 ± 29.23 |
  | llama 7B Q4_0                  |   3.56 GiB |     6.74 B | ROCm       |  99 |         tg128 |        103.31 ± 0.03 |

  build: 31ac5834 (3818)
Vulkan:

  GGML_VK_VISIBLE_DEVICES=1 ./llama-bench -m /models/gguf/llama-2-7b.Q4_0.gguf
  | model                          |       size |     params | backend    | ngl |          test |                  t/s |
  | ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
  ggml_vulkan: Found 1 Vulkan devices:
  Vulkan0: Radeon RX 7900 XTX (RADV NAVI31) (radv) | uma: 0 | fp16: 1 | warp size: 64
  | llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |         pp512 |       1077.49 ± 2.00 |
  | llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |         tg128 |         71.83 ± 0.06 |

  build: 31ac5834 (3818)

EDIT: HN should really support markdown...


> ...so the extra effort is just installing ROCm? (vs the Vulkan devtools)

The problem with ROCm is that for non-bleeding edge AMD cards you have to install an out of date unsupported version of it because the $current version does not support your card. And that means containerization woes. If you're going to spend $800 on a top of the line current generation video card anyway then you'll have fewer problems (for a few years).

Also, the vulkan vs. rocm performance difference for non-bleeding edge non-top of the line cards is smaller.


Radeon RX 7900 XTX is RDNA3 but I wonder if llama.cpp is using the Vulkan matrix instructions wmma and mfma.

I have not noticed any remarkable differences between Vulkan and ROCm when using IREE but it's not a turnkey solution yet[1].

[1] <https://github.com/nod-ai/sharktank/blob/main/docs/model_coo...>


Any chance we might see Vulkan extensions to close this performance gap? Was really hoping Intel and AMD would team up to vreate an open standard that we could all have installed by default, but instead we get these clumsy vendor-specific solutions...


I think that it is very unlikely that the performance difference is caused by anything that could be solved with a Vulkan extension.

Vulkan only exposes the raw compute capabilities of the hardware and any well optimized Vulkan application can reach the full performance, but you need to write such optimized code.

On the other hand, ROCm, like CUDA, includes optimized libraries for certain applications, like rocBLAS.

It is likely that here the ROCm backend uses optimized library functions, perhaps from rocBLAS, while the Vulkan backend might use some generic functions for linear algebra, which are not optimized for the AMD GPUs.


Thank you, I did not know there is an alternate route to achieve GPU usage. Will look into this.


ROCm versions are fairly closely tied to kernel versions. Debian 12 should run fine with the ROCm that was released around the time of whatever kernel you're running, but it's going to be a bad experience mixing the latest ROCm with an elderly kernel or vice versa.

Old kernel + old kernel driver + new rocm => the driver doesn't really know what the userspace is doing and you get the bugs which have been fixed since

Old kernel + new kernel driver => very ymmv, the internel kernel api is not stable

New kernel + matching driver, old rocm is probably OK, unless you're using upstream clang in which case it's all bad once more

ROCm was designed and implemented in the HPC environment, where you know the exact kernel in use and the whole stack is deployed as one self consistent lump. Driver, compiler, libraries and so forth. It's not having such a good time in the Linux world of mix and match because the aggressive internal testing structure assumes you're using a consistent system. Backwards/forwards ABI and API compatibility is difficult, expensive and slow so it's not where the money is being spent. Rightly so, probably.


I've had more luck with the ROCm docker container. I run it via k8s. It was pretty painless to set up and has been mostly painless since. Prior to that it was nearly impossible to get Jax running reliably on ROCm.

Even with the container, you have to be careful installing Python libraries because they can still break things.


I just recently went down the AMD GPU + ROCm rabbit hole as well. ROCm 6.2 was just released in August of this year and introduces a lot better support, though as the above poster mentioned, isn't merged into most recent OSes.

This Github repo is good for tracking the latest Ubuntu + ROCm install process: https://github.com/nktice/AMD-AI


That's a nice repo of random installation notes. Very helpful, thanks!


Like everything in machine learning it only really runs on Ubuntu 22.04. Anything else is unsupported and you need to spend weeks tinkering to get it to work, then never upgrade.


Nice work! I was just playing with the inference side of things with 405B myself this weekend [0].

I'm not convinced that 'torch.cuda' is really that bad since the AMD version of PyTorch just translates that for you. More like a naming problem, than anything. Fact is that it is just as easy to grab the rocm:pytorch container, as it is the rocm:jax container.

I don't see very many numbers posted. What MFU did you get?

[0] https://x.com/HotAisle/status/1837580046732874026


Nice!

I need to calculate MFU. GPU, VRAM details can be found in the repo: https://dub.sh/amd-405b-res.

I plan to reattempt the training run next weekend and JIT the entire training step to calculate MFU then


We (ZML) measured MI300X at 30% faster than H100. These are great chips!


Does any Cloud provider have a 8xAMD MI300 host that one can rent? I use AWS for a lot of my professional work, and was hoping to try out an AMD GPU.


Disclosure: my company rents 8xMI300x, contact us.


Oracle do. Others will probably follow, though I expect the smaller players are more reasonable to interact with.


Has this changed? Previously, my understanding was that they only deal with large quantities.


That's a very good point. I know they've got boxes because I've used them, but didn't deal with any of the setup myself and we're probably a large quantities case.


Still hard to tell, but pushing 16k gpus on people doesn't seem be encouraging in this regard.

https://ir.amd.com/news-events/press-releases/detail/1217/am...

I'm happy to take the stragglers though. ;-)


Where is the performance data?


(author here, sorry for the delay in replying, was stuck in back-to-back meetings)

I updated our github repo to include GPU, VRAM utilization data (https://github.com/felafax/felafax?tab=readme-ov-file#amd-40...)

Note: we couldn't run the JIT-compiled version of the 405B model due to our code/VRAM constraints (we need to investigate this further). The entire training run was executed in JAX eager mode, so there is significant potential for performance improvements.

GPU utilization across the board was still ~30-40% even with eager mode, which is quite good! With JIT, I think the GPU util can easily shoot up to ~50-60%.


I’d be interested to hear how long it takes, wall clock, to train the same model to the same loss with same number of batches. I don’t trust utilization to say anything useful.


If possible, it would be interesting to explore ways to overcome the memory constraints and run a JIT-compiled version. This could potentially lead to further performance improvements.


+1, we still have a lot of performance we can extract! JIT-compiled train steps, more optimized data loading and sharding, gradient accumulation, and activation checkpointing. We will continue building and will do another blog soon after implementing all the improvements!


Is AMD any closer to extracting value from this with large orders of their GPUs causing a shortage?

I’m getting the impression of “no”


I get your sarcasm. But at this point, unless you are willing to give up all of AI to a single source for hardware and software, we've got to start to work towards alternative solutions. People are working against a massive head start and there is clearly a lot to work on on the software side. Give it time.


Why is obsidian (a note-taking app) doing this?


They aren't. This company is using obsidian publish to publish documents.


Weird.


It's a very easy way to write some markdown that renders as a website. If the URL didn't have obsidian in the title I wouldn't have guessed it was involved.


How do you buy such a GPU or is it still only reserved to the rich so they can get ahead of the game once the pleb gets their unwashed hands on these cards?


Buying them is pretty easy. You call up Dell and place an order. Let me know if you want to talk to a sales person, or we can help source them for you.

The hard parts are that you have to buy 8 of them at a time, you need cooling, each server takes up around 10-11kWh and the fans sound like a 747 taking off. Oh and they cost a small fortune.

Disclosure: I handle the hard parts and you can rent them from me.


Thought this was a post from Obsidian at first. Why haven't they done the GitHub.com vs GitHub.io thing yet.


Looking at the URL has me thinking that this confusion would be resolved if HN adds a small piece of logic to treat the domain publish.obsidian.md specially, just like how HN already does for pages served under forbes.com/sites which is not written by the Forbes staff themselves.

So instead of showing the domain as obsidian.md, HN would show the domain for this link as publish.obsidian.md

Maybe something for dang to consider if he sees this comment?


Same thought here. Why would Obsidian bother with AI? Oh wait, this is publish? So this is what $8 per month gets you? I am amazed, as I would have at least expected a subhost: [username].publish.obsidian.md


Yeah, used Obsidian Publish.

But struggling to get custom domain to work with it (have emailed support).


Resolved?


@dang: could we get url to include the username since this isn't about Obsidian itself, but rather a user generated blog?


It's strange that HN didn't include the full domain "publish.obsidian.cmd".


That's not turned on by default but I've done it for this domain now.


Thanks dang!


That's something obsidian should fix if they care about not looking like they are being impersonated on HN.


Obsidian can't do anything about it. It's HN chopping up the url


[flagged]


I guess it doesn't matter if this is human or bot, an advert's an advert.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: