pytorch 计算混淆矩阵

这篇具有很好参考价值的文章主要介绍了pytorch 计算混淆矩阵。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

混淆矩阵是评估模型结果的一种指标 用来判断分类模型的好坏

pytorch混淆矩阵代码,CV基础知识,矩阵,pytorch,机器学习,人工智能,python

 预测对了 为对角线 

还可以通过矩阵的上下角发现哪些容易出错

从这个 矩阵出发 可以得到 acc != precision recall  特异度?

pytorch混淆矩阵代码,CV基础知识,矩阵,pytorch,机器学习,人工智能,python

 pytorch混淆矩阵代码,CV基础知识,矩阵,pytorch,机器学习,人工智能,python

 目标检测01笔记AP mAP recall precision是什么 查全率是什么 查准率是什么 什么是准确率 什么是召回率_:)�东东要拼命的博客-CSDN博客

pytorch混淆矩阵代码,CV基础知识,矩阵,pytorch,机器学习,人工智能,python

 acc  是对所有类别来说的

其他三个都是 对于类别来说的

pytorch混淆矩阵代码,CV基础知识,矩阵,pytorch,机器学习,人工智能,python

pytorch混淆矩阵代码,CV基础知识,矩阵,pytorch,机器学习,人工智能,python

下面给出源码 

import json
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from prettytable import PrettyTable
from torchvision import datasets
from torchvision.models import MobileNetV2
from torchvision.transforms import transforms


class ConfusionMatrix(object):
    """
    注意版本问题,使用numpy来进行数值计算的
    """

    def __init__(self, num_classes: int, labels: list):
            self.matrix = np.zeros((num_classes, num_classes))
            self.num_classes = num_classes
            self.labels = labels

    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[t, p] += 1

# 行代表预测标签 列表示真实标签




    def summary(self):
        # calculate accuracy
        sum_TP = 0
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]
        acc = sum_TP / np.sum(self.matrix)
        print("acc is", acc)

        # precision, recall, specificity
        table = PrettyTable()
        table.fields_names = ["", "pre", "recall", "spec"]
        for i in range(self.num_classes):
            TP = self.matrix[i, i]
            FP = np.sum(self.matrix[i, :]) - TP
            FN = np.sum(self.matrix[:, i]) - TP
            TN = np.sum(self.matrix) - TP - FP - FN
            pre = round(TP / (TP + FP), 3)    # round 保留三位小数
            recall = round(TP / (TP + FN), 3)
            spec = round(TN / (FP + FN), 3)
            table.add_row([self.labels[i], pre, recall, spec])
        print(table)


    def plot(self):
        matrix = self.matrix
        print(matrix)
        plt.imshow(matrix, cmap=plt.cm.Blues)  # 颜色变化从白色到蓝色

        # 设置 x  轴坐标 label
        plt.xticks(range(self.num_classes), self.labels, rotation=45)
        # 将原来的 x 轴的数字替换成我们想要的信息 self.num_classes  x 轴旋转45度
        # 设置 y  轴坐标 label
        plt.yticks(range(self.num_classes), self.labels)

        # 显示 color bar  可以通过颜色的密度看出数值的分布
        plt.colorbar()
        plt.xlabel("true_label")
        plt.ylabel("Predicted_label")
        plt.title("ConfusionMatrix")

        # 在图中标注数量 概率信息
        thresh = matrix.max() / 2
        # 设定阈值来设定数值文本的颜色 开始遍历图像的时候一般是图像的左上角
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 这里矩阵的行列交换,因为遍历的方向 第y行 第x列
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        plt.tight_layout()
        # 图形显示更加的紧凑
        plt.show()



if __name__ ==' __main__':
    device = torch.device("cuda:0" if torch.cuda.is_available()else "cpu")
    print(device)
    # 使用验证集的预处理方式
    data_transform = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor()
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    data_loot = os.path.abspath(os.path.join(os.getcwd(), "../.."))
    # get data root path
    image_path = data_loot + "/data_set/flower_data/"
    # flower data set path

    validate_dataset = datasets.ImageFolder(root=image_path +"val",
                                            transform=data_transform)

    batch_size = 16
    validate_loader = torch.utils.data.DataLoder(validate_dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=2)

    net = MobileNetV2(num_classes=5)
    #加载预训练的权重
    model_weight_path = "./MobileNetV2.pth"
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    net.to(device)

    #read class_indict
    try:
        json_file = open('./class_indicts.json', 'r')
        class_indict = json.load(json_file)
    except Exception as e:
        print(e)
        exit(-1)


    labels = [label for _, label in class_indict.item()]
    # 通过json文件读出来的label
    confusion = ConfusionMatrix(num_classes=5, labels=labels)
    net.eval()
    # 启动验证模式
    # 通过上下文管理器  no_grad  来停止pytorch的变量对梯度的跟踪
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            outputs = torch.softmax(outputs, dim=1)
            outputs = torch.argmax(outputs, dim=1)
            # 获取概率最大的元素
            confusion.update(outputs.numpy(), val_labels.numpy())
            # 预测值和标签值
    confusion.plot()
    # 绘制混淆矩阵
    confusion.summary()
    # 来打印各个指标信息
































是这样的 这篇算是一个学习笔记,其中的基础图都源于我的导师

pytorch混淆矩阵代码,CV基础知识,矩阵,pytorch,机器学习,人工智能,python

 霹雳吧啦Wz的个人空间_哔哩哔哩_bilibili

欢迎无依无靠的CV同学加入 

讲的非常好 代码其实也是导师给的 

我能做的就是读懂每一行加点注释

给不想看视频的同学留点时间文章来源地址https://www.toymoban.com/news/detail-781621.html

到了这里,关于pytorch 计算混淆矩阵的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Pytorch基础知识点复习

    本篇博客是本人对pytorch使用的查漏补缺,参考资料来自 深入浅出PyTorch,本文主要以提问的方式对知识点进行回顾,小伙伴们不记得的知识点可以查一下前面的教程哦。   现在并行计算的策略是 不同的数据分布到不同的设备中,执行相同的任务(Data parallelism) 。   它的逻

    2024年01月20日
    浏览(41)
  • 矩阵的基础知识

    一、矩阵的定义  矩阵:一个由m×n个元素排成的m行n列的表。 矩阵的常规存储:将矩阵描述成一个二维数组。 矩阵的常规存储的特点:1.可以对其元素进行随机存取 2.矩阵的运算非常简单 3.存储密度为1  矩阵的压缩存储:1.为多个相同的非零元素只分配一个存储空间 2.对零元

    2024年02月06日
    浏览(44)
  • MATLAB:矩阵(基础知识)

    1.矩阵的输入 2.调用矩阵 3.子数组的赋值 1. 矩阵的构造与操作 zeros 生成元素全为0的矩阵 ones 生成元素全为1的矩阵 eye 生成单位矩阵 rand 生成随机矩阵 fliplr 矩阵左右翻转 flipud 矩阵上下翻转 triu  矩阵的上三角部分 tril 矩阵的下三角部分 diag 对角矩阵 full 将稀疏矩阵化为普通

    2023年04月08日
    浏览(38)
  • MATLAB矩阵基础知识(一)

            MATLAB即Matrix Laboratory(矩阵实验室),可见MATLAB在矩阵问题上的优势,本次内容主要关于矩阵的生成调用。         矩阵是由m*n个数组成的m行n列的数表,也可以看做m个n维向量组成。若m=n则矩阵为n阶仿真。 矩阵的生成  1、直接通过键盘输入生成矩阵是最常用的

    2024年02月10日
    浏览(57)
  • 深度学习基础知识-pytorch数据基本操作

    1.1.1 数据结构 机器学习和神经网络的主要数据结构,例如                 0维:叫标量,代表一个类别,如1.0                 1维:代表一个特征向量。如  [1.0,2,7,3.4]                 2维:就是矩阵,一个样本-特征矩阵,如: [[1.0,2,7,3.4 ]                   

    2024年02月11日
    浏览(48)
  • 知识储备--基础算法篇-矩阵

    第一题上来就跪了,看了官方答案感觉不是很好理解,找了一个比较容易理解的。 还有一个暴力方法,其中有几个知识点, list的[]中有三个参数,用冒号分割 list[param1:param2:param3] param1,相当于start_index,可以为空,默认是0 param2,相当于end_index,可以为空,默认是list.size p

    2024年02月10日
    浏览(31)
  • MATLAB基础知识之数组与矩阵

    本文是参考书籍《MATLAB R2020a完全自学一本通 》自己整理的一些笔记和一些练习,希望会给大家带来一些帮助。 目录 1、数组创建与运算 1.1数组的创建 1.2数组的运算 1.2.1 算术运算  1.2.2关系运算与逻辑运算  2、矩阵的构造与操作 2.1矩阵的构造 2.2矩阵的操作 2.3矩阵索引  2

    2024年02月07日
    浏览(41)
  • matlab基础知识加矩阵运算初步

    ** matlab(matrix laboratory)** 功能符号 1.分号(;) 不让matlab显示运算结果,抑制输出 2.续行号(…) 某行命令太长,指令行必须多行书写时,使用“…\\\"处理,表示下一行是上一行的连续 常用指令 1.cd 显示或改变工作目录 2.clc 清空命令行窗口 3.clear 清除所有变量 clear+变量名 清除一

    2024年02月10日
    浏览(34)
  • 【Pytorch基础知识】数据的归一化和反归一化

    一张正常的图,或者说是人眼习惯的图是这样的: 但是,为了 神经网络更快收敛 ,我们在深度学习网络过程中 通常需要将读取的图片转为tensor并归一化 (此处的归一化指 transforms .Normalize()操作)输入到网络中进行系列操作。 如果将转成的tensor再直接转为图片,就会变成下

    2023年04月09日
    浏览(84)
  • Python库第一课:基础Numpy知识(下):矩阵

            好的,我们今天继续来学习Numpy的基础,昨天,已经介绍完Numpy的成员之一——数组,今天,在接着介绍其另一大成员——矩阵,也是应用非常广泛的成员。         矩阵,在线性代数中是几乎贯穿全文的成员,因此,这里需要较高的线性代数的基础。在这里,默认

    2024年02月03日
    浏览(54)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包