I do not understand what Part 3.4 is doing. I particularly don’t understand Step 5. How was the softmax over one hash when the hashes are chunked together into one tensor? And what is the purpose of using the softmax again? I just don’t understand this step in general.

Hey @Harvey_Wang,

Apologies for the delay. I started doing the same lab myself. Undoubtedly, this lab has been the most difficult for me throughout this specialization. So, allow me to try to help you out.

Just to be on the same page, can you please let me know if you are facing trouble in understanding the concept or in understanding the code?

Cheers,

Elemento

No worries at all about the delay. Thanks for your response!

I mainly have trouble understanding the code. However, it’s possible that my trouble with understanding the code may also be because I don’t understand the concepts. So, I think an explanation of both the code and concepts would help.

Hey @Harvey_Wang,

Sure, I will try my best. I am assuming here that we are clear on the concept of Chunked Dot Product Attention, and till 3.4, we have sorted our queries so that the similar queries are in one chunk, and now, we simply need to compute the attention for these individual chunks. For simplicity, in our implementations we are very conveniently assuming that each bucket spans over a single chunk only, which is not the case in practice, and this has been explicitly mentioned towards the end of the lab.

Furthermore, in my entire explanation, I would be bringing up analogies to word embeddings, since I believe, it is easier to relate and understand in that manner.

Now, let’s begin with the shape of `sq`

, which as per the lab is `(16, 3)`

, i.e., 16 word embeddings, each embedding spanning over 3 dimensions. The `kv_chunk_len = t_kv_chunk_len = 2`

, i.e., we have 2 chunks. So, the very first step is to divide all the embeddings into 2 chunks, and hence, the shape of `rsq`

is `(8, 2, 3)`

. The shape of `rsqt`

is `(8, 3, 2)`

, which is basically taking the transpose of `rsq`

. So, essentially `rsq`

is `Q`

and `rsqt`

is `K.T`

, as per our standard attention. And now the `dotlike`

denotes the `np.dot(Q, K.T)`

, which has a shape of `(8, 2, 2)`

.

The step 2 is pretty simple as well, but somewhat interesting. You will find that in the code cell following the one that we are discussing, 2 examples have been mentioned, one with softmax disabled, i.e., `passthrough = True`

, and the other one with softmax enabled, i.e., `passthrough = False`

. I am not sure, why softmax has been disabled per se, but I guess, it’s just to show the difference between the 2 outputs and how `our_softmax`

affects the final output. That being said, let us talk about the second example, in which softmax is enabled.

So, `dotlike`

after Step 2 contains `Softmax(Q, K.T)`

and `slogits`

contains the sum of logs, which is pretty conventional.

P.S. - I will continue this answer in my next comment.

Now, coming to Step 3, which is pretty simple as well. In Step 3, we compute `so`

, which is `Softmax(Q, K.T) @ V`

, which initially has a shape of `(8, 2, 5)`

, since `dotlike`

has a shape of `(8, 2, 2)`

and `vr`

has a shape of `(8, 2, 5)`

. And after reshaping once more, `so`

has a shape of `(16, 5)`

and `slogits`

is reshaped to `(16,)`

, i.e., one sum of logits for each embedding. So, now, we have one row for each embedding in `so`

and one sum of logits for each embedding.

Now, let’s come to Step 4, which is perhaps the simplest. It just re-arranges the `so`

and `slogits`

, so that they resemble the original order in which the embeddings were fed as the inputs. Hence, you can see that `o`

has the same shape as `so`

and `logits`

has the same shape as `slogits`

.

Now, comes Step 5, which is most interesting part, kind of like the “eureka” moment. If you scroll a bit upwards, you will find the length of sequences, denoted by `t_n_seq = t_seqlen = 8`

and the number of hash-tables for LSH, denoted by `t_n_hashes = 2`

. However, we know that `logits`

has a shape of `(16, 5)`

. So, essentially, for every word embedding, we have `t_n_hashes = 2`

different sums of weighted value vectors. So, how to combine these, so that we have only one sum of weighted value vectors per word embedding. And the answer is **Softmax**

We reshape `o`

to have a shape of `(2, 8, 5)`

and `logits`

to have a shape of `(2, 8, 1)`

. Now, we take `logits`

as the weights to combine these “sums of weighted value vectors”, and for that, we apply softmax, once again, on `logits`

to get `probs`

, using which we transform `o`

to give us the final output.

I really hope this helps you, and if you face any queries, feel free to let us know.

Cheers,

Elemento

Hey,

Only one thing is not clear, why are we taking logit, which is already a log sum as the weights?

Hey @Aaditya1,

Apologies for the delayed response. Can you please specify if you referring to Part 3.4? And if yes, can you please specify the step in 3.4, which you are referring to?

Cheers,

Elemento