According to the doc string, the input x shape is (n_batch X n_heads, seqlen, d_head)

So when reshaping x shape (n_batch, n_heads, seqlen, d_head), I use
n_batch = x.shape[0] // n_heads
x = jnp.reshape(x, (n_batch, n_heads, seqlen, d_head))
However, the unit test failed with a stack trace. If I use
x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
Then it passes.

Looking into the failed test case, I found:
x.shape = (6, 2, 3)
However, n_heads = 3, seqlen = 2, d_head = 2.
So n_batch should be 2.
Using x = jnp.reshape(x, (-1, n_heads, seqlen, d_head)), we have
x.shape = (3, 3, 2, 2) after reshaping. It passes the test, but that means n_batch is 3.

If you have a tensor (x) of shape (6, 2, 3) and you reshape it with:
jnp.reshape(x, (-1, 3, 2, 3)) you would get the output of the same shape (2, 3, 2, 3) as in
jnp.reshape(x, (2, 3, 2, 3))

Yes, but the issue is: the result shape must be (3, 3, 2, 2) to pass the test. (2, 3, 2, 3) wonâ€™t pass. The reason is that the â€śxâ€ť in the test case is invalid: it does not follow the (n_batch X n_heads, seqlen, d_head) shape as in the doc string.

Nice catch! You are absolutely correct. The unit test functiontest_compute_attention_output_closure second part â€śtest dummy tensors 2â€ť is implement wrong. I will submit the issue for fixing.

@arvyzukai - This may be only a semi-relevant question, but is there an inconsistency between the values for d_head and the dimensions of x?

x has (or should have) shape (n_batch X n_heads, seqlen, d_head).

So, both
x = jnp.reshape(x,(-1, n_heads, seqlen, d_head)), and
x = jnp.reshape(x,(-1, n_heads, seqlen, x.shape[-1]))

should be equally valid. However, the first option passes all 4 tests, and the other passes only 3 of 4 tests. Is this inconsisteny the issue, or have I misunderstood something?

I know and understand this has been an ongoing and difficult issue to solve. I was just thinking that since d_head and x.shape[-1] should always be the same, their inconsistency might point towards the problem â€¦ or, not. You guys know this code and topic far better than I. Anyway, thx for responding.