Hacker News new | past | comments | ask | show | jobs | submit login
JAX – NumPy on the CPU, GPU, and TPU (jax.readthedocs.io)
276 points by peter_d_sherman on Sept 29, 2023 | hide | past | favorite | 143 comments



It took me a while to realize it, but Jax is actually a huge opportunity for a lot of scientific computing. Jax was originally developed as a more flexible platform for doing machine learning research. But Jax's real superpower is that it bundles XLA and makes it really easy to run computations on GPU or TPU. And huge swathes of scientific computation basically run large scale vectorized computations.

When I was in astronomy (about a decade ago) I did large scale simulations of gravitational interactions. But at the time all these simulations were done on CPU. Some of the really big efforts used more specialized chips, but it was a huge effort to write the code for it.

But today with Jax, if you want to write an N-body simulation of a globular cluster, you can just code it up in numpy and it'll run on a GPU for free and be about 1000x faster. From what I can tell though, very few people in the sciences have caught on yet.


Luckily I discovered JAX in the beginning of my PhD four years ago. It has made our data processing (biomedical imaging) so much easier and more readable, albeit with a slight learning curve due to JAX being functional/pure.

I am also continously surprised how little adoption JIT and autodiff libraries have gotten in scientific computing. A lot of my colleagues somehow really like coding cost function gradients and fine-tuned GPU code by hand. I guess using something like JAX can reduce your standing in the group, because it can make it seem like coding algorithms is pretty easy.


I feel like I see the opposite -- that everything scientific computing is getting rewritten in something autodifferentiable! Whether that's JAX or something else.

My experience might be biased though: shameless advert for Equinox (https://github.com/patrick-kidger/equinox, 1.4k GitHub stars), which is now the foundation of quite a lot of SciComp in JAX. (Both internal and open-source.)


> foundation of quite a lot of SciComp in JAX

...if that SciComp uses machine learning, I guess? In my "physics of biomedical imaging" bubble, people are hardly doing state-of-the-art ML, but rather expensive forward models for which computing a gradient is cumbersome.

But I know that e.g. Stephan Hoyer is a physicist and you are a mathematician originally -- I have read a lot of your JAX issues and libraries ;-) maybe it just depends on the "mini-bubble' aka. the indiviual research group and not only the field of science.


> if that SciComp uses machine learning, I guess?

Not necessarily! It's perfectly possible (and quite common) to e.g. write down a traditional parameterised ODE, and then optimise its parameters via gradient descent. Compute the gradients wrt parameters using autodiff through the numerical ODE solver. All without a single neural network in sight! ;)

My usual spiel is that autodiff+autoparallel are really useful for any kind of numerical computation -- of which ML is a (popular, well funded) special case.

At least in my mini bubble, these kinds of "scipy but autodifferentiable" use-cases are fairly common.

> I have read a lot of your JAX issues and libraries ;-)

Haha, that's fun to hear though! Thank you for sharing that.


Ah, I think I was unclear. I specifically meant your reference to Equinox, because that seemed to me to be somewhat ML specific.

In general, I very much agree that "autodiff+autoparallel are really useful for any kind of numerical computation". And the use cases are also really common in my bubble. It's just that (imho) most people have not realized this.


Ah right! Actually it's a good point, the Equinox readme/etc do tend to emphasise the ML use cases -- partly this is deliberate (go where the money is)! But I should probably tweak it to emphasise more general parameterised models.


> In my "physics of biomedical imaging" bubble, people are hardly doing state-of-the-art ML, but rather expensive forward models for which computing a gradient is cumbersome.

I’d appreciate any pointers to the literature; curious to see the kinds of models people work with. Thanks!


I don't have didactic examples at hand, but e.g. [1] or [2]. IIRC [1] uses the Laplace operator (second-order spatial derivate) and [2] uses a linear solve inside the forward model through which differentiation is certainly possible but pretty cumbersome in practice.

[1] https://www.nature.com/articles/s41598-019-52283-6

[2] https://doi.org/10.1117/1.JMI.4.3.034005


> "It took me a while to realize it, but Jax is actually a huge opportunity for a lot of scientific computing."

