pytorch/torch/utils/data/_utils/fetch.py
Junteng Jia 335033f718 asyncio increase throughput (pytorch change) (#84301)
Summary: This diffs add a check in the fetcher, that if the dataset to be fetched has a function "getitems" then use it for fetching a batch of elements, as oppose to one by one. This is benefical for io bounded usage.

Differential Revision: D39145980

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84301
Approved by: https://github.com/VitalyFedyunin
2022-09-08 17:00:45 +00:00

62 lines
2.1 KiB
Python

r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
data from an iterable-style or map-style dataset. This logic is shared in both
single- and multi-processing data loading.
"""
class _BaseDatasetFetcher(object):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
self.dataset = dataset
self.auto_collation = auto_collation
self.collate_fn = collate_fn
self.drop_last = drop_last
def fetch(self, possibly_batched_index):
raise NotImplementedError()
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(
dataset, auto_collation, collate_fn, drop_last
)
self.dataset_iter = iter(dataset)
self.ended = False
def fetch(self, possibly_batched_index):
if self.ended:
raise StopIteration
if self.auto_collation:
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
self.ended = True
break
if len(data) == 0 or (
self.drop_last and len(data) < len(possibly_batched_index)
):
raise StopIteration
else:
data = next(self.dataset_iter)
return self.collate_fn(data)
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(
dataset, auto_collation, collate_fn, drop_last
)
def fetch(self, possibly_batched_index):
if self.auto_collation:
if hasattr(self.dataset, "__getitems__"):
data = self.dataset.__getitems__(possibly_batched_index)
else:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)