Mask Multi Head Attention

HI Mentor,

During training phase, what is the input we are passing to decoder block in Transformer network ?

Also can someone explain Mask multi head attention to understand better?

Hey @Anbu,

Coming to your first question, I have attached the concerned slide for your reference above. Here, as can be seen the Multi-Head Attention (MHA) Block requires 3 matrices, Q(Query), K(Key) and V(Value). For this transformer, in the decoder block, we have 2 MHA blocks. To the first block, we are feeding the Q, K and V from the part of the data-point, along with the positional encodings, for instance, the first 6 words and their positional encodings for predicting the 7th word. To the second MHA block, K and V are output of the corresponding encoder block and Q is the output of the previous add and norm block in the decoder block itself. This should answer your first question.

In this example, we are doing the translation task (French to English). In both training and inference phases, the decoder block gets the target sequence only (first few words for instance). The aspect that differs is that in the training phase, we have the actual target sequence, so we use it as the input, but in the inference phase, we don’t have the actual target sequence, hence we use the already predicted words by the transformer as the input to the decoder block.

As for your second question, you can find a great understanding of Mask Multi-Head Attention in this nicely written blog. I hope this helps.

Regards,
Elemento

1 Like

Hi Elemento, First of all Thanks for reply to my questions but still needs clarity from your answers.

  1. Regarding the first 6 words you mentioned in the first paragraph, where it is actually coming from ?

  2. During training phase, are we passing actual entire sequence of words Y to the decoder ? Actual entire sequence of words like Jane visits africa in september. All the whole words ?

For part 1. 6 words are merely an analogy for the input. It is the target words that form the sentence and its details being passed down furthe into the network.

For part 2. You are right.

Seconded. The answer is quite reasonable. Thanks for sharing.

Hey @Anbu, just to elaborate on what @tushar31093 has said for your second query, let’s understand it with a simple example. Consider FS as the French sentence (Jane visite l’Afrique en Septembre), and ES as it’s English translation (jane visits Africa in September). In the training phase, we have FS and ES both, and in the inference phase, we have FS only.

Now, in the training phase, we pass the entire FS to the encoder. Since the decoder has a softmax layer in the end, it means it is modelled in such a manner that it will predict only a single word. So, we train the decoder like:

  • Feeding in SOS to the decoder and making it predict Jane
  • Feeding in SOS Jane to the decoder and making it predict visits
  • Feeding in SOS Jane visits to the decoder and making it predict Africa

and so on, until it predicts EOS. Here, note that every time, we are feeding the decoder a slightly different input, but the encoder is fed the FS every time. This is how we can train the encoder-decoder (entire network) multiple times using a single example, and then we can iterate over the other examples.

Now, in the training phase, since we have the ES, we can simply feed in the correct English Translation, extending it word by word. In the inference phase, we don’t have the ES with us. In that case, we simply feed in the input that the network has predicted in the previous iteration. It goes on like this:

  • Feeding in SOS to the decoder and let’s say the decoder predicts the next word as Jane
  • Now, we will feed SOS Jane to the decoder, and let’s say that the decoder predicts the next word as September. Here, note that the next word should have been visits, but in this scenario, the network makes a wrong prediction.
  • So, now we will feed SOS Jane September to the decoder and so on

We will continue this till the network predicts EOS. Here, note that since we have the FS, in every iteration, we will still feed the entire FS to the encoder. I hope this helps

Regards,
Elemento