Pytorch中Dataset和dadaloader的理解

这篇具有很好参考价值的文章主要介绍了Pytorch中Dataset和dadaloader的理解。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

不同的数据集在形式上千差万别,为了能够统一用于模型的训练,Pytorch框架下定义了一个dataset类和一个dataloader类。

dataset用于获取数据集中的样本,dataloader 用于抽取部分样本用于训练。比如说一个用于分割任务的图像数据集的结构如图1所示,一个样本由原图像和对应的mask组成。

Pytorch中Dataset和dadaloader的理解,深度学习(PyTorch),pytorch,人工智能,python

图1 典型数据集的结构

为了获取数据集,典型的代码如下

from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from torchvision import transforms

# 定义数据集
train_data_dir = 'dataset/train'
train_GT_dir = 'dataset/train_GT'

class MyData(Dataset):
    def __init__(self, imgdir, maskdir,transform):
        self.imgdir = imgdir
        self.maskdir = maskdir
        self.transform = transform
        self.img_list = os.listdir(self.imgdir)
        self.mask_list= os.listdir(self.maskdir)
        self.img_list.sort()
        self.mask_list.sort()

    def __getitem__(self, idx):
        img_name = self.img_list[idx]
        mask_name =self.mask_list[idx]
        img_item_path = os.path.join(self.imgdir, img_name)
        mask_item_path =os.path.join(self.maskdir,mask_name)

        img =Image.open(img_item_path)
        mask =Image.open(mask_item_path)

        img = self.transform(img)
        mask = self.transform(mask)

        return img, mask

    def __len__(self):
        assert len(self.img_list) == len(self.mask_list)
        return len(self.img_list)

if __name__ == '__main__':
    transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
    train_data_dir = 'dataset/train'
    train_GT_dir = 'dataset/train_GT'
    dataset = MyData(train_data_dir, train_GT_dir ,transform)
    dataloader = DataLoader(dataset, batch_size=4, num_workers=0)
    for step, (img,mask) in enumerate(dataloader):
        print(step)
        print(img.shape)
        print(mask.shape)
        if step>0:
            break

程序运行的结果如下:

Pytorch中Dataset和dadaloader的理解,深度学习(PyTorch),pytorch,人工智能,python

返回了一个batch的img 和mask 的尺寸,说明数据集抽取成功了.

在建立数据集的过程中需用重写__getitem()__和__len()__方法即可。文章来源地址https://www.toymoban.com/news/detail-823266.html

到了这里,关于Pytorch中Dataset和dadaloader的理解的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 深度学习之PyTorch实战(5)——对CrossEntropyLoss损失函数的理解与学习

      其实这个笔记起源于一个报错,报错内容也很简单,希望传入一个三维的tensor,但是得到了一个四维。 查看代码报错点,是出现在pytorch计算交叉熵损失的代码。其实在自己手写写语义分割的代码之前,我一直以为自己是对交叉熵损失完全了解的。但是实际上还是有一些些

    2023年04月09日
    浏览(44)
  • 谈谈Pytorch中的dataset

    关注B站查看更多手把手教学: 肆十二-的个人空间-肆十二-个人主页-哔哩哔哩视频 (bilibili.com) torch.utils.data.Dataset 是 PyTorch 中一个非常重要的抽象类,它用于表示数据集,方便数据加载和预处理。通过实现这个类的两个方法 __len__ 和 __getitem__ ,你可以自定义自己的数据集类。

    2024年03月11日
    浏览(36)
  • PyTorch深度学习实战(2)——PyTorch基础

    PyTorch 是广泛应用于机器学习领域中的强大开源框架,因其易用性和高效性备受青睐。在本节中,将介绍使用 PyTorch 构建神经网络的基础知识。首先了解 PyTorch 的核心数据类型——张量对象。然后,我们将深入研究用于张量对象的各种操作。 PyTorch 提供了许多帮助构建神经网

    2024年02月09日
    浏览(41)
  • 01_pytorch中的DataSet

    在pytorch 中, Dataset : 用于数据集的创建; DataLoader : 用于在训练过程中,传递获取一个batch的数据; 这里先介绍 pytorch 中的 Dataset 这个类, torch.utils.data. dataset.py 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。 数据集,其实就是一个负责

    2024年02月08日
    浏览(38)
  • 如何使用pytorch的Dataset, 来定义自己的Dataset

    Dataset与DataLoader的关系 Dataset: 构建一个数据集,其中含有所有的数据样本 DataLoader:将构建好的Dataset,通过shuffle、划分batch、多线程num_workers运行的方式,加载到可训练的迭代容器。 实战1:CSV数据集(结构化数据集) 实战2:图片数据集 ├── flower_data —├── flower_photo

    2024年01月22日
    浏览(51)
  • 33- PyTorch实现分类和线性回归 (PyTorch系列) (深度学习)

    知识要点  pytorch 最常见的创建模型 的方式, 子类 读取数据: data = pd.read_csv (\\\'./dataset/credit-a.csv\\\', header=None) 数据转换为tensor: X = torch .from_numpy(X.values).type(torch.FloatTensor) 创建简单模型: 定义损失函数: loss_fn = nn.BCELoss () 定义优化器: opt = torch.optim.SGD (model.parameters(), lr=0.00001) 把梯度

    2024年02月06日
    浏览(50)
  • PyTorch深度学习实战(3)——使用PyTorch构建神经网络

    我们已经学习了如何从零开始构建神经网络,神经网络通常包括输入层、隐藏层、输出层、激活函数、损失函数和学习率等基本组件。在本节中,我们将学习如何在简单数据集上使用 PyTorch 构建神经网络,利用张量对象操作和梯度值计算更新网络权重。 1.1 使用 PyTorch 构建神

    2024年02月08日
    浏览(47)
  • 【PyTorch与深度学习】2、PyTorch张量的运算API(上)

    课程地址 最近做实验发现自己还是基础框架上掌握得不好,于是开始重学一遍PyTorch框架,这个是课程笔记,这个课还是讲的简略,我半小时的课听了一个半小时。 (1) chunk :将一个张量分割为特定数目的张量,每个块都是输入张量的视图。 按维度0分割: 运行结果: b=

    2024年04月29日
    浏览(47)
  • Pytorch深度学习 - 学习笔记

    dir() :打开,看见包含什么 help() :说明书 pytorch中读取数据主要涉及到两个类 Dataset 和 Dataloader 。 Dataset可以将可以使用的数据提取出来,并且可以对数据完成编号。即提供一种方式获取数据及其对应真实的label值。 Dataloader为网络提供不同的数据形式。 Dataset Dataset是一个抽

    2024年02月07日
    浏览(45)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包