pytorch 中的数据集与多进程并发

这篇具有很好参考价值的文章主要介绍了pytorch 中的数据集与多进程并发。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

总述

需要 Dataset + collate_fn + Sampler + DataLoader 联用, 才等价于 tf 的 dataset.

  • DataLoader, 对外服务的类. 通过 _get_iterator() 方法返回 iterator, 对其调用 next() 得到 tensor.
  • Sampler, 数据集的采样策略, 给出每个 step 要使用的数据的索引, 记为 possibly_batched_index.
  • Fetcher, 根据 possibly_batched_index, 从 dataset 对象中拿数据
  • collate_fn, Fetcher 对象拿到原始数据后, 调用 collate_fn 得到 tensor 对象, 送往模型.

一. Dataset

torch.utils.data.Dataset, 这是一个抽象类, 自己需要实现它的子类.

class Dataset(Generic[T_co]):
    def __getitem__(self, index) -> T_co:
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError

二. Sampler

torch.utils.data.Sampler, 也是一个抽象类.
默认的是 SequentialSampler + BatchSampler 的搭配.

class Sampler(Generic[T_co]):
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

class SequentialSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.

    Args:
        data_source (Dataset): dataset to sample from
    """
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

class BatchSampler(Sampler[List[int]]):
	def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
		pass

    def __iter__(self) -> Iterator[List[int]]:
	    sampler_iter = iter(self.sampler)
	       while True:
	           try:
	               batch = [next(sampler_iter) for _ in range(self.batch_size)]
	               yield batch
	           except StopIteration:
	               break

三. collate_fn

一个接口, 完成基本类型数据到 batch tensor 的处理. 方法签名见下:

  • def my_collate_fn(feature_dict_list: List[Dict[str, Union[str, Tensor]]]) -> Dict[str, Tensor]
    • feature_dict_list. 元素个数为 batch_size, 元素为 Dict[str, Any], 通常为基本数据类型.
    • return: Dict[str, Tensor], tensor_.shape[0] 通常为相应的 batch_size.

Q: 如何额外传参?
方法签名中看到没有额外的传参设计, 那么我们想传一些参数配置(比如不同的特征处理规则), 想做到通用化, 要怎么办呢?
A: 传一个 callable 对象即可. 做法为自定义 MyCollator 类, init 方法传入配置, 并实现 __call__(self, xxx)方法, 签名与 collate_fn 保持一致即可.

四. DataLoader

依赖 dataset, 负责 batch, shuffle 等能力增强, 返回是 Tensor 对象.

  • DataLoader#__init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,...,collate_fn, ...)
    • collate_fn (callable, optional): merges a list of samples to form a
      mini-batch of Tensor(s).
      如果不传, 内部会赋值为torch.utils.data._utils.collate.default_collate, 已足够好用, 见下节例子.
      在 collate_fn 中可做灵活处理, 等价于 tf.dataset.map(map_fn).
class DataLoader(Generic[T_co]):
	def __init__(self, dataset: Dataset[T_co], 
					batch_size: Optional[int] = 1,  
					num_workers: int = 0, 
					collate_fn: Optional[_collate_fn_t] = None, 
					worker_init_fn: Optional[_worker_init_fn_t] = None,
					)
        if sampler is None:  # give default samplers
            if self._dataset_kind == _DatasetKind.Iterable:
                # See NOTE [ Custom Samplers and IterableDataset ]
                sampler = _InfiniteConstantSampler()
            else:  # map-style
                if shuffle:
                    sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]
                else:
                    sampler = SequentialSampler(dataset)  # type: ignore[arg-type]
        if batch_size is not None and batch_sampler is None:
            # auto_collation without custom batch_sampler
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

    def _next_index(self):
        return next(self._sampler_iter)	
        	
    def _get_iterator(self) -> '_BaseDataLoaderIter':
       if self.num_workers == 0:
           return _SingleProcessDataLoaderIter(self)
       else:
           self.check_worker_number_rationality()
           return _MultiProcessingDataLoaderIter(self)

4.1 _DatasetKind

class _DatasetKind(object):
    Map = 0
    Iterable = 1

    @staticmethod
    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        if kind == _DatasetKind.Map:
            return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
        else:
            return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)

4.2 Fetcher

class _IterableDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
        self.dataset_iter = iter(dataset)
        self.ended = False

    def fetch(self, possibly_batched_index):
        if self.ended:
            raise StopIteration

        if self.auto_collation:
            data = []
            for _ in possibly_batched_index:
                try:
                    data.append(next(self.dataset_iter))
                except StopIteration:
                    self.ended = True
                    break
            if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
                raise StopIteration
        else:
            data = next(self.dataset_iter)
        return self.collate_fn(data)

4.3 _SingleProcessDataLoaderIter

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        return data

五. dataloader 多进程

主进程不再从自己的 dataset 对象中拿数据, 而是全靠子进程往队列里放, 自己从队列里拿, 是一个 “多生产者, 单消费者” 数据传递.

5.1 主进程

_MultiProcessingDataLoaderIter

class _BaseDataLoaderIter(object):
	def __iter__(self) -> '_BaseDataLoaderIter':
        return self
        
    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration

    def _next_data(self):
        raise NotImplementedError


class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
	def __init__(self, loader):
		self._worker_init_fn = loader.worker_init_fn
		self._worker_result_queue = multiprocessing_context.Queue()
		self._index_queues = []
		self._workers = []
		for i in range(self._num_workers):
			index_queue = multiprocessing_context.Queue()
            w = multiprocessing_context.Process(
                target=_utils.worker._worker_loop,
                args=(self._dataset_kind, self._dataset, index_queue,
                      self._worker_result_queue, self._workers_done_event,
                      self._auto_collation, self._collate_fn, self._drop_last,
                      self._base_seed, self._worker_init_fn, i, self._num_workers,
                      self._persistent_workers, self._shared_seed))
            w.daemon = True
            w.start()
            self._index_queues.append(index_queue)
            self._workers.append(w)
            self._reset(loader, first_iter=True)
		
    def _next_data(self):
        while True:
        	idx, data = self._get_data()
        	return self._process_data(data)

	def _get_data(self):
	     while True:
	         success, data = self._try_get_data()
	         if success:
	             return data

    def _process_data(self, data):
        self._rcvd_idx += 1
        self._try_put_index()
        if isinstance(data, ExceptionWrapper):
            data.reraise()
        return data
        
    def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
        # Tries to fetch data from `self._data_queue` once for a given timeout.
        # This can also be used as inner loop of fetching without timeout, with
        # the sender status as the loop condition.
        #
        # This raises a `RuntimeError` if any worker died expectedly. This error
        # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
        # (only for non-Windows platforms), or the manual check below on errors
        # and timeouts.
        #
        # Returns a 2-tuple:
        #   (bool: whether successfully get data, any: data if successful else None)
        try:
            data = self._data_queue.get(timeout=timeout)
            return (True, data)

5.2 子进程

子进程的 target 逻辑在 worker.py 中.
dataset 对象由主进程序列化后通过管道传递给子进程, 然后反序列化处理.

def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
                 auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
                 num_workers, persistent_workers, shared_seed):
    try:
        global _worker_info
        _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
                                    seed=seed, dataset=dataset)
        if init_fn is not None:
            init_fn(worker_id)
            
        fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
        watchdog = ManagerWatchdog()
        while watchdog.is_alive():
            r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
            data = fetcher.fetch(index)
            data_queue.put((idx, data))
            del data, idx, index, r  # save memory
    except KeyboardInterrupt:
        # Main process will raise KeyboardInterrupt anyways.
        pass

reader 句柄处理

如果是流式数据集, reader 等句柄需要延迟打开, 否则会因被支持序列化而报错.

数据切片与 WorkerInfo

dataset 对象从主进程序列化而来, 直接用的话, 读的数据必然是重复的.
子进程启动后, 会维护全局变量 _worker_info, _DatasetKind.create_fetcher()方法内会触发 iter(dataset) , 需要 dataset 通过 torch.utils.data.get_worker_info() 主动感知子进程的环境, 调整数据分片.

六. IterableDataset 流式数据集

详见 参考[1].
典型场景是内存盛不下, 网络数据库 -> dataset -> model feed 流式运作.
例子见下.

import numpy as np
from torch.utils.data import IterableDataset, DataLoader


class StreamingDataset(IterableDataset):

    def generator(self):
        i = 0
        data = np.arange(0,10).reshape((5, 2))
        while True:
            if i == len(data):
                break
            yield {'sample_id': i, 'value': data[i]}
            i += 1

    def __iter__(self):
        return iter(self.generator())


def dataset_test():
    it = iter(StreamingDataset())
    print(next(it))
    print(next(it))


def loader_test():
    loader = DataLoader(StreamingDataset(), batch_size=2)
    it = iter(loader)
    print(next(it), next(it)) 

if __name__ == '__main__':
    loader_test()

"""
dataset_test()
{'sample_id': 0, 'value': array([0, 1])}
{'sample_id': 1, 'value': array([2, 3])}


loader_test()
{'sample_id': tensor([0, 1]), 'value': tensor([[0, 1],
        [2, 3]], dtype=torch.int32)}
{'sample_id': tensor([2, 3]), 'value': tensor([[4, 5],
        [6, 7]], dtype=torch.int32)}
"""

参考

todo文章来源地址https://www.toymoban.com/news/detail-501451.html

到了这里,关于pytorch 中的数据集与多进程并发的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 分布式集群与多线程高并发

      后台数据的处理语言有很多,Java 是对前端采集的数据的一种比较常见的开发语言。互联网移动客户端的用户量特别大,大量的数据处理需求应运而生。可移动嵌入式设备的表现形式   很多,如 PC 端,手机移动端,智能手表,Google  眼镜等。Server2client 的互联网开发模式比

    2024年02月08日
    浏览(48)
  • C++11并发与多线程笔记 (1)

    指在一个时间段内有多个进程在执行 两个或者更多的任务(独立的活动)同时发生(进行):一个程序同时执行多个独立的任务; 以往计算机,单核cpu(中央处理器):某一个时刻只能执行一个任务,由操作系统调度,每秒钟进行多次所谓的“任务切换”。并发的假象( 不

    2024年02月12日
    浏览(45)
  • python多进程与多线程

    1.1 GIL 全局解释器锁 其他语言,CPU是多核时是支持多个线程同时执行。但在Python中,无论是单核还是多核,同时只能由一个线程在执行。其根源是GIL的存在。GIL的全称是Global Interpreter Lock(全局解释器锁),来源是Python设计之初的考虑,为了数据安全所做的决定。某个线程想要执

    2024年02月05日
    浏览(41)
  • 一文掌握Python多线程与多进程

    并发是今天计算机编程中的一项重要能力,尤其是在面对需要大量计算或I/O操作的任务时。Python 提供了多种并发的处理方式,本篇文章将深入探讨其中的两种:多线程与多进程,解析其使用场景、优点、缺点,并结合代码例子深入解读。 Python中的线程是利用 threading 模块实现

    2024年02月09日
    浏览(45)
  • C++11并发与多线程笔记(6) unique_lock(类模板)

    unique_lock 是一个类模板。 unique_lock 比 lock_guard 灵活很多 ( 多出来很多用法 ),效率差一点,内存占用多一些。 使用: unique_lockmutex myUniLock(myMutex); std::adopt_lock:标记作用,表示这个互斥量已经被lock()(方便记忆:已经被lock()收养了,不需要再次lock() ),即 不需要在构造函

    2024年02月12日
    浏览(43)
  • 【神行百里】python开启多线程(threading)与多进程(multiprocessing)运行

      由于处理数据过多,程序运行很慢,就学习了一下python开启多线程与多进程的方法,虽然最后也没用上,但还是记录总结一下,以备不时之需。   传送门:进程与线程认识,进程与线程通俗理解   简言之, 进程为资源分配的最小单元,线程为程序执行的最小单元

    2024年02月02日
    浏览(43)
  • C++并发与多线程笔记八:async、future、packaged_task、promise

    本文接上文 C++并发与多线程笔记七:condition_variable、wait、notify_one/all 的内容,主要记录 async、future、packaged_task、promise 概念以及用法。 2.1 基本用法 std::async 是个函数模板,用来启动一个异步任务,启动一个异步任务后,它返回一个 std::future 类模板对象。 上述\\\"启动一个异步

    2023年04月13日
    浏览(46)
  • 【pytorch实用小技巧】单gpu与多gpu训练与评估

    1、单gpu 首先检查GPU是否可用,并将模型、输入数据和目标标签移动到GPU上。 然后,定义损失函数和优化器。在训练循环中,将模型设置为训练模式,进行前向传播、计算损失、反向传播和参数更新。 在测试阶段,将模型设置为评估模式,并在测试数据上进行推断。 2、多

    2024年02月12日
    浏览(36)
  • C++11并发与多线程笔记(9) async、future、packaged_task、promise

    std::async : 是一个函数模板,用来启动一个异步任务,启动起来一个异步任务之后,它返回一个std::future对象,这个对象是个类模板。 什么叫“ 启动一个异步任务 ”?就是自动创建一个线程,并开始执行对应的线程入口函数,它返回一个std::future对象,这个std::future对象中就

    2024年02月12日
    浏览(41)
  • C++11并发与多线程笔记(10) future其他成员函数、shared_future、atomic

    status = result.wait_for(std::chrono::seconds(几秒)); 卡住当前流程,等待std::async()的异步任务运 行一段时间,然后返回其状态std::future_status 。如果std::async()的参数是std::launch::deferred(延迟执行),则不会卡住主流程。 std::future_status是枚举类型,表示异步任务的执行状态。类型的取值

    2024年02月12日
    浏览(40)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包