A key difference is that each iteration of scan is called by the host. Put differently, JAX can't fuse scan into a single GPU kernel, but launches a kernel for each iteration.
Depening on the workload this is no problem. If you have many cheap iterations, you will notice the overhead.
I am not sure if they are working on fusing scan and what's the current status.
Yes, but is it really the same? Afaik the `unroll=n` parameter translates `n` iterations into a vanilla `for` loop which is then unrolled into sequential statements (in contrast to a JAX `fori` loop). There still is no loop on the accelerator, strictly speaking?
I think this is up to XLA to handle not Jax. The whole selling point in TF of the tf.function decorator (which uses XLA underneath as well) is that it fuses arithmetic to lower launch count.