I’m surprised how it’s only 30% cheaper vs nvidia. How come? This seems to indicate that the nvidia premium isn’t as high as everybody makes it out to be.
Also, calculating GPU costs is getting quite nuanced, with a wide range of prices (https://cloud-gpus.com/) and other variables that makes it harder to do apples-to-apples comparison.
I think you slightly misunderstood, and I wasn't clear enough—sorry! It's not a 30-70% speedup; it's 30-70% more cost-efficient. This is mainly due to non-NVIDIA chipsets (e.g., Google TPU) being cheaper, with some additional efficiency gains from JAX being more closely integrated with the XLA architecture.
No, we haven't run our JAX + XLA on NVIDIA chipsets yet. I'm not sure if NVIDIA has good XLA backend support.
At the bottom, it shows the calculations around the 30% cost efficiency of TPU vs GPU.
Our range of 30-70% is based on some numbers we collected from running fine-tuning runs on TPU and comparing them to similar runs on NVIDIA (though not using our code but other OSS libraries).
It would be a lot more convincing if you actually ran it yourself and did a proper apples to apples comparison, especially considering that’s the whole idea behind your project.
It's also comparing prices on google cloud, which has its own markup, a lot more expensive than say runpod. Runpod is $1.64/hr for the A100 on secure cloud while the A100 on Google is $4.44/hr. A lot more expensive... yeah. So in that context a 30% price beat is actually a huge loss overall.