Raise warning for unpickable local function (#80140)

Fixes https://github.com/pytorch/data/issues/538

- Improve the validation function to raise warning about unpickable function when either lambda or local function is provided to `DataPipe`.
- The inner function from `functools.partial` object is extracted as well for validation
- Mimic the behavior of `pickle` module for local lambda function: It would only raise Error for the local function rather than `lambda` function. So, we will raise warning about local function not lambda function.
```py
>>> import pickle
>>> def fn():
...     lf = lambda x: x
...     pickle.dumps(lf)
>>> pickle.dumps(fn)
AttributeError: Can't pickle local object 'fn.<locals>.<lambda>'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80140
Approved by: https://github.com/VitalyFedyunin, https://github.com/NivekT
This commit is contained in:
erjia 2022-06-24 13:50:51 +00:00 committed by PyTorch MergeBot
parent 3afc802c5a
commit 4b75b7d3c1
7 changed files with 169 additions and 91 deletions

View File

@ -603,6 +603,11 @@ def _worker_init_fn(worker_id):
torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id)
lambda_fn1 = lambda x: x # noqa: E731
lambda_fn2 = lambda x: x % 2 # noqa: E731
lambda_fn3 = lambda x: x >= 5 # noqa: E731
class TestFunctionalIterDataPipe(TestCase):
def _serialization_test_helper(self, datapipe, use_dill):
@ -702,16 +707,41 @@ class TestFunctionalIterDataPipe(TestCase):
def test_serializable_with_dill(self):
"""Only for DataPipes that take in a function as argument"""
input_dp = dp.iter.IterableWrapper(range(10))
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Collator, (lambda x: x,), {}),
(dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}),
(dp.iter.Filter, (lambda x: x >= 5,), {}),
(dp.iter.Grouper, (lambda x: x >= 5,), {}),
(dp.iter.Mapper, (lambda x: x,), {}),
datapipes_with_lambda_fn: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Collator, (lambda_fn1,), {}),
(dp.iter.Demultiplexer, (2, lambda_fn2,), {}),
(dp.iter.Filter, (lambda_fn3,), {}),
(dp.iter.Grouper, (lambda_fn3,), {}),
(dp.iter.Mapper, (lambda_fn1,), {}),
]
def _local_fns():
def _fn1(x):
return x
def _fn2(x):
return x % 2
def _fn3(x):
return x >= 5
return _fn1, _fn2, _fn3
fn1, fn2, fn3 = _local_fns()
datapipes_with_local_fn: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Collator, (fn1,), {}),
(dp.iter.Demultiplexer, (2, fn2,), {}),
(dp.iter.Filter, (fn3,), {}),
(dp.iter.Grouper, (fn3,), {}),
(dp.iter.Mapper, (fn1,), {}),
]
dp_compare_children = {dp.iter.Demultiplexer}
if HAS_DILL:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
for dpipe, dp_args, dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn:
if dpipe in dp_compare_children:
dp1, dp2 = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self._serialization_test_for_dp_with_children(dp1, dp2, use_dill=True)
@ -719,13 +749,16 @@ class TestFunctionalIterDataPipe(TestCase):
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self._serialization_test_for_single_dp(datapipe, use_dill=True)
else:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
with warnings.catch_warnings(record=True) as wa:
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle")
with self.assertRaises(AttributeError):
p = pickle.dumps(datapipe)
msgs = (
r"^Lambda function is not supported by pickle",
r"^Local function is not supported by pickle"
)
for dps, msg in zip((datapipes_with_lambda_fn, datapipes_with_local_fn), msgs):
for dpipe, dp_args, dp_kwargs in dps:
with self.assertWarnsRegex(UserWarning, msg):
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
with self.assertRaises((pickle.PicklingError, AttributeError)):
pickle.dumps(datapipe)
def test_iterable_wrapper_datapipe(self):
@ -1150,14 +1183,19 @@ class TestFunctionalIterDataPipe(TestCase):
def fn_nn(d0, d1):
return -d0, -d1, d0 + d1
def _helper(ref_fn, fn, input_col=None, output_col=None):
def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
for constr in (list, tuple):
datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn) if ref_fn is not None else datapipe
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))
if ref_fn is None:
with self.assertRaises(error):
res_dp = datapipe.map(fn, input_col, output_col)
list(res_dp)
else:
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))
# Replacing with one input column and default output column
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
@ -1166,20 +1204,17 @@ class TestFunctionalIterDataPipe(TestCase):
_helper(lambda data: (data[0], data[1], 1 + data[1]), fn_n1_def, 1, 2)
# The index of input column is out of range
with self.assertRaises(IndexError):
_helper(None, fn_1n, 3)
_helper(None, fn_1n, 3, error=IndexError)
# Unmatched input columns with fn arguments
with self.assertRaises(ValueError):
_helper(None, fn_n1, 1)
_helper(None, lambda d0, d1: d0 + d1, 0)
_helper(None, p_fn_n1, (0, 1))
_helper(None, fn_n1, 1, error=ValueError)
_helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError)
_helper(None, p_fn_n1, (0, 1, 3), error=ValueError)
# Function takes fewer parameters than input col
with self.assertRaises(ValueError):
def zero_args():
return
_helper(None, zero_args, 0)
def zero_args():
return
_helper(None, zero_args, 0, error=ValueError)
# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
@ -1190,19 +1225,16 @@ class TestFunctionalIterDataPipe(TestCase):
2)
# output_col can only be specified when input_col is not None
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, 1)
_helper(None, fn_n1, None, 1, error=ValueError)
# output_col can only be single-element list or tuple
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, [0, 1])
_helper(None, fn_n1, None, [0, 1], error=ValueError)
# Single-element list as output_col
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
# Replacing with one input column and single specified output column
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
_helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
# The index of output column is out of range
with self.assertRaises(IndexError):
_helper(None, fn_1n, 1, 3)
_helper(None, fn_1n, 1, 3, error=IndexError)
_helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
_helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)
@ -1238,22 +1270,26 @@ class TestFunctionalIterDataPipe(TestCase):
del _data[idx]
return _data
def _helper(ref_fn, fn, input_col=None, output_col=None):
def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
datapipe = dp.iter.IterableWrapper(
[{"x": 0, "y": 1, "z": 2},
{"x": 3, "y": 4, "z": 5},
{"x": 6, "y": 7, "z": 8}]
)
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn) if ref_fn is not None else datapipe
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))
if ref_fn is None:
with self.assertRaises(error):
res_dp = datapipe.map(fn, input_col, output_col)
list(res_dp)
else:
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))
# Replacing with one input column and default output column
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}),
lambda x, y: x + y, ("x", "y"), "z")
_helper(lambda data: _dict_update(data, {"x": 1 + data["y"]}), fn_n1_def, "y",
@ -1262,19 +1298,16 @@ class TestFunctionalIterDataPipe(TestCase):
p_fn_n1 = partial(fn_n1, d1=1)
_helper(lambda data: _dict_update(data, {"x": 1 + data["y"]}), p_fn_n1, "y", "x")
# The key of input column is not in dict
with self.assertRaises(KeyError):
_helper(None, fn_1n, "a")
_helper(None, fn_1n, "a", error=KeyError)
# Unmatched input columns with fn arguments
with self.assertRaises(ValueError):
_helper(None, fn_n1, "y")
_helper(None, lambda x, y: x + y, "x")
_helper(None, p_fn_n1, ("x", "y"))
_helper(None, fn_n1, "y", error=ValueError)
_helper(None, lambda x, y: x + y, "x", error=ValueError)
_helper(None, p_fn_n1, ("x", "y", "z"), error=ValueError)
# Function takes fewer parameters than input col
with self.assertRaises(ValueError):
def zero_args():
return
_helper(None, zero_args, "x")
def zero_args():
return
_helper(None, zero_args, "x", error=ValueError)
# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
@ -1283,11 +1316,9 @@ class TestFunctionalIterDataPipe(TestCase):
_helper(lambda data: _dict_update(data, {"x": data["x"] + data["y"]}), fn_n1_def, ("x", "y"), "x")
# output_col can only be specified when input_col is not None
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, "x")
_helper(None, fn_n1, None, "x", error=ValueError)
# output_col can only be single-element list or tuple
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, ["x", "y"])
_helper(None, fn_n1, None, ["x", "y"], error=ValueError)
# Single-element list as output_col
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
# Replacing with one input column and single specified output column
@ -1677,24 +1708,41 @@ class TestFunctionalMapDataPipe(TestCase):
def test_serializable_with_dill(self):
"""Only for DataPipes that take in a function as argument"""
input_dp = dp.map.SequenceWrapper(range(10))
unpicklable_datapipes: List[
datapipes_with_lambda_fn: List[
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
] = [
(dp.map.Mapper, (lambda x: x,), {}),
(dp.map.Mapper, (lambda_fn1,), {}),
]
def _local_fns():
def _fn1(x):
return x
return _fn1
fn1 = _local_fns()
datapipes_with_local_fn: List[
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
] = [
(dp.map.Mapper, (fn1,), {}),
]
if HAS_DILL:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
for dpipe, dp_args, dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn:
_ = dill.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
else:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
with warnings.catch_warnings(record=True) as wa:
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self.assertEqual(len(wa), 1)
self.assertRegex(
str(wa[0].message), r"^Lambda function is not supported for pickle"
)
with self.assertRaises(AttributeError):
p = pickle.dumps(datapipe)
msgs = (
r"^Lambda function is not supported by pickle",
r"^Local function is not supported by pickle"
)
for dps, msg in zip((datapipes_with_lambda_fn, datapipes_with_local_fn), msgs):
for dpipe, dp_args, dp_kwargs in dps:
with self.assertWarnsRegex(UserWarning, msg):
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
with self.assertRaises((pickle.PicklingError, AttributeError)):
pickle.dumps(datapipe)
def test_sequence_wrapper_datapipe(self):
seq = list(range(10))

View File

@ -4,8 +4,9 @@ from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data._utils.collate import default_collate
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.utils.common import (
_check_lambda_fn,
validate_input_col)
_check_unpickable_fn,
validate_input_col
)
__all__ = [
"CollatorIterDataPipe",
@ -66,7 +67,7 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
super().__init__()
self.datapipe = datapipe
_check_lambda_fn(fn)
_check_unpickable_fn(fn)
self.fn = fn # type: ignore[assignment]
self.input_col = input_col

View File

@ -5,7 +5,7 @@ from typing import Any, Callable, Iterator, List, Optional, Sized, Tuple, TypeVa
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
__all__ = [
"ConcaterIterDataPipe",
@ -300,7 +300,7 @@ class DemultiplexerIterDataPipe(IterDataPipe):
if num_instances < 1:
raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")
_check_lambda_fn(classifier_fn)
_check_unpickable_fn(classifier_fn)
# When num_instances == 1, demux can be replaced by filter,
# but keep it as Demultiplexer for the sake of consistency

View File

@ -2,7 +2,7 @@ from collections import defaultdict
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe, DataChunk
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar
__all__ = [
@ -215,7 +215,7 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
group_size: Optional[int] = None,
guaranteed_group_size: Optional[int] = None,
drop_remaining: bool = False):
_check_lambda_fn(group_key_fn)
_check_unpickable_fn(group_key_fn)
self.datapipe = datapipe
self.group_key_fn = group_key_fn

View File

@ -4,9 +4,10 @@ from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
from torch.utils.data.datapipes.utils.common import (
_check_lambda_fn,
_check_unpickable_fn,
_deprecation_warning,
validate_input_col)
validate_input_col
)
__all__ = ["FilterIterDataPipe", ]
@ -51,7 +52,7 @@ class FilterIterDataPipe(IterDataPipe[T_co]):
super().__init__()
self.datapipe = datapipe
_check_lambda_fn(filter_fn)
_check_unpickable_fn(filter_fn)
self.filter_fn = filter_fn # type: ignore[assignment]
if drop_empty_batches is None:

View File

@ -1,4 +1,4 @@
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
from typing import Callable, TypeVar
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import MapDataPipe
@ -48,7 +48,7 @@ class MapperMapDataPipe(MapDataPipe[T_co]):
) -> None:
super().__init__()
self.datapipe = datapipe
_check_lambda_fn(fn)
_check_unpickable_fn(fn)
self.fn = fn # type: ignore[assignment]
def __len__(self) -> int:

View File

@ -1,10 +1,11 @@
import os
import fnmatch
import warnings
import inspect
import os
import warnings
from io import IOBase
from typing import Dict, Iterable, List, Tuple, Union, Optional, Callable
from functools import partial
from typing import Callable, Dict, Iterable, List, Tuple, Union, Optional
from torch.utils.data._utils.serialization import DILL_AVAILABLE
@ -28,7 +29,7 @@ def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]
Returns:
None.
Raises:
TypeError: If the function is not compatible with the input column.
ValueError: If the function is not compatible with the input column.
"""
sig = inspect.signature(fn)
if isinstance(input_col, (list, tuple)):
@ -39,25 +40,52 @@ def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]
if len(sig.parameters) > sz:
non_default_params = [p for p in sig.parameters.values() if p.default is p.empty]
if len(non_default_params) > sz:
fn_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
raise ValueError(
f"The function {fn.__name__} takes {len(non_default_params)} "
f"The function {fn_name} takes {len(non_default_params)} "
f"non-default parameters, but {sz} are required for the given `input_col`."
)
if len(sig.parameters) < sz:
fn_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
raise ValueError(
f"The function {fn.__name__} takes {len(sig.parameters)} "
f"The function {fn_name} takes {len(sig.parameters)} "
f"parameters, but {sz} are required for the given `input_col`."
)
def _check_lambda_fn(fn):
# Partial object has no attribute '__name__', but can be pickled
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not DILL_AVAILABLE:
def _is_local_fn(fn):
return fn.__code__.co_flags & inspect.CO_NESTED
def _check_unpickable_fn(fn: Callable):
"""
Checks function is pickable or not. If it is a lambda or local function, a UserWarning
will be raised. If it's not a callable function, a TypeError will be raised.
"""
if not callable(fn):
raise TypeError(f"A callable function is expected, but {type(fn)} is provided.")
# Extract function from partial object
# Nested partial function is automatically expanded as a single partial object
if isinstance(fn, partial):
fn = fn.func
# Local function
if _is_local_fn(fn) and not DILL_AVAILABLE:
warnings.warn(
"Lambda function is not supported for pickle, please use "
"Local function is not supported by pickle, please use "
"regular python function or functools.partial instead."
)
return
# Lambda function
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not DILL_AVAILABLE:
warnings.warn(
"Lambda function is not supported by pickle, please use "
"regular python function or functools.partial instead."
)
return
def match_masks(name : str, masks : Union[str, List[str]]) -> bool: