Hi @Harvey_Wang

However, the input “x” for compute_attention_heads is already a tensor where one of the axes has a length equal to the number of heads.

No, the input `x`

is a tensor with shape **(n_batch, seqlen, n_heads X d_head)** - note that the last dimension is a product of n_heads and d_head.

`compute_attention_heads_closure`

needs to know how to split last dimension for different heads.

For example, if you have an `x`

with shape - (3, 2, 6) how would you split the 6? 2x3 or 3x2? When you know `n_heads`

(and `d_head`

) you can reshape `x`

:

n_batch, seqlen, n_heads*d_head → n_batch, seqlen, n_heads, d_head

(3, 2, 6) → (3, 2, 2, 3)

later you rearrange the dimensions:

n_batch, seqlen, n_heads, d_head → n_batch, n_heads, seqlen, d_head

(3, 2, 2, 3) → (3, 2, 2, 3)

and reshape to (n_batch X n_heads, seqlen, d_head):

n_batch, n_heads, seqlen, d_head → n_batch*n_heads, seqlen, d_head

(3, 2, 2, 3) → (6, 2, 3)

This reshaping and rearranging is done with Queries, Keys and Values.

For a more real example, lets pretend that there is one sentence (batch_size=1), and this sentence padded is of length 8. Then here are the dimensions:

Input: (1, 8)

Then the first part of `TransformerLM`

(with comments, you are asking about `AttnHeads`

part):

```
Serial[
Serial[
ShiftRight(1) # (1, 8) still
]
Embedding_33300_512 # (1, 8, 512)
Dropout # (1, 8, 512)
PositionalEncoding # (1, 8, 512)
Serial[
Branch_out2[ # makes two copies
None # does nothing to first
Serial[
LayerNorm # (1, 8, 512)
Serial[
Branch_out3[ # makes three copies
# Note Dense_512 with W(512,512), b(512,1) result in no shape change -> (1, 8, 512)
# These three Dense_512 create Queries, Keys and Values matrices of same shape (1, 8, 512)
# but different number values.
# AttnHeads (compute_attention_heads_closure) reshapes and rearranges these values
[Dense_512, AttnHeads] # (1, 8, 512) -> (1, 8, 8, 64) # queries # n_heads=8
[Dense_512, AttnHeads] # (1, 8, 512) -> (1, 8, 8, 64) # keys
[Dense_512, AttnHeads] # (1, 8, 512) -> (1, 8, 8, 64) # values
]
# Dot product self-attention. (apply DotProductAttention(query, key, value, mask))
DotProductAttn_in3 # (1, 8, 8, 64) # last dimension (L_q by L_v)
AttnOutput # (1, 8, 512) # reshape and rearrange again (compute_attention_output_closure)
Dense_512 # (1, 8, 512) # apply linear transformation on sorted back attention output
]
Dropout # (1, 8, 512)
]
]
Add_in2 # Adds first copy to the output (residual connection)
]
...
```

Cheers