Which dimension of k.shape should dk be? Is it correct to use k.shape[-1]?
The assignment is a bit confusing in this regard when it says:
“dk is the dimension of the keys”
and (in the code)
k – key shape == (…, seq_len_k, depth)
Which dimension of k.shape should dk be? Is it correct to use k.shape[-1]?
The assignment is a bit confusing in this regard when it says:
“dk is the dimension of the keys”
and (in the code)
k – key shape == (…, seq_len_k, depth)
The shape of k (as well as q, v) is (batch_size, num_heads, seq_len, depth), where batch_size and num_heads (#of multi-head) are optional.
“dk is the dimension of the keys” means its depth.
So, you’re right, it’s k.shape[-1].