@Mubsi @balaji.ambresh I am also having problems in the unit_tests: I get 24/6 tests passed. Initially I get:
(DeviceArray([[49, 50, 51, 52, 53, 54, 55, 56, 57, 1],
[50, 51, 52, 53, 54, 55, 56, 57, 48, 1]], dtype=int32),
DeviceArray([[49, 50, 51, 52, 53, 54, 55, 56, 57, 1],
[50, 51, 52, 53, 54, 55, 56, 57, 48, 1]], dtype=int32),
DeviceArray([1, 1], dtype=int32))
Expected output
(DeviceArray([[49, 50, 51, 52, 53, 54, 55, 56, 57, 1],
[50, 51, 52, 53, 54, 55, 56, 57, 48, 1]], dtype=int32),
DeviceArray([[49, 50, 51, 52, 53, 54, 55, 56, 57, 1],
[50, 51, 52, 53, 54, 55, 56, 57, 48, 1]], dtype=int32),
DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32))
Element with index 2 in the output tuple has incorrect shape. It should be (batch_size, max_length).
Expected (2, 10).
Got (2,).
I am assuming element with index 2 is the mask_np_arr?
I use np.where() to create the example_mask
My lab id is phpssgbp
Thanks,
Drew