Suppose we are working on a Multi class classification problem with 4 classes. When we did our error analysis. I found out that the model is predicting more wrong for class 3. What should I do to improve those ones specifically?
May be a dumb question. Is there anyway I can change the architecture specific to the class 4 in the Neural Network?
You can start by checking if Class 3 is under-represented in the training data. And if that is the case, you could try to include more of the class 3 samples in the training set…worst case, if you cannot find new samples for class 3, take some of the existing examples from class 3 and add it back to the training set.
You are asking your question in C2 W3, so I suppose you have gone through Andrew’s video in fighting against underfitting which seems to be the case of your question. Shanup has shared suggestions in that direction, and personally examining your data would also be the first thing I would do.
If I further on that direction but to focus on the architecture, let’s first take a step back and think about what a multi-class neural network is giving us. We have some hidden layers and we have an output layer. In the output layer, we have 4 neurons each for one class. What’s in common about them is that, all 4 neurons take the same set of input from the hidden layers. Therefore, we are optimizing those hidden layers in a way that is good to all 4 classes overall-speaking. In other words, the hidden layers should learn features that distinguish one class from another.
Now, you say class 4 is performing badly, so it could be that there is not enough features learnt in those hidden layer that can distinguish class 4 from the other classes. In this case, I would try to increase the size of my hidden layers and see if it can learn something additional and useful. Now, there is no guarantee it will learn something if you add 2 neurons or 5, let’s say. You might need to add way more than that before it has a chance to learn such features, but before you see an improvement in class 4, you might see a performance drop due to overfitting. This is where you need (to gain) skills and experience to get to your desired model.
Having said that, please don’t just focus on the model architecture. I think sometimes it is a fatal mistake that they overlook the data.
Happy new year, and cheers,