Hi All. For the in C4_W2 compute_attention_heads assignment, I’m not understanding how to use jnp.transpose
.
When I call x = jnp.transpose(x, (-1, n_heads, seqlen, d_head))
, I receive the error:
transpose permutation isn't a permutation of operand
dimensions, got permutation (-1, 2, 2, 3) for operand
shape (3, 2, 2, 3).
Before I call reshape
, the shape of x
is (3, 2, 6)
. After reshape
, the shape of x
is (3, 2, 2, 3)
. Should this be the shape of x
?