Pytorch的torch.utils.data中Dataset以及DataLoader等详解

这篇具有很好参考价值的文章主要介绍了Pytorch的torch.utils.data中Dataset以及DataLoader等详解。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

在我们进行深度学习的过程中,不免要用到数据集,那么数据集是如何加载到我们的模型中进行训练的呢?以往我们大多数初学者肯定都是拿网上的代码直接用,但是它底层的原理到底是什么还是不太清楚。所以今天就从内置的Dataset函数和自定义的Dataset函数做一个详细的解析。

前言

torch.utils.dataPyTorch提供的一个模块,用于处理和加载数据。该模块提供了一系列工具类和函数,用于创建、操作和批量加载数据集。

下面是 torch.utils.data 模块中一些常用的类和函数:

  • Dataset: 定义了抽象的数据集类,用户可以通过继承该类来构建自己的数据集。Dataset 类提供了两个必须实现的方法:__getitem__ 用于访问单个样本,__len__ 用于返回数据集的大小。
  • TensorDataset: 继承自 Dataset 类,用于将张量数据打包成数据集。它接受多个张量作为输入,并按照第一个输入张量的大小来确定数据集的大小。
  • DataLoader: 数据加载器类,用于批量加载数据集。它接受一个数据集对象作为输入,并提供多种数据加载和预处理的功能,如设置批量大小、多线程数据加载和数据打乱等。
  • Subset: 数据集的子集类,用于从数据集中选择指定的样本。
  • random_split: 将一个数据集随机划分为多个子集,可以指定划分的比例或指定每个子集的大小。
  • ConcatDataset: 将多个数据集连接在一起形成一个更大的数据集。
  • get_worker_info: 获取当前数据加载器所在的进程信息。

除了上述的类和函数之外,torch.utils.data 还提供了一些常用的数据预处理的工具,如随机裁剪、随机旋转、标准化等。

通过 torch.utils.data 模块提供的类和函数,可以方便地加载、处理和批量加载数据,为模型训练和验证提供了便利。但是,我们最常用的两个类还是DatasetDataLoader类。

1、自定义Dataset类

torch.utils.data.Dataset是 PyTorch 中用于表示数据集的抽象类,用于定义数据集的访问方式和样本数量。

Dataset 类是一个基类,我们可以通过继承该类并实现下面两个方法来创建自定义的数据集类:

getitem(self, index): 根据给定的索引 index,返回对应的样本数据。索引可以是一个整数,表示按顺序获取样本,也可以是其他方式,如通过文件名获取样本等。
len(self): 返回数据集中样本的数量。

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        # 根据索引获取样本
        return self.data[index]

    def __len__(self):
        # 返回数据集大小
        return len(self.data)

# 创建数据集对象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)

# 根据索引获取样本
sample = dataset[2]
print(sample)
# 3

上面的代码样例主要实现的是一个自定义Dataset数据集类的方法,这一般都是在我们需要训练自己的数据时候需要定义的。但是一般我们作为深度学习初学者来讲,使用的都是MNIST、CIFAR-10等内置数据集,这时候就不需要再自己定义Dataset类了。至于为什么,我们下面进行详解。

2、torchvision.datasets

如果要使用PyTorch中的内置数据集,通常是通过torchvision.datasets模块来实现。torchvision.datasets模块提供了许多常用的计算机视觉数据集,如MNIST、CIFAR10、ImageNet等。

下面是使用内置数据集的示例代码:

import torch
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

在上述代码中,我们实现的便是一个内置MNIST(手写数字)数据集的加载和使用。可以看到,我们在这里面并未用到上面所提到的torch.utils.data.Dataset类,这是为什么呢?

这是因为在 torchvision.datasets 模块中,内置的数据集类已经实现了torch.utils.data.Dataset 接口,并直接返回一个可用的数据集对象。因此,在使用内置数据集时,我们可以直接实例化内置数据集类,而不需要显式地继承 torch.utils.data.Dataset 类。

