混淆矩阵的生成

这篇具有很好参考价值的文章主要介绍了混淆矩阵的生成。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

混淆矩阵简介

混淆矩阵(Confusion Matrix)是一个二维表格,常用于评价分类模型的性能。在混淆矩阵中,每一列代表了预测值,每一行代表了真实值。因此,混淆矩阵中的每一个元素表示了一个样本被预测为某一类别的次数。混淆矩阵的构成如下:

预测值=正例 预测值=反例
真实值=正例 TP FN
真实值=反例 FP TN

其中,TP表示真正例(True Positive),FN表示假反例(False Negative),FP表示假正例(False Positive),TN表示真反例(True Negative)。

解释如下:

TP:真正例,指的是模型将正例预测为正例的次数;
FN:假反例,指的是模型将正例预测为反例的次数;
FP:假正例,指的是模型将反例预测为正例的次数;
TN:真反例,指的是模型将反例预测为反例的次数。
混淆矩阵的重要性在于,可以通过计算其中的四个元素,得到各种评价指标,如精确度(Accuracy)、召回率(Recall)、准确率(Precision)和 F1 值等。

精确度(Accuracy):表示模型预测正确的样本数与总样本数之比,即 A c c u r a c y = T P + T N T P + F P + F N + T N Accuracy = \frac{TP+TN}{TP+FP+FN+TN} Accuracy=TP+FP+FN+TNTP+TN
召回率(Recall):表示模型正确预测正例样本的比例,即 R e c a l l = T P T P + F N Recall = \frac{TP}{TP+FN} Recall=TP+FNTP
准确率(Precision):表示模型预测为正例的样本中,真正例的比例,即 P r e c i s i o n = T P T P + F P Precision = \frac{TP}{TP+FP} Precision=TP+FPTP
F1 值:综合了准确率和召回率,即 F 1 = 2 × P r e c i s i o n × R e c a l l P r e c i s i o n + R e c a l l F1 = \frac{2\times Precision\times Recall}{Precision+Recall} F1=Precision+Recall2×Precision×Recall
混淆矩阵也可以可视化,可以使用热力图等图形来展示混淆矩阵中每个元素的数值大小,以便更加直观地理解分类模型的性能。

混淆矩阵的主要作用和意义如下:

评估分类器的性能:混淆矩阵可以帮助我们计算分类器的准确率、召回率、精确率、F1分数等指标,从而评估分类器的性能。

比较不同分类器的性能:混淆矩阵可以帮助我们比较不同分类器的性能,找出最优的分类器。

识别分类器的错误类型:混淆矩阵可以帮助我们了解分类器在哪些情况下容易出错,识别出分类器的错误类型,从而针对性地改进分类器。

优化分类器的阈值:混淆矩阵可以帮助我们优化分类器的阈值,从而提高分类器的性能。

可视化分类器的性能:混淆矩阵可以将分类器的性能可视化,从而更直观地了解分类器的性能。

混淆矩阵可视化代码:

import os
from matplotlib.font_manager import FontProperties
import itertools
import matplotlib.pyplot as plt
import numpy as np


# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    - cm : 计算出的混淆矩阵的值
    - classes : 混淆矩阵中每一行每一列对应的列
    - normalize : True:显示百分比, False:显示个数
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("显示百分比:")
        np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
        print(cm)
    else:
        print('显示具体数字:')
        print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    # matplotlib版本问题,如果不加下面这行代码,则绘制的混淆矩阵上下只能显示一半,有的版本的matplotlib不需要下面的代码,分别试一下即可
    plt.ylim(len(classes) - 0.5, -0.5)
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()


cnf_matrix = np.array([[151, 64, 731, 164, 45],
                       [821, 653, 79, 0, 28],
                       [266, 167, 423, 4, 2],
                       [691, 0, 107, 776, 26],
                       [30, 0, 111, 17, 42]])
attack_types = ['Normal', 'DoS', 'Probe', 'R2L', 'U2R']
# 归一化
# plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Confusion matrix')
# 不归一化
plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Confusion matrix')

其中上述有两种方式可以选择,即一种是归一化,一种是不归一化
归一化设置 normalize=True
结果为:
混淆矩阵的生成
不归一化设置 normalize=False
结果为:

