According to the doc string, the input x shape is (n_batch X n_heads, seqlen, d_head)
So when reshaping x shape (n_batch, n_heads, seqlen, d_head), I use
n_batch = x.shape[0] // n_heads
x = jnp.reshape(x, (n_batch, n_heads, seqlen, d_head))
However, the unit test failed with a stack trace. If I use
x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
Then it passes.
Looking into the failed test case, I found:
x.shape = (6, 2, 3)
However, n_heads = 3, seqlen = 2, d_head = 2.
So n_batch should be 2.
Using x = jnp.reshape(x, (-1, n_heads, seqlen, d_head)), we have
x.shape = (3, 3, 2, 2) after reshaping. It passes the test, but that means n_batch is 3.
If you have a tensor (x) of shape (6, 2, 3) and you reshape it with:
jnp.reshape(x, (-1, 3, 2, 3)) you would get the output of the same shape (2, 3, 2, 3) as in
jnp.reshape(x, (2, 3, 2, 3))
Yes, but the issue is: the result shape must be (3, 3, 2, 2) to pass the test. (2, 3, 2, 3) won’t pass. The reason is that the “x” in the test case is invalid: it does not follow the (n_batch X n_heads, seqlen, d_head) shape as in the doc string.
Nice catch! You are absolutely correct. The unit test functiontest_compute_attention_output_closure second part “test dummy tensors 2” is implement wrong. I will submit the issue for fixing.
@arvyzukai - This may be only a semi-relevant question, but is there an inconsistency between the values for d_head and the dimensions of x?
x has (or should have) shape (n_batch X n_heads, seqlen, d_head).
So, both
x = jnp.reshape(x,(-1, n_heads, seqlen, d_head)), and
x = jnp.reshape(x,(-1, n_heads, seqlen, x.shape[-1]))
should be equally valid. However, the first option passes all 4 tests, and the other passes only 3 of 4 tests. Is this inconsisteny the issue, or have I misunderstood something?
I know and understand this has been an ongoing and difficult issue to solve. I was just thinking that since d_head and x.shape[-1] should always be the same, their inconsistency might point towards the problem … or, not. You guys know this code and topic far better than I. Anyway, thx for responding.