Question about MultiHeadAttention layer

I have some confusion regarding the MultiHeadAttention layer. When we call an MHA layer with multiple heads and query, key, value all equal to X, using the lecture’s notation, is this equivalent to something of this form (ignoring the mask and batches) Attention(Q,K,V) = Attention(W_i^QX,W_i^KX,W_i^VX)=softmax(\frac{W_i^QX (W_i^K X)^T }{\sqrt{d_k}})W_i^VX
where the weights W_i's vary depending on head?

Let me think about this one more time :slight_smile:

From the paper:

Now, with a little bit of ingenuity, for the query, we can construct

\begin{bmatrix} Q W_1^Q & Q W_2^Q & \cdots & Q W_h^Q \end{bmatrix} = Q \begin{bmatrix} W_1^Q & W_2^Q & \cdots & W_h^Q \end{bmatrix} = Q W^Q

where Q \in \mathbb{R}^{m \times \text{seq_len_q} \times d_{\text{model}}} and W_i^Q \in \mathbb{R}^{d_\text{model} \times d_k} with d_k = d_{\text{model}} / h according to the paper.

Consequently, we see that W^Q must have the shape (d_model, d_k * h) = (d_model, d_model / h * h) = (d_model, d_model).

If we split the output of the linear transformation into h heads, it has the same effect as calculating h heads independently and then proceeding. However, computation-wise, it is more effective only to use one dense layer and split into h heads after the linear transformation instead of splitting first and doing h linear transformations afterward.

Q W^Q has shape (m, seq_len_q, d_model) * (d_model, d_model) = (m, seq_len_q, d_model), i.e., the shape of Q again.

1 Like

So using the paper’s notation instead, head_i = softmax(\frac{QW_i^Q(KW_i^K)^T}{\sqrt{d_k}})VW_i^V and if Q=K=V=X, the shape of X should be(batch_size, timesteps, features)?

Regarding d_k, in the assignment, k had shape key shape == (..., seq_len_k, depth), should d_k= depth?

1 Like

Yes :slight_smile:

D_k is your last dimension. In this case it is features / heads.

If you look at scaled_dot_product_attention, it is not multi head attention. So in this case maybe it is easier to think that you only operate on one head. In that case we are not passed (batch_size, timesteps, features), but (batch_size, timesteps, d_k) as you noticed.

1 Like

So if a tensor has shape (dim_1, dim_2, dim_3), in the case of softmax(\frac{QK^T}{\sqrt{d_k}})V, d_k = dim_3 of K, and in the case of softmax(\frac{XW_i^Q(XW_i^K)^T}{\sqrt{d_k}})XW_i^V, d_k = dim_3 of X?

@LuBinLiu, the above statement is actually wrong :sweat_smile: I have had more time to look at this issue today.

The code works for both multi-head attention and when you don’t supply multiple heads. Matrix multiplication is performed on the last 2 dimensions only. The total number of dims can be 3 or 4. The function works the same and is happy to process your inputs. I have created my own version below in which you can see the shapes explicitly:

solution removed

If you don’t use tf.matmul(..., transpose_b=True) but tf.matmul(Q, transpose(K, perm=(0,1,3,2)) or tf.matmul(Q, transpose(K, perm=(0,2,1)) then you actually fix it to work only for multihead attention in the first case, and for only a single head in the last case. The key piece of code that makes the function generalize is matmul with tranpose_b=True.

1 Like

Yes. Your X acts as both Q, K and V and the last dimension is the embedding size dimension. Depending on how you use the function, the last dimension will be d_model or d_k. D_v can be used but it is not used as a separate value in the transformer paper.

1 Like

Is there a reason that in the case of Attention(XW_i^Q, XW_i^K, XW_i^V), d_k is the last dimension of X as opposed to the last dimension of the product XW_i^K?

1 Like

The last dimension of X in your formula is d_model. The dimension of the weights are d_model, d_k, so you end up with a slice of the embedding dimension which is d_k.

The authors of the paper illustrate that you use different heads. And the weight matrices helps you slice out a head.

So if X has shape (..., d_{model}) and XW_i^K has shape (...., d_k), why not scale XW_i^Q(XW_i^K)^T inside the softmax of Attention(...) by \frac{1}{\sqrt{d_k}} as opposed to \frac{1}{\sqrt{d_{model}}}?

1 Like

In that case you should scale by 1/sqrt(d_k). Where do you see d_model? In the code, I reference the last dimension of K, which probably is d_k. But if you run attention without the linear transformation, i.e. on X and not X * W^k_i, then you actually end up with d_k = d_model in the code.

1 Like