[dynamo][pytree][1/N] make CXX pytree traceable: tree_iter / tree_leaves (#137397)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137397
Approved by: https://github.com/jansel
ghstack dependencies: #141360
This commit is contained in:
Xuehai Pan 2024-11-27 02:54:50 +08:00 committed by PyTorch MergeBot
parent cdde73033e
commit 07850bb2c1
7 changed files with 139 additions and 57 deletions

View File

@ -32,7 +32,7 @@ import torch
import torch._dynamo.testing import torch._dynamo.testing
import torch._inductor.test_case import torch._inductor.test_case
import torch.onnx.operators import torch.onnx.operators
import torch.utils._pytree as pytree import torch.utils._pytree as python_pytree
import torch.utils.cpp_extension import torch.utils.cpp_extension
from torch import Tensor from torch import Tensor
from torch._C import FileCheck from torch._C import FileCheck
@ -89,9 +89,11 @@ from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.logging_utils import logs_to_string from torch.testing._internal.logging_utils import logs_to_string
HAS_OPTREE = importlib.util.find_spec("optree") HAS_OPTREE = python_pytree._cxx_pytree_exists
if HAS_OPTREE: if HAS_OPTREE:
import optree import torch.utils._cxx_pytree as cxx_pytree
else:
cxx_pytree = None
MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"]) MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"])
T = typing.TypeVar("T") T = typing.TypeVar("T")
@ -293,9 +295,9 @@ class MiscTests(torch._inductor.test_case.TestCase):
@unittest.skipIf(not HAS_OPTREE, "missing optree package") @unittest.skipIf(not HAS_OPTREE, "missing optree package")
def test_optree_graph_break_message(self): def test_optree_graph_break_message(self):
@torch.compile( import optree
backend="eager",
) @torch.compile(backend="eager")
def fn(x): def fn(x):
d = {"a": 1} d = {"a": 1}
optree.tree_flatten(d) optree.tree_flatten(d)
@ -8666,9 +8668,9 @@ def ___make_guard_fn():
def test_tracing_py_tree(self): def test_tracing_py_tree(self):
def fn(xs): def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs) flat_xs, spec = python_pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs] res = [x.clone() for x in flat_xs]
return pytree.tree_unflatten(res, spec) return python_pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)] xs = [torch.tensor(i) for i in range(3)]
@ -8678,12 +8680,10 @@ def ___make_guard_fn():
self.assertEqual(counter.op_count, 3) self.assertEqual(counter.op_count, 3)
def test_tracing_nested_py_tree(self): def test_tracing_nested_py_tree(self):
import torch.utils._pytree as pytree
def fn(xs): def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs) flat_xs, spec = python_pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs] res = [x.clone() for x in flat_xs]
return pytree.tree_unflatten(res, spec) return python_pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)] xs = [torch.tensor(i) for i in range(3)]
xsl = [xs, xs, xs, xs] xsl = [xs, xs, xs, xs]
@ -8696,12 +8696,10 @@ def ___make_guard_fn():
self.assertEqual(counter.op_count, 12) self.assertEqual(counter.op_count, 12)
def test_tracing_nested_py_tree_tuples(self): def test_tracing_nested_py_tree_tuples(self):
import torch.utils._pytree as pytree
def fn(xs): def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs) flat_xs, spec = python_pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs] res = [x.clone() for x in flat_xs]
return pytree.tree_unflatten(res, spec) return python_pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)] xs = [torch.tensor(i) for i in range(3)]
xsl = (xs, xs, xs, xs) xsl = (xs, xs, xs, xs)
@ -8714,12 +8712,10 @@ def ___make_guard_fn():
self.assertEqual(counter.op_count, 12) self.assertEqual(counter.op_count, 12)
def test_tracing_nested_py_tree_dicts(self): def test_tracing_nested_py_tree_dicts(self):
import torch.utils._pytree as pytree
def fn(xs): def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs) flat_xs, spec = python_pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs] res = [x.clone() for x in flat_xs]
return pytree.tree_unflatten(res, spec) return python_pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)] xs = [torch.tensor(i) for i in range(3)]
xsl = { xsl = {
@ -8752,12 +8748,10 @@ def ___make_guard_fn():
self.assertEqual(counter.op_count, 2) self.assertEqual(counter.op_count, 2)
def test_tracing_nested_py_tree_mixed_all(self): def test_tracing_nested_py_tree_mixed_all(self):
import torch.utils._pytree as pytree
def fn(xs): def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs) flat_xs, spec = python_pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs] res = [x.clone() for x in flat_xs]
return pytree.tree_unflatten(res, spec) return python_pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)] xs = [torch.tensor(i) for i in range(3)]
xsa = (xs, xs) xsa = (xs, xs)
@ -8802,13 +8796,12 @@ def ___make_guard_fn():
self.assertEqual(cnt.frame_count, 2) self.assertEqual(cnt.frame_count, 2)
def test_tracing_py_tree_tensor_subclass(self): def test_tracing_py_tree_tensor_subclass(self):
import torch.utils._pytree as pytree
from torch.testing._internal.two_tensor import TwoTensor from torch.testing._internal.two_tensor import TwoTensor
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
def fn(xs): def fn(xs):
nested_xs = [[xs]] nested_xs = [[xs]]
flat_xs, spec = pytree.tree_flatten(xs) flat_xs, spec = python_pytree.tree_flatten(xs)
return flat_xs[0].clone() return flat_xs[0].clone()
# use checkpoint to trigger a "sourceless" tensor subclass # use checkpoint to trigger a "sourceless" tensor subclass
@ -8823,13 +8816,11 @@ def ___make_guard_fn():
self.assertEqual(counter.op_count, 2) self.assertEqual(counter.op_count, 2)
def test_tracing_tree_map_only(self): def test_tracing_tree_map_only(self):
import torch.utils._pytree as pytree
def fn(xs): def fn(xs):
def mapper(x): def mapper(x):
return x.clone() return x.clone()
y = pytree.tree_map_only(torch.Tensor, mapper, xs) y = python_pytree.tree_map_only(torch.Tensor, mapper, xs)
return y return y
xs = [torch.tensor(i) for i in range(3)] + ["hi"] xs = [torch.tensor(i) for i in range(3)] + ["hi"]
@ -10183,7 +10174,9 @@ def ___make_guard_fn():
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
def test_pytree_tree_leaves(self): def test_pytree_tree_leaves(self):
implemtations = [("python", pytree)] implemtations = [("python", python_pytree)]
if cxx_pytree is not None:
implemtations.append(("cxx", cxx_pytree))
for name, module in implemtations: for name, module in implemtations:
with self.subTest(f"pytree implement: {name}"): with self.subTest(f"pytree implement: {name}"):
@ -10215,7 +10208,7 @@ def ___make_guard_fn():
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
def test_pytree_tree_flatten_unflatten(self): def test_pytree_tree_flatten_unflatten(self):
implemtations = [("python", pytree)] implemtations = [("python", python_pytree)]
for name, module in implemtations: for name, module in implemtations:
with self.subTest(f"pytree implement: {name}"): with self.subTest(f"pytree implement: {name}"):
@ -10264,7 +10257,7 @@ def ___make_guard_fn():
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
def test_pytree_tree_map(self): def test_pytree_tree_map(self):
implemtations = [("python", pytree)] implemtations = [("python", python_pytree)]
for name, module in implemtations: for name, module in implemtations:
with self.subTest(f"pytree implement: {name}"): with self.subTest(f"pytree implement: {name}"):

View File

@ -2080,10 +2080,11 @@ class GuardBuilder(GuardBuilderBase):
obj_ref = None obj_ref = None
# Not necessary to have weakref for Enum type, but there is a bug that # Not necessary to have weakref for Enum type, but there is a bug that
# makes hasattr(guarded_object.__class__, "__weakref__") return True. # makes hasattr(guarded_object.__class__, "__weakref__") return True.
supports_weakref = (
getattr(guarded_object.__class__, "__weakrefoffset__", 0) != 0
)
# See D64140537 for why we are checking for tuple. # See D64140537 for why we are checking for tuple.
if hasattr(guarded_object.__class__, "__weakref__") and not isinstance( if supports_weakref and not isinstance(guarded_object, (enum.Enum, tuple)):
guarded_object, (enum.Enum, tuple)
):
obj_ref = weakref.ref(guarded_object) obj_ref = weakref.ref(guarded_object)
guard.set_export_info( guard.set_export_info(

View File

@ -23,6 +23,7 @@ if TYPE_CHECKING:
itertools as itertools, itertools as itertools,
operator as operator, operator as operator,
os as os, os as os,
pytree as pytree,
sys as sys, sys as sys,
) )

View File

@ -18,6 +18,7 @@ POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
"itertools", "itertools",
"operator", "operator",
"os", "os",
"pytree",
"sys", "sys",
) )
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple( POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(

View File

@ -0,0 +1,89 @@
"""
Python polyfills for torch.utils.pytree
"""
from __future__ import annotations
from typing import Any, Callable, Iterable, TYPE_CHECKING
import torch.utils._pytree as python_pytree
from ..decorators import substitute_in_graph
if TYPE_CHECKING:
from torch.utils._cxx_pytree import PyTree
__all__: list[str] = []
if python_pytree._cxx_pytree_exists:
import optree
import optree._C
import torch.utils._cxx_pytree as cxx_pytree
@substitute_in_graph(
optree._C.is_dict_insertion_ordered,
can_constant_fold_through=True,
)
def _(*args: Any, **kwargs: Any) -> bool:
# In namespace 'torch', the dictionary is always traversed in insertion order.
# This function returns True.
raise ValueError(
"Should not be called directly "
"because the original function will be called in the constant fold path."
)
__name = ""
for __name in (
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
"namedtuple_fields",
"structseq_fields",
):
__func = getattr(optree, __name)
substitute_in_graph(__func, can_constant_fold_through=True)(
__func.__python_implementation__
)
del __func
del __name
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
def tree_iter(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> Iterable[Any]:
stack = [tree]
while stack:
node = stack.pop()
if node is None or (is_leaf is not None and is_leaf(node)):
yield node
continue
if optree.register_pytree_node.get(type(node), namespace="torch") is None: # type: ignore[attr-defined]
yield node
continue
children, *_ = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
stack.extend(reversed(children))
__all__ += ["tree_iter"]
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
def tree_leaves(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> list[Any]:
return list(tree_iter(tree, is_leaf=is_leaf))
__all__ += ["tree_leaves"]

View File

@ -3310,6 +3310,7 @@ MOD_INLINELIST = [
"torch.testing", "torch.testing",
"torch.utils._content_store", "torch.utils._content_store",
"torch.utils._contextlib", "torch.utils._contextlib",
"torch.utils._cxx_pytree",
"torch.utils._device", "torch.utils._device",
"torch.utils._foreach_utils", "torch.utils._foreach_utils",
"torch.utils._python_dispatch", "torch.utils._python_dispatch",

View File

@ -30,10 +30,10 @@ from typing import (
from typing_extensions import deprecated from typing_extensions import deprecated
import optree import optree
from optree import PyTreeSpec # direct import for type annotations from optree import PyTreeSpec as TreeSpec # direct import for type annotations
import torch.utils._pytree as _pytree import torch.utils._pytree as python_pytree
from torch.utils._pytree import KeyEntry from torch.utils._pytree import KeyEntry as KeyEntry
__all__ = [ __all__ = [
@ -79,7 +79,6 @@ R = TypeVar("R")
Context = Any Context = Any
PyTree = Any PyTree = Any
TreeSpec = PyTreeSpec
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree] OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
@ -151,9 +150,7 @@ def register_pytree_node(
from_dumpable_context=from_dumpable_context, from_dumpable_context=from_dumpable_context,
) )
from . import _pytree as python python_pytree._private_register_pytree_node(
python._private_register_pytree_node(
cls, cls,
flatten_fn, flatten_fn,
unflatten_fn, unflatten_fn,
@ -871,24 +868,19 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
f"treespec_dumps(spec): Expected `spec` to be instance of " f"treespec_dumps(spec): Expected `spec` to be instance of "
f"TreeSpec but got item of type {type(treespec)}." f"TreeSpec but got item of type {type(treespec)}."
) )
from ._pytree import (
tree_structure as _tree_structure,
treespec_dumps as _treespec_dumps,
)
orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec)) dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
return _treespec_dumps(orig_treespec, protocol=protocol) orig_treespec = python_pytree.tree_structure(dummy_tree)
return python_pytree.treespec_dumps(orig_treespec, protocol=protocol)
def treespec_loads(serialized: str) -> TreeSpec: def treespec_loads(serialized: str) -> TreeSpec:
"""Deserialize a treespec from a JSON string.""" """Deserialize a treespec from a JSON string."""
from ._pytree import ( orig_treespec = python_pytree.treespec_loads(serialized)
tree_unflatten as _tree_unflatten, dummy_tree = python_pytree.tree_unflatten(
treespec_loads as _treespec_loads, [0] * orig_treespec.num_leaves,
orig_treespec,
) )
orig_treespec = _treespec_loads(serialized)
dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec)
treespec = tree_structure(dummy_tree) treespec = tree_structure(dummy_tree)
return treespec return treespec
@ -1002,6 +994,10 @@ def key_get(obj: Any, kp: KeyPath) -> Any:
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
_pytree._cxx_pytree_imported = True with python_pytree._NODE_REGISTRY_LOCK:
for args, kwargs in _pytree._cxx_pytree_pending_imports: python_pytree._cxx_pytree_imported = True
_private_register_pytree_node(*args, **kwargs) args, kwargs = (), {} # type: ignore[var-annotated]
for args, kwargs in python_pytree._cxx_pytree_pending_imports:
_private_register_pytree_node(*args, **kwargs)
python_pytree._cxx_pytree_pending_imports.clear()
del args, kwargs