Hi!

I have been working for a lot of time now on this question and the sample DotProductAttention line underneath runs and reproduces the Expected Output, nevertheless, tests are not passing.

Are there any other steps I should be mindful of for UNQ_C1 besides the following?

dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)

dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) (with conditional)

logsumexp = trax.fastmath.logsumexp(dots, axis=-2)

When I check the testing file, the masks are not of Bool type, but has zeros (0) and -1e9 as values… can this be the source of my error?

Thanks,

Gabriel

3 Likes

The problem is with calculating `logsumexp`

While I try to fix it (will try in a couple of hours) could you help me a bit more explicitly, please? I do not find the trax docs very helpful. Is the axis the problem? I figured it should be axis=-1 initially, for the last dimension, but tests are not passing either.

Thanks,

Gabriel

Certainly. For starters, please click my name and send your notebook as an attachment.

I sent it! Thanks in advance for your help!

I was stuck there until I stumbled on this solution:

```
trax.fastmath.logsumexp(dots,axis=-1,keepdims=True)
```

it has passed local tests and passed the autograder

4 Likes

Thanks it works! How did you think that through if I can ask?

I actually found it by accident. I searched for `logsumexp`

and the first hit was `scipy`

. That documentation includes this description for the parameter `keepdims`

**keepdims** bool, optional

*If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array.*

I didn’t find anything useful in the `trax`

documentation and the comment above that line of code in my notebook just says *Note: softmax = None* which is rather unhelpful. Glad this worked for you!

https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html

2 Likes