Hacker News new | past | comments | ask | show | jobs | submit login
Google/Trax – Understand and explore advanced deep learning (github.com/google)
359 points by Bella-Xiang on Feb 17, 2020 | hide | past | favorite | 50 comments



Word is that internally at Google, among a few teams, and then also externally, Trax/Jax are putting up real competition to Tensorflow. Some teams have moved off of tensorflow entirely. Combined with the better research capabilities of PyTorch, the future of tensorflow is not bright. Given that, Tensorflow still provides the highest performance with regards to production usage, and has tons of legacy code strewn throughout the web.

I would argue that this is not the fault of Tensorflow, but rather the hazard of being the first implementation in an extremely complex space. Seems like usually there needs to be some sacrificial lamb in software domains. Somewhat like Map/Reduce was quickly replaced by Spark, which has no real competitors.


Overall, I've seen similar movement away from Tensorflow in my social circle of research scientists/engineers.

One area I'd push back on is that "this is not the fault of Tensorflow." An area of weakness for Tensorflow is that it solves a number of DL problems with a specialized API call. That's not an asset, that's a liability.

LSTMs were always a pain point. So much so that for Tensorflow projects, I gave up and insisted on traditional feedforward approaches like CNNs + MLPs or ResNets when LSTMs would be viable. Mostly identical performance with decent speed boosts from avoiding recurrence, and the simpler code reduced maintenance by non-ML engineers.

As soon as you branch out of standard DL bread and butter models, you spend frustratingly long periods of time tracking down obscure solutions in a part of the API space that had its own hard-to-follow logic.

Every time I'd point out that it's hard to do something either in forums or HN directly, I'd get a response that its easy to do with [insert-random-api] function call.

In the end, it's my opinion that Tensorflow will lose out to JAX and Pytorch, by no fault other than its own complicated construction.


I agree with this, although I think it was a conscious and deliberate choice with TF 2.0. We have given up on TF for all future work which is sad since I really appreciate a number of the pieces that surround the core. I think they made a choice to emphasize the support of already developed models and make the experience great for novices will be a decision that they will come to regret. We found so many issues when we tried to port some of our existing models to TF 2.0. The sad part was that there were GitHub issues for all of them.

Personally I think Tensorflow has already lost and we just need to let it play out over the next few years. One interesting wrinkle is that since Trax, Jax and Flax utilize pieces of Tensorflow the TF team can probably claim good internal adoption numbers depending on how they count.


Every obscure solution in Tensorflow also has a change to break at an upgrade, I'm glad I moved to pytorch.


I thought Theano was the first real deep learning implementation. What is the difference between that and tensorflow ?


True, I was more so referencing first implementation to see wide scale adoption. But maybes there some holes in my narrative...


Actually, Torch (2002) is older than Theano (2007). Moreover, both were already used widely in the research community. Tensorflow was actually late in the grand scheme of things, but replaced Theano and to some extend Torch (which used Lua for its public API).

I think it's fairer to say that Tensorflow is the first implementation that had wide adoption outside academia.


The first would be SN/Lush from circa 1987, of which Torch (from NEC Research), and then PyTorch are direct descendants in a lot of ways.

https://leon.bottou.org/projects/lush


> Somewhat like Map/Reduce was quickly replaced by Spark, which has no real competitors.

I wouldn't say that - what about Flink? (https://flink.apache.org/)


Actually Dask on the python side.

Much higher performance if you were doing pyspark.

Because of the serialisation cost


That's a shame, I liked tensorflow much better than the imperative alternatives. I guess it comes with the territory. Scientific software has to be dumbed down on the CS side because the users (grad students) typically never seriously programmed.


>I would argue that this is not the fault of Tensorflow, but rather the hazard of being the first implementation in an extremely complex space. Seems like usually there needs to be some sacrificial lamb in software domains.

I'd agree if Google didn't have a history of building things with (arguably) unnecessarily complex APIs, like Angular1. I remember when Angular and React were new, seeing an Angular "cheatsheat" that was around 14 pages; the equivalent React cheatsheet was only 2 pages. Now, I do love the idea of Tensorflow 1, essentially a functional DSL to explicitly construct computation graphs, but Google's implementation of that idea was suboptimal: hard-to-follow error messages, not intuitive, multiple APIs to do the same thing (and continual API breakage as new APIs are introduced), difficult to debug. And even if the graph "compiled" correctly, it could still fail at execution time.

