In the function “def NMTAttn”, we have the tl.Serial function in which we have the following line:
# Step 2: copy input tokens and target tokens as they will be needed later
Following that hint/instruction we use tl.Select function in a way that ignores or deletes the “mask weight”, which was added to the input stream by “trax.data.AddLossWeights(id_to_mask=0)(train_batch_stream)”.
After deleting that from the input stream we add the same mask (1s for non pad tokens and 0s for pad tokens) via the function “prepare_attention_input” later on in the tl.Series function.
Two questions:
- Am I right is assuming that we are actually deleting the mask weights and adding the same weights later in the tl.Series?
- If so, then is that just a contrivance to impart more information about Trax or is there an implementation reason why the mask is deleted/ignored?