An attempt at a summary: They use a sigmoid function to make differentiable "soft" branches, and stack them to construct a binary tree, with the goal of only taking one branch at inference time (but training the whole tree) leading to log(W) instead of W inference cost. They gradually harden the branches so they become hard branches at the end of training.
A branch is computed as branch(input, N), with a neural network N computing a scalar c=N(input), then using a sigmoid to do a soft branch by returning the weighted sum of the recursive call s(c)*branch(input, N_left) + (1-s(c)) * branch(input, N_right) (the two weights s(c) and 1-s(c) sum to 1). They only do "proper processing" using the leaf nodes.
Then they add a new loss term that encourages hard decisions by minimising the entropy of the Bernoulli distribution, making the 2 weights converge to 0 and 1, at which point only one branch needs to be taken at inference. They also state that this hardening often happens automatically though.
It's a simple idea but the loss formulation is nice, you usually want your loss terms to be a measure of information.
From the previous paper you cited
>Pushing FFFs to the limit, we show that they can use as little as 1% of layer neurons for inference in vision transformers while preserving 94.2% of predictive performance.
This feels like that often misinterpreted Einstein meme/qoute about humans only using a fraction of their brain power.
Is this only for inference though? could it boost training?
That's an interesting question. It actually provides a nice way to parallelized training: Pretrain e.g. the first 3 branch levels, which effectively fragments the model into 8 separate parts, which you can continue training across 8 independent servers/nodes with no further communication between the nodes. A central server would run the 1st 3 levels and mark parts of the training set that each node has to train on. Maybe you could do this for the whole network and distribute the training in SETI@HOME style all over the world.
Hold on, you don't even need to freeze the branches completely: each node could train 1 branch on the path to its leaf node and communicate a change in the branch node to a central server, so you can distribute training without having to pre-freeze the branches. Still would need some pre-training though, and the splits would change slowly, and the attention mechanism could complicate things.
Currently distributed neural network training SETI@HOME style looks like a complete pipe dream that nobody is taking seriously. But a smart branching mechanism like this could suddenly make it possible. Folding@home reached 1.5 exaflops, which made it the world's largest supercomputer. Imagine the models we could train that way, they would far surpass whatever OpenAI or Google could train and would be public.
You should check out Hivemind[1]. It is very similar to what you described except it used MoE for "fragmentation". They have a couple of examples of pre-training in their repo. Hivemind was also used to build Petals[2] but it only supports fine-tuning and inference[3] afaik.
Apologies for layman question: how much tera/peta/exa-flops do current models use to train?
Well, I'm assuming they'd use whatever they're given, so maybe the question should be "how much less time would training take on a 1.5 exaflops computer?"
A lot of clusters are totally homogeneous, at least within some very large domains, so for a given interconnect and a generation of GPU you know the maximum message latency, the peak sustained pflop rate, and so on but what often matters is some combination of the depreciation-cost-per-time and the watt hours per unit time, where you can sort of approximate both if you ignore the unfortunate realities, which then act as a multiplier.
For example, a problem is network issues - and not just scale - as the training sequence often involve billions of cycles of short compute-sync sequences which are bursty (e.g., all-to-all, barrier, compute, barrier, all to all, ...) but between which there isn't enough time to engage low power modes so you're burning $ due to slack and waste. This is true in different ways for a lot of training approaches.
You can approximate this, but it's so sensitive to data set size, specific training schedule, etc. that you won't be able to get the most important answer.
It's mentioned briefly in the paper(1), but I'm more interested in the interpretability implications of this approach. In some respects, this marries the interpretability/editability of a small decision tree with the expressive power of a large neural network. Usually you see those two on extreme opposite ends of a tradeoff spectrum - but this approach, if it scales, might shift the pareto frontier.
(1): As a byproduct, the learned regions can also be used as a partition of the input space for interpretability, surgical model editing, catastrophic forgetting mitigation, reduction of replay data budget, etc..
…ETH Zurich is an illustrious research university that often cooperates with Deepmind and other hyped groups, they're right there at the frontier too, and have been for a very long time. They don't have massive training runs on their own but pound for pound I'd say they have better papers.
ETH Zurich is one of the top labs in the world. Disney Research also works with them a lot. Another "sleeper" is University of Amsterdam that has rockstars like Max Welling and his students Kingma, Salimans,van den Berg, and Hoogeboom.
It's easy to get hyped up on the big tech labs because they have the most compute, but the best papers come from smaller labs and unfortunately more lately face larger challenges in getting published. It's the smaller works that create the foundations that end up in these giant models. ML is in a really weird space right now.
From the first author on Twitter:
"It could quite a big deal for people who don't have access to a colocated cluster of GPUs:
e.g. with DiLoCo you could train your model, with data-parallelism, across all GPU providers, looking in real-time for the cheapest price, even if pre-emptable, even across continents"
It is not surprising. The assumption is that they have the best people. That you can objectively search 8 billion people for the best people globally is folly of course. There are geniuses without US citizenship / visas / green cards. And so outside brains are going to figure this out. Mix in GDP of $rest_of_world has much more resources than any company, and the luck-driven nature of making AI discoveries, and I reckon most progress will be outside of OpenAI etc. Driven by a problem the big guys don't need to solve: how do I avoiding buying a $5k graphics card.
I wouldn't be so quick to conspiracy. I'm the author of a work and a famous blog post that trains a particular common architecture much faster (don't want to dox myself too much) and with far fewer parameters, but it has been rejected several times and is now arxiv only. Our most common complaint was "who would use this? Why not just take a large model and tune it?" That question alone held us back a year (had over a hundred citations by then and remains my most cited work) until it switched to "use more datasets" and "not novel" (by that time true, others had built off of us, cited us, and published in top venues).
I don't think this was some conspiracy by big labs to push back against us (we're nobodies) but rather that people get caught up in hype and reviewers are lazy and incentivized to reject. You're trained to be critical of works and especially consider that post hoc most solutions appear far simpler than they actually are. But context matters because if you don't approach every paper with nuance it's easy to say "oh, it's just x." But if those ideas were so simple and obvious they would also be prolific. I see a lot of small labs suffer the same fate simply due to lack of compute. If you don't make your new technique work on many datasets it becomes the easiest thing to reject a paper by. ACs aren't checking that reviews are reasonable. I've even argued with fellow reviewers about papers in workshops -- papers I would have accepted in the main conference -- that are brushed off and the reviewers admit in their reviews that they do not work on these topics. I don't understand what's going on but at times it feels like a collective madness. A 10 page paper with 4 very different datasets that solves a problem, is clearly written, has no major flaws, and is useful to the community should not need defending when submitted to a workshop just because reviewers aren't qualified to review the work (this paper got in btw). We are moving into a "pay to play" ecosystem and that will only create bad science due to group think. (another aspect of "pay to play" is in the tuning. Spending $1M to tune your model to be the best doesn't mean it is better than a model that could not afford the search. Often more than half of resources are spent on tuning now)
Is there a place where you guys discuss... things? I'm layman interested in this topic akin to pop-physics/maths, but have no chance to just read papers and "get it". On the other hand, immediately available resources focus more on how-to part of it rather than on what's up overall. Also, do you have something like 3b1b/pbs/nph for it? Content that you can watch and say "well, yep, good job".
I don't have any great recommendations and unfortunately my advice may be not what you want to hear. What I tell my students is "You don't need to know math to build good models, but you need to know math to know why your models are wrong." But this is even a contentious statement within the community. (Personally I'm more interested in exploring what we can build and understand rather than focusing on throwing more compute and data at problems. There's a lot of work to be done that does not require significant compute, but it isn't flashy and you'll get little fame. Every famous model you know has some unsung hero(s) who built the foundation before compute was thrown at the problem). I was previously a physicist and we similarly frequently express that you do not know the material unless you can do the math. Physicists are trained in generating analogies as they help communication but this sometimes leads to people convincing themselves that they understand things far more than they actually do. They say the devil is in the details, and boy are there a lot of details. (Of the science communicators, I'm happy those are the ones you mention though!) But do not take this as gatekeeping! These groups are often happy to help with the math and recommend readings. ML is kinda a while west and you can honestly pick a subdomain of math and probably find it useful, but I would start by making sure you have a foundation in multivariate calculus and linear algebra.
As to paper reading, my suggestion is to just start. This is a fear I faced when I began grad school and it feels overwhelming and like everyone is leagues ahead of you and you have no idea where to begin. I promise that is not the case. Start anywhere, it is okay, as where you end up will not matter too much on where you begin. Mentors help, but they aren't necessary if you have dedication. As you read you will become accustomed to the language and start to understand the "lore." I highly suggest following topics you find interesting backwards through time, as this has been one of the most beneficial practices in my learning. I still find revisiting some old works reveals many hidden gems that were forgotten. Plus, they'll be easier to read! Yes, you will have to reread many of those works later, as you mature your knowledge, but that is not a bad thing. You will come with newer eyes. Your goal should be to first understand the motivation/lore, so do not worry if you do not understand all the details. You will learn a lot through immersion. It is perfectly okay if you barely understand a work when first starting because a mistake many people make (including a lot of researchers!) is that a paper is not and cannot be self contained. You cannot truthfully read a work without understanding its history and that only comes with time and experience. Never forget this aspect; it is all too easy to deceive yourself that things are simpler than they are (the curse of hindsight).
I'd also suggest to just get building. To learn physics you must do physics problems. To learn ML you must build ML systems. There are no shortcuts but progress is faster than it looks. There's hundreds of tutorials out there and most are absolute garbage but I also don't have something I can point to that's comprehensive. Just keep in mind that you're always learning and so are the people writing tutorials. I'm going to kinda just dump some links, they aren't in any particular order sorry haha. Its far from comprehensive, but this should help you getting started, nothing in here is too advanced. If it looks complicated, spend more time, you'll get it. It's normal if it doesn't click right away and there's nothing wrong with that.
Unless they were very confident of acceptance, a top research prof would rewrite and resubmit before publishing on arxiv so that others could "build on it" (scoop you at a top conference).
Welcome to ML. And idk, I'd feel pretty confident that a paper that gets so many citations gets accepted. The review system is like a slot machine if you aren't a big tech lab
They certainly have an incentive to keep these kinds of improvements in-house and not publish them, since they are commercial entities and this represents a competitive advantage.
Nvidia can't make GPUs fast enough. I doubt 10xing training and/or inference efficiency would result in a decrease in demand. I would be surprised if it didn't instead increase demand. Mind you, Nvidia is pushing hard on TensorRT which optimizes models at inference time and results in major increases in throughput (not 10x though lol).
But if things get too efficient for individual users, you won't need an Nvidia GPU anymore. People will use cheaper hardware instead. I'm looking forward to running good models at decent speed on a low-end CPU or whatever crappy GPU is in my phone.
I had the same thought this morning and was debating selling my nvda stock when I saw this - feels like they are well-positioned right now, as with crypto a few years ago, but if there were an efficiency breakthrough that allowed commodity CPUs to do the inference instead, this advantage could vanish quickly.
Many labs doing foundational work like this and making progress don’t have the anything near the budget or compute to implement at scale. In other words they don’t have a Sam and his backers or a Zuck and his budget.
https://arxiv.org/abs/2308.14711
An attempt at a summary: They use a sigmoid function to make differentiable "soft" branches, and stack them to construct a binary tree, with the goal of only taking one branch at inference time (but training the whole tree) leading to log(W) instead of W inference cost. They gradually harden the branches so they become hard branches at the end of training.
A branch is computed as branch(input, N), with a neural network N computing a scalar c=N(input), then using a sigmoid to do a soft branch by returning the weighted sum of the recursive call s(c)*branch(input, N_left) + (1-s(c)) * branch(input, N_right) (the two weights s(c) and 1-s(c) sum to 1). They only do "proper processing" using the leaf nodes.
Then they add a new loss term that encourages hard decisions by minimising the entropy of the Bernoulli distribution, making the 2 weights converge to 0 and 1, at which point only one branch needs to be taken at inference. They also state that this hardening often happens automatically though.
It's a simple idea but the loss formulation is nice, you usually want your loss terms to be a measure of information.