Note the last part of Mubsi’s comment: using the wrong function there (mean instead of sum) is not the only issue with your code. Have another careful look at the dimensions of the inputs. Also note that you’ll need the from_logits
parameter there, as discussed on this thread.
1 Like