谈谈Pytorch中的dataset

这篇具有很好参考价值的文章主要介绍了谈谈Pytorch中的dataset。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

关注B站查看更多手把手教学:

肆十二-的个人空间-肆十二-个人主页-哔哩哔哩视频 (bilibili.com)

基本用法

torch.utils.data.Dataset 是 PyTorch 中一个非常重要的抽象类,它用于表示数据集,方便数据加载和预处理。通过实现这个类的两个方法 __len____getitem__,你可以自定义自己的数据集类。__len__ 方法应返回数据集的大小(即样本数),而 __getitem__ 方法则根据给定的索引返回一个样本。

以下是一个简单的示例,说明如何使用 torch.utils.data.Dataset 创建一个自定义的数据集类:

import torch  
from torch.utils.data import Dataset  
  
class MyCustomDataset(Dataset):  
    def __init__(self, data, targets):  
        """  
        参数:  
            data: 样本数据, 形状为 [num_samples, ...] (例如 [num_samples, num_channels, height, width])  
            targets: 样本标签, 形状为 [num_samples, ...] (例如 [num_samples])  
        """  
        self.data = data  
        self.targets = targets  
  
    def __len__(self):  
        # 返回数据集的样本数  
        return len(self.data)  
  
    def __getitem__(self, idx):  
        # 根据索引 idx 返回一个样本 (数据和标签)  
        return self.data[idx], self.targets[idx]  
  
# 示例数据和标签  
X = torch.randn(100, 3, 32, 32)  # 假设有 100 个 3x32x32 的样本  
y = torch.randint(0, 10, (100,))  # 假设有 100 个对应的标签 (0-9)  
  
# 创建数据集实例  
dataset = MyCustomDataset(X, y)  
  
# 可以使用 len() 获取数据集大小  
print(len(dataset))  # 输出: 100  
  
# 可以使用索引获取样本  
sample, label = dataset[0]  # 获取第一个样本和标签  
print(sample.shape)  # 输出: torch.Size([3, 32, 32])  
print(label)  # 输出: 一个整数 (0-9)

在上面的示例中,我们创建了一个名为 MyCustomDataset 的自定义数据集类,该类继承自 torch.utils.data.Dataset。在类的构造函数中,我们接收样本数据和标签,并将它们存储在类的实例变量中。我们还实现了 __len____getitem__ 方法,分别用于返回数据集的大小和根据索引获取样本。最后,我们创建了一个数据集实例,并展示了如何使用它来获取数据集的大小和样本。

标准数据集

在PyTorch的torchvision.datasets模块中,包含了多个标准的数据集,这些数据集在计算机视觉领域非常流行。以下是一些常用的标准数据集:

  1. MNIST:手写数字识别数据集,包含了大量的手写数字图片和对应的标签。
  2. CIFAR:包含CIFAR-10和CIFAR-100两个数据集,分别用于10类和100类的小图片分类任务。
  3. ImageNet:一个大规模的图片分类数据集,包含了上千万张标注过的图片,通常用于训练深度神经网络。在torchvision.datasets中,可以通过ImageFolder类来加载按文件夹组织的ImageNet风格的数据集。虽然完整的ImageNet数据集很大并不直接包含在torchvision.datasets中,但PyTorch提供了处理这种数据集的工具。
  4. COCO (Common Objects in Context):用于图像标注、目标检测和语义分割的大型数据集。它包含了图片、物体的标注框、分割掩码以及关键点等信息。
  5. LSUN (Large-scale Scene UNderstanding):场景理解的大型数据集,包含了不同类别的场景图片。
  6. FashionMNIST:类似于MNIST,但是用于时尚服装和配饰的图片分类。
  7. SVHN (Street View House Numbers):从谷歌街景图片中提取的门牌号识别数据集。
  8. PhotoTour:用于图像匹配的数据集,包含了从不同角度拍摄的同一景点的图片对。
  9. STL10:一个用于无监督学习和半监督学习的图像数据集,包含了少量的标注数据和大量的无标注数据。
  10. Kinetics:用于视频动作识别的大型数据集。
  11. CelebA (CelebFaces Attributes):用于人脸检测和属性识别的大型人脸数据集。

这些标准数据集可以通过简单地调用torchvision.datasets中的相应类来加载和预处理。例如,加载MNIST数据集可以通过以下代码实现:

import torchvision.datasets as dsets  
  
