Asking because I have worked extensively on training a large model on a TPU cluster, and started with Levanter, then tried MaxText, and finally ended up on EasyLM. My thoughts are:
- Levanter is well intentioned but is unproven and lacking in features. For instance, their sharding is odd in that it requires embedding dimension to be a multiple of the number of devices, so I can't test using a model with embedding dimension 768 on a 512-device pod. Lost confidence in Levanter after finding some glaring correctness bugs (and helping get them fixed). Also, while I'm a huge fan of Equinox's approach, it's sadly underdeveloped (for instance, there's no way to specify non-default weight initialization strategies without manually doing model surgery to set weights).
- MaxText was just very difficult to work with. We felt like we were fighting against it every time we needed to change something because we would be digging through numerous needless layers of abstraction. My favorite was after one long day of debugging, I found a function who's only purpose was to pass its arguments to another function untouched; this function's only purpose was to pass its arguments untouched to a new, third function, that then slightly changed them and passed them to a fourth function that did the work.
- EasyLM is, as the name says, easy. But on a deeper dive, the sharding functionality seems to be underdeveloped. What they call "FSDP" is not necessarily true FSDP, it's literally just a certain axis that the JAX mesh is being sharded around that happens to shard some data axes and some model weight axes.
I'm still searching for a "perfect" JAX LLM codebase - any pointers?
>MaxText was just very difficult to work with. We felt like we were fighting against it every time we needed to change something because we would be digging through numerous needless layers of abstraction. My favorite was after one long day of debugging, I found a function who's only purpose was to pass its arguments to another function untouched; this function's only purpose was to pass its arguments untouched to a new, third function, that then slightly changed them and passed them to a fourth function that did the work
Some of this complexity may be necessary for achieving optimal performance in Jax. E.g. extra indirection to avoid the compiler making some bad fusion decision, or multiple calls so something can be marked as static for the jit in the outer call. As far as I'm aware MaxText is the only public Jax codebase that's demonstrated scaling to models with 100s of billions of weights. I've just started evaluating it and it seems to scale better than the Torch implementation I was using previously (even on GPU). Most of the abstraction seems to have a reason behind it (at least for me since I'm making some modifications to the vanilla model, which is easier when the components are less tightly coupled).
> Some of this complexity may be necessary for achieving optimal performance in Jax. E.g. extra indirection to avoid the compiler making some bad fusion decision, or multiple calls so something can be marked as static for the jit in the outer call
certainly some of it is but not the lion's share - I have a much simpler (private) codebase which scales pretty similarly afaict.
the complexity of Maxtext feels more Serious Engineering ™ flavored, following Best Practices.
t5 is an architecture, t5x is a framework for training models that was created with that architecture in mind, but can be used to train other architectures, including decoder-only ones(there is one in examples).
This might be a tangent, but why does JAX only support the saving / serialization of AOT compilation executables for TPU [1]? It would be great to have the ability to save compiled functions and not have to JIT compile something every time you restart a session.
(Julia has had this problem too, but they've made great progress on caching JIT compiled functions to reduce latency.)
- EasyLM [1] - Levanter [2] - T5X [3] - and more?
[1]: https://github.com/young-geng/EasyLM [2]: https://github.com/stanford-crfm/levanter [3]: https://github.com/google-research/t5x
Asking because I have worked extensively on training a large model on a TPU cluster, and started with Levanter, then tried MaxText, and finally ended up on EasyLM. My thoughts are:
- Levanter is well intentioned but is unproven and lacking in features. For instance, their sharding is odd in that it requires embedding dimension to be a multiple of the number of devices, so I can't test using a model with embedding dimension 768 on a 512-device pod. Lost confidence in Levanter after finding some glaring correctness bugs (and helping get them fixed). Also, while I'm a huge fan of Equinox's approach, it's sadly underdeveloped (for instance, there's no way to specify non-default weight initialization strategies without manually doing model surgery to set weights).
- MaxText was just very difficult to work with. We felt like we were fighting against it every time we needed to change something because we would be digging through numerous needless layers of abstraction. My favorite was after one long day of debugging, I found a function who's only purpose was to pass its arguments to another function untouched; this function's only purpose was to pass its arguments untouched to a new, third function, that then slightly changed them and passed them to a fourth function that did the work.
- EasyLM is, as the name says, easy. But on a deeper dive, the sharding functionality seems to be underdeveloped. What they call "FSDP" is not necessarily true FSDP, it's literally just a certain axis that the JAX mesh is being sharded around that happens to shard some data axes and some model weight axes.
I'm still searching for a "perfect" JAX LLM codebase - any pointers?