mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: This is a modified version of https://github.com/pytorch/pytorch/pull/14705 since commit structure for that PR is quite messy. 1. Add `IterableDataset`. 3. So we have 2 data loader mods: `Iterable` and `Map`. 1. `Iterable` if the `dataset` is an instance of `IterableDataset` 2. `Map` o.w. 3. Add better support for non-batch loading (i.e., `batch_size=None` and `batch_sampler=None`). This is useful in doing things like bulk loading. 3. Refactor `DataLoaderIter` into two classes, `_SingleProcessDataLoaderIter` and `_MultiProcessingDataLoaderIter`. Rename some methods to be more generic, e.g., `get_batch` -> `get_data`. 4. Add `torch.utils.data.get_worker_info` which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used in `IterableDataset.__iter__` and `worker_init_fn` to do per-worker configuration. 5. Add `ChainDataset`, which is the analog of `ConcatDataset` for `IterableDataset`. 7. Import torch.utils.data in `torch/__init__.py` 9. data loader examples and documentations 10. Use `get_worker_info` to detect whether we are in a worker process in `default_collate` Closes https://github.com/pytorch/pytorch/issues/17909, https://github.com/pytorch/pytorch/issues/18096, https://github.com/pytorch/pytorch/issues/19946, and some of https://github.com/pytorch/pytorch/issues/13023 Pull Request resolved: https://github.com/pytorch/pytorch/pull/19228 Reviewed By: bddppq Differential Revision: D15058152 fbshipit-source-id: 9e081a901a071d7e4502b88054a34b450ab5ddde
83 lines
3.3 KiB
Python
83 lines
3.3 KiB
Python
r""""Contains definitions of the methods used by the _DataLoaderIter workers to
|
|
collate samples fetched from dataset into Tensor(s).
|
|
|
|
These **needs** to be in global scope since Py2 doesn't support serializing
|
|
static methods.
|
|
"""
|
|
|
|
import torch
|
|
import re
|
|
from torch._six import container_abcs, string_classes, int_classes
|
|
|
|
np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
|
|
|
|
|
def default_convert(data):
|
|
r"""Converts each NumPy array data field into a tensor"""
|
|
|
|
elem_type = type(data)
|
|
if isinstance(data, torch.Tensor):
|
|
return data
|
|
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
and elem_type.__name__ != 'string_':
|
|
# array of string classes and object
|
|
if elem_type.__name__ == 'ndarray' \
|
|
and np_str_obj_array_pattern.search(data.dtype.str) is not None:
|
|
return data
|
|
return torch.as_tensor(data)
|
|
elif isinstance(data, container_abcs.Mapping):
|
|
return {key: default_convert(data[key]) for key in data}
|
|
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
|
|
return elem_type(default_convert(d) for d in data)
|
|
elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes):
|
|
return [default_convert(d) for d in data]
|
|
else:
|
|
return data
|
|
|
|
|
|
default_collate_err_msg_format = (
|
|
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
|
"dicts or lists; found {}")
|
|
|
|
|
|
def default_collate(batch):
|
|
r"""Puts each data field into a tensor with outer dimension batch size"""
|
|
|
|
elem = batch[0]
|
|
elem_type = type(elem)
|
|
if isinstance(elem, torch.Tensor):
|
|
out = None
|
|
if torch.utils.data.get_worker_info() is not None:
|
|
# If we're in a background process, concatenate directly into a
|
|
# shared memory tensor to avoid an extra copy
|
|
numel = sum([x.numel() for x in batch])
|
|
storage = elem.storage()._new_shared(numel)
|
|
out = elem.new(storage)
|
|
return torch.stack(batch, 0, out=out)
|
|
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
and elem_type.__name__ != 'string_':
|
|
elem = batch[0]
|
|
if elem_type.__name__ == 'ndarray':
|
|
# array of string classes and object
|
|
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
|
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
|
|
|
return default_collate([torch.as_tensor(b) for b in batch])
|
|
elif elem.shape == (): # scalars
|
|
return torch.as_tensor(batch)
|
|
elif isinstance(elem, float):
|
|
return torch.tensor(batch, dtype=torch.float64)
|
|
elif isinstance(elem, int_classes):
|
|
return torch.tensor(batch)
|
|
elif isinstance(elem, string_classes):
|
|
return batch
|
|
elif isinstance(elem, container_abcs.Mapping):
|
|
return {key: default_collate([d[key] for d in batch]) for key in elem}
|
|
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
|
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
|
|
elif isinstance(elem, container_abcs.Sequence):
|
|
transposed = zip(*batch)
|
|
return [default_collate(samples) for samples in transposed]
|
|
|
|
raise TypeError(default_collate_err_msg_format.format(elem_type))
|