# 加载MNIST训练集  
train_dataset = dsets.MNIST(root='./data',  
                            train=True,  
                            transform=transforms.ToTensor(),  
                            download=True)  
  
# 加载MNIST测试集  
test_dataset = dsets.MNIST(root='./data',  
                           train=False,  
                           transform=transforms.ToTensor())

注意,上面的代码中使用了transforms.ToTensor()来对图片进行预处理,将其转换为PyTorch的Tensor格式。在实际使用中,你可能还需要根据具体任务添加其他的预处理步骤,比如裁剪、归一化等。这些都可以通过组合torchvision.transforms中的不同变换来实现。文章来源地址https://www.toymoban.com/news/detail-838561.html

到了这里,关于谈谈Pytorch中的dataset的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 学习pytorch 2 导入查看dataset

    B站小土堆视频 https://download.pytorch.org/tutorial/hymenoptera_data.zip

    2024年02月13日
    浏览(32)
  • 【深度学习】Pytorch 系列教程(十二):PyTorch数据结构:4、数据集(Dataset)

             目录 一、前言 二、实验环境 三、PyTorch数据结构 0、分类 1、张量(Tensor) 2、张量操作(Tensor Operations) 3、变量(Variable) 4、数据集(Dataset) 随机洗牌           ChatGPT:         PyTorch是一个开源的机器学习框架,广泛应用于深度学习领域。它提供了丰富

    2024年02月07日
    浏览(43)
  • Pytorch Dataset类的使用(个人学习笔记)

    训练模型一般都是先处理 数据的输入问题 和 预处理问题 。 Pytorch提供了几个有用的工具: torch.utils.data.Dataset类 和 torch.utils.data.DataLoader类。 流程是先把 原始数据 转变成 torch.utils.data.Dataset类 , 随后再把得到 torch.utils.data.Dataset类 当作一个参数传递给 torch.utils.data.DataLoader类

    2024年02月05日
    浏览(38)
  • Pytorch中Dataset和dadaloader的理解

    不同的数据集在形式上千差万别,为了能够统一用于模型的训练,Pytorch框架下定义了一个dataset类和一个dataloader类。 dataset用于获取数据集中的样本,dataloader 用于抽取部分样本用于训练。比如说一个用于分割任务的图像数据集的结构如图1所示,一个样本由原图像和对应的m

    2024年01月25日
    浏览(29)
  • PyTorch翻译官网教程3-DATASETS & DATALOADERS

    Datasets DataLoaders — PyTorch Tutorials 2.0.1+cu117 documentation 处理样本数据的代码可能会变得混乱并且难以维护。理想情况下,我们希望我们的数据集代码与模型训练代码解耦,以获得更好的可读性和模块化。PyTorch提供了两个数据源:torch.utils.data.DataLoader和torch.utils.data.Dataset,它们允

    2024年02月11日
    浏览(47)
  • 一文弄懂Pytorch的DataLoader,Dataset,Sampler之间的关系

    以下内容都是针对Pytorch 1.0-1.1介绍。 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握重点,所以本文将会 自上而下 地对Pytorch数据读取方法进行介绍。 首先我们看一下

    2024年02月01日
    浏览(44)
  • PyTorch 深度学习之加载数据集Dataset and DataLoader(七)

    全部Batch:计算速度,性能有问题 1 个 :跨越鞍点 mini-Batch:均衡速度与性能 两种处理数据的方式 linux 与 windows 多线程不一样 torchvision 内置数据集 MINIST Dataset

    2024年02月07日
    浏览(41)
  • Pytorch的torch.utils.data中Dataset以及DataLoader等详解

    在我们进行深度学习的过程中,不免要用到数据集,那么数据集是如何加载到我们的模型中进行训练的呢?以往我们大多数初学者肯定都是拿网上的代码直接用,但是它底层的原理到底是什么还是不太清楚。所以今天就从内置的Dataset函数和自定义的Dataset函数做一个详细的解

    2024年02月11日
    浏览(48)
  • pytorch 训练过程内存泄露/显存泄露debug记录:dataloader和dataset导致的泄露

    微调 mask-rcnn 代码,用的是 torchvision.models.detection.maskrcnn_resnet50_fpn 代码,根据该代码的注释,输入应该是: images, targets=None (List[Tensor], Optional[List[Dict[str, Tensor]]]) - Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] 所以我写的 dataset 是这样的: 大概思路是: 先把所有的标注信息读入内存

    2024年02月14日
    浏览(48)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包