C4_W2_Assignment UNQ_C4 unit test has a bug: the testing data is wrong

According to the doc string, the input x shape is (n_batch X n_heads, seqlen, d_head)

So when reshaping x shape (n_batch, n_heads, seqlen, d_head), I use
n_batch = x.shape[0] // n_heads
x = jnp.reshape(x, (n_batch, n_heads, seqlen, d_head))
However, the unit test failed with a stack trace. If I use
x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
Then it passes.

Looking into the failed test case, I found:
x.shape = (6, 2, 3)
However, n_heads = 3, seqlen = 2, d_head = 2.
So n_batch should be 2.
Using x = jnp.reshape(x, (-1, n_heads, seqlen, d_head)), we have
x.shape = (3, 3, 2, 2) after reshaping. It passes the test, but that means n_batch is 3.

Hi @larryleguo

Could you be more specific about which Exercise are you talking about? Is it “UNQ_C4”?

Yes, it is:

# UNQ_C4
# GRADED FUNCTION: compute_attention_output_closure
def compute_attention_output_closure(n_heads, d_head):
    ...

    def compute_attention_output(x):
        ...

If you have a tensor (x) of shape (6, 2, 3) and you reshape it with:
jnp.reshape(x, (-1, 3, 2, 3)) you would get the output of the same shape (2, 3, 2, 3) as in
jnp.reshape(x, (2, 3, 2, 3))

These are equivalent.

Yes, but the issue is: the result shape must be (3, 3, 2, 2) to pass the test. (2, 3, 2, 3) won’t pass. The reason is that the “x” in the test case is invalid: it does not follow the (n_batch X n_heads, seqlen, d_head) shape as in the doc string.

Hi,

I am having the same issue:

w2_tests.test_compute_attention_output_closure(compute_attention_output_closure)
input shape is : (6, 2, 3)
Reshape dimensions are: 3 2 2 3
Shape after reshape: (3, 2, 2, 3)
Transpose dimensions are: 3 2 2 3
# The middle 2s get swapped by swapping dimensions (0,1,2,3) to (0,2,1,3)
Shape after transpose: (3, 2, 2, 3)
input shape is : (6, 2, 3)
Reshape dimensions are: 2 3 2 2

InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (6, 2, 3) and (2, 3, 2, 2)

Nice catch! You are absolutely correct. The unit test functiontest_compute_attention_output_closure second part “test dummy tensors 2” is implement wrong. I will submit the issue for fixing.

Cheers

1 Like

Is there any update on this? I am having the same issue (still getting this error). Please let us know.

1 Like

Hi @Davit_Khachatryan

The test is being fixed but is not fixed yet. For now, just use (-1):

x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))

It will be fixed soon.

Cheers

@arvyzukai ,

Thank you!

@arvyzukai - This may be only a semi-relevant question, but is there an inconsistency between the values for d_head and the dimensions of x?

x has (or should have) shape (n_batch X n_heads, seqlen, d_head).

So, both
x = jnp.reshape(x,(-1, n_heads, seqlen, d_head)), and
x = jnp.reshape(x,(-1, n_heads, seqlen, x.shape[-1]))

should be equally valid. However, the first option passes all 4 tests, and the other passes only 3 of 4 tests. Is this inconsisteny the issue, or have I misunderstood something?

Hi @Steven1

The test case has one buggy array (issue with the shape) and that is the reason why some valid solutions do not pass.

As far as I understand it is not easy to fix, so the fixing of it is put off to some future date.

Cheers

Hi @aryzukai,

I know and understand this has been an ongoing and difficult issue to solve. I was just thinking that since d_head and x.shape[-1] should always be the same, their inconsistency might point towards the problem … or, not. You guys know this code and topic far better than I. Anyway, thx for responding.

Using x.shape[-1] instead of d_head doesn’t cause exception, but still results in failure. 3 success, 1 fail.

I am having the same problem. Has the unit test been fixed? I took a look at the unit test and it seems it hasn’t been fixed.

1 Like

Same error. It looks like it’s not fixed yet.

as 09.05.2023 still not fixed :-/ I’ve wasted half hour for debugging my code but issue is elsewhere

Thank you, for your note :+1:

The new instruction have been added for future learners in order not to get into this trouble until the test case will be fixed.