# Which video of the week 1 is the reference for next symbol function?

There is no video to explain every step in `next_symbol` function. I’ve seen learners struggle with this function so let me try to explain it with an example.

First, we should get familiar with the inputs and the expected outputs that are explained in the doc-string:

``````    """Returns the index of the next token.

Args:
NMTAttn (tl.Serial): An LSTM sequence-to-sequence model with attention.
input_tokens (np.ndarray 1 x n_tokens): tokenized representation of the input sentence
cur_output_tokens (list): tokenized representation of previously translated words
temperature (float): parameter for sampling ranging from 0.0 to 1.0.
0.0: same as argmax, always pick the most probable token
1.0: sampling from the distribution (can sometimes say random things)

Returns:
int: index of the next token in the translated sentence
float: log probability of the next symbol
"""
``````

Then, I take a concrete example of translating sentence “I love languages.” to German at the point when the model translated first two words for a better illustration.

So, for example the tokenizer might have changed the English sentence “I love languages.” to:

• [[ 46, 5013, 4115, 3, 1]] - or concretely: “I” → 46, “love” → 5013, “languages” → 4115, “.” → 3, <eos> → 1;

Further, let’s assume the translated sentence so far is “Ich liebe” (just two words) which tokenizer might have tokenized to:

• [161, 12202] - or concretely: 161 → “Ich”, 12202 ->" liebe";

So, the `next_symbol` function should output the next word (third) for this translation. What are proposed steps to get it:

``````    # set the length of the current output tokens
token_length = ...
``````

Here your task is to get the current output length - how many words have you translated so far (in the example → 2).

``````    # calculate next power of 2 for padding length
``````

Here your task is to know how long the padded output should be (which of 1, 2, 4, 8, 16, 32, 64, … etc.). As the hint suggests use 2^log_2(token_length + 1) which in this case would be 4 (or: log_2(2+1) → ~1.585; roundup(1.585) → 2; 2^(2 )-> 4).

``````    # pad cur_output_tokens up to the padded_length
``````

This should pad the current outputs - for example, [161, 12202, 0, 0]

``````    # model expects the output to have an axis for the batch size in front so
# convert `padded` list to a numpy array with shape (1, <padded_length>)
``````

Here you just need an additional dimension - for example, array([[ 161, 12202, 0, 0]])

``````    # get the model prediction
output, _ = ...
``````

Here you should make use of the passed `NMTAttn` model which outputs all the log probabilities for the translation of shape (1, 4, 33300). For example:

``````DeviceArray([[[-16.068834 , -12.784518 , -10.474875 , ..., -16.060408 ,
-16.08047  , -16.057804 ],
[-18.527576 , -16.917854 , -10.098618 , ..., -18.50307  ,
-18.522917 , -18.516605 ],
[-18.028234 , -13.601976 ,  -7.3021574, ..., -18.02614  ,
-17.998901 , -18.003044 ],
[-14.261356 , -10.109584 ,  -4.4620075, ..., -14.248788 ,
-14.245432 , -14.233846 ]]], dtype=float32)
``````

``````    # get log probabilities from the last token output
log_probs = output[...]
``````

Here you need to get the log probabilities just for the third word - which is [batch = 0, third word index = token_length, all the probabilities for this token = `:`]. This way you pluck out only the third word log probs, which in this example:

``````DeviceArray([-18.028234 , -13.601976 ,  -7.3021574, ..., -18.02614  ,
-17.998901 , -18.003044 ], dtype=float32)
``````

``````    # get the next symbol by getting a logsoftmax sample (*hint: cast to an int)
symbol = ...
``````

Here you need to get (or sample) the actual word for this translation with the help of trax `tl.logsoftmax_sample(...)` and converting it to an `int`. In this example, the max log prob is -0.5954628, at index position 5112 (not visible in the previous array), which is word ‘Sprachen’.

So out of 33 300 tokens the most probable one is at index 5112, with the log prob of -0.5954628 (or ~ 55% probability). If the `temperature` parameter is 0, then your prediction would be this particular token.

Cheers!

1 Like