基于MMdetection框架的目标检测研究-6.混淆矩阵绘制

这篇具有很好参考价值的文章主要介绍了基于MMdetection框架的目标检测研究-6.混淆矩阵绘制。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

文章背景:

当我们训练完模型后,我们需要用训练后的模型对正负样本图片进行目标检测测试,这时候我们需要算模型在新的数据集上的检测效果(精度、过杀率、漏检率,准确度等),这时候使用测试后的结果绘制成混淆矩阵,可以很方便的帮助我们呈现和理解模型的泛化能力。

核心代码:

# -*- coding=utf-8 -*-
'''
功能说明:根据已有的分类数据,绘制相应的混淆矩阵,便于统计过杀率和漏检率
'''
import numpy as np
import matplotlib.pyplot as plt
# 修改类别列表中的数据和矩阵中数据可以绘制多类混淆矩阵
classes = ['OK ','NG']
confusion_matrix = np.array([(20,5),(5,55)],dtype=np.float64)
plt.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Oranges)  #按照像素显示出矩阵
plt.title('confusion_matrix')
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
thresh = confusion_matrix.max() / 2.
#iters = [[i,j] for i in range(len(classes)) for j in range((classes))]
#ij配对,遍历矩阵迭代器
iters = np.reshape([[[i,j] for j in range(len(classes))] for i in range(len(classes))],(confusion_matrix.size,2))
for i, j in iters:
    plt.text(j, i, format(confusion_matrix[i, j]))   #显示对应的数字
plt.ylabel('Real label')
plt.xlabel('Prediction')
plt.tight_layout()
#plt.show()
# 保存每次生成的图像
f = plt.gcf()  #获取当前图像
f.savefig(r'./{}.png'.format('result'))# 一定要放到plt.show()前面,否则保存图像为空白
plt.show()#plt.show() 后实际上已经创建了一个新的空白的图片
#f.clear()  #释放内存,迭代保存的时候,plt.plot()会出现多根线在一张图叠加,可以加这句话
print('混淆矩阵图像绘制结束并保存在当前路径下。')

结果显示如下,并在代码路径下保存生成结果:

mmdetection混淆矩阵,MMdetection,python,目标检测,混淆矩阵,python,MMdetection

混淆矩阵图分析: 

该混淆矩阵结果图表示的是,OK实际测试样本有25个,预测为OK的样本有20个,预测为NG的样本有5个。NG实际测试样本有60个,预测为NG的有55个,预测为OK的样本有5个。

绘制多类分类矩阵:

#confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
classes = ['A','B','C','D','E']
confusion_matrix = np.array([(9,1,3,4,0),(2,13,1,3,4),(1,4,10,0,13),(3,1,1,17,0),(0,0,0,1,14)],dtype=np.float64)
 
plt.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Oranges)  #按照像素显示出矩阵
plt.title('confusion_matrix')
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
 
thresh = confusion_matrix.max() / 2.
#iters = [[i,j] for i in range(len(classes)) for j in range((classes))]
#ij配对,遍历矩阵迭代器
iters = np.reshape([[[i,j] for j in range(5)] for i in range(5)],(confusion_matrix.size,2))
for i, j in iters:
    plt.text(j, i, format(confusion_matrix[i, j]))   #显示对应的数字
 
plt.ylabel('Real label')
plt.xlabel('Prediction')
plt.tight_layout()
plt.show()

结果如下:

mmdetection混淆矩阵,MMdetection,python,目标检测,混淆矩阵,python,MMdetection

封装成函数绘制矩阵: 

from __future__ import division
import  numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
 
def plotCM(classes, matrix, savname):
    """classes: a list of class names"""
    # Normalize by row
    matrix = matrix.astype(np.float)
    linesum = matrix.sum(1)
    linesum = np.dot(linesum.reshape(-1, 1), np.ones((1, matrix.shape[1])))
    matrix /= linesum
    # plot
    plt.switch_backend('agg')
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(matrix)
    fig.colorbar(cax)
    ax.xaxis.set_major_locator(MultipleLocator(1))
    ax.yaxis.set_major_locator(MultipleLocator(1))
    for i in range(matrix.shape[0]):
        ax.text(i, i, str('%.2f' % (matrix[i, i] * 100)), va='center', ha='center')
    
    ax.set_xticklabels([''] + classes, rotation=90)
    ax.set_yticklabels([''] + classes)
    #save
    plt.savefig(savname)
    
classes = ['A','B','C','D','E']
matrix = np.array(([9,1,3,4,0],[2,13,1,3,4],[1,4,10,0,13],[3,1,1,17,0],[0,0,0,1,14]),dtype=int)
savname = 'test'
plotCM(classes, matrix, savname)

结果展示:

mmdetection混淆矩阵,MMdetection,python,目标检测,混淆矩阵,python,MMdetection

如果要改变绘制矩阵的颜色,在代码中cmap=plt.cm.Oranges按照如下修改即可:

mmdetection混淆矩阵,MMdetection,python,目标检测,混淆矩阵,python,MMdetection

mmdetection混淆矩阵,MMdetection,python,目标检测,混淆矩阵,python,MMdetection

 mmdetection混淆矩阵,MMdetection,python,目标检测,混淆矩阵,python,MMdetection

 mmdetection混淆矩阵,MMdetection,python,目标检测,混淆矩阵,python,MMdetection

有时候我们需要将混淆矩阵的标签显示为中文,这时候需要我们进行适当的修改才可以 ,否则会出现乱码,代码和效果如下:

#coding=utf-8
import matplotlib.pyplot as plt
import numpy as np

confusion = np.array(([91,0,0],[0,92,1],[0,0,95]))
# 热度图,后面是指定的颜色块,可设置其他的不同颜色
plt.imshow(confusion, cmap=plt.cm.Blues)
# ticks 坐标轴的坐标点
# label 坐标轴标签说明
indices = range(len(confusion))
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
#plt.xticks(indices, [0, 1, 2])
#plt.yticks(indices, [0, 1, 2])
plt.xticks(indices, ['圆形', '三角形', '方形'])
plt.yticks(indices, ['圆形', '三角形', '方形'])

plt.colorbar()

plt.xlabel('预测值')
plt.ylabel('真实值')
plt.title('混淆矩阵')

# plt.rcParams两行是用于解决标签不能显示汉字的问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 显示数据
for first_index in range(len(confusion)):    #第几行
    for second_index in range(len(confusion[first_index])):    #第几列
        plt.text(first_index, second_index, confusion[first_index][second_index])
# 在matlab里面可以对矩阵直接imagesc(confusion)
# 显示
plt.show()

效果如下:

mmdetection混淆矩阵,MMdetection,python,目标检测,混淆矩阵,python,MMdetection

 文章来源地址https://www.toymoban.com/news/detail-585561.html

到了这里,关于基于MMdetection框架的目标检测研究-6.混淆矩阵绘制的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包