混淆矩阵的绘制

这篇具有很好参考价值的文章主要介绍了混淆矩阵的绘制。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

主要展示在分类算法预测的过程中,加入混淆矩阵的绘制。


具体步骤

1.引入库

代码如下(示例):

import argparse

import torch
from torch.backends import cudnn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

from data_loaders import Plain_Dataset, eval_data_dataloader
from model import ResidualNet  # 引入模型

import matplotlib.pyplot as plt

2.设置参数

代码如下(示例):

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser(description="Configuration of testing process")
parser.add_argument('-m', '--model', type=str,default='./model/RestNet18.pt')
parser.add_argument('-depth', default=18, type=int)
parser.add_argument('-d', '--data', type=str, default='')
parser.add_argument('-att_type', default='se', choices=['cbam', 'se'], type=str)
args = parser.parse_args()

transformation = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
test_path = args.data + '/' + 'test'
dataset = Plain_Dataset( img_dir=test_path, datatype='test',transform=transformation)
test_loader =  DataLoader(dataset,batch_size=64,num_workers=0)

# 加载模型
net = ResidualNet('CIFAR10', args.depth, 7, args.att_type)
net.load_state_dict(torch.load(args.model))
net.to(device)

3.混淆矩阵定义

代码如下(示例):

# 混淆矩阵定义
def confusion_matrix(preds,labels,conf_matrix):
    for p,t in zip(preds,labels):
        conf_matrix[p,t] += 1
    return conf_matrix

def plot_maxtrix(maxtrix,per_kinds):
 	# 分类标签
    lables = ['Angry', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sad', 'Surprise']
    
    Maxt = np.empty(shape=[0,7])

    m = 0
    for i in range(7):
        print('row sum:',per_kinds[m])
        f = (maxtrix[m,:]*100)/per_kinds[m]
        Maxt = np.vstack((Maxt,f))
        m = m+1

    thresh = Maxt.max()/1

    plt.imshow(Maxt, cmap=plt.cm.Blues)

    for x in range(7):
        for y in range(7):
            info = float(format('%.1f' % F[y,x]))
            print('info:',info)
            plt.text(x,y,info,verticalalignment='center',horizontalalignment='center')
    plt.tight_layout()
    plt.yticks(range(7),lables)  # y轴标签
    plt.xticks(range(7),lables,rotation=45)  # x轴标签
    plt.savefig('./test.png',bbox_inches='tight')  # bbox_inches='tight'可确保标签信息显示全
    plt.show()

4.计算准确率及绘制混淆矩阵

代码如下(示例):

if __name__ == '__main__':
	with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)

            outputs = net(data)
            pred = F.softmax(outputs,dim=1)
            classs = torch.argmax(pred,1)

            conf_maxtri = confusion_matrix(classs,labels,conf_maxtri)
            conf_maxtri = conf_maxtri.cpu()

            wrong = torch.where(classs != labels,torch.tensor([1.]).cuda(),torch.tensor([0.]).cuda())
            acc = 1- (torch.sum(wrong) / 64)  # 64为batch size
            total.append(acc.item())

    print('测试集的准确率为: %f %%' % (100 * np.mean(total)))
   
    # 绘制混淆矩阵
    conf_maxtri = np.array(conf_maxtri.cpu())
    corrects = conf_maxtri.diagonal(offset=0)
    per_kinds = conf_maxtri.sum(axis=1)
    plot_maxtrix(conf_maxtri,per_kinds)

绘制结果

混淆矩阵怎么画,图像处理,pytorch文章来源地址https://www.toymoban.com/news/detail-607860.html

