I do not understand what Part 3.4 is doing. I particularly don’t understand Step 5. How was the softmax over one hash when the hashes are chunked together into one tensor? And what is the purpose of using the softmax again? I just don’t understand this step in general.
Hey @Harvey_Wang,
Apologies for the delay. I started doing the same lab myself. Undoubtedly, this lab has been the most difficult for me throughout this specialization. So, allow me to try to help you out.
Just to be on the same page, can you please let me know if you are facing trouble in understanding the concept or in understanding the code?
Cheers,
Elemento
No worries at all about the delay. Thanks for your response!
I mainly have trouble understanding the code. However, it’s possible that my trouble with understanding the code may also be because I don’t understand the concepts. So, I think an explanation of both the code and concepts would help.
Hey @Harvey_Wang,
Sure, I will try my best. I am assuming here that we are clear on the concept of Chunked Dot Product Attention, and till 3.4, we have sorted our queries so that the similar queries are in one chunk, and now, we simply need to compute the attention for these individual chunks. For simplicity, in our implementations we are very conveniently assuming that each bucket spans over a single chunk only, which is not the case in practice, and this has been explicitly mentioned towards the end of the lab.
Furthermore, in my entire explanation, I would be bringing up analogies to word embeddings, since I believe, it is easier to relate and understand in that manner.
Now, let’s begin with the shape of sq
, which as per the lab is (16, 3)
, i.e., 16 word embeddings, each embedding spanning over 3 dimensions. The kv_chunk_len = t_kv_chunk_len = 2
, i.e., we have 2 chunks. So, the very first step is to divide all the embeddings into 2 chunks, and hence, the shape of rsq
is (8, 2, 3)
. The shape of rsqt
is (8, 3, 2)
, which is basically taking the transpose of rsq
. So, essentially rsq
is Q
and rsqt
is K.T
, as per our standard attention. And now the dotlike
denotes the np.dot(Q, K.T)
, which has a shape of (8, 2, 2)
.
The step 2 is pretty simple as well, but somewhat interesting. You will find that in the code cell following the one that we are discussing, 2 examples have been mentioned, one with softmax disabled, i.e., passthrough = True
, and the other one with softmax enabled, i.e., passthrough = False
. I am not sure, why softmax has been disabled per se, but I guess, it’s just to show the difference between the 2 outputs and how our_softmax
affects the final output. That being said, let us talk about the second example, in which softmax is enabled.
So, dotlike
after Step 2 contains Softmax(Q, K.T)
and slogits
contains the sum of logs, which is pretty conventional.
P.S. - I will continue this answer in my next comment.
Now, coming to Step 3, which is pretty simple as well. In Step 3, we compute so
, which is Softmax(Q, K.T) @ V
, which initially has a shape of (8, 2, 5)
, since dotlike
has a shape of (8, 2, 2)
and vr
has a shape of (8, 2, 5)
. And after reshaping once more, so
has a shape of (16, 5)
and slogits
is reshaped to (16,)
, i.e., one sum of logits for each embedding. So, now, we have one row for each embedding in so
and one sum of logits for each embedding.
Now, let’s come to Step 4, which is perhaps the simplest. It just re-arranges the so
and slogits
, so that they resemble the original order in which the embeddings were fed as the inputs. Hence, you can see that o
has the same shape as so
and logits
has the same shape as slogits
.
Now, comes Step 5, which is most interesting part, kind of like the “eureka” moment. If you scroll a bit upwards, you will find the length of sequences, denoted by t_n_seq = t_seqlen = 8
and the number of hash-tables for LSH, denoted by t_n_hashes = 2
. However, we know that logits
has a shape of (16, 5)
. So, essentially, for every word embedding, we have t_n_hashes = 2
different sums of weighted value vectors. So, how to combine these, so that we have only one sum of weighted value vectors per word embedding. And the answer is Softmax
We reshape o
to have a shape of (2, 8, 5)
and logits
to have a shape of (2, 8, 1)
. Now, we take logits
as the weights to combine these “sums of weighted value vectors”, and for that, we apply softmax, once again, on logits
to get probs
, using which we transform o
to give us the final output.
I really hope this helps you, and if you face any queries, feel free to let us know.
Cheers,
Elemento
Hey,
Only one thing is not clear, why are we taking logit, which is already a log sum as the weights?
Hey @Aaditya1,
Apologies for the delayed response. Can you please specify if you referring to Part 3.4? And if yes, can you please specify the step in 3.4, which you are referring to?
Cheers,
Elemento
I understood it @Elemento , thanks for the help!!