Lab 1 question about n_heads in OurSelfAttention

When calling OurSelfAttention we used n_heads=3. But from how we defined the function it is not clear how n_heads is used. Also the dimension of input x is 8X5, and for n_heads=3, it is not a multiple for either 8 or 5. How is the n_heads used here?

class OurSelfAttention(tl.SelfAttention):
“”“Our self-attention. Just the Forward Function.”“”

def forward_unbatched(
    self, x, mask=None, *, weights, state, rng, update_state, verbose=False
):
    print("ourSelfAttention:forward_unbatched")
    del update_state
    attend_rng, output_rng = fastmath.random.split(rng)
    if self._bias:
        if self._share_qk:
            w_q, w_v, w_o, b_q, b_v = weights
        else:
            w_q, w_k, w_v, w_o, b_q, b_k, b_v = weights
    else:
        if self._share_qk:
            w_q, w_v, w_o = weights
        else:
            w_q, w_k, w_v, w_o = weights

    print("x.shape,w_q.shape", x.shape, w_q.shape)
    q = np.matmul(x, w_q)
    k = None
    if not self._share_qk:
        k = np.matmul(x, w_k)
    v = np.matmul(x, w_v)

    if self._bias:
        q = q + b_q
        if not self._share_qk:
            k = k + b_k
        v = v + b_v

    mask_fn = functools.partial(
        mask_self_attention,
        causal=self._causal,
        exclude_self=self._share_qk,
        masked=self._masked,
    )
    q_info = kv_info = tie_in(x, np.arange(q.shape[-2], dtype=np.int32))

    assert (mask is not None) == self._masked
    if self._masked:
        # mask is a boolean array (True means "is valid token")
        ones_like_mask = tie_in(x, np.ones_like(mask, dtype=np.int32))
        kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask)

    # Notice, we are calling our version of attend
    o, _ = our_simple_attend(
        q,
        k,
        v,
        mask_fn=mask_fn,
        q_info=q_info,
        kv_info=kv_info,
        dropout=self._attention_dropout,
        rng=attend_rng,
        verbose=True,
    )

    # Notice, wo weight matrix applied to output of attend in forward_unbatched
    out = np.matmul(o, w_o)
    out = apply_broadcasted_dropout(out, self._output_dropout, output_rng)
    return out, state

ChatGPT

Hi @PZ2004

Looking at the code, I see this lab introduced the customized attention. In the lab:

n_heads = 3
d_qk = 3
d_v = 4
seq_len = 8
emb_len = 5
batch_size = 1

Notice, the d_qk specifies the dimensionality for Q and K, and d_v the dimensionality for V.

If “I love learning” is a concrete example sentence, then:

  • the tokenizer could tokenize it to [23, 565, 2332], then the data generator would pad this sentence to [23, 565, 2332, 0, 0, 0, 0, 0] (notice seq_len = 8 and batch_size in this case is 1, since we have 1 sentence). So the shape of input would be (1, 8);
  • the embedding layer would embed this input to shape (1, 8, 5) (OurSelfAttention would have the embedding weight matrix of shape - (vocab_size, 5)., so (1, 8) → (1, 8, 5));
  • now OurSelfAttention has 3 heads, and each head’s weights w_q, w_k and w_v are of shapes (5, 3), (5, 3) and (5, 4) accordingly. Also, what is not explicitly well explained, w_o is of shape (4, 5) so that the output shape would match the input shape. So, step by step (with code lines), when each head receives an input of shape (1, 8, 5):
    – Q becomes (8, 3); since (8, 5) dot (5, 3) → (8, 3); q = np.matmul(x, w_q);
    – K becomes (8, 3;) since (8, 5) dot (5, 3) → (8, 3); k = np.matmul(x, w_k);
    – V becomes (8, 4); since (8, 5) dot (5, 4) → (8, 4); v = np.matmul(x, w_v);
  • now attention matrix becomes (8, 8) since (8, 3) dot (3, 8) → (8, 8); dots = np.matmul(q, kr );
  • applied attention on V results in (8, 4); since (8, 8) dot (8, 4) → (8, 4); out = np.matmul(dots, v);
  • the outputs are then transformed to be the same shape as inputs (8, 5) (by applying w_o weight matrix); (8, 4) dot (4, 5) → (8, 5); np.matmul(o, w_o);
  • repeat the same steps for the remaining two heads;
  • in this lab, the results of all 3 heads are summed up (code). I’m not sure/remember why is this the case here, but probably for LSHAttention.

Cheers

1 Like