Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

For those curious what the big deal is here: PyTrees make it wildly easier to take derivatives with respect to parameters involving a complex structure. This makes it much easier to organize code for non-trivial models.

As an example: if you want to implement logistic regression in JAX, you need to optimize the weights. This is easy enough since this can be modeled as a single value, a matrix of weights. If you want to model a 2 layer MLP, now you have to use 2 matrices of weights (at least). You could treat this as two parameters to your function (which makes the derivative more complicated to manage) or you could concatenate the weights and split them up, etc. Annoying, but managable.

When you get to something like a diffusion model you now need to manage parameters for a variety of different, quite complex, models. It really helps if you can keep track of all these parameters in whatever data structure you like, but also trivially just call "grad" with regard to these and get your models derivative with respect to its parameters.

Pytrees make this incredibly simple, and is a major quality of life improvement in automatic differentiation.



Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: