【代码笔记】Pytorch学习 DataLoader模块详解

这篇具有很好参考价值的文章主要介绍了【代码笔记】Pytorch学习 DataLoader模块详解。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

dataloader整体结构

dataloader主要有6个class构成(可见下图)

  • _DatasetKind:
  • _InfiniteConstantSampler:
  • DataLoader:
  • _BaseDataLoaderIter:
  • _SingleProcessDataLoaderIter:
  • _MultiProcessingDataLoaderIter:
    【代码笔记】Pytorch学习 DataLoader模块详解

DataLoader

我们首先看一下DataLoader的整体结构:

  • init:
  • _get_iterator:
  • multiprocessing_context:
  • multiprocessing_context:
  • setattr:
  • iter:
  • _auto_collation:
  • _index_sampler:
  • len:
  • check_worker_number_rationality:

init 初始化

参数解释

这里会把参数全部列出,这里列出的目的是让大家知道各个参数的意义。实际上很多是用不到的,我用加粗字体表示一些常用的参数。

  • self:代之Dataset这个类本身
  • dataset: Dataset[T_co]是默认值,是你要处理的数据集
  • batch_size: Optional[int] = 1, 可选,默认是1。每个batch可以加载batct_size个数据。
  • shuffle: bool = False, 每轮训练后是否将数据集打乱
  • sampler: Optional[Sampler] = None, 默认是None 自定义方法(某种顺序)从Dataset中取样本,指定这个参数就不能设置shuffle。因为shuffle是打乱数据集的顺序,而sample是以某种顺序取数据,所以二者互斥!sampler可能是获取一整个数据集的数据,是对一整个数据集进行操作,而不是一个batch_size。
  • batch_sampler: Optional[Sampler[Sequence]] = None, 返回一个batch的索引,与batch_size, shuffle, sampler, drop_last互斥
    传入了batch_sampler,相当于已经告诉了PyTorch如何从Dataset取多少数据,怎么取数据去组成一个mini batch,所以不需要以上参数。可以理解为batch_sampler是batch_size和sampler的结合,所以不需要batch_size, sampler, shuffle, drop_last(因为drop_last也是怎么取数据)。
  • num_workers: int = 0, 多进程加载数据,默认为0,即采用主进程加载数据
  • collate_fn: Optional[_collate_fn_t] = None, 聚集函数,用来对一个batch进行后处理,拿到一个batch的数据后进行什么处理,返回处理后的batch数据。默认源码中进行了若干逻辑判断,仅将数据组合起来返回,没有实质性工作。默认collate_fn的声明是:def default_collate(batch): 所以自定义collate_fn需要以batch为输入,以处理后的batch为输出。类似于transform,transform是对单个数据处理,而collate_fn是对单个batch做处理。
  • pin_memory: bool = False, 用于将tensor加载到GPU中进行运算
  • drop_last: bool = False, 是否保存最后一个mini batch,样本数量可能不支持被batch size整除,所以drop_last参数决定是否保留最后一个可能批量较小的batch
  • timeout: float = 0, 控制从进程中获取一个batch数据的时延
  • worker_init_fn: Optional[_worker_init_fn_t] = None, 初始化子进程
  • multiprocessing_context=None,
  • generator=None,
  • prefetch_factor: int = 2, 控制样本在每个进程里的预加载,默认为2
  • persistent_workers: bool = False 控制加载完一次Dataset是否保留进程,默认为False
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: bool = False, sampler: Optional[Sampler] = None,
                 batch_sampler: Optional[Sampler[Sequence]] = None,
                 num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
                 pin_memory: bool = False, drop_last: bool = False,
                 timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
                 multiprocessing_context=None, generator=None,
                 *, prefetch_factor: int = 2,
                 persistent_workers: bool = False):

代码解析

在DataLoader的__init__函数里,我们可以看到,它实现了:

  1. 判断是否是IterableDataset类型,如果是需要进一步判断参数是否正确
  2. 构建Sampler,单样本
  3. 构建BatchSampler,
  4. 组建batch 构建collate
  5. 其他的一些逻辑判断
IterableDataset 判断
  • IterableDataset应用于数据集非常大,将其完全加载进内存不现实(例如高达几个TB的数据),这时就需要IterableDataset构建可迭代的Dataset类,自定义的Dataset需要继承自torch.util.data.IterableDataset,重写__iter__方法,返回可迭代对象(通常是yield生成器)
  • 对于IterableDataset来说,就没有构建采样器Sampler的需求,因为样本是通过调用__iter__一个个读取出来的。执行封装的DataLoader传进去的batch_size次__iter__方法,就获取到一个mini batch
# 判断dataset是否是IterableDataset类型
 if isinstance(dataset, IterableDataset):
     self._dataset_kind = _DatasetKind.Iterable
     # 按照__iter__获取数据,所以不需要打乱
     if shuffle is not False:
         raise ValueError(
             "DataLoader with IterableDataset: expected unspecified ""shuffle option, but got shuffle={}".format(shuffle))
     elif sampler is not None:
         # 按照__iter__获取数据,也不再需要sampler获取数据
         raise ValueError("DataLoader with IterableDataset: expected unspecified ""sampler option, but got sampler={}".format(sampler))
     elif batch_sampler is not None:
         # 按照__iter__获取数据,也不再需要batch_sampler获取数据索引
         raise ValueError("DataLoader with IterableDataset: expected unspecified " "batch_sampler option, but got batch_sampler={}".format(batch_sampler))
 else:
     self._dataset_kind = _DatasetKind.Map

构建Sampler,单样本
if sampler is None:  # give default samplers
     if self._dataset_kind == _DatasetKind.Iterable:
         # 如果是Iterable的Dataset,就采用迭代的方式获取sampler
         sampler = _InfiniteConstantSampler()
     else:  # 否则判断是否使用shuffle,使用则随机产生sampler,不使用就按照顺序产生sampler
         if shuffle:
             sampler = RandomSampler(dataset, generator=generator)
         else:
             sampler = SequentialSampler(dataset)
构建BatchSampler,组建batch
  • 注意,上面说batch_sampler不能和batch_size、sampler、drop_last同时使用是指:如果已经定义了batch_sampler则与batch_size和sampler互斥!!!前提是已经定义了batch_sampler!!!但是如果没有定义batch_sampler,则可以通过batch_size,sampler,dorp_last来组建batch!!!
# 要取batch_size个sampler,但是还没有取,即batch_sampler==None
if batch_size is not None and batch_sampler is None:
     # 获取batch_size个sampler个索引
     batch_sampler = BatchSampler(sampler, batch_size, drop_last)
构建collate_fn 对获取的batch进行处理
if collate_fn is None:
     if self._auto_collation:
     # 默认的实际上什么也没干
         collate_fn = _utils.collate.default_collate
     else:
         collate_fn = _utils.collate.default_convert
其他的一些逻辑判断
# sampler 不能和 shuffle 同时出现
# 因为shuffle是将数据打乱,而sampler是按照某一顺序获取数据
 if sampler is not None and shuffle:
     raise ValueError('sampler option is mutually exclusive with ''shuffle')

 if batch_sampler is not None:
     # batch_sampler不能和batch_size,shuffle,sampler,drop_last同时使用。
     # batch_sampler可以理解为batch_size和sampler的结合
     if batch_size != 1 or shuffle or sampler is not None or drop_last:
         raise ValueError('batch_sampler option is mutually exclusive ''with batch_size, shuffle, sampler, and '                          'drop_last')
     batch_size = None
     drop_last = False
 elif batch_size is None:
     # batch_size为None,默认是1,如果drop_last为True就会舍弃最后一个,这样数据就会减少。(构成了一个batch但是仍然舍弃掉)
     if drop_last:
         raise ValueError('batch_size=None option disables auto-batching ''and is mutually exclusive with drop_last')
 
 self.collate_fn = collate_fn
 self.persistent_workers = persistent_workers

 self.__initialized = True
 self._IterableDataset_len_called = None  # See NOTE [ IterableDataset and __len__ ]

 self._iterator = None

 self.check_worker_number_rationality()

 torch.set_vital('Dataloader', 'enabled', 'True')  # type: ignore[attr-defined]

_get_iterator

代码解析

def _get_iterator(self) -> '_BaseDataLoaderIter':
    if self.num_workers == 0:
    # 单线程
        return _SingleProcessDataLoaderIter(self)
    else:
    # 多线程
        self.check_worker_number_rationality()
        return _MultiProcessingDataLoaderIter(self)

multiprocessing_context

multiprocessing_context

setattr

iter

代码解释

 # 其中 -> '_BaseDataLoaderIter' 是函数注释,运行时跟没有加注解之前的效果也没有任何差距。
 # 主要作用是提醒程序猿这里应该是 '_BaseDataLoaderIter'的数据类型
 def __iter__(self) -> '_BaseDataLoaderIter':
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()

_auto_collation

代码解析

	@property
    def _auto_collation(self):
    # 根据batch_sampler判断是否设置_auto_collation
        return self.batch_sampler is not None

_index_sampler

len

check_worker_number_rationality

_SingleProcessDataLoaderIter

代码解析

def __init__(self, loader):
    super(_SingleProcessDataLoaderIter, self).__init__(loader)
    assert self._timeout == 0
    assert self._num_workers == 0

    self._dataset_fetcher = _DatasetKind.create_fetcher(
        self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

 def _next_data(self):
 	 # 获取索引
     index = self._next_index()  # may raise StopIteration
     # 获取数据
     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
     if self._pin_memory:
         data = _utils.pin_memory.pin_memory(data)
     # 返回数据
     return data

_BaseDataLoaderIter

__next__方法会调用_next_data,_next_data获取一个batch的数据文章来源地址https://www.toymoban.com/news/detail-410386.html

到了这里,关于【代码笔记】Pytorch学习 DataLoader模块详解的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Pytorch的torch.utils.data中Dataset以及DataLoader等详解

    在我们进行深度学习的过程中,不免要用到数据集,那么数据集是如何加载到我们的模型中进行训练的呢?以往我们大多数初学者肯定都是拿网上的代码直接用,但是它底层的原理到底是什么还是不太清楚。所以今天就从内置的Dataset函数和自定义的Dataset函数做一个详细的解

    2024年02月11日
    浏览(51)
  • pytorch进阶学习(二):使用DataLoader读取自己的数据集

    上一节使用的是官方数据集fashionminist进行训练,这节课使用自己搜集的数据集来进行数据的获取和训练。 教学视频:https://www.bilibili.com/video/BV1by4y1b7hX/?spm_id_from=333.1007.top_right_bar_window_history.content.clickvd_source=e482aea0f5ebf492c0b0220fb64f98d3 pytorch进阶学习(一):https://blog.csdn.net/w

    2024年02月09日
    浏览(44)
  • 【图像分割】【深度学习】SAM官方Pytorch代码-各模块的功能解析

    Segment Anything:建立了迄今为止最大的分割数据集,在1100万张图像上有超过1亿个掩码,模型的设计和训练是灵活的,其重要的特点是Zero-shot(零样本迁移性)转移到新的图像分布和任务,一个图像分割新的任务、模型和数据集。SAM由三个部分组成:一个强大的图像编码器(Image

    2024年02月11日
    浏览(39)
  • pytorch中的DataLoader

    通常在训练时我们会将数据集分成若干小的、随机的批(batch),这个操作当然可以手动操作,但是pytorch里面为我们提供了API让我们方便地从dataset中获得batch,DataLoader就是来解决这个问题的。 DataLoader的本质是一个可迭代对象,即经过DataLoader的返回值为一个可迭代的对象,一

    2024年01月18日
    浏览(42)
  • Apollo规划模块代码学习(1): 算法架构原理、运行机制一文详解

    Apollo开源自动驾驶平台中,高清地图模块提供了每个在线模块都可以访问的高清地图。感知和定位模块提供了必要的动态环境信息,可以在预测模块中进一步用于预测未来的环境状态。运动规划模块考虑所有信息,以生成安全平滑的轨迹,并将其输入车辆控制模块。 目前Ap

    2024年01月25日
    浏览(47)
  • DataLoader PyTorch 主要参数的含义

    定义: DataLoader类是一个用于从数据集(dataset)中加载数据,并以迭代器(iterator)的形式返回数据样本(data samples)的工具¹²。您给出的两个字典(dictionary)分别是训练集(train set)和测试集(test set)的数据加载参数,下面我会逐一解释它们的含义和默认值:   举例演示

    2024年02月11日
    浏览(41)
  • FPGA学习笔记:verilog基础代码与modelsim仿真(六)——vga显示模块

    VGA显示 目标:实现屏幕红、橙、黄、绿、青、蓝、紫、黑、白、灰条形显示 1. 模块框图与波形图 vga_colorbar是实现目标功能的总体模块框图,为了实现对应的输出,我们使用三个具体功能模块实现功能。 (1) clk_gen——使用pll锁相环实现时钟分频 (2)vga_ctrl——图像控制与输出模

    2024年02月04日
    浏览(42)
  • pytorch实战5——DataLoader数据集制作

    目录 1.如何自定义数据集: 咱们以花朵数据集为例: 任务1:读取txt文件中的路径和标签 任务2:通过上面字典返回数据,分别把数据和标签都存在list里 任务3:图像数据路径得完整 任务4:把上面那几个事得写在一起,整合到一个类。 任务5:数据预处理(transform)¶ 任务6:根据

    2024年02月04日
    浏览(42)
  • 解决pytorch中Dataloader读取数据太慢的问题

    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 最近在使用pytorch框架进行模型训练时遇到一个性能问题,即数据读取的速度远远大于GPU训练的速度,导致整个训练流程中有大部分时间都在等待数据发送到GPU,在资源管理器中呈现出CUDA使用率周期性波

    2023年04月11日
    浏览(51)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包