Can anyone explain to me why Q and K are both needed? They only ever appear as a pair, so why can’t you just define a matrix A = QK and learn that directly?
Because the size of the attention matrix depends on the number of tokens (this is what makes attention N^2).
If you don't care about having a flexible number of input tokens (e.g. in image processing) you can learn a fixed routing matrix. This is known as an MLP mixer https://arxiv.org/pdf/2105.01601 : you have one layer that processes each token in isolation ("vertical MLP") but ignores the inter-token connections, followed by a layer that combines between tokens ("horizontal MLP") that treats the internals of every token identically.
Because the size of the attention matrix depends on the number of tokens (this is what makes attention N^2). If you don't care about having a flexible number of input tokens (e.g. in image processing) you can learn a fixed routing matrix. This is known as an MLP mixer https://arxiv.org/pdf/2105.01601 : you have one layer that processes each token in isolation ("vertical MLP") but ignores the inter-token connections, followed by a layer that combines between tokens ("horizontal MLP") that treats the internals of every token identically.