Hi! In Course 4 → Week 1 → Video: “NMT Model with Attention” → at 2:25, Younes says that the computation of the encoder and pre-attention decoder can be done in parallel. Note that we are not talking about the actual Attention and the Transformer yet.
When you look at the paper, which I believe (and perhaps I can be wrong here), is the source of this lecture: Neural Machine Translation by Jointly Learning to Align and Translate Bahdanau, D. et al. (2014), they explicitly mention that to calculate the Alignment weights (not yet “Attention” at that point) you have to use previous RNN hidden state:
An LSTM layer requires context + input vector at each step. Do you initialize the context as zeros for the Decoder? Or for every Decoder timestamp? If so, that’s not how they describe it in the paper.
P.S. Yes, in the Transformer you calculate Ecoder/Decoder in parallel, but only because there are no RNN layers.
P.S.S. I did this Assignment quite a while ago using Trax, so I am not sure what happens under the hood there and I don’t have access to check out the Tensorflow implementation of this.
Hi @leggard those are some really good questions!
You are referring to the correct paper, it was the first time (at least in my knowledge) the concept of Attention was used, although it was not called Attention at that time, but RNNsearch. (Thank god that name was dropped )
What Yonus mentions in the video is that we are doing the computation of Pre-attention Decoder and Encoder in parallel. Remember, in the paper, there is no Pre or Post Decoder and so they use the previous state of RNN (as you point out). At 0:45/ 3:57 of that video Yonus mentions the same fact that you mention. He further adds that we break it down into two decoders since the way paper says is not easy to implement. Hence, we can run Pre-attention Decoder and Encoder in parallel. The post-attention decoder will only work after pre-attention decoder has provided its ‘hidden state’ via QKVs and mask.
PS: I could not find the coding assignment you are referring to in your second question, can you please provide me the link to that assignment?
He further adds that we break it down into two decoders since the way paper says is not easy to implement.
Yes, I understood this part. We calculate the encoder + pre-attention decoder together, then “attention”, then the actual decoder. That’s quite a neat change in architecture, making it sequential.
My question was specifically about the encoder + pre-attention decoder. Younes says those can be calculated in parallel, meaning that the pre-attention decoder, consisting of LSTMs, won’t get the cell state (context vector) from the encoder.
So, the pre-attention decoder doesn’t know anything about the input sentence. Isn’t that quite important? And it still can be done pretty easily, it’ll be like a regular Seq2Seq: encoder → decoder. And then everything else goes on top (attention, actual decoder).
PS: I could not find the coding assignment you are referring to in your second question, can you please provide me the link to that assignment?
I was referring to Course 4 → First Week’s Assignment, but the older version of it. It was written with Trax. At that time pre-attention decoder was part of Exercise 2. And if I assume that Trax’s LSTM works the same as in TensorFlow, then the pre-attention decoder gets initialized with random weights.
In our implementation, we are using teacher forcing to train NMT model. That means the hidden states are to be supplied by LSTM with “target” sequences as input (instead of “input” sequences). This way we make sure that the hidden states are getting the context vector of the correct translation (but shifted right).
We later use the K and V generated from “input” sequences in the encoder for the post-attention decoder, combined with hidden states from “target” sequences.
Does it make sense now why it can be parallelized?
The teacher’s forcing is a separate thing. Let’s imagine the inference of the proposed NMT model, not the training.
You pass input query to the encoder and store all the hidden states, including the last one.
You pass <EOS> token to the decoder, because you must start your output with something. It starts in parallel with encoder as explained in the video. Now let’s break it down a bit more:
2.1. The <EOS> token gets converted to a word embedding
2.2. The LSTM layer gets initialized with random weights because having a “cell state” (context vector) is a must for an LSTM even on its first step. And we do that, even though we could wait for the encoder to produce the final hidden state (why we don’t wait for the encoder and don’t use its last hidden state was my actual question).
2.3. Only now we pass the word embedding and the initialized context vector into the LSTM layer.
2.4. Bam! We got out first output hidden state of pre-attention decoder!
We do the attention and get the context vector for the <EOS> token
Pass the context vector + word embedding to the actual decoder (you see, this one takes both context+embedding right away, comparing to the pre-attention decoder, which gets only embedding on its first step).
So, I thought that ideally, we would do it like this:
It doesn’t look too complex change, that’s exactly how a default Seq2Seq model works, except that here, it’s only the first step of the whole model. And that’s how the model is described in the paper I mentioned before as well.
However, now that I spent a decent amount of time thinking about it, it looks like a minor change. In addition, I think by mentioning the parallelism of the encoder and decoder they tried to do a smoother introduction to the Transformer, where that really happens in parallel because this whole background with Seq2Seq and that ancestor of the actual Attention mechanism can be enough to confuse what is actual Attention, where LSTM ends and what’s Transformer
I hope now you get what I mean. Thanks for the time and quick answers once again!
Thanks for mentioning inference part only. It gives me more clarity on what parts are confusing for you. I like how the discussion is progressing as it is also allowing me to re-read through code, paper and articles.
From what I understand, the training and inference part might be confusing as they work slightly differently in the sense that we keep track of states during inference but during training, since we already know the next word (again teacher forcing), we do not keep track of states.
Regarding your point 2.2 (during inference for token), even though we initialize states with zeros, once they are forward passed in the decoder->pre-attentionLSTM they will return updated states (since this LSTM has been trained and weights are updated according to correct translations). These states (s_{i-1} in paper) are then used in the attention mechanism to attend the parts of input sequence. You can refer to Appendix A of the paper for the maths behind this.
Now about the context vector and why its not feeding into pre-attentionLSTM:
Here I believe you are confusing with what context vector actually means. Context Vector comes out only after attention mechanism. In the following snippet from assignment, we use the word ‘context’ as an alias to input sequence. Furthermore, the “cell state” that you are referring is also not a context vector.
# Vectorize it and pass it through the encoder
context = english_vectorizer(texts).to_tensor()
context = encoder(context)
The only thing that is passed from encoder to decoder is context vector after attention mechanism. The pre-attention LSTM’s job is to give hidden states from target sequence (during training)(teacher forcing so no need to track states) or predicted sequence (during inference)(predict one word at a time, so we need to track states).
I am mentioning some notions that might help clarify the paper and code better:
s_i are the hidden states of decoder RNN. In our implementation we have that coming from post attention LSTM and s_{i-1} coming from pre attention LSTM.
h_i are the annotations/hidden states from encoder RNN, which are concatenation of \overrightarrow{h_i} and \overleftarrow{h_i}. In our implementation we have only unidirectional LSTM.
context vector is an output of attention mechanism that tells decoder where to attend.