Problems in UNQ_C1: matmul, where or logsumexp?

I have been working for a lot of time now on this question and the sample DotProductAttention line underneath runs and reproduces the Expected Output, nevertheless, tests are not passing.

Are there any other steps I should be mindful of for UNQ_C1 besides the following?

dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)
dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) (with conditional)
logsumexp = trax.fastmath.logsumexp(dots, axis=-2)

When I check the testing file, the masks are not of Bool type, but has zeros (0) and -1e9 as values… can this be the source of my error?



The problem is with calculating logsumexp

While I try to fix it (will try in a couple of hours) could you help me a bit more explicitly, please? I do not find the trax docs very helpful. Is the axis the problem? I figured it should be axis=-1 initially, for the last dimension, but tests are not passing either.


Certainly. For starters, please click my name and send your notebook as an attachment.

I sent it! Thanks in advance for your help!

I was stuck there until I stumbled on this solution:


it has passed local tests and passed the autograder :muscle:


Thanks it works! How did you think that through if I can ask?

To share what went on:

  1. When you use trax.fastmath.logsumexp, use it with keepdims=True i.e. trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)
  2. There’s no need to transpose logsumexp in the next step. it’s just dots - logsumexp . The code comment should be helpful on that.
1 Like

I actually found it by accident. I searched for logsumexp and the first hit was scipy. That documentation includes this description for the parameter keepdims

keepdims bool, optional

If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array.

I didn’t find anything useful in the trax documentation :frowning_face: and the comment above that line of code in my notebook just says Note: softmax = None which is rather unhelpful. Glad this worked for you!