In all conferences like NeurIPS, in Google ML Community days, etc., whenever there is a JAX workshop/tutorial/talk, it is always touted as a numerical computation library. And it was developed as such. Sure the focus is in ML, but everyone involved in it always have said that this is a general purpose scientific computing library.

Flax, Haiku, etc. are Deep Learning libraries.


Meanwhile, the first sentence in their readme is this:

> JAX is Autograd and XLA, brought together for high-performance machine learning research.

That does not really convey the generality of it that well.


You're right! Maybe we should revise that... I made https://github.com/google/jax/pull/17851, comments welcome!


So was tensorflow... And yet it's pretty much dead.


None of these libraries unfortunately allows making good use of CPU vectorized units. Xla might produce some SIMD code but it pales in (performance) comparison to routines written explicitly for SIMD on GPU. ISPC is a good example of this.


the issue here is that if your ideal algorithm isn't simply expressible in numpy (which many aren't), you're pretty much out of luck. As a result, imo the better approach is to use a fast language that also compiles to GPU (e.g. Julia)


Having used JAX quite a bit for numerical computing (and having lectured on this use-case) I would say that a surprisingly large number of algorithms can be expressed as array[0] operations (even if it sometimes takes a bit of thinking).

And, more importantly, things that cannot be expressed that way tend to not be a good fit for GPU computing anyway (independently of the language / framework you are using).

[0]: `array` is a shortcut here, JAX is not limited to operations on arrays.


Agreed. I've done a fair amount of reworking signal processing algorithms to run on GPU/TPU, and it's a different beast. You often have to really rebuild the algorithm from the ground up to take advantage of parallelization. But often you /can/ rework the algorithm, and end up with much higher throughput than the crusty old serial algorithm: there's typically nothing fundamentally stopping you from finding a good implementation, just that the original devs were working in the 70s and hasn't thought that far ahead.


> things that cannot be expressed that way tend to not be a good fit for GPU computing anyway

I'll have to disagree with you a little bit here. SIMT model of GPUs are quiet a bit more expressive than the numpy's SIMD model. As an obvious example, you'll have to manually maintain a mask to implement if/else i.e. code path divergence in SIMD. GPUs automatically does this and many more to make your life easier. And frankly, I find it lot more easier to reason about what should happen to one data point than a bunch of them together.

An interesting article I read recently that has some relevance to this discussion. https://pharr.org/matt/blog/2018/04/18/ispc-origins


Convenient authoring doesn't necessarily make it a good fit for the hardware. Add in enough divergence and your GPU code is going to be matched or outperformed by a competent CPU implementation (on a chip of comparable size). Branchless code can result in substantial speedups on either.

To be fair though, modern GPUs are pretty good at branching and latency hiding, while numpy-style code has poor data locality unless you have a magic compiler.


To spell out what the linked ISPC post implies, most of the difference, like ISPC shows, is differences in GPU languages and compilers vs CPU side equivalents.


Pretty sure I got multiple 1000x speed ups when I vectorized my my algo trader from a dumb python loop to a dumb numba compiled thing, and when I benchmarked Jax, the performance blew away the numba thing (which was already a million times faster than the naive version) because Jax performance stayed perfectly flat as the scale went up whereas numba slowed down. Might have been my approach for each, but it was enlightening and funny to watch.


I have at least one complaint with the numpy model:

When you chain a sequence of vectorized operations on arrays, loop fusion would save you from allocating memory for each intermediate variable, and the round trip time of moving it from RAM to CPU multiple times. I don’t know how good JAX’s JITted loop fusion is on CPU, but I’ve been very very impressed by Julia.

Eg: I had some Numpy code that took hours (and needed terabyte RAM) that was very straightforward to code in Julia, and needed only a few GB to finish in a few seconds — on my laptop.

I want to be able to think in arrays, but to also not have to materialize the arrays as much as possible.


I found this to be not true in practice when working with graphs.

Having access to high performance explicit loops and ifs/masks allows one to focus on the hard parts of the algorithms, rather than on the purely incidental puzzle how to best avoid spending time in the Python runtime.


