【深度学习】自定义数据集对象mydataset |继承torch.utils.data.Dataset类

这篇具有很好参考价值的文章主要介绍了【深度学习】自定义数据集对象mydataset |继承torch.utils.data.Dataset类。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

与datasets.ImageFolder类似,深度学习课题中还有一种很常用的自定义数据集的方法:继承torch.utils.data.Dataset类
可以参考我之前的博客:【深度学习】datasets.ImageFolder 使用方法

datasets.ImageFolder返回的对象和继承torch.utils.data.Dataset的自定义数据集(如MyDataset)生成的对象类型是一样的的吗?:
是的。它们都是torch.utils.data.Dataset类的实例,都实现了__len__和__getitem__方法,可以被传递给torch.utils.data.DataLoader用于数据的迭代和批处理等操作。虽然它们的实现方式不同,但是它们都符合了torch.utils.data.Dataset的接口规范,因此可以被视为同一类型的对象。

一、自定义mydataset的例子

比如我要从指定文件夹里读取图片生成数据集:

import os
from PIL import Image
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_filenames = os.listdir(root_dir)
    
    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, index):
        # 读取图像
        image_path = os.path.join(self.root_dir, self.image_filenames[index])
        image = Image.open(image_path).convert('RGB')
        
        # 对图像进行变换(如果有)
        if self.transform is not None:
            image = self.transform(image)
        
        return image

ImageDataset就是继承的dataset
最重要的就是这三部分:构造函数,两个魔术方法(len,getitem)

如果我对图像还有预处理的话,代码举例如下:

import os
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_filenames = os.listdir(root_dir)
    
    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, index):
        # 读取图像
        image_path = os.path.join(self.root_dir, self.image_filenames[index])
        image = Image.open(image_path).convert('RGB')
        
        # 对图像进行变换(如果有)
        if self.transform is not None:
            image = self.transform(image)
        
        return image

# 定义变换函数
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整图像大小
    transforms.ToTensor(),         # 将图像转换为张量
    transforms.Normalize(          # 归一化图像
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# 创建数据集实例
dataset = ImageDataset(root_dir='path/to/images', transform=transform)

# 获取第一张图像数据
image = dataset[0]

由于 transforms.ToTensor() 能够将 PIL.Image.Image 对象直接转换为张量,因此在这里可以直接使用 transforms.ToTensor() 进行转换,就不用再把PIL.Image.Image单独把转为ndarray了。具体来说,在 transforms.ToTensor() 中,会先将 PIL.Image.Image 对象转换为 numpy.ndarray 对象,然后再将其转换为张量。

二、torch.utils.data.Dataset长啥样

我们打开dataset类函数进去看看:

class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """
    functions: Dict[str, Callable] = {}

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

    def __getattr__(self, attribute_name):
        if attribute_name in Dataset.functions:
            function = functools.partial(Dataset.functions[attribute_name], self)
            return function
        else:
            raise AttributeError

    @classmethod
    def register_function(cls, function_name, function):
        cls.functions[function_name] = function

    @classmethod
    def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False):
        if function_name in cls.functions:
            raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))

        def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
            result_pipe = cls(source_dp, *args, **kwargs)
            if isinstance(result_pipe, Dataset):
                if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
                    if function_name not in UNTRACABLE_DATAFRAME_PIPES:
                        result_pipe = result_pipe.trace_as_dataframe()

            return result_pipe

        function = functools.partial(class_function, cls_to_register, enable_df_api_tracing)
        cls.functions[function_name] = function

chatgpt的解读为:Dataset 类的定义,它是一个抽象类,所有表示从键到数据样本的数据集都应该继承它。所有的子类都应该覆盖 getitem 方法,支持根据给定的键获取数据样本。子类还可以选择性地覆盖 len 方法,它被许多 torch.utils.data.Sampler 实现和 torch.utils.data.DataLoader 的默认选项所使用,用于返回数据集的大小。如果数据集的键不是整数类型,需要提供一个自定义的采样器(sampler)来使其与 torch.utils.data.DataLoader 兼容。此外,Dataset 类还提供了一些方法和属性,如 add 方法、getattr 方法等。

raise NotImplementedError 表示该方法还没有被实现,需要在子类中进行实现。在 Python 中,使用 NotImplementedError 异常可以方便地提示开发者该方法还未被实现,这也是一种规范的实现方式。当然,你也可以直接在子类中实现这两个方法,而不是使用 NotImplementedError。

三.一些使用过的继承dataset类总结

3.1.在图像去噪任务中,使用patch将单张图片分割为多个子图训练

import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os

class PatchDataset(Dataset):
    def __init__(self, noisy_image_folder, clean_image_folder, patch_size=64):
        self.noisy_image_folder = noisy_image_folder
        self.clean_image_folder = clean_image_folder
        self.patch_size = patch_size
        self.transform = transforms.Compose([
            transforms.Resize(patch_size + 16),
            transforms.RandomCrop(patch_size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5],
                std=[0.5]
            )
        ])
        self.noisy_image_paths = [os.path.join(noisy_image_folder, x) for x in os.listdir(noisy_image_folder)]
        self.clean_image_paths = [os.path.join(clean_image_folder, x) for x in os.listdir(clean_image_folder)]
    
    def __len__(self):
        return len(self.noisy_image_paths)
    
    def __getitem__(self, idx):
        # 读取图像
        noisy_image = Image.open(self.noisy_image_paths[idx]).convert('L')
        clean_image = Image.open(self.clean_image_paths[idx]).convert('L')
        # 对图像进行 patch 操作
        patches = []
        for i in range(4):  # 每个图像分割成 4 个 patch
            for j in range(4):
                x = j * self.patch_size
                y = i * self.patch_size
                noisy_patch = noisy_image.crop((x, y, x + self.patch_size, y + self.patch_size))
                clean_patch = clean_image.crop((x, y, x + self.patch_size, y + self.patch_size))
                noisy_patch = self.transform(noisy_patch)
                clean_patch = self.transform(clean_patch)
                patches.append((noisy_patch, clean_patch))
        return patches

在这个示例中,我们定义了一个名为 PatchDataset 的自定义数据集类,它继承自 PyTorch 的 Dataset 类。在 init 函数中,我们传入了有噪点图像文件夹路径 noisy_image_folder、无噪点图像文件夹路径 clean_image_folder 和 patch 的大小 patch_size,并定义了变换函数 self.transform。在 getitem 函数中,我们读取有噪点图像和无噪点图像并对它们进行 patch 操作,将得到的 16 个 patch 组成一个列表并返回,其中每个元素是一个包含有噪点 patch 与对应的无噪点 patch 的元组。
需要注意的是,在这个示例中,我们将有噪点 patch 和无噪点 patch 都进行了归一化。这是因为在图像去噪任务中,我们需要将有噪点图像输入模型进行训练,同时需要使用无噪点图像作为标签进行监督学习。因此,对有噪点图像和无噪点图像进行相同的归一化操作可以简化代码并提高训练效果。

3.2.在HDR图像重建任务中,dataset类中的transform应该是神马样的(.hdr.exr文件和jpg打开方式不太一样)

import torch
from torch.utils.data import Dataset
from torchvision import transforms
import cv2
import os

class PatchDataset(Dataset):
    def __init__(self, ldr_image_folder, hdr_image_folder, patch_size=64):
        self.ldr_image_folder = ldr_image_folder
        self.hdr_image_folder = hdr_image_folder
        self.patch_size = patch_size
        self.transform = transforms.Compose([
            transforms.Resize(patch_size + 16),
            transforms.RandomCrop(patch_size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5]
            )
        ])
        self.ldr_image_paths = [os.path.join(ldr_image_folder, x) for x in os.listdir(ldr_image_folder)]
        self.hdr_image_paths = [os.path.join(hdr_image_folder, x) for x in os.listdir(hdr_image_folder)]
    
    def __len__(self):
        return len(self.ldr_image_paths)
    
    def read_hdr_image(self, path):
        # 读取 HDR 图像
        hdr_image = cv2.imread(path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
        # 将像素值恢复到原始范围
        hdr_image = hdr_image / 65535.0 * 100.0  # 假设原始范围为 [0, 100]
        return hdr_image
    
    def __getitem__(self, idx):
        # 读取图像
        ldr_image = cv2.imread(self.ldr_image_paths[idx], cv2.IMREAD_COLOR)
        hdr_image = self.read_hdr_image(self.hdr_image_paths[idx])
        # 对图像进行 patch 操作
        patches = []
        for i in range(4):  # 每个图像分割成 4 个 patch
            for j in range(4):
                x = j * self.patch_size
                y = i * self.patch_size
                ldr_patch = ldr_image[y:y+self.patch_size, x:x+self.patch_size, :]
                hdr_patch = hdr_image[y:y+self.patch_size, x:x+self.patch_size, :]
                ldr_patch = self.transform(ldr_patch)
                hdr_patch = self.transform(hdr_patch)
                patches.append((ldr_patch, hdr_patch))
        return patches

我们首先定义了一个 read_hdr_image() 函数,它使用 OpenCV 的 imread() 函数来读取 HDR 图像,并将像素值恢复到原始范围。在 getitem 函数中,我们读取 LDR 图像和 HDR 图像并对它们进行 patch 操作,将得到的 16 个 patch 组成一个列表并返回,其中每个元素是一个包含 LDR patch 与对应的 HDR patch 的元组。

因为是chatgpt生成的代码 ,是否对错需要上电脑验证:
.hdr 或 .exr 格式的 HDR 图像的像素值通常是浮点数(比如在 OpenEXR 中,像素值的数据类型为 FLOAT),而且通常不是在 0 到 255 的范围内,而是在一个更大的范围内。具体的范围取决于采集设备和图像处理过程中所使用的参数。
在使用 OpenCV 来读取 HDR 图像时,它会将像素值缩放到 0 到 255 的范围内,因此我们需要手动将像素值恢复到原始范围。在这个示例中,我们假设原始范围为 [0, 100],因此我们将原始的浮点数像素值除以 65535(即 2**16-1,因为像素值在 OpenCV 中被存储为 16 位整数)并乘以 100,以将像素值恢复到原始范围。文章来源地址https://www.toymoban.com/news/detail-440507.html

到了这里,关于【深度学习】自定义数据集对象mydataset |继承torch.utils.data.Dataset类的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 深度学习——划分自定义数据集

    以人脸表情数据集raf_db为例,初始目录如下: 需要经过处理后返回 train_images, train_label, val_images, val_label 定义 read_split_data(root: str, val_rate: float = 0.2) 方法来解决,代码如下: 此时可通过以下代码获得训练集和测试集数据: 完结撒花。

    2024年02月14日
    浏览(39)
  • 【深度学习框架-torch】torch.norm函数详解用法

    torch版本 1.6 dim是matrix norm 如果 input 是 matrix norm ,也就是维度大于等于2维,则 P值默认为 fro , Frobenius norm 可认为是与计算向量的欧氏距离类似 有时候为了比较真实的矩阵和估计的矩阵值之间的误差 或者说比较真实矩阵和估计矩阵之间的相似性,我们可以采用 Frobenius 范数。

    2024年02月10日
    浏览(51)
  • 机器学习&&深度学习——torch.nn模块

    torch.nn模块包含着torch已经准备好的层,方便使用者调用构建网络。 卷积就是 输入和卷积核之间的内积运算 ,如下图: 容易发现,卷积神经网络中通过输入卷积核来进行卷积操作,使输入单元(图像或特征映射)和输出单元(特征映射)之间的连接时稀疏的,能够减少需要

    2024年02月15日
    浏览(42)
  • 深度学习torch基础知识

    detach是截断反向传播的梯度流 将某个node变成不需要梯度的Varibale。因此当反向传播经过这个node时,梯度就不会从这个node往前面传播。 拼接:将多个维度参数相同的张量连接成一个张量 torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0) module即表示你定义的模型,devic

    2024年02月13日
    浏览(48)
  • Pytorch目标分类深度学习自定义数据集训练

    目录 一,Pytorch简介; 二,环境配置; 三,自定义数据集; 四,模型训练; 五,模型验证;         PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。PyTorch 基于 Python: PyTorch 以 Python 为中心或“pythonic”,旨在深度集成 Python 代码,而不是

    2024年02月07日
    浏览(60)
  • Windows配置深度学习环境——torch+CUDA

    这里基于读者已经有使用Python的相关经验,就不介绍Python的安装过程。 win10+mx350+Python3.7.4+CUDA11.4.0+cudnn11.4 torch 1.11.0+cu113 torchaudio 0.11.0 torchvision 0.12.0+cu113 一般来说在命令行界面输入python就可以了解python版本。 也可以使用如下代码查询python版本。 以下是torch与Python版本的对应关

    2024年01月25日
    浏览(58)
  • 【深度学习】PyTorch的dataloader制作自定义数据集

    PyTorch的dataloader是用于读取训练数据的工具,它可以自动将数据分割成小batch,并在训练过程中进行数据预处理。以下是制作PyTorch的dataloader的简单步骤: 导入必要的库 定义数据集类 需要自定义一个继承自 torch.utils.data.Dataset 的类,在该类中实现 __len__ 和 __getitem__ 方法。 创建

    2024年02月10日
    浏览(55)
  • 深度学习—Python、Cuda、Cudnn、Torch环境配置搭建

    近期由于毕设需要使用Yolo,于是经过两天捣腾,加上看了CSDN上各位大佬的经验帖后,成功搭建好了GPU环境,并能成功使用。因而在此写下这次搭建的历程。 万事开头难,搭建环境很费时间,如果一开始版本不对应,到后面就要改来改去,很麻烦。首先要注意以下事项: 1.

    2024年02月11日
    浏览(211)
  • Anaconda配置深度学习环境并安装GPU版torch

    本人属于刚入学的小白,因为任务需要,所以得从零开始安装深度学习环境。对于从未接触过深度学习的人来讲,光配置环境就花费了我好久好久的时间,中间心态炸裂好几次,索性还是安装成功了。现在就从0开始复盘一下我的安装过程。不喜勿喷,出门右转不送。爷又不靠

    2024年02月06日
    浏览(67)
  • 【深度学习】多卡训练__单机多GPU方法详解(torch.nn.DataParallel、torch.distributed)

    多GPU训练能够加快模型的训练速度,而且在单卡上不能训练的模型可以使用多个小卡达到训练的目的。 多GPU训练可以分为单机多卡和多机多卡这两种,后面一种也就是分布式训练——训练方式比较麻烦,而且要关注的性能问题也有很多,据网上的资料有人建议能单机训练最好

    2024年02月02日
    浏览(36)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包