logoalt Hacker News

soVeryTiredtoday at 4:53 PM1 replyview on HN

Can anyone explain to me why Q and K are both needed? They only ever appear as a pair, so why can’t you just define a matrix A = QK and learn that directly?


Replies

mattalextoday at 5:19 PM

Because the size of the attention matrix depends on the number of tokens (this is what makes attention N^2). If you don't care about having a flexible number of input tokens (e.g. in image processing) you can learn a fixed routing matrix. This is known as an MLP mixer https://arxiv.org/pdf/2105.01601 : you have one layer that processes each token in isolation ("vertical MLP") but ignores the inter-token connections, followed by a layer that combines between tokens ("horizontal MLP") that treats the internals of every token identically.