mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytrees] Allow tree_map_only to support predicate function as filter (#119974)
In many places in the code we use `tree_map_only((SymInt, SymBool, SymFloat), foo)` but with nested ints, it is possible to have SymInts that are non-symbolic, so we may want to do something like `tree_map_only(is_symbolic, foo)` instead. Alternative: wrap nested int SymNodes with something other than SymInt. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119974 Approved by: https://github.com/zou3519 ghstack dependencies: #119661
This commit is contained in:
parent
722e87899a
commit
2e77629b9f
|
|
@ -629,6 +629,18 @@ class TestGenericPytree(TestCase):
|
|||
pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
# cxx tree_map_only does not support passing predicate fn as filter
|
||||
],
|
||||
)
|
||||
def test_tree_map_only_predicate_fn(self, pytree_impl):
|
||||
self.assertEqual(
|
||||
pytree_impl.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1]
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
|
|
|
|||
|
|
@ -86,7 +86,11 @@ class ShapeEnvEvent:
|
|||
|
||||
# Replay itself, but using shape_env as self.
|
||||
def run(self, shape_env=None) -> Any:
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymTypes
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
is_symbolic,
|
||||
ShapeEnv,
|
||||
SymTypes,
|
||||
)
|
||||
|
||||
# Special handling for the constructor event.
|
||||
if self.f is ShapeEnv:
|
||||
|
|
@ -105,7 +109,7 @@ class ShapeEnvEvent:
|
|||
# Replace any argument of type SymTypes by a new instance,
|
||||
# replacing its ShapeEnv reference.
|
||||
args, kwargs = pytree.tree_map_only(
|
||||
SymTypes,
|
||||
lambda x: isinstance(x, SymTypes) and is_symbolic(x),
|
||||
lambda a: type(a)(a.node.with_shape_env(shape_env)),
|
||||
(args, kwargs),
|
||||
)
|
||||
|
|
@ -172,7 +176,7 @@ class ShapeEnvEvent:
|
|||
# If we find more than one object of any of the above types, we
|
||||
# also check that the ShapeEnv instance is the same for all of them.
|
||||
def _extract_shape_env_and_assert_equal(args, kwargs):
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymTypes
|
||||
from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes
|
||||
|
||||
def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
|
||||
if old is not None:
|
||||
|
|
@ -183,7 +187,7 @@ def _extract_shape_env_and_assert_equal(args, kwargs):
|
|||
for val in itertools.chain(args, kwargs.values()):
|
||||
if isinstance(val, ShapeEnv):
|
||||
shape_env = assert_equal(shape_env, val)
|
||||
if isinstance(val, SymTypes):
|
||||
if isinstance(val, SymTypes) and is_symbolic(val):
|
||||
shape_env = assert_equal(shape_env, val.node.shape_env)
|
||||
|
||||
return shape_env
|
||||
|
|
|
|||
|
|
@ -491,27 +491,34 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
|
|||
# These specializations help with type inference on the lambda passed to this
|
||||
# function
|
||||
@overload
|
||||
def map_only(__type_or_types: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
|
||||
def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(__type_or_types: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
||||
def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(__type_or_types: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
|
||||
def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
|
||||
...
|
||||
|
||||
|
||||
# This specialization is needed for the implementations below that call
|
||||
@overload
|
||||
def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
||||
def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
||||
...
|
||||
|
||||
|
||||
def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
||||
@overload
|
||||
def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
|
||||
...
|
||||
|
||||
|
||||
def map_only(
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
|
||||
) -> MapOnlyFn[FnAny[Any]]:
|
||||
"""
|
||||
Suppose you are writing a tree_map over tensors, leaving everything
|
||||
else unchanged. Ordinarily you would have to write:
|
||||
|
|
@ -530,11 +537,15 @@ def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
|||
|
||||
You can also directly use 'tree_map_only'
|
||||
"""
|
||||
if not isinstance(__type_or_types_or_pred, (tuple, type)):
|
||||
raise ValueError(
|
||||
"cxx_pytree map_only currently only accepts type or tuple of types"
|
||||
)
|
||||
|
||||
def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
|
||||
@functools.wraps(func)
|
||||
def wrapped(x: T) -> Any:
|
||||
if isinstance(x, __type_or_types):
|
||||
if isinstance(x, __type_or_types_or_pred):
|
||||
return func(x)
|
||||
return x
|
||||
|
||||
|
|
@ -545,7 +556,7 @@ def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types: Type[T],
|
||||
__type_or_types_or_pred: Type[T],
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -555,7 +566,7 @@ def tree_map_only(
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types: Type2[T, S],
|
||||
__type_or_types_or_pred: Type2[T, S],
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -565,7 +576,17 @@ def tree_map_only(
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
__type_or_types_or_pred: Type3[T, S, U],
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types_or_pred: Callable[[Any], bool],
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -574,17 +595,17 @@ def tree_map_only(
|
|||
|
||||
|
||||
def tree_map_only(
|
||||
__type_or_types: TypeAny,
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
return tree_map(map_only(__type_or_types)(func), tree, is_leaf=is_leaf)
|
||||
return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types: Type[T],
|
||||
__type_or_types_or_pred: Type[T],
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -594,7 +615,7 @@ def tree_map_only_(
|
|||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types: Type2[T, S],
|
||||
__type_or_types_or_pred: Type2[T, S],
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -604,7 +625,17 @@ def tree_map_only_(
|
|||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
__type_or_types_or_pred: Type3[T, S, U],
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types_or_pred: Callable[[Any], bool],
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -613,12 +644,12 @@ def tree_map_only_(
|
|||
|
||||
|
||||
def tree_map_only_(
|
||||
__type_or_types: TypeAny,
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
return tree_map_(map_only(__type_or_types)(func), tree, is_leaf=is_leaf)
|
||||
return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
def tree_all(
|
||||
|
|
|
|||
|
|
@ -947,27 +947,34 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
|
|||
# These specializations help with type inference on the lambda passed to this
|
||||
# function
|
||||
@overload
|
||||
def map_only(__type_or_types: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
|
||||
def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(__type_or_types: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
||||
def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(__type_or_types: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
|
||||
def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
|
||||
...
|
||||
|
||||
|
||||
# This specialization is needed for the implementations below that call
|
||||
@overload
|
||||
def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
||||
def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
||||
...
|
||||
|
||||
|
||||
def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
||||
@overload
|
||||
def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
|
||||
...
|
||||
|
||||
|
||||
def map_only(
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
|
||||
) -> MapOnlyFn[FnAny[Any]]:
|
||||
"""
|
||||
Suppose you are writing a tree_map over tensors, leaving everything
|
||||
else unchanged. Ordinarily you would have to write:
|
||||
|
|
@ -986,11 +993,19 @@ def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
|||
|
||||
You can also directly use 'tree_map_only'
|
||||
"""
|
||||
if isinstance(__type_or_types_or_pred, (tuple, type)):
|
||||
return _map_only(lambda x: isinstance(x, __type_or_types_or_pred))
|
||||
elif callable(__type_or_types_or_pred):
|
||||
return _map_only(__type_or_types_or_pred)
|
||||
else:
|
||||
raise TypeError("Argument must be a type, a tuple of types, or a callable.")
|
||||
|
||||
|
||||
def _map_only(pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
|
||||
def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
|
||||
# @functools.wraps(func) # torch dynamo doesn't support this yet
|
||||
def wrapped(x: T) -> Any:
|
||||
if isinstance(x, __type_or_types):
|
||||
if pred(x):
|
||||
return func(x)
|
||||
return x
|
||||
|
||||
|
|
@ -1001,7 +1016,7 @@ def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types: Type[T],
|
||||
__type_or_types_or_pred: Type[T],
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1011,7 +1026,7 @@ def tree_map_only(
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types: Type2[T, S],
|
||||
__type_or_types_or_pred: Type2[T, S],
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1021,7 +1036,7 @@ def tree_map_only(
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
__type_or_types_or_pred: Type3[T, S, U],
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1029,18 +1044,28 @@ def tree_map_only(
|
|||
...
|
||||
|
||||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types: TypeAny,
|
||||
__type_or_types_or_pred: Callable[[Any], bool],
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
return tree_map(map_only(__type_or_types)(func), tree, is_leaf=is_leaf)
|
||||
...
|
||||
|
||||
|
||||
def tree_map_only(
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types: Type[T],
|
||||
__type_or_types_or_pred: Type[T],
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1050,7 +1075,7 @@ def tree_map_only_(
|
|||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types: Type2[T, S],
|
||||
__type_or_types_or_pred: Type2[T, S],
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1060,7 +1085,7 @@ def tree_map_only_(
|
|||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
__type_or_types_or_pred: Type3[T, S, U],
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1069,12 +1094,12 @@ def tree_map_only_(
|
|||
|
||||
|
||||
def tree_map_only_(
|
||||
__type_or_types: TypeAny,
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
return tree_map_(map_only(__type_or_types)(func), tree, is_leaf=is_leaf)
|
||||
return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
def tree_all(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user