Programming Assignment Section 3.3 Loop: JAX Tracer Arrary Conversion Error

Hi, I am having the same issue. My Lab ID is iiipujei. When I run the following code I get the error below:

Cell/Code:

training_loop = training.Loop(NMTAttn(mode=‘train’),
train_task,
eval_tasks=[eval_task],
output_dir=output_dir)

Error message:

jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method array () was called on the JAX Tracer object Traced<ShapedArray(bool[32])>with<DynamicJaxprTrace(level=1/0)> (JAX Errors — JAX documentation)

Would love any help you could offer. All other tests pass and I’ve updated my workbook to the latest version.

Hi, I faced similar error you can try the approach that worked for me
In UNQ_C3
I faced problem because I was using mask=np.array() which was causing issue
I tried to change it to mask=fastnp.array() and it worked for me

Hi Mohit:

I ran into the same problem. Using fastnp.where() instead of np.where() solved the problem!

Cheers,
Drew