mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[pytree] expand tree_map to accept multi-inputs (#115642)
Fixes #115419 Fixes #91323 Closes #115549 - #115419 - #91323 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115642 Approved by: https://github.com/vmoens, https://github.com/zou3519
This commit is contained in:
parent
7e1542b938
commit
36c6c0c7dc
|
|
@ -502,7 +502,7 @@ class TestGenericPytree(TestCase):
|
|||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_treemap(self, pytree_impl):
|
||||
def test_tree_map(self, pytree_impl):
|
||||
def run_test(pytree):
|
||||
def f(x):
|
||||
return x * 3
|
||||
|
|
@ -536,7 +536,40 @@ class TestGenericPytree(TestCase):
|
|||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_tree_only(self, pytree_impl):
|
||||
def test_tree_map_multi_inputs(self, pytree_impl):
|
||||
def run_test(pytree):
|
||||
def f(x, y, z):
|
||||
return x, [y, (z, 0)]
|
||||
|
||||
pytree_x = pytree
|
||||
pytree_y = pytree_impl.tree_map(lambda x: (x + 1,), pytree)
|
||||
pytree_z = pytree_impl.tree_map(lambda x: {"a": x * 2, "b": 2}, pytree)
|
||||
|
||||
self.assertEqual(
|
||||
pytree_impl.tree_map(f, pytree_x, pytree_y, pytree_z),
|
||||
pytree_impl.tree_map(
|
||||
lambda x: f(x, (x + 1,), {"a": x * 2, "b": 2}), pytree
|
||||
),
|
||||
)
|
||||
|
||||
cases = [
|
||||
[()],
|
||||
([],),
|
||||
{"a": ()},
|
||||
{"a": 1, "b": [{"c": 2}]},
|
||||
{"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
|
||||
]
|
||||
for case in cases:
|
||||
run_test(case)
|
||||
|
||||
@parametrize(
|
||||
"pytree_impl",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_tree_map_only(self, pytree_impl):
|
||||
self.assertEqual(
|
||||
pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -345,8 +345,8 @@ def tree_structure(tree: PyTree) -> TreeSpec:
|
|||
)
|
||||
|
||||
|
||||
def tree_map(func: Callable[..., Any], tree: PyTree) -> PyTree:
|
||||
"""Map a function over leaves in a pytree to produce a new pytree.
|
||||
def tree_map(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
|
||||
"""Map a multi-input function over pytree args to produce a new pytree.
|
||||
|
||||
See also :func:`tree_map_`.
|
||||
|
||||
|
|
@ -355,43 +355,56 @@ def tree_map(func: Callable[..., Any], tree: PyTree) -> PyTree:
|
|||
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
|
||||
{'x': False, 'y': (False, False), 'z': True}
|
||||
|
||||
If multiple inputs are given, the structure of the tree is taken from the first input;
|
||||
subsequent inputs need only have ``tree`` as a prefix:
|
||||
|
||||
>>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
|
||||
[[5, 7, 9], [6, 1, 2]]
|
||||
|
||||
Args:
|
||||
func (callable): A function that takes a single argument, to be applied at the corresponding
|
||||
leaves of the pytree.
|
||||
tree (pytree): A pytree to be mapped over, with each leaf providing the argument to function
|
||||
``func``.
|
||||
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
|
||||
corresponding leaves of the pytrees.
|
||||
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
|
||||
argument to function ``func``.
|
||||
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
||||
``tree`` or has ``tree`` as a prefix.
|
||||
|
||||
Returns:
|
||||
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
|
||||
``func(x)`` where ``x`` is the value at the corresponding leaf in ``tree``.
|
||||
``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
|
||||
is the tuple of values at corresponding nodes in ``rests``.
|
||||
"""
|
||||
return optree.tree_map(
|
||||
func,
|
||||
tree,
|
||||
*rests,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
|
||||
|
||||
def tree_map_(func: Callable[..., Any], tree: PyTree) -> PyTree:
|
||||
def tree_map_(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
|
||||
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
|
||||
|
||||
See also :func:`tree_map`.
|
||||
|
||||
Args:
|
||||
func (callable): A function that takes a single argument, to be applied at the corresponding
|
||||
leaves of the pytree.
|
||||
tree (pytree): A pytree to be mapped over, with each leaf providing the argument to function
|
||||
``func``.
|
||||
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
|
||||
corresponding leaves of the pytrees.
|
||||
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
|
||||
argument to function ``func``.
|
||||
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
||||
``tree`` or has ``tree`` as a prefix.
|
||||
|
||||
Returns:
|
||||
The original ``tree`` with the value at each leaf is given by the side-effect of function
|
||||
``func(x)`` (not the return value) where ``x`` is the value at the corresponding leaf in
|
||||
``tree``.
|
||||
``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
|
||||
in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
|
||||
"""
|
||||
return optree.tree_map_(
|
||||
func,
|
||||
tree,
|
||||
*rests,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from typing import (
|
|||
DefaultDict,
|
||||
Deque,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterable,
|
||||
List,
|
||||
NamedTuple,
|
||||
|
|
@ -442,6 +443,14 @@ _private_register_pytree_node(
|
|||
)
|
||||
|
||||
|
||||
STANDARD_DICT_TYPES: FrozenSet[type] = frozenset(
|
||||
{dict, OrderedDict, defaultdict},
|
||||
)
|
||||
BUILTIN_TYPES: FrozenSet[type] = frozenset(
|
||||
{tuple, list, dict, namedtuple, OrderedDict, defaultdict, deque}, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
|
||||
def _is_namedtuple_instance(tree: Any) -> bool:
|
||||
typ = type(tree)
|
||||
|
|
@ -476,8 +485,14 @@ class TreeSpec:
|
|||
context: Context
|
||||
children_specs: List["TreeSpec"]
|
||||
|
||||
num_nodes: int = dataclasses.field(init=False)
|
||||
num_leaves: int = dataclasses.field(init=False)
|
||||
num_children: int = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs])
|
||||
self.num_nodes = 1 + sum(spec.num_nodes for spec in self.children_specs)
|
||||
self.num_leaves = sum(spec.num_leaves for spec in self.children_specs)
|
||||
self.num_children = len(self.children_specs)
|
||||
|
||||
def __repr__(self, indent: int = 0) -> str:
|
||||
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, ["
|
||||
|
|
@ -496,13 +511,121 @@ class TreeSpec:
|
|||
return repr_prefix + repr_suffix
|
||||
|
||||
def is_leaf(self) -> bool:
|
||||
return isinstance(self, LeafSpec)
|
||||
return self.num_nodes == 1 and self.num_leaves == 1
|
||||
|
||||
def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None:
|
||||
if self.is_leaf():
|
||||
subtrees.append(tree)
|
||||
return
|
||||
|
||||
node_type = _get_node_type(tree)
|
||||
if self.type not in BUILTIN_TYPES:
|
||||
# Always require custom node types to match exactly
|
||||
if node_type != self.type:
|
||||
raise ValueError(
|
||||
f"Type mismatch; "
|
||||
f"expected {self.type!r}, but got {node_type!r}.",
|
||||
)
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
child_pytrees, context = flatten_fn(tree)
|
||||
if len(child_pytrees) != self.num_children:
|
||||
raise ValueError(
|
||||
f"Node arity mismatch; "
|
||||
f"expected {self.num_children}, but got {len(child_pytrees)}.",
|
||||
)
|
||||
if context != self.context:
|
||||
raise ValueError(
|
||||
f"Node context mismatch for custom node type {self.type!r}.",
|
||||
)
|
||||
else:
|
||||
# For builtin dictionary types, we allow some flexibility
|
||||
# Otherwise, we require exact matches
|
||||
both_standard_dict = (
|
||||
self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES
|
||||
)
|
||||
if node_type != self.type and not both_standard_dict:
|
||||
raise ValueError(
|
||||
f"Node type mismatch; "
|
||||
f"expected {self.type!r}, but got {node_type!r}.",
|
||||
)
|
||||
if len(tree) != self.num_children:
|
||||
raise ValueError(
|
||||
f"Node arity mismatch; "
|
||||
f"expected {self.num_children}, but got {len(tree)}.",
|
||||
)
|
||||
|
||||
if both_standard_dict: # dictionary types are compatible with each other
|
||||
dict_context = (
|
||||
self.context
|
||||
if self.type is not defaultdict
|
||||
# ignore mismatch of `default_factory` for defaultdict
|
||||
else self.context[1]
|
||||
)
|
||||
expected_keys = dict_context
|
||||
got_key_set = set(tree)
|
||||
expected_key_set = set(expected_keys)
|
||||
if got_key_set != expected_key_set:
|
||||
missing_keys = expected_key_set.difference(got_key_set)
|
||||
extra_keys = got_key_set.difference(expected_key_set)
|
||||
message = ""
|
||||
if missing_keys:
|
||||
message += f"; missing key(s): {missing_keys}"
|
||||
if extra_keys:
|
||||
message += f"; extra key(s): {extra_keys}"
|
||||
raise ValueError(f"Node keys mismatch{message}.")
|
||||
child_pytrees = [tree[key] for key in expected_keys]
|
||||
else:
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
child_pytrees, context = flatten_fn(tree)
|
||||
if (
|
||||
context != self.context
|
||||
and self.type is not deque # ignore mismatch of `maxlen` for deque
|
||||
):
|
||||
raise ValueError(
|
||||
f"Node context mismatch for node type {self.type!r}; "
|
||||
f"expected {self.context!r}, but got {context!r}.", # namedtuple type mismatch
|
||||
)
|
||||
|
||||
for child_pytree, child_spec in zip(child_pytrees, self.children_specs):
|
||||
child_spec._flatten_up_to_helper(child_pytree, subtrees)
|
||||
|
||||
def flatten_up_to(self, tree: PyTree) -> List[PyTree]:
|
||||
subtrees: List[PyTree] = []
|
||||
self._flatten_up_to_helper(tree, subtrees)
|
||||
return subtrees
|
||||
|
||||
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
|
||||
if not isinstance(leaves, (list, tuple)):
|
||||
leaves = list(leaves)
|
||||
if len(leaves) != self.num_leaves:
|
||||
raise ValueError(
|
||||
f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
|
||||
f"but the spec refers to a pytree that holds {self.num_leaves} "
|
||||
f"items ({self}).",
|
||||
)
|
||||
if self.is_leaf():
|
||||
return leaves[0]
|
||||
|
||||
unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn
|
||||
|
||||
# Recursively unflatten the children
|
||||
start = 0
|
||||
end = 0
|
||||
child_pytrees = []
|
||||
for child_spec in self.children_specs:
|
||||
end += child_spec.num_leaves
|
||||
child_pytrees.append(child_spec.unflatten(leaves[start:end]))
|
||||
start = end
|
||||
|
||||
return unflatten_fn(child_pytrees, self.context)
|
||||
|
||||
|
||||
class LeafSpec(TreeSpec):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(None, None, [])
|
||||
self.num_nodes = 1
|
||||
self.num_leaves = 1
|
||||
self.num_children = 0
|
||||
|
||||
def __repr__(self, indent: int = 0) -> str:
|
||||
return "*"
|
||||
|
|
@ -546,29 +669,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
|||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be "
|
||||
f"instance of TreeSpec but got item of type {type(treespec)}.",
|
||||
)
|
||||
if not isinstance(leaves, (list, tuple)):
|
||||
leaves = list(leaves)
|
||||
if len(leaves) != treespec.num_leaves:
|
||||
raise ValueError(
|
||||
f"tree_unflatten(leaves, treespec): `leaves` has length {len(leaves)} "
|
||||
f"but the spec refers to a pytree that holds {treespec.num_leaves} "
|
||||
f"items ({treespec}).",
|
||||
)
|
||||
if isinstance(treespec, LeafSpec):
|
||||
return leaves[0]
|
||||
|
||||
unflatten_fn = SUPPORTED_NODES[treespec.type].unflatten_fn
|
||||
|
||||
# Recursively unflatten the children
|
||||
start = 0
|
||||
end = 0
|
||||
child_pytrees = []
|
||||
for child_spec in treespec.children_specs:
|
||||
end += child_spec.num_leaves
|
||||
child_pytrees.append(tree_unflatten(leaves[start:end], child_spec))
|
||||
start = end
|
||||
|
||||
return unflatten_fn(child_pytrees, treespec.context)
|
||||
return treespec.unflatten(leaves)
|
||||
|
||||
|
||||
def _tree_leaves_helper(tree: PyTree, leaves: List[Any]) -> None:
|
||||
|
|
@ -597,14 +698,61 @@ def tree_structure(tree: PyTree) -> TreeSpec:
|
|||
return tree_flatten(tree)[1]
|
||||
|
||||
|
||||
def tree_map(func: Callable[..., Any], tree: PyTree) -> PyTree:
|
||||
flat_args, spec = tree_flatten(tree)
|
||||
return tree_unflatten([func(i) for i in flat_args], spec)
|
||||
def tree_map(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
|
||||
"""Map a multi-input function over pytree args to produce a new pytree.
|
||||
|
||||
See also :func:`tree_map_`.
|
||||
|
||||
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
|
||||
{'x': 8, 'y': (43, 65)}
|
||||
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
|
||||
{'x': False, 'y': (False, False), 'z': True}
|
||||
|
||||
If multiple inputs are given, the structure of the tree is taken from the first input;
|
||||
subsequent inputs need only have ``tree`` as a prefix:
|
||||
|
||||
>>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
|
||||
[[5, 7, 9], [6, 1, 2]]
|
||||
|
||||
Args:
|
||||
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
|
||||
corresponding leaves of the pytrees.
|
||||
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
|
||||
argument to function ``func``.
|
||||
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
||||
``tree`` or has ``tree`` as a prefix.
|
||||
|
||||
Returns:
|
||||
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
|
||||
``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
|
||||
is the tuple of values at corresponding nodes in ``rests``.
|
||||
"""
|
||||
leaves, treespec = tree_flatten(tree)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
return treespec.unflatten(map(func, *flat_args))
|
||||
|
||||
|
||||
def tree_map_(func: Callable[..., Any], tree: PyTree) -> PyTree:
|
||||
flat_args = tree_leaves(tree)
|
||||
deque(map(func, flat_args), maxlen=0) # consume and exhaust the iterable
|
||||
def tree_map_(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
|
||||
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
|
||||
|
||||
See also :func:`tree_map`.
|
||||
|
||||
Args:
|
||||
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
|
||||
corresponding leaves of the pytrees.
|
||||
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
|
||||
argument to function ``func``.
|
||||
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
||||
``tree`` or has ``tree`` as a prefix.
|
||||
|
||||
Returns:
|
||||
The original ``tree`` with the value at each leaf is given by the side-effect of function
|
||||
``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
|
||||
in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
|
||||
"""
|
||||
leaves, treespec = tree_flatten(tree)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
||||
return tree
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user