In the NMT training code with LSTM and Cross Attention, the whole right-shifted target is passed to the decoder, which is also passed to the cross-attention, while the same decoder is called for each word prediction during the prediction time.
Does it mean multiple decoder hidden states are passed to cross attention and cross attention does the right thing to compute the attention with each decoder hidden state in one go during training??
Is this a question about one of the courses? Or about your own project?
Please which course are you taking? Your question relates to the assignment in NLP course 4 week 1 and I have moved it there. Hopefully the mentors for the course can answer you.
You are right about how decoder is used in training and prediction time.
In the code, during training, pre-attention decoder hidden state of LSTM are not passed to the attention mechanism. This is because we are using the shifted-right target sequence to attend the correct next token. This target sequence passes through the pre-attention decoder and only the output of the LSTM along with the encoded context is used in attention mechanism as query and value respectively.
I am referring a previous discussion on this. FYI, here implementation in paper is discussed which is slightly different than the TF code.