Hi @mats!
Are you passing the axis argument in tf.math.argmax? If you do not pass the axis argument, it is default to 0, which means you are performing the argmax across the columns and this is not what you want.
Furthermore, I would like just to point that the mask should indicate the values that are different from -1 and not equal to -1, otherwise you will be removing the non masked values after masking.
Please, tell me if this helps.
Cheers,
Lucas