The choice of loss function and activation function

Hey everyone,

When I finished the assignment of week 3, I wanted to look at some other code examples and how people used 3D U-Net on Decathlon Challenge to see how I can improve the model further. One of the key points that confused my mind is, people tend to use different loss/activation functions.

There are 4 classes in the dataset, therefore we need a multiclass loss function such as the one we used in the assignment (soft dice loss), and at the same time, we are trying to predict pixels which can be done using the logistic function. However, I saw that some people used binary cross entropy and softmax function in the last layer. What would that change?

My intuitive understanding is this, we are trying to predict the probabilities of pixels being 0 or 1 by using the logistic function in the last layer and producing a mask. Since we have 4 different classes, there are class-dependent differences in the masks therefore by using soft dice loss we are taking into account these differences and trying to find an optimal point where our logistic function produces the correct masks for each class. In this case, we can simply also use categorical cross entropy which would also take into account all the classes. Binary cross entropy, on the other hand, would not take into account the differences in the classes and try to find an optimal point as if all the instances are drawn from the same class (In practice, this I think would be a problem for a multiclass segmentation because the class which has the highest number of training instances would dominate the loss function). Lastly, my intuitive understanding of using softmax in the last layer is much more blurry, because as long as we take into account the differences in the classes by using multi-class loss functions (categorical cross entropy, soft dice loss, or other losses), our model just needs to output 0 or 1 (black or white points) to produce a mask. Softmax would be perhaps handy for producing different outputs for each class, however, I am not sure we need multi-outputs here.

Could anyone kindly correct me if there is something that I am missing in my intuition or provide further information? I think this concept can be complex for beginners and many people may need to ask this question at some time on their journey to learn segmentation.

Kind regards.

Hi @Neurojedi,

Your analysis of the problem is interesting and comprehensive. Let me try to add a bit of value on top it.

1- However, I saw that some people used binary cross entropy and softmax function in the last layer. What would that change?

  • Using binary cross entropy for multi-class is not a good approach because we have multiple classes, while binary cross entropy gives us 2 classes. You could however separate the classes by 2, apply binary cross entropy on each pair, and then average the results. However this is a weak approach because it doesn’t capture the relationship between all the classes.

2- In this case, we can simply also use categorical cross entropy which would also take into account all the classes.

  • Yes, categorical cross entropy is appropriate to use here and this is because categorical cross entropy can handle well class imbalance.

3- Softmax would be perhaps handy for producing different outputs for each class, however, I am not sure we need multi-outputs here.

  • Softmax use in this case is not a must, but it is helpful. It gives a final output with a probability distribution that sums up to 1. Let’s say that this normalizes our output for a better outcome.

Hope this answers your queries.

Best regards

Thank you for your answer!! I didn’t think that softmax would be useful thanks to its normalization effect, so you do have a point there.

Also about the loss functions, we can also add that one of the main reasons why we choose dice loss over cross-entropy is there is already a big class imbalance in the images (background and all the other classes). Categorical cross entropy may not also be a good choice in this case but I think using binary cross entropy would make things even worse in this case, and let background dominate the loss.