mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make Demux serializable with lambda function (#71311)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71311 Test Plan: Imported from OSS Reviewed By: NivekT Differential Revision: D33584552 Pulled By: ejguan fbshipit-source-id: 52324faf5547f9f77582ec170ec91ce3114cfc61
This commit is contained in:
parent
f0db15122f
commit
fd9e08df5d
|
|
@ -428,33 +428,39 @@ def _worker_init_fn(worker_id):
|
|||
|
||||
class TestFunctionalIterDataPipe(TestCase):
|
||||
|
||||
# TODO(VitalyFedyunin): If dill installed this test fails
|
||||
def _test_picklable(self):
|
||||
arr = range(10)
|
||||
picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [
|
||||
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (), {}),
|
||||
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (_fake_fn, (0, )), {}),
|
||||
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (partial(_fake_add, 1), (0,)), {}),
|
||||
(dp.iter.Collator, dp.iter.IterableWrapper(arr), (), {}),
|
||||
(dp.iter.Collator, dp.iter.IterableWrapper(arr), (_fake_fn, (0, )), {}),
|
||||
(dp.iter.Filter, dp.iter.IterableWrapper(arr), (_fake_filter_fn, (0, )), {}),
|
||||
(dp.iter.Filter, dp.iter.IterableWrapper(arr), (partial(_fake_filter_fn, 5), (0,)), {}),
|
||||
def test_serializable(self):
|
||||
input_dp = dp.iter.IterableWrapper(range(10))
|
||||
picklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
|
||||
(dp.iter.Mapper, (_fake_fn, ), {}),
|
||||
(dp.iter.Mapper, (partial(_fake_add, 1), ), {}),
|
||||
(dp.iter.Collator, (_fake_fn, ), {}),
|
||||
(dp.iter.Filter, (_fake_filter_fn, ), {}),
|
||||
(dp.iter.Filter, (partial(_fake_filter_fn_constant, 5), ), {}),
|
||||
(dp.iter.Demultiplexer, (2, _fake_filter_fn), {}),
|
||||
]
|
||||
for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes:
|
||||
p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
|
||||
for dpipe, dp_args, dp_kwargs in picklable_datapipes:
|
||||
print(dpipe)
|
||||
_ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
|
||||
|
||||
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [
|
||||
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (lambda x: x, ), {}),
|
||||
(dp.iter.Collator, dp.iter.IterableWrapper(arr), (lambda x: x, ), {}),
|
||||
(dp.iter.Filter, dp.iter.IterableWrapper(arr), (lambda x: x >= 5, ), {}),
|
||||
def test_serializable_with_dill(self):
|
||||
input_dp = dp.iter.IterableWrapper(range(10))
|
||||
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
|
||||
(dp.iter.Mapper, (lambda x: x, ), {}),
|
||||
(dp.iter.Collator, (lambda x: x, ), {}),
|
||||
(dp.iter.Filter, (lambda x: x >= 5, ), {}),
|
||||
(dp.iter.Demultiplexer, (2, lambda x: x % 2, ), {})
|
||||
]
|
||||
for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes:
|
||||
with warnings.catch_warnings(record=True) as wa:
|
||||
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
|
||||
self.assertEqual(len(wa), 1)
|
||||
self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle")
|
||||
with self.assertRaises(AttributeError):
|
||||
p = pickle.dumps(datapipe)
|
||||
if HAS_DILL:
|
||||
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
|
||||
_ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
|
||||
else:
|
||||
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
|
||||
with warnings.catch_warnings(record=True) as wa:
|
||||
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
|
||||
self.assertEqual(len(wa), 1)
|
||||
self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle")
|
||||
with self.assertRaises(AttributeError):
|
||||
p = pickle.dumps(datapipe)
|
||||
|
||||
def test_iterable_wrapper_datapipe(self):
|
||||
|
||||
|
|
@ -711,8 +717,9 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
classifier_fn=lambda x: 0 if x >= 5 else 1,
|
||||
buffer_size=-1
|
||||
)
|
||||
self.assertEqual(len(wa), 1)
|
||||
self.assertRegex(str(wa[0].message), r"Unlimited buffer size is set")
|
||||
exp_l = 1 if HAS_DILL else 2
|
||||
self.assertEqual(len(wa), exp_l)
|
||||
self.assertRegex(str(wa[-1].message), r"Unlimited buffer size is set")
|
||||
output1, output2 = list(dp1), list(dp2)
|
||||
self.assertEqual(list(range(5, 10)), output1)
|
||||
self.assertEqual(list(range(0, 5)), output2)
|
||||
|
|
@ -1152,33 +1159,38 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
|
||||
|
||||
class TestFunctionalMapDataPipe(TestCase):
|
||||
# TODO(VitalyFedyunin): If dill installed this test fails
|
||||
def _test_picklable(self):
|
||||
arr = range(10)
|
||||
def test_serializable(self):
|
||||
input_dp = dp.map.SequenceWrapper(range(10))
|
||||
picklable_datapipes: List[
|
||||
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
|
||||
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
|
||||
] = [
|
||||
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (), {}),
|
||||
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (_fake_fn, (0,)), {}),
|
||||
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (partial(_fake_add, 1), (0,)), {}),
|
||||
(dp.map.Mapper, (), {}),
|
||||
(dp.map.Mapper, (_fake_fn, ), {}),
|
||||
(dp.map.Mapper, (partial(_fake_add, 1), ), {}),
|
||||
]
|
||||
for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes:
|
||||
p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
|
||||
for dpipe, dp_args, dp_kwargs in picklable_datapipes:
|
||||
_ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
|
||||
|
||||
def test_serializable_with_dill(self):
|
||||
input_dp = dp.map.SequenceWrapper(range(10))
|
||||
unpicklable_datapipes: List[
|
||||
Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]]
|
||||
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
|
||||
] = [
|
||||
(dp.map.Mapper, dp.map.SequenceWrapper(arr), (lambda x: x,), {}),
|
||||
(dp.map.Mapper, (lambda x: x,), {}),
|
||||
]
|
||||
for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes:
|
||||
with warnings.catch_warnings(record=True) as wa:
|
||||
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
|
||||
self.assertEqual(len(wa), 1)
|
||||
self.assertRegex(
|
||||
str(wa[0].message), r"^Lambda function is not supported for pickle"
|
||||
)
|
||||
with self.assertRaises(AttributeError):
|
||||
p = pickle.dumps(datapipe)
|
||||
if HAS_DILL:
|
||||
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
|
||||
_ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
|
||||
else:
|
||||
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
|
||||
with warnings.catch_warnings(record=True) as wa:
|
||||
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
|
||||
self.assertEqual(len(wa), 1)
|
||||
self.assertRegex(
|
||||
str(wa[0].message), r"^Lambda function is not supported for pickle"
|
||||
)
|
||||
with self.assertRaises(AttributeError):
|
||||
p = pickle.dumps(datapipe)
|
||||
|
||||
def test_sequence_wrapper_datapipe(self):
|
||||
seq = list(range(10))
|
||||
|
|
|
|||
|
|
@ -1,18 +1,11 @@
|
|||
import warnings
|
||||
from torch.utils.data import IterDataPipe, _utils, functional_datapipe
|
||||
from typing import Callable, Iterator, Sized, TypeVar
|
||||
|
||||
try:
|
||||
import dill
|
||||
from torch.utils.data import IterDataPipe, _utils, functional_datapipe
|
||||
from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE, check_lambda_fn
|
||||
|
||||
# XXX: By default, dill writes the Pickler dispatch table to inject its
|
||||
# own logic there. This globally affects the behavior of the standard library
|
||||
# pickler for any user who transitively depends on this module!
|
||||
# Undo this extension to avoid altering the behavior of the pickler globally.
|
||||
if DILL_AVAILABLE:
|
||||
import dill
|
||||
dill.extend(use_dill=False)
|
||||
DILL_AVAILABLE = True
|
||||
except ImportError:
|
||||
DILL_AVAILABLE = False
|
||||
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
|
||||
|
|
@ -50,13 +43,10 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
|
|||
) -> None:
|
||||
super().__init__()
|
||||
self.datapipe = datapipe
|
||||
# Partial object has no attribute '__name__', but can be pickled
|
||||
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not DILL_AVAILABLE:
|
||||
warnings.warn(
|
||||
"Lambda function is not supported for pickle, please use "
|
||||
"regular python function or functools.partial instead."
|
||||
)
|
||||
|
||||
check_lambda_fn(fn)
|
||||
self.fn = fn # type: ignore[assignment]
|
||||
|
||||
self.input_col = input_col
|
||||
if input_col is None and output_col is not None:
|
||||
raise ValueError("`output_col` must be None when `input_col` is None.")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,14 @@
|
|||
import warnings
|
||||
|
||||
from torch.utils.data import IterDataPipe, functional_datapipe
|
||||
from typing import Any, Callable, Iterator, List, Optional, Set, Sized, Tuple, TypeVar, Deque
|
||||
from collections import deque
|
||||
from typing import Any, Callable, Iterator, List, Optional, Set, Sized, Tuple, TypeVar, Deque
|
||||
|
||||
from torch.utils.data import IterDataPipe, functional_datapipe
|
||||
from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE, check_lambda_fn
|
||||
|
||||
if DILL_AVAILABLE:
|
||||
import dill
|
||||
dill.extend(use_dill=False)
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
|
||||
|
|
@ -190,6 +196,9 @@ class DemultiplexerIterDataPipe(IterDataPipe):
|
|||
classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 1000):
|
||||
if num_instances < 1:
|
||||
raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")
|
||||
|
||||
check_lambda_fn(classifier_fn)
|
||||
|
||||
# When num_instances == 1, demux can be replaced by filter,
|
||||
# but keep it as Demultiplexer for the sake of consistency
|
||||
# like throwing Error when classification result is out of o range
|
||||
|
|
@ -277,6 +286,41 @@ class _DemultiplexerIterDataPipe(IterDataPipe):
|
|||
self.instance_started = [False] * self.num_instances
|
||||
self.main_datapipe_exhausted = False
|
||||
|
||||
def __getstate__(self):
|
||||
if IterDataPipe.getstate_hook is not None:
|
||||
return IterDataPipe.getstate_hook(self)
|
||||
|
||||
if DILL_AVAILABLE:
|
||||
dill_function = dill.dumps(self.classifier_fn)
|
||||
else:
|
||||
dill_function = self.classifier_fn
|
||||
state = (
|
||||
self.main_datapipe,
|
||||
self.num_instances,
|
||||
self.buffer_size,
|
||||
dill_function,
|
||||
self.drop_none,
|
||||
)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
(
|
||||
self.main_datapipe,
|
||||
self.num_instances,
|
||||
self.buffer_size,
|
||||
dill_function,
|
||||
self.drop_none,
|
||||
) = state
|
||||
if DILL_AVAILABLE:
|
||||
self.classifier_fn = dill.loads(dill_function) # type: ignore[assignment]
|
||||
else:
|
||||
self.classifier_fn = dill_function # type: ignore[assignment]
|
||||
self._datapipe_iterator = None
|
||||
self.current_buffer_usage = 0
|
||||
self.child_buffers = [deque() for _ in range(self.num_instances)]
|
||||
self.instance_started = [False] * self.num_instances
|
||||
self.main_datapipe_exhausted = False
|
||||
|
||||
|
||||
@functional_datapipe('mux')
|
||||
class MultiplexerIterDataPipe(IterDataPipe):
|
||||
|
|
|
|||
|
|
@ -5,6 +5,27 @@ import warnings
|
|||
from io import IOBase
|
||||
from typing import Iterable, List, Tuple, Union
|
||||
|
||||
try:
|
||||
import dill
|
||||
|
||||
# XXX: By default, dill writes the Pickler dispatch table to inject its
|
||||
# own logic there. This globally affects the behavior of the standard library
|
||||
# pickler for any user who transitively depends on this module!
|
||||
# Undo this extension to avoid altering the behavior of the pickler globally.
|
||||
dill.extend(use_dill=False)
|
||||
DILL_AVAILABLE = True
|
||||
except ImportError:
|
||||
DILL_AVAILABLE = False
|
||||
|
||||
|
||||
def check_lambda_fn(fn):
|
||||
# Partial object has no attribute '__name__', but can be pickled
|
||||
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not DILL_AVAILABLE:
|
||||
warnings.warn(
|
||||
"Lambda function is not supported for pickle, please use "
|
||||
"regular python function or functools.partial instead."
|
||||
)
|
||||
|
||||
|
||||
def match_masks(name : str, masks : Union[str, List[str]]) -> bool:
|
||||
# empty mask matches any input name
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user