It's like they were building a programming language but lacked anyone with language design or PL theory background. Which makes sense given that anyone passionate about language design might prefer to work for somewhere closer to the cutting edge like Microsoft (C#, F#, Typescript, F*...), Facebook (Bucklescript, Hack), Apple (Swift) or Mozilla (Rust). Google does have its own languages, Dart and Go, but they're notable for ignoring and rejecting respectively cutting-edge PL theory (e.g. not disallowing null pointers, and in Go's case not even supporting parametric polymorphism). The day-to-day languages used at Google are also not particularly appealing to a PL enthusiast: Python, Java and non-modern C++.

Google software often also seems to care more about enforcing their idea of "best practices" on the user than about user experience. Tensorflow's C++ support is an example of this: it requires using Babel. As Babel doesn't easily support integration into an existing C++ project, this essentially means you have to change your whole project over to Babel just to use Tensorflow C++, which is a huge amount of effort to go to just to use a library. Especially when it's probably quicker to just rewrite the model in PyTorch, which provides a simple header file and static library for linking, the standard way of distributing C++ libraries. PyTorch also provides a nicer C++ API, because it's not confined to the ancient C++ standards that Google enforces, so can provide a modern API that's not full of macros (macros in modern C++ are considered bad practice; should only be used when there's absolutely no alternative).

To be fair, Google is really good at engineering language runtimes. Dart, Go and Tensorflow are all impressive works of engineering. They just seem to lack the organisational DNA for making them really nice to use, which maybe makes sense given the main source of their revenues is search/AdWords, the success of which is primarily driven by superior science/engineering. Compared to e.g. Facebook and Microsoft, that were/are in the business of making pretty things that users like to use (operating system, word processor, website). Or even comparing to Apple: people pay a huge premium for Apple phones over Android phones, in spite of their worse hardware, because of their more appealing design.


I would actually argue that the very first version of Tensorflow was great. I don't quite remember what version number that was, but it was well before 1.0. It was basically just a better version of Theano with a single way of doing things and very explicit graphs. I remember I ported all my Theano code without issues and loved it.

From then on it all went downhill. Lots of duplicate APIs were added and the code grew increasingly complex and opque because everyone at Google wanted in. Everyone wanted a slice of the pie to pad their resumes and internal performance reviews to be able to say "Look, I contributed this to TF! Promote me!" - Now it's a typical example of Conway's Law - a big mess that mirrors how messed up the internal incentive structure at Google is.


Makes sense. I only started using it after it'd already been though tf.nn, tf.layers, and tf.estimators (which eventually got replaced by tf.keras, around the time I moved to PyTorch, and the time dropout was broken for two releases.. https://github.com/tensorflow/tensorflow/issues/25175). I remember though once spending half a day just trying to figure out why a graph that built correctly kept failing at execution time: I ended up having to binary-search the code changes because I couldn't find the cause of the error in the stack trace, something I've only experienced previously in template heavy C++ code. I don't imagine that situation was much better in 1.0.


The error message have always been an issue, even with early releases. That being said, it was much easier to trace errors when there were fewer high-level functions and you had to build the layers yourself. You fully understood what the graph looked like, and it was mostly obvious where something is likely to go wrong.

Sure, it's a bit more work, but in the long run it saves times because it avoids the situation where you need spend half a day trying to understand opaque TF code touched and extended by 100 different developers.

EDIT: Also, I LOVE the Github issue you posted. It's such a beautiful example of complexity cost when a large number of people is working on a project and there's no longer a single person that fully understands how code interacts across modules.


I'd love to know what the solution is to prevent that eventually happening to a project, if one even exists. PyTorch seems to be doing better, but maybe that's just because it's a younger project. I'd be interested to see a blog post comparing Facebook's incentive structure for maintaining products to Google's.


It's a mistake to imagine that Google has some kind of centralized "programming languages" organization[1]. The people who designed Angular have probably never even met the people who designed Go, much less been influenced by them or shared influences. Same for TensorFlow. There may be "PL enthusiasts" at Google but the things they do at work are the same things everybody does at Google: moving data from one protobuf to another.

1: There is a "language platforms and tools" group (I used to work there) but its mission is stewardship of existing languages, not development of new ones.


Exactly what I call the programmer open-source API trap. The moment I saw how tensor was built out and published almost instantly I knew it was a funnel for capturing innovation and would eventually be replaced when they found a better mouse trap. I did not spend much time learning that open source wheel. Hopefully this new one will be smaller.


>> highest performance with regards to production usage

Not really, no. I've been using TensorRT for that quite successfully. If you can work around its limitations, I don't think anything can compete, at least not on modern NVIDIA GPUs. And yes, I'm aware TF can use TensorRT as well. But why drag in all of TF if you just want inference?


