Understanding multi-headed attention - C5W4A1

This is a question about understanding the inputs /outputs of the mha layer calls in the decoder blocks that form the core of the transformers model in the final assignment of the DLS.

The second multi-head attention (MHA) block in the decoder portion of the transformer involves two inputs, according to the lectures and the figure below.

  • Two matrices of key and value vectors derived from the encoder, i.e. K and V (which are ultimately derived from the input sentence)
  • A matrix of query values Q derived from the first MHA block of the decoder (which are, in turn, derived from the outputs, i.e. the target sentence)

My problem with this in the code is that the MHA layer in Keras returns, well, the multi-head attention - i.e. the A’s that are described in lectures (the more feature rich alternatives to the simple embeddings we’ve used in our other sequence models). This is the result of combining Q, K and V using the scalar-dot-product softmax function that is described at length in the lectures.

Neither the MHA layer nor the encoder as a whole output the specific Q, K and V matrices, however. So when we say that we are inputting K and V from the encoder into the decoder’s second MHA, what we actually seem to be doing is utilising the full encoder’s output (i.e. the output of the encoder’s MHA plus normalisations and a feed forward neural network, run N times) and then using this representation of the inputs as the basis of the key and value vectors within the decoder’s second MHA.

Specifically, where we set key = enc_output in the decoder’s MHA we do not literally mean ‘use enc_output as K’ but are instead using ‘enc_output’ within a calculation buried in the MHA along the lines of: K = \text{parameter_matrix} \times \text{enc_output} - and this parameter-matrix for K is then fit during training.

Possibly because of the slightly tricky meaning of key= within the function calls to self.mha() I really struggled to grasp what was going on. Only by looking within the code on GitHub for MHA did I (hopefully) work this out - but I wanted to post this to:

  1. Check other people agreed with me
  2. Hopefully provide a bit of insight into the transformer’s model code as the descriptions in the code-comments are a bit lacking.

PS: for the really keen, the salient bit of code in the linked GitHub page above seems to be where the input key, value and query vectors are multiplied by weight matrices (the kernels) - meaning K, Q and V are derived from the inputs and not literally equal to them.

# Linear transformations
query = tf.einsum("...NI , HIO -> ...NHO", query, self.query_kernel)
key = tf.einsum("...MI , HIO -> ...MHO", key, self.key_kernel)
value = tf.einsum("...MI , HIO -> ...MHO", value, self.value_kernel)

Where:

  • H = number of heads
  • I = embedding dimension; this is summed over in the einsum as we are going from embedding to a different representation (MHA) of the words
  • N = number of query features
  • M = number of key/value features (which have to be equal)
  • O = number of dimensions of the output (so-called head-size)

But in our case we have N = M = O - since we have equal numbers of queries and key/values; and we make the output equal in size to the number of features as we loop the code meaning outputs have to be equal in dimension to inputs.

1 Like

Hi Alastair_Heffernan,

I agree with your understanding of the process.

Here are my two cents that may provide some further clarification.

Attention can be seen as a feature extractor, extracting relevant meaning features from value vectors. The output tensor resulting from the pass through multiple MHA and FF layers combines feature extractions with increasing levels of complexity and abstraction, conceptually comparable to the extraction of image features by deeper layers in CNNs.

As you note, the full output of the encoder is fed into the decoder. This output is first linearly transformed into (number of heads) sets of key and value pairs within the multi-head attention layer. Then, the respective (number of heads) queries (which also result from a linear transformation process within the multi-head attention after the first Add & Norm of the decoder) are used to extract features to be mapped to the ouput that is to be generated. During calculation, the respective sets of keys, values, and queries are concatenated into matrices V, K, and Q.