This is very surprising to me and I don't completely understand it.
Let me ask; Does this crucially rely on deterministic behavior of the transformer?
The only way I can imagine what you say is that the transformer can see the past inputs, and if it is deterministic know what (internal) chain of thought it must have had and hence what its plan (factorization) would have been when it generated the output (first prime number).
In this case, I have the impression that any amount of non-determinicity, i.e. randomness, breaks this mechanism and hence transformers with noise
are weaker than RNNs with noise. Would you agree?
Furthermore, if I understand correctly it does seem like it would be strange for a model to learn to use its knowledge of its own deterministic behavior to predict what the only strategy it could have chosen must have been. I just can't imagine a transformer solving the prime product generation task with this method, even if generating large primes and muktiplying them was easy for it. (And I can easily imagine an RNN solving the task).
It reminds me of solutions to infinite prisoner hat problems where you agree on some abstract well-ordering of possibilities...
No it doesn't crucially rely on deterministic behavior (you can add some dropout layers if you want).
Transformers does see the past inputs. But it won't really use them, except to "serialize the answer" that it will form on its first time-step : From the random initial context vector, splitting this space in regions it will generate two "numbers" as features, and multiply them as features and from this feature vector containing the answer it will just serialize it. At no times will the network ever factorize anything.
As an exercise you can probably manually specify the weights that will do exactly the task I describe, by considering each feature as an individual bit, and specifying layers weights such like they behave as multiplier logical circuit (with or-gates, and-gates and not-gates ...).
(Therefore the architecture can at least express the solution even though the network will probably have a really hard time reaching it)
Alternatively, a transformer is capable of over-fitting exactly a dataset : so if you give it a dataset ab,a,b it is perfectly capable of memorizing it exactly so it will be able to produce the task you want, even though it will not have understood anything.
Alternatively, if you want for the network to implicitly discover chain of thought reasoning then you will have to train the transformer to produce ????????ab,a,b this is a much harder task to learn because the gradient has to flow across various time-steps and through discrete sampling (instead of just trying to predict from imitation using known previous character). This is a Reinforcement Learning problem, that has to deal with credit assignment and long time horizon. You can use decision transformer. Or use other larger LLM that has access to the task description to help sample intelligent ???????, and train standard transformer on this sampled inputs, define a reward model, and build datasets of increasing rewards.
Even in these reinforcement learning, you can reach the optimal state without having to "plan". You can learn optimal controller that plan implicitly by just solving for the Bellman equation. This controller are living in the present and not anticipating the future outcomes, but you can also build controller like Model Predictive Control, that at each time-step make a plan, project it and solve for the best next move, but it's much more computationally expensive.
Hmm I would love for someone else to give their opinion because I find it very interesting but don't quite understand it yet. The way I see it, at the very beginnig the transformer has a large number of choices for numbers a, b that allow it to solve the problem. If there is randomness present then it will (pseudo-)randomly choose a pair a, b, with the intention to write ab,a,b
After writing the first digit of ab, as I understand it wanta to recover its features by doing the same operations as before on the past sequence. But as the computation of the features is non-deterministic, it can't arrive at the same pair a, b.
Let me try to specifiy a more difficult task for a tranaformer with randomness:
I want you to generate exactly three numbers in the form c,a,b where a and b are prime numbers with 250-300 digits and c=ab. I want these numbers to be randomly chosen with a distribution so that the range of primes 250-300 is approximately uniformly covered.
Suppose now a transformer has uniformly picked a and b and generated the first digit of ab. Let's in fact say it has already generated all the digits of ab. If the transformer has weights that make it now successfully print a, b (say with probability > 0.99), then you have constructed a method of factoring products of prime numbers in the 250-300 digit range, i.e. you just initialize the context window with the desired number 500-600 number ab to be factorized, and let the transformer do its work.
I.e. such a transformer has to be computationally powerful enough to factorize large prome products with >0.99 percent accuracy.
On the other hand an RNN or a human both with randomness can solve this task without having to be computationally powerful enough to do the factorization.
>you just initialize the context window with the desired number 500-600 number ab to be factorized
You also have to initialize the initial random context correctly correlated with your initial ab prompt (before the prompt, aka the initial state) which it has used to generate this number ab. For example give it a feature vector corresponding to a, b, and ab written in binary (in the way the network has learned to do). It won't ever learn to invert the one way function ab-> a,b from only being shown ab.
In practice, the learning signal will be quite weak because only when it has seen ab,a,b aka at the last character and back-propagate it through to the initial time (in training all previous character level prediction will amount to noise that it will have to learn to model first to be able to ignore ("Benford's law" distortion ??) , (or you can just give weight of zero to these intermediate character predictions that only add to the noise)
About the deterministic/randomness, the task is fundamentally deterministic as it has a single answer, so noise won't help but a non-deterministic network can/will learn to produce deterministic output without problem. In fact if at any point the network output a wrong digit, it won't be able to recover (except if you give him some character like backspace in which case each time it outputs a wrong digit not corresponding to its initial state that it want to write it will output a backspace for the next character and try its luck again ) and the answer will be wrong, but the output will look something like ab,a,b where a and b are real primes but with some wrong digits (like if they had been corrupted by a noisy channel).
Let me ask; Does this crucially rely on deterministic behavior of the transformer?
The only way I can imagine what you say is that the transformer can see the past inputs, and if it is deterministic know what (internal) chain of thought it must have had and hence what its plan (factorization) would have been when it generated the output (first prime number).
In this case, I have the impression that any amount of non-determinicity, i.e. randomness, breaks this mechanism and hence transformers with noise are weaker than RNNs with noise. Would you agree?
Furthermore, if I understand correctly it does seem like it would be strange for a model to learn to use its knowledge of its own deterministic behavior to predict what the only strategy it could have chosen must have been. I just can't imagine a transformer solving the prime product generation task with this method, even if generating large primes and muktiplying them was easy for it. (And I can easily imagine an RNN solving the task).
It reminds me of solutions to infinite prisoner hat problems where you agree on some abstract well-ordering of possibilities...