Intuition reagarding why output of "scaled-dot product" attention represents similarity between tokens

Here i what i think why the output of “scaled-dot product” attention reflects how much each token of key affects each token in query:
Q matrix has shape (m, n) where m is the number of tokens and n is the embedding dimension.
K matrix has shape (m, n) where m is the number of tokens and n is the embedding dimension.
V matrix has shape (m, n) where m is the number of tokens and n is the embedding dimension.

So when we do Q * K.T we are effectively multiplying the query corresponding to each token, since each row in Q is basically a query corresponding to a token, to the all the features of all the tokens, since the value in each column in K.T(length=n) represents the same feature of the corresponding key, and we get (m, m) matrix.

Then we again do matrix mutplication between (Q * K.T) which is of shape (m,m) and V which is of shape (m,n) to get a (m, n) matrix.
In this matrix each row represents a token and each value represents how much the corresponding embedding feature influences the current output.

So now we can use the final (m, n) matrix to know how much each token affects every other token simply by calculating some sort of similarity between them, since tokens that are similar would have similar values for embedding features.

Is this way of thinking correct?

Hi @God_of_Calamity

I believe your overall understanding is correct. I have some doubts about your wording so I will elaborate on that.

(Q * K.T) or (m, m) is the “alignment” score, so that would be the matrix that would tell how much each token “influences” (your term, I would use “weighs” how much of each value each token would aggregate) each other token.

When this “alignment” score is dot multiplied by V, we have the final “aggregated” value of the attention. (If the first token attends 100% to itself (the first token), then the first row of the final m,n matrix would be completely a copy of first V row). This is pretty much “attention”.

Usually things don’t end here. Usually there is residual and normalization layer after which there’s a feed forward layer which “acts” on these outputs (or “thinks” what to do with what the attention “added” to the original inputs).

Cheers