Would like to understand the difference between the two, in the compute_attention_output in the assignment
x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
x = jnp.reshape(x, (x.shape[0] / n_heads, n_heads, seqlen, d_head))
The first one produce no error while the second one incurs error. Is there a difference in terms of implementation? Thank you!
Hi @roger.lee
There is a bug in a test case:
I think you are talking about this error (test unit case error and not the code error)?
Cheers