In my experience, numba's killer feature is fast looping over numpy arrays. It does this extremely well. I work in an applied physics research lab, where we eschew matlab wherever possible. So we rely a lot on python, numpy and scipy. Even if we wanted to move to something like julia, it wouldn’t be practical. Every time we buy a new piece of equipment that has a python api and outputs a lot of data (we do a lot of work with picosecond resolution event timing), it’s going to use numpy.
Though if any numba developers come across this, I’d advise them to plan their upgrade process a little more carefully. The current numba conversion for python lists is due to be depreciated in favor of typed lists. But the typed list (List()) method is still too buggy. I experienced huge delays when popping and appending elements. Please flesh out these new methods before adding a bunch of warning messages about depreciation to the current builds.
JAX recompiles functions every time you call them with an array of a new shape, i.e. if called with an array of shape (7,10) and then one with shape (17, 12) the function is compiled twice. This is generally fine for deep learning and some other numerical applications where you do the same computations again and again over arrays with the same shape. But for data exploration like Pandas, in my experience, your data shapes are different with each call, so the repeated recompilations make it unattractive.
In Numba it only recompiles for dimension changes i.e. if shape changes from (7, 10) to (2, 10, 12). However numba does not integrate AD so if you need that, JAX is probably your best bet unless Enzyme matures and you want to look into integratig it with numba.
JAX is a numerics library combined with autograd for machine learning. you string together operations in python in a functional style, which is then passed into XLA which is google's optimizing compiler that can target cpus, gpus and tpus to generate optimized machine code for those architectures.
numba lets you inline with a subset of python which is then compiled with llvm producing something very similar to what you would get if you applied a bunch of regexes to that subset of python to convert it from python to C. (with special bindings for numpy arrays, since they have special importance in these domains)
numba is specifically targeted at things like core numerical algorithms that are typically coded in C and fortran, and are typically comprised of solely for loops and basic arithmetic. JAX is more targeted at high level machine learning applications where the end user is stringing together more high level numerical algorithms.
i suspect that JAX would be a bad fit for custom computer vision or numerical algorithms that are used outside of the use-case of doing neural networks work.
JAX is actually lower level than deep learning (despite including some specialized constructs) which makes it an almost drop-in replacement for numpy that has the ability to jit your python code.
I am currently doing some tests introducing JAX in a large numerical code base (that was previously using C++ extensions), we are not using autograd nor any deep learning specific functionalities.
Having seen actual numbers, I can tell you that JAX on CPU is competitive with C++ but produces more readable code with the added benefit of also running on GPU. However, it does introduces some constraints (array sizes cannot be too dynamic) so, if you are not also planning on also targeting GPU, I would probably focus on numba.
i actually poked around a bit as a contributor a few years ago (before i had to start a real job) and remember it being a thin layer on top of XLA among a few other things. interesting to learn that it is growing into something that people are using as a fully fledged numerical computing library.
also a little bit surprising to see how immature and fragmented the python gpu numerical computing ecosystem is. everybody bags on matlab, but it has been automatically shipping relevant operations over to available gpus for years.
There are also tricks to get around the array shape dynamics. Like padding up your shapes to some common format. Everything between 6 and 10 becomes 10, everything 11-20 pads to 20, etc.
Jax is a great general
purpose numerical computing library.
Jax offers a much lower level control, it's almost bare metal, and can be used for all sorts of things besides Deep Learning. I am currently using to implement a better `scipy.optim` library.
Though if any numba developers come across this, I’d advise them to plan their upgrade process a little more carefully. The current numba conversion for python lists is due to be depreciated in favor of typed lists. But the typed list (List()) method is still too buggy. I experienced huge delays when popping and appending elements. Please flesh out these new methods before adding a bunch of warning messages about depreciation to the current builds.