Make Demux serializable with lambda function (#71311)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71311

Test Plan: Imported from OSS

Reviewed By: NivekT

Differential Revision: D33584552

Pulled By: ejguan

fbshipit-source-id: 52324faf5547f9f77582ec170ec91ce3114cfc61
This commit is contained in:
Erjia Guan 2022-01-18 06:46:40 -08:00 committed by Facebook GitHub Bot
parent f0db15122f
commit fd9e08df5d
4 changed files with 132 additions and 65 deletions

View File

@ -428,33 +428,39 @@ def _worker_init_fn(worker_id):
class TestFunctionalIterDataPipe(TestCase):
# TODO(VitalyFedyunin): If dill installed this test fails
def _test_picklable(self):
arr = range(10)
picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (), {}),
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (_fake_fn, (0, )), {}),
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (partial(_fake_add, 1), (0,)), {}),
(dp.iter.Collator, dp.iter.IterableWrapper(arr), (), {}),
(dp.iter.Collator, dp.iter.IterableWrapper(arr), (_fake_fn, (0, )), {}),
(dp.iter.Filter, dp.iter.IterableWrapper(arr), (_fake_filter_fn, (0, )), {}),
(dp.iter.Filter, dp.iter.IterableWrapper(arr), (partial(_fake_filter_fn, 5), (0,)), {}),
def test_serializable(self):
input_dp = dp.iter.IterableWrapper(range(10))
picklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Mapper, (_fake_fn, ), {}),
(dp.iter.Mapper, (partial(_fake_add, 1), ), {}),
(dp.iter.Collator, (_fake_fn, ), {}),
(dp.iter.Filter, (_fake_filter_fn, ), {}),
(dp.iter.Filter, (partial(_fake_filter_fn_constant, 5), ), {}),
(dp.iter.Demultiplexer, (2, _fake_filter_fn), {}),
]
for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes:
p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
for dpipe, dp_args, dp_kwargs in picklable_datapipes:
print(dpipe)
_ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (lambda x: x, ), {}),
(dp.iter.Collator, dp.iter.IterableWrapper(arr), (lambda x: x, ), {}),
(dp.iter.Filter, dp.iter.IterableWrapper(arr), (lambda x: x >= 5, ), {}),
def test_serializable_with_dill(self):
input_dp = dp.iter.IterableWrapper(range(10))
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Mapper, (lambda x: x, ), {}),
(dp.iter.Collator, (lambda x: x, ), {}),
(dp.iter.Filter, (lambda x: x >= 5, ), {}),
(dp.iter.Demultiplexer, (2, lambda x: x % 2, ), {})
]
for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes:
with warnings.catch_warnings(record=True) as wa:
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle")
with self.assertRaises(AttributeError):
p = pickle.dumps(datapipe)
if HAS_DILL:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
_ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
else:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
with warnings.catch_warnings(record=True) as wa:
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle")
with self.assertRaises(AttributeError):
p = pickle.dumps(datapipe)
def test_iterable_wrapper_datapipe(self):
@ -711,8 +717,9 @@ class TestFunctionalIterDataPipe(TestCase):
classifier_fn=lambda x: 0 if x >= 5 else 1,
buffer_size=-1
)
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"Unlimited buffer size is set")
exp_l = 1 if HAS_DILL else 2
self.assertEqual(len(wa), exp_l)
self.assertRegex(str(wa[-1].message), r"Unlimited buffer size is set")
output1, output2 = list(dp1), list(dp2)
self.assertEqual(list(range(5, 10)), output1)
self.assertEqual(list(range(0, 5)), output2)
@ -1152,33 +1159,38 @@ class TestFunctionalIterDataPipe(TestCase):
class TestFunctionalMapDataPipe(TestCase):
# TODO(VitalyFedyunin): If dill installed this test fails
def _test_picklable(self):
arr = range(10)
def test_serializable(self):
input_dp = dp.map.SequenceWrapper(range(10))
picklable_datapipes: List[
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
] = [
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (), {}),
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (_fake_fn, (0,)), {}),
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (partial(_fake_add, 1), (0,)), {}),
(dp.map.Mapper, (), {}),
(dp.map.Mapper, (_fake_fn, ), {}),
(dp.map.Mapper, (partial(_fake_add, 1), ), {}),
]
for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes:
p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
for dpipe, dp_args, dp_kwargs in picklable_datapipes:
_ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
def test_serializable_with_dill(self):
input_dp = dp.map.SequenceWrapper(range(10))
unpicklable_datapipes: List[
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
] = [
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (lambda x: x,), {}),
(dp.map.Mapper, (lambda x: x,), {}),
]
for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes:
with warnings.catch_warnings(record=True) as wa:
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self.assertEqual(len(wa), 1)
self.assertRegex(
str(wa[0].message), r"^Lambda function is not supported for pickle"
)
with self.assertRaises(AttributeError):
p = pickle.dumps(datapipe)
if HAS_DILL:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
_ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
else:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
with warnings.catch_warnings(record=True) as wa:
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self.assertEqual(len(wa), 1)
self.assertRegex(
str(wa[0].message), r"^Lambda function is not supported for pickle"
)
with self.assertRaises(AttributeError):
p = pickle.dumps(datapipe)
def test_sequence_wrapper_datapipe(self):
seq = list(range(10))

View File

@ -1,18 +1,11 @@
import warnings
from torch.utils.data import IterDataPipe, _utils, functional_datapipe
from typing import Callable, Iterator, Sized, TypeVar
try:
import dill
from torch.utils.data import IterDataPipe, _utils, functional_datapipe
from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE, check_lambda_fn
# XXX: By default, dill writes the Pickler dispatch table to inject its
# own logic there. This globally affects the behavior of the standard library
# pickler for any user who transitively depends on this module!
# Undo this extension to avoid altering the behavior of the pickler globally.
if DILL_AVAILABLE:
import dill
dill.extend(use_dill=False)
DILL_AVAILABLE = True
except ImportError:
DILL_AVAILABLE = False
T_co = TypeVar("T_co", covariant=True)
@ -50,13 +43,10 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
) -> None:
super().__init__()
self.datapipe = datapipe
# Partial object has no attribute '__name__', but can be pickled
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not DILL_AVAILABLE:
warnings.warn(
"Lambda function is not supported for pickle, please use "
"regular python function or functools.partial instead."
)
check_lambda_fn(fn)
self.fn = fn # type: ignore[assignment]
self.input_col = input_col
if input_col is None and output_col is not None:
raise ValueError("`output_col` must be None when `input_col` is None.")

View File

@ -1,8 +1,14 @@
import warnings
from torch.utils.data import IterDataPipe, functional_datapipe
from typing import Any, Callable, Iterator, List, Optional, Set, Sized, Tuple, TypeVar, Deque
from collections import deque
from typing import Any, Callable, Iterator, List, Optional, Set, Sized, Tuple, TypeVar, Deque
from torch.utils.data import IterDataPipe, functional_datapipe
from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE, check_lambda_fn
if DILL_AVAILABLE:
import dill
dill.extend(use_dill=False)
T_co = TypeVar('T_co', covariant=True)
@ -190,6 +196,9 @@ class DemultiplexerIterDataPipe(IterDataPipe):
classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 1000):
if num_instances < 1:
raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")
check_lambda_fn(classifier_fn)
# When num_instances == 1, demux can be replaced by filter,
# but keep it as Demultiplexer for the sake of consistency
# like throwing Error when classification result is out of o range
@ -277,6 +286,41 @@ class _DemultiplexerIterDataPipe(IterDataPipe):
self.instance_started = [False] * self.num_instances
self.main_datapipe_exhausted = False
def __getstate__(self):
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(self)
if DILL_AVAILABLE:
dill_function = dill.dumps(self.classifier_fn)
else:
dill_function = self.classifier_fn
state = (
self.main_datapipe,
self.num_instances,
self.buffer_size,
dill_function,
self.drop_none,
)
return state
def __setstate__(self, state):
(
self.main_datapipe,
self.num_instances,
self.buffer_size,
dill_function,
self.drop_none,
) = state
if DILL_AVAILABLE:
self.classifier_fn = dill.loads(dill_function) # type: ignore[assignment]
else:
self.classifier_fn = dill_function # type: ignore[assignment]
self._datapipe_iterator = None
self.current_buffer_usage = 0
self.child_buffers = [deque() for _ in range(self.num_instances)]
self.instance_started = [False] * self.num_instances
self.main_datapipe_exhausted = False
@functional_datapipe('mux')
class MultiplexerIterDataPipe(IterDataPipe):

View File

@ -5,6 +5,27 @@ import warnings
from io import IOBase
from typing import Iterable, List, Tuple, Union
try:
import dill
# XXX: By default, dill writes the Pickler dispatch table to inject its
# own logic there. This globally affects the behavior of the standard library
# pickler for any user who transitively depends on this module!
# Undo this extension to avoid altering the behavior of the pickler globally.
dill.extend(use_dill=False)
DILL_AVAILABLE = True
except ImportError:
DILL_AVAILABLE = False
def check_lambda_fn(fn):
# Partial object has no attribute '__name__', but can be pickled
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not DILL_AVAILABLE:
warnings.warn(
"Lambda function is not supported for pickle, please use "
"regular python function or functools.partial instead."
)
def match_masks(name : str, masks : Union[str, List[str]]) -> bool:
# empty mask matches any input name