一文弄懂Pytorch的DataLoader,Dataset,Sampler之间的关系

这篇具有很好参考价值的文章主要介绍了一文弄懂Pytorch的DataLoader,Dataset,Sampler之间的关系。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

以下内容都是针对Pytorch 1.0-1.1介绍。

很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握重点,所以本文将会自上而下地对Pytorch数据读取方法进行介绍。

自上而下理解三者关系

首先我们看一下DataLoader.next的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据)。

class DataLoader(object):
	...
	
    def __next__(self):
        if self.num_workers == 0:  
            indices = next(self.sample_iter)  # Sampler
            batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
            if self.pin_memory:
                batch = _utils.pin_memory.pin_memory_batch(batch)
            return batch

在阅读上面代码前,我们可以假设我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取数据就只需要对应的index即可,即上面代码中的indices,而选取index的方式有多种,有按顺序的,也有乱序的,所以这个工作需要Sampler完成,现在你不需要具体的细节,后面会介绍,你只需要知道DataLoader和Sampler在这里产生关系。

那么Dataset和DataLoader在什么时候产生关系呢?没错就是下面一行。我们已经拿到了indices,那么下一步我们只需要根据index对数据进行读取即可了。

再下面的if语句的作用简单理解就是,如果pin_memory=True,那么Pytorch会采取一系列操作把数据拷贝到GPU,总之就是为了加速。

综上可以知道DataLoader,Sampler和Dataset三者关系如下:

一文弄懂Pytorch的DataLoader,Dataset,Sampler之间的关系

在阅读后文的过程中,你始终需要将上面的关系记在心里,这样能帮助你更好地理解。

Sampler

参数传递

要更加细致地理解Sampler原理,我们需要先阅读一下DataLoader 的源代码,如下:

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None)

可以看到初始化参数里有两种sampler:samplerbatch_sampler,都默认为None。前者的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,得到一个又一个batch的index。例如下面示例中,BatchSamplerSequentialSampler生成的index按照指定的batch size分组。

>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

Pytorch中已经实现的Sampler有如下几种:

  • SequentialSampler
  • RandomSampler
  • WeightedSampler
  • SubsetRandomSampler

需要注意的是DataLoader的部分初始化参数之间存在互斥关系,这个你可以通过阅读源码更深地理解,这里只做总结:

  • 如果你自定义了batch_sampler,那么这些参数都必须使用默认值:batch_sizeshuffle,sampler,drop_last.
  • 如果你自定义了sampler,那么shuffle需要设置为False
  • 如果samplerbatch_sampler都为None,那么batch_sampler使用Pytorch已经实现好的BatchSampler,而sampler分两种情况:
    • shuffle=True,则sampler=RandomSampler(dataset)
    • shuffle=False,则sampler=SequentialSampler(dataset)

如何自定义Sampler和BatchSampler?

仔细查看源代码其实可以发现,所有采样器其实都继承自同一个父类,即Sampler,其代码定义如下:

class Sampler(object):
    r"""Base class for all Samplers.
    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.
    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError
		
    def __len__(self):
        return len(self.data_source)

所以你要做的就是定义好__iter__(self)函数,不过要注意的是该函数的返回值需要是可迭代的。例如SequentialSampler返回的是iter(range(len(self.data_source)))

另外BatchSampler与其他Sampler的主要区别是它需要将Sampler作为参数进行打包,进而每次迭代返回以batch size为大小的index列表。也就是说在后面的读取数据过程中使用的都是batch sampler。

Dataset

Dataset定义方式如下:

class Dataset(object):
	def __init__(self):
		...
		
	def __getitem__(self, index):
		return ...
	
	def __len__(self):
		return ...

上面三个方法是最基本的,其中__getitem__是最主要的方法,它规定了如何读取数据。但是它又不同于一般的方法,因为它是python built-in方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问。假如你定义好了一个dataset,那么你可以直接通过dataset[0]来访问第一个数据。在此之前我一直没弄清楚__getitem__是什么作用,所以一直不知道该怎么进入到这个函数进行调试。现在如果你想对__getitem__方法进行调试,你可以写一个for循环遍历dataset来进行调试了,而不用构建dataloader等一大堆东西了,建议学会使用ipdb这个库,非常实用!!!以后有时间再写一篇ipdb的使用教程。另外,其实我们通过最前面的Dataloader的__next__函数可以看到DataLoader对数据的读取其实就是用了for循环来遍历数据,不用往上翻了,我直接复制了一遍,如下:

class DataLoader(object): 
    ... 
     
    def __next__(self): 
        if self.num_workers == 0:   
            indices = next(self.sample_iter)  
            batch = self.collate_fn([self.dataset[i] for i in indices]) # this line 
            if self.pin_memory: 
                batch = _utils.pin_memory.pin_memory_batch(batch) 
            return batch

我们仔细看可以发现,前面还有一个self.collate_fn方法,这个是干嘛用的呢?在介绍前我们需要知道每个参数的意义:

  • indices: 表示每一个iteration,sampler返回的indices,即一个batch size大小的索引列表
  • self.dataset[i]: 前面已经介绍了,这里就是对第i个数据进行读取操作,一般来说self.dataset[i]=(img, label)

看到这不难猜出collate_fn的作用就是将一个batch的数据进行合并操作。默认的collate_fn是将img和label分别合并成imgs和labels,所以如果你的__getitem__方法只是返回 img, label,那么你可以使用默认的collate_fn方法,但是如果你每次读取的数据有img, box, label等等,那么你就需要自定义collate_fn来将对应的数据合并成一个batch数据,这样方便后续的训练步骤。

 文章来源地址https://www.toymoban.com/news/detail-429871.html

到了这里,关于一文弄懂Pytorch的DataLoader,Dataset,Sampler之间的关系的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【Python从入门到人工智能】详解 PyTorch数据读取机制 DataLoader & Dataset(以人民币-RMB二分类实战 为例讲解,含完整源代码+问题解决)| 附:文心一言测试

      我想此后只要能以工作赚得生活费,不受意外的气,又有一点自己玩玩的余暇,就可以算是万分幸福了。                                                              ———《两地书》   🎯作者主页: 追光者♂🔥          🌸个人简介:

    2024年02月11日
    浏览(56)
  • 关于Dataset和DataLoader的概念

    在机器学习中,Dataset和DataLoader是两个很重要的概念,它们通常用于训练和测试模型时的数据处理。 Dataset是指用于存储和管理数据的类。在深度学习中,通常将数据存储在Dataset中,并使用Dataset提供的方法读取和处理数据。Dataset可以是各种类型的数据,例如图像、文本、音频

    2023年04月12日
    浏览(65)
  • dataset dataloader tensor list情况

    如上面代码所示,getitem期望返回一个tensor list 一个tensor,但是调用dataloader时只能接收到一个list,从打印内容中可以看到,dataloader中将getitem中返回的两个值都合并了。 从上述验证代码中可以看到,即使是相同类型的返回值,也不能分开来接收,返回的具体值为: 程序也会自

    2024年02月11日
    浏览(39)
  • K8s:一文认知 CRI,OCI,容器运行时,Pod 之间的关系

    博文内容整体结构为结合 华为云云原生课程 整理而来,部分内容做了补充 课程是免费的,有华为云账户就可以看,适合理论认知,感觉很不错。 有需要的小伙伴可以看看,链接在文末 理解不足小伙伴帮忙指正 对每个人而言,真正的职责只有一个:找到自我。然后在心中坚守

    2024年02月09日
    浏览(43)
  • 通俗易懂解释python和anaconda和pytorch以及pycharm之间的关系

    Python :Python 就像是一门编程语言的工具箱,你可以把它看作是一种通用的编程语言,就像是一把多功能的工具刀。你可以使用 Python 来编写各种类型的程序,就像使用工具刀来制作各种不同的手工艺品一样。 Anaconda :Anaconda 就像是一个装有不同种类工具的大工具箱。这个工

    2024年01月20日
    浏览(73)
  • cuda、cuDNN、深度学习框架、pytorch、tentsorflow、keras这些概念之间的关系

    当讨论CUDA、cuDNN、深度学习框架、pytorch、tensorflow、keras这些概念的时候,我们讨论的是与GPU加速深度学习相关的技术和工具。 CUDA(Compute Unified Device Architecture) : CUDA是由NVIDIA开发的一种并行计算平台和编程模型,旨在利用GPU(图形处理单元)进行通用目的的高性能计算。

    2024年02月12日
    浏览(42)
  • 【容器运行时】一文理解 OCI、runc、containerd、docker、shim进程、cri、kubelet 之间的关系

    docker,containerd,runc,docker-shim 之间的关系 Containerd shim 进程 PPID 之谜 内核大神教你从 Linux 进程的角度看 Docker RunC 简介 OCI和runC Containerd 简介 从 docker 到 runC Dockershim究竟是什么 技术干货|Docker和 Containerd 的区别,看这一篇就够了 Docker,containerd,CRI,CRI-O,OCI,runc 分不清?

    2024年02月05日
    浏览(70)
  • 一文搞懂 x64 IA-64 AMD64 Inte64 IA-32e 架构之间的关系

    想要搞清楚 x64、IA64、AMD64 指令集之间的关系,就要先了解 Intel 和 AMD 这两家公司在生产处理器上的发展历史。 1978年 Intel 生产了它的第一款 16bit 处理器8086,之后几款处理器名字也都以86结尾,包括80186,80286, 80386,80486,这些处理器的架构被统一称为 x86 架构。其中8086、

    2024年02月02日
    浏览(57)
  • PyTorch框架中torch、torchvision、torchaudio与python之间的版本对应关系(9月最新版)

    随着python语言和pytorch框架的更新,torchtorchvisiontorchaudio与python之间的版本对应关系也在不断地更新。 最新版本 torch与torchvision 对应关系如下: 稍旧版本 torch与torchvision 对应关系如下: 最新版本 torch与torchaudio 对应关系如下:

    2024年02月21日
    浏览(49)
  • pytorch中的DataLoader

    通常在训练时我们会将数据集分成若干小的、随机的批(batch),这个操作当然可以手动操作,但是pytorch里面为我们提供了API让我们方便地从dataset中获得batch,DataLoader就是来解决这个问题的。 DataLoader的本质是一个可迭代对象,即经过DataLoader的返回值为一个可迭代的对象,一

    2024年01月18日
    浏览(43)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包