【Python机器学习】实验15 将Lenet5应用于Cifar10数据集

这篇具有很好参考价值的文章主要介绍了【Python机器学习】实验15 将Lenet5应用于Cifar10数据集。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

CIFAR10数据集介绍

CIFAR-10 数据集由10个类别的60000张32x32彩色图像组成,每类6000张图像。有50000张训练图像和10000张测试图像。数据集分为五个训练批次
和一个测试批次,每个批次有10000张图像。测试批次包含从每个类别中随机选择的1000张图像。训练批次包含随机顺序的剩余图像,但一些训练批次
可能包含比另一个类别更多的图像。在它们之间训练批次包含来自每个类的5000张图像。以下是数据集中的类,以及每个类中的10张随机图像:
【Python机器学习】实验15 将Lenet5应用于Cifar10数据集,《 机器学习 》,python,机器学习,pytorch

因为CIFAR10数据集颜色通道有3个,所以卷积层L1的输入通道数量(in_channels)需要设为3。全连接层fc1的输入维度设为400,这与上例设为256有所不同,原因是初始输入数据的形状不一样,经过卷积池化后,输出的数据形状是不一样的。如果是采用动态图开发模型,那么有一种便捷的方式查看中间结果的形状,即在forward()方法中,用print函数把中间结果的形状打印出来。根据中间结果的形状,决定接下来各网络层的参数。
【Python机器学习】实验15 将Lenet5应用于Cifar10数据集,《 机器学习 》,python,机器学习,pytorch

1. 数据的下载

import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
train_dataset = CIFAR10(root="./data/CIFAR10",train=True,transform=transforms.ToTensor(),download=True)
test_dataset = CIFAR10(root="./data/CIFAR10", train=False,transform=transforms.ToTensor())
Files already downloaded and verified
train_dataset[0][0].shape
torch.Size([3, 32, 32])
train_dataset[0][1]
6

2.修改模型与前面的参数设置保持一致

from torch import nn
class Lenet5(nn.Module):
    def __init__(self):
        super(Lenet5,self).__init__()
        #1+ 32-5/(1)==28
        self.features=nn.Sequential(
        #定义第一个卷积层
        nn.Conv2d(in_channels=3,out_channels=6,kernel_size=(5,5),stride=1),
        nn.ReLU(),
        nn.AvgPool2d(kernel_size=2,stride=2),
        #定义第二个卷积层
        nn.Conv2d(in_channels=6,out_channels=16,kernel_size=(5,5),stride=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2),
        )
        
        #定义全连接层
        self.classfier=nn.Sequential(nn.Linear(in_features=400,out_features=120),
        nn.ReLU(),
        nn.Linear(in_features=120,out_features=84),
        nn.ReLU(),
        nn.Linear(in_features=84,out_features=10),  
        )
        
    def forward(self,x):
        x=self.features(x)
        x=torch.flatten(x,1)
        result=self.classfier(x)
        return result    

3. 新建模型

model=Lenet5()
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model=model.to(device)

4. 从数据集中分批量读取数据

#加载数据集
batch_size=32
train_loader= torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
test_loader= torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False)
# 类别信息也是需要我们给定的
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

5. 定义损失函数

from torch import optim
loss_fun=nn.CrossEntropyLoss()
loss_lst=[]

6. 定义优化器

optimizer=optim.SGD(params=model.parameters(),lr=0.001,momentum=0.9)

7. 开始训练

import time
start_time=time.time()
#训练的迭代次数
for epoch in range(10):
    loss_i=0
    for i,(batch_data,batch_label) in enumerate(train_loader):
        #清空优化器的梯度
        optimizer.zero_grad()
        #模型前向预测
        pred=model(batch_data)
        loss=loss_fun(pred,batch_label)
        loss_i+=loss
        loss.backward()
        optimizer.step()
        if (i+1)%200==0:
            print("第%d次训练,第%d批次,损失为%.2f"%(epoch,i,loss_i/200))
            loss_i=0
end_time=time.time()
print("共训练了%d 秒"%(end_time-start_time))
第0次训练,第199批次,损失为2.30
第0次训练,第399批次,损失为2.30
第0次训练,第599批次,损失为2.30
第0次训练,第799批次,损失为2.30
第0次训练,第999批次,损失为2.30
第0次训练,第1199批次,损失为2.30
第0次训练,第1399批次,损失为2.30
第1次训练,第199批次,损失为2.30
第1次训练,第399批次,损失为2.30
第1次训练,第599批次,损失为2.30
第1次训练,第799批次,损失为2.30
第1次训练,第999批次,损失为2.29
第1次训练,第1199批次,损失为2.27
第1次训练,第1399批次,损失为2.18
第2次训练,第199批次,损失为2.07
第2次训练,第399批次,损失为2.04
第2次训练,第599批次,损失为2.03
第2次训练,第799批次,损失为2.00
第2次训练,第999批次,损失为1.98
第2次训练,第1199批次,损失为1.96
第2次训练,第1399批次,损失为1.95
第3次训练,第199批次,损失为1.89
第3次训练,第399批次,损失为1.86
第3次训练,第599批次,损失为1.84
第3次训练,第799批次,损失为1.80
第3次训练,第999批次,损失为1.75
第3次训练,第1199批次,损失为1.71
第3次训练,第1399批次,损失为1.71
第4次训练,第199批次,损失为1.66
第4次训练,第399批次,损失为1.65
第4次训练,第599批次,损失为1.63
第4次训练,第799批次,损失为1.61
第4次训练,第999批次,损失为1.62
第4次训练,第1199批次,损失为1.60
第4次训练,第1399批次,损失为1.59
第5次训练,第199批次,损失为1.56
第5次训练,第399批次,损失为1.56
第5次训练,第599批次,损失为1.54
第5次训练,第799批次,损失为1.55
第5次训练,第999批次,损失为1.52
第5次训练,第1199批次,损失为1.52
第5次训练,第1399批次,损失为1.49
第6次训练,第199批次,损失为1.50
第6次训练,第399批次,损失为1.47
第6次训练,第599批次,损失为1.46
第6次训练,第799批次,损失为1.47
第6次训练,第999批次,损失为1.46
第6次训练,第1199批次,损失为1.43
第6次训练,第1399批次,损失为1.45
第7次训练,第199批次,损失为1.42
第7次训练,第399批次,损失为1.42
第7次训练,第599批次,损失为1.39
第7次训练,第799批次,损失为1.39
第7次训练,第999批次,损失为1.40
第7次训练,第1199批次,损失为1.40
第7次训练,第1399批次,损失为1.40
第8次训练,第199批次,损失为1.36
第8次训练,第399批次,损失为1.37
第8次训练,第599批次,损失为1.38
第8次训练,第799批次,损失为1.37
第8次训练,第999批次,损失为1.34
第8次训练,第1199批次,损失为1.37
第8次训练,第1399批次,损失为1.35
第9次训练,第199批次,损失为1.31
第9次训练,第399批次,损失为1.31
第9次训练,第599批次,损失为1.31
第9次训练,第799批次,损失为1.31
第9次训练,第999批次,损失为1.34
第9次训练,第1199批次,损失为1.32
第9次训练,第1399批次,损失为1.31
共训练了156 秒

