[pytrees] Allow tree_map_only to support predicate function as filter (#119974)

In many places in the code we use `tree_map_only((SymInt, SymBool, SymFloat), foo)` but with nested ints, it is possible to have SymInts that are non-symbolic, so we may want to do something like `tree_map_only(is_symbolic, foo)` instead.

Alternative: wrap nested int SymNodes with something other than SymInt.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119974
Approved by: https://github.com/zou3519
ghstack dependencies: #119661
This commit is contained in:
soulitzer 2024-02-16 11:23:28 -05:00 committed by PyTorch MergeBot
parent 722e87899a
commit 2e77629b9f
4 changed files with 108 additions and 36 deletions

View File

@ -629,6 +629,18 @@ class TestGenericPytree(TestCase):
pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
# cxx tree_map_only does not support passing predicate fn as filter
],
)
def test_tree_map_only_predicate_fn(self, pytree_impl):
self.assertEqual(
pytree_impl.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1]
)
@parametrize(
"pytree_impl",
[

View File

@ -86,7 +86,11 @@ class ShapeEnvEvent:
# Replay itself, but using shape_env as self.
def run(self, shape_env=None) -> Any:
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymTypes
from torch.fx.experimental.symbolic_shapes import (
is_symbolic,
ShapeEnv,
SymTypes,
)
# Special handling for the constructor event.
if self.f is ShapeEnv:
@ -105,7 +109,7 @@ class ShapeEnvEvent:
# Replace any argument of type SymTypes by a new instance,
# replacing its ShapeEnv reference.
args, kwargs = pytree.tree_map_only(
SymTypes,
lambda x: isinstance(x, SymTypes) and is_symbolic(x),
lambda a: type(a)(a.node.with_shape_env(shape_env)),
(args, kwargs),
)
@ -172,7 +176,7 @@ class ShapeEnvEvent:
# If we find more than one object of any of the above types, we
# also check that the ShapeEnv instance is the same for all of them.
def _extract_shape_env_and_assert_equal(args, kwargs):
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymTypes
from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes
def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
if old is not None:
@ -183,7 +187,7 @@ def _extract_shape_env_and_assert_equal(args, kwargs):
for val in itertools.chain(args, kwargs.values()):
if isinstance(val, ShapeEnv):
shape_env = assert_equal(shape_env, val)
if isinstance(val, SymTypes):
if isinstance(val, SymTypes) and is_symbolic(val):
shape_env = assert_equal(shape_env, val.node.shape_env)
return shape_env

View File

@ -491,27 +491,34 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
# These specializations help with type inference on the lambda passed to this
# function
@overload
def map_only(__type_or_types: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
...
@overload
def map_only(__type_or_types: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
...
@overload
def map_only(__type_or_types: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
...
# This specialization is needed for the implementations below that call
@overload
def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
...
def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
@overload
def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
...
def map_only(
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
) -> MapOnlyFn[FnAny[Any]]:
"""
Suppose you are writing a tree_map over tensors, leaving everything
else unchanged. Ordinarily you would have to write:
@ -530,11 +537,15 @@ def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
You can also directly use 'tree_map_only'
"""
if not isinstance(__type_or_types_or_pred, (tuple, type)):
raise ValueError(
"cxx_pytree map_only currently only accepts type or tuple of types"
)
def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
@functools.wraps(func)
def wrapped(x: T) -> Any:
if isinstance(x, __type_or_types):
if isinstance(x, __type_or_types_or_pred):
return func(x)
return x
@ -545,7 +556,7 @@ def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
@overload
def tree_map_only(
__type_or_types: Type[T],
__type_or_types_or_pred: Type[T],
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -555,7 +566,7 @@ def tree_map_only(
@overload
def tree_map_only(
__type_or_types: Type2[T, S],
__type_or_types_or_pred: Type2[T, S],
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -565,7 +576,17 @@ def tree_map_only(
@overload
def tree_map_only(
__type_or_types: Type3[T, S, U],
__type_or_types_or_pred: Type3[T, S, U],
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
@overload
def tree_map_only(
__type_or_types_or_pred: Callable[[Any], bool],
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -574,17 +595,17 @@ def tree_map_only(
def tree_map_only(
__type_or_types: TypeAny,
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
return tree_map(map_only(__type_or_types)(func), tree, is_leaf=is_leaf)
return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@overload
def tree_map_only_(
__type_or_types: Type[T],
__type_or_types_or_pred: Type[T],
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -594,7 +615,7 @@ def tree_map_only_(
@overload
def tree_map_only_(
__type_or_types: Type2[T, S],
__type_or_types_or_pred: Type2[T, S],
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -604,7 +625,17 @@ def tree_map_only_(
@overload
def tree_map_only_(
__type_or_types: Type3[T, S, U],
__type_or_types_or_pred: Type3[T, S, U],
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
@overload
def tree_map_only_(
__type_or_types_or_pred: Callable[[Any], bool],
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -613,12 +644,12 @@ def tree_map_only_(
def tree_map_only_(
__type_or_types: TypeAny,
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
return tree_map_(map_only(__type_or_types)(func), tree, is_leaf=is_leaf)
return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
def tree_all(

View File

@ -947,27 +947,34 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
# These specializations help with type inference on the lambda passed to this
# function
@overload
def map_only(__type_or_types: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
...
@overload
def map_only(__type_or_types: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
...
@overload
def map_only(__type_or_types: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
...
# This specialization is needed for the implementations below that call
@overload
def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
...
def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
@overload
def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
...
def map_only(
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
) -> MapOnlyFn[FnAny[Any]]:
"""
Suppose you are writing a tree_map over tensors, leaving everything
else unchanged. Ordinarily you would have to write:
@ -986,11 +993,19 @@ def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
You can also directly use 'tree_map_only'
"""
if isinstance(__type_or_types_or_pred, (tuple, type)):
return _map_only(lambda x: isinstance(x, __type_or_types_or_pred))
elif callable(__type_or_types_or_pred):
return _map_only(__type_or_types_or_pred)
else:
raise TypeError("Argument must be a type, a tuple of types, or a callable.")
def _map_only(pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
# @functools.wraps(func) # torch dynamo doesn't support this yet
def wrapped(x: T) -> Any:
if isinstance(x, __type_or_types):
if pred(x):
return func(x)
return x
@ -1001,7 +1016,7 @@ def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]:
@overload
def tree_map_only(
__type_or_types: Type[T],
__type_or_types_or_pred: Type[T],
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -1011,7 +1026,7 @@ def tree_map_only(
@overload
def tree_map_only(
__type_or_types: Type2[T, S],
__type_or_types_or_pred: Type2[T, S],
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -1021,7 +1036,7 @@ def tree_map_only(
@overload
def tree_map_only(
__type_or_types: Type3[T, S, U],
__type_or_types_or_pred: Type3[T, S, U],
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -1029,18 +1044,28 @@ def tree_map_only(
...
@overload
def tree_map_only(
__type_or_types: TypeAny,
__type_or_types_or_pred: Callable[[Any], bool],
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
return tree_map(map_only(__type_or_types)(func), tree, is_leaf=is_leaf)
...
def tree_map_only(
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@overload
def tree_map_only_(
__type_or_types: Type[T],
__type_or_types_or_pred: Type[T],
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -1050,7 +1075,7 @@ def tree_map_only_(
@overload
def tree_map_only_(
__type_or_types: Type2[T, S],
__type_or_types_or_pred: Type2[T, S],
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -1060,7 +1085,7 @@ def tree_map_only_(
@overload
def tree_map_only_(
__type_or_types: Type3[T, S, U],
__type_or_types_or_pred: Type3[T, S, U],
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -1069,12 +1094,12 @@ def tree_map_only_(
def tree_map_only_(
__type_or_types: TypeAny,
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
return tree_map_(map_only(__type_or_types)(func), tree, is_leaf=is_leaf)
return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
def tree_all(