Explanation for attention_axes parameter in Keras' MultiheadAttention layer?

In Keras’ MultiheadAttention, there is a parameter called attention_axes described as:

axes over which the attention is applied. None means attention over all axes, but batch, heads, and features.

I’m confused as to what this means. For instance, if I call the MHA layer with q=k=v=x having shape (batch, seq\_len,d\_model), how would I specify the axes argument if I wanted to put in a value (I believe this is left as None in the programming assignment)? Which axes is the feature axes?

Hi @LuBinLiu ,

Instead of sentence, let’s take image as an example. Suppose we’re dealing with 2D images, there are (height * width) pixels per image.

height, width, features = 32, 64, 4
inputs = tf.keras.Input(shape=[height, width, features])
# (None, 32, 64, 4)

Case 1: We thought each pixel has something to do with all pixels. In other words, each pixel has to pay attention to all pixels in the image. Thus, the #of attention weights per pixel is (height * width), and total is num_heads * (height * width) * (height * width).

layer1 = MultiHeadAttention(num_heads=2, key_dim=128)
output1, weights1 = layer1(inputs, inputs, return_attention_scores=True)
print(output1.shape, weights1.shape)
# (None, 32, 64, 4) (None, 2, 32, 64, 32, 64)

Case 2: We thought each pixel only related to the pixels that have the same width. In other words, each pixel only pays attention to the pixels which have the same width but different height. So, we set attention_axes=(1). Thus, the #of attention weights is height, and total is num_heads * width * (height * height).

layer2 = MultiHeadAttention(num_heads=2, key_dim=128, attention_axes=(1))
output2, weights2 = layer2(inputs, inputs, return_attention_scores=True)
print(output2.shape, weights2.shape)
# (None, 32, 64, 4) (None, 64, 2, 32, 32)

If the inputs are sentences like in the Programming Assignment, should setting attention_axes = (1) (corresponding to the seq_len dimension of the input) produce the same result as attention_axis = None (default) ?

You’re right, it’s same, because a sentence is one dimension (except batch and feature dimensions.) Setting attention_axes=(1) (applies attention to axis=1) is same as attention_axes=None (applies attention to all axes.)

1 Like