Hey everyone,
When I finished the assignment of week 3, I wanted to look at some other code examples and how people used 3D U-Net on Decathlon Challenge to see how I can improve the model further. One of the key points that confused my mind is, people tend to use different loss/activation functions.
There are 4 classes in the dataset, therefore we need a multiclass loss function such as the one we used in the assignment (soft dice loss), and at the same time, we are trying to predict pixels which can be done using the logistic function. However, I saw that some people used binary cross entropy and softmax function in the last layer. What would that change?
My intuitive understanding is this, we are trying to predict the probabilities of pixels being 0 or 1 by using the logistic function in the last layer and producing a mask. Since we have 4 different classes, there are class-dependent differences in the masks therefore by using soft dice loss we are taking into account these differences and trying to find an optimal point where our logistic function produces the correct masks for each class. In this case, we can simply also use categorical cross entropy which would also take into account all the classes. Binary cross entropy, on the other hand, would not take into account the differences in the classes and try to find an optimal point as if all the instances are drawn from the same class (In practice, this I think would be a problem for a multiclass segmentation because the class which has the highest number of training instances would dominate the loss function). Lastly, my intuitive understanding of using softmax in the last layer is much more blurry, because as long as we take into account the differences in the classes by using multi-class loss functions (categorical cross entropy, soft dice loss, or other losses), our model just needs to output 0 or 1 (black or white points) to produce a mask. Softmax would be perhaps handy for producing different outputs for each class, however, I am not sure we need multi-outputs here.
Could anyone kindly correct me if there is something that I am missing in my intuition or provide further information? I think this concept can be complex for beginners and many people may need to ask this question at some time on their journey to learn segmentation.
Kind regards.