mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #112635 Fix docstrings for `torch.utils.data` files. ``` Before: > pydocstyle torch/utils/data/graph.py --count Before: 5 After: 1 > pydocstyle torch/utils/data/graph_settings.py --count Before: 8 After: 3 > pydocstyle torch/utils/data/dataloader.py --count Before: 12 After: 6 > pydocstyle torch/utils/data/dataset.py --count Before: 28 After: 23 > pydocstyle torch/utils/data/sampler.py --count Before: 24 After: 19 > pydocstyle torch/utils/data/_utils/signal_handling.py --count Before: 1 After: 0 > pydocstyle torch/utils/data/_utils/__init__.py --count Before: 2 After: 0 > pydocstyle torch/utils/data/_utils/collate.py --count Before: 20 After: 6 > pydocstyle torch/utils/data/_utils/fetch.py --count Before: 3 After: 0 > pydocstyle torch/utils/data/_utils/pin_memory.py --count Before: 4 After: 1 > pydocstyle torch/utils/data/datapipes/_decorator.py --count Before: 19 After: 16 > pydocstyle torch/utils/data/datapipes/_hook_iterator.py --count Before: 13 After: 0 > pydocstyle torch/utils/data/datapipes/_typing.py --count Before: 17 After: 4 > pydocstyle torch/utils/data/datapipes/gen_pyi.py --count Before: 19 After: 4 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/112817 Approved by: https://github.com/kit1980
55 lines
1.9 KiB
Python
55 lines
1.9 KiB
Python
r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset.
|
|
|
|
This logic is shared in both single- and multi-processing data loading.
|
|
"""
|
|
|
|
|
|
class _BaseDatasetFetcher:
|
|
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
|
self.dataset = dataset
|
|
self.auto_collation = auto_collation
|
|
self.collate_fn = collate_fn
|
|
self.drop_last = drop_last
|
|
|
|
def fetch(self, possibly_batched_index):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class _IterableDatasetFetcher(_BaseDatasetFetcher):
|
|
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
|
super().__init__(dataset, auto_collation, collate_fn, drop_last)
|
|
self.dataset_iter = iter(dataset)
|
|
self.ended = False
|
|
|
|
def fetch(self, possibly_batched_index):
|
|
if self.ended:
|
|
raise StopIteration
|
|
|
|
if self.auto_collation:
|
|
data = []
|
|
for _ in possibly_batched_index:
|
|
try:
|
|
data.append(next(self.dataset_iter))
|
|
except StopIteration:
|
|
self.ended = True
|
|
break
|
|
if len(data) == 0 or (
|
|
self.drop_last and len(data) < len(possibly_batched_index)
|
|
):
|
|
raise StopIteration
|
|
else:
|
|
data = next(self.dataset_iter)
|
|
return self.collate_fn(data)
|
|
|
|
|
|
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
|
def fetch(self, possibly_batched_index):
|
|
if self.auto_collation:
|
|
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
|
data = self.dataset.__getitems__(possibly_batched_index)
|
|
else:
|
|
data = [self.dataset[idx] for idx in possibly_batched_index]
|
|
else:
|
|
data = self.dataset[possibly_batched_index]
|
|
return self.collate_fn(data)
|