8.测试模型

len(test_dataset)
10000
correct=0
for batch_data,batch_label in test_loader:
    pred_test=model(batch_data)
    pred_result=torch.max(pred_test.data,1)[1]
    correct+=(pred_result==batch_label).sum()
print("准确率为:%.2f%%"%(correct/len(test_dataset)))
准确率为:0.53%

9. 手写体图片的可视化

from torchvision import transforms as T
import torch
len(train_dataset)
50000
train_dataset[0][0].shape
torch.Size([3, 32, 32])
import matplotlib.pyplot as plt
plt.imshow(train_dataset[0][0][0],cmap="gray")
plt.axis('off')
(-0.5, 31.5, 31.5, -0.5)

【Python机器学习】实验15 将Lenet5应用于Cifar10数据集,《 机器学习 》,python,机器学习,pytorch

plt.imshow(train_dataset[0][0][0])
plt.axis('off')
(-0.5, 31.5, 31.5, -0.5)

【Python机器学习】实验15 将Lenet5应用于Cifar10数据集,《 机器学习 》,python,机器学习,pytorch

10. 多幅图片的可视化

from matplotlib import pyplot as plt
plt.figure(figsize=(20,15))
cols=10
rows=10
for i in range(0,rows):
    for j in range(0,cols):
        idx=j+i*cols
        plt.subplot(rows,cols,idx+1) 
        plt.imshow(train_dataset[idx][0][0])
        plt.axis('off')

【Python机器学习】实验15 将Lenet5应用于Cifar10数据集,《 机器学习 》,python,机器学习,pytorch

import numpy as np
img10 = np.stack(list(train_dataset[i][0][0] for i in range(10)), axis=1).reshape(32,320)
plt.imshow(img10)
plt.axis('off')
(-0.5, 319.5, 31.5, -0.5)

【Python机器学习】实验15 将Lenet5应用于Cifar10数据集,《 机器学习 》,python,机器学习,pytorch

img100 = np.stack( 
            tuple( np.stack(
                tuple( train_dataset[j*10+i][0][0] for i in range(10) ), 
                axis=1).reshape(32,320) for j in range(10)),
            axis=0).reshape(320,320)
plt.imshow(img100)
plt.axis('off')
(-0.5, 319.5, 319.5, -0.5)

【Python机器学习】实验15 将Lenet5应用于Cifar10数据集,《 机器学习 》,python,机器学习,pytorch

思考题

  • 测试集中有哪些识别错误的手写数字图片? 汇集整理并分析原因?

11. 读取测试集的图片预测值(神经网络的输出为10)

pre_result=torch.zeros(len(test_dataset),10)
for i in range(len(test_dataset)):
    pre_result[i,:]=model(torch.reshape(test_dataset[i][0],(-1,3,32,32)))
pre_result    
tensor([[-0.4934, -1.0982,  0.4072,  ..., -0.4038, -1.1655, -0.8201],
        [ 4.0154,  4.4736, -0.2921,  ..., -2.3925,  4.3176,  4.1910],
        [ 1.3858,  3.2022, -0.7004,  ..., -2.2767,  3.0923,  2.3740],
        ...,
        [-1.9551, -3.8085,  1.7917,  ...,  2.1104, -2.9573, -1.7387],
        [ 0.6681, -0.5328,  0.3059,  ...,  0.1170, -2.5236, -0.5746],
        [-0.5194, -2.6185,  1.1929,  ...,  3.7749, -2.3134, -1.5123]],
       grad_fn=<CopySlices>)
