My solution to # UNQ_C3 passes all the test for prepare_attention_input function,
However, when I run the Loop code cell I am thrown with the error:
The numpy.ndarray conversion method array() was called on the JAX Tracer object. It comes from prepare_attention_input. Searching for fixes I found that substituting numpy functions with jax fixes it. However I haven’t used any numpy functions.
I’m unable to progress through the assignment.
@Mubsi could you please take a look? Labid lajbiqby
I was able to figure it out. I was using a numpy array for mask. Had to use fastnp array instead. jax implementation of numpy.nonzero is a bit contrived. For anyone else facing similar issue, here is the link to docs
For everyone who is struggling with this function, I used jax where function implementation, which apparently workes, even if there are ‘if’ statements in the function itself.