一、混淆矩阵介绍
混淆矩阵的每一列代表了预测类别,每一列的总数表示预测为该类别的数据的数目;每一行代表了数据的真实归属类别,每一行的数据总数表示该类别的数据实例的数目。每一列中的数值表示真实数据被预测为该类的数目。
以下图为例,第一行的数值总和为2+0+0=2,表示ant类别共有2个样本,其中,有2个样本被预测为ant类别,0个样本被预测为bird类别,0个样本被预测为cat类别,即ant类别的图像全预测正确了。其他行同理。
上面这个混淆矩阵并没有归一化,对其进行归一化后的结果如下。以第三行为例进行解释:0.33表示有33%的cat图像被预测为了ant,0%的cat图像被预测为bird,也即没有cat图像被预测为bird,67%的cat图像被预测为cat。
上面说这么多,主要是想让大家直观地理解混淆矩阵到底是怎么一回事。总是,混淆矩阵可以让我们清晰地看到网络的错分情况。
二、绘制混淆矩阵
在下面这个代码中,主要用到的两个函数分别是:库函数confusion_matrix 和 自定义函数plot_confusion_matrix。其中,库函数只需要安装【scikit】包,具体安装命令如下。
自定义函数plot_confusion_matrix大家直接粘贴下面的代码就行。
实际应用时,大家只需要 改一下 下述代码中的 真实标签y_true 和 预测标签y_pred ,及 标签名称label_name 即可。需要注意的是 label_name的顺序是按0,1,2的顺序排的 ,即因为ant的数字标签为0,因此它在第一位,bird的数字标签为1,因此它在第二位,cat的数字标签为2,因此它在第三位。以此类推。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# 绘制混淆矩阵的函数
def plot_confusion_matrix(cm, labels_name, title="Confusion Matrix", is_norm=True, colorbar=True, cmap=plt.cm.Blues):
if is_norm==True:
cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis],2) # 横轴归一化并保留2位小数
plt.imshow(cm, interpolation='nearest', cmap=cmap) # 在特定的窗口上显示图像
for i in range(len(cm)):
for j in range(len(cm)):
plt.annotate(cm[j, i], xy=(i, j), horizontalalignment='center', verticalalignment='center') # 默认所有值均为黑色
# plt.annotate(cm[j, i], xy=(i, j), horizontalalignment='center', color="white" if i==j else "black", verticalalignment='center') # 将对角线值设为白色
if colorbar:
plt.colorbar() # 创建颜色条
num_local = np.array(range(len(labels_name)))
plt.xticks(num_local, labels_name) # 将标签印在x轴坐标上
plt.yticks(num_local, labels_name) # 将标签印在y轴坐标上
plt.title(title) # 图像标题
plt.ylabel('True label')
plt.xlabel('Predicted label')
if is_norm==True:
plt.savefig(r'.\cm_norm_' + '.png', format='png')
else:
plt.savefig(r'.\cm_' + '.png', format='png')
plt.show() # plt.show()在plt.savefig()之后
plt.close()
y_true = [2, 0, 2, 2, 0, 1] # 真实标签
y_pred = [0, 0, 2, 2, 0, 2] # 预测标签
label_name = ['ant', 'bird', 'cat']
cm = confusion_matrix(y_true, y_pred) # 调用库函数confusion_matrix
plot_confusion_matrix(cm, label_name, "Confusion Matrix", is_norm=False) # 调用上面编写的自定义函数
plot_confusion_matrix(cm, label_name, "Confusion Matrix", is_norm=True) # 经过归一化的混淆矩阵
三、在深度学习代码中添加绘制混淆矩阵模块
在上述代码中,真实标签和预测标签都给定好了,那么如何在深度学习中根据图像真实标签和预测标签,从而对每个Epoch的错分情况进行绘制呢?具体做法如下,只需要在模型主函数的测试模块中,加入下述几行代码,即可。(注:笔者是做表情识别方向的,因此类别数总共有7种。)
文章来源:https://www.toymoban.com/news/detail-423585.html
至此,本博文就结束了。如果本文对你有所帮助的话,欢迎订阅本专栏。永远相信美好的事情即将发生。文章来源地址https://www.toymoban.com/news/detail-423585.html
到了这里,关于【论文必用】Python绘制混淆矩阵的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!