Actually, that's never been a constraint for JAX autodiff. JAX grew out of the original Autograd (https://github.com/hips/autograd), so differentiating through Python control flow always worked. It's jax.jit and jax.vmap which place constraints on control flow, requiring structured control flow combinators like those.
For vanilla "if", the condition must be known at compile time. For runtime, you have to use "cond", "where", or "select" (which may be analogous).