Add option in data loader for out of order data (#141833)

Fixes #105203

Facing a similar problem to the linked issue, where variable sized input data can mean that a handful of slow to process samples holds up smaller and faster to process samples from being used. This also leads to lower GPU utilization as well. In certain cases, e.g. evaluation epochs, inference pipelines or other cases where reproducibility isn't important, this can bring significant speed ups.

This PR adds an `allow_out_of_order` bool input to the `DataLoader` class, defaulting to `false`, which when set to `true` will returning data from workers in whatever order they are ready/processed in, rather in the strict index order.
Instead of storing data that was returned out of order, it is passed directly to the main thread and the entry in `_task_info` is deleted. The main changes are they to check that an entry in `_task_info` does exist, and only increasing `self._rcvd_idx` when the lowest index remaining gets returned.

Two tests are added to test this for iterable type datasets and index type datasets.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141833
Approved by: https://github.com/andrewkho
This commit is contained in:
Michael Diggin 2024-12-06 19:55:58 +00:00 committed by PyTorch MergeBot
parent 61a7c83c64
commit 18ef3a09cd
2 changed files with 115 additions and 8 deletions

View File

@ -3501,6 +3501,99 @@ class TestConvAfterFork(TestCase):
self.assertEqual(x.shape, (1, 1, 1, 23999))
class TestSlowIndexDataset(Dataset):
def __init__(self, end: int, slow_index: int):
self.end = end
self.slow_index = slow_index
def __getitem__(self, idx):
if idx == self.slow_index:
time.sleep(0.5)
return idx
def __len__(self):
return self.end
class TestSlowIterableDataset(IterableDataset):
def __init__(self, start: int, end: int):
self.start = start
self.end = end
self.mid = math.ceil((self.end - self.start) / 2)
def give_data(self, iter_start, iter_end):
for i in range(iter_start, iter_end):
if i >= self.mid:
time.sleep(0.5)
yield i
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
per_worker = int(
math.ceil((self.end - self.start) / float(worker_info.num_workers))
)
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return self.give_data(iter_start, iter_end)
class TestOutOfOrderDataLoader(TestCase):
def test_in_order_index_ds(self):
dataset = TestSlowIndexDataset(end=10, slow_index=2)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=2,
in_order=True,
)
expected_order = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
output = [sample.item() for sample in dataloader]
self.assertEqual(expected_order, output)
def test_out_of_order_index_ds(self):
dataset = TestSlowIndexDataset(end=10, slow_index=2)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=2,
in_order=False,
)
# normally, this should be [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
expected_order = [0, 1, 3, 5, 7, 2, 4, 6, 8, 9]
output = [sample.item() for sample in dataloader]
self.assertEqual(expected_order, output)
def test_in_order_iterable_ds(self):
dataset = TestSlowIterableDataset(start=0, end=10)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=2,
in_order=True,
)
expected_order = [0, 5, 1, 6, 2, 7, 3, 8, 4, 9]
output = [sample.item() for sample in dataloader]
self.assertEqual(expected_order, output)
def test_out_of_order_iterable_ds(self):
dataset = TestSlowIterableDataset(start=0, end=10)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=2,
in_order=False,
)
# normally, this should be [0, 5, 1, 6, 2, 7, 3, 8, 4, 9]
expected_order = [0, 1, 2, 3, 5, 4, 6, 7, 8, 9]
output = [sample.item() for sample in dataloader]
self.assertEqual(expected_order, output)
instantiate_device_type_tests(TestDataLoaderDeviceType, globals())

View File

@ -185,6 +185,8 @@ class DataLoader(Generic[_T_co]):
maintain the workers `Dataset` instances alive. (default: ``False``)
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
``True``.
in_order (bool, optional): If ``False``, the data loader will not enforce that batches
are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``)
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
@ -213,6 +215,9 @@ class DataLoader(Generic[_T_co]):
.. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
:ref:`data-loading-randomness` notes for random seed related questions.
.. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data
distribution being fed to the trainer in cases with imbalanced data.
.. _multiprocessing context:
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
"""
@ -248,6 +253,7 @@ class DataLoader(Generic[_T_co]):
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
pin_memory_device: str = "",
in_order: bool = True,
):
torch._C._log_api_usage_once("python.data_loader")
@ -281,6 +287,7 @@ class DataLoader(Generic[_T_co]):
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context
self.in_order = in_order
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
# _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
@ -1074,6 +1081,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
super().__init__(loader)
self._prefetch_factor = loader.prefetch_factor
self._in_order = loader.in_order
assert self._num_workers > 0
assert self._prefetch_factor > 0
@ -1423,13 +1431,14 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
# call and `_IterableDatasetStopIteration` check below can mark
# extra worker(s) as dead.
while self._rcvd_idx < self._send_idx:
info = self._task_info[self._rcvd_idx]
worker_id = info[0]
if (
len(info) == 2 or self._workers_status[worker_id]
): # has data or is still active
break
del self._task_info[self._rcvd_idx]
info = self._task_info.get(self._rcvd_idx, None)
if info:
worker_id = info[0]
if (
len(info) == 2 or self._workers_status[worker_id]
): # has data or is still active
break
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
@ -1442,6 +1451,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
# Check if the next sample has already been generated
if len(self._task_info[self._rcvd_idx]) == 2:
data = self._task_info.pop(self._rcvd_idx)[1]
self._rcvd_idx += 1
return self._process_data(data)
assert not self._shutdown and self._tasks_outstanding > 0
@ -1458,10 +1468,15 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
continue
if idx != self._rcvd_idx:
if not self._in_order:
# don't store it for later, process now
del self._task_info[idx]
return self._process_data(data)
# store out-of-order samples
self._task_info[idx] += (data,)
else:
del self._task_info[idx]
self._rcvd_idx += 1
return self._process_data(data)
def _try_put_index(self):
@ -1485,7 +1500,6 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
self._send_idx += 1
def _process_data(self, data):
self._rcvd_idx += 1
self._try_put_index()
if isinstance(data, ExceptionWrapper):
data.reraise()