This is a question about understanding the inputs /outputs of the mha layer calls in the decoder blocks that form the core of the transformers model in the final assignment of the DLS.
The second multi-head attention (MHA) block in the decoder portion of the transformer involves two inputs, according to the lectures and the figure below.
- Two matrices of key and value vectors derived from the encoder, i.e. K and V (which are ultimately derived from the input sentence)
- A matrix of query values Q derived from the first MHA block of the decoder (which are, in turn, derived from the outputs, i.e. the target sentence)
My problem with this in the code is that the MHA layer in Keras returns, well, the multi-head attention - i.e. the A’s that are described in lectures (the more feature rich alternatives to the simple embeddings we’ve used in our other sequence models). This is the result of combining Q, K and V using the scalar-dot-product softmax function that is described at length in the lectures.
Neither the MHA layer nor the encoder as a whole output the specific Q, K and V matrices, however. So when we say that we are inputting K and V from the encoder into the decoder’s second MHA, what we actually seem to be doing is utilising the full encoder’s output (i.e. the output of the encoder’s MHA plus normalisations and a feed forward neural network, run N times) and then using this representation of the inputs as the basis of the key and value vectors within the decoder’s second MHA.
Specifically, where we set key = enc_output
in the decoder’s MHA we do not literally mean ‘use enc_output as K’ but are instead using ‘enc_output’ within a calculation buried in the MHA along the lines of: K = \text{parameter_matrix} \times \text{enc_output} - and this parameter-matrix for K is then fit during training.
Possibly because of the slightly tricky meaning of key=
within the function calls to self.mha()
I really struggled to grasp what was going on. Only by looking within the code on GitHub for MHA did I (hopefully) work this out - but I wanted to post this to:
- Check other people agreed with me
- Hopefully provide a bit of insight into the transformer’s model code as the descriptions in the code-comments are a bit lacking.
PS: for the really keen, the salient bit of code in the linked GitHub page above seems to be where the input key, value and query vectors are multiplied by weight matrices (the kernels) - meaning K, Q and V are derived from the inputs and not literally equal to them.
# Linear transformations
query = tf.einsum("...NI , HIO -> ...NHO", query, self.query_kernel)
key = tf.einsum("...MI , HIO -> ...MHO", key, self.key_kernel)
value = tf.einsum("...MI , HIO -> ...MHO", value, self.value_kernel)
Where:
- H = number of heads
- I = embedding dimension; this is summed over in the
einsum
as we are going from embedding to a different representation (MHA) of the words - N = number of query features
- M = number of key/value features (which have to be equal)
- O = number of dimensions of the output (so-called head-size)
But in our case we have N = M = O - since we have equal numbers of queries and key/values; and we make the output equal in size to the number of features as we loop the code meaning outputs have to be equal in dimension to inputs.