diff --git a/test/test_datapipe.py b/test/test_datapipe.py index c951b0c53e7..c3736edae27 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -112,7 +112,8 @@ def create_temp_dir_and_files(): # Then, reset the DataPipe and return a tuple of two lists # 1. A list of elements yielded before the reset # 2. A list of all elements of the DataPipe after the reset -def reset_after_n_next_calls(datapipe: IterDataPipe[T_co], n: int) -> Tuple[List[T_co], List[T_co]]: +def reset_after_n_next_calls(datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_co]], + n: int) -> Tuple[List[T_co], List[T_co]]: it = iter(datapipe) res_before_reset = [] for _ in range(n): @@ -644,19 +645,6 @@ class IDP_NoLen(IterDataPipe): yield i -class MDP(MapDataPipe): - def __init__(self, input_dp): - super().__init__() - self.input_dp = input_dp - self.length = len(input_dp) - - def __getitem__(self, index): - return self.input_dp[index] - - def __len__(self) -> int: - return self.length - - def _fake_fn(data, *args, **kwargs): return data @@ -1443,8 +1431,8 @@ class TestFunctionalMapDataPipe(TestCase): picklable_datapipes: List[ Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]] ] = [ - (dp.map.Mapper, MDP(arr), (), {}), - (dp.map.Mapper, MDP(arr), (_fake_fn, (0,), {'test': True}), {}), + (dp.map.Mapper, dp.map.SequenceWrapper(arr), (), {}), + (dp.map.Mapper, dp.map.SequenceWrapper(arr), (_fake_fn, (0,), {'test': True}), {}), ] for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes: p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] @@ -1452,7 +1440,7 @@ class TestFunctionalMapDataPipe(TestCase): unpicklable_datapipes: List[ Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]] ] = [ - (dp.map.Mapper, MDP(arr), (lambda x: x,), {}), + (dp.map.Mapper, dp.map.SequenceWrapper(arr), (lambda x: x,), {}), ] for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes: with warnings.catch_warnings(record=True) as wa: @@ -1464,9 +1452,36 @@ class TestFunctionalMapDataPipe(TestCase): with self.assertRaises(AttributeError): p = pickle.dumps(datapipe) + def test_sequence_wrapper_datapipe(self): + seq = list(range(10)) + input_dp = dp.map.SequenceWrapper(seq) + + # Functional Test: all elements are equal in the same order + self.assertEqual(seq, list(input_dp)) + + # Functional Test: confirm deepcopy works by default + seq.append(11) + self.assertEqual(list(range(10)), list(input_dp)) # input_dp shouldn't have 11 + + # Functional Test: non-deepcopy version is working + seq2 = [1, 2, 3] + input_dp_non_deep = dp.map.SequenceWrapper(seq2, deepcopy=False) + seq2.append(4) + self.assertEqual(list(seq2), list(input_dp_non_deep)) # should have 4 + + # Reset Test: reset the DataPipe + seq = list(range(10)) + n_elements_before_reset = 5 + res_before_reset, res_after_reset = reset_after_n_next_calls(input_dp, n_elements_before_reset) + self.assertEqual(list(range(5)), res_before_reset) + self.assertEqual(seq, res_after_reset) + + # __len__ Test: inherits length from sequence + self.assertEqual(len(seq), len(input_dp)) + def test_concat_datapipe(self): - input_dp1 = MDP(range(10)) - input_dp2 = MDP(range(5)) + input_dp1 = dp.map.SequenceWrapper(range(10)) + input_dp2 = dp.map.SequenceWrapper(range(5)) with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"): dp.map.Concater() @@ -1482,7 +1497,7 @@ class TestFunctionalMapDataPipe(TestCase): def test_map_datapipe(self): arr = range(10) - input_dp = MDP(arr) + input_dp = dp.map.SequenceWrapper(arr) def fn(item, dtype=torch.float, *, sum=False): data = torch.tensor(item, dtype=dtype) diff --git a/torch/utils/data/datapipes/map/__init__.py b/torch/utils/data/datapipes/map/__init__.py index 5879165aff2..0b356dc7e21 100644 --- a/torch/utils/data/datapipes/map/__init__.py +++ b/torch/utils/data/datapipes/map/__init__.py @@ -1,6 +1,7 @@ # Functional DataPipe from torch.utils.data.datapipes.map.callable import MapperMapDataPipe as Mapper from torch.utils.data.datapipes.map.combining import ConcaterMapDataPipe as Concater +from torch.utils.data.datapipes.map.utils import SequenceWrapperMapDataPipe as SequenceWrapper -__all__ = ['Concater', 'Mapper'] +__all__ = ['Concater', 'Mapper', 'SequenceWrapper'] diff --git a/torch/utils/data/datapipes/map/utils.py b/torch/utils/data/datapipes/map/utils.py new file mode 100644 index 00000000000..2dc938f6255 --- /dev/null +++ b/torch/utils/data/datapipes/map/utils.py @@ -0,0 +1,38 @@ +import copy +import warnings +from torch.utils.data import MapDataPipe + + +class SequenceWrapperMapDataPipe(MapDataPipe): + r""":class:`SequenceWrapperMapDataPipe`. + + Map DataPipe that wraps a sequence object. + + Args: + sequence: Sequence object to be wrapped into an IterDataPipe + deepcopy: Option to deepcopy input sequence object + + .. note:: + If `deepcopy` is set to False explicitly, users should ensure + that data pipeline doesn't contain any in-place operations over + the iterable instance, in order to prevent data inconsistency + across iterations. + """ + def __init__(self, sequence, deepcopy=True): + if deepcopy: + try: + self.sequence = copy.deepcopy(sequence) + except TypeError: + warnings.warn( + "The input sequence can not be deepcopied, " + "please be aware of in-place modification would affect source data" + ) + self.sequence = sequence + else: + self.sequence = sequence + + def __getitem__(self, index): + return self.sequence[index] + + def __len__(self): + return len(self.sequence)