C4 compute_attention_heads_closure misbehavior

Hi, C4 seems to have some strange behavior to me, batch size of the input should be computed as

x.shape[0] // n_heads

but when reshaping x to shape (batch_size, n_heads, seqlen, d_head), I got shape mismatch error

InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (6, 2, 3) and (2, 3, 2, 2)

This error would be resolved if instead of compute batch_size explicitly, we set it to -1. But what’s the difference?

It seems the shape is implicitly inferred when setting to -1. Could you try printing out the shape of all intermediate variables and see if they match in each operation?

1 Like

hi @jackliu333 , this is my implementation

def compute_attention_output_closure(n_heads, d_head):
    """ Function that simulates environment inside CausalAttention function.
    Args:
        d_head (int):  dimensionality of heads
        n_heads (int): number of attention heads
    Returns:
        function: compute_attention_output function
    """
    
    def compute_attention_output(x):
        """ Compute the attention output.
        Args:
            x (jax.interpreters.xla.DeviceArray): tensor with shape (n_batch X n_heads, seqlen, d_head).
        Returns:
            jax.interpreters.xla.DeviceArray: reshaped tensor with shape (n_batch, seqlen, n_heads X d_head).
        """
        ### START CODE HERE (REPLACE INSTANCES OF 'None' WITH YOUR CODE) ###
        
        # Length of the sequence
        # Should be size of x's first dimension without counting the batch dim
        seqlen = x.shape[1]
        # Reshape x using jnp.reshape() to shape (n_batch, n_heads, seqlen, d_head)
        #batch_size = x.shape[0] // n_heads
        
        print('shape of input before reshape')
        print(x.shape)
        
        print("n_heads:", n_heads, "d_head", d_head)
        #assert batch_size * n_heads == x.shape[0]
        x = np.reshape(x, (-1, n_heads, seqlen, d_head))
        # Transpose x using jnp.transpose() to shape (n_batch, seqlen, n_heads, d_head)
        x = np.transpose(x, (0, 2, 1, 3))
        
        ### END CODE HERE ###
        print('shape of the output')
        print(x.shape)
        # Reshape to allow to concatenate the heads
        return jnp.reshape(x, (-1, seqlen, n_heads * d_head))
    return compute_attention_output

The print out of shapes and inputs are

# case 1
shape of input before reshape
(6, 2, 3)
n_heads: 2 d_head 3
shape of the output
(3, 2, 2, 3)

# case 2
shape of input before reshape
(6, 2, 3)
n_heads: 3 d_head 2
shape of the output
(3, 2, 3, 2)
 All tests passed

I think case 2 is problematic. The input has shape (6, 2, 3), and last dimension is suppose to be d_head, but the given d_head argument is 2.

1 Like

I don’t really understand this either.
x is supposed to have dimensions (n_batch X n_heads, seqlen, d_head).
So i assume to get n_batch you have to use: n_batch = int(x.shape[0]/n_heads)
Then:

#Reshape x using jnp.reshape() to shape (n_batch, n_heads, seqlen, d_head)
x = jnp.reshape(x, (n_batch,n_heads,seqlen, d_head))

But this does not work for the unit test.
Instead I have to use:

x = jnp.reshape(x, (-1,n_heads,seqlen, d_head))

2 Likes

Note that in compute_attention_heads, the docstring indicates that x is a tensor with shape (n_batch, seqlen, n_heads X d_head).
In compute_attention_output, x is a tensor with shape (n_batch X n_heads, seqlen, d_head). In order to concatenate the heads correctly, the first dimension needs to be split into the dimensions of n_batch and n_heads first, and as n_batch is not given locally (and should not be provided by means of a global variable), you have to use -1 for the first dimension. Next, the dimensions can be transposed and finally concatenated.
I hope this helps.

2 Likes

The second test is wrong.

It has these inputs:

{
    "x": jnp.array(
        [
            [[1, 0, 0], [0, 1, 0]],
            [[1, 0, 0], [0, 1, 0]],
            [[1, 0, 0], [0, 1, 0]],
            [[1, 0, 0], [0, 1, 0]],
            [[1, 0, 0], [0, 1, 0]],
            [[1, 0, 0], [0, 1, 0]],
        ]
    )
    , "n_heads": 3
    , "d_head": 2
}

But "d_head": 2 is inconsistent with x.shape[-1] == 3, given that x has shape (n_batch X n_heads, seqlen, d_head).

You can get it to pass by setting n_batch=-1 in the reshape, but you shouldn’t have to.