How to Improve Confusion Matrix Visualization in Python for Large-Scale Classification?

How can I plot confusion matrix Python effectively?

I am using scikit-learn for text classification, where I classify 22,000 documents into 100 classes. I use scikit-learn’s confusion_matrix method to compute the confusion matrix:

from sklearn.linear_model import LogisticRegression
from sklearn import metrics
import matplotlib.pyplot as plt

model1 = LogisticRegression()
model1 = model1.fit(matrix, labels)
pred = model1.predict(test_matrix)
cm = metrics.confusion_matrix(test_labels, pred)
print(cm)
plt.imshow(cm, cmap='binary')

This generates a confusion matrix like the following:

[[3962  325    0 ...,    0    0    0]
 [ 250 2765    0 ...,    0    0    0]
 [   2    8   17 ...,    0    0    0]
 ..., 
 [   1    6    0 ...,    5    0    0]
 [   1    1    0 ...,    0    0    0]
 [   9    0    0 ...,    0    0    9]]

The problem is that the plot is not clear or legible. Is there a better way to plot confusion matrix Python for better visualization and clarity?

From my experience, one of the most common issues I’ve seen when plotting confusion matrices is that the default imshow() method in Matplotlib can sometimes make it difficult to visually differentiate the values, especially when your dataset is large. A small improvement is to use a color bar and adjust the matrix to make the plot clearer. Here’s how you can do that:

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm):
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()  # Adds a color bar for better clarity
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()

# Compute confusion matrix
cm = confusion_matrix(test_labels, pred)
plot_confusion_matrix(cm)

This approach uses imshow with a blue color palette, which works well for distinguishing values. The addition of the color bar enhances clarity, helping you quickly interpret the matrix values based on color intensity. A simple tweak, but it can make your plot confusion matrix Python more readable.

I’ve been working with data visualizations for a while, and honestly, one of the most elegant solutions I’ve found is using seaborn to plot confusion matrix Python. Seaborn’s heatmap is more aesthetically pleasing and allows for direct annotation, which really helps when you need to present your results clearly. It also adds more flexibility for color schemes and fine-tuning.

import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

# Compute confusion matrix
cm = confusion_matrix(test_labels, pred)

# Plot using seaborn heatmap
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='YlGnBu')  # Adds annotation in the cells
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

What’s great about this is that it automatically annotates each cell with the corresponding number, so you don’t need to worry about readability. You can also easily adjust the color map to match your needs, making it more intuitive to read the matrix, especially when you’re working with a large number of classes.

I’ve been using scikit-learn for quite a while now, and I’ve found that ConfusionMatrixDisplay is a great built-in method to directly plot confusion matrix Python without too much customization. It gives you a standardized plot that looks professional and is easy to interpret, especially with large datasets.

from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

# Compute confusion matrix
cm = confusion_matrix(test_labels, pred)

# Plot using ConfusionMatrixDisplay
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='viridis', xticks_rotation='vertical', figsize=(10, 8))
plt.title('Confusion Matrix')
plt.show()

This method automatically handles the formatting of labels and axis, and you can fine-tune the color map and figure size to match your needs. It’s especially helpful when you want a quick, standardized way to plot a large confusion matrix Python, and the automatic axis labeling makes it less prone to errors.