pytorch/torch/utils/data/_utils/fetch.py
Tongzhou Wang 058beae411 Add IterableDataset (#19228)
Summary:
This is a modified version of https://github.com/pytorch/pytorch/pull/14705 since commit structure for that PR is quite messy.

1. Add `IterableDataset`.
3. So we have 2 data loader mods: `Iterable` and `Map`.

    1. `Iterable` if the `dataset` is an instance of `IterableDataset`
    2. `Map` o.w.

3. Add better support for non-batch loading (i.e., `batch_size=None` and `batch_sampler=None`). This is useful in doing things like bulk loading.
3. Refactor `DataLoaderIter` into two classes, `_SingleProcessDataLoaderIter` and `_MultiProcessingDataLoaderIter`. Rename some methods to be more generic, e.g., `get_batch` -> `get_data`.
4. Add `torch.utils.data.get_worker_info` which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used in `IterableDataset.__iter__` and `worker_init_fn` to do per-worker configuration.
5. Add `ChainDataset`, which is the analog of `ConcatDataset` for `IterableDataset`.
7. Import torch.utils.data in `torch/__init__.py`
9. data loader examples and documentations
10. Use `get_worker_info` to detect whether we are in a worker process in `default_collate`

Closes https://github.com/pytorch/pytorch/issues/17909, https://github.com/pytorch/pytorch/issues/18096, https://github.com/pytorch/pytorch/issues/19946, and some of https://github.com/pytorch/pytorch/issues/13023
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19228

Reviewed By: bddppq

Differential Revision: D15058152

fbshipit-source-id: 9e081a901a071d7e4502b88054a34b450ab5ddde
2019-06-20 20:12:44 -07:00

48 lines
1.8 KiB
Python

r""""Contains definitions of the methods used by the _DataLoaderIter to fetch
data from an iterable-style or map-style dataset. This logic is shared in both
single- and multi-processing data loading.
"""
class _BaseDatasetFetcher(object):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
self.dataset = dataset
self.auto_collation = auto_collation
self.collate_fn = collate_fn
self.drop_last = drop_last
def fetch(self, possibly_batched_index):
raise NotImplementedError()
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
break
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
raise StopIteration
else:
data = next(self.dataset_iter)
return self.collate_fn(data)
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)