> transformer memory requirements scale with the square of sequence lengths
Not true, see: Flash Attention. You can losslessly calculate the attention in blocks using a little math trick. Essentially each subsequent block "corrects" the denominator of the last block's softmax calculation. At the end you have a perfectly* accurate softmax. Since you don't need to keep the whole sequence in memory to perform the softmax, your memory now scales linearly with respect to sequence length, and due to the lower memory bandwidth requirements and increased kernel fusion the operation also tends to be faster.
* While mathematically the calculation ends up exactly the same, in practice the result ends up slightly different due to the whims of F32 and F16 inaccuracies, and since the "max" used to calculate the softmax in a numerically stable way is calculated on a per-block basis. Doesn't significantly effect training or validation loss though.
Not true, see: Flash Attention. You can losslessly calculate the attention in blocks using a little math trick. Essentially each subsequent block "corrects" the denominator of the last block's softmax calculation. At the end you have a perfectly* accurate softmax. Since you don't need to keep the whole sequence in memory to perform the softmax, your memory now scales linearly with respect to sequence length, and due to the lower memory bandwidth requirements and increased kernel fusion the operation also tends to be faster.
* While mathematically the calculation ends up exactly the same, in practice the result ends up slightly different due to the whims of F32 and F16 inaccuracies, and since the "max" used to calculate the softmax in a numerically stable way is calculated on a per-block basis. Doesn't significantly effect training or validation loss though.