mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add option in data loader for out of order data (#141833)
Fixes #105203 Facing a similar problem to the linked issue, where variable sized input data can mean that a handful of slow to process samples holds up smaller and faster to process samples from being used. This also leads to lower GPU utilization as well. In certain cases, e.g. evaluation epochs, inference pipelines or other cases where reproducibility isn't important, this can bring significant speed ups. This PR adds an `allow_out_of_order` bool input to the `DataLoader` class, defaulting to `false`, which when set to `true` will returning data from workers in whatever order they are ready/processed in, rather in the strict index order. Instead of storing data that was returned out of order, it is passed directly to the main thread and the entry in `_task_info` is deleted. The main changes are they to check that an entry in `_task_info` does exist, and only increasing `self._rcvd_idx` when the lowest index remaining gets returned. Two tests are added to test this for iterable type datasets and index type datasets. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141833 Approved by: https://github.com/andrewkho
This commit is contained in:
parent
61a7c83c64
commit
18ef3a09cd
|
|
@ -3501,6 +3501,99 @@ class TestConvAfterFork(TestCase):
|
|||
self.assertEqual(x.shape, (1, 1, 1, 23999))
|
||||
|
||||
|
||||
class TestSlowIndexDataset(Dataset):
|
||||
def __init__(self, end: int, slow_index: int):
|
||||
self.end = end
|
||||
self.slow_index = slow_index
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx == self.slow_index:
|
||||
time.sleep(0.5)
|
||||
return idx
|
||||
|
||||
def __len__(self):
|
||||
return self.end
|
||||
|
||||
|
||||
class TestSlowIterableDataset(IterableDataset):
|
||||
def __init__(self, start: int, end: int):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.mid = math.ceil((self.end - self.start) / 2)
|
||||
|
||||
def give_data(self, iter_start, iter_end):
|
||||
for i in range(iter_start, iter_end):
|
||||
if i >= self.mid:
|
||||
time.sleep(0.5)
|
||||
yield i
|
||||
|
||||
def __iter__(self):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
per_worker = int(
|
||||
math.ceil((self.end - self.start) / float(worker_info.num_workers))
|
||||
)
|
||||
worker_id = worker_info.id
|
||||
iter_start = self.start + worker_id * per_worker
|
||||
iter_end = min(iter_start + per_worker, self.end)
|
||||
return self.give_data(iter_start, iter_end)
|
||||
|
||||
|
||||
class TestOutOfOrderDataLoader(TestCase):
|
||||
def test_in_order_index_ds(self):
|
||||
dataset = TestSlowIndexDataset(end=10, slow_index=2)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=2,
|
||||
in_order=True,
|
||||
)
|
||||
|
||||
expected_order = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
output = [sample.item() for sample in dataloader]
|
||||
self.assertEqual(expected_order, output)
|
||||
|
||||
def test_out_of_order_index_ds(self):
|
||||
dataset = TestSlowIndexDataset(end=10, slow_index=2)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=2,
|
||||
in_order=False,
|
||||
)
|
||||
|
||||
# normally, this should be [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
expected_order = [0, 1, 3, 5, 7, 2, 4, 6, 8, 9]
|
||||
output = [sample.item() for sample in dataloader]
|
||||
self.assertEqual(expected_order, output)
|
||||
|
||||
def test_in_order_iterable_ds(self):
|
||||
dataset = TestSlowIterableDataset(start=0, end=10)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=2,
|
||||
in_order=True,
|
||||
)
|
||||
|
||||
expected_order = [0, 5, 1, 6, 2, 7, 3, 8, 4, 9]
|
||||
output = [sample.item() for sample in dataloader]
|
||||
self.assertEqual(expected_order, output)
|
||||
|
||||
def test_out_of_order_iterable_ds(self):
|
||||
dataset = TestSlowIterableDataset(start=0, end=10)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=2,
|
||||
in_order=False,
|
||||
)
|
||||
|
||||
# normally, this should be [0, 5, 1, 6, 2, 7, 3, 8, 4, 9]
|
||||
expected_order = [0, 1, 2, 3, 5, 4, 6, 7, 8, 9]
|
||||
output = [sample.item() for sample in dataloader]
|
||||
self.assertEqual(expected_order, output)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestDataLoaderDeviceType, globals())
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -185,6 +185,8 @@ class DataLoader(Generic[_T_co]):
|
|||
maintain the workers `Dataset` instances alive. (default: ``False``)
|
||||
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
|
||||
``True``.
|
||||
in_order (bool, optional): If ``False``, the data loader will not enforce that batches
|
||||
are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``)
|
||||
|
||||
|
||||
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
|
||||
|
|
@ -213,6 +215,9 @@ class DataLoader(Generic[_T_co]):
|
|||
.. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
|
||||
:ref:`data-loading-randomness` notes for random seed related questions.
|
||||
|
||||
.. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data
|
||||
distribution being fed to the trainer in cases with imbalanced data.
|
||||
|
||||
.. _multiprocessing context:
|
||||
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
|
||||
"""
|
||||
|
|
@ -248,6 +253,7 @@ class DataLoader(Generic[_T_co]):
|
|||
prefetch_factor: Optional[int] = None,
|
||||
persistent_workers: bool = False,
|
||||
pin_memory_device: str = "",
|
||||
in_order: bool = True,
|
||||
):
|
||||
torch._C._log_api_usage_once("python.data_loader")
|
||||
|
||||
|
|
@ -281,6 +287,7 @@ class DataLoader(Generic[_T_co]):
|
|||
self.timeout = timeout
|
||||
self.worker_init_fn = worker_init_fn
|
||||
self.multiprocessing_context = multiprocessing_context
|
||||
self.in_order = in_order
|
||||
|
||||
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
||||
# _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
|
||||
|
|
@ -1074,6 +1081,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
super().__init__(loader)
|
||||
|
||||
self._prefetch_factor = loader.prefetch_factor
|
||||
self._in_order = loader.in_order
|
||||
|
||||
assert self._num_workers > 0
|
||||
assert self._prefetch_factor > 0
|
||||
|
|
@ -1423,7 +1431,8 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
# call and `_IterableDatasetStopIteration` check below can mark
|
||||
# extra worker(s) as dead.
|
||||
while self._rcvd_idx < self._send_idx:
|
||||
info = self._task_info[self._rcvd_idx]
|
||||
info = self._task_info.get(self._rcvd_idx, None)
|
||||
if info:
|
||||
worker_id = info[0]
|
||||
if (
|
||||
len(info) == 2 or self._workers_status[worker_id]
|
||||
|
|
@ -1442,6 +1451,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
# Check if the next sample has already been generated
|
||||
if len(self._task_info[self._rcvd_idx]) == 2:
|
||||
data = self._task_info.pop(self._rcvd_idx)[1]
|
||||
self._rcvd_idx += 1
|
||||
return self._process_data(data)
|
||||
|
||||
assert not self._shutdown and self._tasks_outstanding > 0
|
||||
|
|
@ -1458,10 +1468,15 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
continue
|
||||
|
||||
if idx != self._rcvd_idx:
|
||||
if not self._in_order:
|
||||
# don't store it for later, process now
|
||||
del self._task_info[idx]
|
||||
return self._process_data(data)
|
||||
# store out-of-order samples
|
||||
self._task_info[idx] += (data,)
|
||||
else:
|
||||
del self._task_info[idx]
|
||||
self._rcvd_idx += 1
|
||||
return self._process_data(data)
|
||||
|
||||
def _try_put_index(self):
|
||||
|
|
@ -1485,7 +1500,6 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
self._send_idx += 1
|
||||
|
||||
def _process_data(self, data):
|
||||
self._rcvd_idx += 1
|
||||
self._try_put_index()
|
||||
if isinstance(data, ExceptionWrapper):
|
||||
data.reraise()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user