I have some confusion regarding the MultiHeadAttention layer. When we call an MHA layer with multiple heads and query
, key
, value
all equal to X, using the lecture’s notation, is this equivalent to something of this form (ignoring the mask and batches) Attention(Q,K,V) = Attention(W_i^QX,W_i^KX,W_i^VX)=softmax(\frac{W_i^QX (W_i^K X)^T }{\sqrt{d_k}})W_i^VX
where the weights W_i's vary depending on head?
From the paper:
Now, with a little bit of ingenuity, for the query, we can construct
\begin{bmatrix} Q W_1^Q & Q W_2^Q & \cdots & Q W_h^Q \end{bmatrix} = Q \begin{bmatrix} W_1^Q & W_2^Q & \cdots & W_h^Q \end{bmatrix} = Q W^Q
where Q \in \mathbb{R}^{m \times \text{seq_len_q} \times d_{\text{model}}} and W_i^Q \in \mathbb{R}^{d_\text{model} \times d_k} with d_k = d_{\text{model}} / h according to the paper.
Consequently, we see that W^Q must have the shape (d_model, d_k * h)
= (d_model, d_model / h * h)
= (d_model, d_model)
.
If we split the output of the linear transformation into h
heads, it has the same effect as calculating h
heads independently and then proceeding. However, computation-wise, it is more effective only to use one dense layer and split into h
heads after the linear transformation instead of splitting first and doing h
linear transformations afterward.
Q W^Q has shape (m, seq_len_q, d_model) * (d_model, d_model) = (m, seq_len_q, d_model)
, i.e., the shape of Q again.
So using the paper’s notation instead, head_i = softmax(\frac{QW_i^Q(KW_i^K)^T}{\sqrt{d_k}})VW_i^V and if Q=K=V=X, the shape of X should be(batch_size, timesteps, features)
?
Regarding d_k, in the assignment, k had shape key shape == (..., seq_len_k, depth)
, should d_k= depth?
Yes
D_k is your last dimension. In this case it is features / heads.
If you look at scaled_dot_product_attention, it is not multi head attention. So in this case maybe it is easier to think that you only operate on one head. In that case we are not passed (batch_size, timesteps, features), but (batch_size, timesteps, d_k) as you noticed.
So if a tensor has shape (dim_1, dim_2, dim_3), in the case of softmax(\frac{QK^T}{\sqrt{d_k}})V, d_k = dim_3 of K, and in the case of softmax(\frac{XW_i^Q(XW_i^K)^T}{\sqrt{d_k}})XW_i^V, d_k = dim_3 of X?
@LuBinLiu, the above statement is actually wrong I have had more time to look at this issue today.
The code works for both multi-head attention and when you don’t supply multiple heads. Matrix multiplication is performed on the last 2 dimensions only. The total number of dims can be 3 or 4. The function works the same and is happy to process your inputs. I have created my own version below in which you can see the shapes explicitly:
solution removed
If you don’t use tf.matmul(..., transpose_b=True)
but tf.matmul(Q, transpose(K, perm=(0,1,3,2))
or tf.matmul(Q, transpose(K, perm=(0,2,1))
then you actually fix it to work only for multihead attention in the first case, and for only a single head in the last case. The key piece of code that makes the function generalize is matmul with tranpose_b=True.
Yes. Your X acts as both Q, K and V and the last dimension is the embedding size dimension. Depending on how you use the function, the last dimension will be d_model or d_k. D_v can be used but it is not used as a separate value in the transformer paper.
Is there a reason that in the case of Attention(XW_i^Q, XW_i^K, XW_i^V), d_k is the last dimension of X as opposed to the last dimension of the product XW_i^K?
The last dimension of X in your formula is d_model. The dimension of the weights are d_model, d_k, so you end up with a slice of the embedding dimension which is d_k.
The authors of the paper illustrate that you use different heads. And the weight matrices helps you slice out a head.
So if X has shape (..., d_{model}) and XW_i^K has shape (...., d_k), why not scale XW_i^Q(XW_i^K)^T inside the softmax of Attention(...) by \frac{1}{\sqrt{d_k}} as opposed to \frac{1}{\sqrt{d_{model}}}?
In that case you should scale by 1/sqrt(d_k). Where do you see d_model? In the code, I reference the last dimension of K, which probably is d_k. But if you run attention without the linear transformation, i.e. on X and not X * W^k_i, then you actually end up with d_k = d_model in the code.