hi @jackliu333 , this is my implementation
def compute_attention_output_closure(n_heads, d_head):
""" Function that simulates environment inside CausalAttention function.
Args:
d_head (int): dimensionality of heads
n_heads (int): number of attention heads
Returns:
function: compute_attention_output function
"""
def compute_attention_output(x):
""" Compute the attention output.
Args:
x (jax.interpreters.xla.DeviceArray): tensor with shape (n_batch X n_heads, seqlen, d_head).
Returns:
jax.interpreters.xla.DeviceArray: reshaped tensor with shape (n_batch, seqlen, n_heads X d_head).
"""
### START CODE HERE (REPLACE INSTANCES OF 'None' WITH YOUR CODE) ###
# Length of the sequence
# Should be size of x's first dimension without counting the batch dim
seqlen = x.shape[1]
# Reshape x using jnp.reshape() to shape (n_batch, n_heads, seqlen, d_head)
#batch_size = x.shape[0] // n_heads
print('shape of input before reshape')
print(x.shape)
print("n_heads:", n_heads, "d_head", d_head)
#assert batch_size * n_heads == x.shape[0]
x = np.reshape(x, (-1, n_heads, seqlen, d_head))
# Transpose x using jnp.transpose() to shape (n_batch, seqlen, n_heads, d_head)
x = np.transpose(x, (0, 2, 1, 3))
### END CODE HERE ###
print('shape of the output')
print(x.shape)
# Reshape to allow to concatenate the heads
return jnp.reshape(x, (-1, seqlen, n_heads * d_head))
return compute_attention_output
The print out of shapes and inputs are
# case 1
shape of input before reshape
(6, 2, 3)
n_heads: 2 d_head 3
shape of the output
(3, 2, 2, 3)
# case 2
shape of input before reshape
(6, 2, 3)
n_heads: 3 d_head 2
shape of the output
(3, 2, 3, 2)
All tests passed
I think case 2 is problematic. The input has shape (6, 2, 3), and last dimension is suppose to be d_head, but the given d_head argument is 2.