mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove useless parentheses in `raise` statements if the exception type is raised with no argument. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261 Approved by: https://github.com/albanD
55 lines
1.9 KiB
Python
55 lines
1.9 KiB
Python
r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset.
|
|
|
|
This logic is shared in both single- and multi-processing data loading.
|
|
"""
|
|
|
|
|
|
class _BaseDatasetFetcher:
|
|
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
|
self.dataset = dataset
|
|
self.auto_collation = auto_collation
|
|
self.collate_fn = collate_fn
|
|
self.drop_last = drop_last
|
|
|
|
def fetch(self, possibly_batched_index):
|
|
raise NotImplementedError
|
|
|
|
|
|
class _IterableDatasetFetcher(_BaseDatasetFetcher):
|
|
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
|
super().__init__(dataset, auto_collation, collate_fn, drop_last)
|
|
self.dataset_iter = iter(dataset)
|
|
self.ended = False
|
|
|
|
def fetch(self, possibly_batched_index):
|
|
if self.ended:
|
|
raise StopIteration
|
|
|
|
if self.auto_collation:
|
|
data = []
|
|
for _ in possibly_batched_index:
|
|
try:
|
|
data.append(next(self.dataset_iter))
|
|
except StopIteration:
|
|
self.ended = True
|
|
break
|
|
if len(data) == 0 or (
|
|
self.drop_last and len(data) < len(possibly_batched_index)
|
|
):
|
|
raise StopIteration
|
|
else:
|
|
data = next(self.dataset_iter)
|
|
return self.collate_fn(data)
|
|
|
|
|
|
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
|
def fetch(self, possibly_batched_index):
|
|
if self.auto_collation:
|
|
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
|
data = self.dataset.__getitems__(possibly_batched_index)
|
|
else:
|
|
data = [self.dataset[idx] for idx in possibly_batched_index]
|
|
else:
|
|
data = self.dataset[possibly_batched_index]
|
|
return self.collate_fn(data)
|