C4_W2 compute_attention_heads transpose

Hi All. For the in C4_W2 compute_attention_heads assignment, I’m not understanding how to use jnp.transpose.

When I call x = jnp.transpose(x, (-1, n_heads, seqlen, d_head)), I receive the error:

transpose permutation isn't a permutation of operand
 dimensions, got permutation (-1, 2, 2, 3) for operand 
shape (3, 2, 2, 3).

Before I call reshape, the shape of x is (3, 2, 6). After reshape, the shape of x is (3, 2, 2, 3). Should this be the shape of x?

I think I figured this out.

For other’s seeing this I had to look at numpy’s documentation:
https://numpy.org/doc/stable/reference/generated/numpy.transpose.html

Specifically look at this example:

a = np.ones((1, 2, 3))
np.transpose(a, (1, 0, 2)).shape
(2, 1, 3)
1 Like