It is a very bad idea to handle the KV cache in Jax naively like that. Jax requires static shapes. You're creating dynamic shapes there, causing a ton of recompilation.
I used this to see if something is repeatedly compiled. I.e. I have the code that runs in a loop and you immediately see if something is compiled only once, or every time. (and it produces a lot of output) I'm not saying this is the best way to do it though, it just worked for me.
Just don't use jit in generation and it would be fine. Of course there is some performance penalty but in my experience jit is oversold and the difference is something like ~10-30%.
Also in any case to get optimized code you need flash attention and many other tricks.