When calling OurSelfAttention we used n_heads=3. But from how we defined the function it is not clear how n_heads is used. Also the dimension of input x is 8X5, and for n_heads=3, it is not a multiple for either 8 or 5. How is the n_heads used here?
class OurSelfAttention(tl.SelfAttention):
“”“Our self-attention. Just the Forward Function.”“”
def forward_unbatched(
self, x, mask=None, *, weights, state, rng, update_state, verbose=False
):
print("ourSelfAttention:forward_unbatched")
del update_state
attend_rng, output_rng = fastmath.random.split(rng)
if self._bias:
if self._share_qk:
w_q, w_v, w_o, b_q, b_v = weights
else:
w_q, w_k, w_v, w_o, b_q, b_k, b_v = weights
else:
if self._share_qk:
w_q, w_v, w_o = weights
else:
w_q, w_k, w_v, w_o = weights
print("x.shape,w_q.shape", x.shape, w_q.shape)
q = np.matmul(x, w_q)
k = None
if not self._share_qk:
k = np.matmul(x, w_k)
v = np.matmul(x, w_v)
if self._bias:
q = q + b_q
if not self._share_qk:
k = k + b_k
v = v + b_v
mask_fn = functools.partial(
mask_self_attention,
causal=self._causal,
exclude_self=self._share_qk,
masked=self._masked,
)
q_info = kv_info = tie_in(x, np.arange(q.shape[-2], dtype=np.int32))
assert (mask is not None) == self._masked
if self._masked:
# mask is a boolean array (True means "is valid token")
ones_like_mask = tie_in(x, np.ones_like(mask, dtype=np.int32))
kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask)
# Notice, we are calling our version of attend
o, _ = our_simple_attend(
q,
k,
v,
mask_fn=mask_fn,
q_info=q_info,
kv_info=kv_info,
dropout=self._attention_dropout,
rng=attend_rng,
verbose=True,
)
# Notice, wo weight matrix applied to output of attend in forward_unbatched
out = np.matmul(o, w_o)
out = apply_broadcasted_dropout(out, self._output_dropout, output_rng)
return out, state
ChatGPT