混淆矩阵的生成
如果想要配合模型生成混淆矩阵,则需要让神经生成一个混淆矩阵的矩阵序列代码为:

import os
import json

import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from prettytable import PrettyTable

from model import MobileNetV2


class ConfusionMatrix(object):
    """
    注意,如果显示的图像不全,是matplotlib版本问题
    本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常
    需要额外安装prettytable库
    """

    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))
        self.num_classes = num_classes
        self.labels = labels

    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[p, t] += 1

    def plot(self, normalize=False):
        if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("显示百分比:")
        np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
        print(cm)
   		else:
        print('显示具体数字:')
        print(cm)
        
        matrix = self.matrix
	    plt.imshow(matrix , interpolation='nearest', cmap=cmap)
	    plt.title(title)
	    plt.colorbar()
	    tick_marks = np.arange(len(classes))
	    plt.xticks(tick_marks, classes, rotation=45)
	    plt.yticks(tick_marks, classes)
	    # matplotlib版本问题,如果不加下面这行代码,则绘制的混淆矩阵上下只能显示一半,有的版本的matplotlib不需要下面的代码,分别试一下即可
	    plt.ylim(len(classes) - 0.5, -0.5)
	    fmt = '.2f' if normalize else 'd'
	    thresh = cm.max() / 2.
	    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
	        plt.text(j, i, format(cm[i, j], fmt),
	                 horizontalalignment="center",
	                 color="white" if cm[i, j] > thresh else "black")
	    plt.tight_layout()
	    plt.ylabel('True label')
	    plt.xlabel('Predicted label')
	    plt.show()


if __name__ == '__main__':
    mylabel = {"4": "4", "5": "5", "6": "6"}
    num_classes=3 #################################
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
    ROOT_DATA = r'D:/other/ClassicalModel/data/flower_datas'  #################################
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    data_transform = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    validate_dataset = datasets.ImageFolder(root=os.path.join(ROOT_DATA, "val"),
                                            transform=data_transform)

    batch_size = 16
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=2)

    net = MobileNetV2(num_classes=num_classes)  ###########################
    # load pretrain weights
    model_weight_path = r"D:/other/ClassicalModel/MobileNet/runs1/mobilenet_v2.pth"  #########################
    assert os.path.exists(model_weight_path), "cannot find {} file".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    net.to(device)
    labels = [label for _, label in mylabel.items()]
    confusion = ConfusionMatrix(num_classes=num_classes, labels=labels) 
    net.eval()
    with torch.no_grad():
        for val_data in tqdm(validate_loader):
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            outputs = torch.softmax(outputs, dim=1)
            outputs = torch.argmax(outputs, dim=1)
            # print('outputs++'+str(outputs.to("cpu").numpy())+'val_labels++'+str(val_labels.numpy()))
            confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy())
    confusion.plot()

其中*多的地方需要自行修改,例如

    ROOT_DATA = r'D:/other/ClassicalModel/data/flower_datas'  #################################

在这里进行数据集的修改

    mylabel = {"4": "4", "5": "5", "6": "6"}

进行标签的修改

    net = MobileNetV2(num_classes=3)  ###########################

在这里进行网络修改

    model_weight_path = r"D:/other/ClassicalModel/MobileNet/runs1/mobilenet_v2.pth"  #########################

在这里进行本地模型权重的修改文章来源地址https://www.toymoban.com/news/detail-417218.html

