The shapes are correct. x_tmp.shape[-1] != da_tmp.shape[-1]
You’re right about the general observation. In the test, we’re performing the calculations for the first 4 steps. If you look at the comment
#Retrieve the dimensions from da's and x1's shape
, take T_x
from da.shape[-1]