pytorch/torch/utils/data/_utils/collate.py
Tongzhou Wang 058beae411 Add IterableDataset (#19228)
Summary:
This is a modified version of https://github.com/pytorch/pytorch/pull/14705 since commit structure for that PR is quite messy.

1. Add `IterableDataset`.
3. So we have 2 data loader mods: `Iterable` and `Map`.

    1. `Iterable` if the `dataset` is an instance of `IterableDataset`
    2. `Map` o.w.

3. Add better support for non-batch loading (i.e., `batch_size=None` and `batch_sampler=None`). This is useful in doing things like bulk loading.
3. Refactor `DataLoaderIter` into two classes, `_SingleProcessDataLoaderIter` and `_MultiProcessingDataLoaderIter`. Rename some methods to be more generic, e.g., `get_batch` -> `get_data`.
4. Add `torch.utils.data.get_worker_info` which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used in `IterableDataset.__iter__` and `worker_init_fn` to do per-worker configuration.
5. Add `ChainDataset`, which is the analog of `ConcatDataset` for `IterableDataset`.
7. Import torch.utils.data in `torch/__init__.py`
9. data loader examples and documentations
10. Use `get_worker_info` to detect whether we are in a worker process in `default_collate`

Closes https://github.com/pytorch/pytorch/issues/17909, https://github.com/pytorch/pytorch/issues/18096, https://github.com/pytorch/pytorch/issues/19946, and some of https://github.com/pytorch/pytorch/issues/13023
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19228

Reviewed By: bddppq

Differential Revision: D15058152

fbshipit-source-id: 9e081a901a071d7e4502b88054a34b450ab5ddde
2019-06-20 20:12:44 -07:00

83 lines
3.3 KiB
Python

r""""Contains definitions of the methods used by the _DataLoaderIter workers to
collate samples fetched from dataset into Tensor(s).
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""
import torch
import re
from torch._six import container_abcs, string_classes, int_classes
np_str_obj_array_pattern = re.compile(r'[SaUO]')
def default_convert(data):
r"""Converts each NumPy array data field into a tensor"""
elem_type = type(data)
if isinstance(data, torch.Tensor):
return data
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
# array of string classes and object
if elem_type.__name__ == 'ndarray' \
and np_str_obj_array_pattern.search(data.dtype.str) is not None:
return data
return torch.as_tensor(data)
elif isinstance(data, container_abcs.Mapping):
return {key: default_convert(data[key]) for key in data}
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
return elem_type(default_convert(d) for d in data)
elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes):
return [default_convert(d) for d in data]
else:
return data
default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int_classes):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))