I have been working for a lot of time now on this question and the sample DotProductAttention line underneath runs and reproduces the Expected Output, nevertheless, tests are not passing.
Are there any other steps I should be mindful of for UNQ_C1 besides the following?
dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)
dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) (with conditional)
logsumexp = trax.fastmath.logsumexp(dots, axis=-2)
When I check the testing file, the masks are not of Bool type, but has zeros (0) and -1e9 as values… can this be the source of my error?
The problem is with calculating
While I try to fix it (will try in a couple of hours) could you help me a bit more explicitly, please? I do not find the trax docs very helpful. Is the axis the problem? I figured it should be axis=-1 initially, for the last dimension, but tests are not passing either.
Certainly. For starters, please click my name and send your notebook as an attachment.
I sent it! Thanks in advance for your help!
I was stuck there until I stumbled on this solution:
it has passed local tests and passed the autograder
Thanks it works! How did you think that through if I can ask?
I actually found it by accident. I searched for
logsumexp and the first hit was
scipy. That documentation includes this description for the parameter
keepdims bool, optional
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array.
I didn’t find anything useful in the
trax documentation and the comment above that line of code in my notebook just says Note: softmax = None which is rather unhelpful. Glad this worked for you!