C1W1 Assignment confusion matrix code simplification (using ConfusionMatrixDisplay)

In my opinion, given plot_confusion_matrix method in the C1W1_Assignment notebook is overcomplicated
Also, the title parameter is not used in the method body

def plot_confusion_matrix(y_true, y_pred, title='', labels=[0,1]):
    cm = confusion_matrix(y_true, y_pred)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(cm)
    plt.title('Confusion matrix of the classifier')
    fig.colorbar(cax)
    ax.set_xticklabels([''] + labels)
    ax.set_yticklabels([''] + labels)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    fmt = 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
          plt.text(j, i, format(cm[i, j], fmt),
                  horizontalalignment="center",
                  color="black" if cm[i, j] > thresh else "white")
    plt.show()

plot_confusion_matrix(test_Y[1], np.round(type_pred), title='Wine Type', labels = [0, 1])

It can be reduced to the following using ConfusionMatrixDisplay

# Below I only list the missing imports
from sklearn.metrics import ConfusionMatrixDisplay

cm = confusion_matrix(test_Y[1], np.round(type_pred), labels=[0, 1])
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=[0, 1])
disp.plot(values_format='d');

which also results in a slightly better picture
image

Please note that this function is already given to the learners, I’m not revealing any answers

Nice!
I’ll pass the suggestion on to the developers

1 Like