In my opinion, given plot_confusion_matrix
method in the C1W1_Assignment
notebook is overcomplicated
Also, the title
parameter is not used in the method body
def plot_confusion_matrix(y_true, y_pred, title='', labels=[0,1]):
cm = confusion_matrix(y_true, y_pred)
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(cm)
plt.title('Confusion matrix of the classifier')
fig.colorbar(cax)
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)
plt.xlabel('Predicted')
plt.ylabel('True')
fmt = 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="black" if cm[i, j] > thresh else "white")
plt.show()
plot_confusion_matrix(test_Y[1], np.round(type_pred), title='Wine Type', labels = [0, 1])
It can be reduced to the following using ConfusionMatrixDisplay
# Below I only list the missing imports
from sklearn.metrics import ConfusionMatrixDisplay
cm = confusion_matrix(test_Y[1], np.round(type_pred), labels=[0, 1])
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=[0, 1])
disp.plot(values_format='d');
which also results in a slightly better picture
Please note that this function is already given to the learners, I’m not revealing any answers