# Week 2 Assignment: How are the multiple heads in the multi-head attention created?

I have finished the assignment, but I was confused about how the multiple heads were created. I noticed that one of the inputs for `compute_attention_heads_closure` was the number of heads. However, the input “x” for `compute_attention_heads` is already a tensor where one of the axes has a length equal to the number of heads.

Where and how are the heads created? Where does tensor “x” get its shape from?

However, the input “x” for compute_attention_heads is already a tensor where one of the axes has a length equal to the number of heads.

No, the input `x` is a tensor with shape (n_batch, seqlen, n_heads X d_head) - note that the last dimension is a product of n_heads and d_head.

`compute_attention_heads_closure` needs to know how to split last dimension for different heads.

For example, if you have an `x` with shape - (3, 2, 6) how would you split the 6? 2x3 or 3x2? When you know `n_heads` (and `d_head`) you can reshape `x`:

(3, 2, 6) → (3, 2, 2, 3)

later you rearrange the dimensions:
(3, 2, 2, 3) → (3, 2, 2, 3)

(3, 2, 2, 3) → (6, 2, 3)

This reshaping and rearranging is done with Queries, Keys and Values.

For a more real example, lets pretend that there is one sentence (batch_size=1), and this sentence padded is of length 8. Then here are the dimensions:

Input: (1, 8)

Then the first part of `TransformerLM` (with comments, you are asking about `AttnHeads` part):

``````Serial[
Serial[
ShiftRight(1)  # (1, 8) still
]
Embedding_33300_512  # (1, 8, 512)
Dropout  # (1, 8, 512)
PositionalEncoding  # (1, 8, 512)
Serial[
Branch_out2[  # makes two copies
None  # does nothing to first
Serial[
LayerNorm  # (1, 8, 512)
Serial[
Branch_out3[  # makes three copies
# Note Dense_512 with W(512,512), b(512,1) result in no shape change -> (1, 8, 512)
# These three Dense_512 create Queries, Keys and Values matrices of same shape (1, 8, 512)
# but different number values.
[Dense_512, AttnHeads]  # (1, 8, 512) -> (1, 8, 8, 64)  # queries # n_heads=8
[Dense_512, AttnHeads]  # (1, 8, 512) -> (1, 8, 8, 64)  # keys
[Dense_512, AttnHeads]  # (1, 8, 512) -> (1, 8, 8, 64)  # values
]
# Dot product self-attention. (apply DotProductAttention(query, key, value, mask))
DotProductAttn_in3  # (1, 8, 8, 64) # last dimension (L_q by L_v)
AttnOutput  # (1, 8, 512)  # reshape and rearrange again (compute_attention_output_closure)
Dense_512  # (1, 8, 512)  # apply linear transformation on sorted back attention output
]
Dropout  # (1, 8, 512)
]
]