Hi,
I needed some help in understanding the loss function while using multiclass classifiers.
If the loss function when using multiclass classifier (using softmax) does not take into account the predictions for the non-true classes for the training example, how are the values in the non-true classes for the training example tuned to lower values ?
Thanks!
Hi @Abhi08,
Iām not sure if Iām understanding the question, so bear with me and please feel free to respond, of course!
The training set is a vector of n classes with just a 1 in the class that represented in the example. Even though there is only one class in the example, the error will be āglobalā, in the sense that whatever loss is there, it will be backpropagated and it will affect all weights in all layers. This in turn will affect the values of the non-true classes which are the inputs to the softmax.
So, in all, is not that the non-true values are being tuned down, itās just that loss+backprop is increasing the probability of the true class as part of the learning process.
Hope that helped.
1 Like
Also, I just noticed that you talk about softmax as a loss function, and that is not quite right.
Softmax is an activation function. It takes a vector with n classes in it and it transforms it into a probability distribution (sum of the output is 1) over the predicted output classes.
For TF/Keras, the loss function used in these cases is the categorical_crossentropy, which calculates the loss across the classes. You can read more about it here: tf.keras.losses.categorical_crossentropy Ā |Ā TensorFlow Core v2.4.1
1 Like
Thank you for your reply.
Apologies for not being clear with the question.
I meant to ask how are the non-true values tuned to lower values.
For example, suppose we need to classify a dataset into 3 classes- cat, dog and none of the above (nota)
So we have 3 output cases, y= [1 0 0] (cat), [0 1 0] (dog) and [0 0 1] (nota)
If we have an output from a catās image as-
(case 1)
y_pred=[0.92 0.04 0.03] - This can be confidently classified as a cat
(case 2)
y_pred=[0.92 0.04 0.94] - This cannot be classified as a cat.
However, while training with cat images, since the value of the 3rd element in the output vector is not a part of the cost function. How do we make sure that the training data does not end up giving case 2 ?
In binary classification, it is avoided by including the non-true element in the cost function. Thus ensuring it has a low value.
Thanks! Yes, I did not consider the softmax as a loss function. I meant the loss function associated with a multiclass classifier.
Itās a good question. The way to look at this is that the point is that all the outputs of softmax add up to 1, right? All we really need to care about is the value that gets assigned to the correct label. How the āerrorā (wrong) values get assigned or distributed among the possible wrong answers doesnāt really matter. Say the model is still pretty bad and only gives 0.4 for the correct label: we donāt really care where the rest of the 0.6 gets distributed. We just know itās pretty bad that the correct answer only got 0.4. If you think about it, thatās what the loss function is doing.
To make this a bit more concrete, suppose that we have 4 classes: cat, dog, kangaroo and ānone of the aboveā. If a given sample is labelled as a cat and the model predicts 0.4 for cat, then we donāt really care whether the other 3 values are (0.6, 0.0, 0.0) or (0.2, 0.2, 0.2) or (0.1, 0.1, 0.4) or some permutation of those values. Those are all equally bad answers from our point of view.
3 Likes
Ok, Thank you for the reply !