Question on reshape

Would like to understand the difference between the two, in the compute_attention_output in the assignment

x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))

x = jnp.reshape(x, (x.shape[0] / n_heads, n_heads, seqlen, d_head))

The first one produce no error while the second one incurs error. Is there a difference in terms of implementation? Thank you!

Hi @roger.lee

There is a bug in a test case:

I think you are talking about this error (test unit case error and not the code error)?