内置数据集类(如 torchvision.datasets.MNIST)的实现已经包含了对 __getitem____len__ 方法的定义,这使得我们可以直接从内置数据集对象中获取样本和确定数据集的大小。这样,我们在使用内置数据集时可以直接将内置数据集对象传递给 torch.utils.data.DataLoader 进行数据加载和批量处理。

在内置数据集的背后,它们仍然是基于 torch.utils.data.Dataset 类进行实现,只是为了方便使用和提供更多功能,PyTorch 将这些常用数据集封装成了内置的数据集类。

为此,我专门到pytorch官网去查看了该内置数据集的加载代码,如下图所示:
Pytorch的torch.utils.data中Dataset以及DataLoader等详解,深度学习基础,零基础深度学习项目实战,pytorch,人工智能,python
可以看出,确实以及内置了Dataset数据集类。

3、DataLoader

torch.utils.data.DataLoader 是 PyTorch 中用于批量加载数据的工具类。它接受一个数据集对象(如 torch.utils.data.Dataset 的子类)并提供多种功能,如数据加载、批量处理、数据打乱等。

以下是 torch.utils.data.DataLoader 的常用参数和功能:

  • dataset: 数据集对象,可以是 torch.utils.data.Dataset 的子类对象。
  • batch_size: 每个批次的样本数量,默认为 1。
  • shuffle: 是否对数据进行打乱,默认为 False。在每个 epoch 时会重新打乱数据。
  • num_workers: 使用多少个子进程加载数据,默认为 0,表示在主进程中加载数据。其实在Windows系统里面都设置为0,但是在Linux中可以设置成大于0的数。
  • collate_fn: 在返回批次数据之前,对每个样本进行处理的函数。如果为 None,默认使用 torch.utils.data._utils.collate.default_collate 函数进行处理。
  • drop_last: 是否丢弃最后一个样本数量不足一个批次的数据,默认为 False
  • pin_memory: 是否将加载的数据存放在 CUDA 对应的固定内存中,默认为 False
  • prefetch_factor: 预取因子,用于预取数据到设备,默认为 2。
  • persistent_workers: 如果为 True,则在每个 epoch 中使用持久的子进程进行数据加载,默认为 False

示例代码如下:

import torch
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

# 使用数据加载器迭代样本
for images, labels in train_loader:
    # 训练模型的代码
    ...

4、torchvision.transforms

torchvision.transforms模块是PyTorch中用于图像数据预处理的功能模块。它提供了一系列的转换函数,用于在加载、训练或推断图像数据时进行各种常见的数据变换和增强操作。下面是一些常用的转换函数的详细解释:

  1. Resize:调整图像大小

    • Resize(size):将图像调整为给定的尺寸。可以接受一个整数作为较短边的大小,也可以接受一个元组或列表作为图像的目标大小。
  2. ToTensor:将图像转换为张量

    • ToTensor():将图像转换为张量,像素值范围从0-255映射到0-1。适用于将图像数据传递给深度学习模型。
  3. Normalize:标准化图像数据

    • Normalize(mean, std):对图像数据进行标准化处理。传入的mean和std是用于像素值归一化的均值和标准差。需要注意的是,mean和std需要与之前使用的数据集相对应。
  4. RandomHorizontalFlip:随机水平翻转图像

    • RandomHorizontalFlip(p=0.5):以给定的概率对图像进行随机水平翻转。概率p控制翻转的概率,默认为0.5。
  5. RandomCrop:随机裁剪图像

    • RandomCrop(size, padding=None):随机裁剪图像为给定的尺寸。可以提供一个元组或整数作为目标尺寸,并可选地提供填充值。
  6. ColorJitter:颜色调整

    • ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):随机调整图像的亮度、对比度、饱和度和色调。可以通过设置不同的参数来调整图像的样貌。

在使用的时候,我们常常通过transforms.Compose来对这些数据处理操作进行一个组合,使用的时候,直接调用该组合即可。

示例代码如下:

from torchvision import transforms

