I ran into the same error, it turned out I didn’t transpose the matrix k. matmul_qk = tf.matmul(q, k.transpose())