pre_result.shape
torch.Size([10000, 10])
pre_result[:5]
tensor([[-0.4934, -1.0982,  0.4072,  1.7331, -0.4456,  1.6433,  0.1721, -0.4038,
         -1.1655, -0.8201],
        [ 4.0154,  4.4736, -0.2921, -3.2882, -1.6234, -4.4814, -3.1241, -2.3925,
          4.3176,  4.1910],
        [ 1.3858,  3.2022, -0.7004, -1.0123, -1.7394, -1.6657, -3.2578, -2.2767,
          3.0923,  2.3740],
        [ 2.1151,  0.8262,  0.0071, -1.1410, -0.3051, -2.0239, -2.3023, -0.3573,
          2.9400,  0.5595],
        [-2.3524, -2.7907,  1.9834,  2.1088,  2.7645,  1.1118,  2.9782, -0.3876,
         -3.2325, -2.3916]], grad_fn=<SliceBackward0>)
#显示这10000张图片的标签
label_10000=[test_dataset[i][1] for i in range(10000)]
label_10000
[3,
 8,
 8,
 0,
 6,
 6,
 1,
 6,
 3,
 1,
 0,
 9,
 5,
 7,
 9,
 8,
 5,
 7,
 8,
 6,
 7,
 0,
 4,
 9,
 5,
 2,
 4,
 0,
 9,
 6,
 6,
 5,
 4,
 5,
 9,
 2,
 4,
 1,
 9,
 5,
 4,
 6,
 5,
 6,
 0,
 9,
 3,
 9,
 7,
 6,
 9,
 8,
 0,
 3,
 8,
 8,
 7,
 7,
 4,
 6,
 7,
 3,
 6,
 3,
 6,
 2,
 1,
 2,
 3,
 7,
 2,
 6,
 8,
 8,
 0,
 2,
 9,
 3,
 3,
 8,
 8,
 1,
 1,
 7,
 2,
 5,
 2,
 7,
 8,
 9,
 0,
 3,
 8,
 6,
 4,
 6,
 6,
 0,
 0,
 7,
 4,
 5,
 6,
 3,
 1,
 1,
 3,
 6,
 8,
 7,
 4,
 0,
 6,
 2,
 1,
 3,
 0,
 4,
 2,
 7,
 8,
 3,
 1,
 2,
 8,
 0,
 8,
 3,
 5,
 2,
 4,
 1,
 8,
 9,
 1,
 2,
 9,
 7,
 2,
 9,
 6,
 5,
 6,
 3,
 8,
 7,
 6,
 2,
 5,
 2,
 8,
 9,
 6,
 0,
 0,
 5,
 2,
 9,
 5,
 4,
 2,
 1,
 6,
 6,
 8,
 4,
 8,
 4,
 5,
 0,
 9,
 9,
 9,
 8,
 9,
 9,
 3,
 7,
 5,
 0,
 0,
 5,
 2,
 2,
 3,
 8,
 6,
 3,
 4,
 0,
 5,
 8,
 0,
 1,
 7,
 2,
 8,
 8,
 7,
 8,
 5,
 1,
 8,
 7,
 1,
 3,
 0,
 5,
 7,
 9,
 7,
 4,
 5,
 9,
 8,
 0,
 7,
 9,
 8,
 2,
 7,
 6,
 9,
 4,
 3,
 9,
 6,
 4,
 7,
 6,
 5,
 1,
 5,
 8,
 8,
 0,
 4,
 0,
 5,
 5,
 1,
 1,
 8,
 9,
 0,
 3,
 1,
 9,
 2,
 2,
 5,
 3,
 9,
 9,
 4,
 0,
 3,
 0,
 0,
 9,
 8,
 1,
 5,
 7,
 0,
 8,
 2,
 4,
 7,
 0,
 2,
 3,
 6,
 3,
 8,
 5,
 0,
 3,
 4,
 3,
 9,
 0,
 6,
 1,
 0,
 9,
 1,
 0,
 7,
 9,
 1,
 2,
 6,
 9,
 3,
 4,
 6,
 0,
 0,
 6,
 6,
 6,
 3,
 2,
 6,
 1,
 8,
 2,
 1,
 6,
 8,
 6,
 8,
 0,
 4,
 0,
 7,
 7,
 5,
 5,
 3,
 5,
 2,
 3,
 4,
 1,
 7,
 5,
 4,
 6,
 1,
 9,
 3,
 6,
 6,
 9,
 3,
 8,
 0,
 7,
 2,
 6,
 2,
 5,
 8,
 5,
 4,
 6,
 8,
 9,
 9,
 1,
 0,
 2,
 2,
 7,
 3,
 2,
 8,
 0,
 9,
 5,
 8,
 1,
 9,
 4,
 1,
 3,
 8,
 1,
 4,
 7,
 9,
 4,
 2,
 7,
 0,
 7,
 0,
 6,
 6,
 9,
 0,
 9,
 2,
 8,
 7,
 2,
 2,
 5,
 1,
 2,
 6,
 2,
 9,
 6,
 2,
 3,
 0,
 3,
 9,
 8,
 7,
 8,
 8,
 4,
 0,
 1,
 8,
 2,
 7,
 9,
 3,
 6,
 1,
 9,
 0,
 7,
 3,
 7,
 4,
 5,
 0,
 0,
 2,
 9,
 3,
 4,
 0,
 6,
 2,
 5,
 3,
 7,
 3,
 7,
 2,
 5,
 3,
 1,
 1,
 4,
 9,
 9,
 5,
 7,
 5,
 0,
 2,
 2,
 2,
 9,
 7,
 3,
 9,
 4,
 3,
 5,
 4,
 6,
 5,
 6,
 1,
 4,
 3,
 4,
 4,
 3,
 7,
 8,
 3,
 7,
 8,
 0,
 5,
 7,
 6,
 0,
 5,
 4,
 8,
 6,
 8,
 5,
 5,
 9,
 9,
 9,
 5,
 0,
 1,
 0,
 8,
 1,
 1,
 8,
 0,
 2,
 2,
 0,
 4,
 6,
 5,
 4,
 9,
 4,
 7,
 9,
 9,
 4,
 5,
 6,
 6,
 1,
 5,
 3,
 8,
 9,
 5,
 8,
 5,
 7,
 0,
 7,
 0,
 5,
 0,
 0,
 4,
 6,
 9,
 0,
 9,
 5,
 6,
 6,
 6,
 2,
 9,
 0,
 1,
 7,
 6,
 7,
 5,
 9,
 1,
 6,
 2,
 5,
 5,
 5,
 8,
 5,
 9,
 4,
 6,
 4,
 3,
 2,
 0,
 7,
 6,
 2,
 2,
 3,
 9,
 7,
 9,
 2,
 6,
 7,
 1,
 3,
 6,
 6,
 8,
 9,
 7,
 5,
 4,
 0,
 8,
 4,
 0,
 9,
 3,
 4,
 8,
 9,
 6,
 9,
 2,
 6,
 1,
 4,
 7,
 3,
 5,
 3,
 8,
 5,
 0,
 2,
 1,
 6,
 4,
 3,
 3,
 9,
 6,
 9,
 8,
 8,
 5,
 8,
 6,
 6,
 2,
 1,
 7,
 7,
 1,
 2,
 7,
 9,
 9,
 4,
 4,
 1,
 2,
 5,
 6,
 8,
 7,
 6,
 8,
 3,
 0,
 5,
 5,
 3,
 0,
 7,
 9,
 1,
 3,
 4,
 4,
 5,
 3,
 9,
 5,
 6,
 9,
 2,
 1,
 1,
 4,
 1,
 9,
 4,
 7,
 6,
 3,
 8,
 9,
 0,
 1,
 3,
 6,
 3,
 6,
 3,
 2,
 0,
 3,
 1,
 0,
 5,
 9,
 6,
 4,
 8,
 9,
 6,
 9,
 6,
 3,
 0,
 3,
 2,
 2,
 7,
 8,
 3,
 8,
 2,
 7,
 5,
 7,
 2,
 4,
 8,
 7,
 4,
 2,
 9,
 8,
 8,
 6,
 8,
 8,
 7,
 4,
 3,
 3,
 8,
 4,
 9,
 4,
 8,
 8,
 1,
 8,
 2,
 1,
 3,
 6,
 5,
 4,
 2,
 7,
 9,
 9,
 4,
 1,
 4,
 1,
 3,
 2,
 7,
 0,
 7,
 9,
 7,
 6,
 6,
 2,
 5,
 9,
 2,
 9,
 1,
 2,
 2,
 6,
 8,
 2,
 1,
 3,
 6,
 6,
 0,
 1,
 2,
 7,
 0,
 5,
 4,
 6,
 1,
 6,
 4,
 0,
 2,
 2,
 6,
 0,
 5,
 9,
 1,
 7,
 6,
 7,
 0,
 3,
 9,
 6,
 8,
 3,
 0,
 3,
 4,
 7,
 7,
 1,
 4,
 7,
 2,
 7,
 1,
 4,
 7,
 4,
 4,
 8,
 4,
 7,
 7,
 5,
 3,
 7,
 2,
 0,
 8,
 9,
 5,
 8,
 3,
 6,
 2,
 0,
 8,
 7,
 3,
 7,
 6,
 5,
 3,
 1,
 3,
 2,
 2,
 5,
 4,
 1,
 2,
 9,
 2,
 7,
 0,
 7,
 2,
 1,
 3,
 2,
 0,
 2,
 4,
 7,
 9,
 8,
 9,
 0,
 7,
 7,
 0,
 7,
 8,
 4,
 6,
 3,
 3,
 0,
 1,
 3,
 7,
 0,
 1,
 3,
 1,
 4,
 2,
 3,
 8,
 4,
 2,
 3,
 7,
 8,
 4,
 3,
 0,
 9,
 0,
 0,
 1,
 0,
 4,
 4,
 6,
 7,
 6,
 1,
 1,
 3,
 7,
 3,
 5,
 2,
 6,
 6,
 5,
 8,
 7,
 1,
 6,
 8,
 8,
 5,
 3,
 0,
 4,
 0,
 1,
 3,
 8,
 8,
 0,
 6,
 9,
 9,
 9,
 5,
 5,
 8,
 6,
 0,
 0,
 4,
 2,
 3,
 2,
 7,
 2,
 2,
 5,
 9,
 8,
 9,
 1,
 7,
 4,
 0,
 3,
 0,
 1,
 3,
 8,
 3,
 9,
 6,
 1,
 4,
 7,
 0,
 3,
 7,
 8,
 9,
 1,
 1,
 6,
 6,
 6,
 6,
 9,
 1,
 9,
 9,
 4,
 2,
 1,
 7,
 0,
 6,
 8,
 1,
 9,
 2,
 9,
 0,
 4,
 7,
 8,
 3,
 1,
 2,
 0,
 1,
 5,
 8,
 4,
 6,
 3,
 8,
 1,
 3,
 8,
 ...]
