Trying to do scaled_dot_product_attention_test(). I got this error, and I figured that the cause is my dk, which is set to be dk = k.shape[-2]. I got the code to work by doing a tf.cast() into a tf.float32, but this seems unnecessarily complicated so I’m wondering if I missed anything/not doing this the way it’s intended.
Also, since dk is supposed to be the dimension of the keys, that means it’s supposed to be ‘seq_len_k’, and not ‘depth’, right? I’m very confused since doing k.shape[-1] and k.shape[-2] didn’t change anything (still passed the test)
The reason why I used it was because I was under the impression that we don’t know how many elements is in the shape of k (since the hint said: key shape == (…, seq_len_k, depth) ). Wouldn’t (dk, col) = np.shape(k) only work under the assumption that calling .shape will return only 2 numbers?
I’m terribly confused, so I tried to go back to my code of just dk = k.shape[-2] and running it so I can get the error again to update my picture to show the entire stack trace, but apparently now it worked completely without having to do tf.cast() to float?
Tried restarting kernel and clearing all output, yup still worked. Wow. I have absolutely no idea why it suddenly worked without tf.cast(), so if you have any insight that would be amazing.