Intuition about the application of padding masks and look-ahead masks in Transformer's encoder/decoder

From the Tensorflow tutorial, the shape of the padding mask is (batch_size, 1, 1, seq_len) and look-ahead mask is (batch_size, 1, seq_len, seq_len) which fed into scaled_dot_product function along with q,v,k.

For input X with shape (batch_size, seq_len) of integer tokens that are then converted to an embedding, the resulting Tensor, let’s call X' has shape (batch_size, seq_len, d_model). In the call method of MultiHeadAttention class of the tutorial we have:

def call(self, v, k, q, mask):
   batch_size = tf.shape(q)[0]

   q = self.wq(q)  # (batch_size, seq_len, d_model)
   k = self.wk(k)  # (batch_size, seq_len, d_model)
   v = self.wv(v)  # (batch_size, seq_len, d_model)

   q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
   k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
   v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

   scaled_attention, attention_weights = scaled_dot_product_attention(
    q, k, v, mask)

   ....

If q=v=k=X' are input into the above call, then input into scaled_dot_product in the last line would be tensors of shape (batch_size, num_head, seq_len, depth ).Then, in the first line of scaled_dot_product:

matmul_qk = tf.matmul(q, k, transpose_b=True)

would result in matmul_qk having a shape of (batch_size, num_head, seq_len, seq_len ), and this Tensor then has the masked applied via:

if mask is not None:
   scaled_attention_logits += (mask * -1e9)

I understand that the shape of the masks (via broadcasting) matches that of the shape of matmul_qk above, but after all of the transformations of the original X, I’m having a hard time visualizing how the padding and look-ahead masks are doing what they are intended to do. For instance, how is the (batch_size, 1, 1, seq_len) padding mask created based on padded 0s in the original input X of integers tokens end up masking padded values in matmul_qk?

Hi @LuBinLiu ,
To make it easy to understand, let’s simplify the problem, say seq_len_q = 5, seq_len_k = seq_len_v = 6, batch_size = num_head = 1, and the length of example sentence is 4, e.g., [This] [is] [source] [seq] [pad] [pad].
Note: The mask value in the tutorial, 1 means padding, which is different from MultiHeadAttention api.
So the masking operation is as below:


After taking softmax, it becomes:

Pij represents how much Qi pays attention to Kj. You can see the attention weights of [pad] are zero.

2 Likes

So the last dimension (which has dimension seq_len_k = length of each sentence) can be interpreted as the sentence dimension?

I would say that seq_len_k is the sentence length you want to pay attention to.
For instance, the 2nd MultiHeadAttention layer, mha2 in the Transformer DecoderLayer, seq_len_k = (encoder sentence length), and seq_len_q = (decoder sentence length). Its mask is created based on the encoder sentence.
Take a look at create_masks function in Transformer class in the tutorial, you will see dec_padding_mask is same as enc_padding_mask, both are created based on the encoder input sentence. It’s because enc_padding_mask is for encoder self-attention to pay attention to itself, but dec_padding_mask is for decoder to pay attention to encoder.

1 Like