mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DataPipe] Count number of successful yields for IterDataPipe (#79657)
This PR adds an attribute and logic to count the number of successful yields from `IterDataPipe`. This information can be useful to fast-forward a DataPipe (or the entire graph) back to a certain state. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79657 Approved by: https://github.com/VitalyFedyunin
This commit is contained in:
parent
7850a328b4
commit
b8e50f512f
|
|
@ -277,7 +277,7 @@ class TestIterableDataPipeBasic(TestCase):
|
|||
|
||||
def test_listdirfiles_iterable_datapipe(self):
|
||||
temp_dir = self.temp_dir.name
|
||||
datapipe = dp.iter.FileLister(temp_dir, '')
|
||||
datapipe: IterDataPipe = dp.iter.FileLister(temp_dir, '')
|
||||
|
||||
count = 0
|
||||
for pathname in datapipe:
|
||||
|
|
@ -2640,5 +2640,120 @@ class TestIterDataPipeSingletonConstraint(TestCase):
|
|||
next(it1)
|
||||
self.assertEqual(1, next(it3))
|
||||
|
||||
class TestIterDataPipeCountSampleYielded(TestCase):
|
||||
|
||||
def _yield_count_test_helper(self, datapipe, n_expected_samples):
|
||||
|
||||
# Functional Test: Check if number of samples yielded is as expected
|
||||
res = list(datapipe)
|
||||
self.assertEqual(len(res), datapipe._number_of_samples_yielded)
|
||||
|
||||
# Functional Test: Check if the count is correct when DataPipe is partially read
|
||||
it = iter(datapipe)
|
||||
res = []
|
||||
for i, value in enumerate(it):
|
||||
res.append(value)
|
||||
if i == n_expected_samples - 1:
|
||||
break
|
||||
self.assertEqual(n_expected_samples, datapipe._number_of_samples_yielded)
|
||||
|
||||
# Functional Test: Check for reset behavior and if iterator also works
|
||||
it = iter(datapipe) # reset the DataPipe
|
||||
res = list(it)
|
||||
self.assertEqual(len(res), datapipe._number_of_samples_yielded)
|
||||
|
||||
def test_iterdatapipe_sample_yielded_generator_function(self):
|
||||
# Functional Test: `__iter__` is a generator function
|
||||
datapipe: IterDataPipe = dp.iter.IterableWrapper(range(10))
|
||||
self._yield_count_test_helper(datapipe, n_expected_samples=5)
|
||||
|
||||
def test_iterdatapipe_sample_yielded_generator_function_exception(self):
|
||||
# Functional Test: `__iter__` is a custom generator function with exception
|
||||
class _CustomGeneratorFnDataPipe(IterDataPipe):
|
||||
# This class's `__iter__` has a Runtime Error
|
||||
def __iter__(self):
|
||||
yield 0
|
||||
yield 1
|
||||
yield 2
|
||||
raise RuntimeError("Custom test error after yielding 3 elements")
|
||||
yield 3
|
||||
|
||||
# Functional Test: Ensure the count is correct even when exception is raised
|
||||
datapipe: IterDataPipe = _CustomGeneratorFnDataPipe()
|
||||
with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"):
|
||||
list(datapipe)
|
||||
self.assertEqual(3, datapipe._number_of_samples_yielded)
|
||||
|
||||
# Functional Test: Check for reset behavior and if iterator also works
|
||||
it = iter(datapipe) # reset the DataPipe
|
||||
with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"):
|
||||
list(it)
|
||||
self.assertEqual(3, datapipe._number_of_samples_yielded)
|
||||
|
||||
def test_iterdatapipe_sample_yielded_return_self(self):
|
||||
class _CustomGeneratorDataPipe(IterDataPipe):
|
||||
# This class's `__iter__` is not a generator function
|
||||
def __init__(self):
|
||||
self.source = iter(range(10))
|
||||
|
||||
def __iter__(self):
|
||||
return self.source
|
||||
|
||||
def reset(self):
|
||||
self.source = iter(range(10))
|
||||
|
||||
datapipe: IterDataPipe = _CustomGeneratorDataPipe()
|
||||
self._yield_count_test_helper(datapipe, n_expected_samples=5)
|
||||
|
||||
def test_iterdatapipe_sample_yielded_next(self):
|
||||
class _CustomNextDataPipe(IterDataPipe):
|
||||
# This class's `__iter__` returns `self` and has a `__next__`
|
||||
def __init__(self):
|
||||
self.source = iter(range(10))
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return next(self.source)
|
||||
|
||||
def reset(self):
|
||||
self.source = iter(range(10))
|
||||
|
||||
datapipe: IterDataPipe = _CustomNextDataPipe()
|
||||
self._yield_count_test_helper(datapipe, n_expected_samples=5)
|
||||
|
||||
def test_iterdatapipe_sample_yielded_next_exception(self):
|
||||
class _CustomNextDataPipe(IterDataPipe):
|
||||
# This class's `__iter__` returns `self` and has a `__next__`
|
||||
def __init__(self):
|
||||
self.source = iter(range(10))
|
||||
self.count = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.count == 3:
|
||||
raise RuntimeError("Custom test error after yielding 3 elements")
|
||||
self.count += 1
|
||||
return next(self.source)
|
||||
|
||||
def reset(self):
|
||||
self.count = 0
|
||||
self.source = iter(range(10))
|
||||
|
||||
# Functional Test: Ensure the count is correct even when exception is raised
|
||||
datapipe: IterDataPipe = _CustomNextDataPipe()
|
||||
with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"):
|
||||
list(datapipe)
|
||||
self.assertEqual(3, datapipe._number_of_samples_yielded)
|
||||
|
||||
# Functional Test: Check for reset behavior and if iterator also works
|
||||
it = iter(datapipe) # reset the DataPipe
|
||||
with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"):
|
||||
list(it)
|
||||
self.assertEqual(3, datapipe._number_of_samples_yielded)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -94,26 +94,40 @@ def hook_iterator(namespace, profile_name):
|
|||
return torch.autograd.profiler.record_function(profile_name)
|
||||
|
||||
class IteratorDecorator:
|
||||
"""Wrap the iterator and modifying its `__next__` method"""
|
||||
def __init__(self, iterator, source_dp, iterator_id):
|
||||
r"""
|
||||
Wrap the iterator and modifying its `__next__` method. This decorator is applied to
|
||||
DataPipes of which `__iter__` method is NOT a generator function. Those `__iter__`
|
||||
method commonly returns `self` but not necessarily.
|
||||
"""
|
||||
def __init__(self, iterator, source_dp, iterator_id, has_next_method):
|
||||
self.iterator = iterator
|
||||
self.source_dp = source_dp
|
||||
self.iterator_id = iterator_id
|
||||
self._profiler_enabled = torch.autograd._profiler_enabled()
|
||||
# Check if `__iter__` returns `self` and `DataPipe` has `__next__`
|
||||
self.self_and_has_next_method = self.iterator is self.source_dp and has_next_method
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def _get_next(self):
|
||||
r"""
|
||||
Return next with logic related to iterator validity, profiler, and incrementation of samples yielded.
|
||||
"""
|
||||
_check_iterator_valid(self.source_dp, self.iterator_id)
|
||||
result = next(self.iterator)
|
||||
if not self.self_and_has_next_method:
|
||||
self.source_dp._number_of_samples_yielded += 1
|
||||
return result
|
||||
|
||||
def __next__(self):
|
||||
# TODO: Add try-except to in-place reduce traceback from the Exception
|
||||
# See: https://github.com/pytorch/data/issues/284
|
||||
if self._profiler_enabled:
|
||||
with profiler_record_fn_context():
|
||||
_check_iterator_valid(self.source_dp, self.iterator_id)
|
||||
return next(self.iterator)
|
||||
return self._get_next()
|
||||
else: # Decided against using `contextlib.nullcontext` for performance reasons
|
||||
_check_iterator_valid(self.source_dp, self.iterator_id)
|
||||
return next(self.iterator)
|
||||
return self._get_next()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.iterator, name)
|
||||
|
|
@ -136,6 +150,7 @@ def hook_iterator(namespace, profile_name):
|
|||
response = gen.send(None)
|
||||
|
||||
while True:
|
||||
datapipe._number_of_samples_yielded += 1
|
||||
request = yield response
|
||||
# Pass through here every time `__next__` is called
|
||||
if _profiler_enabled:
|
||||
|
|
@ -172,14 +187,19 @@ def hook_iterator(namespace, profile_name):
|
|||
def wrap_next(*args, **kwargs):
|
||||
if torch.autograd._profiler_enabled():
|
||||
with profiler_record_fn_context():
|
||||
return next_func(*args, **kwargs)
|
||||
result = next_func(*args, **kwargs)
|
||||
else:
|
||||
return next_func(*args, **kwargs)
|
||||
result = next_func(*args, **kwargs)
|
||||
datapipe = args[0]
|
||||
datapipe._number_of_samples_yielded += 1
|
||||
return result
|
||||
|
||||
namespace['__next__'] = wrap_next
|
||||
|
||||
# Note that if the `__next__` and `__iter__` do something completely unrelated? It may cause issue but
|
||||
# the user will be violating the iterator protocol
|
||||
# Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but
|
||||
# the user will be violating the iterator protocol. Potential issue:
|
||||
# 1. Valid iterator ID may not update or checked properly
|
||||
# 2. The number of samples yielded will be miscounted
|
||||
|
||||
# Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators
|
||||
@functools.wraps(func)
|
||||
|
|
@ -187,6 +207,6 @@ def hook_iterator(namespace, profile_name):
|
|||
iter_ret = func(*args, **kwargs)
|
||||
datapipe = args[0]
|
||||
iterator_id = _set_datapipe_valid_iterator_id(datapipe) # This ID is tied to each created iterator
|
||||
return IteratorDecorator(iter_ret, datapipe, iterator_id)
|
||||
return IteratorDecorator(iter_ret, datapipe, iterator_id, '__next__' in namespace)
|
||||
|
||||
namespace['__iter__'] = wrap_iter
|
||||
|
|
|
|||
|
|
@ -357,6 +357,7 @@ class _IterDataPipeMeta(_DataPipeMeta):
|
|||
if datapipe._restored is True:
|
||||
datapipe._restored = False
|
||||
else:
|
||||
datapipe._number_of_samples_yielded = 0
|
||||
reset_func(*args, **kwargs)
|
||||
|
||||
namespace['reset'] = conditional_reset
|
||||
|
|
|
|||
|
|
@ -110,6 +110,7 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta):
|
|||
str_hook: Optional[Callable] = None
|
||||
repr_hook: Optional[Callable] = None
|
||||
_valid_iterator_id: Optional[int] = None
|
||||
_number_of_samples_yielded: int = 0
|
||||
_restored: bool = False
|
||||
|
||||
def __getattr__(self, attribute_name):
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta):
|
|||
getstate_hook: Optional[Callable] = ...
|
||||
str_hook: Optional[Callable] = ...
|
||||
repr_hook: Optional[Callable] = ...
|
||||
_number_of_samples_yielded: int = ...
|
||||
_restored: bool = False
|
||||
def __getattr__(self, attribute_name: Any): ...
|
||||
@classmethod
|
||||
def register_function(cls, function_name: Any, function: Any) -> None: ...
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user