Hmm… I’m not sure. Let me elaborate:
If we are talking strictly about the next_symbol()
function and log_probs
then that is not entirely true.
For example, if the input sentence for next_symbol()
is “It’s time for tea”, then the output is (symbol, symbol log probability)
or in other words - (next word, next word’s log probability). For example, it could be the word “now”.
In other words, the model outputs the predictions for previous tokens and the next token, next-next token, next-next-next token and so on up to sequence length. (we do not care about the model’s predictions for “It’s time for tea” but we care for the next word prediction).
But if we are talking about the details what happens inside the function in this line:
# get the model prediction
output, _ = ...
then it’s more of what you said - the model gets a tuple of inputs. One part being input_tokens
and the other padded_with_batch
. Then the part you were talking about:
is the part that goes to input_encoder
- which embeds, then LSTM layers encode the information for the decoder consumption. In other words, it tries to express/compress the information that is “inside” these words and their sequence so that decoder could do it’s best to decompress what was inside the input and to continue the sequence.
In general, when we train language models, the thing that happens is what you talked about. For example, when we have a sentence like “It’s time for tea”, we train the model like:
- [<sos>] → what is the next word? → well, this time it’s [“It’s”]
- [<sos>, “It’s”] → what is the next word? → well, this time it’s [“time”]
- [<sos>, “It’s”, “time”] → what is the next word? → well, this time it’s [“for”]
- [<sos>, “It’s”, “time”, “for”] → what is the next word? → well, this time it’s [“tea”]
and we update the weights accordingly.
We do this not only for efficiency (since we have the whole sentence and we should get most out of now and not reload it next time) but also so that the model could predict from any sentence length, even when its length is zero.
So to be clear, the model predicts every word, but the next_symbol
function takes out only the one we care about.