Error in training.loop: even after passing all test

After Passing All Tests , i Am getting the following error while running code for training loop:

Its not easy to trace tis error but if you see it says there is something wrong with eval_shape, I would start with the encoder and decoder architecture there might be some mistake there, but it doesn’t exclude other parts above too.

I do not have knowledge about JAX, and since all the unit tests have passed, any suggestions how should i start debugging this problem

The above is my suggestion at this point!

The problem was with Masking code, i was able to solve it, Thanks!

1 Like