Any idea about flax? And how do flax and Trax interact?


This is an example of Peter Thiel's "the last mover advantage" principle at work.


Is it just me or is there zero explanation to what this actually is?

It somehow "helps" me understand deep learning but its tutorial / doc is one python notebook with three cells where some nondescript unknown API is called to train a transformer.

Huh?


You're not alone. Even as an ML researcher at ex-FAANG I have no idea what this is. Is it a collection of well-documented models built on top of jax?

I could probably figure out what exactly this is if I spend an hour looking through the code, but it should be made clear in the README.

It's kind of funny, I think it's completely opaque what this actually is to 99% of HN users, but for some reason it's being upvoted because it has Google and Deep Learning in the name.


15min of looking through the code and my best guess right now is that it's a reimplementation of higher-level primitives, such as optimizers and layers, found in Tensorflow/Pytorch/etc, but based on a variable backend (you can pick jax or TF), together with a collection of models and training loops. I think the idea is that most of the code is simpler and more modular than what you would find in TF, which makes the models easier to read.

However, I don't yet understand what the use case is, or how it helps you to "learn" anything.


Trax is meant as a successor to Tensor2Tensor.

From my understanding of what Tensor2Tensor was (and thus what Trax is):

Basically, at it's core, it's a library of models, datasets, and utilities that are meant to be clear and work together nicely.

This is meant to be good for education, as you have an easy and unified way to run SOTA trained models on a wide variety of datasets.

This is also meant to be good for research, as much of research involves building on top of other people's models, and having a library of models with shared API (as opposed to the fragmented wilderness that is public research code) is meant to facilitate that.

I believe that it also ended up being used in industry as well, simply as a reliable source of pre-trained models.

PS: As Jax did not have an "official" neural network library until very recently, this also served as a neural network library for Jax.


JAX still doesn't have an "official" neural network library yet.

Source: I work closely with the JAX team at Google.


Poor wording perhaps. I meant that there was no project that "officially" called itself a neural network library for Jax, not that Jax had "officially" chosen a neural network library.

I'm talking, of course, about Flax.


From the README:

"Trax code is structured in a way that allows you to understand deep learning from scratch. We start with basic maths and go through layers, models, supervised and reinforcement learning."


But what is it? It says it is an API or module? Is it a DL framework? Is it an extension of TF/Jax with new models? Is it a set of tutorials? Are we to work through the code? Is a part of the doc for Jax? Is it a set of implementations of Jax code into concrete models?

All these things come to mind when going through this Github.

By now I think to have figured out that it is a module that implements some DL models using Jax. So it is like an extension to Jax. It "helps" you to understand how to build the models in Jax if you go through the code. You could also load these abstractions as a Python module and use them in code, but I would doubt that this would "help you understand deep learning". So the real intention is to go through the code and see how Jax implements, for example, a transformer.

Alternatively, this could be an early stage of some PyTorch-type framework over Jax and TF. So essentially, another Keras.


It's the successor of tensor2tensor, so it's a high level library that implements different the algorithms. Unlike other libraries the code is very readable.


Thank you


Note that, in this space, there is also Flax[0] which is also built on top of Jax bringing more deep-learning specific primitives (while not trying to be tensorflow compatible unlike Trax if I understand correctly).

[0]: https://github.com/google-research/flax/tree/prerelease


Is this like a layer on top of TensorFlow to make it easier to get started? Is it meant to compete with PyTorch in that respect?

I wish the title and description were more clear. They make it sound like a course but it is a library/command-line tool.


So Tensorflow 2 is built on top of a Keras API.

Its supposed to be better UX, but Pytorch really is far superior UX.


Trax is built on top of Jax, which is built on top of Autograd, which is a PyTorch like thing


I was recently surprised to discover that Jax can't use a TPU's CPU, and that there are no plans to add this to Jax. https://github.com/google/jax/issues/2108#issuecomment-58154...

A TPU's CPU is the only reason that TPUs are able to get such high performance on MLPerf benchmarks like imagenet resnet training. https://mlperf.org/training-results-0-6

They do infeed processing (image transforms, etc) on the TPU's CPU. Then the results are fed to each TPU core.

Without this capability, I don't know how you'd feed the TPUs with data in a timely fashion. It seems like your input will be starved.

Hopefully they'll bring jax to parity with tensorflow in this regard soon. Otherwise, given that jax is a serious tensorflow competitor, I'm not sure how the future of TPUs will play out.

(If it sounds like this is just a minor feature, consider how it would sound to say "We're selling this car, and it can go fast, but it has no seats." Kind of a crucial feature of a car.)

