mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[pytree][dynamo] trace on native optree functions for community pytree support (#165860)
Resolves #164972 - #164972 All `torch.utils._cxx_pytree` functions are based on `optree` functions with hardcoded `none_is_leaf=True` and `namespace="torch"`. This PR changes the polyfills to generic `optree` functions with those arguments unhardcoded. This means `torch.utils._cxx_pytree` functions are still traceable while the community `optree` usages can get dynamo support additionally. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165860 Approved by: https://github.com/Lucaskabela
This commit is contained in:
parent
410e6a4321
commit
1009790ad8
|
|
@ -424,7 +424,7 @@ from user code:
|
|||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
d = {"a": 1}
|
||||
optree.tree_flatten(d)
|
||||
optree.tree_flatten_with_path(d)
|
||||
return torch.sin(x)
|
||||
|
||||
fn(torch.randn(4))
|
||||
|
|
@ -434,10 +434,10 @@ from user code:
|
|||
first_graph_break,
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten.
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten_with_path.
|
||||
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
|
||||
|
||||
Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: <missing reason>
|
||||
Developer debug context: module: optree._C, qualname: PyCapsule.flatten_with_path, skip reason: <missing reason>
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -110,6 +110,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
import torch.utils._cxx_pytree as cxx_pytree
|
||||
|
||||
pytree_modules["cxx"] = cxx_pytree
|
||||
pytree_modules["native_optree"] = cxx_pytree.optree
|
||||
else:
|
||||
cxx_pytree = None
|
||||
|
||||
|
|
@ -12862,6 +12863,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
|||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
|
|
@ -12876,6 +12880,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
|||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
|
|
@ -12893,6 +12900,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
|||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
|
|
@ -12910,6 +12920,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
|||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
|
|
@ -12931,6 +12944,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
|||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
|
|
@ -13032,7 +13048,13 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
|||
torch.ones(3, 2),
|
||||
1,
|
||||
]
|
||||
new_tree = pytree.tree_unflatten(new_leaves, treespec)
|
||||
if pytree.__name__ == "optree":
|
||||
# `None` is a internal node rather than leaf in default OpTree / JAX PyTree
|
||||
new_leaves.pop()
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
new_tree = pytree.tree_unflatten(treespec, new_leaves)
|
||||
else:
|
||||
new_tree = pytree.tree_unflatten(new_leaves, treespec)
|
||||
return leaves, new_tree
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
|
|
@ -13087,6 +13109,10 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
|||
|
||||
@parametrize_pytree_module
|
||||
def test_pytree_tree_map_only(self, pytree):
|
||||
if not callable(getattr(pytree, "tree_map_only", None)):
|
||||
# OpTree and JAX PyTree do not have `tree_map_only`
|
||||
return
|
||||
|
||||
def fn(xs):
|
||||
def mapper(x):
|
||||
return x.clone()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from __future__ import annotations
|
|||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Literal, TYPE_CHECKING
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
|
|
@ -28,7 +28,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
import optree
|
||||
import optree._C
|
||||
|
||||
import torch.utils._cxx_pytree as cxx_pytree
|
||||
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._cxx_pytree import PyTree
|
||||
|
|
@ -64,45 +64,69 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
del __func
|
||||
del __name
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True)
|
||||
@substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True)
|
||||
def tree_is_leaf(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> bool:
|
||||
if tree is None or (is_leaf is not None and is_leaf(tree)):
|
||||
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
|
||||
return True
|
||||
if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined]
|
||||
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined]
|
||||
return True
|
||||
return False
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
|
||||
@substitute_in_graph(optree.tree_iter, can_constant_fold_through=False)
|
||||
def tree_iter(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> Iterable[Any]:
|
||||
stack = [tree]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if tree_is_leaf(node, is_leaf=is_leaf):
|
||||
if tree_is_leaf(
|
||||
node,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
):
|
||||
yield node
|
||||
continue
|
||||
|
||||
children, *_ = optree.tree_flatten_one_level(
|
||||
node,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
stack.extend(reversed(children))
|
||||
|
||||
__all__ += ["tree_iter"]
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
|
||||
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True)
|
||||
def tree_leaves(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> list[Any]:
|
||||
return list(tree_iter(tree, is_leaf=is_leaf))
|
||||
return list(
|
||||
tree_iter(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
)
|
||||
|
||||
__all__ += ["tree_leaves"]
|
||||
|
||||
|
|
@ -127,12 +151,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
_metadata: Any
|
||||
_entries: tuple[Any, ...]
|
||||
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
|
||||
none_is_leaf: bool
|
||||
namespace: str
|
||||
|
||||
num_nodes: int = field(init=False)
|
||||
num_leaves: int = field(init=False)
|
||||
num_children: int = field(init=False)
|
||||
none_is_leaf: Literal[True] = field(init=False)
|
||||
namespace: Literal["torch"] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self._type is None:
|
||||
|
|
@ -152,8 +176,6 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
object.__setattr__(self, "num_nodes", num_nodes)
|
||||
object.__setattr__(self, "num_leaves", num_leaves)
|
||||
object.__setattr__(self, "num_children", num_children)
|
||||
object.__setattr__(self, "none_is_leaf", True)
|
||||
object.__setattr__(self, "namespace", "torch")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def helper(treespec: PyTreeSpec) -> str:
|
||||
|
|
@ -168,6 +190,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
]
|
||||
if (
|
||||
treespec.type in BUILTIN_TYPES
|
||||
or (treespec.type is type(None) and not self.none_is_leaf)
|
||||
or optree.is_namedtuple_class(treespec.type)
|
||||
or optree.is_structseq_class(treespec.type)
|
||||
):
|
||||
|
|
@ -181,9 +204,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
f"[{', '.join(children_representations)}])"
|
||||
)
|
||||
|
||||
return (
|
||||
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
|
||||
)
|
||||
inner = [
|
||||
str(helper(self)),
|
||||
*(["NoneIsLeaf"] if self.none_is_leaf else []),
|
||||
f"namespace={self.namespace!r}",
|
||||
]
|
||||
return f"PyTreeSpec({', '.join(inner)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_leaves
|
||||
|
|
@ -228,8 +254,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
|
||||
children, metadata, *_ = optree.tree_flatten_one_level(
|
||||
node,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
none_is_leaf=self.none_is_leaf,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
if len(children) != treespec.num_children:
|
||||
raise ValueError(
|
||||
|
|
@ -277,8 +303,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
# node_type is treespec.type
|
||||
children, metadata, *_ = optree.tree_flatten_one_level(
|
||||
node,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
none_is_leaf=self.none_is_leaf,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
if (
|
||||
node_type
|
||||
|
|
@ -320,25 +346,40 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
assert callable(self._unflatten_func)
|
||||
return self._unflatten_func(self._metadata, subtrees)
|
||||
|
||||
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
|
||||
|
||||
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
|
||||
return isinstance(obj, PyTreeSpec)
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
cxx_pytree.tree_flatten,
|
||||
optree.tree_flatten,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def tree_flatten(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> tuple[list[Any], PyTreeSpec]:
|
||||
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
|
||||
if tree_is_leaf(node, is_leaf=is_leaf):
|
||||
if tree_is_leaf(
|
||||
node,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
):
|
||||
leaves.append(node)
|
||||
return _LEAF_SPEC
|
||||
return PyTreeSpec(
|
||||
(),
|
||||
None,
|
||||
None,
|
||||
(),
|
||||
None,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
(
|
||||
children,
|
||||
|
|
@ -348,13 +389,21 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
) = optree.tree_flatten_one_level(
|
||||
node,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
# Recursively flatten the children
|
||||
subspecs = tuple(helper(child, leaves) for child in children)
|
||||
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]
|
||||
return PyTreeSpec(
|
||||
subspecs,
|
||||
type(node),
|
||||
metadata,
|
||||
entries,
|
||||
unflatten_func,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
leaves: list[Any] = []
|
||||
treespec = helper(tree, leaves)
|
||||
|
|
@ -363,26 +412,35 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
__all__ += ["tree_flatten"]
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
cxx_pytree.tree_structure,
|
||||
optree.tree_structure,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def tree_structure(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTreeSpec:
|
||||
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]
|
||||
return tree_flatten( # type: ignore[return-value]
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)[1]
|
||||
|
||||
__all__ += ["tree_structure"]
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
cxx_pytree.tree_unflatten,
|
||||
optree.tree_unflatten,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
|
||||
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
||||
|
|
@ -392,29 +450,57 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||
|
||||
__all__ += ["tree_unflatten"]
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
|
||||
@substitute_in_graph(optree.tree_map, can_constant_fold_through=True)
|
||||
def tree_map(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
/,
|
||||
*rests: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTree:
|
||||
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
||||
leaves, treespec = tree_flatten(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
return treespec.unflatten(map(func, *flat_args))
|
||||
|
||||
__all__ += ["tree_map"]
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
|
||||
@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True)
|
||||
def tree_map_(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
/,
|
||||
*rests: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTree:
|
||||
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
||||
leaves, treespec = tree_flatten(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
||||
return tree
|
||||
|
||||
__all__ += ["tree_map_"]
|
||||
|
||||
_none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr]
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
_none_unflatten,
|
||||
can_constant_fold_through=True,
|
||||
skip_signature_check=True,
|
||||
)
|
||||
def none_unflatten(_: None, children: Iterable[Any], /) -> None:
|
||||
if len(list(children)) != 0:
|
||||
raise ValueError("Expected no children.")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -367,11 +367,6 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
|||
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
|
||||
``treespec``.
|
||||
"""
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user