Question about the dimensions of Q V K in lab1

In lab 1, there is defining the Q V K in LSH. what are these dimension variable correspond to?

In the code below, emb_len was not used, does seq_len refer to length of embedding instead? It is a common dimension between Q and V. I suppose they should have similar embedding length.

seq_len = 8
emb_len = 5
d_qk = 3
d_v = 4
with fastmath.use_backend(“jax”): # specify the backend for consistency
rng_attend = fastmath.random.get_prng(1)
q = k = jax.random.uniform(rng_attend, (seq_len, d_qk), dtype=np.float32)
v = jax.random.uniform(rng_attend, (seq_len, d_v), dtype=np.float32)
o, logits = our_simple_attend(
q,
k,
v,
mask_fn=None,
q_info=None,
kv_info=None,
dropout=0.0,
rng=rng_attend,
verbose=True,
)
print(o, “\n”, logits)

Hi,

  • seq_len just refers to the number of tokens (if you set one word to be one token for easy understanding, it is the length of the input sentence.)
  • Each token of the input sentence is transformed into an ‘embedding’ which is simple a vector. The length (or dimension) of the vector is called emd_len.
  • Now, you need to convert each embedding into Query, Key, and Value which are all vectors. The lengths of Query, Key, and Value are called d_q, d_k, d_v, respectively.
  • The reason we often omit d_q is that it is almost always the same as d_k.