UNQ_C3 Mask Implementation

I am getting an inexplicable error in the mask implementation.

    mask = np.ones(inputs.shape)
    mask[inputs==0]=0

This gives me an error of dimension mismatch:

How is this possible !!? Please help.

You might want to try looking at (the trax fastmath version of) numpy.where()
https://numpy.org/doc/stable/reference/generated/numpy.where.html

(condition, action when condition==True, action when condition==False)