How does np.squeeze work?

Hi,

I am trying to understand why
classes[np.squeeze(train_set_y[:, index])]
outputs b’non-cat’ or b’cat’

Precisely, why do we need to use np.squeeze?

Thanks

1 Like

np.squeeze removes axes of length one.

train_set_y[:,index].shape is (1,), since we keep the first axis.
To concatenate strings we need the string ‘cat’ not [‘cat’]; hence, we squeeze the array to retrieve the scalar value it contains.
Alternatively, you could access the scalar directly by using train_set_y[0,index]. In that case, np.squeeze is redundant.

Example:

two = np.array([[[[[[2]]]]]])
print(two.shape, np.squeeze(two).shape)

yields

(1, 1, 1, 1, 1, 1) ()

i.e, a scalar in the last case.

4 Likes

Thanks for the clear explanation!

1 Like