An alternative is to write most of the program in Python + JAX + implement a few custom XLA ops in CUDA / Triton. That way, the program is very readable and can interoperate with the larger ecosystem, while still being fast to run.


Jax JIT of scan is fairly good, so loops aren't as slow as you'd expect.


A key difference is that each iteration of scan is called by the host. Put differently, JAX can't fuse scan into a single GPU kernel, but launches a kernel for each iteration.

Depening on the workload this is no problem. If you have many cheap iterations, you will notice the overhead.

I am not sure if they are working on fusing scan and what's the current status.


There is an ‘unroll’ parameter in scan that lets you control how many iterations of the loop are fused into a single kernel.


Yes, but is it really the same? Afaik the `unroll=n` parameter translates `n` iterations into a vanilla `for` loop which is then unrolled into sequential statements (in contrast to a JAX `fori` loop). There still is no loop on the accelerator, strictly speaking?


I think this is up to XLA to handle not Jax. The whole selling point in TF of the tf.function decorator (which uses XLA underneath as well) is that it fuses arithmetic to lower launch count.


there's a paper about it that I just found, enjoy https://arxiv.org/pdf/2301.13062


Jax is super useful for scientific computing. Although nbody sims might not be the best application. A naive nbody sim is very easy to implement and accelerate in jax (here’s my version: https://github.com/PWhiddy/jax-experiments/blob/main/nbody.i...), but it can be tricky to scale it. This is because efficient nbody sims usually either rely on trees or spatial hashing/sorting which are tricky to efficiently implement with jax.


Have you seen JAX MD? https://github.com/jax-md/jax-md


I've seen it although haven't dived deep into it. It looks like they have some interesting support for particle cell data structures, but is fairly complicated and carries limitations: https://jax-md.readthedocs.io/en/main/_modules/jax_md/partit...


Last time I looked at JAX MD it didn't support most of the force field terms necessary for simulating proteins and DNA. For example, it could do n-body simulations with some potential, but not the bonds/torsions between atoms. It's unclear if they added support, but that's a huge gap in functionality compared to other systems.


we did an interview with Chris Lattner of XLA fame where he also similarly had nice things to say about JAX: https://www.latent.space/p/modular

just sharing for those who want to learn more


The open source release of XLA predates Lattner's tenure at Google by 7 months, and it definitely existed before that -- the codebase was already 66k SLOC at that point. During his tenure it went from 100k SLOC to 250k SLOC. It's now 700k SLOC. He also has, as far as I can tell, zero commits in the XLA codebase. "Of LLVM fame" would be more accurate I think.


my bad - i guess i was just saying he led that team but didnt mean to imply he originated it

you seem to have very precise knowledge of the SLOC at a point in time - just curious is there any tooling you used to do that? that can be pretty nifty to pull out on occasion


I git cloned the repo and then ran sloccount after checking out various commits (just did `git log | grep -C3 'Jan 1 [0-9:]* 2017'` or similar to find the relevant commits)


ha, simple enough. thx


> large scale simulations of gravitational interactions

I'm guessing this was mostly Fast Multipole Method? I don't think it ports that easily GPU since there's so much communication involved and the leaves don't do a whole lot


Yes. But some of the algorithms cannot benefit that much from the GPU. In my field -- mathematical optimization, lots of algorithms rely on sparse matrix operations and takes many iterations until convergence.



Nope it's super slow for large sparse matrices. It's even faster to use generic scatter/gather to implement some, instead of that built in thing.


Have you investigated why? I know that many projects have an "implement first, optimize later" approach, and the lesser used functions might be far from optimal.

Back in the tensorflow days, I had this issue and submitted a patch that gave a ~50x speedup for my usecase. It's always better to optimize the base function rather than have 100 people all manually working around the same performance issue.


Because they use a funny format (BCOO). I'm not mocking, it must be a solid choice for some reasons, like sparsification or other fancy stuff. But for large and even with batches (ie multiply with tall dense matrix), it doesn't match an equivalent scatter (x.at[idx].add(vals)). Which itself is several times slow than equivalent opencl (on an A40)


> it'll run on a GPU for free and be about 1000x faster.

Are there any benchmarks for that? Running on GPU never comes for free. You have to transfer data back and forth which has a cost, for instance.


That was just from some quick benchmarks I did a few months back on some 10,000 particle N-body simulations. The performance boost will depend on the task, though. For the kinds of computations I did in grad school it would have been less effective since I was only looking at 3--5 objects, so there's just less parallelism to take advantage of.


Would you happen to have sources on the three orders of magnitude speedup coming at no cost? I'd assume porting + data movement considerations making this task non-trivial.


Yep. I’ve been writing a molecular dynamics simulator in PyTorch and the reduction in scope is wild because of heterogenous operators and automatic differentiation.


Porting existing codes is still a massive effort and there is low faith in long-term support from Google based software and hardware. I’m not aware on much (any) TPU use in scientific HPC.


There is not yet. But there is huge pressure on HPC Centers at least in Europe to also make resources available for ML. Already many scientific supercomputers have GPUs as accelerators. So we might have the other situation: HPC users are faced more and more with machines with spare accelerators and it will make sense to use them. Actually it would totally make sense if JAX development is in part also public also financed in this case (e.g. through EuroHPC).


It's so awkward that these truly fantastic tools for fast numerical computation (NumPy and JAX) have to be accessed through Python, which is a truly terrible language for fast numerical computation.

Is anyone making any serious progress in fast GPU based computational tools for other faster languages? I'm looking for something that also works on the GPU on windows (unlike JAX)


> have to be accessed through python

It’s because most of the people doing these computations don’t have the capacity to become experts in multiple fields. They understand the math and analytics very well, and they expend all their time thinking about that, not about type systems, memory management, etc. Python lets them code without having to think about a lot of that stuff so they can focus on the things they care about. These aren’t computer scientists or programmers, they’re meteorologists, astronomers, oil and gas analysts, investment bankers etc. That’s why some truly great computer scientists and programmers invested their time into building these tools for python vs other languages.


You might want to look into Futhark https://futhark-lang.org/


This type of code isn’t executed by the Python interpreter.


I think JAX is cool, but I do find it slightly disingenuous when it claims to be "numpy by on the GPU" (as opposed to PyTorch), when actually there's a fundamental difference; it's functional. So if I have an array `x` and want to set index 0 to 10, I can't do:

  x[0] = 10
Instead I have to do:

  y = x.at[0].set(10)
Of course this has advantages, but you can't then go and claim that JAX is a drop in replacement for numpy, because this such a fundamental change to how numpy developers think (and in this regard, PyTorch is closer to numpy than JAX).


Agree, though I wouldn’t call PyTorch close to a drop-in for NumPy either, there are quite some mismatches in their APIs. CuPy is the drop-in. Excepting some corner cases, you can use the same code for both. E.g. Thinc’s ops work with both NumPy and CuPy:

https://github.com/explosion/thinc/blob/master/thinc/backend...

Though I guess the question is why one would still use NumPy when there are good libraries for CPU and GPU. Maybe for interop with other libraries, but DLPack works pretty well for converting arrays.


Why is that? Why doesn't Jax just do something like

    class JaxWrapper:
        def __init__(self, arr):
            self.arr = arr
        def __setitem__(self, key, val):
            return self.arr.at[key].set(val)
        ....


On the other hand, if I wanted some scientific NumPy code to run on the GPU, I think rewriting it in JAX would probably be a better choice than PyTorch.


In my experience, the answer comes down to "does your code use classes liberally?"

If no, you're just passing things between functions, then go ahead with Jax! But converting larger codebases with classes is just significantly better with PyTorch even if they use different method names etc.


I'm going to disagree here! Classes and functional programming can go very well together, just don't expect to do in-place mutation. (I.e. OO-style programming.)

You might like Equinox (https://github.com/patrick-kidger/equinox ; 1.4k GitHub stars) which deliberately offers a very PyTorch-like feel for JAX.

Regarding speed, I would strongly recommend JAX over PyTorch for SciComp. The XLA compiler seems to be much more effective for such use cases.


Sorry for my potentially VERY ignorant question, I only know functional programming at average joe level.

Why can't you do the first in functional programming (not in this specific case because it's just how it is, but in general)?

And even if you can't do so for any reasonable reason in functional (again, in general), what stops us to just add syntactic sugar to equal it to the second to make programmer's life easier?


There's 2 different aspects people mean when they call sth functional programming:

- higher order functions (lambdas, currying, closures, etc.)

- pure functions, immutability by default, side effects are pushed to the top level and marked clearly

The first aspect of functional programming has been already accepted by most OOP languages (even C++ has lambdas and closures).

The second aspect of functional programming is what makes it useful on GPU (because GPU architecture that makes it so powerful requires no interactions between code fragments that are run in parallel on 1000s of cores). So you can easily run pure functional code on GPU, but you can't easily run imperative code on GPU.

You can introduce side effects to functional programming, but then it ceases to be any more useful for GPU (and other parallel programming) than imperative/OOP.


The fundamental reason why many functional languages won't allow you to do the first is that they use immutable data structures.

We could indeed introduce syntactic sugar (`y= (x[0]:=10)` maybe), but you'll still need to introduce a new variable to hold the modified list.


It's my understanding that, at least in Python, you can't change immutable data type but you can just assign a new data to the same variable and therefore overwrite it, right? So even if JAX makes list type immutable, you can still just re-use `x` to save the new modified list.


doesn't `[] =` just call a method on the object in python?

e.g, `x[0] = 10` is the same as `x.__set_item__(0, 10)`, so there shouldn't be any technical limitation to using `x[0]` (says the guy who never even imported jax)


You could do `y = x.__setitem__(0, 10)`, but you cannot assign `x[0] = 10` to a new variable. If `__setitem__` was overridden, you would not be able to distinguish between these cases and raise an error in the second one.


Yes, that makes perfect sense.

I somehow completely missed the assignment part of the second example.

Thank you for the clarification.


Also conditionals can be tricky (greater, if else) and often need rewriting.


I love JAX. It can be a great replacement for Numba or whatever, the @jit works really well. And vmap is amazing... I often get lost in the matrix sauce when batching, using JAX I develop as if it's just a single instance, then vmap that shit. The ecosystem I was using was: jax, optax, haiku.

The big issue I had was: I was developing on the CPU, then moved to running it on a GPU, and it wasn't as fast as I expected-- I started debugging, and saw there was still lots of communication between the CPU and GPU even tho it was all jit'd. I think PyTorch is a more user friendly for writing high performance models if you're not straying too far from the beaten path. But I really love JAX would like to play around with it more to understand these pits I'm falling into.

And another complaint is I can't run it on my Macbook M1 GPU... but I'm seeing this page now, so maybe that's not true anymore: https://developer.apple.com/metal/jax/


I’ve used Numpy with Numba primarily (on the CPU) and it had been a game changer for my data science workloads.

Naturally, I pay close attention to Jax and give it a look every now and then. So, I’ll focus my observations below on Jax’s Numpy API support.

At a glance, Jax code looks like regular Python, but it’s a very different style of programming. Two big differences I’ve found are:

- All Jax functions must be pure. You can’t pass references. - ndarrays cannot be created with dynamic shapes. You have to hardcode the shape tuples. One possible workaround can be to create a buffer much bigger than you need and return that along with actual shape.

Then there are many small things that are very well documented[1] by the Jax team.

Overall, if you are training ML models, the trouble might be worth it (Autograd). But for accelerating Numpy alone, it is no Numba replacement - which will happily work in the above mentioned use-cases.

[1] https://jax.readthedocs.io/en/latest/notebooks/Common_Gotcha...


> You have to hardcode the shape tuples.

This is not true. Rather, all shapes have to be known at compile time. That means that output shapes must not depend on input values, but may depend on input shapes -- also explicitely.

Furthermore, there are two useful additions:

1. You can use vanilla numpy for compile-time computations. An example would be computing an array of indices for some moving-window filter, depending on the input shape and a stride parameter.

2. You can mark function arguments as "static". Then their values may change shapes of the output, but accordingly the function is compiled for each value of those arguments.


You’re right, I apologize for wording it incorrectly. It might have been a restriction with the # of dimensions. Either way, it wasn’t cut out for my use case.


JAX GPU support is limited to Linux only. Even the WSL2 support is experimental. https://jax.readthedocs.io/en/latest/installation.html#suppo...


Apple supports JAX[0] along with PyTorch[1] and Tensorflow[2] on macOS with both Apple Silicon and AMD GPUs (on x86 Macs). Although, the perf isn't great. I write most of my experimental ML code in JAX on an M2 Macbook Air and then move to a proper multi-GPU Linux box for full training runs.

[0]: https://developer.apple.com/metal/jax/

[1]: https://developer.apple.com/metal/pytorch/

[2]: https://developer.apple.com/metal/tensorflow-plugin/


Pytorch on my M2 max using the MPS backend has pretty decent performance to be honest?

It's significantly faster than CPU. Something like 100x using sheet


Is there a specific reason why Windows is not supported?


Presumably because the Google cloud doesn't run on Windows. Well, nothin HPC related runs Windows.


Life science industry uses plenty of Windows, including HPC workloads.


We ship Windows CPU only at the moment.

We don't support Windows GPU because we haven't had the engineer bandwidth to support it well.

We recommend WSL2 for GPU on Windows at the moment because that is a compromise: it allows CUDA support, without us having to support another release variant.

But we welcome community contributions!


Because no one has done the work to add it... Could be you!


They don't build on Windows at all, as well.


Not true!

We release Windows CPU wheels (https://pypi.org/project/jaxlib/#files). So JAX on CPU works great on Windows.

We don't release Windows GPU wheels at the moment, but that's because we're a small team and none of us use Windows personally. We welcome contributions!

(I verified that the Windows CUDA GPU support built as recently as two weeks ago, but I don't have the ability to test that it works.)

We recommend WSL2 because that's just using our existing Linux CUDA release.


Oops so sorry. But this is recent isn't it? I thought it was actually due to XLA/Bazel not supporting it?


Yes, we made this more formally supported recently.

We felt that Windows CPU support was important so everyone can run JAX, even if it's not always the most-accelerated version of JAX. And we got some great PRs from the community that helped fix a few open issues.


Very nice! I just installed, I hope to eventually contribute down the line, especially in terms of custom operators. They weren't even document until recently, and there's still quite some work to add them.


I wrote a Notebook [0] that gets you introduced to JAX in a very gentle manner.

It also covers things like functional purity in Deep Learning, and handling of random numbers in JAX.

[0]: Learn JAX: From Linear Regression to Neural Networks - https://www.kaggle.com/code/truthr/jax-0


I have been working on my DNN model using TensorFlow even though ML is not my main research. But it is a substantial part of my research, so I have to figure things out on my own, and I have done so over the past 3 years. However, I spend so much time on figuring out how any of TF methods works and debugging them. I never used JAX but I am not sure if this sort of grinding is normal when you use JAX as well (I always hear great things about JAX). I have built so many things using TF I don't think it is wise for me to learn JAX and migrate my work into JAX code base.


Is there a reason you are limited to JAX/tensorflow and can't use pytorch?

For a lot of people I know whose main job was not to write code, switching from tensorflow to pytorch was something that saved them ten to hundreds of hours in the long run, even accounting for the initial learning time.


I always feel like I am a prisoner of sunk cost fallacy


You should definitely switch to PyTorch. Or even JAX.

TF is not worth it anymore.


Does it support arrays of variable lengths now? Last time I looked, I think this was not supported. So it means, for every variable dimension, you need to use an upper bound, and then use masking properly, and hope that it would not waste computation too much on the unused part (e.g. when running a loop over it).

I'm working with sequences, e.g. speech recognition, machine translation, language modeling. This is a quite fundamental property for this type of models, that we have variable lengths sequences.

In those cases, for some example code, I have seen that training also used only fixed size dimensions. And at inference time, they had some non-JAX code for the loop over the sequence around the JAX code with fixed-size dimensions.

This seems like a quite fundamental issue to me? I wonder a bit that this is not an issue for others.


For JIT-ing you need to know the sizes upfront. There was an experimental branch for introducing jagged tensors, but as far as I know, it has been abandoned.


For large scale ML Jax is pretty nice. It makes the multi-host computation feel like a first class citizen and your programs tend to be designed with that in mind.


Very unrelated but I did job interview with Nvidia JAX team for a compiler engineer role some time ago, not very friendly and very opinionated.


What happened?


Nothing happened. It was an informal technical interview with the program manager at JAX. 1 hour call and the interview was remote but describing him as opinionated and entitled is an understatement. Best of luck to them.


Asking because I am on that team :)

I went through the same hiring process and had a positive experience at every stage. I had a strong competing offer but went with the JAX team at NVIDIA.

I'll pass it along as feedback.


Would you mind sharing some details? It sounds like an interesting peek behind the curtain.


I'm a huge fan of Jax. The Jax team is incredibly strong!

Just want to share that Ray (an open source project we're developing at Anyscale), can be used to scale Jax (e.g., across TPUs).

Some docs from Google on how to do this

https://cloud.google.com/tpu/docs/ray-guide

Alpa is an open source project scaling Jax on 1000+ GPUs

https://www.anyscale.com/blog/training-175b-parameter-langua...

Cohere uses Ray + Jax + TPUs to build their LLMs

https://www.youtube.com/watch?v=For8yLkZP5w

A demo from Matt Johnson on the Jax team

https://www.youtube.com/watch?v=hyQ-tgD5sgc


Is there any benefit using it instead of pytorch?


Jax has a much nicer handling of higher order differentiation. PyTorch has functions to compute Hessians and there are libraries to keep differentiability through optimizers, but going out of their standard use-cases becomes tricky very fast. In contrast, JAX can compute nth-derivatives of things very easily.


The main benefit in my experience is that it’s much easier to do distributed computations in JAX. It has a much nicer API. For single device computing there’s no advantage either way.


If you like functional languages, then Jax will fit better for you. It provides a bunch of function transformations to implement eg grad, JIT etc.


I'll remember this always: when DeepMind solved a subset of the protein structure prediction problem, they used Jax as the framework.

PSPP was a long-standing issue and to see a fairly new computational tool used to significantly aid in the process of "solving" it speaks greatly towards its general utility in the sciences.


Imho the project should emphasize more the potential for simple and uniquitous multi-core acceleration of vector compute that is available by definition to anybody having any modern cpu.

nvidia dropped cuda support for perfectly good gpu's, showing the perils and waste of being locked-in in a profit-maximazing monopoly.


Now if only it supported ONNX export or was cross-platform so I could run it from Java and .NET land



I need the opposite - JAX to ONNX


I've been looking in to this for the java world. What's your use case? Deployment in to existing applications?


Yea exactly - Python for training, Java/.NET for inference at production. I looked at approaches like GRPC and things but my case is a bit more time-sensitive and the latency added by going over a network layer was too much.

For now I'm happy with Pytorch->ONNX and then running the ONNX model directly. But as I said, that means I can't easily train using JAX :-(



Ohh, I'll check that out!


> With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy code.

Nice. When did they make this change?

Here is the old way in the docs, where you needed to define functions for the if-true branch and the if-false branch, and feed them to a conditional function, to get the normal if-then-else conditional.

https://jax.readthedocs.io/en/latest/notebooks/Common_Gotcha...


Actually that never changed. The README has always had an example of differentiating through native Python control flow:

https://github.com/google/jax/commit/948a8db0adf233f333f3e5f...

The constraints on control flow expressions come from jax.jit (because Python control flow can't be staged out) and jax.vmap (because we can't take multiple branches of Python control flow, which we might need to do for different batch elements). But autodiff of Python-native control flow works fine!


This is still the case afaik.

For vanilla "if", the condition must be known at compile time. For runtime, you have to use "cond", "where", or "select" (which may be analogous).


Actually, that's never been a constraint for JAX autodiff. JAX grew out of the original Autograd (https://github.com/hips/autograd), so differentiating through Python control flow always worked. It's jax.jit and jax.vmap which place constraints on control flow, requiring structured control flow combinators like those.


What happened to Jax? Is it still alive?


Very much alive. From what I can tell it has more or less replaced Tensorflow for research purposes. (A lot of researchers use PyTorch though.)


Development seems not to have dropped at all from the contributions page: https://github.com/google/jax/graphs/contributors

Don’t know about usage and uptake though.


Why would it be dead?


Just going over the docs and there’s something I don’t understand, hoping someone here can explain:

What’s the point of having to explicitly call grad(jit(f))? Why doesn’t grad just call jit internally? Is there a usecase where you want the grad without jit?


Indeed, it is much better to use jit(grad(f)) in general.

Supporting the opposite composition is still useful in some edge cases -- for example when debugging, you want to step through a computation without jit, and simply not crash when differentiating any inner functions also decorated with jit.


Wouldn't you use jax.disable_jit() for that?


You can do that too! That will disable every JIT though. In practice you might only want to disable just one.

To add some colour to my answer. When writing a library, it's typical to a put a JIT statement on everything in the public API. This means you get the benefits of JIT compilation even when you're just hacking around in the REPL, and mitigates the new-user-footgun in which they forget to use JIT themselves.

Meanwhile, good practice is always to JIT your whole computation.

Combined, this mean that it's fairly common to go jit (at the top level) -> grad (of your operation) -> jit (of some library call).

When debugging your code, the JIT'd library call is _probably_ not the culprit. So you only want to disable the top-level JIT when stepping through, and still take advantage of JIT compilation where you can. Overall one obtains a composition of the form grad(jit(...)).

TL;DR: even if use case doesn't come up super frequently, it's more user-friendly to support grad(jit(...)) than it is to just crash.


I am running JAX version of "ctoraman/hate-speech-bert" over here https://news.ycombinator.com/item?id=37696033 and its pretty efficient



Here's mine [0]. It's a very gentle introduction.

[0]: https://www.kaggle.com/code/truthr/jax-0


Is there a path to compile and deploy JAX models on Android apps locally?


I haven't tried it myself, but perhaps it's worth looking into the jax2tf -> tflite route.


I honestly had no idea JAX was useful outside autograd, due to my own tunnel vision. I even used it with other libraries to do this kind of work. Is there a term for this type of mistake?


Anybody using it in production? Is it, or its derivatives like Flax, worth using over pyTorch for anything?

edit: Made comparison more fair.


I’m a researcher, not using anything in production, but I find jax more usable as a general GPU-accelerated tensor math library. PyTorch is more specifically targeted at the neural network use case. It can be shoehorned into other use cases, but is clearly designed & documented for NN training & inference.


Agreed. I used Jax about a year ago to estimate some diode parameters for a side project of mine.


Not a fair comparison IMO. Jax is low level library used to make ML frameworks while pytorch is a full blow ML framework.

In terms of is it worth using it - that depends on what you're doing. If you just want to start with ML training probably not. If you have something already and you want to take it to next level (e.g. influence how training and inference work) than it's a good choice. You might be interested in looking into flax or haiku instead of using vanilla Jax. These are closer to pytorch.


If you like PyTorch then you might like Equinox, by the way. (https://github.com/patrick-kidger/equinox ; 1.4k GitHub stars now!) Basically designed to offer PyTorch-like syntax for working with JAX. The latter is excellent for the reasons the sibling replies have stated, but PyTorch absolutely got the usability story correct.


We recently switched to Jax to boost performance as we scale up our algorithmic core. The nice thing is that it presents only a minor jump in capabilities to get developers working with it, if they have prior exposure to numpy ofcourse. Quite nice :)


A number of large AI companies use it to train their large models; Midjourney, Stability, Anthropic, DeepMind, among others.


This is pretty well known, so I don't know why it would get so popular.


Jax is ceres but for python!


Looking for some numerical problems to practice JAX, any suggestion or resource?


Although not terribly complicated, you can compute fractals like the Mandelbrot set or Julia sets. They look really cool and you can play with visualization.


there's a comment above : search for kaggle and you'll find such.

also:

https://jax.readthedocs.io/en/latest/advanced_guide.html




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: