混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

这篇具有很好参考价值的文章主要介绍了混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

目录

1. Confusion Matrix

2. 其他的性能指标

3. example

4. 代码实现混淆矩阵

5.  测试,计算混淆矩阵

6. show

7. 代码


1. Confusion Matrix

混淆矩阵可以将真实标签和预测标签的结果以矩阵的形式表示出来,相比于之前计算的正确率acc更加的直观。

如下,是花分类的混淆矩阵:

之前计算的acc = 预测正确的个数 / 总个数 = 对角线的和 / 矩阵的总和

混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

 

2. 其他的性能指标

除了准确率之外,还有别的指标可能更加方便的知道每一个类别的预测情况。

在介绍下面的内容之前,需要了解一些名词

混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

其中,T都是True预测正确的,F都是False预测错误的。P是正确的label,N是错误的label

TP和TN都是是预测正确的类别。两者说明网络都可以正常分类,TP是真实值比如是猫,预测也是猫。TN是真实值为非猫,预测的结果也是非猫

FP和FN都是预测错误的。两者说明网络都不能正常分类,FN是说,真实值是猫,预测为非猫,FP是说真实值为非猫,预测为猫

方便的记法,T就是网络正确预测,P就是正确的类别。

例如:

TP,就是网络预测是对的,标签也是对的(猫)。

FP就是网络预测错的,标签是对的类别(也就是label是猫,网络预测是非猫,因为F代表错误的)。

FN就是,预测是错误的,N代表不是真正的标签,所以预测出来的是错误的正样本

TN就是,预测是对的,N代表不是正确的类别,所有预测出来也不是正确的类别

常见的有下面几种性能指标:除了准确率,其余的都是针对特定的类别计算的

混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

 

3. example

比如,下面的为三分类的混淆矩阵

混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

准确率 = 预测正确的 / 样本的总数 = (TP + TN) / (TP+TN+FP+FN) = (10+15+20)/66=0.68

下面都是针对于猫的其三个指标:

精确率 = TP / (TP+FP) = 10 / (10+1+2) = 0.77

精确度也叫查准率Precision,也就是预测为正样本中,真正正样本的比率

召回率 = TP/ (TP + FN) = 10 / (10 +3+5) = 0.56

召回率是说真正正样本中,预测为正样本的比率

特异度 = TN / (TN+FP) = (15+4+20+6) / (15+4+20+6+1+2) = 0.94

4. 代码实现混淆矩阵

首先,实现一个混淆矩阵类

混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

 

然后更新混淆矩阵的值,传入预测和真正的标签,横坐标是真实值,纵坐标是预测值

p代表矩阵的行,也就是预测,t代表矩阵的列,就是真实

混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

 

各项指标的计算

混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

 

接着打印混淆矩阵

混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

 

5.  测试,计算混淆矩阵

这里用的是之前的resnet34的迁移学习模型,数据是CIFAR10数据集

首先创建混淆矩阵类,上面注释的是手动编写的类别,下面是json文件提取的

注意这里混淆矩阵类,传入的第一个参数是混淆矩阵的size,也就是分类的个数。labels是一个list列表,存放不同的类名

混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

 

更新打印混淆矩阵

混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

 

6. show

混淆矩阵:

混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

 

输出控制台:

混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

观察可以发现召回率recall,就是对应对角线的值 / 1000

不难理解,因为recall = TP / (TP+FN),而分母就是label的个数,CIFAR10的测试集有1W张图像,共有10个类别,刚好每个是1k张图像,所有recall的分母都是1k

召回率,真正正样本中预测为正样本的个数

 将混淆矩阵输出的图关闭后,会打印性能指标

混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

 

7. 代码

混淆矩阵放在utils中,utils代码:

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

import matplotlib.pyplot as plt
import numpy as np
from prettytable import PrettyTable


# 计算混淆矩阵
class ConfusionMatrix(object):
    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))  # 初始化混淆矩阵
        self.num_classes = num_classes
        self.labels = labels

    def update(self, preds, labels):    # 计算混淆矩阵的值
        for p, t in zip(preds, labels):
            self.matrix[p, t] += 1

    def summary(self):          # 计算各项指标
        # calculate accuracy
        sum_TP = 0
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]        # 对角线的和
        acc = sum_TP / np.sum(self.matrix)     # 混淆矩阵的和
        print("the model accuracy is ", acc)

        # precision, recall, specificity
        table = PrettyTable()
        table.field_names = ["", "Precision", "Recall", "Specificity"]  # 表格的tittle
        for i in range(self.num_classes):
            TP = self.matrix[i, i]                      # label为真,预测为真
            FP = np.sum(self.matrix[i, :]) - TP         # label为假,预测为真
            FN = np.sum(self.matrix[:, i]) - TP         # label为假,预测为真
            TN = np.sum(self.matrix) - TP - FP - FN     # label为假,预测为假
            Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
            Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
            Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
            table.add_row([self.labels[i], Precision, Recall, Specificity])
        print(table)

    def plot(self):
        matrix = self.matrix
        print(matrix)
        plt.imshow(matrix, cmap=plt.cm.Blues)

        plt.xticks(range(self.num_classes), self.labels, rotation=45)       # 设置x轴坐标label
        plt.yticks(range(self.num_classes), self.labels)        # 设置y轴坐标label
        plt.colorbar()      # 显示 colorbar

        plt.xlabel('True Labels')
        plt.ylabel('Predicted Labels')
        plt.title('Confusion matrix')

        thresh = matrix.max() / 2   # 在图中标注数量/概率信息
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 注意这里的matrix[y, x]不是matrix[x, y]
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        plt.tight_layout()
        plt.show()

