OT but instead of quadratic attention can we not have n^10 or something crazier? I feel like we are limiting the intelligence just to save cost. But I can imagine that there might be some questions that may be worth paying higher cost for.
I feel like n^10 attention can capture patterns that lower complexity attention may not. So it seems arbitrary that we have n^2 attention.
You can find papers discussing "cubic" attention, i.e. each token gets to interact with each pair of other tokens, but always in very theoretical settings with single-layer transformers on contrived synthetic tasks.
Keep in mind that LLMs have many many layers, so they have plenty of opportunity to model higher-order interactions without needing to brute force every possible combination of 10 previous tokens, of which the vast majority will be useless. Empirically, even full "quadratic" attention is not always necessary, as evidenced by the existence of linear/sparse attention variants that perform almost as well.
Yes, and it works in theory.
Less so in practice. You saturate the memory of a b200 with a few dozen tokens on attentions higher than order 4. Training is even worse.
To paraphrase Knuth: high order polynomials are much more unimaginably large than mere infinity.
This is a common way of thinking. In practice this type of thing is more like optimizing flop allocation. Surely with an infinite compute and parameter budget you could have a better model with more intensive operations.
Another thing to consider is that transformers are very general computers. You can encode many many more complex architectures in simpler, multi layer transformers.
Aren't layers basically doing n^k attention? The attention block is n^2 because it allows 1 number per input/output pair. But nothing prevents you from stacking these on top of each other and get k-th order of "attentioness" with each layer encoding a different order.
n^2 isn't a setting someone chose, it's a mathematical consequence of what attention is.
Here's what attention does: every token looks at every other token to decide what's relevant. If you have n tokens, and each one looks at n others, you get n * n = n^2 operations.
Put another way: n^2 is when every token gets to look at every other token. What would n^3 be? n^10?
(sibling comment has same interpretation as you, then handwaves transformers can emulate more complex systems)
What you're missing is that there's no need to do extra work in the kernel smoothing step (what attention essentially is) because all the fancy transformation work is already happening in learning the kernel.
The feedforward networks prior to the attention layer are effectively learning sophisticated kernels. If you're unfamiliar (or for those who are) a Kernel is just a generalization of the dot product which is the most fundamental way of defining "similarity" between two points.
By learning a kernel the transformer is learning the best way to define what "similar" means for the task at hand and then we simply apply some basic smoothing over the data. This will handle all sort of interesting ways to compare points and that comparison will allow all points to provide a little bit of information.
Anything you could hope to achieve by performing more comparisons would be better solved by a better similarity function.