From 89a1fe69667c16b76de1cd87b8dda8ffd77762d3 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 28 Nov 2023 15:27:07 +0800 Subject: [PATCH] [pytree] register pytree node type in both C++ pytree and Python pytree (#112111) Changes: 1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree. 2. Do not allow registering a type as pytree node twice in the Python pytree. 3. Add thread lock to the Python pytree node register API. 4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning. 5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations. 6. Add tests to ensure a warning will be raised when the old private function is called. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111 Approved by: https://github.com/zou3519 --- test/export/test_export.py | 17 +- test/test_fx.py | 2 +- test/test_pytree.py | 71 +++++- torch/_export/utils.py | 4 +- torch/_functorch/aot_autograd.py | 15 +- torch/fx/experimental/proxy_tensor.py | 2 +- torch/fx/immutable_collections.py | 6 +- .../_internal/fx/dynamo_graph_extractor.py | 13 +- torch/return_types.py | 2 +- torch/utils/_cxx_pytree.py | 203 +++++++++++++++++- torch/utils/_pytree.py | 150 ++++++++++--- 11 files changed, 415 insertions(+), 70 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index caa576bfa98..27e44f27aea 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -624,16 +624,23 @@ class TestExport(TestCase): roundtrip_spec = treespec_loads(treespec_dumps(spec)) self.assertEqual(roundtrip_spec, spec) - # Override the registration with keep none fields - register_dataclass_as_pytree_node(MyDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyDataClass") + @dataclass + class MyOtherDataClass: # the pytree registration don't allow registering the same class twice + x: int + y: int + z: int = None + # Override the registration with keep none fields + register_dataclass_as_pytree_node(MyOtherDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyOtherDataClass") + + dt = MyOtherDataClass(x=3, y=4) flat, spec = tree_flatten(dt) self.assertEqual( spec, TreeSpec( - MyDataClass, + MyOtherDataClass, ( - MyDataClass, + MyOtherDataClass, ['x', 'y', 'z'], [], ), @@ -643,7 +650,7 @@ class TestExport(TestCase): self.assertEqual(flat, [3, 4, None]) orig_dt = tree_unflatten(flat, spec) - self.assertTrue(isinstance(orig_dt, MyDataClass)) + self.assertTrue(isinstance(orig_dt, MyOtherDataClass)) self.assertEqual(orig_dt.x, 3) self.assertEqual(orig_dt.y, 4) self.assertEqual(orig_dt.z, None) diff --git a/test/test_fx.py b/test/test_fx.py index 8de7c3dd6a9..fa63e79cb46 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -3529,7 +3529,7 @@ class TestFX(JitTestCase): def f_namedtuple_add(x): return x.x + x.y - pytree._register_pytree_node( + pytree.register_pytree_node( Foo, lambda x: ([x.a, x.b], None), lambda x, _: Foo(x[0], x[1]), diff --git a/test/test_pytree.py b/test/test_pytree.py index 0c0120397ee..ab96a9e1f3c 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -1,7 +1,7 @@ # Owner(s): ["module: pytree"] import unittest -from collections import namedtuple, OrderedDict +from collections import namedtuple, OrderedDict, UserDict import torch import torch.utils._cxx_pytree as cxx_pytree @@ -26,6 +26,45 @@ class GlobalDummyType: class TestGenericPytree(TestCase): + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_register_pytree_node(self, pytree_impl): + class MyDict(UserDict): + pass + + d = MyDict(a=1, b=2, c=3) + + # Custom types are leaf nodes by default + values, spec = pytree_impl.tree_flatten(d) + self.assertEqual(values, [d]) + self.assertIs(values[0], d) + self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) + self.assertTrue(spec.is_leaf()) + + # Register MyDict as a pytree node + pytree_impl.register_pytree_node( + MyDict, + lambda d: (list(d.values()), list(d.keys())), + lambda values, keys: MyDict(zip(keys, values)), + ) + + values, spec = pytree_impl.tree_flatten(d) + self.assertEqual(values, [1, 2, 3]) + self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) + + # Do not allow registering the same type twice + with self.assertRaisesRegex(ValueError, "already registered"): + pytree_impl.register_pytree_node( + MyDict, + lambda d: (list(d.values()), list(d.keys())), + lambda values, keys: MyDict(zip(keys, values)), + ) + @parametrize( "pytree_impl", [ @@ -407,6 +446,28 @@ class TestGenericPytree(TestCase): class TestPythonPytree(TestCase): + def test_deprecated_register_pytree_node(self): + class DummyType: + def __init__(self, x, y): + self.x = x + self.y = y + + with self.assertWarnsRegex( + UserWarning, "torch.utils._pytree._register_pytree_node" + ): + py_pytree._register_pytree_node( + DummyType, + lambda dummy: ([dummy.x, dummy.y], None), + lambda xs, _: DummyType(*xs), + ) + + with self.assertWarnsRegex(UserWarning, "already registered"): + py_pytree._register_pytree_node( + DummyType, + lambda dummy: ([dummy.x, dummy.y], None), + lambda xs, _: DummyType(*xs), + ) + def test_treespec_equality(self): self.assertTrue( py_pytree.LeafSpec() == py_pytree.LeafSpec(), @@ -540,7 +601,7 @@ TreeSpec(tuple, None, [*, self.x = x self.y = y - py_pytree._register_pytree_node( + py_pytree.register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -560,7 +621,7 @@ TreeSpec(tuple, None, [*, self.x = x self.y = y - py_pytree._register_pytree_node( + py_pytree.register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -585,7 +646,7 @@ TreeSpec(tuple, None, [*, with self.assertRaisesRegex( ValueError, "Both to_dumpable_context and from_dumpable_context" ): - py_pytree._register_pytree_node( + py_pytree.register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -599,7 +660,7 @@ TreeSpec(tuple, None, [*, self.x = x self.y = y - py_pytree._register_pytree_node( + py_pytree.register_pytree_node( DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), diff --git a/torch/_export/utils.py b/torch/_export/utils.py index afee8efc594..d8344783a0a 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -63,16 +63,16 @@ def register_dataclass_as_pytree_node( flatten_fn: Optional[FlattenFunc] = None, unflatten_fn: Optional[UnflattenFunc] = None, *, + serialized_type_name: Optional[str] = None, to_dumpable_context: Optional[ToDumpableContextFn] = None, from_dumpable_context: Optional[FromDumpableContextFn] = None, - serialized_type_name: Optional[str] = None, return_none_fields: bool = False, ) -> None: assert dataclasses.is_dataclass( cls ), f"Only dataclasses can be registered with this function: {cls}" - serialized_type = f"{cls.__module__}.{cls.__name__}" + serialized_type = f"{cls.__module__}.{cls.__qualname__}" SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index db83c84e8a6..ff48fd2bb1b 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -29,7 +29,7 @@ from torch._logging import getArtifactLogger from torch._subclasses import FakeTensor, FakeTensorMode from torch._subclasses.fake_tensor import is_fake from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode -from torch.fx import immutable_collections, Interpreter +from torch.fx import Interpreter from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types from torch.fx.experimental.symbolic_shapes import ( ShapeEnv, is_concrete_int, fx_placeholder_vals, definitely_true, definitely_false, sym_eq @@ -95,19 +95,6 @@ OutputType = Enum( ) ) -pytree._register_pytree_node( - immutable_collections.immutable_list, - lambda x: (list(x), None), - lambda x, c: immutable_collections.immutable_list(x), -) -pytree._register_pytree_node( - immutable_collections.immutable_dict, - lambda x: (list(x.values()), list(x.keys())), - lambda x, c: immutable_collections.immutable_dict( - dict(zip(c, x)) - ), -) - def partial_asdict(obj: Any) -> Any: if dataclasses.is_dataclass(obj): return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)} diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index dd3520f541a..e3d8bd673a4 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -49,7 +49,7 @@ CONSTANT_NUMEL_LIMIT = 1 # We currently convert all SymInt to proxies before we use them. # This could plausibly be handled at the Dynamo level. -pytree._register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs)) +pytree.register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs)) def fake_signature(fn, nargs): """FX gets confused by varargs, de-confuse it""" diff --git a/torch/fx/immutable_collections.py b/torch/fx/immutable_collections.py index 616555015f0..a359335f6ec 100644 --- a/torch/fx/immutable_collections.py +++ b/torch/fx/immutable_collections.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Iterable, List, Tuple from ._compatibility import compatibility -from torch.utils._pytree import Context, _register_pytree_node +from torch.utils._pytree import Context, register_pytree_node __all__ = ["immutable_list", "immutable_dict"] @@ -50,5 +50,5 @@ def _immutable_list_unflatten(values: Iterable[Any], context: Context) -> List[A return immutable_list(values) -_register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten) -_register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten) +register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten) +register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten) diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py index f55afefd1bb..79a690f5f48 100644 --- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py +++ b/torch/onnx/_internal/fx/dynamo_graph_extractor.py @@ -40,7 +40,11 @@ class _PyTreeExtensionContext: def __enter__(self): for class_type, (flatten_func, unflatten_func) in self._extensions.items(): - pytree._register_pytree_node(class_type, flatten_func, unflatten_func) + pytree._private_register_pytree_node( + class_type, + flatten_func, + unflatten_func, + ) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -93,8 +97,11 @@ class _PyTreeExtensionContext: # All 'ModelOutput' subclasses are defined under module 'modeling_outputs'. named_model_output_classes = inspect.getmembers( modeling_outputs, - lambda x: inspect.isclass(x) - and issubclass(x, modeling_outputs.ModelOutput), + lambda x: ( + inspect.isclass(x) + and issubclass(x, modeling_outputs.ModelOutput) + and x is not modeling_outputs.ModelOutput + ), ) for _, class_type in named_model_output_classes: diff --git a/torch/return_types.py b/torch/return_types.py index 9f8c8528527..b1284c81338 100644 --- a/torch/return_types.py +++ b/torch/return_types.py @@ -13,7 +13,7 @@ def pytree_register_structseq(cls): def structseq_unflatten(values, context): return cls(values) - torch.utils._pytree._register_pytree_node(cls, structseq_flatten, structseq_unflatten) + torch.utils._pytree.register_pytree_node(cls, structseq_flatten, structseq_unflatten) for name in dir(return_types): if name.startswith('__'): diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 06309499ec4..6e55c21a511 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -13,6 +13,7 @@ collection support for PyTorch APIs. """ import functools +import warnings from typing import ( Any, Callable, @@ -26,6 +27,11 @@ from typing import ( Union, ) +import torch + +if torch._running_with_deploy(): + raise ImportError("C++ pytree utilities do not work with torch::deploy.") + import optree from optree import PyTreeSpec # direct import for type annotations @@ -35,6 +41,9 @@ __all__ = [ "Context", "FlattenFunc", "UnflattenFunc", + "DumpableContext", + "ToDumpableContextFn", + "FromDumpableContextFn", "TreeSpec", "LeafSpec", "register_pytree_node", @@ -68,6 +77,9 @@ TreeSpec = PyTreeSpec FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree] +DumpableContext = Any # Any json dumpable text +ToDumpableContextFn = Callable[[Context], DumpableContext] +FromDumpableContextFn = Callable[[DumpableContext], Context] def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc: @@ -84,9 +96,11 @@ def register_pytree_node( unflatten_fn: UnflattenFunc, *, serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, namespace: str = "torch", ) -> None: - """Extend the set of types that are considered internal nodes in pytrees. + """Register a container-like type as pytree node. The ``namespace`` argument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix @@ -109,6 +123,13 @@ def register_pytree_node( The function should return an instance of ``cls``. serialized_type_name (str, optional): A keyword argument used to specify the fully qualified name used when serializing the tree spec. + to_dumpable_context (callable, optional): An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable representation. This is + used for json serialization, which is being used in :mod:`torch.export` right now. + from_dumpable_context (callable, optional): An optional keyword argument to custom specify + how to convert the custom json dumpable representation of the context back to the + original context. This is used for json deserialization, which is being used in + :mod:`torch.export` right now. namespace (str, optional): A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type. (default: :const:`"torch"`) @@ -193,24 +214,192 @@ def register_pytree_node( ) ) """ - from ._pytree import _register_pytree_node - - _register_pytree_node( + _private_register_pytree_node( cls, flatten_fn, unflatten_fn, serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + namespace=namespace, ) - optree.register_pytree_node( + from . import _pytree as python + + python._private_register_pytree_node( cls, flatten_fn, - _reverse_args(unflatten_fn), + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def _register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + namespace: str = "torch", +) -> None: + """Register a container-like type as pytree node for the C++ pytree only. + + The ``namespace`` argument is used to avoid collisions that occur when different libraries + register the same Python type with different behaviors. It is recommended to add a unique prefix + to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify + the same class in different namespaces for different use cases. + + .. warning:: + For safety reasons, a ``namespace`` must be specified while registering a custom type. It is + used to isolate the behavior of flattening and unflattening a pytree node type. This is to + prevent accidental collisions between different libraries that may register the same type. + + Args: + cls (type): A Python type to treat as an internal pytree node. + flatten_fn (callable): A function to be used during flattening, taking an instance of + ``cls`` and returning a pair, with (1) an iterable for the children to be flattened + recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be + passed to the ``unflatten_fn``. + unflatten_fn (callable): A function taking two arguments: the auxiliary data that was + returned by ``flatten_fn`` and stored in the treespec, and the unflattened children. + The function should return an instance of ``cls``. + serialized_type_name (str, optional): A keyword argument used to specify the fully + qualified name used when serializing the tree spec. + to_dumpable_context (callable, optional): An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable representation. This is + used for json serialization, which is being used in :mod:`torch.export` right now. + from_dumpable_context (callable, optional): An optional keyword argument to custom specify + how to convert the custom json dumpable representation of the context back to the + original context. This is used for json deserialization, which is being used in + :mod:`torch.export` right now. + namespace (str, optional): A non-empty string that uniquely identifies the namespace of the + type registry. This is used to isolate the registry from other modules that might + register a different custom behavior for the same type. (default: :const:`"torch"`) + + Example:: + + >>> # xdoctest: +SKIP + >>> # Registry a Python type with lambda functions + >>> register_pytree_node( + ... set, + ... lambda s: (sorted(s), None, None), + ... lambda children, _: set(children), + ... namespace='set', + ... ) + + >>> # xdoctest: +SKIP + >>> # Register a Python type into a namespace + >>> import torch + >>> register_pytree_node( + ... torch.Tensor, + ... flatten_func=lambda tensor: ( + ... (tensor.cpu().detach().numpy(),), + ... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad}, + ... ), + ... unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata), + ... namespace='torch2numpy', + ... ) + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))} + >>> tree + {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])} + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # Flatten without specifying the namespace + >>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes # xdoctest: +SKIP + ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *})) + + >>> # xdoctest: +SKIP + >>> # Flatten with the namespace + >>> tree_flatten(tree, namespace='torch2numpy') # xdoctest: +SKIP + ( + [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)], + PyTreeSpec( + { + 'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]), + 'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]) + }, + namespace='torch2numpy' + ) + ) + + >>> # xdoctest: +SKIP + >>> # Register the same type with a different namespace for different behaviors + >>> def tensor2flatparam(tensor): + ... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None + ... + >>> def flatparam2tensor(children, metadata): + ... return children[0].reshape(metadata) + ... + >>> register_pytree_node( + ... torch.Tensor, + ... flatten_func=tensor2flatparam, + ... unflatten_func=flatparam2tensor, + ... namespace='tensor2flatparam', + ... ) + + >>> # xdoctest: +SKIP + >>> # Flatten with the new namespace + >>> tree_flatten(tree, namespace='tensor2flatparam') # xdoctest: +SKIP + ( + [ + Parameter containing: tensor([0., 0.], requires_grad=True), + Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True) + ], + PyTreeSpec( + { + 'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]), + 'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*]) + }, + namespace='tensor2flatparam' + ) + ) + """ + warnings.warn( + "torch.utils._cxx_pytree._register_pytree_node is deprecated. " + "Please use torch.utils._cxx_pytree.register_pytree_node instead.", + stacklevel=2, + ) + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, namespace=namespace, ) -_register_pytree_node = register_pytree_node +def _private_register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + namespace: str = "torch", +) -> None: + """This is an internal function that is used to register a pytree node type + for the C++ pytree only. End-users should use :func:`register_pytree_node` + instead. + """ + # TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support + # PyStructSequence types + if not optree.is_structseq_class(cls): + optree.register_pytree_node( + cls, + flatten_fn, + _reverse_args(unflatten_fn), + namespace=namespace, + ) def tree_flatten( diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 6821a3acb49..4e085121ef4 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -17,6 +17,7 @@ To improve the performance we can move parts of the implementation to C++. import dataclasses import json +import threading import warnings from collections import deque, namedtuple, OrderedDict from typing import ( @@ -99,6 +100,7 @@ class NodeDef(NamedTuple): unflatten_fn: UnflattenFunc +_NODE_REGISTRY_LOCK = threading.Lock() SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} @@ -120,18 +122,17 @@ SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {} SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {} -def _register_pytree_node( +def register_pytree_node( cls: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc, - to_str_fn: Optional[ToStrFunc] = None, # deprecated - maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated *, serialized_type_name: Optional[str] = None, to_dumpable_context: Optional[ToDumpableContextFn] = None, from_dumpable_context: Optional[FromDumpableContextFn] = None, ) -> None: - """ + """Register a container-like type as pytree node. + Args: cls: the type to register flatten_fn: A callable that takes a pytree and returns a flattened @@ -150,39 +151,132 @@ def _register_pytree_node( back to the original context. This is used for json deserialization, which is being used in torch.export right now. """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + raise ValueError(f"{cls} is already registered as pytree node.") + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + try: + from . import _cxx_pytree as cxx + except ImportError: + pass + else: + cxx._private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def _register_pytree_node( + cls: Any, + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + to_str_fn: Optional[ToStrFunc] = None, # deprecated + maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, +) -> None: + """Register a container-like type as pytree node for the Python pytree only. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + """ + warnings.warn( + "torch.utils._pytree._register_pytree_node is deprecated. " + "Please use torch.utils._pytree.register_pytree_node instead.", + stacklevel=2, + ) + if to_str_fn is not None or maybe_from_str_fn is not None: warnings.warn( "to_str_fn and maybe_from_str_fn is deprecated. " "Please use to_dumpable_context and from_dumpable_context instead." ) - node_def = NodeDef( + _private_register_pytree_node( cls, flatten_fn, unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, ) - SUPPORTED_NODES[cls] = node_def - if (to_dumpable_context is None) ^ (from_dumpable_context is None): - raise ValueError( - f"Both to_dumpable_context and from_dumpable_context for {cls} must " - "be None or registered." + +def _private_register_pytree_node( + cls: Any, + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, +) -> None: + """This is an internal function that is used to register a pytree node type + for the Python pytree only. End-users should use :func:`register_pytree_node` + instead. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + # TODO: change this warning to an error after OSS/internal stabilize + warnings.warn( + f"{cls} is already registered as pytree node. " + "Overwriting the previous registration.", + ) + + node_def = NodeDef( + cls, + flatten_fn, + unflatten_fn, ) + SUPPORTED_NODES[cls] = node_def - if serialized_type_name is None: - serialized_type_name = f"{cls.__module__}.{cls.__name__}" + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." + ) - serialize_node_def = _SerializeNodeDef( - cls, - serialized_type_name, - to_dumpable_context, - from_dumpable_context, - ) - SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def - SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls + if serialized_type_name is None: + serialized_type_name = f"{cls.__module__}.{cls.__qualname__}" - -register_pytree_node = _register_pytree_node + serialize_node_def = _SerializeNodeDef( + cls, + serialized_type_name, + to_dumpable_context, + from_dumpable_context, + ) + SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def + SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: @@ -243,25 +337,25 @@ def _odict_unflatten( return OrderedDict((key, value) for key, value in zip(context, values)) -_register_pytree_node( +_private_register_pytree_node( dict, _dict_flatten, _dict_unflatten, serialized_type_name="builtins.dict", ) -_register_pytree_node( +_private_register_pytree_node( list, _list_flatten, _list_unflatten, serialized_type_name="builtins.list", ) -_register_pytree_node( +_private_register_pytree_node( tuple, _tuple_flatten, _tuple_unflatten, serialized_type_name="builtins.tuple", ) -_register_pytree_node( +_private_register_pytree_node( namedtuple, _namedtuple_flatten, _namedtuple_unflatten, @@ -269,7 +363,7 @@ _register_pytree_node( from_dumpable_context=_namedtuple_deserialize, serialized_type_name="collections.namedtuple", ) -_register_pytree_node( +_private_register_pytree_node( OrderedDict, _odict_flatten, _odict_unflatten, @@ -729,7 +823,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: if treespec.type not in SUPPORTED_SERIALIZED_TYPES: raise NotImplementedError( - f"Serializing {treespec.type} in pytree is not registered." + f"Serializing {treespec.type} in pytree is not registered.", ) serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]