Anyone understand why this transpose, i.e. swapping axis0 and axis1 is required prior to concat? In the next step, we’re anyway concatenating only over axis2:

The Step 2d is needed for the else statement to be equivalent with 'tf,fhb->htb' (check this post).

While the Step 3 is just concatenating the rotated_vecs with the same but opposite sign values (see the same linked post).

In other words, the 2a, 2b, 2c and 2d steps are the implementation of the 'tf,fhb->htb' matrix multiplication and Step 3 concatenation on the “irrelevant” axis is logical and in no way in odds with the previous steps.

Explains it well; thanks a bunch @arvyzukai ! My earlier point was that if you look at what is happening in 2d, one might be able to concatenate without the transpose and end up with a (8,3,4) shape in Step 3.

I believe one could technically adjust upstream steps to work with (8,3,4). However, I think I now get the mental convenience of having it as (3,8,4). Basically, “n_hashes count of (seq_len, rot_size) matrices” is easier to think and reason about than "seq_len count of (n_hashes , rot_size) matrices.