mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytree][dynamo] trace on native optree functions for community pytree support (#165860)
Resolves #164972 - #164972 All `torch.utils._cxx_pytree` functions are based on `optree` functions with hardcoded `none_is_leaf=True` and `namespace="torch"`. This PR changes the polyfills to generic `optree` functions with those arguments unhardcoded. This means `torch.utils._cxx_pytree` functions are still traceable while the community `optree` usages can get dynamo support additionally. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165860 Approved by: https://github.com/Lucaskabela
This commit is contained in:
parent
410e6a4321
commit
1009790ad8
|
|
@ -424,7 +424,7 @@ from user code:
|
||||||
@torch.compile(backend="eager")
|
@torch.compile(backend="eager")
|
||||||
def fn(x):
|
def fn(x):
|
||||||
d = {"a": 1}
|
d = {"a": 1}
|
||||||
optree.tree_flatten(d)
|
optree.tree_flatten_with_path(d)
|
||||||
return torch.sin(x)
|
return torch.sin(x)
|
||||||
|
|
||||||
fn(torch.randn(4))
|
fn(torch.randn(4))
|
||||||
|
|
@ -434,10 +434,10 @@ from user code:
|
||||||
first_graph_break,
|
first_graph_break,
|
||||||
"""\
|
"""\
|
||||||
Attempted to call function marked as skipped
|
Attempted to call function marked as skipped
|
||||||
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten.
|
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten_with_path.
|
||||||
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
|
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
|
||||||
|
|
||||||
Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: <missing reason>
|
Developer debug context: module: optree._C, qualname: PyCapsule.flatten_with_path, skip reason: <missing reason>
|
||||||
|
|
||||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
|
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -110,6 +110,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
import torch.utils._cxx_pytree as cxx_pytree
|
import torch.utils._cxx_pytree as cxx_pytree
|
||||||
|
|
||||||
pytree_modules["cxx"] = cxx_pytree
|
pytree_modules["cxx"] = cxx_pytree
|
||||||
|
pytree_modules["native_optree"] = cxx_pytree.optree
|
||||||
else:
|
else:
|
||||||
cxx_pytree = None
|
cxx_pytree = None
|
||||||
|
|
||||||
|
|
@ -12862,6 +12863,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||||
def fn(xs):
|
def fn(xs):
|
||||||
flat_xs, spec = pytree.tree_flatten(xs)
|
flat_xs, spec = pytree.tree_flatten(xs)
|
||||||
res = [x.clone() for x in flat_xs]
|
res = [x.clone() for x in flat_xs]
|
||||||
|
if pytree.__name__ == "optree":
|
||||||
|
# The treespec argument comes first in OpTree / JAX PyTree
|
||||||
|
return pytree.tree_unflatten(spec, res)
|
||||||
return pytree.tree_unflatten(res, spec)
|
return pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
|
|
@ -12876,6 +12880,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||||
def fn(xs):
|
def fn(xs):
|
||||||
flat_xs, spec = pytree.tree_flatten(xs)
|
flat_xs, spec = pytree.tree_flatten(xs)
|
||||||
res = [x.clone() for x in flat_xs]
|
res = [x.clone() for x in flat_xs]
|
||||||
|
if pytree.__name__ == "optree":
|
||||||
|
# The treespec argument comes first in OpTree / JAX PyTree
|
||||||
|
return pytree.tree_unflatten(spec, res)
|
||||||
return pytree.tree_unflatten(res, spec)
|
return pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
|
|
@ -12893,6 +12900,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||||
def fn(xs):
|
def fn(xs):
|
||||||
flat_xs, spec = pytree.tree_flatten(xs)
|
flat_xs, spec = pytree.tree_flatten(xs)
|
||||||
res = [x.clone() for x in flat_xs]
|
res = [x.clone() for x in flat_xs]
|
||||||
|
if pytree.__name__ == "optree":
|
||||||
|
# The treespec argument comes first in OpTree / JAX PyTree
|
||||||
|
return pytree.tree_unflatten(spec, res)
|
||||||
return pytree.tree_unflatten(res, spec)
|
return pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
|
|
@ -12910,6 +12920,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||||
def fn(xs):
|
def fn(xs):
|
||||||
flat_xs, spec = pytree.tree_flatten(xs)
|
flat_xs, spec = pytree.tree_flatten(xs)
|
||||||
res = [x.clone() for x in flat_xs]
|
res = [x.clone() for x in flat_xs]
|
||||||
|
if pytree.__name__ == "optree":
|
||||||
|
# The treespec argument comes first in OpTree / JAX PyTree
|
||||||
|
return pytree.tree_unflatten(spec, res)
|
||||||
return pytree.tree_unflatten(res, spec)
|
return pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
|
|
@ -12931,6 +12944,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||||
def fn(xs):
|
def fn(xs):
|
||||||
flat_xs, spec = pytree.tree_flatten(xs)
|
flat_xs, spec = pytree.tree_flatten(xs)
|
||||||
res = [x.clone() for x in flat_xs]
|
res = [x.clone() for x in flat_xs]
|
||||||
|
if pytree.__name__ == "optree":
|
||||||
|
# The treespec argument comes first in OpTree / JAX PyTree
|
||||||
|
return pytree.tree_unflatten(spec, res)
|
||||||
return pytree.tree_unflatten(res, spec)
|
return pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
|
|
@ -13032,7 +13048,13 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||||
torch.ones(3, 2),
|
torch.ones(3, 2),
|
||||||
1,
|
1,
|
||||||
]
|
]
|
||||||
new_tree = pytree.tree_unflatten(new_leaves, treespec)
|
if pytree.__name__ == "optree":
|
||||||
|
# `None` is a internal node rather than leaf in default OpTree / JAX PyTree
|
||||||
|
new_leaves.pop()
|
||||||
|
# The treespec argument comes first in OpTree / JAX PyTree
|
||||||
|
new_tree = pytree.tree_unflatten(treespec, new_leaves)
|
||||||
|
else:
|
||||||
|
new_tree = pytree.tree_unflatten(new_leaves, treespec)
|
||||||
return leaves, new_tree
|
return leaves, new_tree
|
||||||
|
|
||||||
x = torch.randn(3, 2)
|
x = torch.randn(3, 2)
|
||||||
|
|
@ -13087,6 +13109,10 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||||
|
|
||||||
@parametrize_pytree_module
|
@parametrize_pytree_module
|
||||||
def test_pytree_tree_map_only(self, pytree):
|
def test_pytree_tree_map_only(self, pytree):
|
||||||
|
if not callable(getattr(pytree, "tree_map_only", None)):
|
||||||
|
# OpTree and JAX PyTree do not have `tree_map_only`
|
||||||
|
return
|
||||||
|
|
||||||
def fn(xs):
|
def fn(xs):
|
||||||
def mapper(x):
|
def mapper(x):
|
||||||
return x.clone()
|
return x.clone()
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Literal, TYPE_CHECKING
|
from typing import Any, Callable, TYPE_CHECKING
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
import torch.utils._pytree as python_pytree
|
import torch.utils._pytree as python_pytree
|
||||||
|
|
@ -28,7 +28,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
import optree
|
import optree
|
||||||
import optree._C
|
import optree._C
|
||||||
|
|
||||||
import torch.utils._cxx_pytree as cxx_pytree
|
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.utils._cxx_pytree import PyTree
|
from torch.utils._cxx_pytree import PyTree
|
||||||
|
|
@ -64,45 +64,69 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
del __func
|
del __func
|
||||||
del __name
|
del __name
|
||||||
|
|
||||||
@substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True)
|
@substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True)
|
||||||
def tree_is_leaf(
|
def tree_is_leaf(
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
|
/,
|
||||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||||
|
*,
|
||||||
|
none_is_leaf: bool = False,
|
||||||
|
namespace: str = "",
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if tree is None or (is_leaf is not None and is_leaf(tree)):
|
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
|
||||||
return True
|
return True
|
||||||
if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined]
|
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined]
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
|
@substitute_in_graph(optree.tree_iter, can_constant_fold_through=False)
|
||||||
def tree_iter(
|
def tree_iter(
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
|
/,
|
||||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||||
|
*,
|
||||||
|
none_is_leaf: bool = False,
|
||||||
|
namespace: str = "",
|
||||||
) -> Iterable[Any]:
|
) -> Iterable[Any]:
|
||||||
stack = [tree]
|
stack = [tree]
|
||||||
while stack:
|
while stack:
|
||||||
node = stack.pop()
|
node = stack.pop()
|
||||||
if tree_is_leaf(node, is_leaf=is_leaf):
|
if tree_is_leaf(
|
||||||
|
node,
|
||||||
|
is_leaf=is_leaf,
|
||||||
|
none_is_leaf=none_is_leaf,
|
||||||
|
namespace=namespace,
|
||||||
|
):
|
||||||
yield node
|
yield node
|
||||||
continue
|
continue
|
||||||
|
|
||||||
children, *_ = optree.tree_flatten_one_level(
|
children, *_ = optree.tree_flatten_one_level(
|
||||||
node,
|
node,
|
||||||
is_leaf=is_leaf,
|
is_leaf=is_leaf,
|
||||||
none_is_leaf=True,
|
none_is_leaf=none_is_leaf,
|
||||||
namespace="torch",
|
namespace=namespace,
|
||||||
)
|
)
|
||||||
stack.extend(reversed(children))
|
stack.extend(reversed(children))
|
||||||
|
|
||||||
__all__ += ["tree_iter"]
|
__all__ += ["tree_iter"]
|
||||||
|
|
||||||
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
|
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True)
|
||||||
def tree_leaves(
|
def tree_leaves(
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
|
/,
|
||||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||||
|
*,
|
||||||
|
none_is_leaf: bool = False,
|
||||||
|
namespace: str = "",
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
return list(tree_iter(tree, is_leaf=is_leaf))
|
return list(
|
||||||
|
tree_iter(
|
||||||
|
tree,
|
||||||
|
is_leaf=is_leaf,
|
||||||
|
none_is_leaf=none_is_leaf,
|
||||||
|
namespace=namespace,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
__all__ += ["tree_leaves"]
|
__all__ += ["tree_leaves"]
|
||||||
|
|
||||||
|
|
@ -127,12 +151,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
_metadata: Any
|
_metadata: Any
|
||||||
_entries: tuple[Any, ...]
|
_entries: tuple[Any, ...]
|
||||||
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
|
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
|
||||||
|
none_is_leaf: bool
|
||||||
|
namespace: str
|
||||||
|
|
||||||
num_nodes: int = field(init=False)
|
num_nodes: int = field(init=False)
|
||||||
num_leaves: int = field(init=False)
|
num_leaves: int = field(init=False)
|
||||||
num_children: int = field(init=False)
|
num_children: int = field(init=False)
|
||||||
none_is_leaf: Literal[True] = field(init=False)
|
|
||||||
namespace: Literal["torch"] = field(init=False)
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self._type is None:
|
if self._type is None:
|
||||||
|
|
@ -152,8 +176,6 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
object.__setattr__(self, "num_nodes", num_nodes)
|
object.__setattr__(self, "num_nodes", num_nodes)
|
||||||
object.__setattr__(self, "num_leaves", num_leaves)
|
object.__setattr__(self, "num_leaves", num_leaves)
|
||||||
object.__setattr__(self, "num_children", num_children)
|
object.__setattr__(self, "num_children", num_children)
|
||||||
object.__setattr__(self, "none_is_leaf", True)
|
|
||||||
object.__setattr__(self, "namespace", "torch")
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
def helper(treespec: PyTreeSpec) -> str:
|
def helper(treespec: PyTreeSpec) -> str:
|
||||||
|
|
@ -168,6 +190,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
]
|
]
|
||||||
if (
|
if (
|
||||||
treespec.type in BUILTIN_TYPES
|
treespec.type in BUILTIN_TYPES
|
||||||
|
or (treespec.type is type(None) and not self.none_is_leaf)
|
||||||
or optree.is_namedtuple_class(treespec.type)
|
or optree.is_namedtuple_class(treespec.type)
|
||||||
or optree.is_structseq_class(treespec.type)
|
or optree.is_structseq_class(treespec.type)
|
||||||
):
|
):
|
||||||
|
|
@ -181,9 +204,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
f"[{', '.join(children_representations)}])"
|
f"[{', '.join(children_representations)}])"
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
inner = [
|
||||||
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
|
str(helper(self)),
|
||||||
)
|
*(["NoneIsLeaf"] if self.none_is_leaf else []),
|
||||||
|
f"namespace={self.namespace!r}",
|
||||||
|
]
|
||||||
|
return f"PyTreeSpec({', '.join(inner)})"
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.num_leaves
|
return self.num_leaves
|
||||||
|
|
@ -228,8 +254,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
|
|
||||||
children, metadata, *_ = optree.tree_flatten_one_level(
|
children, metadata, *_ = optree.tree_flatten_one_level(
|
||||||
node,
|
node,
|
||||||
none_is_leaf=True,
|
none_is_leaf=self.none_is_leaf,
|
||||||
namespace="torch",
|
namespace=self.namespace,
|
||||||
)
|
)
|
||||||
if len(children) != treespec.num_children:
|
if len(children) != treespec.num_children:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -277,8 +303,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
# node_type is treespec.type
|
# node_type is treespec.type
|
||||||
children, metadata, *_ = optree.tree_flatten_one_level(
|
children, metadata, *_ = optree.tree_flatten_one_level(
|
||||||
node,
|
node,
|
||||||
none_is_leaf=True,
|
none_is_leaf=self.none_is_leaf,
|
||||||
namespace="torch",
|
namespace=self.namespace,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
node_type
|
node_type
|
||||||
|
|
@ -320,25 +346,40 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
assert callable(self._unflatten_func)
|
assert callable(self._unflatten_func)
|
||||||
return self._unflatten_func(self._metadata, subtrees)
|
return self._unflatten_func(self._metadata, subtrees)
|
||||||
|
|
||||||
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
|
|
||||||
|
|
||||||
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
|
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
|
||||||
return isinstance(obj, PyTreeSpec)
|
return isinstance(obj, PyTreeSpec)
|
||||||
|
|
||||||
@substitute_in_graph( # type: ignore[arg-type]
|
@substitute_in_graph( # type: ignore[arg-type]
|
||||||
cxx_pytree.tree_flatten,
|
optree.tree_flatten,
|
||||||
# We need to disable constant folding here because we want the function to reference the
|
# We need to disable constant folding here because we want the function to reference the
|
||||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||||
can_constant_fold_through=False,
|
can_constant_fold_through=False,
|
||||||
)
|
)
|
||||||
def tree_flatten(
|
def tree_flatten(
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
|
/,
|
||||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||||
|
*,
|
||||||
|
none_is_leaf: bool = False,
|
||||||
|
namespace: str = "",
|
||||||
) -> tuple[list[Any], PyTreeSpec]:
|
) -> tuple[list[Any], PyTreeSpec]:
|
||||||
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
|
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
|
||||||
if tree_is_leaf(node, is_leaf=is_leaf):
|
if tree_is_leaf(
|
||||||
|
node,
|
||||||
|
is_leaf=is_leaf,
|
||||||
|
none_is_leaf=none_is_leaf,
|
||||||
|
namespace=namespace,
|
||||||
|
):
|
||||||
leaves.append(node)
|
leaves.append(node)
|
||||||
return _LEAF_SPEC
|
return PyTreeSpec(
|
||||||
|
(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
(),
|
||||||
|
None,
|
||||||
|
none_is_leaf=none_is_leaf,
|
||||||
|
namespace=namespace,
|
||||||
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
children,
|
children,
|
||||||
|
|
@ -348,13 +389,21 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
) = optree.tree_flatten_one_level(
|
) = optree.tree_flatten_one_level(
|
||||||
node,
|
node,
|
||||||
is_leaf=is_leaf,
|
is_leaf=is_leaf,
|
||||||
none_is_leaf=True,
|
none_is_leaf=none_is_leaf,
|
||||||
namespace="torch",
|
namespace=namespace,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Recursively flatten the children
|
# Recursively flatten the children
|
||||||
subspecs = tuple(helper(child, leaves) for child in children)
|
subspecs = tuple(helper(child, leaves) for child in children)
|
||||||
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]
|
return PyTreeSpec(
|
||||||
|
subspecs,
|
||||||
|
type(node),
|
||||||
|
metadata,
|
||||||
|
entries,
|
||||||
|
unflatten_func,
|
||||||
|
none_is_leaf=none_is_leaf,
|
||||||
|
namespace=namespace,
|
||||||
|
) # type: ignore[arg-type]
|
||||||
|
|
||||||
leaves: list[Any] = []
|
leaves: list[Any] = []
|
||||||
treespec = helper(tree, leaves)
|
treespec = helper(tree, leaves)
|
||||||
|
|
@ -363,26 +412,35 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
__all__ += ["tree_flatten"]
|
__all__ += ["tree_flatten"]
|
||||||
|
|
||||||
@substitute_in_graph( # type: ignore[arg-type]
|
@substitute_in_graph( # type: ignore[arg-type]
|
||||||
cxx_pytree.tree_structure,
|
optree.tree_structure,
|
||||||
# We need to disable constant folding here because we want the function to reference the
|
# We need to disable constant folding here because we want the function to reference the
|
||||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||||
can_constant_fold_through=False,
|
can_constant_fold_through=False,
|
||||||
)
|
)
|
||||||
def tree_structure(
|
def tree_structure(
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
|
/,
|
||||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||||
|
*,
|
||||||
|
none_is_leaf: bool = False,
|
||||||
|
namespace: str = "",
|
||||||
) -> PyTreeSpec:
|
) -> PyTreeSpec:
|
||||||
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]
|
return tree_flatten( # type: ignore[return-value]
|
||||||
|
tree,
|
||||||
|
is_leaf=is_leaf,
|
||||||
|
none_is_leaf=none_is_leaf,
|
||||||
|
namespace=namespace,
|
||||||
|
)[1]
|
||||||
|
|
||||||
__all__ += ["tree_structure"]
|
__all__ += ["tree_structure"]
|
||||||
|
|
||||||
@substitute_in_graph( # type: ignore[arg-type]
|
@substitute_in_graph( # type: ignore[arg-type]
|
||||||
cxx_pytree.tree_unflatten,
|
optree.tree_unflatten,
|
||||||
# We need to disable constant folding here because we want the function to reference the
|
# We need to disable constant folding here because we want the function to reference the
|
||||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||||
can_constant_fold_through=False,
|
can_constant_fold_through=False,
|
||||||
)
|
)
|
||||||
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
|
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
|
||||||
if not _is_pytreespec_instance(treespec):
|
if not _is_pytreespec_instance(treespec):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
||||||
|
|
@ -392,29 +450,57 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||||
|
|
||||||
__all__ += ["tree_unflatten"]
|
__all__ += ["tree_unflatten"]
|
||||||
|
|
||||||
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
|
@substitute_in_graph(optree.tree_map, can_constant_fold_through=True)
|
||||||
def tree_map(
|
def tree_map(
|
||||||
func: Callable[..., Any],
|
func: Callable[..., Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
|
/,
|
||||||
*rests: PyTree,
|
*rests: PyTree,
|
||||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||||
|
none_is_leaf: bool = False,
|
||||||
|
namespace: str = "",
|
||||||
) -> PyTree:
|
) -> PyTree:
|
||||||
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
leaves, treespec = tree_flatten(
|
||||||
|
tree,
|
||||||
|
is_leaf=is_leaf,
|
||||||
|
none_is_leaf=none_is_leaf,
|
||||||
|
namespace=namespace,
|
||||||
|
)
|
||||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||||
return treespec.unflatten(map(func, *flat_args))
|
return treespec.unflatten(map(func, *flat_args))
|
||||||
|
|
||||||
__all__ += ["tree_map"]
|
__all__ += ["tree_map"]
|
||||||
|
|
||||||
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
|
@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True)
|
||||||
def tree_map_(
|
def tree_map_(
|
||||||
func: Callable[..., Any],
|
func: Callable[..., Any],
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
|
/,
|
||||||
*rests: PyTree,
|
*rests: PyTree,
|
||||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||||
|
none_is_leaf: bool = False,
|
||||||
|
namespace: str = "",
|
||||||
) -> PyTree:
|
) -> PyTree:
|
||||||
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
leaves, treespec = tree_flatten(
|
||||||
|
tree,
|
||||||
|
is_leaf=is_leaf,
|
||||||
|
none_is_leaf=none_is_leaf,
|
||||||
|
namespace=namespace,
|
||||||
|
)
|
||||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||||
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
__all__ += ["tree_map_"]
|
__all__ += ["tree_map_"]
|
||||||
|
|
||||||
|
_none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr]
|
||||||
|
|
||||||
|
@substitute_in_graph( # type: ignore[arg-type]
|
||||||
|
_none_unflatten,
|
||||||
|
can_constant_fold_through=True,
|
||||||
|
skip_signature_check=True,
|
||||||
|
)
|
||||||
|
def none_unflatten(_: None, children: Iterable[Any], /) -> None:
|
||||||
|
if len(list(children)) != 0:
|
||||||
|
raise ValueError("Expected no children.")
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -367,11 +367,6 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||||
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
|
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
|
||||||
``treespec``.
|
``treespec``.
|
||||||
"""
|
"""
|
||||||
if not _is_pytreespec_instance(treespec):
|
|
||||||
raise TypeError(
|
|
||||||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
|
||||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
|
||||||
)
|
|
||||||
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
|
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user