C5W4 Transformers Assignment/MultiHeadAttention & Concern About Q,K and V dimensions

Hello,
First, I liked a lot, the last assignment of the new added module 4, about Transformers, I (maybe the best in the specialization), especially when working with the classes (…): Really Excellent!
But,
I have a concern about Q, K and V:
First, since in the Assignment, we were not invited to implement the MultiHeadAttention (I would have liked to see a part of the assignment to implement the MultiHeadAttention (by ourselves), because it is the core of transformers: this is where the magic happens (the parallelization of computations …) , I searched my self to understand how the parallelization is done: Splitting Q,K and V…(I ended up by understanding it.

Second, by examining the function
scaled_dot_product_attention(q, k, v, mask) of the Assignment: I found like a contradiction

  • In fact, by examining the computations (within this function of the assignment)
    we should have this equality so that the computations work in all cases
    seq_len_k = seq_len_v , because to do the multiplication between attention_weights and v, it won’t be possible if seq_len_k is different from seq_len_v
    ==> so,only two arguments seq_len_k_v and seq_len_q are enough (not 3 different params)

  • Then, By examining the MultiHeadAttention implementation(by Searching…), we can immediately realize that the parallelization of the computations (which characterizes the transformers), was possible “roughly” by using the two key functions: tf.reshape, and tf.transpose, and computing their last_dimension(depth)= original_depth_before_splitting //number_of_heads…) and it turns out thatQ, K and V have the same dimensions. In other words: the same depth, and the same seqence_length (that said depth_v (in our case, in the function scaled_dot_product_attention(q, k, v, mask) ) should not be different from depth => we can therefore merge them into a single depth parameter “depth”(not two params), similarly for the length of sequences: we can merge them into a single parameter: seq_len (not 3 different parameters, which gives the impression that they can be different). If I am wrong please enlighten me.

I hope that you were able to find the answer to your question.