pytorch/torch/utils/data/_utils/fetch.py
Aryan Gupta 92e7f79609 Doc: Add and Fix docstrings for torch.util.data files (#112817)
Fixes #112635

Fix docstrings for `torch.utils.data` files.

```
Before:
> pydocstyle torch/utils/data/graph.py --count
Before: 5
After: 1

> pydocstyle torch/utils/data/graph_settings.py --count
Before: 8
After: 3

> pydocstyle torch/utils/data/dataloader.py --count
Before: 12
After: 6

> pydocstyle torch/utils/data/dataset.py --count
Before: 28
After: 23

> pydocstyle torch/utils/data/sampler.py --count
Before: 24
After: 19

> pydocstyle torch/utils/data/_utils/signal_handling.py --count
Before: 1
After: 0

> pydocstyle torch/utils/data/_utils/__init__.py --count
Before: 2
After: 0

> pydocstyle torch/utils/data/_utils/collate.py --count
Before: 20
After: 6

> pydocstyle torch/utils/data/_utils/fetch.py --count
Before: 3
After: 0

> pydocstyle torch/utils/data/_utils/pin_memory.py --count
Before: 4
After: 1

> pydocstyle torch/utils/data/datapipes/_decorator.py --count
Before: 19
After: 16

> pydocstyle torch/utils/data/datapipes/_hook_iterator.py --count
Before: 13
After: 0

> pydocstyle torch/utils/data/datapipes/_typing.py --count
Before: 17
After: 4

> pydocstyle torch/utils/data/datapipes/gen_pyi.py --count
Before: 19
After: 4
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112817
Approved by: https://github.com/kit1980
2023-11-07 17:59:56 +00:00

55 lines
1.9 KiB
Python

r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset.
This logic is shared in both single- and multi-processing data loading.
"""
class _BaseDatasetFetcher:
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
self.dataset = dataset
self.auto_collation = auto_collation
self.collate_fn = collate_fn
self.drop_last = drop_last
def fetch(self, possibly_batched_index):
raise NotImplementedError()
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super().__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
self.ended = False
def fetch(self, possibly_batched_index):
if self.ended:
raise StopIteration
if self.auto_collation:
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
self.ended = True
break
if len(data) == 0 or (
self.drop_last and len(data) < len(possibly_batched_index)
):
raise StopIteration
else:
data = next(self.dataset_iter)
return self.collate_fn(data)
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if self.auto_collation:
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
data = self.dataset.__getitems__(possibly_batched_index)
else:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)