Still, I think this is just a passing issue. There's no way that Google is going to let their TPU fleet languish. Not when they bring in >$1M/yr per TPU pod commitment.


You can use tensorflow's tf.data, I think.


Sort of. The way you do it is to scope operations to tf.device(None), which selects the TPU's CPU. (I think it's equivalent to using the first device in sess.list_devices(), which is the CPU.)

You can scope operations in tf.data, sure, but you can also execute arbitrary training operations. I use this technique to train GPT-2 117M with a 25k context window, which requires about 300GB memory. Only a TPU's CPU has so much.

That's why it was surprising to hear Jax can't do that. It's one of the best features of a TPU.


Not sure why one would bother with this. This is a less mature version of PyTorch. And I know there's XLA and stuff, but I've yet to see any major benefit from that for research in particular. A ton of time in DL frameworks is spent in the kernels (which in most practical cases means CUDA/cuDNN) which are hand-optimized far better than anything we'll ever get out of any optimizer.


If you're talking about Jax, there's a couple different reasons to bother for research

1. Full numpy compatibility.

2. More efficient higher order gradients (because of forward mode auto diff). Naively it's asymptotic improvement, but I believe Pytorch uses some autodiff tricks to perform higher order gradients with backwards mode, at the cost of a decently high constant factor.

3. Some cool transformations like vmap.

4. Full code gen, which is neat especially for scientific computing purposes.

5. A neat API for higher order gradients.

2. and 5. are the most appealing for DL research, 1., 3. and 4. are appealing for those in the stats/scientific computing communities.

PyTorch is working on all of these, to various degrees of effort, but Jax currently has an advantage in these points (and may have a fundamental advantage in design in some).


From a practitioner:

1. Meh. PyTorch is close enough to not worry about it, and is better in some places.

2. Meh. All the methods people use in practice for deep learning in particular do not use higher order gradients. Most higher order methods are prohibitively memory expensive, and memory is at a premium in acceleration hardware (and so is the bus bandwidth - so you can't "swap to RAM"). I do agree that higher order gradients are the next frontier in optimization though - current optimizer research seems to have stalled, so people focus on training with huge batches and stuff like that. Most SOTA models in my field are trained with SGD+momentum - super primitive stuff. I don't see how Jax would solve the memory problem though. You still have to store those Hessians somewhere, at least partially.

3. Do agree, that's cool if it actually parallelizes nontrivial stuff which e.g. tf.vectorized_map barfs on. Although in a lot of cases you can "vectorize" by concatenating input tensors into a higher dimensional tensor.

4. Meh. Not sure why I'd want that if I already have tracing and JIT.

5. This is #2

With PyTorch though, you get close enough to Numpy to feel at home in both, and there's so much code written for it already that you can usually find a good starting point for your research pretty easily on github and then build on top of that.

If you need to deploy, there's also tracing and jit, which lets you load and serve models with libtorch.

I see what you're saying regarding "advantages", I'm just pointing out that PyTorch might be "good enough" for most people. If I were on that team, I'd focus on providing comfortable transition from TF 2.x which is a dumpster fire (with the exception of TensorFlow Lite which is excellent). That, IMO, would be the only way for this project to achieve mainstream success unless PyTorch disintegrates over time.


I agree from a practitioner standpoint - but you were talking about research :)

1. Much of the scientific computing/stats community is stuck in the past. Many are still using Matlab! As opposed to the CS community, who are used to learning new frameworks, offering the ability to "import jax.numpy as np" and having their scripts just run is valuable to that community. As is having an API that they've only just started to become familiar with (and has way more documentation about).

2. Once again, this is true for practitioners, but not research. Hessian vector products show up in a decent amount of places. For example, if you have an inner optimization loop (a la. most meta learning approaches or Deep Set Prediction Networks) you have a Hessian Vector Product! Perhaps not prevalent in models that practitioners run but definitely something to keep an eye on in research.

3. My understanding is that it actually does a pretty decent job. Enough that it's useful in the prototyping phase.

4. PyTorch JIT is neat, and is what I meant by Pytorch team is "working" on it. However, the JIT doesn't do full code gen (thus, significant operator overhead for say, scalar networks) and has significantly less man hours poured into compared to XLA.

5. I was specifically talking about how you call grad on a function to get a function that returns its gradient. It's a cleaner API than PyTorch's autograd.

Jax is definitely not meant for deployment or industry usage, and I believe their developers hope they'll never be pushed along that direction :^)

I definitely agree that PyTorch is "good enough" for most people. However, among researchers, there's a decent amount of subgroups it could gain favor in.