import numpy
pre_10000=pre_result.detach()
pre_10000
tensor([[-0.4934, -1.0982,  0.4072,  ..., -0.4038, -1.1655, -0.8201],
        [ 4.0154,  4.4736, -0.2921,  ..., -2.3925,  4.3176,  4.1910],
        [ 1.3858,  3.2022, -0.7004,  ..., -2.2767,  3.0923,  2.3740],
        ...,
        [-1.9551, -3.8085,  1.7917,  ...,  2.1104, -2.9573, -1.7387],
        [ 0.6681, -0.5328,  0.3059,  ...,  0.1170, -2.5236, -0.5746],
        [-0.5194, -2.6185,  1.1929,  ...,  3.7749, -2.3134, -1.5123]])
pre_10000=numpy.array(pre_10000)
pre_10000
array([[-0.49338394, -1.098238  ,  0.40724754, ..., -0.40375623,
        -1.165497  , -0.820113  ],
       [ 4.0153656 ,  4.4736323 , -0.29209492, ..., -2.392501  ,
         4.317573  ,  4.190993  ],
       [ 1.3858219 ,  3.2021556 , -0.70040375, ..., -2.2767155 ,
         3.092283  ,  2.373978  ],
       ...,
       [-1.9550545 , -3.808494  ,  1.7917161 , ...,  2.110389  ,
        -2.9572597 , -1.7386926 ],
       [ 0.66809845, -0.5327946 ,  0.30590305, ...,  0.11701592,
        -2.5236375 , -0.5746133 ],
       [-0.51935434, -2.6184506 ,  1.1929085 , ...,  3.7748828 ,
        -2.3134274 , -1.5123445 ]], dtype=float32)

