What does key_min do in tf.keras.layers.MultiHeadAttention?

According to the Keras documentation (tf.keras.layers.MultiHeadAttention  |  TensorFlow Core v2.9.1), tf.keras.layers.MultiHeadAttention constructor requires two positional arguments; num_heads and key_dim, where key_dim is ‘Size of each attention head for query and key’.

To me, it is unclear what ‘key_dim’ means. When we call it, we are to pass in call arguments of query and value, each of which has the shape (B, T, dim) and (B, S, dim) (tf.keras.layers.MultiHeadAttention  |  TensorFlow Core v2.9.1).

I think what matters is the ‘dim’ we pass in when calling it not ‘key_dim’ we pass in when instantiating it. Do different 'key_dim’s force different internal processing?

Welcome to the community.

Here is an overview of Transformer Encoder to focus on MlutiHeadAttention.

The most important dimension is so called “embedding_dim”, d_{model}, which is used throughout the transformer processing.

Initially, the input word sequence is converted into the word vector, then, positional encoding is added.
Then, this goes into MHA. (Note that in the case of Self-attention, we use same X for Q, K and V. In the case of Decoder process, Q comes from Self-attention in Decoder, and K/V come from Encoder.)

In MHA, there is a linear operation with trainable weights and bias. Then, weighted Q, K and V are split into query_size which is d_k = \frac{d_{model}}{(number\ of\ heads)}, and dispatched into multiple heads in MHA.

After a scaled dot product in each head, outputs from all heads are concatenated into a single output, and fed into a fully connected layer.

Hope this helps.

Hi Nobu,
Even though your answer did not exactly relate to my question, it was an excellent help in getting me to better understand the Transformer architecture. I could find the answer to my question after some thoughts by leveraging your masterpiece visualization. I really appreciate your support. My sincere thanks to you.