diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 28550c27283..8941a52c5df 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -428,33 +428,39 @@ def _worker_init_fn(worker_id): class TestFunctionalIterDataPipe(TestCase): - # TODO(VitalyFedyunin): If dill installed this test fails - def _test_picklable(self): - arr = range(10) - picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [ - (dp.iter.Mapper, dp.iter.IterableWrapper(arr), (), {}), - (dp.iter.Mapper, dp.iter.IterableWrapper(arr), (_fake_fn, (0, )), {}), - (dp.iter.Mapper, dp.iter.IterableWrapper(arr), (partial(_fake_add, 1), (0,)), {}), - (dp.iter.Collator, dp.iter.IterableWrapper(arr), (), {}), - (dp.iter.Collator, dp.iter.IterableWrapper(arr), (_fake_fn, (0, )), {}), - (dp.iter.Filter, dp.iter.IterableWrapper(arr), (_fake_filter_fn, (0, )), {}), - (dp.iter.Filter, dp.iter.IterableWrapper(arr), (partial(_fake_filter_fn, 5), (0,)), {}), + def test_serializable(self): + input_dp = dp.iter.IterableWrapper(range(10)) + picklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [ + (dp.iter.Mapper, (_fake_fn, ), {}), + (dp.iter.Mapper, (partial(_fake_add, 1), ), {}), + (dp.iter.Collator, (_fake_fn, ), {}), + (dp.iter.Filter, (_fake_filter_fn, ), {}), + (dp.iter.Filter, (partial(_fake_filter_fn_constant, 5), ), {}), + (dp.iter.Demultiplexer, (2, _fake_filter_fn), {}), ] - for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes: - p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] + for dpipe, dp_args, dp_kwargs in picklable_datapipes: + print(dpipe) + _ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] - unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [ - (dp.iter.Mapper, dp.iter.IterableWrapper(arr), (lambda x: x, ), {}), - (dp.iter.Collator, dp.iter.IterableWrapper(arr), (lambda x: x, ), {}), - (dp.iter.Filter, dp.iter.IterableWrapper(arr), (lambda x: x >= 5, ), {}), + def test_serializable_with_dill(self): + input_dp = dp.iter.IterableWrapper(range(10)) + unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [ + (dp.iter.Mapper, (lambda x: x, ), {}), + (dp.iter.Collator, (lambda x: x, ), {}), + (dp.iter.Filter, (lambda x: x >= 5, ), {}), + (dp.iter.Demultiplexer, (2, lambda x: x % 2, ), {}) ] - for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes: - with warnings.catch_warnings(record=True) as wa: - datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] - self.assertEqual(len(wa), 1) - self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle") - with self.assertRaises(AttributeError): - p = pickle.dumps(datapipe) + if HAS_DILL: + for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: + _ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] + else: + for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: + with warnings.catch_warnings(record=True) as wa: + datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] + self.assertEqual(len(wa), 1) + self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle") + with self.assertRaises(AttributeError): + p = pickle.dumps(datapipe) def test_iterable_wrapper_datapipe(self): @@ -711,8 +717,9 @@ class TestFunctionalIterDataPipe(TestCase): classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=-1 ) - self.assertEqual(len(wa), 1) - self.assertRegex(str(wa[0].message), r"Unlimited buffer size is set") + exp_l = 1 if HAS_DILL else 2 + self.assertEqual(len(wa), exp_l) + self.assertRegex(str(wa[-1].message), r"Unlimited buffer size is set") output1, output2 = list(dp1), list(dp2) self.assertEqual(list(range(5, 10)), output1) self.assertEqual(list(range(0, 5)), output2) @@ -1152,33 +1159,38 @@ class TestFunctionalIterDataPipe(TestCase): class TestFunctionalMapDataPipe(TestCase): - # TODO(VitalyFedyunin): If dill installed this test fails - def _test_picklable(self): - arr = range(10) + def test_serializable(self): + input_dp = dp.map.SequenceWrapper(range(10)) picklable_datapipes: List[ - Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]] + Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]] ] = [ - (dp.map.Mapper, dp.map.SequenceWrapper(arr), (), {}), - (dp.map.Mapper, dp.map.SequenceWrapper(arr), (_fake_fn, (0,)), {}), - (dp.map.Mapper, dp.map.SequenceWrapper(arr), (partial(_fake_add, 1), (0,)), {}), + (dp.map.Mapper, (), {}), + (dp.map.Mapper, (_fake_fn, ), {}), + (dp.map.Mapper, (partial(_fake_add, 1), ), {}), ] - for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes: - p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] + for dpipe, dp_args, dp_kwargs in picklable_datapipes: + _ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] + def test_serializable_with_dill(self): + input_dp = dp.map.SequenceWrapper(range(10)) unpicklable_datapipes: List[ - Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]] + Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]] ] = [ - (dp.map.Mapper, dp.map.SequenceWrapper(arr), (lambda x: x,), {}), + (dp.map.Mapper, (lambda x: x,), {}), ] - for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes: - with warnings.catch_warnings(record=True) as wa: - datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] - self.assertEqual(len(wa), 1) - self.assertRegex( - str(wa[0].message), r"^Lambda function is not supported for pickle" - ) - with self.assertRaises(AttributeError): - p = pickle.dumps(datapipe) + if HAS_DILL: + for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: + _ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] + else: + for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: + with warnings.catch_warnings(record=True) as wa: + datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] + self.assertEqual(len(wa), 1) + self.assertRegex( + str(wa[0].message), r"^Lambda function is not supported for pickle" + ) + with self.assertRaises(AttributeError): + p = pickle.dumps(datapipe) def test_sequence_wrapper_datapipe(self): seq = list(range(10)) diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index b63b6d6e835..1d939265256 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -1,18 +1,11 @@ -import warnings -from torch.utils.data import IterDataPipe, _utils, functional_datapipe from typing import Callable, Iterator, Sized, TypeVar -try: - import dill +from torch.utils.data import IterDataPipe, _utils, functional_datapipe +from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE, check_lambda_fn - # XXX: By default, dill writes the Pickler dispatch table to inject its - # own logic there. This globally affects the behavior of the standard library - # pickler for any user who transitively depends on this module! - # Undo this extension to avoid altering the behavior of the pickler globally. +if DILL_AVAILABLE: + import dill dill.extend(use_dill=False) - DILL_AVAILABLE = True -except ImportError: - DILL_AVAILABLE = False T_co = TypeVar("T_co", covariant=True) @@ -50,13 +43,10 @@ class MapperIterDataPipe(IterDataPipe[T_co]): ) -> None: super().__init__() self.datapipe = datapipe - # Partial object has no attribute '__name__', but can be pickled - if hasattr(fn, "__name__") and fn.__name__ == "" and not DILL_AVAILABLE: - warnings.warn( - "Lambda function is not supported for pickle, please use " - "regular python function or functools.partial instead." - ) + + check_lambda_fn(fn) self.fn = fn # type: ignore[assignment] + self.input_col = input_col if input_col is None and output_col is not None: raise ValueError("`output_col` must be None when `input_col` is None.") diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index 1e886547bb8..55788565e70 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -1,8 +1,14 @@ import warnings -from torch.utils.data import IterDataPipe, functional_datapipe -from typing import Any, Callable, Iterator, List, Optional, Set, Sized, Tuple, TypeVar, Deque from collections import deque +from typing import Any, Callable, Iterator, List, Optional, Set, Sized, Tuple, TypeVar, Deque + +from torch.utils.data import IterDataPipe, functional_datapipe +from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE, check_lambda_fn + +if DILL_AVAILABLE: + import dill + dill.extend(use_dill=False) T_co = TypeVar('T_co', covariant=True) @@ -190,6 +196,9 @@ class DemultiplexerIterDataPipe(IterDataPipe): classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 1000): if num_instances < 1: raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found") + + check_lambda_fn(classifier_fn) + # When num_instances == 1, demux can be replaced by filter, # but keep it as Demultiplexer for the sake of consistency # like throwing Error when classification result is out of o range @@ -277,6 +286,41 @@ class _DemultiplexerIterDataPipe(IterDataPipe): self.instance_started = [False] * self.num_instances self.main_datapipe_exhausted = False + def __getstate__(self): + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(self) + + if DILL_AVAILABLE: + dill_function = dill.dumps(self.classifier_fn) + else: + dill_function = self.classifier_fn + state = ( + self.main_datapipe, + self.num_instances, + self.buffer_size, + dill_function, + self.drop_none, + ) + return state + + def __setstate__(self, state): + ( + self.main_datapipe, + self.num_instances, + self.buffer_size, + dill_function, + self.drop_none, + ) = state + if DILL_AVAILABLE: + self.classifier_fn = dill.loads(dill_function) # type: ignore[assignment] + else: + self.classifier_fn = dill_function # type: ignore[assignment] + self._datapipe_iterator = None + self.current_buffer_usage = 0 + self.child_buffers = [deque() for _ in range(self.num_instances)] + self.instance_started = [False] * self.num_instances + self.main_datapipe_exhausted = False + @functional_datapipe('mux') class MultiplexerIterDataPipe(IterDataPipe): diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py index 2f390c34340..16bd5d227b8 100644 --- a/torch/utils/data/datapipes/utils/common.py +++ b/torch/utils/data/datapipes/utils/common.py @@ -5,6 +5,27 @@ import warnings from io import IOBase from typing import Iterable, List, Tuple, Union +try: + import dill + + # XXX: By default, dill writes the Pickler dispatch table to inject its + # own logic there. This globally affects the behavior of the standard library + # pickler for any user who transitively depends on this module! + # Undo this extension to avoid altering the behavior of the pickler globally. + dill.extend(use_dill=False) + DILL_AVAILABLE = True +except ImportError: + DILL_AVAILABLE = False + + +def check_lambda_fn(fn): + # Partial object has no attribute '__name__', but can be pickled + if hasattr(fn, "__name__") and fn.__name__ == "" and not DILL_AVAILABLE: + warnings.warn( + "Lambda function is not supported for pickle, please use " + "regular python function or functools.partial instead." + ) + def match_masks(name : str, masks : Union[str, List[str]]) -> bool: # empty mask matches any input name