快速入门Torch读取自定义图像数据集

这篇具有很好参考价值的文章主要介绍了快速入门Torch读取自定义图像数据集。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

学习新技术当然首先要看官网了

所有数据集都是torch.utils.data.Dataset的子类,即实现了__getitem__和__len__方法。因此,它们都可以传递给torch.utils.data. dataloader,它可以使用torch并行加载多个样本。多处理工人。例如:

imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

就这???官方提供了许多内置好的数据集,但是我需要自定义啊!!!

还好官方上面文字说需要继承Dataset这个抽象类,实现__getitem__和__len__方法就ok了。

class CatDogDataSet(Dataset):
	def __init__(self):
		pass
		
    def __getitem__(self, index):
    	pass

    def __len__(self):
    	pass

我是谁?我在哪?我在干什么?完全不知道如何实现好吧

我知道ImageNet是从网上拉下来zip包解压后处理图片读取图片的,不妨看看ImageNet是如何实现的class ImageNet(ImageFolder):ImageFolder!!!这个类让我有预感我很快就可以copy了。果然datasets.ImageFolder(root)传入数据根目录且符合下面的格式就可以读取自定义数据集。

class ImageFolder(DatasetFolder):
    """A generic data loader where the images are arranged in this way by default: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/[...]/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/[...]/asd932_.png

完结撒花?我的数据集格式和ImageFolder需要的格式不一样

快速入门Torch读取自定义图像数据集,python,pytorch,深度学习,人工智能

最简单的方法当然是写个脚本整理为官方需求的格式,但是我不忘初心,说自定义就是自定义,copy99%也要自定义,而且移动数据的成本高,改改代码读取逻辑就能完成当然要改代码了

源码中find_classes方法,根据目录名定义classes变量改为classes = list(frozenset([i.split('.')[0] for i in os.listdir(directory)]))就可以了

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

再看数据集部分文章来源地址https://www.toymoban.com/news/detail-809094.html

    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        """
        源码是判断目录是否与当前target一致,一直则读取这一目录
        target_dir = os.path.join(directory, target_class)
		if not os.path.isdir(target_dir):
			continue
        """
        for root, _, fnames in sorted(os.walk(directory, followlinks=True)):
            for fname in sorted(fnames):
            	# TODO: 在此处添加判断,当前文件名是否包含target
                if target_class in fname:
                    path = os.path.join(root, fname)
                    if is_valid_file(path):
                        item = path, class_index
                        instances.append(item)

                        if target_class not in available_classes:
                            available_classes.add(target_class)

献上完整自定义数据集代码

import os
from typing import Dict, Optional, Tuple, Callable, List, Union, cast

from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import IMG_EXTENSIONS, has_file_allowed_extension


def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    classes = list(frozenset([i.split('.')[0] for i in os.listdir(directory)]))
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx


def make_dataset(
        directory: str,
        class_to_idx: Optional[Dict[str, int]] = None,
        extensions: Optional[Union[str, Tuple[str, ...]]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
    """Generates a list of samples of a form (path_to_sample, class).

    See :class:`DatasetFolder` for details.

    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
    by default.
    """
    directory = os.path.expanduser(directory)

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

    if extensions is not None:
        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]

    is_valid_file = cast(Callable[[str], bool], is_valid_file)

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        # target_dir = os.path.join(directory, target_class)
        # if not os.path.isdir(target_dir):
        #     continue
        for root, _, fnames in sorted(os.walk(directory, followlinks=True)):
            for fname in sorted(fnames):
                if target_class in fname:
                    path = os.path.join(root, fname)
                    if is_valid_file(path):
                        item = path, class_index
                        instances.append(item)

                        if target_class not in available_classes:
                            available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes:
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
            msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances


class CatDogLoader(ImageFolder):
    def __init__(
            self,
            root: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
        super().__init__(root,
                         transform,
                         target_transform,
                         is_valid_file=is_valid_file)
        classes, class_to_idx = self.find_classes(self.root)
        self.samples = self.make_dataset(self.root, class_to_idx, IMG_EXTENSIONS if is_valid_file is None else None,
                                         is_valid_file)

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
        return find_classes(directory)

    def make_dataset(
            self,
            directory: str,
            class_to_idx: Dict[str, int],
            extensions: Optional[Tuple[str, ...]] = None,
            is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> List[Tuple[str, int]]:
        if class_to_idx is None:
            raise ValueError("The class_to_idx parameter cannot be None.")
        return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

到了这里,关于快速入门Torch读取自定义图像数据集的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 无脑入门pytorch系列(二)—— torch.mean

    本系列教程适用于没有任何pytorch的同学(简单的python语法还是要的),从代码的表层出发挖掘代码的深层含义,理解具体的意思和内涵。pytorch的很多函数看着非常简单,但是其中包含了很多内容,不了解其中的意思就只能【看懂代码】,无法【理解代码】。 顾名思义,tor

    2024年02月14日
    浏览(39)
  • Python Qt6快速入门-自定义对话框和标准对话框

    对话框是有用的 GUI 组件,可以与用户进行交流(因此得名对话框)。 它们通常用于文件打开/保存、设置、首选项或不适合应用程序主 UI 的功能。 它们是位于主应用程序前面的小模态(或阻塞)窗口,直到它们被关闭。 Qt 为最常见的用例提供

    2024年02月03日
    浏览(53)
  • Pytorch数据类型转换(torch.tensor,torch.FloatTensor)

    之前遇到转为tensor转化为浮点型的问题,今天整理下,我只讲几个我常用的,如果有更好的方法,欢迎补充 1.首先讲下torch.tensor,默认整型数据类型为torch.int64,浮点型为torch.float32 2.这是我认为平常最爱用的转数据类型的方法,可以用dtype去定义数据类型 1.这个函数不要乱用

    2024年02月11日
    浏览(51)
  • PyTorch中torch、torchtext、torchvision、torchaudio与Python版本兼容性

    torch与torchtext,Python对应关系,来源:https://pypi.org/project/torchtext/ 截止发文,最新版本:torch 2.0.0,torchtext 0.15.1 安装方法: 或 torch与torchvision,Python对应关系,来源:https://github.com/pytorch/vision 截止发文,最新版本:torch 2.0.0,torchvision 0.15.1 安装方法: 或 torch与torchaudio,Pyt

    2024年02月04日
    浏览(81)
  • OpenCV读取图像时按照BGR的顺序HWC排列,PyTorch按照RGB的顺序CHW排列

    在OpenCV中,读取的图片默认是HWC格式,即按照高度、宽度和通道数的顺序排列图像尺寸的格式。我们看最后一个维度是C,因此最小颗粒度是C。 例如,一张形状为256×256×3的RGB图像,在OpenCV中读取后的格式为[256, 256, 3],其中最后一个维度表示图像的通道数。在OpenCV中,可以通

    2024年02月04日
    浏览(41)
  • OpenCV 入门教程:图像读取和显示

    2023年07月08日
    浏览(59)
  • pytorch快速入门中文——03

    原文:https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html#sphx-glr-beginner-blitz-neural-networks-tutorial-py 可以使用 torch.nn 包构建神经网络。 现在您已经了解了 autograd , nn 依赖于 autograd 来定义模型并对其进行微分。 nn.Module 包含层,以及返回 output 的方法 forward(input) 。 例如,

    2024年02月11日
    浏览(40)
  • pytorch快速入门中文——02

    原文:https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autograd-tutorial-py torch.autograd 是 PyTorch 的自动差分引擎,可为神经网络训练提供支持。 在本节中,您将获得有关 Autograd 如何帮助神经网络训练的概念性理解。 神经网络(NN)是在某些输入数据上执行的

    2024年02月11日
    浏览(35)
  • pytorch快速入门中文——01

    原文:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 作者 : Soumith Chintala https://www.youtube.com/embed/u7x8RXwLKcA PyTorch 是基于以下两个目的而打造的python科学计算框架: 无缝替换NumPy,并且通过利用GPU的算力来实现神经网络的加速。 通过自动微分机制,来让神经网络的实现

    2024年02月11日
    浏览(42)
  • 【深度学习】pytorch——快速入门

    笔记为自我总结整理的学习笔记,若有错误欢迎指出哟~ PyTorch是一个开源的机器学习框架,它提供了丰富的工具和库,用于构建和训练深度学习模型。下面是一些关于PyTorch的基本信息: 张量(Tensor)操作 :PyTorch中的核心对象是张量,它是一个多维数组。PyTorch提供了广泛的

    2024年02月06日
    浏览(44)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包