[pytree][dynamo] trace on native optree functions for community pytree support (#165860)

Resolves #164972

- #164972

All `torch.utils._cxx_pytree` functions are based on `optree` functions with hardcoded `none_is_leaf=True` and `namespace="torch"`. This PR changes the polyfills to generic `optree` functions with those arguments unhardcoded. This means `torch.utils._cxx_pytree` functions are still traceable while the community `optree` usages can get dynamo support additionally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165860
Approved by: https://github.com/Lucaskabela
This commit is contained in:
Xuehai Pan 2025-10-21 14:13:08 +00:00 committed by PyTorch MergeBot
parent 410e6a4321
commit 1009790ad8
4 changed files with 154 additions and 47 deletions

View File

@ -424,7 +424,7 @@ from user code:
@torch.compile(backend="eager") @torch.compile(backend="eager")
def fn(x): def fn(x):
d = {"a": 1} d = {"a": 1}
optree.tree_flatten(d) optree.tree_flatten_with_path(d)
return torch.sin(x) return torch.sin(x)
fn(torch.randn(4)) fn(torch.randn(4))
@ -434,10 +434,10 @@ from user code:
first_graph_break, first_graph_break,
"""\ """\
Attempted to call function marked as skipped Attempted to call function marked as skipped
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten. Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten_with_path.
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: <missing reason> Developer debug context: module: optree._C, qualname: PyCapsule.flatten_with_path, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""", For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
) )

View File

@ -110,6 +110,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
import torch.utils._cxx_pytree as cxx_pytree import torch.utils._cxx_pytree as cxx_pytree
pytree_modules["cxx"] = cxx_pytree pytree_modules["cxx"] = cxx_pytree
pytree_modules["native_optree"] = cxx_pytree.optree
else: else:
cxx_pytree = None cxx_pytree = None
@ -12862,6 +12863,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs): def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs) flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs] res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec) return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)] xs = [torch.tensor(i) for i in range(3)]
@ -12876,6 +12880,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs): def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs) flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs] res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec) return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)] xs = [torch.tensor(i) for i in range(3)]
@ -12893,6 +12900,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs): def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs) flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs] res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec) return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)] xs = [torch.tensor(i) for i in range(3)]
@ -12910,6 +12920,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs): def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs) flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs] res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec) return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)] xs = [torch.tensor(i) for i in range(3)]
@ -12931,6 +12944,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs): def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs) flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs] res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec) return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)] xs = [torch.tensor(i) for i in range(3)]
@ -13032,7 +13048,13 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
torch.ones(3, 2), torch.ones(3, 2),
1, 1,
] ]
new_tree = pytree.tree_unflatten(new_leaves, treespec) if pytree.__name__ == "optree":
# `None` is a internal node rather than leaf in default OpTree / JAX PyTree
new_leaves.pop()
# The treespec argument comes first in OpTree / JAX PyTree
new_tree = pytree.tree_unflatten(treespec, new_leaves)
else:
new_tree = pytree.tree_unflatten(new_leaves, treespec)
return leaves, new_tree return leaves, new_tree
x = torch.randn(3, 2) x = torch.randn(3, 2)
@ -13087,6 +13109,10 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
@parametrize_pytree_module @parametrize_pytree_module
def test_pytree_tree_map_only(self, pytree): def test_pytree_tree_map_only(self, pytree):
if not callable(getattr(pytree, "tree_map_only", None)):
# OpTree and JAX PyTree do not have `tree_map_only`
return
def fn(xs): def fn(xs):
def mapper(x): def mapper(x):
return x.clone() return x.clone()

View File

@ -6,7 +6,7 @@ from __future__ import annotations
from collections import deque from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Literal, TYPE_CHECKING from typing import Any, Callable, TYPE_CHECKING
from typing_extensions import TypeIs from typing_extensions import TypeIs
import torch.utils._pytree as python_pytree import torch.utils._pytree as python_pytree
@ -28,7 +28,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
import optree import optree
import optree._C import optree._C
import torch.utils._cxx_pytree as cxx_pytree import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.utils._cxx_pytree import PyTree from torch.utils._cxx_pytree import PyTree
@ -64,45 +64,69 @@ if python_pytree._cxx_pytree_dynamo_traceable:
del __func del __func
del __name del __name
@substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True) @substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True)
def tree_is_leaf( def tree_is_leaf(
tree: PyTree, tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None, is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> bool: ) -> bool:
if tree is None or (is_leaf is not None and is_leaf(tree)): if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
return True return True
if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined] if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined]
return True return True
return False return False
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False) @substitute_in_graph(optree.tree_iter, can_constant_fold_through=False)
def tree_iter( def tree_iter(
tree: PyTree, tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None, is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> Iterable[Any]: ) -> Iterable[Any]:
stack = [tree] stack = [tree]
while stack: while stack:
node = stack.pop() node = stack.pop()
if tree_is_leaf(node, is_leaf=is_leaf): if tree_is_leaf(
node,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
):
yield node yield node
continue continue
children, *_ = optree.tree_flatten_one_level( children, *_ = optree.tree_flatten_one_level(
node, node,
is_leaf=is_leaf, is_leaf=is_leaf,
none_is_leaf=True, none_is_leaf=none_is_leaf,
namespace="torch", namespace=namespace,
) )
stack.extend(reversed(children)) stack.extend(reversed(children))
__all__ += ["tree_iter"] __all__ += ["tree_iter"]
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True) @substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True)
def tree_leaves( def tree_leaves(
tree: PyTree, tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None, is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> list[Any]: ) -> list[Any]:
return list(tree_iter(tree, is_leaf=is_leaf)) return list(
tree_iter(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
)
__all__ += ["tree_leaves"] __all__ += ["tree_leaves"]
@ -127,12 +151,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
_metadata: Any _metadata: Any
_entries: tuple[Any, ...] _entries: tuple[Any, ...]
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None _unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
none_is_leaf: bool
namespace: str
num_nodes: int = field(init=False) num_nodes: int = field(init=False)
num_leaves: int = field(init=False) num_leaves: int = field(init=False)
num_children: int = field(init=False) num_children: int = field(init=False)
none_is_leaf: Literal[True] = field(init=False)
namespace: Literal["torch"] = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self._type is None: if self._type is None:
@ -152,8 +176,6 @@ if python_pytree._cxx_pytree_dynamo_traceable:
object.__setattr__(self, "num_nodes", num_nodes) object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves) object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children) object.__setattr__(self, "num_children", num_children)
object.__setattr__(self, "none_is_leaf", True)
object.__setattr__(self, "namespace", "torch")
def __repr__(self) -> str: def __repr__(self) -> str:
def helper(treespec: PyTreeSpec) -> str: def helper(treespec: PyTreeSpec) -> str:
@ -168,6 +190,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
] ]
if ( if (
treespec.type in BUILTIN_TYPES treespec.type in BUILTIN_TYPES
or (treespec.type is type(None) and not self.none_is_leaf)
or optree.is_namedtuple_class(treespec.type) or optree.is_namedtuple_class(treespec.type)
or optree.is_structseq_class(treespec.type) or optree.is_structseq_class(treespec.type)
): ):
@ -181,9 +204,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
f"[{', '.join(children_representations)}])" f"[{', '.join(children_representations)}])"
) )
return ( inner = [
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})" str(helper(self)),
) *(["NoneIsLeaf"] if self.none_is_leaf else []),
f"namespace={self.namespace!r}",
]
return f"PyTreeSpec({', '.join(inner)})"
def __len__(self) -> int: def __len__(self) -> int:
return self.num_leaves return self.num_leaves
@ -228,8 +254,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
children, metadata, *_ = optree.tree_flatten_one_level( children, metadata, *_ = optree.tree_flatten_one_level(
node, node,
none_is_leaf=True, none_is_leaf=self.none_is_leaf,
namespace="torch", namespace=self.namespace,
) )
if len(children) != treespec.num_children: if len(children) != treespec.num_children:
raise ValueError( raise ValueError(
@ -277,8 +303,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
# node_type is treespec.type # node_type is treespec.type
children, metadata, *_ = optree.tree_flatten_one_level( children, metadata, *_ = optree.tree_flatten_one_level(
node, node,
none_is_leaf=True, none_is_leaf=self.none_is_leaf,
namespace="torch", namespace=self.namespace,
) )
if ( if (
node_type node_type
@ -320,25 +346,40 @@ if python_pytree._cxx_pytree_dynamo_traceable:
assert callable(self._unflatten_func) assert callable(self._unflatten_func)
return self._unflatten_func(self._metadata, subtrees) return self._unflatten_func(self._metadata, subtrees)
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]: def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec) return isinstance(obj, PyTreeSpec)
@substitute_in_graph( # type: ignore[arg-type] @substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_flatten, optree.tree_flatten,
# We need to disable constant folding here because we want the function to reference the # We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module. # PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False, can_constant_fold_through=False,
) )
def tree_flatten( def tree_flatten(
tree: PyTree, tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None, is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[Any], PyTreeSpec]: ) -> tuple[list[Any], PyTreeSpec]:
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
if tree_is_leaf(node, is_leaf=is_leaf): if tree_is_leaf(
node,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
):
leaves.append(node) leaves.append(node)
return _LEAF_SPEC return PyTreeSpec(
(),
None,
None,
(),
None,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
( (
children, children,
@ -348,13 +389,21 @@ if python_pytree._cxx_pytree_dynamo_traceable:
) = optree.tree_flatten_one_level( ) = optree.tree_flatten_one_level(
node, node,
is_leaf=is_leaf, is_leaf=is_leaf,
none_is_leaf=True, none_is_leaf=none_is_leaf,
namespace="torch", namespace=namespace,
) )
# Recursively flatten the children # Recursively flatten the children
subspecs = tuple(helper(child, leaves) for child in children) subspecs = tuple(helper(child, leaves) for child in children)
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type] return PyTreeSpec(
subspecs,
type(node),
metadata,
entries,
unflatten_func,
none_is_leaf=none_is_leaf,
namespace=namespace,
) # type: ignore[arg-type]
leaves: list[Any] = [] leaves: list[Any] = []
treespec = helper(tree, leaves) treespec = helper(tree, leaves)
@ -363,26 +412,35 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_flatten"] __all__ += ["tree_flatten"]
@substitute_in_graph( # type: ignore[arg-type] @substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_structure, optree.tree_structure,
# We need to disable constant folding here because we want the function to reference the # We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module. # PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False, can_constant_fold_through=False,
) )
def tree_structure( def tree_structure(
tree: PyTree, tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None, is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTreeSpec: ) -> PyTreeSpec:
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value] return tree_flatten( # type: ignore[return-value]
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)[1]
__all__ += ["tree_structure"] __all__ += ["tree_structure"]
@substitute_in_graph( # type: ignore[arg-type] @substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_unflatten, optree.tree_unflatten,
# We need to disable constant folding here because we want the function to reference the # We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module. # PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False, can_constant_fold_through=False,
) )
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree: def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
if not _is_pytreespec_instance(treespec): if not _is_pytreespec_instance(treespec):
raise TypeError( raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of " f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
@ -392,29 +450,57 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_unflatten"] __all__ += ["tree_unflatten"]
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True) @substitute_in_graph(optree.tree_map, can_constant_fold_through=True)
def tree_map( def tree_map(
func: Callable[..., Any], func: Callable[..., Any],
tree: PyTree, tree: PyTree,
/,
*rests: PyTree, *rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None, is_leaf: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTree: ) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
return treespec.unflatten(map(func, *flat_args)) return treespec.unflatten(map(func, *flat_args))
__all__ += ["tree_map"] __all__ += ["tree_map"]
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True) @substitute_in_graph(optree.tree_map_, can_constant_fold_through=True)
def tree_map_( def tree_map_(
func: Callable[..., Any], func: Callable[..., Any],
tree: PyTree, tree: PyTree,
/,
*rests: PyTree, *rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None, is_leaf: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTree: ) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
return tree return tree
__all__ += ["tree_map_"] __all__ += ["tree_map_"]
_none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr]
@substitute_in_graph( # type: ignore[arg-type]
_none_unflatten,
can_constant_fold_through=True,
skip_signature_check=True,
)
def none_unflatten(_: None, children: Iterable[Any], /) -> None:
if len(list(children)) != 0:
raise ValueError("Expected no children.")
return None

View File

@ -367,11 +367,6 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
The reconstructed pytree, containing the ``leaves`` placed in the structure described by The reconstructed pytree, containing the ``leaves`` placed in the structure described by
``treespec``. ``treespec``.
""" """
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type] return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]