After Passing All Tests , i Am getting the following error while running code for training loop:
Its not easy to trace tis error but if you see it says there is something wrong with eval_shape, I would start with the encoder and decoder architecture there might be some mistake there, but it doesn’t exclude other parts above too.
I do not have knowledge about JAX, and since all the unit tests have passed, any suggestions how should i start debugging this problem
The above is my suggestion at this point!
The problem was with Masking code, i was able to solve it, Thanks!
