It took me a while to realize it, but Jax is actually a huge opportunity for a lot of scientific computing. Jax was originally developed as a more flexible platform for doing machine learning research. But Jax's real superpower is that it bundles XLA and makes it really easy to run computations on GPU or TPU. And huge swathes of scientific computation basically run large scale vectorized computations.
When I was in astronomy (about a decade ago) I did large scale simulations of gravitational interactions. But at the time all these simulations were done on CPU. Some of the really big efforts used more specialized chips, but it was a huge effort to write the code for it.
But today with Jax, if you want to write an N-body simulation of a globular cluster, you can just code it up in numpy and it'll run on a GPU for free and be about 1000x faster. From what I can tell though, very few people in the sciences have caught on yet.
Luckily I discovered JAX in the beginning of my PhD four years ago. It has made our data processing (biomedical imaging) so much easier and more readable, albeit with a slight learning curve due to JAX being functional/pure.
I am also continously surprised how little adoption JIT and autodiff libraries have gotten in scientific computing. A lot of my colleagues somehow really like coding cost function gradients and fine-tuned GPU code by hand. I guess using something like JAX can reduce your standing in the group, because it can make it seem like coding algorithms is pretty easy.
I feel like I see the opposite -- that everything scientific computing is getting rewritten in something autodifferentiable! Whether that's JAX or something else.
My experience might be biased though: shameless advert for Equinox (https://github.com/patrick-kidger/equinox, 1.4k GitHub stars), which is now the foundation of quite a lot of SciComp in JAX. (Both internal and open-source.)
...if that SciComp uses machine learning, I guess? In my "physics of biomedical imaging" bubble, people are hardly doing state-of-the-art ML, but rather expensive forward models for which computing a gradient is cumbersome.
But I know that e.g. Stephan Hoyer is a physicist and you are a mathematician originally -- I have read a lot of your JAX issues and libraries ;-) maybe it just depends on the "mini-bubble' aka. the indiviual research group and not only the field of science.
Not necessarily! It's perfectly possible (and quite common) to e.g. write down a traditional parameterised ODE, and then optimise its parameters via gradient descent. Compute the gradients wrt parameters using autodiff through the numerical ODE solver. All without a single neural network in sight! ;)
My usual spiel is that autodiff+autoparallel are really useful for any kind of numerical computation -- of which ML is a (popular, well funded) special case.
At least in my mini bubble, these kinds of "scipy but autodifferentiable" use-cases are fairly common.
> I have read a lot of your JAX issues and libraries ;-)
Haha, that's fun to hear though! Thank you for sharing that.
Ah, I think I was unclear. I specifically meant your reference to Equinox, because that seemed to me to be somewhat ML specific.
In general, I very much agree that "autodiff+autoparallel are really useful for any kind of numerical computation". And the use cases are also really common in my bubble. It's just that (imho) most people have not realized this.
Ah right! Actually it's a good point, the Equinox readme/etc do tend to emphasise the ML use cases -- partly this is deliberate (go where the money is)! But I should probably tweak it to emphasise more general parameterised models.
> In my "physics of biomedical imaging" bubble, people are hardly doing state-of-the-art ML, but rather expensive forward models for which computing a gradient is cumbersome.
I’d appreciate any pointers to the literature; curious to see the kinds of models people work with. Thanks!
I don't have didactic examples at hand, but e.g. [1] or [2]. IIRC [1] uses the Laplace operator (second-order spatial derivate) and [2] uses a linear solve inside the forward model through which differentiation is certainly possible but pretty cumbersome in practice.
> "It took me a while to realize it, but Jax is actually a huge opportunity for a lot of scientific computing."
In all conferences like NeurIPS, in Google ML Community days, etc., whenever there is a JAX workshop/tutorial/talk, it is always touted as a numerical computation library. And it was developed as such. Sure the focus is in ML, but everyone involved in it always have said that this is a general purpose scientific computing library.
None of these libraries unfortunately allows making good use of CPU vectorized units. Xla might produce some SIMD code but it pales in (performance) comparison to routines written explicitly for SIMD on GPU. ISPC is a good example of this.
the issue here is that if your ideal algorithm isn't simply expressible in numpy (which many aren't), you're pretty much out of luck. As a result, imo the better approach is to use a fast language that also compiles to GPU (e.g. Julia)
Having used JAX quite a bit for numerical computing (and having lectured on this use-case) I would say that a surprisingly large number of algorithms can be expressed as array[0] operations (even if it sometimes takes a bit of thinking).
And, more importantly, things that cannot be expressed that way tend to not be a good fit for GPU computing anyway (independently of the language / framework you are using).
[0]: `array` is a shortcut here, JAX is not limited to operations on arrays.
Agreed. I've done a fair amount of reworking signal processing algorithms to run on GPU/TPU, and it's a different beast. You often have to really rebuild the algorithm from the ground up to take advantage of parallelization. But often you /can/ rework the algorithm, and end up with much higher throughput than the crusty old serial algorithm: there's typically nothing fundamentally stopping you from finding a good implementation, just that the original devs were working in the 70s and hasn't thought that far ahead.
> things that cannot be expressed that way tend to not be a good fit for GPU computing anyway
I'll have to disagree with you a little bit here. SIMT model of GPUs are quiet a bit more expressive than the numpy's SIMD model. As an obvious example, you'll have to manually maintain a mask to implement if/else i.e. code path divergence in SIMD. GPUs automatically does this and many more to make your life easier. And frankly, I find it lot more easier to reason about what should happen to one data point than a bunch of them together.
Convenient authoring doesn't necessarily make it a good fit for the hardware. Add in enough divergence and your GPU code is going to be matched or outperformed by a competent CPU implementation (on a chip of comparable size). Branchless code can result in substantial speedups on either.
To be fair though, modern GPUs are pretty good at branching and latency hiding, while numpy-style code has poor data locality unless you have a magic compiler.
To spell out what the linked ISPC post implies, most of the difference, like ISPC shows, is differences in GPU languages and compilers vs CPU side equivalents.
Pretty sure I got multiple 1000x speed ups when I vectorized my my algo trader from a dumb python loop to a dumb numba compiled thing, and when I benchmarked Jax, the performance blew away the numba thing (which was already a million times faster than the naive version) because Jax performance stayed perfectly flat as the scale went up whereas numba slowed down. Might have been my approach for each, but it was enlightening and funny to watch.
I have at least one complaint with the numpy model:
When you chain a sequence of vectorized operations on arrays, loop fusion would save you from allocating memory for each intermediate variable, and the round trip time of moving it from RAM to CPU multiple times. I don’t know how good JAX’s JITted loop fusion is on CPU, but I’ve been very very impressed by Julia.
Eg: I had some Numpy code that took hours (and needed terabyte RAM) that was very straightforward to code in Julia, and needed only a few GB to finish in a few seconds — on my laptop.
I want to be able to think in arrays, but to also not have to materialize the arrays as much as possible.
I found this to be not true in practice when working with graphs.
Having access to high performance explicit loops and ifs/masks allows one to focus on the hard parts of the algorithms, rather than on the purely incidental puzzle how to best avoid spending time in the Python runtime.
An alternative is to write most of the program in Python + JAX + implement a few custom XLA ops in CUDA / Triton. That way, the program is very readable and can interoperate with the larger ecosystem, while still being fast to run.
A key difference is that each iteration of scan is called by the host. Put differently, JAX can't fuse scan into a single GPU kernel, but launches a kernel for each iteration.
Depening on the workload this is no problem. If you have many cheap iterations, you will notice the overhead.
I am not sure if they are working on fusing scan and what's the current status.
Yes, but is it really the same? Afaik the `unroll=n` parameter translates `n` iterations into a vanilla `for` loop which is then unrolled into sequential statements (in contrast to a JAX `fori` loop). There still is no loop on the accelerator, strictly speaking?
I think this is up to XLA to handle not Jax. The whole selling point in TF of the tf.function decorator (which uses XLA underneath as well) is that it fuses arithmetic to lower launch count.
Jax is super useful for scientific computing. Although nbody sims might not be the best application. A naive nbody sim is very easy to implement and accelerate in jax (here’s my version: https://github.com/PWhiddy/jax-experiments/blob/main/nbody.i...), but it can be tricky to scale it. This is because efficient nbody sims usually either rely on trees or spatial hashing/sorting which are tricky to efficiently implement with jax.
Last time I looked at JAX MD it didn't support most of the force field terms necessary for simulating proteins and DNA. For example, it could do n-body simulations with some potential, but not the bonds/torsions between atoms. It's unclear if they added support, but that's a huge gap in functionality compared to other systems.
The open source release of XLA predates Lattner's tenure at Google by 7 months, and it definitely existed before that -- the codebase was already 66k SLOC at that point. During his tenure it went from 100k SLOC to 250k SLOC. It's now 700k SLOC. He also has, as far as I can tell, zero commits in the XLA codebase. "Of LLVM fame" would be more accurate I think.
my bad - i guess i was just saying he led that team but didnt mean to imply he originated it
you seem to have very precise knowledge of the SLOC at a point in time - just curious is there any tooling you used to do that? that can be pretty nifty to pull out on occasion
I git cloned the repo and then ran sloccount after checking out various commits (just did `git log | grep -C3 'Jan 1 [0-9:]* 2017'` or similar to find the relevant commits)
> large scale simulations of gravitational interactions
I'm guessing this was mostly Fast Multipole Method? I don't think it ports that easily GPU since there's so much communication involved and the leaves don't do a whole lot
Yes. But some of the algorithms cannot benefit that much from the GPU. In my field -- mathematical optimization, lots of algorithms rely on sparse matrix operations and takes many iterations until convergence.
Have you investigated why? I know that many projects have an "implement first, optimize later" approach, and the lesser used functions might be far from optimal.
Back in the tensorflow days, I had this issue and submitted a patch that gave a ~50x speedup for my usecase. It's always better to optimize the base function rather than have 100 people all manually working around the same performance issue.
Because they use a funny format (BCOO). I'm not mocking, it must be a solid choice for some reasons, like sparsification or other fancy stuff. But for large and even with batches (ie multiply with tall dense matrix), it doesn't match an equivalent scatter (x.at[idx].add(vals)). Which itself is several times slow than equivalent opencl (on an A40)
That was just from some quick benchmarks I did a few months back on some 10,000 particle N-body simulations. The performance boost will depend on the task, though. For the kinds of computations I did in grad school it would have been less effective since I was only looking at 3--5 objects, so there's just less parallelism to take advantage of.
Would you happen to have sources on the three orders of magnitude speedup coming at no cost? I'd assume porting + data movement considerations making this task non-trivial.
Yep. I’ve been writing a molecular dynamics simulator in PyTorch and the reduction in scope is wild because of heterogenous operators and automatic differentiation.
Porting existing codes is still a massive effort and there is low faith in long-term support from Google based software and hardware. I’m not aware on much (any) TPU use in scientific HPC.
There is not yet. But there is huge pressure on HPC Centers at least in Europe to also make resources available for ML. Already many scientific supercomputers have GPUs as accelerators. So we might have the other situation: HPC users are faced more and more with machines with spare accelerators and it will make sense to use them. Actually it would totally make sense if JAX development is in part also public also financed in this case (e.g. through EuroHPC).
It's so awkward that these truly fantastic tools for fast numerical computation (NumPy and JAX) have to be accessed through Python, which is a truly terrible language for fast numerical computation.
Is anyone making any serious progress in fast GPU based computational tools for other faster languages? I'm looking for something that also works on the GPU on windows (unlike JAX)
It’s because most of the people doing these computations don’t have the capacity to become experts in multiple fields. They understand the math and analytics very well, and they expend all their time thinking about that, not about type systems, memory management, etc. Python lets them code without having to think about a lot of that stuff so they can focus on the things they care about. These aren’t computer scientists or programmers, they’re meteorologists, astronomers, oil and gas analysts, investment bankers etc. That’s why some truly great computer scientists and programmers invested their time into building these tools for python vs other languages.
I think JAX is cool, but I do find it slightly disingenuous when it claims to be "numpy by on the GPU" (as opposed to PyTorch), when actually there's a fundamental difference; it's functional. So if I have an array `x` and want to set index 0 to 10, I can't do:
x[0] = 10
Instead I have to do:
y = x.at[0].set(10)
Of course this has advantages, but you can't then go and claim that JAX is a drop in replacement for numpy, because this such a fundamental change to how numpy developers think (and in this regard, PyTorch is closer to numpy than JAX).
Agree, though I wouldn’t call PyTorch close to a drop-in for NumPy either, there are quite some mismatches in their APIs. CuPy is the drop-in. Excepting some corner cases, you can use the same code for both. E.g. Thinc’s ops work with both NumPy and CuPy:
Though I guess the question is why one would still use NumPy when there are good libraries for CPU and GPU. Maybe for interop with other libraries, but DLPack works pretty well for converting arrays.
On the other hand, if I wanted some scientific NumPy code to run on the GPU, I think rewriting it in JAX would probably be a better choice than PyTorch.
In my experience, the answer comes down to "does your code use classes liberally?"
If no, you're just passing things between functions, then go ahead with Jax! But converting larger codebases with classes is just significantly better with PyTorch even if they use different method names etc.
I'm going to disagree here! Classes and functional programming can go very well together, just don't expect to do in-place mutation. (I.e. OO-style programming.)
Sorry for my potentially VERY ignorant question, I only know functional programming at average joe level.
Why can't you do the first in functional programming (not in this specific case because it's just how it is, but in general)?
And even if you can't do so for any reasonable reason in functional (again, in general), what stops us to just add syntactic sugar to equal it to the second to make programmer's life easier?
There's 2 different aspects people mean when they call sth functional programming:
- higher order functions (lambdas, currying, closures, etc.)
- pure functions, immutability by default, side effects are pushed to the top level and marked clearly
The first aspect of functional programming has been already accepted by most OOP languages (even C++ has lambdas and closures).
The second aspect of functional programming is what makes it useful on GPU (because GPU architecture that makes it so powerful requires no interactions between code fragments that are run in parallel on 1000s of cores). So you can easily run pure functional code on GPU, but you can't easily run imperative code on GPU.
You can introduce side effects to functional programming, but then it ceases to be any more useful for GPU (and other parallel programming) than imperative/OOP.
It's my understanding that, at least in Python, you can't change immutable data type but you can just assign a new data to the same variable and therefore overwrite it, right? So even if JAX makes list type immutable, you can still just re-use `x` to save the new modified list.
doesn't `[] =` just call a method on the object in python?
e.g, `x[0] = 10` is the same as `x.__set_item__(0, 10)`, so there shouldn't be any technical limitation to using `x[0]` (says the guy who never even imported jax)
You could do `y = x.__setitem__(0, 10)`, but you cannot assign `x[0] = 10` to a new variable. If `__setitem__` was overridden, you would not be able to distinguish between these cases and raise an error in the second one.
I love JAX. It can be a great replacement for Numba or whatever, the @jit works really well. And vmap is amazing... I often get lost in the matrix sauce when batching, using JAX I develop as if it's just a single instance, then vmap that shit. The ecosystem I was using was: jax, optax, haiku.
The big issue I had was: I was developing on the CPU, then moved to running it on a GPU, and it wasn't as fast as I expected-- I started debugging, and saw there was still lots of communication between the CPU and GPU even tho it was all jit'd. I think PyTorch is a more user friendly for writing high performance models if you're not straying too far from the beaten path. But I really love JAX would like to play around with it more to understand these pits I'm falling into.
And another complaint is I can't run it on my Macbook M1 GPU... but I'm seeing this page now, so maybe that's not true anymore: https://developer.apple.com/metal/jax/
I’ve used Numpy with Numba primarily (on the CPU) and it had been a game changer for my data science workloads.
Naturally, I pay close attention to Jax and give it a look every now and then. So, I’ll focus my observations below on Jax’s Numpy API support.
At a glance, Jax code looks like regular Python, but it’s a very different style of programming. Two big differences I’ve found are:
- All Jax functions must be pure. You can’t pass references.
- ndarrays cannot be created with dynamic shapes. You have to hardcode the shape tuples. One possible workaround can be to create a buffer much bigger than you need and return that along with actual shape.
Then there are many small things that are very well documented[1] by the Jax team.
Overall, if you are training ML models, the trouble might be worth it (Autograd). But for accelerating Numpy alone, it is no Numba replacement - which will happily work in the above mentioned use-cases.
This is not true. Rather, all shapes have to be known at compile time. That means that output shapes must not depend on input values, but may depend on input shapes -- also explicitely.
Furthermore, there are two useful additions:
1. You can use vanilla numpy for compile-time computations. An example would be computing an array of indices for some moving-window filter, depending on the input shape and a stride parameter.
2. You can mark function arguments as "static". Then their values may change shapes of the output, but accordingly the function is compiled for each value of those arguments.
You’re right, I apologize for wording it incorrectly. It might have been a restriction with the # of dimensions. Either way, it wasn’t cut out for my use case.
Apple supports JAX[0] along with PyTorch[1] and Tensorflow[2] on macOS with both Apple Silicon and AMD GPUs (on x86 Macs). Although, the perf isn't great. I write most of my experimental ML code in JAX on an M2 Macbook Air and then move to a proper multi-GPU Linux box for full training runs.
We don't support Windows GPU because we haven't had the engineer bandwidth to support it well.
We recommend WSL2 for GPU on Windows at the moment because that is a compromise: it allows CUDA support, without us having to support another release variant.
We don't release Windows GPU wheels at the moment, but that's because we're a small team and none of us use Windows personally. We welcome contributions!
(I verified that the Windows CUDA GPU support built as recently as two weeks ago, but I don't have the ability to test that it works.)
We recommend WSL2 because that's just using our existing Linux CUDA release.
Yes, we made this more formally supported recently.
We felt that Windows CPU support was important so everyone can run JAX, even if it's not always the most-accelerated version of JAX. And we got some great PRs from the community that helped fix a few open issues.
Very nice! I just installed, I hope to eventually contribute down the line, especially in terms of custom operators. They weren't even document until recently, and there's still quite some work to add them.
I have been working on my DNN model using TensorFlow even though ML is not my main research. But it is a substantial part of my research, so I have to figure things out on my own, and I have done so over the past 3 years. However, I spend so much time on figuring out how any of TF methods works and debugging them. I never used JAX but I am not sure if this sort of grinding is normal when you use JAX as well (I always hear great things about JAX). I have built so many things using TF I don't think it is wise for me to learn JAX and migrate my work into JAX code base.
Is there a reason you are limited to JAX/tensorflow and can't use pytorch?
For a lot of people I know whose main job was not to write code, switching from tensorflow to pytorch was something that saved them ten to hundreds of hours in the long run, even accounting for the initial learning time.
Does it support arrays of variable lengths now? Last time I looked, I think this was not supported. So it means, for every variable dimension, you need to use an upper bound, and then use masking properly, and hope that it would not waste computation too much on the unused part (e.g. when running a loop over it).
I'm working with sequences, e.g. speech recognition, machine translation, language modeling. This is a quite fundamental property for this type of models, that we have variable lengths sequences.
In those cases, for some example code, I have seen that training also used only fixed size dimensions. And at inference time, they had some non-JAX code for the loop over the sequence around the JAX code with fixed-size dimensions.
This seems like a quite fundamental issue to me? I wonder a bit that this is not an issue for others.
For JIT-ing you need to know the sizes upfront. There was an experimental branch for introducing jagged tensors, but as far as I know, it has been abandoned.
For large scale ML Jax is pretty nice. It makes the multi-host computation feel like a first class citizen and your programs tend to be designed with that in mind.
Nothing happened. It was an informal technical interview with the program manager at JAX. 1 hour call and the interview was remote but describing him as opinionated and entitled is an understatement. Best of luck to them.
I went through the same hiring process and had a positive experience at every stage. I had a strong competing offer but went with the JAX team at NVIDIA.
Jax has a much nicer handling of higher order differentiation. PyTorch has functions to compute Hessians and there are libraries to keep differentiability through optimizers, but going out of their standard use-cases becomes tricky very fast. In contrast, JAX can compute nth-derivatives of things very easily.
The main benefit in my experience is that it’s much easier to do distributed computations in JAX. It has a much nicer API. For single device computing there’s no advantage either way.
I'll remember this always: when DeepMind solved a subset of the protein structure prediction problem, they used Jax as the framework.
PSPP was a long-standing issue and to see a fairly new computational tool used to significantly aid in the process of "solving" it speaks greatly towards its general utility in the sciences.
Imho the project should emphasize more the potential for simple and uniquitous multi-core acceleration of vector compute that is available by definition to anybody having any modern cpu.
nvidia dropped cuda support for perfectly good gpu's, showing the perils and waste of being locked-in in a profit-maximazing monopoly.
Yea exactly - Python for training, Java/.NET for inference at production. I looked at approaches like GRPC and things but my case is a bit more time-sensitive and the latency added by going over a network layer was too much.
For now I'm happy with Pytorch->ONNX and then running the ONNX model directly. But as I said, that means I can't easily train using JAX :-(
> With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy code.
Nice. When did they make this change?
Here is the old way in the docs, where you needed to define functions for the if-true branch and the if-false branch, and feed them to a conditional function, to get the normal if-then-else conditional.
The constraints on control flow expressions come from jax.jit (because Python control flow can't be staged out) and jax.vmap (because we can't take multiple branches of Python control flow, which we might need to do for different batch elements). But autodiff of Python-native control flow works fine!
Actually, that's never been a constraint for JAX autodiff. JAX grew out of the original Autograd (https://github.com/hips/autograd), so differentiating through Python control flow always worked. It's jax.jit and jax.vmap which place constraints on control flow, requiring structured control flow combinators like those.
Just going over the docs and there’s something I don’t understand, hoping someone here can explain:
What’s the point of having to explicitly call grad(jit(f))? Why doesn’t grad just call jit internally? Is there a usecase where you want the grad without jit?
Indeed, it is much better to use jit(grad(f)) in general.
Supporting the opposite composition is still useful in some edge cases -- for example when debugging, you want to step through a computation without jit, and simply not crash when differentiating any inner functions also decorated with jit.
You can do that too! That will disable every JIT though. In practice you might only want to disable just one.
To add some colour to my answer. When writing a library, it's typical to a put a JIT statement on everything in the public API. This means you get the benefits of JIT compilation even when you're just hacking around in the REPL, and mitigates the new-user-footgun in which they forget to use JIT themselves.
Meanwhile, good practice is always to JIT your whole computation.
Combined, this mean that it's fairly common to go jit (at the top level) -> grad (of your operation) -> jit (of some library call).
When debugging your code, the JIT'd library call is _probably_ not the culprit. So you only want to disable the top-level JIT when stepping through, and still take advantage of JIT compilation where you can. Overall one obtains a composition of the form grad(jit(...)).
TL;DR: even if use case doesn't come up super frequently, it's more user-friendly to support grad(jit(...)) than it is to just crash.
I honestly had no idea JAX was useful outside autograd, due to my own tunnel vision. I even used it with other libraries to do this kind of work. Is there a term for this type of mistake?
I’m a researcher, not using anything in production, but I find jax more usable as a general GPU-accelerated tensor math library. PyTorch is more specifically targeted at the neural network use case. It can be shoehorned into other use cases, but is clearly designed & documented for NN training & inference.
Not a fair comparison IMO. Jax is low level library used to make ML frameworks while pytorch is a full blow ML framework.
In terms of is it worth using it - that depends on what you're doing. If you just want to start with ML training probably not. If you have something already and you want to take it to next level (e.g. influence how training and inference work) than it's a good choice. You might be interested in looking into flax or haiku instead of using vanilla Jax. These are closer to pytorch.
If you like PyTorch then you might like Equinox, by the way. (https://github.com/patrick-kidger/equinox ; 1.4k GitHub stars now!)
Basically designed to offer PyTorch-like syntax for working with JAX. The latter is excellent for the reasons the sibling replies have stated, but PyTorch absolutely got the usability story correct.
We recently switched to Jax to boost performance as we scale up our algorithmic core. The nice thing is that it presents only a minor jump in capabilities to get developers working with it, if they have prior exposure to numpy ofcourse. Quite nice :)
Although not terribly complicated, you can compute fractals like the Mandelbrot set or Julia sets. They look really cool and you can play with visualization.
When I was in astronomy (about a decade ago) I did large scale simulations of gravitational interactions. But at the time all these simulations were done on CPU. Some of the really big efforts used more specialized chips, but it was a huge effort to write the code for it.
But today with Jax, if you want to write an N-body simulation of a globular cluster, you can just code it up in numpy and it'll run on a GPU for free and be about 1000x faster. From what I can tell though, very few people in the sciences have caught on yet.