Raise warning for unpickable local function (#80140)

Fixes https://github.com/pytorch/data/issues/538

- Improve the validation function to raise warning about unpickable function when either lambda or local function is provided to `DataPipe`.
- The inner function from `functools.partial` object is extracted as well for validation
- Mimic the behavior of `pickle` module for local lambda function: It would only raise Error for the local function rather than `lambda` function. So, we will raise warning about local function not lambda function.
```py
>>> import pickle
>>> def fn():
...     lf = lambda x: x
...     pickle.dumps(lf)
>>> pickle.dumps(fn)
AttributeError: Can't pickle local object 'fn.<locals>.<lambda>'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80140
Approved by: https://github.com/VitalyFedyunin, https://github.com/NivekT
This commit is contained in:
erjia 2022-06-24 13:50:51 +00:00 committed by PyTorch MergeBot
parent 3afc802c5a
commit 4b75b7d3c1
7 changed files with 169 additions and 91 deletions

View File

@ -603,6 +603,11 @@ def _worker_init_fn(worker_id):
torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id) torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id)
lambda_fn1 = lambda x: x # noqa: E731
lambda_fn2 = lambda x: x % 2 # noqa: E731
lambda_fn3 = lambda x: x >= 5 # noqa: E731
class TestFunctionalIterDataPipe(TestCase): class TestFunctionalIterDataPipe(TestCase):
def _serialization_test_helper(self, datapipe, use_dill): def _serialization_test_helper(self, datapipe, use_dill):
@ -702,16 +707,41 @@ class TestFunctionalIterDataPipe(TestCase):
def test_serializable_with_dill(self): def test_serializable_with_dill(self):
"""Only for DataPipes that take in a function as argument""" """Only for DataPipes that take in a function as argument"""
input_dp = dp.iter.IterableWrapper(range(10)) input_dp = dp.iter.IterableWrapper(range(10))
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Collator, (lambda x: x,), {}), datapipes_with_lambda_fn: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}), (dp.iter.Collator, (lambda_fn1,), {}),
(dp.iter.Filter, (lambda x: x >= 5,), {}), (dp.iter.Demultiplexer, (2, lambda_fn2,), {}),
(dp.iter.Grouper, (lambda x: x >= 5,), {}), (dp.iter.Filter, (lambda_fn3,), {}),
(dp.iter.Mapper, (lambda x: x,), {}), (dp.iter.Grouper, (lambda_fn3,), {}),
(dp.iter.Mapper, (lambda_fn1,), {}),
] ]
def _local_fns():
def _fn1(x):
return x
def _fn2(x):
return x % 2
def _fn3(x):
return x >= 5
return _fn1, _fn2, _fn3
fn1, fn2, fn3 = _local_fns()
datapipes_with_local_fn: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Collator, (fn1,), {}),
(dp.iter.Demultiplexer, (2, fn2,), {}),
(dp.iter.Filter, (fn3,), {}),
(dp.iter.Grouper, (fn3,), {}),
(dp.iter.Mapper, (fn1,), {}),
]
dp_compare_children = {dp.iter.Demultiplexer} dp_compare_children = {dp.iter.Demultiplexer}
if HAS_DILL: if HAS_DILL:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: for dpipe, dp_args, dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn:
if dpipe in dp_compare_children: if dpipe in dp_compare_children:
dp1, dp2 = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] dp1, dp2 = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self._serialization_test_for_dp_with_children(dp1, dp2, use_dill=True) self._serialization_test_for_dp_with_children(dp1, dp2, use_dill=True)
@ -719,13 +749,16 @@ class TestFunctionalIterDataPipe(TestCase):
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self._serialization_test_for_single_dp(datapipe, use_dill=True) self._serialization_test_for_single_dp(datapipe, use_dill=True)
else: else:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: msgs = (
with warnings.catch_warnings(record=True) as wa: r"^Lambda function is not supported by pickle",
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] r"^Local function is not supported by pickle"
self.assertEqual(len(wa), 1) )
self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle") for dps, msg in zip((datapipes_with_lambda_fn, datapipes_with_local_fn), msgs):
with self.assertRaises(AttributeError): for dpipe, dp_args, dp_kwargs in dps:
p = pickle.dumps(datapipe) with self.assertWarnsRegex(UserWarning, msg):
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
with self.assertRaises((pickle.PicklingError, AttributeError)):
pickle.dumps(datapipe)
def test_iterable_wrapper_datapipe(self): def test_iterable_wrapper_datapipe(self):
@ -1150,14 +1183,19 @@ class TestFunctionalIterDataPipe(TestCase):
def fn_nn(d0, d1): def fn_nn(d0, d1):
return -d0, -d1, d0 + d1 return -d0, -d1, d0 + d1
def _helper(ref_fn, fn, input_col=None, output_col=None): def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
for constr in (list, tuple): for constr in (list, tuple):
datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]) datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
res_dp = datapipe.map(fn, input_col, output_col) if ref_fn is None:
ref_dp = datapipe.map(ref_fn) if ref_fn is not None else datapipe with self.assertRaises(error):
self.assertEqual(list(res_dp), list(ref_dp)) res_dp = datapipe.map(fn, input_col, output_col)
# Reset list(res_dp)
self.assertEqual(list(res_dp), list(ref_dp)) else:
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))
# Replacing with one input column and default output column # Replacing with one input column and default output column
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1) _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
@ -1166,20 +1204,17 @@ class TestFunctionalIterDataPipe(TestCase):
_helper(lambda data: (data[0], data[1], 1 + data[1]), fn_n1_def, 1, 2) _helper(lambda data: (data[0], data[1], 1 + data[1]), fn_n1_def, 1, 2)
# The index of input column is out of range # The index of input column is out of range
with self.assertRaises(IndexError): _helper(None, fn_1n, 3, error=IndexError)
_helper(None, fn_1n, 3)
# Unmatched input columns with fn arguments # Unmatched input columns with fn arguments
with self.assertRaises(ValueError): _helper(None, fn_n1, 1, error=ValueError)
_helper(None, fn_n1, 1) _helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError)
_helper(None, lambda d0, d1: d0 + d1, 0) _helper(None, p_fn_n1, (0, 1, 3), error=ValueError)
_helper(None, p_fn_n1, (0, 1))
# Function takes fewer parameters than input col # Function takes fewer parameters than input col
with self.assertRaises(ValueError): def zero_args():
def zero_args(): return
return _helper(None, zero_args, 0, error=ValueError)
_helper(None, zero_args, 0)
# Replacing with multiple input columns and default output column (the left-most input column) # Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0]) _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
@ -1190,19 +1225,16 @@ class TestFunctionalIterDataPipe(TestCase):
2) 2)
# output_col can only be specified when input_col is not None # output_col can only be specified when input_col is not None
with self.assertRaises(ValueError): _helper(None, fn_n1, None, 1, error=ValueError)
_helper(None, fn_n1, None, 1)
# output_col can only be single-element list or tuple # output_col can only be single-element list or tuple
with self.assertRaises(ValueError): _helper(None, fn_n1, None, [0, 1], error=ValueError)
_helper(None, fn_n1, None, [0, 1])
# Single-element list as output_col # Single-element list as output_col
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0]) _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
# Replacing with one input column and single specified output column # Replacing with one input column and single specified output column
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0) _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
_helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2) _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
# The index of output column is out of range # The index of output column is out of range
with self.assertRaises(IndexError): _helper(None, fn_1n, 1, 3, error=IndexError)
_helper(None, fn_1n, 1, 3)
_helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1) _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
_helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0) _helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)
@ -1238,22 +1270,26 @@ class TestFunctionalIterDataPipe(TestCase):
del _data[idx] del _data[idx]
return _data return _data
def _helper(ref_fn, fn, input_col=None, output_col=None): def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
datapipe = dp.iter.IterableWrapper( datapipe = dp.iter.IterableWrapper(
[{"x": 0, "y": 1, "z": 2}, [{"x": 0, "y": 1, "z": 2},
{"x": 3, "y": 4, "z": 5}, {"x": 3, "y": 4, "z": 5},
{"x": 6, "y": 7, "z": 8}] {"x": 6, "y": 7, "z": 8}]
) )
res_dp = datapipe.map(fn, input_col, output_col) if ref_fn is None:
ref_dp = datapipe.map(ref_fn) if ref_fn is not None else datapipe with self.assertRaises(error):
self.assertEqual(list(res_dp), list(ref_dp)) res_dp = datapipe.map(fn, input_col, output_col)
# Reset list(res_dp)
self.assertEqual(list(res_dp), list(ref_dp)) else:
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))
# Replacing with one input column and default output column # Replacing with one input column and default output column
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y") _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y") _helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), _helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}),
lambda x, y: x + y, ("x", "y"), "z") lambda x, y: x + y, ("x", "y"), "z")
_helper(lambda data: _dict_update(data, {"x": 1 + data["y"]}), fn_n1_def, "y", _helper(lambda data: _dict_update(data, {"x": 1 + data["y"]}), fn_n1_def, "y",
@ -1262,19 +1298,16 @@ class TestFunctionalIterDataPipe(TestCase):
p_fn_n1 = partial(fn_n1, d1=1) p_fn_n1 = partial(fn_n1, d1=1)
_helper(lambda data: _dict_update(data, {"x": 1 + data["y"]}), p_fn_n1, "y", "x") _helper(lambda data: _dict_update(data, {"x": 1 + data["y"]}), p_fn_n1, "y", "x")
# The key of input column is not in dict # The key of input column is not in dict
with self.assertRaises(KeyError): _helper(None, fn_1n, "a", error=KeyError)
_helper(None, fn_1n, "a")
# Unmatched input columns with fn arguments # Unmatched input columns with fn arguments
with self.assertRaises(ValueError): _helper(None, fn_n1, "y", error=ValueError)
_helper(None, fn_n1, "y") _helper(None, lambda x, y: x + y, "x", error=ValueError)
_helper(None, lambda x, y: x + y, "x") _helper(None, p_fn_n1, ("x", "y", "z"), error=ValueError)
_helper(None, p_fn_n1, ("x", "y"))
# Function takes fewer parameters than input col # Function takes fewer parameters than input col
with self.assertRaises(ValueError): def zero_args():
def zero_args(): return
return _helper(None, zero_args, "x", error=ValueError)
_helper(None, zero_args, "x")
# Replacing with multiple input columns and default output column (the left-most input column) # Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"]) _helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
@ -1283,11 +1316,9 @@ class TestFunctionalIterDataPipe(TestCase):
_helper(lambda data: _dict_update(data, {"x": data["x"] + data["y"]}), fn_n1_def, ("x", "y"), "x") _helper(lambda data: _dict_update(data, {"x": data["x"] + data["y"]}), fn_n1_def, ("x", "y"), "x")
# output_col can only be specified when input_col is not None # output_col can only be specified when input_col is not None
with self.assertRaises(ValueError): _helper(None, fn_n1, None, "x", error=ValueError)
_helper(None, fn_n1, None, "x")
# output_col can only be single-element list or tuple # output_col can only be single-element list or tuple
with self.assertRaises(ValueError): _helper(None, fn_n1, None, ["x", "y"], error=ValueError)
_helper(None, fn_n1, None, ["x", "y"])
# Single-element list as output_col # Single-element list as output_col
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"]) _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
# Replacing with one input column and single specified output column # Replacing with one input column and single specified output column
@ -1677,24 +1708,41 @@ class TestFunctionalMapDataPipe(TestCase):
def test_serializable_with_dill(self): def test_serializable_with_dill(self):
"""Only for DataPipes that take in a function as argument""" """Only for DataPipes that take in a function as argument"""
input_dp = dp.map.SequenceWrapper(range(10)) input_dp = dp.map.SequenceWrapper(range(10))
unpicklable_datapipes: List[
datapipes_with_lambda_fn: List[
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]] Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
] = [ ] = [
(dp.map.Mapper, (lambda x: x,), {}), (dp.map.Mapper, (lambda_fn1,), {}),
] ]
def _local_fns():
def _fn1(x):
return x
return _fn1
fn1 = _local_fns()
datapipes_with_local_fn: List[
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
] = [
(dp.map.Mapper, (fn1,), {}),
]
if HAS_DILL: if HAS_DILL:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: for dpipe, dp_args, dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn:
_ = dill.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] _ = dill.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
else: else:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: msgs = (
with warnings.catch_warnings(record=True) as wa: r"^Lambda function is not supported by pickle",
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] r"^Local function is not supported by pickle"
self.assertEqual(len(wa), 1) )
self.assertRegex( for dps, msg in zip((datapipes_with_lambda_fn, datapipes_with_local_fn), msgs):
str(wa[0].message), r"^Lambda function is not supported for pickle" for dpipe, dp_args, dp_kwargs in dps:
) with self.assertWarnsRegex(UserWarning, msg):
with self.assertRaises(AttributeError): datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
p = pickle.dumps(datapipe) with self.assertRaises((pickle.PicklingError, AttributeError)):
pickle.dumps(datapipe)
def test_sequence_wrapper_datapipe(self): def test_sequence_wrapper_datapipe(self):
seq = list(range(10)) seq = list(range(10))

View File

@ -4,8 +4,9 @@ from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data._utils.collate import default_collate from torch.utils.data._utils.collate import default_collate
from torch.utils.data.datapipes.datapipe import IterDataPipe from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.utils.common import ( from torch.utils.data.datapipes.utils.common import (
_check_lambda_fn, _check_unpickable_fn,
validate_input_col) validate_input_col
)
__all__ = [ __all__ = [
"CollatorIterDataPipe", "CollatorIterDataPipe",
@ -66,7 +67,7 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
super().__init__() super().__init__()
self.datapipe = datapipe self.datapipe = datapipe
_check_lambda_fn(fn) _check_unpickable_fn(fn)
self.fn = fn # type: ignore[assignment] self.fn = fn # type: ignore[assignment]
self.input_col = input_col self.input_col = input_col

View File

@ -5,7 +5,7 @@ from typing import Any, Callable, Iterator, List, Optional, Sized, Tuple, TypeVa
from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.utils.common import _check_lambda_fn from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
__all__ = [ __all__ = [
"ConcaterIterDataPipe", "ConcaterIterDataPipe",
@ -300,7 +300,7 @@ class DemultiplexerIterDataPipe(IterDataPipe):
if num_instances < 1: if num_instances < 1:
raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found") raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")
_check_lambda_fn(classifier_fn) _check_unpickable_fn(classifier_fn)
# When num_instances == 1, demux can be replaced by filter, # When num_instances == 1, demux can be replaced by filter,
# but keep it as Demultiplexer for the sake of consistency # but keep it as Demultiplexer for the sake of consistency

View File

@ -2,7 +2,7 @@ from collections import defaultdict
from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe, DataChunk from torch.utils.data.datapipes.datapipe import IterDataPipe, DataChunk
from torch.utils.data.datapipes.utils.common import _check_lambda_fn from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar
__all__ = [ __all__ = [
@ -215,7 +215,7 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
group_size: Optional[int] = None, group_size: Optional[int] = None,
guaranteed_group_size: Optional[int] = None, guaranteed_group_size: Optional[int] = None,
drop_remaining: bool = False): drop_remaining: bool = False):
_check_lambda_fn(group_key_fn) _check_unpickable_fn(group_key_fn)
self.datapipe = datapipe self.datapipe = datapipe
self.group_key_fn = group_key_fn self.group_key_fn = group_key_fn

View File

@ -4,9 +4,10 @@ from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
from torch.utils.data.datapipes.utils.common import ( from torch.utils.data.datapipes.utils.common import (
_check_lambda_fn, _check_unpickable_fn,
_deprecation_warning, _deprecation_warning,
validate_input_col) validate_input_col
)
__all__ = ["FilterIterDataPipe", ] __all__ = ["FilterIterDataPipe", ]
@ -51,7 +52,7 @@ class FilterIterDataPipe(IterDataPipe[T_co]):
super().__init__() super().__init__()
self.datapipe = datapipe self.datapipe = datapipe
_check_lambda_fn(filter_fn) _check_unpickable_fn(filter_fn)
self.filter_fn = filter_fn # type: ignore[assignment] self.filter_fn = filter_fn # type: ignore[assignment]
if drop_empty_batches is None: if drop_empty_batches is None:

View File

@ -1,4 +1,4 @@
from torch.utils.data.datapipes.utils.common import _check_lambda_fn from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
from typing import Callable, TypeVar from typing import Callable, TypeVar
from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import MapDataPipe from torch.utils.data.datapipes.datapipe import MapDataPipe
@ -48,7 +48,7 @@ class MapperMapDataPipe(MapDataPipe[T_co]):
) -> None: ) -> None:
super().__init__() super().__init__()
self.datapipe = datapipe self.datapipe = datapipe
_check_lambda_fn(fn) _check_unpickable_fn(fn)
self.fn = fn # type: ignore[assignment] self.fn = fn # type: ignore[assignment]
def __len__(self) -> int: def __len__(self) -> int:

View File

@ -1,10 +1,11 @@
import os
import fnmatch import fnmatch
import warnings
import inspect import inspect
import os
import warnings
from io import IOBase from io import IOBase
from typing import Dict, Iterable, List, Tuple, Union, Optional, Callable from functools import partial
from typing import Callable, Dict, Iterable, List, Tuple, Union, Optional
from torch.utils.data._utils.serialization import DILL_AVAILABLE from torch.utils.data._utils.serialization import DILL_AVAILABLE
@ -28,7 +29,7 @@ def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]
Returns: Returns:
None. None.
Raises: Raises:
TypeError: If the function is not compatible with the input column. ValueError: If the function is not compatible with the input column.
""" """
sig = inspect.signature(fn) sig = inspect.signature(fn)
if isinstance(input_col, (list, tuple)): if isinstance(input_col, (list, tuple)):
@ -39,25 +40,52 @@ def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]
if len(sig.parameters) > sz: if len(sig.parameters) > sz:
non_default_params = [p for p in sig.parameters.values() if p.default is p.empty] non_default_params = [p for p in sig.parameters.values() if p.default is p.empty]
if len(non_default_params) > sz: if len(non_default_params) > sz:
fn_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
raise ValueError( raise ValueError(
f"The function {fn.__name__} takes {len(non_default_params)} " f"The function {fn_name} takes {len(non_default_params)} "
f"non-default parameters, but {sz} are required for the given `input_col`." f"non-default parameters, but {sz} are required for the given `input_col`."
) )
if len(sig.parameters) < sz: if len(sig.parameters) < sz:
fn_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
raise ValueError( raise ValueError(
f"The function {fn.__name__} takes {len(sig.parameters)} " f"The function {fn_name} takes {len(sig.parameters)} "
f"parameters, but {sz} are required for the given `input_col`." f"parameters, but {sz} are required for the given `input_col`."
) )
def _check_lambda_fn(fn): def _is_local_fn(fn):
# Partial object has no attribute '__name__', but can be pickled return fn.__code__.co_flags & inspect.CO_NESTED
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not DILL_AVAILABLE:
def _check_unpickable_fn(fn: Callable):
"""
Checks function is pickable or not. If it is a lambda or local function, a UserWarning
will be raised. If it's not a callable function, a TypeError will be raised.
"""
if not callable(fn):
raise TypeError(f"A callable function is expected, but {type(fn)} is provided.")
# Extract function from partial object
# Nested partial function is automatically expanded as a single partial object
if isinstance(fn, partial):
fn = fn.func
# Local function
if _is_local_fn(fn) and not DILL_AVAILABLE:
warnings.warn( warnings.warn(
"Lambda function is not supported for pickle, please use " "Local function is not supported by pickle, please use "
"regular python function or functools.partial instead." "regular python function or functools.partial instead."
) )
return
# Lambda function
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not DILL_AVAILABLE:
warnings.warn(
"Lambda function is not supported by pickle, please use "
"regular python function or functools.partial instead."
)
return
def match_masks(name : str, masks : Union[str, List[str]]) -> bool: def match_masks(name : str, masks : Union[str, List[str]]) -> bool: