Pytorch代码入门学习之分类任务(二):定义数据集

这篇具有很好参考价值的文章主要介绍了Pytorch代码入门学习之分类任务(二):定义数据集。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

一、导包

import torch
import torchvision
import torchvision.transforms as transforms

二、下载数据集

2.1 代码展示

# 定义数据加载进来后的初始化操作:
transform = transforms.Compose([
    # 张量转换:
    transforms.ToTensor(),
    # 归一化操作:
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
 
trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=0)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=0)

Pytorch代码入门学习之分类任务(二):定义数据集,Python相关知识,pytorch,学习,分类,深度学习,人工智能

2.2 数据集介绍与下载方式

       (1)数据集介绍:

        CIFAR10数据集共60000个样本,其中有50000个训练样本和10000个测试样本。每个样本都是一张32*32像素的RGB图像(彩色图像),每个图像分为3个通道(R通道、G通道与B通道)。 

        CIFAR10数据集用来进行监督学习训练,每个样本包含一个标签值,其中有10类物体,标签值按照0~9来区分,分别是飞机( airplane )、汽车( automobile )、鸟( bird )、猫( cat )、鹿( deer )、狗( dog )、青蛙( frog )、马( horse )、船( ship )和卡车( truck )。

        CIFAR10数据集的内容,如下图所示。

Pytorch代码入门学习之分类任务(二):定义数据集,Python相关知识,pytorch,学习,分类,深度学习,人工智能

        官网介绍链接:CIFAR-10 and CIFAR-100 datasets (toronto.edu)

        (2)下载方式:

        ①下载文件:下载地址:https://pan.baidu.com/s/1Nh28RyfwPNNfe_sS8NBNUA 

        提取码:1h4x

        ②将下载好的文件重命名为cifar-10-batches-py.tar.gz

        ③将文件保存至相应地址下即可

2.3 transforms.Compose

        transforms.Compose:相当于将所有需要的操作进行打包;

        transforms.ToTensor:完成张量转换,pytorch处理的都是tensor数据,需要将读入的图片转换为tensor,其中tensor比普通图片的三通道多了一个通道—batch;

       transforms.Normalize:归一化操作,对这一批次的数据进行归一,可以加速网络的收敛、放置梯度消失与梯度爆炸。

2.4 Dataset

        Dataset是指定义好数据的格式和数据变换的形式,完成一些初始化的变化,然后送给网络(相当于将数据读入进去)。

       torchvision.datasets.CIFAR10(调用数据集):第一个参数为数据集加载的地址、第二个参数为是否是训练数据或测试数据(训练数据为True,测试数据为False)、第三个为download-指该数据集是否本地下载,最后一个参数为要做哪些变化(transform是指数据变换格式,torchvision会将我们需要的数据进行格式变换)。

2.4 Dataloader

        Dataloader:负责用iterative迭代的方式不断读入批次数据,一批次一批次将数据进行打包,送入网络进行训练、学习、测试。

        torch.utils.data.DataLoader:第一个参数为数据,第二个参数为batch_size(代表Dataloader一次从这么多数据中拿多少个数据走),第三个参数为是否将数据打乱(训练的时候将数据打乱会让数据变得复杂,测试的时候不需要变得复杂),第四个参数为需要几线程进行读取数据(对于windows,默认为0就可以)

三、定义元组

        定义元组进行类别名的中文转换:

classes = ('airplane','automobile','bird','car','deer','dog','frog','horse','ship','truck')

四、定义显示函数与运行数据加载器

4.1 代码展示

import matplotlib.pyplot as plt
import numpy as np  # 用这个包中的根据将tensor数据转换成矩阵数据
 
def imshow(img):
    img = img / 2 + 0.5
    # tensor数据转换为numpy
    npimg = img.numpy()
    # 使用transpose进行数据转换-通道转换
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()
 
dataiter = iter(trainloader)
images,labels = dataiter.next()
 
imshow(torchvision.utils.make_grid(images))
 
print(labels)
print(labels[0],classes[labels[0]])
print(' '.join(classes[labels[j]] for j in range(4)))

Pytorch代码入门学习之分类任务(二):定义数据集,Python相关知识,pytorch,学习,分类,深度学习,人工智能

4.2 定义显示函数

        tensor[batch,channel,H,W],而正常显示图片的顺序为H、W、channel,因此需要定义显示函数,通过反归一化才能变成正常的图片去显示。

4.3 定义迭代器

        iter(trainloader) : 定义迭代器,读一次迭代的数据(batch_size=4,所以迭代器一次会读取四张图片);

        torchvision.utils.make_grid:将多张图片拼接为一张图片。

参考:003 第一个分类任务1_哔哩哔哩_bilibili文章来源地址https://www.toymoban.com/news/detail-718398.html

