Hi, I’m stuck with next_symbol function
I have tried multiple variants of log_probs, however I keep getting the same (similar) error:
AssertionError Traceback (most recent call last)
in
1 # UNIT TEST
2 # test_next_symbol
----> 3 w1_unittest.test_next_symbol(next_symbol, NMTAttn)
~/work/w1_unittest.py in test_next_symbol(target, model)
557 next_de_tokens = target(the_model, tokens_en, [18477], 0.0)
558 # print(‘next_de_tokens’, next_de_tokens)
→ 559 assert np.allclose([next_de_tokens[0], next_de_tokens[1]], [7283, -9.929085731506348]), f"Expected output: [{7283}, {-9.929085731506348}], your output: [{next_de_tokens[0]}, {next_de_tokens[1]}]"
560
561 print(“\033[92m All tests passed”)
AssertionError: Expected output: [7283, -9.929085731506348], your output: [7283, -2.295468330383301]