到了这里,关于混淆矩阵的生成的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • python:多分类-计算混淆矩阵confusion_matrix、precision、recall、f1-score分数

    多分类,计算混淆矩阵confusion_matrix,以及accuracy、precision、recall、f1-score分数。 1)使用sklearn计算并画出 混淆矩阵(confusion_matrix) ; 2)使用sklearn计算 accuracy(accuracy_score) ; 3)使用sklearn计算多分类的 precision、recall、f1-score分数 。以及计算每个类别的precision、recall、f1-

    2024年02月06日
    浏览(39)
  • Three.js矩阵`Matrix4` 简介

    参考资料:threejs中文网 threejs qq交流群:814702116 前面两节课,给大家介绍了模型矩阵的数学基础理论,下面给大家介绍Three.js的一个矩阵相关类 Matrix4 (4x4矩阵),并用 Matrix4 创建平移矩阵、旋转矩阵、缩放矩阵。 查看4x4矩阵 Matrix4 文档,你可以看到很多相关矩阵相关的数学几

    2024年04月25日
    浏览(37)
  • 混淆矩阵的生成

    混淆矩阵(Confusion Matrix)是一个二维表格,常用于评价分类模型的性能。在混淆矩阵中,每一列代表了预测值,每一行代表了真实值。因此,混淆矩阵中的每一个元素表示了一个样本被预测为某一类别的次数。混淆矩阵的构成如下: 预测值=正例 预测值=反例 真实值=正例 T

    2023年04月18日
    浏览(38)
  • 【np.bincount】np.bincount()用在分割领域生成混淆矩阵

    混淆矩阵:Confusion Matrix,用于直观展示每个类别的预测情况,能从中计算准确率(Accuracy)、精度(Precision)、召回率(Recall)、交并比(IoU)。 混淆矩阵是 n*n 的矩阵(n是类别),对角线上的是正确预测的数量。 每一行之和是该类的真实样本数量,每一列之和是预测为该类的样本数量

    2023年04月10日
    浏览(49)
  • 矩阵补充(matrix completion)

    这篇文章介绍矩阵补充(matrix completion),它是一种向量召回通道。矩阵补充的本质是对用户 ID 和物品 ID 做 embedding,并用两个 embedding 向量的內积预估用户对物品的兴趣。值得注意的是,矩阵补充存在诸多缺点,在实践中效果远不及双塔模型。 上篇文章介绍了embedding,它可

    2024年01月19日
    浏览(41)
  • 对角矩阵(diagonal matrix)

    对角矩阵(英语:diagonal matrix)是一个 主对角线之外的元素皆为 0 的矩阵。对角线上的元素可以为 0 或其他值。 对角矩阵参与矩阵乘法 矩阵 A 左乘一个对角矩阵 D,是分别用 D 的对角线元素分别作用于矩阵 A 的每一行; 相似地,矩阵 A 右乘一个对角矩阵 D,是分别将 D 的对

    2024年02月11日
    浏览(45)
  • Eigen-Matrix矩阵

    在Eigen中,所有矩阵和向量都是矩阵模板类的对象。向量只是矩阵的一种特殊情况,要么有一行,要么有一列。矩阵就是一个二维数表,可以有多行多列。 Matrix类有六个模板参数,但现在只需要了解前三个参数就足够了。剩下的三个参数都有默认值,我们暂时不碰它们,我们

    2024年03月09日
    浏览(64)
  • 混淆矩阵——矩阵可视化

    相关文章 混淆矩阵——评估指标计算 混淆矩阵——评估指标可视化 正例是指在分类问题中,被标记为目标类别的样本。在二分类问题中, 正例(Positive) 代表我们感兴趣的目标,而另一个类别定义为 反例(Negative) 举个栗子🌰,我们要区分苹果🍎和凤梨🍐。我们 想要

    2024年02月04日
    浏览(55)
  • leetcode 542. 01 Matrix(01矩阵)

    矩阵中只有0,1值,返回每个cell到最近的0的距离。 思路: 0元素到它自己的距离是0, 只需考虑1到最近的0是多少距离。 BFS. 先把元素1处的距离更新为无穷大。 0的位置装入queue。 从每个0出发,走上下左右4个方向,遇到0不需要处理,遇到1,距离为当前距离+1. 如果当前距离

    2024年02月12日
    浏览(36)
  • Eigen 矩阵Matrix及其简单操作

    在Eigen,所有的矩阵和向量都是Matrix模板类的对象,Vector只是一种特殊的矩阵(一行或者一列)。 Matrix有6个模板参数,主要使用前三个参数,剩下的有默认值。 Scalar是表示元素的类型,RowsAtCompileTime为矩阵的行,ColsAtCompileTime为矩阵的列。 库中提供了一些类型便于使用,比如

    2024年02月12日
    浏览(33)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包