12. 采用pandas可视化数据

import pandas as pd 
table=pd.DataFrame(zip(pre_10000,label_10000))
table
0 1
0 [-0.49338394, -1.098238, 0.40724754, 1.7330961... 3
1 [4.0153656, 4.4736323, -0.29209492, -3.2882178... 8
2 [1.3858219, 3.2021556, -0.70040375, -1.0123051... 8
3 [2.11508, 0.82618773, 0.007076204, -1.1409527,... 0
4 [-2.352432, -2.7906854, 1.9833877, 2.1087575, ... 6
... ... ...
9995 [-0.55809855, -4.3891077, -0.3040389, 3.001731... 8
9996 [-2.7151718, -4.1596007, 1.2393914, 2.8491826,... 3
9997 [-1.9550545, -3.808494, 1.7917161, 2.6365147, ... 5
9998 [0.66809845, -0.5327946, 0.30590305, -0.182045... 1
9999 [-0.51935434, -2.6184506, 1.1929085, 0.1288419... 7

10000 rows × 2 columns

table[0].values
array([array([-0.49338394, -1.098238  ,  0.40724754,  1.7330961 , -0.4455951 ,
               1.6433077 ,  0.1720748 , -0.40375623, -1.165497  , -0.820113  ],
             dtype=float32)                                                    ,
       array([ 4.0153656 ,  4.4736323 , -0.29209492, -3.2882178 , -1.6234205 ,
              -4.481386  , -3.1240807 , -2.392501  ,  4.317573  ,  4.190993  ],
             dtype=float32)                                                    ,
       array([ 1.3858219 ,  3.2021556 , -0.70040375, -1.0123051 , -1.7393746 ,
              -1.6656632 , -3.2578242 , -2.2767155 ,  3.092283  ,  2.373978  ],
             dtype=float32)                                                    ,
       ...,
       array([-1.9550545 , -3.808494  ,  1.7917161 ,  2.6365147 ,  0.37311587,
               3.545672  , -0.43889195,  2.110389  , -2.9572597 , -1.7386926 ],
             dtype=float32)                                                    ,
       array([ 0.66809845, -0.5327946 ,  0.30590305, -0.18204585,  2.0045712 ,
               0.47369143, -0.3122899 ,  0.11701592, -2.5236375 , -0.5746133 ],
             dtype=float32)                                                    ,
       array([-0.51935434, -2.6184506 ,  1.1929085 ,  0.1288419 ,  1.8770852 ,
               0.4296908 , -0.22015049,  3.7748828 , -2.3134274 , -1.5123445 ],
             dtype=float32)                                                    ],
      dtype=object)
table["pred"]=[np.argmax(table[0][i]) for i in range(table.shape[0])]
table
0 1 pred
0 [-0.49338394, -1.098238, 0.40724754, 1.7330961... 3 3
1 [4.0153656, 4.4736323, -0.29209492, -3.2882178... 8 1
2 [1.3858219, 3.2021556, -0.70040375, -1.0123051... 8 1
3 [2.11508, 0.82618773, 0.007076204, -1.1409527,... 0 8
4 [-2.352432, -2.7906854, 1.9833877, 2.1087575, ... 6 6
... ... ... ...
9995 [-0.55809855, -4.3891077, -0.3040389, 3.001731... 8 5
9996 [-2.7151718, -4.1596007, 1.2393914, 2.8491826,... 3 3
9997 [-1.9550545, -3.808494, 1.7917161, 2.6365147, ... 5 5
9998 [0.66809845, -0.5327946, 0.30590305, -0.182045... 1 4
9999 [-0.51935434, -2.6184506, 1.1929085, 0.1288419... 7 7

10000 rows × 3 columns

13. 对预测错误的样本点进行可视化

mismatch=table[table[1]!=table["pred"]]
mismatch
0 1 pred
1 [4.0153656, 4.4736323, -0.29209492, -3.2882178... 8 1
2 [1.3858219, 3.2021556, -0.70040375, -1.0123051... 8 1
3 [2.11508, 0.82618773, 0.007076204, -1.1409527,... 0 8
8 [0.02641207, -3.6653092, 2.294829, 2.2884543, ... 3 5
12 [-1.4556388, -1.7955011, -0.6100754, 1.169481,... 5 6
... ... ... ...
9989 [-0.2553262, -2.8777533, 3.4579017, 0.3079242,... 2 4
9993 [-0.077826336, -3.14616, 0.8994149, 3.5604722,... 5 3
9994 [-1.2543154, -2.4472265, 0.6754027, 2.0582433,... 3 6
9995 [-0.55809855, -4.3891077, -0.3040389, 3.001731... 8 5
9998 [0.66809845, -0.5327946, 0.30590305, -0.182045... 1 4

4657 rows × 3 columns

from matplotlib import pyplot as plt
plt.scatter(mismatch[1],mismatch["pred"])
<matplotlib.collections.PathCollection at 0x1b3a92ef910>

【Python机器学习】实验15 将Lenet5应用于Cifar10数据集,《 机器学习 》,python,机器学习,pytorch

14. 看看错误样本被预测为哪些数据?

mismatch[mismatch[1]==9].sort_values("pred").index
Int64Index([2129, 1465, 2907,  787, 2902, 2307, 4588, 5737, 8276, 8225,
            ...
            7635, 7553, 7526, 3999, 1626, 1639, 4193, 7198, 3957, 3344],
           dtype='int64', length=396)
idx_lst=mismatch[mismatch[1]==9].sort_values("pred").index.values
idx_lst,len(idx_lst)
(array([2129, 1465, 2907,  787, 2902, 2307, 4588, 5737, 8276, 8225, 8148,
        4836, 1155, 7218, 8034, 7412, 5069, 1629, 5094, 5109, 7685, 5397,
        1427, 5308, 8727, 2960, 2491, 6795, 1997, 6686, 9449, 6545, 8985,
        9401, 3564, 6034,  383, 9583, 9673,  507, 3288, 6868, 9133, 9085,
         577, 4261, 6974,  411, 6290, 5416, 5350, 5950, 5455, 5498, 6143,
        5964, 5864, 5877, 6188, 5939,   14, 5300, 3501, 3676, 3770, 3800,
        3850, 3893, 3902, 4233, 4252, 4253, 4276, 5335, 4297, 4418, 4445,
        4536, 4681, 6381, 4929, 4945, 5067, 5087, 5166, 5192, 4364, 4928,
        7024, 6542, 8144, 8312, 8385, 8406, 8453, 8465, 8521, 8585, 8673,
        8763, 8946, 9067, 9069, 9199, 9209, 9217, 9280, 9403, 9463, 9518,
        9692, 9743, 9871, 9875, 9881, 8066, 6509, 8057, 7826, 6741, 6811,
        6814, 6840, 6983, 7007, 3492, 7028, 7075, 7121, 7232, 7270, 7424,
        7431, 7444, 7492, 7499, 7501, 7578, 7639, 7729, 7767, 7792, 7818,
        7824, 7942, 3459, 4872, 1834, 1487, 1668, 1727, 1732, 1734, 1808,
        1814, 1815, 1831, 1927, 2111, 2126, 2190, 2246, 2290, 2433, 2596,
        2700, 2714, 1439, 1424, 1376, 1359,   28,  151,  172,  253,  259,
         335,  350,  591,  625, 2754,  734,  940,  951,  970, 1066, 1136,
        1177, 1199, 1222, 1231,  853, 2789, 9958, 2946, 3314, 3307, 2876,
        3208, 3166, 2944, 2817, 2305, 7522, 7155, 7220, 4590, 2899, 2446,
        2186, 7799, 9492, 3163, 4449, 2027, 2387, 1064, 3557, 2177,  654,
        9791, 2670, 2514, 2495, 3450, 8972, 3210, 3755, 2756, 7967, 3970,
        4550, 6017,  938,  744, 6951, 3397, 4852, 3133, 7931,  707, 3312,
        7470, 6871, 8292, 7100, 9529, 9100, 3853, 9060, 9732, 2521, 3789,
        2974, 5311, 3218, 5736, 3055, 7076, 1220, 9147, 1344,  532, 8218,
        3569, 1008, 8475, 8877, 1582, 8936, 4758, 1837, 9517,  252, 5832,
        1916, 6369, 4979, 9324, 6218, 9777, 7923, 4521, 2868,  213, 8083,
        5952, 5579, 4508, 5488, 2460, 5332, 5180, 8323, 8345, 3776, 2568,
        5151, 4570, 2854, 8488, 4874,  680, 2810, 1285, 6136, 3339, 9143,
        6852, 1906, 7067, 7073, 2975, 1924, 6804, 6755, 9299, 2019, 9445,
        9560,  360, 1601, 7297, 9122, 6377, 9214, 6167, 3980,  394, 7491,
        7581, 9349, 8953,  222,  139,  530, 3577, 9868,  247, 9099, 9026,
         209,  538, 3229, 9258,  585, 9204, 9643, 1492, 3609, 6570, 6561,
        6469, 6435, 6419, 2155, 6275, 4481, 2202, 1987, 2271, 2355, 2366,
        2432, 5400, 2497, 2727, 4931, 4619, 9884, 5902, 8796, 6848, 6960,
        8575, 8413,  981, 8272, 8145, 3172, 1221, 3168, 1256, 1889, 1291,
        3964, 7635, 7553, 7526, 3999, 1626, 1639, 4193, 7198, 3957, 3344],
       dtype=int64),
 396)
import numpy as np
img=np.stack(list(test_dataset[idx_lst[i]][0][0] for i in range(5)),axis=1).reshape(32,32*5)
plt.imshow(img)
plt.axis('off')
(-0.5, 159.5, 31.5, -0.5)

【Python机器学习】实验15 将Lenet5应用于Cifar10数据集,《 机器学习 》,python,机器学习,pytorch

#显示4行
import numpy as np
img20=np.stack(
    tuple(np.stack(
            tuple(test_dataset[idx_lst[i+j*5]][0][0] for i in range(5)),
        axis=1).reshape(32,32*5) for j in range(4)),axis=0).reshape(32*4,32*5)
plt.imshow(img20)
plt.axis('off')
(-0.5, 159.5, 127.5, -0.5)

【Python机器学习】实验15 将Lenet5应用于Cifar10数据集,《 机器学习 》,python,机器学习,pytorch文章来源地址https://www.toymoban.com/news/detail-671020.html

15.输出错误的模型类别

idx_lst=mismatch[mismatch[1]==9].index.values
table.iloc[idx_lst[:], 2].values
array([1, 1, 8, 1, 1, 8, 7, 8, 8, 6, 1, 1, 1, 1, 7, 0, 7, 0, 0, 8, 6, 8,
       0, 8, 1, 1, 3, 7, 5, 1, 4, 0, 1, 4, 1, 1, 1, 8, 6, 3, 1, 1, 0, 1,
       1, 6, 8, 1, 1, 8, 7, 8, 6, 1, 1, 1, 0, 1, 0, 1, 8, 6, 7, 8, 0, 8,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 6, 8, 7, 6, 7, 1, 8, 0, 7, 3, 1, 1, 0,
       8, 3, 3, 1, 8, 1, 8, 1, 2, 0, 8, 8, 3, 8, 1, 3, 7, 0, 3, 8, 3, 5,
       7, 1, 3, 1, 1, 8, 1, 3, 1, 7, 1, 7, 7, 1, 3, 0, 0, 1, 1, 0, 5, 7,
       6, 4, 3, 1, 8, 8, 1, 3, 5, 8, 0, 1, 5, 1, 7, 8, 4, 3, 1, 1, 1, 3,
       0, 6, 8, 8, 1, 3, 1, 7, 5, 1, 1, 5, 1, 1, 8, 8, 4, 7, 8, 8, 1, 1,
       1, 0, 1, 1, 1, 1, 1, 3, 8, 7, 7, 1, 4, 7, 0, 2, 8, 1, 6, 0, 4, 1,
       7, 1, 1, 8, 1, 6, 1, 0, 1, 0, 0, 7, 1, 7, 1, 1, 0, 5, 7, 1, 1, 0,
       8, 1, 1, 7, 1, 7, 5, 0, 6, 1, 1, 8, 1, 1, 7, 1, 4, 0, 7, 1, 7, 1,
       6, 8, 1, 6, 7, 1, 8, 8, 8, 1, 1, 0, 8, 8, 0, 1, 7, 0, 7, 1, 1, 1,
       8, 7, 0, 5, 4, 8, 0, 1, 1, 1, 1, 7, 7, 1, 6, 5, 1, 2, 8, 0, 2, 1,
       1, 7, 0, 1, 1, 1, 5, 7, 1, 1, 1, 2, 8, 8, 1, 7, 8, 1, 0, 1, 1, 1,
       3, 1, 1, 1, 7, 4, 1, 4, 0, 1, 1, 7, 1, 8, 0, 6, 0, 8, 0, 5, 1, 7,
       7, 1, 1, 8, 1, 1, 6, 7, 1, 8, 1, 1, 0, 1, 8, 6, 6, 1, 8, 3, 0, 8,
       5, 1, 1, 0, 8, 5, 7, 0, 7, 6, 1, 8, 1, 7, 1, 8, 1, 7, 6, 8, 0, 1,
       7, 0, 1, 3, 6, 1, 5, 7, 0, 8, 0, 1, 5, 1, 6, 3, 8, 1, 1, 1, 8, 1],
      dtype=int64)
arr2=table.iloc[idx_lst[:], 2].values
print('错误模型共' + str(len(arr2)) + '个')
for i in range(33):
    for j in range(12):
        print(classes[arr2[j+i*12]],end=" ")
    print()
错误模型共396个
car car ship car car ship horse ship ship frog car car 
car car horse plane horse plane plane ship frog ship plane ship 
car car cat horse dog car deer plane car deer car car 
car ship frog cat car car plane car car frog ship car 
car ship horse ship frog car car car plane car plane car 
ship frog horse ship plane ship car car car car car car 
car car car frog ship horse frog horse car ship plane horse 
cat car car plane ship cat cat car ship car ship car 
bird plane ship ship cat ship car cat horse plane cat ship 
cat dog horse car cat car car ship car cat car horse 
car horse horse car cat plane plane car car plane dog horse 
frog deer cat car ship ship car cat dog ship plane car 
dog car horse ship deer cat car car car cat plane frog 
ship ship car cat car horse dog car car dog car car 
ship ship deer horse ship ship car car car plane car car 
car car car cat ship horse horse car deer horse plane bird 
ship car frog plane deer car horse car car ship car frog 
car plane car plane plane horse car horse car car plane dog 
horse car car plane ship car car horse car horse dog plane 
frog car car ship car car horse car deer plane horse car 
horse car frog ship car frog horse car ship ship ship car 
car plane ship ship plane car horse plane horse car car car 
ship horse plane dog deer ship plane car car car car horse 
horse car frog dog car bird ship plane bird car car horse 
plane car car car dog horse car car car bird ship ship 
car horse ship car plane car car car cat car car car 
horse deer car deer plane car car horse car ship plane frog 
plane ship plane dog car horse horse car car ship car car 
frog horse car ship car car plane car ship frog frog car 
ship cat plane ship dog car car plane ship dog horse plane 
horse frog car ship car horse car ship car horse frog ship 
plane car horse plane car cat frog car dog horse plane ship 
plane car dog car frog cat ship car car car ship car 

附:系列文章

实验 目录 直达链接
1 Numpy以及可视化回顾 https://want595.blog.csdn.net/article/details/131891689
2 线性回归 https://want595.blog.csdn.net/article/details/131892463
3 逻辑回归 https://want595.blog.csdn.net/article/details/131912053
4 多分类实践(基于逻辑回归) https://want595.blog.csdn.net/article/details/131913690
5 机器学习应用实践-手动调参 https://want595.blog.csdn.net/article/details/131934812
6 贝叶斯推理 https://want595.blog.csdn.net/article/details/131947040
7 KNN最近邻算法 https://want595.blog.csdn.net/article/details/131947885
8 K-means无监督聚类 https://want595.blog.csdn.net/article/details/131952371
9 决策树 https://want595.blog.csdn.net/article/details/131991014
10 随机森林和集成学习 https://want595.blog.csdn.net/article/details/132003451
11 支持向量机 https://want595.blog.csdn.net/article/details/132010861
12 神经网络-感知器 https://want595.blog.csdn.net/article/details/132014769
13 基于神经网络的回归-分类实验 https://want595.blog.csdn.net/article/details/132127413
14 手写体卷积神经网络 https://want595.blog.csdn.net/article/details/132223494
15 将Lenet5应用于Cifar10数据集 https://want595.blog.csdn.net/article/details/132223751
16 卷积、下采样、经典卷积网络 https://want595.blog.csdn.net/article/details/132223985

到了这里,关于【Python机器学习】实验15 将Lenet5应用于Cifar10数据集的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Python机器学习实验 Python 数据分析

    掌握常见数据预处理方法,熟练运用数据分析方法,并掌握 Python 中的 Numpy、 Pandas 模块提供的数据分析方法。 1.   P a nd a s   基本数据处理 使用 Pandas 模块,完成以下操作。 (1)创建一个由 0 到 50 之间的整数组成的 10 行 5 列的 dataframe。如下: (2)汇总每一列的

    2024年04月09日
    浏览(40)
  • 【Python机器学习】实验03 逻辑回归

    在这一次练习中,我们将要实现逻辑回归并且应用到一个分类任务。我们还将通过将正则化加入训练算法,来提高算法的鲁棒性,并用更复杂的情形来测试它。 本实验的数据包含两个变量(评分1和评分2,可以看作是特征),某大学的管理者,想通过申请学生两次测试的评分,来

    2024年02月11日
    浏览(43)
  • 【Python机器学习】实验09 决策树

    1.分类决策树模型是表示基于特征对实例进行分类的树形结构。决策树可以转换成一个 if-then 规则的集合,也可以看作是定义在特征空间划分上的类的条件概率分布。 2.决策树学习旨在构建一个与训练数据拟合很好,并且复杂度小的决策树。因为从可能的决策树中直接选取

    2024年02月12日
    浏览(35)
  • 【Python机器学习】实验08 决策树

    1.分类决策树模型是表示基于特征对实例进行分类的树形结构。决策树可以转换成一个 if-then 规则的集合,也可以看作是定义在特征空间划分上的类的条件概率分布。 2.决策树学习旨在构建一个与训练数据拟合很好,并且复杂度小的决策树。因为从可能的决策树中直接选取

    2024年02月14日
    浏览(35)
  • python与深度学习(九):CNN和cifar10

    本篇文章是CNN的另外一个例子,cifar10,是彩色的十分类数据集。 可以搭建卷积神经网络来训练模型。 以下第三方库是python专门用于深度学习的库 把cifar10数据集进行加载 x_train是fashion_mnist训练集图片,大小的32 32 3的,y_train是对应的标签是数字。 x_test是fashion_mnist测试集图片,大

    2024年02月15日
    浏览(39)
  • 【Python机器学习】实验06 KNN最近邻算法

    1. k k k 近邻法是基本且简单的分类与回归方法。 k k k 近邻法的基本做法是:对给定的训练实例点和输入实例点,首先确定输入实例点的 k k k 个最近邻训练实例点,然后利用这 k k k 个训练实例点的类的多数来预测输入实例点的类。 2. k k k 近邻模型对应于基于训练数据集对

    2024年02月15日
    浏览(51)
  • 【Python机器学习】实验11 神经网络-感知器

    1.感知机是根据输入实例的特征向量 x x x 对其进行二类分类的线性分类模型: f ( x ) = sign ⁡ ( w ⋅ x + b ) f(x)=operatorname{sign}(w cdot x+b) f ( x ) = sign ( w ⋅ x + b ) 感知机模型对应于输入空间(特征空间)中的分离超平面 w ⋅ x + b = 0 w cdot x+b=0 w ⋅ x + b = 0 。 2.感知机学习的策略

    2024年02月13日
    浏览(40)
  • 【Python机器学习】实验04 多分类实践(基于逻辑回归)

    Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据样本,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Se

    2024年02月11日
    浏览(47)
  • 【Python机器学习】实验16 卷积、下采样、经典卷积网络

    tf.Tensor: id=8263, shape=(2, 8), dtype=float32, numpy= array([[0.16938789, 0. , 0.08883161, 0.14095941, 0.34751543, 0.353898 , 0. , 0.13356908], [0. , 0. , 0.48546872, 0.37623546, 0.5447475 , 0.21755993, 0.40121362, 0. ]], dtype=float32) tf.Tensor: id=8296, shape=(4, 5), dtype=float32, numpy= array([[0. , 0. , 0. , 0.14286758, 0. ], [0. , 2.2727172 , 0. , 0

    2024年02月11日
    浏览(44)
  • 【Python机器学习】实验04(1) 多分类(基于逻辑回归)实践

    Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据样本,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Se

    2024年02月14日
    浏览(47)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包