I wish there was more information as to how this differs from or improves on Jax.
Flax+Jax+OpenXLA seems to finally be building some momentum so when a big player launches yet another competitor, the justification for it would be a good thing to see.
What was “not good enough” with Jax? Why did it make sense to put this human time and energy there instead of doubling down on Flax/Jax/OpenXLA? How will this move the needle at all against Nvidia?
I guess the fact that Google is pushing OpenXLA is making the other giants not want to truly lean in?
I don’t want 10 competing “choices”. I want one clear, open, competitor to Cuda that works on all the competing hardware.
The only way to beat CUDA like in many other API cases, is by middleware.
Many keep forgeting that CUDA for years is a polyglot platform, C, C++, Fortran, plus anything PTX, some of which also target OpenCL, meaning Haskell, Java, C#, Julia, Futhark, or Python bindings.
Then there are the libraries, and GPGPU graphical debugging tools.
By the way, Modular just announced partnerships with AWS and NVidia for Mojo and related tooling.
The fact that I can run this on mobile (iOS) devices using the C++ interface makes all the difference for me. I find that extremely refreshing among all those other Python/Server/PC -only frameworks.
Running non-trivial ML workloads on the edge has been on my wishlist for years and it sounds like Apple has just the thing.
> I wish there was more information as to how this differs from or improves on Jax.
The quick start guide has an overview.
> The Python API closely follows NumPy with a few exceptions. MLX also has a fully featured C++ API which closely follows the Python API.
The main differences between MLX and NumPy are:
Composable function transformations: MLX has composable function transformations for automatic differentiation, automatic vectorization, and computation graph optimization.
Lazy computation: Computations in MLX are lazy. Arrays are only materialized when needed.
Multi-device: Operations can run on any of the supported devices (CPU, GPU, …)
The design of MLX is inspired by frameworks like PyTorch, Jax, and ArrayFire. A noteable difference from these frameworks and MLX is the unified memory model. Arrays in MLX live in shared memory. Operations on MLX arrays can be performed on any of the supported device types without performing data copies. Currently supported device types are the CPU and GPU.
I was excited about JAX, but I think the developers missed a trick when they decided it should be entirely immutable. It sounds silly, but I think if I have an array `x` and want to set index 0 to 10, it's a big mistake if I can't do:
x[0] = 10
And instead I have to do:
y = x.at[0].set(10)
Of course this has advantages, and I know it sounds lame, but as someone whose brain works in numpy, this is really offputting.
But you can convert automatically a mutable code into a functional code if that makes things easier. That's what Haskell's `do` notation does, and PyTorch even has `torch.func.functionalize` for that. Immutable should be default, but not compulsory.
Haskell's `do` doesn't allow mutability still, it just allows a syntax that looks a bit more imperative than usual. The problem with all the "convert to pure function" magic is that for example this piece of code
arr[1000000] = 1
has to clone the entire array if you want it to be pure, leading to very unpredictable performance. There are also some algorithms that are straight up impossible to (efficiently) implement without mutability. Often, it's exactly those algorithms that are hard to optimize for optimizers like JAX.
Specifically in JAX, code that is slow due to copying will often be optimized into mutable code before running for performance reasons. But because JAX still has the gurantees of no mutability, it can do many optimizations such as caching or dead-code elimination.
Of course `do` itself is much more capable, but it has an effect of the conversion for some monads, which was what I wanted to say.
You are correct about in-depth mutations and resulting complications, but that only strengthens my assertion: immutable should be default, but not compulsory (because sometimes you absolutely need them). And mutability doesn't preclude caching or dead-code elimination; you just have to be more careful. Often it's the case that you can convert a mutable code into an immutable form only for the purpose of analysis, which is definitely harder than an immutable code in the first place but not impossible. Scalar compilers have used SSA---an immutable description for mutable programs---for a long time after all.
I've found the process of porting custom ML models to iOS extremely difficult.
AFAIK the only way to leverage Apple Neural Engine (and get the best performance) is to use CoreML. The only documented way to use CoreML is via coremltools, which takes a trace of a PyTorch model and attempts to translate it into a protobuf graph understood by CoreML.
This process often fails and requires model changes, or worse "succeeds" but gives you the wrong output when you run the model. Additionally, you have to play detective to figure out why some operations run on the CPU, or GPU instead of ANE.
It's exciting to see more tools like this for working with tensor-like objects, but I really wish Apple would make porting custom models in a high performance manner easier.
I don't understand why Apple isn't trying to integrate better with the standard tools for that field. I guess it makes sense to lock in app devs, but ML eng.?
That said I've had good success with onnxruntime recently [0].
The project probably at least partially serves as documentation for other platforms to integrate Silicon acceleration. It basically demonstrates how to use macOS Accelerate and Metal MPS (metal performance shaders) using C++ for Machine Learning and training optimization.
Thus other platforms can simply take this backend-code and integrate it. (Pytorch basically did that already with Apple's help).
Nailed it.
I think more than partially. What happens in this repo will spread to the other major frameworks and over time, clever ideas that spawn on other projects will be reimplemented with Apple's adjustments back into the repo. It's a brilliant and efficient way to interact with the community, that can likely be measured in more sales of their hardware over time.
I've found that converting a custom PyTorch model to CoreML can be quite complex. I had to modify certain data types to facilitate the conversion process, which was not straightforward.
The process becomes particularly frustrating when the model appears to convert successfully, but then fails to produce any output or loses layers entirely. Additionally, the debug information provided by the conversion tool isn't very helpful, adding to the challenge.
As an iOS developer with no prior experience in Python, I found myself in a unique position. I needed to build a custom model for one of my keyboard apps to handle tasks like spellchecking, grammar correction, next-word prediction, and autocompletion. This necessity pushed me to learn Python and PyTorch. After mastering these, I then had to convert my knowledge back to Swift and CoreML.
Ideally, I would have preferred to build my model directly in Swift & CoreML, but the current tools and resources for this approach are limited. This limitation is particularly evident in terms of the ease of use and flexibility that Python and PyTorch offer.
That’s true, one has to design the model for the deployment target. Especially avoiding in-place tensor operations and python control flow helps for tracing.
I'm not certain how many other researchers would be willing to swap away from PyTorch (or tf or jax) to yet another framework. Amazon pushed MXNet for a while, but the only time I used it was while interning there. From what I understand, advantage here is tight integration with Apple hardware/shared memory?
That being said, going to spend some time getting familiar with it by reimplementing Llama-2 and trying to make it fast: https://github.com/jbarrow/mlxllama
I see that the focus is Apple Silicon, but wouldn't it make sense to also make this available to other hardware (e.g. AMD64 + CUDA) and other platforms (e.g. Linux)?
Otherwise I wonder if this really finds too much adoption. You don't want to lockin yourself when you have all those other choices. (Ok, to be fair, as it is discussed here, PyTorch etc might not work optimal yet on Apple Silicon, but I guess this is just a matter of time.)
> You don't want to lockin yourself when you have all those other choices.
I already have this nice, powerful, and expensive hardware sitting on my desk. Why not make the most optimal use of it? I can worry about lock-in once it becomes inadequate for the task.
> PyTorch etc might not work optimal yet on Apple Silicon
So now we can take MLX apart and see how we can use it to improve PyTorch.
Given the API is similar to existing libraries, I’m curious as to whether the performance is better with this one. And if so, what’s stopping existing libraries from being as fast. IIRC, PyTorch at least has a Metal backend.
The README mentions unified memory, but what stops other frameworks from modeling copies as no-ops? I wonder if MLX makes larger architectural decisions based on GPU CPU communication being cheap.
It seems like it's matching PyTorch's API very closely, which is great. Part of me wishes they took it a step further and just made it completely API-compatible, such that code written for PyTorch could run out-of-the-box with MLX, that would be killer.
This is (partly) outdated. MPS (metal performance shaders) are now (since torch 2.x) fully integrated in standard Pytorch releases, no external backends or special torch versions are needed.
There are few limitations left when compared with other backends. Instead of using 'cuda' device, one simply uses 'MPS' as device.
What remains is: the optimizations Pytorch provides (especially compile() with 2.1) focus on cuda and it's historic restrictions that result from CUDA being _not_ unified memory, and lots of energy goes into developing architectural work-arounds in order to limit the copying between graphics HW and CPU memory, resulting in proprietary compilers (like triton) that move parts of the python code into proprietary hardware.
Apple's unified memory would make all of those super complicated architectural workarounds mostly unnecessary (which they demonstrate with their project).
Getting current D/L platforms to support both paradigms (unified/non unified) will be a lot of work. One possible avenue is the MLIR project currently leveraged by Mojo.
> This is (partly) outdated. MPS (metal performance shaders) are now (since torch 2.x) fully integrated in standard Pytorch releases, no external backends or special torch versions are needed.
Not sure what you're referring to, the link I provided shows how to use the "mps" backend / device from the official PyTorch release.
> lots of energy goes into developing architectural work-arounds in order to limit the copying between graphics HW and CPU memory
Does this remark apply to PyTorch running on NVidia's platforms with unified memory like the Jetsons?
It looks like this is still missing many matrix operations like QR, SVD, einsum, etc. Is there a clear route to using these on the GPU in Python on Apple Silicon? Last I checked the PyTorch backend was still missing at least QR...
factorization methods are somewhat uncommonly used in deep learning (the likely target of this framework) and have compute properties (such as approximate outputs, non-deterministic number of iterations) that make them unlike the BLAS++ standard APIs.
einsum seems like a reasonable thing to request, but it's hard to be performant across the entire surface exposed by the operation.
Exactly right that this targets a narrower surface to enable many deep learning models. I wonder how uncommon it is to hit some operation that is not included, though? It seems pretty common from a PyTorch MPS tracking issue:
NVIDIA's moat is not just in providing BLAS++ operations, but extending this to a wider range of cuSPARSE, cuSOLVE, cuTENSOR, etc. Without these, it feels like Apple is just trying to play catch up with whatever is popular and unsupported...
Going in I was really worried there would be a bunch of wonkiness with unique apis different from torch or tf. But it seems pretty close to torch, which is a pleasant surprise. No need to reinvent the wheel.
This is really cool. I wonder how long it will be till we have GPT-4 quality models that run locally (if we ever will). Would open up a lot of possibilities.
We aren't quite there yet, but the last year has been an incredibly exciting time for the open source LLM community. If your computer is decently powerful, you might be really surprised by what's already possible. LM Studio on Apple Silicon supports GGUF quants on GPU out the box.
Flax+Jax+OpenXLA seems to finally be building some momentum so when a big player launches yet another competitor, the justification for it would be a good thing to see.
What was “not good enough” with Jax? Why did it make sense to put this human time and energy there instead of doubling down on Flax/Jax/OpenXLA? How will this move the needle at all against Nvidia?
I guess the fact that Google is pushing OpenXLA is making the other giants not want to truly lean in?
I don’t want 10 competing “choices”. I want one clear, open, competitor to Cuda that works on all the competing hardware.