C3W2, assignment - Exercise 4 - masked_loss()

In the above image we can see that for the given true_labels and predicted_logits, the cross entropy loss is 1.0508.
Above is the formula for cross entropy loss, if we apply this to the given “true_labels” and “predicted_logits” we get:
(-1 / 4) * [ln(0.1) + ln(0.7) + ln(0.4) + ln(0.4)]
(-1 / 4) * [-2.302 -0.356 -0.916 -0.916]
= 1.1225

But as we can see that the code expects the values as 1.0508.

I even checked the tensorflow documentation for tf.keras.losses.SparseCategoricalCrossentropy() and the formula i mentioned i what they used, because for the below image loss is:
(-1/2) * [ln(0.95) + ln(0.1)]
(-1/2) * [-0.0512 -2.302]
=1.177, which is what is the output in the below image.

Can you tell me the reason for this discrepancy?

This discrepancy is merely due to two different values of the from_logits argument in the tf.keras.losses.SparseCategoricalCrossentropy function.

I assume that in the assignment you used from_logits=True, since you got the expected output correctly.

This is indeed required in the assignment, as the instructions state:

But the default in tf.keras.losses.SparseCategoricalCrossentropy() is from_logits=False:

So if you use the default arguments with the values provided in assignment you get:

Which is close enough to 1.1225 you calculated above. I assume the difference in the 4th decimal is due to some rounding.

Similarly for the other set of values you provided:

Regarding the difference of from_logits=False & from_logits=True:

  • from_logits=False is for use with probabilities that have already undergone a softmax activation function
  • from_logits=True is for use with raw logits - before applying a softmax activation function

Given the above, I would say that the use of predicted_logits = [[0.1,0.6,0.3] , [0.2,0.7,0.1], [0.1, 0.5,0.4], [0.4,0.4,0.2]] in the testing snippet of the assignment could be confusing with regards to the use of from_logits=True: the list of predicted_logits is clearly probabilities, they are all positive and sum up to 1. Yet the masked_loss function itself was written for use with raw logits as the instructions state.