Is it fair to say JAX is what Google made when they looked at Pytorch/Autograd and thought, "oh damn, that's what we should have done?".
If so, is this the beginning of the end of Tensorflow? I know Tensorflow is still top for production, but it is certainly rapidly losing followings in the research field, and Pytorch and now starting to focus on deployment as they know this is their weakness.
Yes but its not mature enough to kill Tensorflow in the short term.
I still see people prefering Tensorflow over Pytorch because they have the feeling that it is more mature for production use.
Meanwhile Jax has not converged on a recommended deep neural network framework (it has the low level pieces).
At the moment its a great building block that researcher should probably know.
It sounds like JAX is necessarily storing your whole calculation in memory, so it will necessarily use more memory for automatic differentiation of heavily iterative calculations, while other implementations of backward-mode automatic differentiation can instead restart your calculation from checkpoints to avoid storing the whole thing. This could be an advantage of several orders of magnitude for some calculations: using twice the CPU or GPU time in exchange for one thousandth or one ten-thousandth of the memory.
Rematerialization in autodiff is super interesting! XLA does rematerialization optimizations, so you get those automatically under jax.jit. There's also the jax.checkpoint decorator (https://github.com/google/jax/pull/1749) which lets you control reverse-mode checkpointing yourself; you can use it recursively to implement sophisticated checkpointing strategies (see Example 5 in that PR, which is the classic strategy for getting memory cost to scale like log(N) for iteration count N but requiring log(N) times as much computational work). It'd be interesting to experiment with heuristics for deploying those strategies automatically (e.g. given a program in JAX's jaxpr IR) but one of JAX's core philosophies is to keep things explicit and give users control through composable APIs. Automatic heuristics can be built on top.
Another goal is to make JAX a great system for playing with things like this!
No worries! I didn't mean it as a correction so much as just a discussion; I'm sure it's true that other autodiff systems have very sophisticated automatic remat (like https://openreview.net/forum?id=BkYYXJ9i-). I'm hoping as users push JAX on new applications, especially in simulation and scientific computing, we'll learn a lot and be able to improve!
There's also "cross-country optimization" (https://www-sop.inria.fr/tropics/slides/EdfCea05.pdf) for mixing some forward-mode into reverse-mode to improve memory efficiency. Analogously to jax.checkpoint, we've only experimented with exposing that manually (in jax.jarrett, named because of https://arxiv.org/abs/1810.08297), and even then only for a special case. There's a lot to learn about, experiment with, and build!
Yes, I think Jax is indeed the nail for tensorflow. It's not there yet, but the part of the research community that did not go to pytorch is going to jax now.
This is a very interesting article, with much better illustrations and deeper investigation than the note on separable image filters I wrote in Dercuano. And I didn't know about JAX, and it's very valuable to know about it now. But the article does have some errors.
The article says:
> Optimization of arbitrary functions is generally a NP-hard problem (there are no solutions other than exploring every possible value, which is impossible in the case of continuous functions)
It is true that optimization of arbitrary functions, or even many interesting classes of functions, is NP-hard. However, the definition given of NP-hard is incorrect, and in fact, on modern hardware, existing SMT solvers such as Z3 can solve substantial instances of many interesting NP-hard optimization problems, precisely because they do not explore every possible value. Moreover, it is in general possible (but again NP-hard) to use interval arithmetic to rigorously optimize functions on continuous domains (which seems to be what is meant), as long as they are not too discontinuous; the answer you get is only an approximation of the true optimum, but you can calculate it to any desired precision.
One particularly interesting class of optimization problems — because they are not NP-hard — are continuous linear optimization problems, which can be solved in guaranteed polynomial time using interior-point methods or usually in polynomial time using the "simplex method". Contrary to what you'd think from the quote from the article, going from continuous to discrete makes the problem NP-hard again. There is also a note in Dercuano surveying the landscape of existing software and methods for solving linear optimization problems; there's a lot of very powerful stuff out there.
It turns out that you can efficiently solve an enormous range of practical optimization problems by introducing a small number of discrete variables into a linear optimization problem, thus gaining most of the performance benefit of using a linear optimizer. I don't know if there's a way to get a reasonable perceptual result in a case like this with a linear optimizer, though.
> This is where various auto-differentiation libraries can help us. Given some function, we can compute its gradient / derivative with regards to some variables completely automatically! This can be achieved either symbolically, or in some cases even numerically if closed-form gradient would be impossible to compute.
Automatic differentiation is a specific approach to differentiation which is an alternative to symbolic differentiation and the older kind of numerical differentiation. What JAX does is automatic differentiation.
If so, is this the beginning of the end of Tensorflow? I know Tensorflow is still top for production, but it is certainly rapidly losing followings in the research field, and Pytorch and now starting to focus on deployment as they know this is their weakness.