# 定义图像预处理操作
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 缩放图像大小为 (256, 256)
    transforms.RandomCrop((224, 224)),  # 随机裁剪图像为 (224, 224)
    transforms.RandomHorizontalFlip(),  # 随机水平翻转图像
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化图像
])

# 对图像进行预处理
image = transform(image)

5、图像分类中Dataset数据集类的定义

就拿眼疾数据集来说(详细可看深度学习实战基础案例——卷积神经网络(CNN)基于SqueezeNet的眼疾识别|第1例),其中我们对数据集进行标签划分以后,生成了train.txt以及valid.txt文件,该文件中分别为两列,第一列为数据集的路径,第二列为数据集的标签(也就是类别),具体如下:
Pytorch的torch.utils.data中Dataset以及DataLoader等详解,深度学习基础,零基础深度学习项目实战,pytorch,人工智能,python
这时候我们就可以定义自己的数据集读取类,具体代码如下:

import os.path
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms

transform_BZ = transforms.Normalize(
    mean=[0.5, 0.5, 0.5],
    std=[0.5, 0.5, 0.5]
)


class MyDataset(Dataset):
    def __init__(self, txt_path, train_flag=True):
        self.imgs_info = self.get_images(txt_path)
        self.train_flag = train_flag

        self.train_tf = transforms.Compose([
            transforms.Resize(224),  # 调整图像大小为224x224
            transforms.RandomHorizontalFlip(),  # 随机左右翻转图像
            transforms.RandomVerticalFlip(),  # 随机上下翻转图像
            transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
            transform_BZ  # 执行某些复杂变换操作
        ])
        self.val_tf = transforms.Compose([
            transforms.Resize(224),  # 调整图像大小为224x224
            transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
            transform_BZ  # 执行某些复杂变换操作
        ])

    def get_images(self, txt_path):
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            imgs_info = list(map(lambda x: x.strip().split(' '), imgs_info))
        return imgs_info

    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]

        img_path = os.path.join('', img_path)
        img = Image.open(img_path)
        img = img.convert("RGB")
        if self.train_flag:
            img = self.train_tf(img)
        else:
            img = self.val_tf(img)
        label = int(label)
        return img, label

    def __len__(self):
        return len(self.imgs_info)

定义完我们自己的数据集读取类以后,就可以将我们的txt文件传入进行数据集的预处理以及读取工作。在我们的自定义dataset类里面,最重要的三个方法是__init__()、getitem()以及__len__(),这三个缺一不可。同时,transforms的数据增强操作也不是必须的,这不过是提高模型性能的一个方法而已,但是我们现在的模型训练过程一般都会加上数据增强操作。

# 加载训练集和验证集
train_data = MyDataset(r"F:\SqueezeNet\train.txt", True)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)
test_data = MyDataset(r"F:\SqueezeNet\valid.txt", False)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)

上面,我们通过自定义的MyDataset类,分别加载了我们的train.txt文件以及valid.txt文件(后面的True参数代表我们要进行训练集的数据增强,而False代表进行验证集的数据增强)。然后,我们再通过我们的DataLoader来进行数据集的批量加载,之后就可以直接把加载好的 train_dl test_dl扔进模型里面训练。


具体实例可参考:文章来源地址https://www.toymoban.com/news/detail-667114.html

  • 深度学习实战基础案例——卷积神经网络(CNN)基于SqueezeNet的眼疾识别|第1例
  • Xception算法解析-鸟类识别实战-Paddle实战

