mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][pytree][1/N] make CXX pytree traceable: tree_iter / tree_leaves (#137397)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137397 Approved by: https://github.com/jansel ghstack dependencies: #141360
This commit is contained in:
parent
cdde73033e
commit
07850bb2c1
|
|
@ -32,7 +32,7 @@ import torch
|
||||||
import torch._dynamo.testing
|
import torch._dynamo.testing
|
||||||
import torch._inductor.test_case
|
import torch._inductor.test_case
|
||||||
import torch.onnx.operators
|
import torch.onnx.operators
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as python_pytree
|
||||||
import torch.utils.cpp_extension
|
import torch.utils.cpp_extension
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch._C import FileCheck
|
from torch._C import FileCheck
|
||||||
|
|
@ -89,9 +89,11 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||||
from torch.testing._internal.logging_utils import logs_to_string
|
from torch.testing._internal.logging_utils import logs_to_string
|
||||||
|
|
||||||
|
|
||||||
HAS_OPTREE = importlib.util.find_spec("optree")
|
HAS_OPTREE = python_pytree._cxx_pytree_exists
|
||||||
if HAS_OPTREE:
|
if HAS_OPTREE:
|
||||||
import optree
|
import torch.utils._cxx_pytree as cxx_pytree
|
||||||
|
else:
|
||||||
|
cxx_pytree = None
|
||||||
|
|
||||||
MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"])
|
MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"])
|
||||||
T = typing.TypeVar("T")
|
T = typing.TypeVar("T")
|
||||||
|
|
@ -293,9 +295,9 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_OPTREE, "missing optree package")
|
@unittest.skipIf(not HAS_OPTREE, "missing optree package")
|
||||||
def test_optree_graph_break_message(self):
|
def test_optree_graph_break_message(self):
|
||||||
@torch.compile(
|
import optree
|
||||||
backend="eager",
|
|
||||||
)
|
@torch.compile(backend="eager")
|
||||||
def fn(x):
|
def fn(x):
|
||||||
d = {"a": 1}
|
d = {"a": 1}
|
||||||
optree.tree_flatten(d)
|
optree.tree_flatten(d)
|
||||||
|
|
@ -8666,9 +8668,9 @@ def ___make_guard_fn():
|
||||||
|
|
||||||
def test_tracing_py_tree(self):
|
def test_tracing_py_tree(self):
|
||||||
def fn(xs):
|
def fn(xs):
|
||||||
flat_xs, spec = pytree.tree_flatten(xs)
|
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||||
res = [x.clone() for x in flat_xs]
|
res = [x.clone() for x in flat_xs]
|
||||||
return pytree.tree_unflatten(res, spec)
|
return python_pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
|
|
||||||
|
|
@ -8678,12 +8680,10 @@ def ___make_guard_fn():
|
||||||
self.assertEqual(counter.op_count, 3)
|
self.assertEqual(counter.op_count, 3)
|
||||||
|
|
||||||
def test_tracing_nested_py_tree(self):
|
def test_tracing_nested_py_tree(self):
|
||||||
import torch.utils._pytree as pytree
|
|
||||||
|
|
||||||
def fn(xs):
|
def fn(xs):
|
||||||
flat_xs, spec = pytree.tree_flatten(xs)
|
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||||
res = [x.clone() for x in flat_xs]
|
res = [x.clone() for x in flat_xs]
|
||||||
return pytree.tree_unflatten(res, spec)
|
return python_pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
xsl = [xs, xs, xs, xs]
|
xsl = [xs, xs, xs, xs]
|
||||||
|
|
@ -8696,12 +8696,10 @@ def ___make_guard_fn():
|
||||||
self.assertEqual(counter.op_count, 12)
|
self.assertEqual(counter.op_count, 12)
|
||||||
|
|
||||||
def test_tracing_nested_py_tree_tuples(self):
|
def test_tracing_nested_py_tree_tuples(self):
|
||||||
import torch.utils._pytree as pytree
|
|
||||||
|
|
||||||
def fn(xs):
|
def fn(xs):
|
||||||
flat_xs, spec = pytree.tree_flatten(xs)
|
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||||
res = [x.clone() for x in flat_xs]
|
res = [x.clone() for x in flat_xs]
|
||||||
return pytree.tree_unflatten(res, spec)
|
return python_pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
xsl = (xs, xs, xs, xs)
|
xsl = (xs, xs, xs, xs)
|
||||||
|
|
@ -8714,12 +8712,10 @@ def ___make_guard_fn():
|
||||||
self.assertEqual(counter.op_count, 12)
|
self.assertEqual(counter.op_count, 12)
|
||||||
|
|
||||||
def test_tracing_nested_py_tree_dicts(self):
|
def test_tracing_nested_py_tree_dicts(self):
|
||||||
import torch.utils._pytree as pytree
|
|
||||||
|
|
||||||
def fn(xs):
|
def fn(xs):
|
||||||
flat_xs, spec = pytree.tree_flatten(xs)
|
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||||
res = [x.clone() for x in flat_xs]
|
res = [x.clone() for x in flat_xs]
|
||||||
return pytree.tree_unflatten(res, spec)
|
return python_pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
xsl = {
|
xsl = {
|
||||||
|
|
@ -8752,12 +8748,10 @@ def ___make_guard_fn():
|
||||||
self.assertEqual(counter.op_count, 2)
|
self.assertEqual(counter.op_count, 2)
|
||||||
|
|
||||||
def test_tracing_nested_py_tree_mixed_all(self):
|
def test_tracing_nested_py_tree_mixed_all(self):
|
||||||
import torch.utils._pytree as pytree
|
|
||||||
|
|
||||||
def fn(xs):
|
def fn(xs):
|
||||||
flat_xs, spec = pytree.tree_flatten(xs)
|
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||||
res = [x.clone() for x in flat_xs]
|
res = [x.clone() for x in flat_xs]
|
||||||
return pytree.tree_unflatten(res, spec)
|
return python_pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)]
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
xsa = (xs, xs)
|
xsa = (xs, xs)
|
||||||
|
|
@ -8802,13 +8796,12 @@ def ___make_guard_fn():
|
||||||
self.assertEqual(cnt.frame_count, 2)
|
self.assertEqual(cnt.frame_count, 2)
|
||||||
|
|
||||||
def test_tracing_py_tree_tensor_subclass(self):
|
def test_tracing_py_tree_tensor_subclass(self):
|
||||||
import torch.utils._pytree as pytree
|
|
||||||
from torch.testing._internal.two_tensor import TwoTensor
|
from torch.testing._internal.two_tensor import TwoTensor
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
def fn(xs):
|
def fn(xs):
|
||||||
nested_xs = [[xs]]
|
nested_xs = [[xs]]
|
||||||
flat_xs, spec = pytree.tree_flatten(xs)
|
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||||
return flat_xs[0].clone()
|
return flat_xs[0].clone()
|
||||||
|
|
||||||
# use checkpoint to trigger a "sourceless" tensor subclass
|
# use checkpoint to trigger a "sourceless" tensor subclass
|
||||||
|
|
@ -8823,13 +8816,11 @@ def ___make_guard_fn():
|
||||||
self.assertEqual(counter.op_count, 2)
|
self.assertEqual(counter.op_count, 2)
|
||||||
|
|
||||||
def test_tracing_tree_map_only(self):
|
def test_tracing_tree_map_only(self):
|
||||||
import torch.utils._pytree as pytree
|
|
||||||
|
|
||||||
def fn(xs):
|
def fn(xs):
|
||||||
def mapper(x):
|
def mapper(x):
|
||||||
return x.clone()
|
return x.clone()
|
||||||
|
|
||||||
y = pytree.tree_map_only(torch.Tensor, mapper, xs)
|
y = python_pytree.tree_map_only(torch.Tensor, mapper, xs)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
xs = [torch.tensor(i) for i in range(3)] + ["hi"]
|
xs = [torch.tensor(i) for i in range(3)] + ["hi"]
|
||||||
|
|
@ -10183,7 +10174,9 @@ def ___make_guard_fn():
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
def test_pytree_tree_leaves(self):
|
def test_pytree_tree_leaves(self):
|
||||||
implemtations = [("python", pytree)]
|
implemtations = [("python", python_pytree)]
|
||||||
|
if cxx_pytree is not None:
|
||||||
|
implemtations.append(("cxx", cxx_pytree))
|
||||||
|
|
||||||
for name, module in implemtations:
|
for name, module in implemtations:
|
||||||
with self.subTest(f"pytree implement: {name}"):
|
with self.subTest(f"pytree implement: {name}"):
|
||||||
|
|
@ -10215,7 +10208,7 @@ def ___make_guard_fn():
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
def test_pytree_tree_flatten_unflatten(self):
|
def test_pytree_tree_flatten_unflatten(self):
|
||||||
implemtations = [("python", pytree)]
|
implemtations = [("python", python_pytree)]
|
||||||
|
|
||||||
for name, module in implemtations:
|
for name, module in implemtations:
|
||||||
with self.subTest(f"pytree implement: {name}"):
|
with self.subTest(f"pytree implement: {name}"):
|
||||||
|
|
@ -10264,7 +10257,7 @@ def ___make_guard_fn():
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
def test_pytree_tree_map(self):
|
def test_pytree_tree_map(self):
|
||||||
implemtations = [("python", pytree)]
|
implemtations = [("python", python_pytree)]
|
||||||
|
|
||||||
for name, module in implemtations:
|
for name, module in implemtations:
|
||||||
with self.subTest(f"pytree implement: {name}"):
|
with self.subTest(f"pytree implement: {name}"):
|
||||||
|
|
|
||||||
|
|
@ -2080,10 +2080,11 @@ class GuardBuilder(GuardBuilderBase):
|
||||||
obj_ref = None
|
obj_ref = None
|
||||||
# Not necessary to have weakref for Enum type, but there is a bug that
|
# Not necessary to have weakref for Enum type, but there is a bug that
|
||||||
# makes hasattr(guarded_object.__class__, "__weakref__") return True.
|
# makes hasattr(guarded_object.__class__, "__weakref__") return True.
|
||||||
|
supports_weakref = (
|
||||||
|
getattr(guarded_object.__class__, "__weakrefoffset__", 0) != 0
|
||||||
|
)
|
||||||
# See D64140537 for why we are checking for tuple.
|
# See D64140537 for why we are checking for tuple.
|
||||||
if hasattr(guarded_object.__class__, "__weakref__") and not isinstance(
|
if supports_weakref and not isinstance(guarded_object, (enum.Enum, tuple)):
|
||||||
guarded_object, (enum.Enum, tuple)
|
|
||||||
):
|
|
||||||
obj_ref = weakref.ref(guarded_object)
|
obj_ref = weakref.ref(guarded_object)
|
||||||
|
|
||||||
guard.set_export_info(
|
guard.set_export_info(
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
||||||
itertools as itertools,
|
itertools as itertools,
|
||||||
operator as operator,
|
operator as operator,
|
||||||
os as os,
|
os as os,
|
||||||
|
pytree as pytree,
|
||||||
sys as sys,
|
sys as sys,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
|
||||||
"itertools",
|
"itertools",
|
||||||
"operator",
|
"operator",
|
||||||
"os",
|
"os",
|
||||||
|
"pytree",
|
||||||
"sys",
|
"sys",
|
||||||
)
|
)
|
||||||
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(
|
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(
|
||||||
|
|
|
||||||
89
torch/_dynamo/polyfills/pytree.py
Normal file
89
torch/_dynamo/polyfills/pytree.py
Normal file
|
|
@ -0,0 +1,89 @@
|
||||||
|
"""
|
||||||
|
Python polyfills for torch.utils.pytree
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Callable, Iterable, TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch.utils._pytree as python_pytree
|
||||||
|
|
||||||
|
from ..decorators import substitute_in_graph
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torch.utils._cxx_pytree import PyTree
|
||||||
|
|
||||||
|
|
||||||
|
__all__: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
if python_pytree._cxx_pytree_exists:
|
||||||
|
import optree
|
||||||
|
import optree._C
|
||||||
|
|
||||||
|
import torch.utils._cxx_pytree as cxx_pytree
|
||||||
|
|
||||||
|
@substitute_in_graph(
|
||||||
|
optree._C.is_dict_insertion_ordered,
|
||||||
|
can_constant_fold_through=True,
|
||||||
|
)
|
||||||
|
def _(*args: Any, **kwargs: Any) -> bool:
|
||||||
|
# In namespace 'torch', the dictionary is always traversed in insertion order.
|
||||||
|
# This function returns True.
|
||||||
|
raise ValueError(
|
||||||
|
"Should not be called directly "
|
||||||
|
"because the original function will be called in the constant fold path."
|
||||||
|
)
|
||||||
|
|
||||||
|
__name = ""
|
||||||
|
for __name in (
|
||||||
|
"is_namedtuple",
|
||||||
|
"is_namedtuple_class",
|
||||||
|
"is_namedtuple_instance",
|
||||||
|
"is_structseq",
|
||||||
|
"is_structseq_class",
|
||||||
|
"is_structseq_instance",
|
||||||
|
"namedtuple_fields",
|
||||||
|
"structseq_fields",
|
||||||
|
):
|
||||||
|
__func = getattr(optree, __name)
|
||||||
|
substitute_in_graph(__func, can_constant_fold_through=True)(
|
||||||
|
__func.__python_implementation__
|
||||||
|
)
|
||||||
|
del __func
|
||||||
|
del __name
|
||||||
|
|
||||||
|
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
|
||||||
|
def tree_iter(
|
||||||
|
tree: PyTree,
|
||||||
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||||
|
) -> Iterable[Any]:
|
||||||
|
stack = [tree]
|
||||||
|
while stack:
|
||||||
|
node = stack.pop()
|
||||||
|
if node is None or (is_leaf is not None and is_leaf(node)):
|
||||||
|
yield node
|
||||||
|
continue
|
||||||
|
if optree.register_pytree_node.get(type(node), namespace="torch") is None: # type: ignore[attr-defined]
|
||||||
|
yield node
|
||||||
|
continue
|
||||||
|
|
||||||
|
children, *_ = optree.tree_flatten_one_level(
|
||||||
|
node,
|
||||||
|
is_leaf=is_leaf,
|
||||||
|
none_is_leaf=True,
|
||||||
|
namespace="torch",
|
||||||
|
)
|
||||||
|
stack.extend(reversed(children))
|
||||||
|
|
||||||
|
__all__ += ["tree_iter"]
|
||||||
|
|
||||||
|
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
|
||||||
|
def tree_leaves(
|
||||||
|
tree: PyTree,
|
||||||
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
return list(tree_iter(tree, is_leaf=is_leaf))
|
||||||
|
|
||||||
|
__all__ += ["tree_leaves"]
|
||||||
|
|
@ -3310,6 +3310,7 @@ MOD_INLINELIST = [
|
||||||
"torch.testing",
|
"torch.testing",
|
||||||
"torch.utils._content_store",
|
"torch.utils._content_store",
|
||||||
"torch.utils._contextlib",
|
"torch.utils._contextlib",
|
||||||
|
"torch.utils._cxx_pytree",
|
||||||
"torch.utils._device",
|
"torch.utils._device",
|
||||||
"torch.utils._foreach_utils",
|
"torch.utils._foreach_utils",
|
||||||
"torch.utils._python_dispatch",
|
"torch.utils._python_dispatch",
|
||||||
|
|
|
||||||
|
|
@ -30,10 +30,10 @@ from typing import (
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
import optree
|
import optree
|
||||||
from optree import PyTreeSpec # direct import for type annotations
|
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||||
|
|
||||||
import torch.utils._pytree as _pytree
|
import torch.utils._pytree as python_pytree
|
||||||
from torch.utils._pytree import KeyEntry
|
from torch.utils._pytree import KeyEntry as KeyEntry
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -79,7 +79,6 @@ R = TypeVar("R")
|
||||||
|
|
||||||
Context = Any
|
Context = Any
|
||||||
PyTree = Any
|
PyTree = Any
|
||||||
TreeSpec = PyTreeSpec
|
|
||||||
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
|
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
|
||||||
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
|
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
|
||||||
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
|
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
|
||||||
|
|
@ -151,9 +150,7 @@ def register_pytree_node(
|
||||||
from_dumpable_context=from_dumpable_context,
|
from_dumpable_context=from_dumpable_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
from . import _pytree as python
|
python_pytree._private_register_pytree_node(
|
||||||
|
|
||||||
python._private_register_pytree_node(
|
|
||||||
cls,
|
cls,
|
||||||
flatten_fn,
|
flatten_fn,
|
||||||
unflatten_fn,
|
unflatten_fn,
|
||||||
|
|
@ -871,24 +868,19 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
||||||
f"treespec_dumps(spec): Expected `spec` to be instance of "
|
f"treespec_dumps(spec): Expected `spec` to be instance of "
|
||||||
f"TreeSpec but got item of type {type(treespec)}."
|
f"TreeSpec but got item of type {type(treespec)}."
|
||||||
)
|
)
|
||||||
from ._pytree import (
|
|
||||||
tree_structure as _tree_structure,
|
|
||||||
treespec_dumps as _treespec_dumps,
|
|
||||||
)
|
|
||||||
|
|
||||||
orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec))
|
dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
|
||||||
return _treespec_dumps(orig_treespec, protocol=protocol)
|
orig_treespec = python_pytree.tree_structure(dummy_tree)
|
||||||
|
return python_pytree.treespec_dumps(orig_treespec, protocol=protocol)
|
||||||
|
|
||||||
|
|
||||||
def treespec_loads(serialized: str) -> TreeSpec:
|
def treespec_loads(serialized: str) -> TreeSpec:
|
||||||
"""Deserialize a treespec from a JSON string."""
|
"""Deserialize a treespec from a JSON string."""
|
||||||
from ._pytree import (
|
orig_treespec = python_pytree.treespec_loads(serialized)
|
||||||
tree_unflatten as _tree_unflatten,
|
dummy_tree = python_pytree.tree_unflatten(
|
||||||
treespec_loads as _treespec_loads,
|
[0] * orig_treespec.num_leaves,
|
||||||
|
orig_treespec,
|
||||||
)
|
)
|
||||||
|
|
||||||
orig_treespec = _treespec_loads(serialized)
|
|
||||||
dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec)
|
|
||||||
treespec = tree_structure(dummy_tree)
|
treespec = tree_structure(dummy_tree)
|
||||||
return treespec
|
return treespec
|
||||||
|
|
||||||
|
|
@ -1002,6 +994,10 @@ def key_get(obj: Any, kp: KeyPath) -> Any:
|
||||||
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
|
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
|
||||||
|
|
||||||
|
|
||||||
_pytree._cxx_pytree_imported = True
|
with python_pytree._NODE_REGISTRY_LOCK:
|
||||||
for args, kwargs in _pytree._cxx_pytree_pending_imports:
|
python_pytree._cxx_pytree_imported = True
|
||||||
|
args, kwargs = (), {} # type: ignore[var-annotated]
|
||||||
|
for args, kwargs in python_pytree._cxx_pytree_pending_imports:
|
||||||
_private_register_pytree_node(*args, **kwargs)
|
_private_register_pytree_node(*args, **kwargs)
|
||||||
|
python_pytree._cxx_pytree_pending_imports.clear()
|
||||||
|
del args, kwargs
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user