Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

It is a very bad idea to handle the KV cache in Jax naively like that. Jax requires static shapes. You're creating dynamic shapes there, causing a ton of recompilation.


The blog mentions it's not for production use. This sounds like one thing you'd want to change.

I was curious what else made it not fit for production. Anything fundamental or just minor issues like this?


Is there any automatic way to get warned against these antipatterns?


you can see each compilation if you use JAX_LOG_COMPILES variable or you use low enough logging level.


Sorry, not to belabor this point.

Would that suggest to you what you did wrong? Or purely show you what you got right? How chatty is this variable?


I used this to see if something is repeatedly compiled. I.e. I have the code that runs in a loop and you immediately see if something is compiled only once, or every time. (and it produces a lot of output) I'm not saying this is the best way to do it though, it just worked for me.


Just don't use jit in generation and it would be fine. Of course there is some performance penalty but in my experience jit is oversold and the difference is something like ~10-30%.

Also in any case to get optimized code you need flash attention and many other tricks.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: