mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][pytree][1/N] make CXX pytree traceable: tree_iter / tree_leaves (#137397)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137397 Approved by: https://github.com/jansel ghstack dependencies: #141360
This commit is contained in:
parent
cdde73033e
commit
07850bb2c1
|
|
@ -32,7 +32,7 @@ import torch
|
|||
import torch._dynamo.testing
|
||||
import torch._inductor.test_case
|
||||
import torch.onnx.operators
|
||||
import torch.utils._pytree as pytree
|
||||
import torch.utils._pytree as python_pytree
|
||||
import torch.utils.cpp_extension
|
||||
from torch import Tensor
|
||||
from torch._C import FileCheck
|
||||
|
|
@ -89,9 +89,11 @@ from torch.testing._internal.jit_utils import JitTestCase
|
|||
from torch.testing._internal.logging_utils import logs_to_string
|
||||
|
||||
|
||||
HAS_OPTREE = importlib.util.find_spec("optree")
|
||||
HAS_OPTREE = python_pytree._cxx_pytree_exists
|
||||
if HAS_OPTREE:
|
||||
import optree
|
||||
import torch.utils._cxx_pytree as cxx_pytree
|
||||
else:
|
||||
cxx_pytree = None
|
||||
|
||||
MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"])
|
||||
T = typing.TypeVar("T")
|
||||
|
|
@ -293,9 +295,9 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
|||
|
||||
@unittest.skipIf(not HAS_OPTREE, "missing optree package")
|
||||
def test_optree_graph_break_message(self):
|
||||
@torch.compile(
|
||||
backend="eager",
|
||||
)
|
||||
import optree
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
d = {"a": 1}
|
||||
optree.tree_flatten(d)
|
||||
|
|
@ -8666,9 +8668,9 @@ def ___make_guard_fn():
|
|||
|
||||
def test_tracing_py_tree(self):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
return python_pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
|
||||
|
|
@ -8678,12 +8680,10 @@ def ___make_guard_fn():
|
|||
self.assertEqual(counter.op_count, 3)
|
||||
|
||||
def test_tracing_nested_py_tree(self):
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
return python_pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
xsl = [xs, xs, xs, xs]
|
||||
|
|
@ -8696,12 +8696,10 @@ def ___make_guard_fn():
|
|||
self.assertEqual(counter.op_count, 12)
|
||||
|
||||
def test_tracing_nested_py_tree_tuples(self):
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
return python_pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
xsl = (xs, xs, xs, xs)
|
||||
|
|
@ -8714,12 +8712,10 @@ def ___make_guard_fn():
|
|||
self.assertEqual(counter.op_count, 12)
|
||||
|
||||
def test_tracing_nested_py_tree_dicts(self):
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
return python_pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
xsl = {
|
||||
|
|
@ -8752,12 +8748,10 @@ def ___make_guard_fn():
|
|||
self.assertEqual(counter.op_count, 2)
|
||||
|
||||
def test_tracing_nested_py_tree_mixed_all(self):
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
return python_pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
xsa = (xs, xs)
|
||||
|
|
@ -8802,13 +8796,12 @@ def ___make_guard_fn():
|
|||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_tracing_py_tree_tensor_subclass(self):
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
def fn(xs):
|
||||
nested_xs = [[xs]]
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||
return flat_xs[0].clone()
|
||||
|
||||
# use checkpoint to trigger a "sourceless" tensor subclass
|
||||
|
|
@ -8823,13 +8816,11 @@ def ___make_guard_fn():
|
|||
self.assertEqual(counter.op_count, 2)
|
||||
|
||||
def test_tracing_tree_map_only(self):
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
def fn(xs):
|
||||
def mapper(x):
|
||||
return x.clone()
|
||||
|
||||
y = pytree.tree_map_only(torch.Tensor, mapper, xs)
|
||||
y = python_pytree.tree_map_only(torch.Tensor, mapper, xs)
|
||||
return y
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)] + ["hi"]
|
||||
|
|
@ -10183,7 +10174,9 @@ def ___make_guard_fn():
|
|||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_pytree_tree_leaves(self):
|
||||
implemtations = [("python", pytree)]
|
||||
implemtations = [("python", python_pytree)]
|
||||
if cxx_pytree is not None:
|
||||
implemtations.append(("cxx", cxx_pytree))
|
||||
|
||||
for name, module in implemtations:
|
||||
with self.subTest(f"pytree implement: {name}"):
|
||||
|
|
@ -10215,7 +10208,7 @@ def ___make_guard_fn():
|
|||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_pytree_tree_flatten_unflatten(self):
|
||||
implemtations = [("python", pytree)]
|
||||
implemtations = [("python", python_pytree)]
|
||||
|
||||
for name, module in implemtations:
|
||||
with self.subTest(f"pytree implement: {name}"):
|
||||
|
|
@ -10264,7 +10257,7 @@ def ___make_guard_fn():
|
|||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_pytree_tree_map(self):
|
||||
implemtations = [("python", pytree)]
|
||||
implemtations = [("python", python_pytree)]
|
||||
|
||||
for name, module in implemtations:
|
||||
with self.subTest(f"pytree implement: {name}"):
|
||||
|
|
|
|||
|
|
@ -2080,10 +2080,11 @@ class GuardBuilder(GuardBuilderBase):
|
|||
obj_ref = None
|
||||
# Not necessary to have weakref for Enum type, but there is a bug that
|
||||
# makes hasattr(guarded_object.__class__, "__weakref__") return True.
|
||||
supports_weakref = (
|
||||
getattr(guarded_object.__class__, "__weakrefoffset__", 0) != 0
|
||||
)
|
||||
# See D64140537 for why we are checking for tuple.
|
||||
if hasattr(guarded_object.__class__, "__weakref__") and not isinstance(
|
||||
guarded_object, (enum.Enum, tuple)
|
||||
):
|
||||
if supports_weakref and not isinstance(guarded_object, (enum.Enum, tuple)):
|
||||
obj_ref = weakref.ref(guarded_object)
|
||||
|
||||
guard.set_export_info(
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
|||
itertools as itertools,
|
||||
operator as operator,
|
||||
os as os,
|
||||
pytree as pytree,
|
||||
sys as sys,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
|
|||
"itertools",
|
||||
"operator",
|
||||
"os",
|
||||
"pytree",
|
||||
"sys",
|
||||
)
|
||||
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(
|
||||
|
|
|
|||
89
torch/_dynamo/polyfills/pytree.py
Normal file
89
torch/_dynamo/polyfills/pytree.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Python polyfills for torch.utils.pytree
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Iterable, TYPE_CHECKING
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
|
||||
from ..decorators import substitute_in_graph
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._cxx_pytree import PyTree
|
||||
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
if python_pytree._cxx_pytree_exists:
|
||||
import optree
|
||||
import optree._C
|
||||
|
||||
import torch.utils._cxx_pytree as cxx_pytree
|
||||
|
||||
@substitute_in_graph(
|
||||
optree._C.is_dict_insertion_ordered,
|
||||
can_constant_fold_through=True,
|
||||
)
|
||||
def _(*args: Any, **kwargs: Any) -> bool:
|
||||
# In namespace 'torch', the dictionary is always traversed in insertion order.
|
||||
# This function returns True.
|
||||
raise ValueError(
|
||||
"Should not be called directly "
|
||||
"because the original function will be called in the constant fold path."
|
||||
)
|
||||
|
||||
__name = ""
|
||||
for __name in (
|
||||
"is_namedtuple",
|
||||
"is_namedtuple_class",
|
||||
"is_namedtuple_instance",
|
||||
"is_structseq",
|
||||
"is_structseq_class",
|
||||
"is_structseq_instance",
|
||||
"namedtuple_fields",
|
||||
"structseq_fields",
|
||||
):
|
||||
__func = getattr(optree, __name)
|
||||
substitute_in_graph(__func, can_constant_fold_through=True)(
|
||||
__func.__python_implementation__
|
||||
)
|
||||
del __func
|
||||
del __name
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
|
||||
def tree_iter(
|
||||
tree: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> Iterable[Any]:
|
||||
stack = [tree]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if node is None or (is_leaf is not None and is_leaf(node)):
|
||||
yield node
|
||||
continue
|
||||
if optree.register_pytree_node.get(type(node), namespace="torch") is None: # type: ignore[attr-defined]
|
||||
yield node
|
||||
continue
|
||||
|
||||
children, *_ = optree.tree_flatten_one_level(
|
||||
node,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
stack.extend(reversed(children))
|
||||
|
||||
__all__ += ["tree_iter"]
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
|
||||
def tree_leaves(
|
||||
tree: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> list[Any]:
|
||||
return list(tree_iter(tree, is_leaf=is_leaf))
|
||||
|
||||
__all__ += ["tree_leaves"]
|
||||
|
|
@ -3310,6 +3310,7 @@ MOD_INLINELIST = [
|
|||
"torch.testing",
|
||||
"torch.utils._content_store",
|
||||
"torch.utils._contextlib",
|
||||
"torch.utils._cxx_pytree",
|
||||
"torch.utils._device",
|
||||
"torch.utils._foreach_utils",
|
||||
"torch.utils._python_dispatch",
|
||||
|
|
|
|||
|
|
@ -30,10 +30,10 @@ from typing import (
|
|||
from typing_extensions import deprecated
|
||||
|
||||
import optree
|
||||
from optree import PyTreeSpec # direct import for type annotations
|
||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||
|
||||
import torch.utils._pytree as _pytree
|
||||
from torch.utils._pytree import KeyEntry
|
||||
import torch.utils._pytree as python_pytree
|
||||
from torch.utils._pytree import KeyEntry as KeyEntry
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -79,7 +79,6 @@ R = TypeVar("R")
|
|||
|
||||
Context = Any
|
||||
PyTree = Any
|
||||
TreeSpec = PyTreeSpec
|
||||
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
|
||||
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
|
||||
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
|
||||
|
|
@ -151,9 +150,7 @@ def register_pytree_node(
|
|||
from_dumpable_context=from_dumpable_context,
|
||||
)
|
||||
|
||||
from . import _pytree as python
|
||||
|
||||
python._private_register_pytree_node(
|
||||
python_pytree._private_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
|
|
@ -871,24 +868,19 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
|||
f"treespec_dumps(spec): Expected `spec` to be instance of "
|
||||
f"TreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
from ._pytree import (
|
||||
tree_structure as _tree_structure,
|
||||
treespec_dumps as _treespec_dumps,
|
||||
)
|
||||
|
||||
orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec))
|
||||
return _treespec_dumps(orig_treespec, protocol=protocol)
|
||||
dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
|
||||
orig_treespec = python_pytree.tree_structure(dummy_tree)
|
||||
return python_pytree.treespec_dumps(orig_treespec, protocol=protocol)
|
||||
|
||||
|
||||
def treespec_loads(serialized: str) -> TreeSpec:
|
||||
"""Deserialize a treespec from a JSON string."""
|
||||
from ._pytree import (
|
||||
tree_unflatten as _tree_unflatten,
|
||||
treespec_loads as _treespec_loads,
|
||||
orig_treespec = python_pytree.treespec_loads(serialized)
|
||||
dummy_tree = python_pytree.tree_unflatten(
|
||||
[0] * orig_treespec.num_leaves,
|
||||
orig_treespec,
|
||||
)
|
||||
|
||||
orig_treespec = _treespec_loads(serialized)
|
||||
dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec)
|
||||
treespec = tree_structure(dummy_tree)
|
||||
return treespec
|
||||
|
||||
|
|
@ -1002,6 +994,10 @@ def key_get(obj: Any, kp: KeyPath) -> Any:
|
|||
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
|
||||
|
||||
|
||||
_pytree._cxx_pytree_imported = True
|
||||
for args, kwargs in _pytree._cxx_pytree_pending_imports:
|
||||
_private_register_pytree_node(*args, **kwargs)
|
||||
with python_pytree._NODE_REGISTRY_LOCK:
|
||||
python_pytree._cxx_pytree_imported = True
|
||||
args, kwargs = (), {} # type: ignore[var-annotated]
|
||||
for args, kwargs in python_pytree._cxx_pytree_pending_imports:
|
||||
_private_register_pytree_node(*args, **kwargs)
|
||||
python_pytree._cxx_pytree_pending_imports.clear()
|
||||
del args, kwargs
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user