到了这里,关于Pytorch代码入门学习之分类任务(二):定义数据集的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Pytorch实现鸟类品种分类识别(含训练代码和鸟类数据集)

    目录 Pytorch实现鸟类识别(含训练代码和鸟类数据集) 1. 前言 2. 鸟类数据集 (1)Bird-Dataset26 (2)自定义数据集 3. 鸟类分类识别模型训练 (1)项目安装 (2)准备Train和Test数据 (3)配置文件:​config.yaml​ (4)开始训练 (5)可视化训练过程 (6)一些优化建议 (7) 一些运

    2024年02月09日
    浏览(63)
  • pytorch-神经网络-手写数字分类任务

    Mnist分类任务: 网络基本构建与训练方法,常用函数解析 torch.nn.functional模块 nn.Module模块 读取Mnist数据集 会自动进行下载 注意数据需转换成tensor才能参与后续建模训练 torch.nn.functional 很多层和函数在这里都会见到 torch.nn.functional中有很多功能,后续会常用的。那什么时候使

    2024年02月09日
    浏览(37)
  • 【Python机器学习】sklearn.datasets分类任务数据集

    如何选择合适的数据集进行机器学习的分类任务? 选择合适的数据集是进行任何机器学习项目的第一步,特别是分类任务。数据集是机器学习任务成功的基础。没有数据,最先进的算法也无从谈起。 本文将专注于 sklearn.datasets 模块中用于分类任务的数据集。这些数据集覆盖

    2024年02月07日
    浏览(45)
  • Pytorch实现中药材(中草药)分类识别(含训练代码和数据集)

    目录 Pytorch实现中药材(中草药)分类识别(含训练代码和数据集) 1. 前言 2. 中药材(中草药)数据集说明 (1)中药材(中草药)数据集:Chinese-Medicine-163 (2)自定义数据集 3. 中草药分类识别模型训练 (1)项目安装 (2)准备Train和Test数据 (3)配置文件: config.yaml (4)开始训练 (

    2023年04月13日
    浏览(40)
  • 基于PyTorch使用LSTM实现新闻文本分类任务

    PyTorch深度学习项目实战100例 https://weibaohang.blog.csdn.net/article/details/127154284?spm=1001.2014.3001.5501 基于PyTorch使用LSTM实现新闻文本分类任务的概况如下: 任务描述:新闻文本分类是一种常见的自然语言处理任务,旨在将新闻文章分为不同的类别,如政治、体育、科技等。 方法:使

    2024年02月09日
    浏览(43)
  • Python深度学习实战-基于tensorflow原生代码搭建BP神经网络实现分类任务(附源码和实现效果)

            前面两篇文章分别介绍了两种搭建神经网络模型的方法,一种是基于tensorflow的keras框架,另一种是继承父类自定义class类,本篇文章将编写原生代码搭建BP神经网络。 本人读研期间发表5篇SCI数据挖掘相关论文,现在某研究院从事数据挖掘相关科研工作,对数据挖掘

    2024年02月08日
    浏览(55)
  • Pytorch:搭建卷积神经网络完成MNIST分类任务:

    2023.7.18 MNIST百科: MNIST数据集简介与使用_bwqiang的博客-CSDN博客 数据集官网:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges 数据集将按以图片和文件夹名为标签的形式保存:  代码:下载mnist数据集并转还为图片  训练代码: 测试代码: 分类正确率不错:

    2024年02月17日
    浏览(46)
  • 度学习pytorch实战六:ResNet50网络图像分类篇自建花数据集图像分类(5类)超详细代码

    1.数据集简介、训练集与测试集划分 2.模型相关知识 3.model.py——定义ResNet50网络模型 4.train.py——加载数据集并训练,训练集计算损失值loss,测试集计算accuracy,保存训练好的网络参数 5.predict.py——利用训练好的网络参数后,用自己找的图像进行分类测试 1.自建数据文件夹

    2024年02月09日
    浏览(41)
  • 分类任务使用Pytorch实现Grad-CAM绘制热力图

    对于深度学习网络,在我们指定数据集类别的情况下,Grad-CAM能够绘制出相应的热力图,让我们能够非常直观的看出网络关注的主要区域与特征是什么。本文主要记录在绘制热力图过程中,自己碰到的一些实际问题,希望能对小伙伴们有所帮助。 以下是本文的参考视频和代码

    2024年02月04日
    浏览(50)
  • 深度学习pytorch实战五:基于ResNet34迁移学习的方法图像分类篇自建花数据集图像分类(5类)超详细代码

    1.数据集简介 2.模型相关知识 3.split_data.py——训练集与测试集划分 4.model.py——定义ResNet34网络模型 5.train.py——加载数据集并训练,训练集计算损失值loss,测试集计算accuracy,保存训练好的网络参数 6.predict.py——利用训练好的网络参数后,用自己找的图像进行分类测试 1.自建

    2024年02月09日
    浏览(60)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包