C5 W4 A1: Question about MultiHeadAttention

In Week 4 Transformers programming assignment, we create an EncoderLayer class. As the first step in the call(), this passes the input sentence to the instantiated MultiHeadAttention object. i.e.

# calculate self-attention using mha(~1 line)
        self_attn_output = self.mha(x, x, attention_mask=mask)

The two occurrences of x (the input sentence) are, as I understand it, because we are doing self-attention in this step. x is paying attention to x.

What I do not understand is how this corresponds to the q, k, v that are meant to be fed into the Attention layer. When we come to the decoder layer, we use this statement:

attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask, return_attention_scores=True)

Here, the same input seems to be used for q, k and v. And in the second decoder block, we have:

attn2, attn_weights_block2 = self.mha2(out1, enc_output, enc_output, padding_mask, return_attention_scores=True)

This implies q is out1, k is enc_output and v is also enc_output. How can the same thing be k and v?

What am I missing here? How does the need for q, k and v correspond to what we are actually feeding to these MultiHeadAttention layers?

Thanks for any enlightenment!

Julian

1 Like

I edited the thread title to include the week and assignment numbers. It helps the mentors find the right notebook to use for reference.

That’s not correct. For self-attention, you need to use ‘x’ three times (for each of k, q, and v), then the mask.

Well, basically I don’t understand why if query, key and value are the same, then this function performs self-attention. As far as I understand, x in self.mha (x, x, x, mask) should a placeholder for the arguments to be fed in during the runtime. right? however, q and v cannot be the same type of arrays as they will have different dimensions.
I basically, do not follow when the Encoder is calling the EncodeLayer

x = self.enc_layers[i](x, training, mask)

supposedly, the input x should be composed of q, k and v. right?
how does it figure out which parts of x represent q, k and v when we just pass the same argument to represent all of them.
I looked at MultiHeadAttention documentation but did not really figure out how it works.
Any hint? or a good source to read?

Aff

2 Likes