I didn’t quite understand why we have compute_attention_heads and compute_attention_output functions. Why are we splicing and reshaping etc.?
Hi Yeshwant_Drattatreya,
compute_attention_heads creates the different heads that are used in the multi-head attention mechanism with d_head = d_feature/n_heads. compute_attention_output stacks the results of the application of the multi-head attention mechanism back together to n_heads * d_head = d_feature.