到了这里,关于混淆矩阵的绘制的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 图像处理 边缘检测 绘制金字塔 模板匹配

    Canny边缘检测器是一种多步算法,用于检测任何输入图像的边缘。 边缘检测步骤: 1.应用 高斯滤波器 ,以平滑图像,滤除噪声( 降噪 ) 2.计算图像中每个像素点的梯度大小(边缘两侧和卷积之间的像素差值和方向(x和y方向)(梯度Scole算子检测边缘) 3.使用非极大值抑制,

    2024年02月06日
    浏览(53)
  • 图像处理中,采用极线约束准则来约束特征点匹配搜索空间,理论上在极线上进行搜索。这里的极线是什么线,怎么定义的?基本矩阵F和本质矩阵E有什么区别?

    问题描述:图像处理中,采用极线约束准则来约束特征点匹配搜索空间,理论上在极线上进行搜索。这里的极线是什么线,怎么定义的?基本矩阵F和本质矩阵E有什么区别? 问题1解答: 极线是通过极线几何学的原理定义的。在摄影测量学和计算机视觉中,极线是由两个相机

    2024年01月19日
    浏览(40)
  • 《数字图像处理-OpenCV/Python》连载(22)绘制直线与线段

    本书京东优惠购书链接:https://item.jd.com/14098452.html 本书CSDN独家连载专栏:https://blog.csdn.net/youcans/category_12418787.html 本章介绍OpenCV的绘图功能和简单的鼠标交互处理方法。与Excel或Matplotlib中的可视化数据图不同,OpenCV中的绘图功能主要用于在图像的指定位置绘制几何图形。 本

    2024年02月02日
    浏览(99)
  • 《数字图像处理-OpenCV/Python》连载(26)绘制椭圆和椭圆弧

    本书京东优惠购书链接:https://item.jd.com/14098452.html 本书CSDN独家连载专栏:https://blog.csdn.net/youcans/category_12418787.html 本章介绍OpenCV的绘图功能和简单的鼠标交互处理方法。与Excel或Matplotlib中的可视化数据图不同,OpenCV中的绘图功能主要用于在图像的指定位置绘制几何图形。 本

    2024年02月06日
    浏览(75)
  • 【opencv+图像处理】Image Processing in OpenCV 1-2基本图形绘制

    🍉 博主微信 cvxiayixiao 🍓 【Segment Anything Model】计算机视觉检测分割任务专栏。 链接 🍑 【公开数据集预处理】特别是医疗公开数据集的接受和预处理,提供代码讲解。链接 🍈 【opencv+图像处理】opencv代码库讲解,结合图像处理知识,不仅仅是调库。链接 本专栏代码地址

    2024年02月08日
    浏览(70)
  • 矩阵迹与图像处理的关联

    矩阵迹与图像处理的关联是一个重要的研究领域,它涉及到计算机视觉、图像处理、数字信号处理等多个领域。在这篇文章中,我们将从以下几个方面进行深入探讨: 背景介绍 核心概念与联系 核心算法原理和具体操作步骤以及数学模型公式详细讲解 具体代码实例和详细解释

    2024年02月20日
    浏览(48)
  • 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割7(数据预处理)

    在上一节:【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割6(数据预处理) 中,我们已经得到了与 mhd 图像同 seriesUID 名称的 mask nrrd 数据文件了,可以说是一一对应了。 并且, mask 的文件,还根据结节被多少人同时标注,区分成了4个文件夹,分别是标注了一、二、三、四次,

    2024年02月07日
    浏览(52)
  • 课程大纲:图像处理中的矩阵计算

    课程名称:《图像处理中的矩阵计算》 课程简介: 图像处理中的矩阵计算是图像分析与处理的核心部分。本课程旨在教授学员如何应用线性代数中的矩阵计算,以实现各种图像处理技术。我们将通过强调实际应用和实践活动来确保学员能够理解和掌握这些概念。 第1章:矩阵

    2024年02月20日
    浏览(37)
  • Pytorch学习笔记(3):图像的预处理(transforms)

      目录  一、torchvision:计算机视觉工具包  二、transforms的运行机制 (1)torchvision.transforms:常用的图像预处理方法 (2)transforms运行原理   三、数据标准化 transforms.Normalize() 四、数据增强  4.1 transforms—数据裁剪 (1)transforms.CentorCrop (2)transforms.RandomCrop (3)RandomResiz

    2023年04月13日
    浏览(47)
  • 图像处理基础——视觉感知要素及空间变化矩阵

    本章重点掌握内容 视觉感知要素 像素间的一些基本关系,计算 空间变换的坐标公式与应用 灰度值插值小结  视网膜感受的颠倒信号,在通过视神经传导到大脑皮层的视觉中枢后,在视觉中枢实现自动翻转。 目录 视觉感知要素 亮度适应与辨别 像素间的一些基本关系 相邻像

    2024年02月03日
    浏览(33)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包