网络model:这里是resnet的代码

import torch
import torch.nn as nn


# residual block
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,in_channel,out_channel,stride=1,downsample=None):
        super(BasicBlock,self).__init__()
        self.conv1 = nn.Conv2d(in_channel,out_channel,kernel_size=3,stride=stride,padding=1,bias=False) # 第一层的话,可能会缩小size,这时候 stride = 2
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(out_channel,out_channel,kernel_size=3,stride=1,padding=1,bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self,x):
        identity = x
        if self.downsample is not None:     # 有下采样,意味着需要1*1进行降维,同时channel翻倍,residual block虚线部分
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


# bottleneck
class Bottleneck(nn.Module):
    expansion = 4       # 卷积核的变化

    def __init__(self,in_channel,out_channel,stride=1,downsample=None):
        super(Bottleneck,self).__init__()
        # 1*1 降维度 --------> padding默认为 0,size不变,channel被降低
        self.conv1 = nn.Conv2d(in_channel,out_channel,kernel_size=1,stride=1,bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        # 3*3 卷积
        self.conv2 = nn.Conv2d(out_channel,out_channel,kernel_size=3,stride=stride,bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        # 1*1 还原维度 --------> padding默认为 0,size不变,channel被还原
        self.conv3 = nn.Conv2d(out_channel,out_channel*self.expansion,kernel_size=1,stride=1,bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        # other
        self.relu = nn.ReLU(inplace=True)
        self.downsample =downsample

    def forward(self,x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


# resnet
class ResNet(nn.Module):
    def __init__(self,block,block_num,num_classes=1000,include_top=True):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64        # max pool 之后的 depth
        # 网络最开始的部分,输入是RGB图像,经过卷积,图像size减半,通道变为64
        self.conv1 = nn.Conv2d(3,self.in_channel,kernel_size=7,stride=2,padding=3,bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)   # size减半,padding = 1

        self.layer1 = self.__make_layer(block,64,block_num[0])                # conv2_x
        self.layer2 = self.__make_layer(block,128,block_num[1],stride=2)      # conv3_x
        self.layer3 = self.__make_layer(block,256,block_num[2],stride=2)      # conv4_X
        self.layer4 = self.__make_layer(block,512,block_num[3],stride=2)      # conv5_x

        if self.include_top:    # 分类部分
            self.avgpool = nn.AdaptiveAvgPool2d((1,1))      # out_size = 1*1
            self.fc = nn.Linear(512*block.expansion,num_classes)

    def __make_layer(self,block,channel,block_num,stride=1):
        downsample =None
        if stride != 1 or self.in_channel != channel*block.expansion:     # shortcut 部分,1*1 进行升维
            downsample=nn.Sequential(
                nn.Conv2d(self.in_channel,channel*block.expansion,kernel_size=1,stride=stride,bias=False),
                nn.BatchNorm2d(channel*block.expansion)
            )

        layers =[]
        layers.append(block(self.in_channel, channel, downsample =downsample, stride=stride))
        self.in_channel = channel * block.expansion

        for _ in range(1,block_num):    # residual 实线的部分
            layers.append(block(self.in_channel,channel))

        return nn.Sequential(*layers)

    def forward(self,x):
        # resnet 前面的卷积部分
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        # residual 特征提取层
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # 分类
        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x,start_dim=1)
            x = self.fc(x)

        return x


# 定义网络
def resnet34(num_classes=1000,include_top=True):
    return ResNet(BasicBlock,[3,4,6,3],num_classes=num_classes,include_top=include_top)


def resnet101(num_classes=1000,include_top=True):
    return ResNet(Bottleneck,[3,4,23,3],num_classes=num_classes,include_top=include_top)

主函数main:文章来源地址https://www.toymoban.com/news/detail-427565.html

import torch
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet34
from utils import ConfusionMatrix
import json


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    data_transform = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # 加载数据
    validate_dataset = datasets.CIFAR10(root='./data',train=False,transform=data_transform)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=16, shuffle=True)

    # 加载网络
    net = resnet34(num_classes=10)
    model_weight_path = "./resnet.pth"
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    net.to(device)

    # 类别
    # classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    # labels = [label for label in classes]
    # confusion = ConfusionMatrix(num_classes=10, labels=labels)

    # 类别
    json_label_path = './class_indices.json'
    json_file = open(json_label_path, 'r')
    class_indict = json.load(json_file)

    labels = [label for _, label in class_indict.items()]
    confusion = ConfusionMatrix(num_classes=10, labels=labels)

    net.eval()
    with torch.no_grad():
        for val_data in tqdm(validate_loader):
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            outputs = torch.softmax(outputs, dim=1)
            outputs = torch.argmax(outputs, dim=1)
            confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy())   # 更新混淆矩阵的值
    confusion.plot()         # 绘制混淆矩阵
    confusion.summary()      # 计算指标

到了这里,关于混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • python:多分类-计算混淆矩阵confusion_matrix、precision、recall、f1-score分数

    多分类,计算混淆矩阵confusion_matrix,以及accuracy、precision、recall、f1-score分数。 1)使用sklearn计算并画出 混淆矩阵(confusion_matrix) ; 2)使用sklearn计算 accuracy(accuracy_score) ; 3)使用sklearn计算多分类的 precision、recall、f1-score分数 。以及计算每个类别的precision、recall、f1-

    2024年02月06日
    浏览(41)
  • 基于ResNet-18实现Cifar-10图像分类

    安耀辉,男,西安工程大学电子信息学院,22级研究生 研究方向:小样本图像分类算法 电子邮箱:1349975181@qq.com 张思怡,女,西安工程大学电子信息学院,2022级研究生,张宏伟人工智能课题组 研究方向:机器视觉与人工智能 电子邮件:981664791@qq.com CIFAR-10 数据集由 60000张图

    2024年02月06日
    浏览(45)
  • Resnet实现CIFAR-10图像分类 —— Mindspore实践

            计算机视觉是当前深度学习研究最广泛、落地最成熟的技术领域,在手机拍照、智能安防、自动驾驶等场景有广泛应用。从2012年AlexNet在ImageNet比赛夺冠以来,深度学习深刻推动了计算机视觉领域的发展,当前最先进的计算机视觉算法几乎都是深度学习相关的。深

    2024年02月07日
    浏览(39)
  • Resnet18训练CIFAR10 准确率95%

    准确率 95.31% 几个关键点: 1、改模型:原始的resnet18首层使用的7x7的卷积核,CIFAR10图片太小不适合,要改成3x3的,步长和padding都要一并改成1。因为图太小,最大池化层也同样没用,删掉。最后一个全连接层输出改成10。 2、图片增强不要太多,只要训练集和验证集结果没有出

    2024年02月02日
    浏览(39)
  • 【神经网络】(10) Resnet18、34 残差网络复现,附python完整代码

    各位同学好,今天和大家分享一下 TensorFlow 深度学习 中如何搭载 Resnet18 和 Resnet34 残差神经网络,残差网络 利用 shotcut 的方法成功解决了网络退化的问题 ,在训练集和校验集上,都证明了的更深的网络错误率越小。 论文中给出的具体的网络结构如下: Resnet50 网络结构 我已

    2023年04月08日
    浏览(40)
  • 基于ResNet34的花朵分类

    新建一个项目文件夹ResNet,并在里面建立data_set文件夹用来保存数据集,在data_set文件夹下创建新文件夹\\\"flower_data\\\",点击链接下载花分类数据集https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz,会下载一个压缩包,将它解压到flower_data文件夹下,执行\\\"split_d

    2024年02月07日
    浏览(45)
  • 实验记录resnet20/cifar100

    Cifar100 / resnet20: 1、 Baseline Namespace:(batch_size=128, decay=0.0003, epoch=200, gammas=[0.1, 0.1, 0.5],  learning_rate=0.1, momentum=0.9, optimizer=\\\'SGD\\\',  schedule=[80, 120, 160]) Best acc: 68.85% 80 和 120 是拐点 2、batch_size, gammas Namespace(batch_size=512,  decay=0.0003, epoch=200, gammas=[0.1, 0.1, 0.1],  learning_rate=0.1, momentum=

    2024年02月12日
    浏览(32)
  • 基于 PyTorch 的 cifar-10 图像分类

    本文的主要内容是基于 PyTorch 的 cifar-10 图像分类,文中包括 cifar-10 数据集介绍、环境配置、实验代码、运行结果以及遇到的问题这几个部分,本实验采用了基本网络和VGG加深网络模型,其中VGG加深网络模型的识别准确率是要优于基本网络模型的。 cifar-10 数据集由 60000 张分辨

    2023年04月24日
    浏览(42)
  • 深度学习pytorch实战五:基于ResNet34迁移学习的方法图像分类篇自建花数据集图像分类(5类)超详细代码

    1.数据集简介 2.模型相关知识 3.split_data.py——训练集与测试集划分 4.model.py——定义ResNet34网络模型 5.train.py——加载数据集并训练,训练集计算损失值loss,测试集计算accuracy,保存训练好的网络参数 6.predict.py——利用训练好的网络参数后,用自己找的图像进行分类测试 1.自建

    2024年02月09日
    浏览(55)
  • CNN实现与训练--------------以cifar10数据集为例进行演示(基于Tensorflow)

    本文以cifar10数据集为例进行演示 (cifar10数据集有5万张32 32像素点的彩色图片,用于训练有1万张32 32像素点的彩色图片,用于测试)

    2024年02月08日
    浏览(35)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包