This is quite useful, as this should make training this type of LLMs much more efficient.
So this is a ternary weight LLM using quantization aware training (QAT). The activations are quantized to 8 bits. The matmal is still there, but it is multiplying the 8 bit activations by one bit values.
Quantization aware training with low bit weights seems to lead to reduced overfitting by an intrensic tendency to regularize. However, also the model capacity should be reduced compared to a model with the same number of weights and a higher number of bits per weights. It's quite possible that this only becomes apparent after the models have been trained with a significant number of tokens, as LLMs seem to be quite sparse.
Edit: In addition to the QAT they also changed the model architecture to use a linear transformer to reduce reliance on multiplications in the attention mechanism. Thanks to logicchains for pointing this out.
So a mishmash of ideas from other papers(Not to downplay the results). This is exciting times of hackery and basically using puzzle pieces and piecing together stuff.
This is the kind of stuff that can only be done so quickly by having more and more people brought into the field to try these ideas out. The more people the more permutations of ideas.
Am I reading the paper right, or is there any reason not to be cynical about these points?
1) It’s weird to choose linear attention for their implementation because that’s not what their paper is about and they claim no insights relevant to attention mechanisms.
2) By benchmarking all models this way (linear vs linear) it likely inflated their numbers over comparing their removal of matmul in a quadratic vs quadratic scenario.
3) This claim implies a comparison to the state of the art in language models where the standard is quadratic attention, and is therefore a flawed comparison:
“We processed billion-parameter scale models at 13W beyond human readable throughput, moving LLMs closer to brain-like efficiency.”
4) Those type of brain comparisons fall apart under scrutiny, are not standard in ML research and don’t mean much anyway.
5) Right up front in the abstract they make specific performance claims and imply they come from removing matmul, but don’t mention linear attention until section 4 on experiments.
Wow - This seems at first read to be really impressive work. They got scaling laws up to a reasonable size, 2.7B, and also run a few downstream tasks. Would be interesting to see how a comparable model trained by someone else does, to check their scores against those.
They get real (61%!?) memory savings during training, and inference too.
On top of all that, they then go build an FPGA core which is programmed with a custom assembler. And their code is posted and works seamlessly with huggingface transformers?! Absolutely going to test this out.
So far pretty awesome - simpler first pass, integrated their BitLinear layer into my training setup, and got a run with just the MLPs swapped to see how things are- at comparable parameter counts (3.2B), the BitLinear MLP model is slightly lower loss!
> cheap piecewise affine approximation that is achieved by adding the bit representation of the floating point numbers together as integers ... with little to no performance impact
> we show that we can eliminate all multiplications in the entire training process, including operations in the forward pass, backward pass and optimizer update, demonstrating the first successful training of modern neural network architectures in a fully multiplication-free fashion.
This reaches demoscene levels of crazy/impressive!
>This reaches demoscene levels of crazy/impressive!
The exp/log trick to multiply with addition does indeed look very familiar. I know that a number of demos used it in the 90ies to simplify matrix multiplications for 3d graphics.
I feel like all of these transformer reductions to binary or ternary bits are basically constructing an implicit decision tree, where any stage of the process is basically answering a question with yes/no/I don't know answers, where "I don't know" basically invokes a continuation for further processing with more context.
it is super easy to try it out, the 2.7B, 1.3B, 0.37B models are on huggingface, and the generate.py example just works if you have triton 2.2 installed
One thing I didn’t figure out from just the paper: how does one train these parameters that are not even approximately real numbers? Specifically, most of the parameters are ternary (i.e. -1, 0, or 1). The approximate gradient discussed in the paper will (I think) give some real gradient on each parameter, and that can be further processed by the learning rate schedule, but the result is still a real number g_i for each parameter a_i. Normally one would update a_i to a_i + g_i, but with these ternary parameters, a_i + g_i isn’t ternary!
So what’s the extra trick to make the model stay quantized? Does one evaluate the gradients on a whole bunch of training inputs, add them up, apply some randomness, and then re-quantize the model? Or is it something else?
>To train our 1-bit model, we employ the straight-through estimator (STE)[BLC13] to approximate the gradient during backpropagation. This method bypasses the nondifferentiable functions, such as the Sign (Eq. 2) and Clip (Eq. 5) functions, during the backward pass. STE allows gradients to flow through the network without being affected by these non-differentiable functions, making it possible to train our quantized model.
>While the weights and the activations are quantized to low precision, the gradients and the optimizer states are stored in high precision to ensure training stability and accuracy. Following the previous work [LSL+21], we maintain a latent weight in a high-precision format for the learnable parameters to accumulate the parameter updates. The latent weights are binarized on the fly during the forward pass and never used for the inference process.
This seems a bit unfortunate — the training process will end up using a whole lot more memory than inference. I wonder whether one could get away with storing the high precision weights in slow host memory and using the quantized weights for the backward pass, thus keeping them out of GPU memory.
The FPGA would be only to prove the transistor efficiency. If this works people will eventually do ASICs (possibly inside the GPUs a few generations ahead).
Follow the github author to this project and you get to https://ruijie-zhu.github.io/ which has in the bio "I am deeply fascinated by large language models and had the privilege to be one of the authors of the RWKV large language model"
geohot is shilling his product, which aims to capitalize on making accelerator manufacturers compete with each other on FLOP / tensor core throughput.
the OP outlines what could be an entirely different compute paradigm for LLMs, hence the FPGA study. they just happen to also get impressive performance on GPUs making the most of the available interface.
https://arxiv.org/abs/2402.17764
The main addition of the new paper seems to be the implementation of optimized and fused kernels using triton, as seen here:
https://github.com/ridgerchu/matmulfreellm/blob/master/mmfre...
This is quite useful, as this should make training this type of LLMs much more efficient.
So this is a ternary weight LLM using quantization aware training (QAT). The activations are quantized to 8 bits. The matmal is still there, but it is multiplying the 8 bit activations by one bit values.
Quantization aware training with low bit weights seems to lead to reduced overfitting by an intrensic tendency to regularize. However, also the model capacity should be reduced compared to a model with the same number of weights and a higher number of bits per weights. It's quite possible that this only becomes apparent after the models have been trained with a significant number of tokens, as LLMs seem to be quite sparse.
Edit: In addition to the QAT they also changed the model architecture to use a linear transformer to reduce reliance on multiplications in the attention mechanism. Thanks to logicchains for pointing this out.