You'd be surprised how many papers get submitted to ICML/Neurips that don't use PyTorch or TensorFlow at all, in favor of raw numpy, C++, or even MatLab! I think the numbers I had were something about 30% of papers don't use any ML framework. Jax could easily gain favor in this crowd.

There's also the crowd that cares a lot about higher order gradients. Also, admittedly a specific subgroup, but growing. Meta learning people care a lot. So do Neural ODE people. All it takes is for one of these subfields to blow up for higher order gradients to all of a sudden become a lot more appealing.

And finally, you have Google. Google researchers are never going to use PyTorch en masse (probably). If researchers at Google want to switch from TF, their only option is Jax. This is a pretty big subgroup of researchers :)

I definitely agree that Jax has a difficult hill to climb. But, they have a solid foothold within Google, and several subfields very amenable to their advantages.

PyTorch seems like the predominant research framework currently, but if any framework is going to erode their lead, I'd place my bets on Jax.


Autodiff is certainly one of the strengths of JAX. See the JAX Autodiff cookbook for a flavor of what JAX can do: https://colab.research.google.com/github/google/jax/blob/mas...

You might also like the per-example gradients example that appears first in the JAX Github page: this is only one line of code, but important for research areas such as differential privacy.


>> You'd be surprised how many papers get submitted to ICML/Neurips that don't use PyTorch or TensorFlow at all

I do keep up with literature and I do some applied research as well, so yeah, I see such things from time to time. The volume of papers is so intense though that unless there are other redeeming qualities if the paper does not use frameworks I already know (TF and PyTorch), I ignore it entirely. I wouldn't say I missed much that could help me in practice. One exception is Leslie Smith's work on cyclic learning rates and momentum modulation - he did it on some ridiculous setup, but it works for what I do.

I'm more surprised how many papers are written for tiny little datasets that you'd never use in practice, especially optimization papers. I mean, come on guys, I get it it's fast to train on CIFAR or fashion MNIST, but those results rarely translate to anything practical. And some papers are just plain not reproducible at all.

>> Google researchers are never going to use PyTorch en masse

IMO they should. It would easily double their productivity, and if Karpathy is to be believed their skin and eyesight would improve too: https://twitter.com/karpathy/status/868178954032513024?lang=...

>> I'd place my bets on Jax

As an ex-Googler, I'd place my bets in something else TBH. Google projects that aren't critical to Google's bottom line tend to deteriorate over time. Just look at TF. I'm not cruel enough to suggest it to my clients anymore, even though I could charge twice as much (because it would take twice as long to get the same result).


Certainly if you ignore those papers, you'd likely have no issue in practice - I suspect many of them are about more theoretical concerns. Perhaps I'll take a look at/post a list tomorrow.

Either way, I believe that our original discussion was on why somebody should bother. I provided a list of (admittedly) somewhat niche reasons. My personal opinion is that Jax will stick around, and at the very least, provide some neat ideas for Pytorch to ... independently come up with :)

>>> I'd place my bets on Jax

Hey hey hey context! Pytorch is currently dominant in research, so who could supplant it? Anecdotally, since I published my article (https://thegradient.pub/state-of-ml-frameworks-2019-pytorch-...) there has been more momentum towards Pytorch (preferred networks and openAI).

So if not Tensorflow, then who? I think Pytorch represents a local optima and is "good enough" for most people. So any newcomer framework needs to bring something new to the table, even if it's niche. I think Jax looks the most promising.


I'd like to see something based on a proper, high performance, statically typed programming language s.t. I could have a modicum of certainty that things would work when someone changes something. With Python, sadly, you don't know until you run things and hit error conditions dynamically. This is unacceptable in larger codebases.


Even then shape errors would require a dependent type system not found in most static languages.


There are levels of survival I'm prepared to accept.


You said from the beginning:

>> "Not sure why one would bother with this".

Then one provided the reasons for that, and you brought out your own opinions to defense that you don't wanna use Jax (without even trying)!? With that in mind, a thousand more reasons would not satisfy you.

From the perspective of a person who does DL research stuff with math background, I find Jax way more intuitive than any of other frameworks, Pytorch included. In math, you just don't "zero-ing the gradient" every iteration. In the same way, you just don't have to "forward" first to get the gradient. The more one does experiments with Jax, the more she / he would find it's faster / less steps / more intuitive to test new ideas. And I don't like the object-oriented design of Pytorch, functional seems more intuitive for me. That's why I would bother with Jax.


Looking forward to a readme that is properly filled out. Some documentation as well. Looks promising.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: