mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Raise warning for unpickable local function (#80140)
Fixes https://github.com/pytorch/data/issues/538 - Improve the validation function to raise warning about unpickable function when either lambda or local function is provided to `DataPipe`. - The inner function from `functools.partial` object is extracted as well for validation - Mimic the behavior of `pickle` module for local lambda function: It would only raise Error for the local function rather than `lambda` function. So, we will raise warning about local function not lambda function. ```py >>> import pickle >>> def fn(): ... lf = lambda x: x ... pickle.dumps(lf) >>> pickle.dumps(fn) AttributeError: Can't pickle local object 'fn.<locals>.<lambda>' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/80140 Approved by: https://github.com/VitalyFedyunin, https://github.com/NivekT
This commit is contained in:
parent
3afc802c5a
commit
4b75b7d3c1
|
|
@ -603,6 +603,11 @@ def _worker_init_fn(worker_id):
|
|||
torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id)
|
||||
|
||||
|
||||
lambda_fn1 = lambda x: x # noqa: E731
|
||||
lambda_fn2 = lambda x: x % 2 # noqa: E731
|
||||
lambda_fn3 = lambda x: x >= 5 # noqa: E731
|
||||
|
||||
|
||||
class TestFunctionalIterDataPipe(TestCase):
|
||||
|
||||
def _serialization_test_helper(self, datapipe, use_dill):
|
||||
|
|
@ -702,16 +707,41 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
def test_serializable_with_dill(self):
|
||||
"""Only for DataPipes that take in a function as argument"""
|
||||
input_dp = dp.iter.IterableWrapper(range(10))
|
||||
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
|
||||
(dp.iter.Collator, (lambda x: x,), {}),
|
||||
(dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}),
|
||||
(dp.iter.Filter, (lambda x: x >= 5,), {}),
|
||||
(dp.iter.Grouper, (lambda x: x >= 5,), {}),
|
||||
(dp.iter.Mapper, (lambda x: x,), {}),
|
||||
|
||||
datapipes_with_lambda_fn: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
|
||||
(dp.iter.Collator, (lambda_fn1,), {}),
|
||||
(dp.iter.Demultiplexer, (2, lambda_fn2,), {}),
|
||||
(dp.iter.Filter, (lambda_fn3,), {}),
|
||||
(dp.iter.Grouper, (lambda_fn3,), {}),
|
||||
(dp.iter.Mapper, (lambda_fn1,), {}),
|
||||
]
|
||||
|
||||
def _local_fns():
|
||||
def _fn1(x):
|
||||
return x
|
||||
|
||||
def _fn2(x):
|
||||
return x % 2
|
||||
|
||||
def _fn3(x):
|
||||
return x >= 5
|
||||
|
||||
return _fn1, _fn2, _fn3
|
||||
|
||||
fn1, fn2, fn3 = _local_fns()
|
||||
|
||||
datapipes_with_local_fn: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
|
||||
(dp.iter.Collator, (fn1,), {}),
|
||||
(dp.iter.Demultiplexer, (2, fn2,), {}),
|
||||
(dp.iter.Filter, (fn3,), {}),
|
||||
(dp.iter.Grouper, (fn3,), {}),
|
||||
(dp.iter.Mapper, (fn1,), {}),
|
||||
]
|
||||
|
||||
dp_compare_children = {dp.iter.Demultiplexer}
|
||||
|
||||
if HAS_DILL:
|
||||
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
|
||||
for dpipe, dp_args, dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn:
|
||||
if dpipe in dp_compare_children:
|
||||
dp1, dp2 = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
|
||||
self._serialization_test_for_dp_with_children(dp1, dp2, use_dill=True)
|
||||
|
|
@ -719,13 +749,16 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
|
||||
self._serialization_test_for_single_dp(datapipe, use_dill=True)
|
||||
else:
|
||||
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
|
||||
with warnings.catch_warnings(record=True) as wa:
|
||||
msgs = (
|
||||
r"^Lambda function is not supported by pickle",
|
||||
r"^Local function is not supported by pickle"
|
||||
)
|
||||
for dps, msg in zip((datapipes_with_lambda_fn, datapipes_with_local_fn), msgs):
|
||||
for dpipe, dp_args, dp_kwargs in dps:
|
||||
with self.assertWarnsRegex(UserWarning, msg):
|
||||
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
|
||||
self.assertEqual(len(wa), 1)
|
||||
self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle")
|
||||
with self.assertRaises(AttributeError):
|
||||
p = pickle.dumps(datapipe)
|
||||
with self.assertRaises((pickle.PicklingError, AttributeError)):
|
||||
pickle.dumps(datapipe)
|
||||
|
||||
def test_iterable_wrapper_datapipe(self):
|
||||
|
||||
|
|
@ -1150,11 +1183,16 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
def fn_nn(d0, d1):
|
||||
return -d0, -d1, d0 + d1
|
||||
|
||||
def _helper(ref_fn, fn, input_col=None, output_col=None):
|
||||
def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
|
||||
for constr in (list, tuple):
|
||||
datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
|
||||
if ref_fn is None:
|
||||
with self.assertRaises(error):
|
||||
res_dp = datapipe.map(fn, input_col, output_col)
|
||||
ref_dp = datapipe.map(ref_fn) if ref_fn is not None else datapipe
|
||||
list(res_dp)
|
||||
else:
|
||||
res_dp = datapipe.map(fn, input_col, output_col)
|
||||
ref_dp = datapipe.map(ref_fn)
|
||||
self.assertEqual(list(res_dp), list(ref_dp))
|
||||
# Reset
|
||||
self.assertEqual(list(res_dp), list(ref_dp))
|
||||
|
|
@ -1166,20 +1204,17 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
_helper(lambda data: (data[0], data[1], 1 + data[1]), fn_n1_def, 1, 2)
|
||||
|
||||
# The index of input column is out of range
|
||||
with self.assertRaises(IndexError):
|
||||
_helper(None, fn_1n, 3)
|
||||
_helper(None, fn_1n, 3, error=IndexError)
|
||||
|
||||
# Unmatched input columns with fn arguments
|
||||
with self.assertRaises(ValueError):
|
||||
_helper(None, fn_n1, 1)
|
||||
_helper(None, lambda d0, d1: d0 + d1, 0)
|
||||
_helper(None, p_fn_n1, (0, 1))
|
||||
_helper(None, fn_n1, 1, error=ValueError)
|
||||
_helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError)
|
||||
_helper(None, p_fn_n1, (0, 1, 3), error=ValueError)
|
||||
|
||||
# Function takes fewer parameters than input col
|
||||
with self.assertRaises(ValueError):
|
||||
def zero_args():
|
||||
return
|
||||
_helper(None, zero_args, 0)
|
||||
_helper(None, zero_args, 0, error=ValueError)
|
||||
|
||||
# Replacing with multiple input columns and default output column (the left-most input column)
|
||||
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
|
||||
|
|
@ -1190,19 +1225,16 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
2)
|
||||
|
||||
# output_col can only be specified when input_col is not None
|
||||
with self.assertRaises(ValueError):
|
||||
_helper(None, fn_n1, None, 1)
|
||||
_helper(None, fn_n1, None, 1, error=ValueError)
|
||||
# output_col can only be single-element list or tuple
|
||||
with self.assertRaises(ValueError):
|
||||
_helper(None, fn_n1, None, [0, 1])
|
||||
_helper(None, fn_n1, None, [0, 1], error=ValueError)
|
||||
# Single-element list as output_col
|
||||
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
|
||||
# Replacing with one input column and single specified output column
|
||||
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
|
||||
_helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
|
||||
# The index of output column is out of range
|
||||
with self.assertRaises(IndexError):
|
||||
_helper(None, fn_1n, 1, 3)
|
||||
_helper(None, fn_1n, 1, 3, error=IndexError)
|
||||
_helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
|
||||
_helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)
|
||||
|
||||
|
|
@ -1238,14 +1270,19 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
del _data[idx]
|
||||
return _data
|
||||
|
||||
def _helper(ref_fn, fn, input_col=None, output_col=None):
|
||||
def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
|
||||
datapipe = dp.iter.IterableWrapper(
|
||||
[{"x": 0, "y": 1, "z": 2},
|
||||
{"x": 3, "y": 4, "z": 5},
|
||||
{"x": 6, "y": 7, "z": 8}]
|
||||
)
|
||||
if ref_fn is None:
|
||||
with self.assertRaises(error):
|
||||
res_dp = datapipe.map(fn, input_col, output_col)
|
||||
ref_dp = datapipe.map(ref_fn) if ref_fn is not None else datapipe
|
||||
list(res_dp)
|
||||
else:
|
||||
res_dp = datapipe.map(fn, input_col, output_col)
|
||||
ref_dp = datapipe.map(ref_fn)
|
||||
self.assertEqual(list(res_dp), list(ref_dp))
|
||||
# Reset
|
||||
self.assertEqual(list(res_dp), list(ref_dp))
|
||||
|
|
@ -1253,7 +1290,6 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
# Replacing with one input column and default output column
|
||||
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
|
||||
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
|
||||
|
||||
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}),
|
||||
lambda x, y: x + y, ("x", "y"), "z")
|
||||
_helper(lambda data: _dict_update(data, {"x": 1 + data["y"]}), fn_n1_def, "y",
|
||||
|
|
@ -1262,19 +1298,16 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
p_fn_n1 = partial(fn_n1, d1=1)
|
||||
_helper(lambda data: _dict_update(data, {"x": 1 + data["y"]}), p_fn_n1, "y", "x")
|
||||
# The key of input column is not in dict
|
||||
with self.assertRaises(KeyError):
|
||||
_helper(None, fn_1n, "a")
|
||||
_helper(None, fn_1n, "a", error=KeyError)
|
||||
# Unmatched input columns with fn arguments
|
||||
with self.assertRaises(ValueError):
|
||||
_helper(None, fn_n1, "y")
|
||||
_helper(None, lambda x, y: x + y, "x")
|
||||
_helper(None, p_fn_n1, ("x", "y"))
|
||||
_helper(None, fn_n1, "y", error=ValueError)
|
||||
_helper(None, lambda x, y: x + y, "x", error=ValueError)
|
||||
_helper(None, p_fn_n1, ("x", "y", "z"), error=ValueError)
|
||||
|
||||
# Function takes fewer parameters than input col
|
||||
with self.assertRaises(ValueError):
|
||||
def zero_args():
|
||||
return
|
||||
_helper(None, zero_args, "x")
|
||||
_helper(None, zero_args, "x", error=ValueError)
|
||||
|
||||
# Replacing with multiple input columns and default output column (the left-most input column)
|
||||
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
|
||||
|
|
@ -1283,11 +1316,9 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
_helper(lambda data: _dict_update(data, {"x": data["x"] + data["y"]}), fn_n1_def, ("x", "y"), "x")
|
||||
|
||||
# output_col can only be specified when input_col is not None
|
||||
with self.assertRaises(ValueError):
|
||||
_helper(None, fn_n1, None, "x")
|
||||
_helper(None, fn_n1, None, "x", error=ValueError)
|
||||
# output_col can only be single-element list or tuple
|
||||
with self.assertRaises(ValueError):
|
||||
_helper(None, fn_n1, None, ["x", "y"])
|
||||
_helper(None, fn_n1, None, ["x", "y"], error=ValueError)
|
||||
# Single-element list as output_col
|
||||
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
|
||||
# Replacing with one input column and single specified output column
|
||||
|
|
@ -1677,24 +1708,41 @@ class TestFunctionalMapDataPipe(TestCase):
|
|||
def test_serializable_with_dill(self):
|
||||
"""Only for DataPipes that take in a function as argument"""
|
||||
input_dp = dp.map.SequenceWrapper(range(10))
|
||||
unpicklable_datapipes: List[
|
||||
|
||||
datapipes_with_lambda_fn: List[
|
||||
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
|
||||
] = [
|
||||
(dp.map.Mapper, (lambda x: x,), {}),
|
||||
(dp.map.Mapper, (lambda_fn1,), {}),
|
||||
]
|
||||
|
||||
def _local_fns():
|
||||
def _fn1(x):
|
||||
return x
|
||||
|
||||
return _fn1
|
||||
|
||||
fn1 = _local_fns()
|
||||
|
||||
datapipes_with_local_fn: List[
|
||||
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
|
||||
] = [
|
||||
(dp.map.Mapper, (fn1,), {}),
|
||||
]
|
||||
|
||||
if HAS_DILL:
|
||||
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
|
||||
for dpipe, dp_args, dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn:
|
||||
_ = dill.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
|
||||
else:
|
||||
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
|
||||
with warnings.catch_warnings(record=True) as wa:
|
||||
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
|
||||
self.assertEqual(len(wa), 1)
|
||||
self.assertRegex(
|
||||
str(wa[0].message), r"^Lambda function is not supported for pickle"
|
||||
msgs = (
|
||||
r"^Lambda function is not supported by pickle",
|
||||
r"^Local function is not supported by pickle"
|
||||
)
|
||||
with self.assertRaises(AttributeError):
|
||||
p = pickle.dumps(datapipe)
|
||||
for dps, msg in zip((datapipes_with_lambda_fn, datapipes_with_local_fn), msgs):
|
||||
for dpipe, dp_args, dp_kwargs in dps:
|
||||
with self.assertWarnsRegex(UserWarning, msg):
|
||||
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
|
||||
with self.assertRaises((pickle.PicklingError, AttributeError)):
|
||||
pickle.dumps(datapipe)
|
||||
|
||||
def test_sequence_wrapper_datapipe(self):
|
||||
seq = list(range(10))
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@ from torch.utils.data.datapipes._decorator import functional_datapipe
|
|||
from torch.utils.data._utils.collate import default_collate
|
||||
from torch.utils.data.datapipes.datapipe import IterDataPipe
|
||||
from torch.utils.data.datapipes.utils.common import (
|
||||
_check_lambda_fn,
|
||||
validate_input_col)
|
||||
_check_unpickable_fn,
|
||||
validate_input_col
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CollatorIterDataPipe",
|
||||
|
|
@ -66,7 +67,7 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
|
|||
super().__init__()
|
||||
self.datapipe = datapipe
|
||||
|
||||
_check_lambda_fn(fn)
|
||||
_check_unpickable_fn(fn)
|
||||
self.fn = fn # type: ignore[assignment]
|
||||
|
||||
self.input_col = input_col
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Any, Callable, Iterator, List, Optional, Sized, Tuple, TypeVa
|
|||
|
||||
from torch.utils.data.datapipes._decorator import functional_datapipe
|
||||
from torch.utils.data.datapipes.datapipe import IterDataPipe
|
||||
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
|
||||
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
|
||||
|
||||
__all__ = [
|
||||
"ConcaterIterDataPipe",
|
||||
|
|
@ -300,7 +300,7 @@ class DemultiplexerIterDataPipe(IterDataPipe):
|
|||
if num_instances < 1:
|
||||
raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")
|
||||
|
||||
_check_lambda_fn(classifier_fn)
|
||||
_check_unpickable_fn(classifier_fn)
|
||||
|
||||
# When num_instances == 1, demux can be replaced by filter,
|
||||
# but keep it as Demultiplexer for the sake of consistency
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from collections import defaultdict
|
|||
|
||||
from torch.utils.data.datapipes._decorator import functional_datapipe
|
||||
from torch.utils.data.datapipes.datapipe import IterDataPipe, DataChunk
|
||||
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
|
||||
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
|
||||
from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -215,7 +215,7 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
|
|||
group_size: Optional[int] = None,
|
||||
guaranteed_group_size: Optional[int] = None,
|
||||
drop_remaining: bool = False):
|
||||
_check_lambda_fn(group_key_fn)
|
||||
_check_unpickable_fn(group_key_fn)
|
||||
self.datapipe = datapipe
|
||||
self.group_key_fn = group_key_fn
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ from torch.utils.data.datapipes._decorator import functional_datapipe
|
|||
from torch.utils.data.datapipes.datapipe import IterDataPipe
|
||||
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
|
||||
from torch.utils.data.datapipes.utils.common import (
|
||||
_check_lambda_fn,
|
||||
_check_unpickable_fn,
|
||||
_deprecation_warning,
|
||||
validate_input_col)
|
||||
validate_input_col
|
||||
)
|
||||
|
||||
__all__ = ["FilterIterDataPipe", ]
|
||||
|
||||
|
|
@ -51,7 +52,7 @@ class FilterIterDataPipe(IterDataPipe[T_co]):
|
|||
super().__init__()
|
||||
self.datapipe = datapipe
|
||||
|
||||
_check_lambda_fn(filter_fn)
|
||||
_check_unpickable_fn(filter_fn)
|
||||
self.filter_fn = filter_fn # type: ignore[assignment]
|
||||
|
||||
if drop_empty_batches is None:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
|
||||
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
|
||||
from typing import Callable, TypeVar
|
||||
from torch.utils.data.datapipes._decorator import functional_datapipe
|
||||
from torch.utils.data.datapipes.datapipe import MapDataPipe
|
||||
|
|
@ -48,7 +48,7 @@ class MapperMapDataPipe(MapDataPipe[T_co]):
|
|||
) -> None:
|
||||
super().__init__()
|
||||
self.datapipe = datapipe
|
||||
_check_lambda_fn(fn)
|
||||
_check_unpickable_fn(fn)
|
||||
self.fn = fn # type: ignore[assignment]
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import os
|
||||
import fnmatch
|
||||
import warnings
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from io import IOBase
|
||||
from typing import Dict, Iterable, List, Tuple, Union, Optional, Callable
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, Iterable, List, Tuple, Union, Optional
|
||||
|
||||
from torch.utils.data._utils.serialization import DILL_AVAILABLE
|
||||
|
||||
|
|
@ -28,7 +29,7 @@ def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]
|
|||
Returns:
|
||||
None.
|
||||
Raises:
|
||||
TypeError: If the function is not compatible with the input column.
|
||||
ValueError: If the function is not compatible with the input column.
|
||||
"""
|
||||
sig = inspect.signature(fn)
|
||||
if isinstance(input_col, (list, tuple)):
|
||||
|
|
@ -39,25 +40,52 @@ def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]
|
|||
if len(sig.parameters) > sz:
|
||||
non_default_params = [p for p in sig.parameters.values() if p.default is p.empty]
|
||||
if len(non_default_params) > sz:
|
||||
fn_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
|
||||
raise ValueError(
|
||||
f"The function {fn.__name__} takes {len(non_default_params)} "
|
||||
f"The function {fn_name} takes {len(non_default_params)} "
|
||||
f"non-default parameters, but {sz} are required for the given `input_col`."
|
||||
)
|
||||
|
||||
if len(sig.parameters) < sz:
|
||||
fn_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
|
||||
raise ValueError(
|
||||
f"The function {fn.__name__} takes {len(sig.parameters)} "
|
||||
f"The function {fn_name} takes {len(sig.parameters)} "
|
||||
f"parameters, but {sz} are required for the given `input_col`."
|
||||
)
|
||||
|
||||
|
||||
def _check_lambda_fn(fn):
|
||||
# Partial object has no attribute '__name__', but can be pickled
|
||||
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not DILL_AVAILABLE:
|
||||
def _is_local_fn(fn):
|
||||
return fn.__code__.co_flags & inspect.CO_NESTED
|
||||
|
||||
|
||||
def _check_unpickable_fn(fn: Callable):
|
||||
"""
|
||||
Checks function is pickable or not. If it is a lambda or local function, a UserWarning
|
||||
will be raised. If it's not a callable function, a TypeError will be raised.
|
||||
"""
|
||||
if not callable(fn):
|
||||
raise TypeError(f"A callable function is expected, but {type(fn)} is provided.")
|
||||
|
||||
# Extract function from partial object
|
||||
# Nested partial function is automatically expanded as a single partial object
|
||||
if isinstance(fn, partial):
|
||||
fn = fn.func
|
||||
|
||||
# Local function
|
||||
if _is_local_fn(fn) and not DILL_AVAILABLE:
|
||||
warnings.warn(
|
||||
"Lambda function is not supported for pickle, please use "
|
||||
"Local function is not supported by pickle, please use "
|
||||
"regular python function or functools.partial instead."
|
||||
)
|
||||
return
|
||||
|
||||
# Lambda function
|
||||
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not DILL_AVAILABLE:
|
||||
warnings.warn(
|
||||
"Lambda function is not supported by pickle, please use "
|
||||
"regular python function or functools.partial instead."
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def match_masks(name : str, masks : Union[str, List[str]]) -> bool:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user