mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: This is a modified version of https://github.com/pytorch/pytorch/pull/14705 since commit structure for that PR is quite messy. 1. Add `IterableDataset`. 3. So we have 2 data loader mods: `Iterable` and `Map`. 1. `Iterable` if the `dataset` is an instance of `IterableDataset` 2. `Map` o.w. 3. Add better support for non-batch loading (i.e., `batch_size=None` and `batch_sampler=None`). This is useful in doing things like bulk loading. 3. Refactor `DataLoaderIter` into two classes, `_SingleProcessDataLoaderIter` and `_MultiProcessingDataLoaderIter`. Rename some methods to be more generic, e.g., `get_batch` -> `get_data`. 4. Add `torch.utils.data.get_worker_info` which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used in `IterableDataset.__iter__` and `worker_init_fn` to do per-worker configuration. 5. Add `ChainDataset`, which is the analog of `ConcatDataset` for `IterableDataset`. 7. Import torch.utils.data in `torch/__init__.py` 9. data loader examples and documentations 10. Use `get_worker_info` to detect whether we are in a worker process in `default_collate` Closes https://github.com/pytorch/pytorch/issues/17909, https://github.com/pytorch/pytorch/issues/18096, https://github.com/pytorch/pytorch/issues/19946, and some of https://github.com/pytorch/pytorch/issues/13023 Pull Request resolved: https://github.com/pytorch/pytorch/pull/19228 Reviewed By: bddppq Differential Revision: D15058152 fbshipit-source-id: 9e081a901a071d7e4502b88054a34b450ab5ddde
48 lines
1.8 KiB
Python
48 lines
1.8 KiB
Python
r""""Contains definitions of the methods used by the _DataLoaderIter to fetch
|
|
data from an iterable-style or map-style dataset. This logic is shared in both
|
|
single- and multi-processing data loading.
|
|
"""
|
|
|
|
|
|
class _BaseDatasetFetcher(object):
|
|
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
|
self.dataset = dataset
|
|
self.auto_collation = auto_collation
|
|
self.collate_fn = collate_fn
|
|
self.drop_last = drop_last
|
|
|
|
def fetch(self, possibly_batched_index):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class _IterableDatasetFetcher(_BaseDatasetFetcher):
|
|
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
|
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
|
|
self.dataset_iter = iter(dataset)
|
|
|
|
def fetch(self, possibly_batched_index):
|
|
if self.auto_collation:
|
|
data = []
|
|
for _ in possibly_batched_index:
|
|
try:
|
|
data.append(next(self.dataset_iter))
|
|
except StopIteration:
|
|
break
|
|
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
|
|
raise StopIteration
|
|
else:
|
|
data = next(self.dataset_iter)
|
|
return self.collate_fn(data)
|
|
|
|
|
|
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
|
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
|
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
|
|
|
|
def fetch(self, possibly_batched_index):
|
|
if self.auto_collation:
|
|
data = [self.dataset[idx] for idx in possibly_batched_index]
|
|
else:
|
|
data = self.dataset[possibly_batched_index]
|
|
return self.collate_fn(data)
|