Hacker News new | past | comments | ask | show | jobs | submit login

Jax JIT of scan is fairly good, so loops aren't as slow as you'd expect.



A key difference is that each iteration of scan is called by the host. Put differently, JAX can't fuse scan into a single GPU kernel, but launches a kernel for each iteration.

Depening on the workload this is no problem. If you have many cheap iterations, you will notice the overhead.

I am not sure if they are working on fusing scan and what's the current status.


There is an ‘unroll’ parameter in scan that lets you control how many iterations of the loop are fused into a single kernel.


Yes, but is it really the same? Afaik the `unroll=n` parameter translates `n` iterations into a vanilla `for` loop which is then unrolled into sequential statements (in contrast to a JAX `fori` loop). There still is no loop on the accelerator, strictly speaking?


I think this is up to XLA to handle not Jax. The whole selling point in TF of the tf.function decorator (which uses XLA underneath as well) is that it fuses arithmetic to lower launch count.


there's a paper about it that I just found, enjoy https://arxiv.org/pdf/2301.13062




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: