总述
需要 Dataset + collate_fn + Sampler + DataLoader 联用, 才等价于 tf 的 dataset.
- DataLoader, 对外服务的类. 通过
_get_iterator
() 方法返回 iterator, 对其调用 next() 得到 tensor. - Sampler, 数据集的采样策略, 给出每个 step 要使用的数据的索引, 记为 possibly_batched_index.
- Fetcher, 根据 possibly_batched_index, 从 dataset 对象中拿数据
- collate_fn, Fetcher 对象拿到原始数据后, 调用 collate_fn 得到 tensor 对象, 送往模型.
一. Dataset
torch.utils.data.Dataset
, 这是一个抽象类, 自己需要实现它的子类.
class Dataset(Generic[T_co]):
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __len__(self):
raise NotImplementedError
二. Sampler
torch.utils.data.Sampler
, 也是一个抽象类.
默认的是 SequentialSampler + BatchSampler 的搭配.
class Sampler(Generic[T_co]):
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
class SequentialSampler(Sampler[int]):
r"""Samples elements sequentially, always in the same order.
Args:
data_source (Dataset): dataset to sample from
"""
data_source: Sized
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
class BatchSampler(Sampler[List[int]]):
def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
pass
def __iter__(self) -> Iterator[List[int]]:
sampler_iter = iter(self.sampler)
while True:
try:
batch = [next(sampler_iter) for _ in range(self.batch_size)]
yield batch
except StopIteration:
break
三. collate_fn
一个接口, 完成基本类型数据到 batch tensor 的处理. 方法签名见下:
- def
my_collate_fn
(feature_dict_list: List[Dict[str, Union[str, Tensor]]]) -> Dict[str, Tensor]- feature_dict_list. 元素个数为 batch_size, 元素为 Dict[str, Any], 通常为基本数据类型.
- return: Dict[str, Tensor], tensor_.shape[0] 通常为相应的 batch_size.
Q: 如何额外传参?
方法签名中看到没有额外的传参设计, 那么我们想传一些参数配置(比如不同的特征处理规则), 想做到通用化, 要怎么办呢?
A: 传一个 callable 对象即可. 做法为自定义 MyCollator 类, init 方法传入配置, 并实现 __call__(self, xxx)
方法, 签名与 collate_fn 保持一致即可.
四. DataLoader
依赖 dataset, 负责 batch, shuffle 等能力增强, 返回是 Tensor 对象.
-
DataLoader#__init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,...,collate_fn, ...)
- collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s).
如果不传, 内部会赋值为torch.utils.data._utils.collate.default_collate
, 已足够好用, 见下节例子.
在 collate_fn 中可做灵活处理, 等价于 tf.dataset.map(map_fn).
- collate_fn (callable, optional): merges a list of samples to form a
class DataLoader(Generic[T_co]):
def __init__(self, dataset: Dataset[T_co],
batch_size: Optional[int] = 1,
num_workers: int = 0,
collate_fn: Optional[_collate_fn_t] = None,
worker_init_fn: Optional[_worker_init_fn_t] = None,
)
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
else:
sampler = SequentialSampler(dataset) # type: ignore[arg-type]
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
def _next_index(self):
return next(self._sampler_iter)
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
4.1 _DatasetKind
class _DatasetKind(object):
Map = 0
Iterable = 1
@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
4.2 Fetcher
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
self.ended = False
def fetch(self, possibly_batched_index):
if self.ended:
raise StopIteration
if self.auto_collation:
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
self.ended = True
break
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
raise StopIteration
else:
data = next(self.dataset_iter)
return self.collate_fn(data)
4.3 _SingleProcessDataLoaderIter
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
return data
五. dataloader 多进程
主进程不再从自己的 dataset 对象中拿数据, 而是全靠子进程往队列里放, 自己从队列里拿, 是一个 “多生产者, 单消费者” 数据传递.
5.1 主进程
_MultiProcessingDataLoaderIter
class _BaseDataLoaderIter(object):
def __iter__(self) -> '_BaseDataLoaderIter':
return self
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def _next_data(self):
raise NotImplementedError
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
self._worker_init_fn = loader.worker_init_fn
self._worker_result_queue = multiprocessing_context.Queue()
self._index_queues = []
self._workers = []
for i in range(self._num_workers):
index_queue = multiprocessing_context.Queue()
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed, self._worker_init_fn, i, self._num_workers,
self._persistent_workers, self._shared_seed))
w.daemon = True
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
self._reset(loader, first_iter=True)
def _next_data(self):
while True:
idx, data = self._get_data()
return self._process_data(data)
def _get_data(self):
while True:
success, data = self._try_get_data()
if success:
return data
def _process_data(self, data):
self._rcvd_idx += 1
self._try_put_index()
if isinstance(data, ExceptionWrapper):
data.reraise()
return data
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
# Tries to fetch data from `self._data_queue` once for a given timeout.
# This can also be used as inner loop of fetching without timeout, with
# the sender status as the loop condition.
#
# This raises a `RuntimeError` if any worker died expectedly. This error
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
# (only for non-Windows platforms), or the manual check below on errors
# and timeouts.
#
# Returns a 2-tuple:
# (bool: whether successfully get data, any: data if successful else None)
try:
data = self._data_queue.get(timeout=timeout)
return (True, data)
5.2 子进程
子进程的 target 逻辑在 worker.py 中.
dataset 对象由主进程序列化后通过管道传递给子进程, 然后反序列化处理.
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
num_workers, persistent_workers, shared_seed):
try:
global _worker_info
_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
seed=seed, dataset=dataset)
if init_fn is not None:
init_fn(worker_id)
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
watchdog = ManagerWatchdog()
while watchdog.is_alive():
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
data = fetcher.fetch(index)
data_queue.put((idx, data))
del data, idx, index, r # save memory
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass
reader 句柄处理
如果是流式数据集, reader 等句柄需要延迟打开, 否则会因被支持序列化而报错.
数据切片与 WorkerInfo
dataset 对象从主进程序列化而来, 直接用的话, 读的数据必然是重复的.
子进程启动后, 会维护全局变量 _worker_info, _DatasetKind.create_fetcher()
方法内会触发 iter(dataset)
, 需要 dataset 通过 torch.utils.data.get_worker_info()
主动感知子进程的环境, 调整数据分片.
六. IterableDataset 流式数据集
详见 参考[1].
典型场景是内存盛不下, 网络数据库 -> dataset -> model feed
流式运作.
例子见下.文章来源:https://www.toymoban.com/news/detail-501451.html
import numpy as np
from torch.utils.data import IterableDataset, DataLoader
class StreamingDataset(IterableDataset):
def generator(self):
i = 0
data = np.arange(0,10).reshape((5, 2))
while True:
if i == len(data):
break
yield {'sample_id': i, 'value': data[i]}
i += 1
def __iter__(self):
return iter(self.generator())
def dataset_test():
it = iter(StreamingDataset())
print(next(it))
print(next(it))
def loader_test():
loader = DataLoader(StreamingDataset(), batch_size=2)
it = iter(loader)
print(next(it), next(it))
if __name__ == '__main__':
loader_test()
"""
dataset_test()
{'sample_id': 0, 'value': array([0, 1])}
{'sample_id': 1, 'value': array([2, 3])}
loader_test()
{'sample_id': tensor([0, 1]), 'value': tensor([[0, 1],
[2, 3]], dtype=torch.int32)}
{'sample_id': tensor([2, 3]), 'value': tensor([[4, 5],
[6, 7]], dtype=torch.int32)}
"""
参考
todo文章来源地址https://www.toymoban.com/news/detail-501451.html
到了这里,关于pytorch 中的数据集与多进程并发的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!