到了这里,关于Pytorch的torch.utils.data中Dataset以及DataLoader等详解的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 一文弄懂Pytorch的DataLoader,Dataset,Sampler之间的关系

    以下内容都是针对Pytorch 1.0-1.1介绍。 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握重点,所以本文将会 自上而下 地对Pytorch数据读取方法进行介绍。 首先我们看一下

    2024年02月01日
    浏览(52)
  • 【Python从入门到人工智能】详解 PyTorch数据读取机制 DataLoader & Dataset(以人民币-RMB二分类实战 为例讲解,含完整源代码+问题解决)| 附:文心一言测试

      我想此后只要能以工作赚得生活费,不受意外的气,又有一点自己玩玩的余暇,就可以算是万分幸福了。                                                              ———《两地书》   🎯作者主页: 追光者♂🔥          🌸个人简介:

    2024年02月11日
    浏览(56)
  • pytorch 训练过程内存泄露/显存泄露debug记录:dataloader和dataset导致的泄露

    微调 mask-rcnn 代码,用的是 torchvision.models.detection.maskrcnn_resnet50_fpn 代码,根据该代码的注释,输入应该是: images, targets=None (List[Tensor], Optional[List[Dict[str, Tensor]]]) - Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] 所以我写的 dataset 是这样的: 大概思路是: 先把所有的标注信息读入内存

    2024年02月14日
    浏览(51)
  • 完美解决 AttributeError: module ‘torch.utils‘ has no attribute ‘data‘

    完美解决 AttributeError: module ‘torch.utils’ has no attribute ‘data’ 下滑查看解决方法 AttributeError: module ‘torch.utils‘ has no attribute ‘data‘ 这个错误通常是由于使用了过时的torch版本导致的。在旧的torch版本中,torch.utils.data模块是存在的,但在新版的torch中已经被移除,因此会出现

    2024年02月07日
    浏览(48)
  • 深度学习技术栈 —— Pytorch之TensorDataset、DataLoader

    简单来说, TensorDataset 与 DataLoader 这两个类的作用, 就是将数据读入并做整合,以便交给模型处理。就像石油加工厂一样,你不关心石油是如何采集与加工的,你关心的是自己去哪加油,油价是多少,对于一个模型而言,DataLoader就是这样的一个予取予求的数据服务商。 参考

    2024年01月24日
    浏览(44)
  • 【深度学习】PyTorch的dataloader制作自定义数据集

    PyTorch的dataloader是用于读取训练数据的工具,它可以自动将数据分割成小batch,并在训练过程中进行数据预处理。以下是制作PyTorch的dataloader的简单步骤: 导入必要的库 定义数据集类 需要自定义一个继承自 torch.utils.data.Dataset 的类,在该类中实现 __len__ 和 __getitem__ 方法。 创建

    2024年02月10日
    浏览(55)
  • 【代码笔记】Pytorch学习 DataLoader模块详解

    dataloader主要有6个class构成(可见下图) _DatasetKind: _InfiniteConstantSampler: DataLoader: _BaseDataLoaderIter: _SingleProcessDataLoaderIter: _MultiProcessingDataLoaderIter: 我们首先看一下DataLoader的整体结构: init : _get_iterator: multiprocessing_context: multiprocessing_context: setattr : iter : _auto_collation: _ind

    2023年04月11日
    浏览(39)
  • 关于Dataset和DataLoader的概念

    在机器学习中,Dataset和DataLoader是两个很重要的概念,它们通常用于训练和测试模型时的数据处理。 Dataset是指用于存储和管理数据的类。在深度学习中,通常将数据存储在Dataset中,并使用Dataset提供的方法读取和处理数据。Dataset可以是各种类型的数据,例如图像、文本、音频

    2023年04月12日
    浏览(65)
  • 【深度学习】Pytorch 系列教程(十二):PyTorch数据结构:4、数据集(Dataset)

             目录 一、前言 二、实验环境 三、PyTorch数据结构 0、分类 1、张量(Tensor) 2、张量操作(Tensor Operations) 3、变量(Variable) 4、数据集(Dataset) 随机洗牌           ChatGPT:         PyTorch是一个开源的机器学习框架,广泛应用于深度学习领域。它提供了丰富

    2024年02月07日
    浏览(44)
  • dataset dataloader tensor list情况

    如上面代码所示,getitem期望返回一个tensor list 一个tensor,但是调用dataloader时只能接收到一个list,从打印内容中可以看到,dataloader中将getitem中返回的两个值都合并了。 从上述验证代码中可以看到,即使是相同类型的返回值,也不能分开来接收,返回的具体值为: 程序也会自

    2024年02月11日
    浏览(39)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包