I got stuck in that single mbr_decode takes more than 5 or even 10 minutes to run, and is there something run with it? Even if I just run the generate_samples(), it will succeed, but sometimes return within within a minute, sometimes taking an extremely long time to run.
I have tried to run C10 overnight, it didn’t return any errors (until I break it), but just kept running.
One common mistake, that could make the code run extremely long is in # UNQ_C6 next_symbol function definition - when calculating log_probs → do not use-1 in the second dimension:
# get log probabilities from the last token output
log_probs = output[None]
The hint in the notebook:
The log probabilities output will have the shape: (batch size, decoder length, vocab size). It will contain log probabilities for each token in the cur_output_tokens plus 1 for the start symbol introduced by the ShiftRight in the preattention decoder. For example, if cur_output_tokens is [1, 2, 5], the model will output an array of log probabilities each for tokens 0 (start symbol), 1, 2, and 5. To generate the next symbol, you just want to get the log probabilities associated with the last token (i.e. token 5 at index 3). You can slice the model output at [0, 3, :] to get this. It will be up to you to generalize this for any length of cur_output_tokens
In other words, do not use -1 to select decoder length dimension - use length of the current output tokens which you stored in an appropriate variable.