Auto-regressive decoder in Transformer

  1. How do we select the values of n_q and n_k ?

  2. Why for auto-regressive decoder, the mask is at upper diagonal instead of lower diagonal ? In other words, why mask off w0, w1 instead of w1, w0 ?

  3. What does it exactly mean by " Teacher forcing refers to the technique of also allowing the decoder access to the input sentence, but in an autoregressive fashion." ?

  4. How does Positional Encoding (the original sin() and cos() idea) works for auto-regressive decoder ?

Hi user342,

With respect to your first question, n_q and n_k depend on the choices of embedding size of the language model and the number of heads used in multi-head attention. For an explanation see this blog post.
For a discussion of the choice of embedding size, see, e.g. this thread.
In order to choose the number of heads in multi-head attention it may be useful to consider the diversity of representations the heads are supposed to capture, as argued in this paper.

With regard to your second question, the masked dot product attention will be multiplied with the value matrix, leading to an output matrix (after concatenation) with dimensions n_seq x d_model (see again this post). In causal attention, the sequence consists of the embeddings of only the first token first, followed by the embeddings of the first and second token next, and so on, expanding the number of tokens along the sequence axis. To arrive at that result, the upper triangle of the dot product attention matrix has to be masked out.

Teacher forcing with autoregression implies that the decoder will have access to the input sequence, but token for token, so that with each token a gradient step can be taken, after which the correct token is passed to the decoder in the next step, where the task is to predict the token after that.

Because a mask is applied in causal attention, the positional encoding can be used as is, as any value that is not to be passed to the decoder is masked out (which includes values resulting from positional encoding).

2 Likes