Yea exactly - Python for training, Java/.NET for inference at production. I looked at approaches like GRPC and things but my case is a bit more time-sensitive and the latency added by going over a network layer was too much.
For now I'm happy with Pytorch->ONNX and then running the ONNX model directly. But as I said, that means I can't easily train using JAX :-(