Week 2 Assignment: How are the multiple heads in the multi-head attention created?

I have finished the assignment, but I was confused about how the multiple heads were created. I noticed that one of the inputs for compute_attention_heads_closure was the number of heads. However, the input “x” for compute_attention_heads is already a tensor where one of the axes has a length equal to the number of heads.

Where and how are the heads created? Where does tensor “x” get its shape from?

Hi @Harvey_Wang

However, the input “x” for compute_attention_heads is already a tensor where one of the axes has a length equal to the number of heads.

No, the input x is a tensor with shape (n_batch, seqlen, n_heads X d_head) - note that the last dimension is a product of n_heads and d_head.

compute_attention_heads_closure needs to know how to split last dimension for different heads.

For example, if you have an x with shape - (3, 2, 6) how would you split the 6? 2x3 or 3x2? When you know n_heads (and d_head) you can reshape x:

n_batch, seqlen, n_heads*d_head → n_batch, seqlen, n_heads, d_head
(3, 2, 6) → (3, 2, 2, 3)

later you rearrange the dimensions:
n_batch, seqlen, n_heads, d_head → n_batch, n_heads, seqlen, d_head
(3, 2, 2, 3) → (3, 2, 2, 3)

and reshape to (n_batch X n_heads, seqlen, d_head):
n_batch, n_heads, seqlen, d_head → n_batch*n_heads, seqlen, d_head
(3, 2, 2, 3) → (6, 2, 3)

This reshaping and rearranging is done with Queries, Keys and Values.

For a more real example, lets pretend that there is one sentence (batch_size=1), and this sentence padded is of length 8. Then here are the dimensions:

Input: (1, 8)

Then the first part of TransformerLM (with comments, you are asking about AttnHeads part):

Serial[
  Serial[
    ShiftRight(1)  # (1, 8) still
  ]
  Embedding_33300_512  # (1, 8, 512)
  Dropout  # (1, 8, 512)
  PositionalEncoding  # (1, 8, 512)
  Serial[
    Branch_out2[  # makes two copies
      None  # does nothing to first
      Serial[
        LayerNorm  # (1, 8, 512)
        Serial[
          Branch_out3[  # makes three copies
	    # Note Dense_512 with W(512,512), b(512,1) result in no shape change -> (1, 8, 512)
	    # These three Dense_512 create Queries, Keys and Values matrices of same shape (1, 8, 512)
	    # but different number values.
	    # AttnHeads (compute_attention_heads_closure) reshapes and rearranges these values
            [Dense_512, AttnHeads]  # (1, 8, 512) -> (1, 8, 8, 64)  # queries # n_heads=8
            [Dense_512, AttnHeads]  # (1, 8, 512) -> (1, 8, 8, 64)  # keys 
            [Dense_512, AttnHeads]  # (1, 8, 512) -> (1, 8, 8, 64)  # values
          ]
	  # Dot product self-attention. (apply DotProductAttention(query, key, value, mask))
          DotProductAttn_in3  # (1, 8, 8, 64) # last dimension (L_q by L_v)
          AttnOutput  # (1, 8, 512)  # reshape and rearrange again (compute_attention_output_closure)
          Dense_512  # (1, 8, 512)  # apply linear transformation on sorted back attention output
        ]
        Dropout  # (1, 8, 512)
      ]
    ]
    Add_in2  # Adds first copy to the output (residual connection)
  ]
...

Cheers

2 Likes