C4W1 Assignment prepare_attention_input jax error

Hello all,

My solution to # UNQ_C3 passes all the test for prepare_attention_input function,

However, when I run the Loop code cell I am thrown with the error:
The numpy.ndarray conversion method array() was called on the JAX Tracer object. It comes from prepare_attention_input. Searching for fixes I found that substituting numpy functions with jax fixes it. However I haven’t used any numpy functions.

I’m unable to progress through the assignment.

@Mubsi could you please take a look? Labid lajbiqby

Thank you

I was able to figure it out. I was using a numpy array for mask. Had to use fastnp array instead. jax implementation of numpy.nonzero is a bit contrived. For anyone else facing similar issue, here is the link to docs

For everyone who is struggling with this function, I used jax where function implementation, which apparently workes, even if there are ‘if’ statements in the function itself.

Same issue here.
I used ones and minimum for that to solve it.

At first I could not catch it, but when I googled it then I could figure it out. Thanks.