[DataPipe] Count number of successful yields for IterDataPipe (#79657)

This PR adds an attribute and logic to count the number of successful yields from `IterDataPipe`. This information can be useful to fast-forward a DataPipe (or the entire graph) back to a certain state.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79657
Approved by: https://github.com/VitalyFedyunin
This commit is contained in:
Kevin Tse 2022-06-27 20:24:50 -04:00 committed by PyTorch MergeBot
parent 7850a328b4
commit b8e50f512f
5 changed files with 151 additions and 12 deletions

View File

@ -277,7 +277,7 @@ class TestIterableDataPipeBasic(TestCase):
def test_listdirfiles_iterable_datapipe(self): def test_listdirfiles_iterable_datapipe(self):
temp_dir = self.temp_dir.name temp_dir = self.temp_dir.name
datapipe = dp.iter.FileLister(temp_dir, '') datapipe: IterDataPipe = dp.iter.FileLister(temp_dir, '')
count = 0 count = 0
for pathname in datapipe: for pathname in datapipe:
@ -2640,5 +2640,120 @@ class TestIterDataPipeSingletonConstraint(TestCase):
next(it1) next(it1)
self.assertEqual(1, next(it3)) self.assertEqual(1, next(it3))
class TestIterDataPipeCountSampleYielded(TestCase):
def _yield_count_test_helper(self, datapipe, n_expected_samples):
# Functional Test: Check if number of samples yielded is as expected
res = list(datapipe)
self.assertEqual(len(res), datapipe._number_of_samples_yielded)
# Functional Test: Check if the count is correct when DataPipe is partially read
it = iter(datapipe)
res = []
for i, value in enumerate(it):
res.append(value)
if i == n_expected_samples - 1:
break
self.assertEqual(n_expected_samples, datapipe._number_of_samples_yielded)
# Functional Test: Check for reset behavior and if iterator also works
it = iter(datapipe) # reset the DataPipe
res = list(it)
self.assertEqual(len(res), datapipe._number_of_samples_yielded)
def test_iterdatapipe_sample_yielded_generator_function(self):
# Functional Test: `__iter__` is a generator function
datapipe: IterDataPipe = dp.iter.IterableWrapper(range(10))
self._yield_count_test_helper(datapipe, n_expected_samples=5)
def test_iterdatapipe_sample_yielded_generator_function_exception(self):
# Functional Test: `__iter__` is a custom generator function with exception
class _CustomGeneratorFnDataPipe(IterDataPipe):
# This class's `__iter__` has a Runtime Error
def __iter__(self):
yield 0
yield 1
yield 2
raise RuntimeError("Custom test error after yielding 3 elements")
yield 3
# Functional Test: Ensure the count is correct even when exception is raised
datapipe: IterDataPipe = _CustomGeneratorFnDataPipe()
with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"):
list(datapipe)
self.assertEqual(3, datapipe._number_of_samples_yielded)
# Functional Test: Check for reset behavior and if iterator also works
it = iter(datapipe) # reset the DataPipe
with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"):
list(it)
self.assertEqual(3, datapipe._number_of_samples_yielded)
def test_iterdatapipe_sample_yielded_return_self(self):
class _CustomGeneratorDataPipe(IterDataPipe):
# This class's `__iter__` is not a generator function
def __init__(self):
self.source = iter(range(10))
def __iter__(self):
return self.source
def reset(self):
self.source = iter(range(10))
datapipe: IterDataPipe = _CustomGeneratorDataPipe()
self._yield_count_test_helper(datapipe, n_expected_samples=5)
def test_iterdatapipe_sample_yielded_next(self):
class _CustomNextDataPipe(IterDataPipe):
# This class's `__iter__` returns `self` and has a `__next__`
def __init__(self):
self.source = iter(range(10))
def __iter__(self):
return self
def __next__(self):
return next(self.source)
def reset(self):
self.source = iter(range(10))
datapipe: IterDataPipe = _CustomNextDataPipe()
self._yield_count_test_helper(datapipe, n_expected_samples=5)
def test_iterdatapipe_sample_yielded_next_exception(self):
class _CustomNextDataPipe(IterDataPipe):
# This class's `__iter__` returns `self` and has a `__next__`
def __init__(self):
self.source = iter(range(10))
self.count = 0
def __iter__(self):
return self
def __next__(self):
if self.count == 3:
raise RuntimeError("Custom test error after yielding 3 elements")
self.count += 1
return next(self.source)
def reset(self):
self.count = 0
self.source = iter(range(10))
# Functional Test: Ensure the count is correct even when exception is raised
datapipe: IterDataPipe = _CustomNextDataPipe()
with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"):
list(datapipe)
self.assertEqual(3, datapipe._number_of_samples_yielded)
# Functional Test: Check for reset behavior and if iterator also works
it = iter(datapipe) # reset the DataPipe
with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"):
list(it)
self.assertEqual(3, datapipe._number_of_samples_yielded)
if __name__ == '__main__': if __name__ == '__main__':
run_tests() run_tests()

View File

@ -94,26 +94,40 @@ def hook_iterator(namespace, profile_name):
return torch.autograd.profiler.record_function(profile_name) return torch.autograd.profiler.record_function(profile_name)
class IteratorDecorator: class IteratorDecorator:
"""Wrap the iterator and modifying its `__next__` method""" r"""
def __init__(self, iterator, source_dp, iterator_id): Wrap the iterator and modifying its `__next__` method. This decorator is applied to
DataPipes of which `__iter__` method is NOT a generator function. Those `__iter__`
method commonly returns `self` but not necessarily.
"""
def __init__(self, iterator, source_dp, iterator_id, has_next_method):
self.iterator = iterator self.iterator = iterator
self.source_dp = source_dp self.source_dp = source_dp
self.iterator_id = iterator_id self.iterator_id = iterator_id
self._profiler_enabled = torch.autograd._profiler_enabled() self._profiler_enabled = torch.autograd._profiler_enabled()
# Check if `__iter__` returns `self` and `DataPipe` has `__next__`
self.self_and_has_next_method = self.iterator is self.source_dp and has_next_method
def __iter__(self): def __iter__(self):
return self return self
def _get_next(self):
r"""
Return next with logic related to iterator validity, profiler, and incrementation of samples yielded.
"""
_check_iterator_valid(self.source_dp, self.iterator_id)
result = next(self.iterator)
if not self.self_and_has_next_method:
self.source_dp._number_of_samples_yielded += 1
return result
def __next__(self): def __next__(self):
# TODO: Add try-except to in-place reduce traceback from the Exception # TODO: Add try-except to in-place reduce traceback from the Exception
# See: https://github.com/pytorch/data/issues/284 # See: https://github.com/pytorch/data/issues/284
if self._profiler_enabled: if self._profiler_enabled:
with profiler_record_fn_context(): with profiler_record_fn_context():
_check_iterator_valid(self.source_dp, self.iterator_id) return self._get_next()
return next(self.iterator)
else: # Decided against using `contextlib.nullcontext` for performance reasons else: # Decided against using `contextlib.nullcontext` for performance reasons
_check_iterator_valid(self.source_dp, self.iterator_id) return self._get_next()
return next(self.iterator)
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.iterator, name) return getattr(self.iterator, name)
@ -136,6 +150,7 @@ def hook_iterator(namespace, profile_name):
response = gen.send(None) response = gen.send(None)
while True: while True:
datapipe._number_of_samples_yielded += 1
request = yield response request = yield response
# Pass through here every time `__next__` is called # Pass through here every time `__next__` is called
if _profiler_enabled: if _profiler_enabled:
@ -172,14 +187,19 @@ def hook_iterator(namespace, profile_name):
def wrap_next(*args, **kwargs): def wrap_next(*args, **kwargs):
if torch.autograd._profiler_enabled(): if torch.autograd._profiler_enabled():
with profiler_record_fn_context(): with profiler_record_fn_context():
return next_func(*args, **kwargs) result = next_func(*args, **kwargs)
else: else:
return next_func(*args, **kwargs) result = next_func(*args, **kwargs)
datapipe = args[0]
datapipe._number_of_samples_yielded += 1
return result
namespace['__next__'] = wrap_next namespace['__next__'] = wrap_next
# Note that if the `__next__` and `__iter__` do something completely unrelated? It may cause issue but # Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but
# the user will be violating the iterator protocol # the user will be violating the iterator protocol. Potential issue:
# 1. Valid iterator ID may not update or checked properly
# 2. The number of samples yielded will be miscounted
# Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators # Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators
@functools.wraps(func) @functools.wraps(func)
@ -187,6 +207,6 @@ def hook_iterator(namespace, profile_name):
iter_ret = func(*args, **kwargs) iter_ret = func(*args, **kwargs)
datapipe = args[0] datapipe = args[0]
iterator_id = _set_datapipe_valid_iterator_id(datapipe) # This ID is tied to each created iterator iterator_id = _set_datapipe_valid_iterator_id(datapipe) # This ID is tied to each created iterator
return IteratorDecorator(iter_ret, datapipe, iterator_id) return IteratorDecorator(iter_ret, datapipe, iterator_id, '__next__' in namespace)
namespace['__iter__'] = wrap_iter namespace['__iter__'] = wrap_iter

View File

@ -357,6 +357,7 @@ class _IterDataPipeMeta(_DataPipeMeta):
if datapipe._restored is True: if datapipe._restored is True:
datapipe._restored = False datapipe._restored = False
else: else:
datapipe._number_of_samples_yielded = 0
reset_func(*args, **kwargs) reset_func(*args, **kwargs)
namespace['reset'] = conditional_reset namespace['reset'] = conditional_reset

View File

@ -110,6 +110,7 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta):
str_hook: Optional[Callable] = None str_hook: Optional[Callable] = None
repr_hook: Optional[Callable] = None repr_hook: Optional[Callable] = None
_valid_iterator_id: Optional[int] = None _valid_iterator_id: Optional[int] = None
_number_of_samples_yielded: int = 0
_restored: bool = False _restored: bool = False
def __getattr__(self, attribute_name): def __getattr__(self, attribute_name):

View File

@ -38,6 +38,8 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta):
getstate_hook: Optional[Callable] = ... getstate_hook: Optional[Callable] = ...
str_hook: Optional[Callable] = ... str_hook: Optional[Callable] = ...
repr_hook: Optional[Callable] = ... repr_hook: Optional[Callable] = ...
_number_of_samples_yielded: int = ...
_restored: bool = False
def __getattr__(self, attribute_name: Any): ... def __getattr__(self, attribute_name: Any): ...
@classmethod @classmethod
def register_function(cls, function_name: Any, function: Any) -> None: ... def register_function(cls, function_name: Any, function: Any) -> None: ...