Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)"

This reverts commit a0d00349ed.

Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/PaliC due to _private_register_pytree_node now checks for duplicate registering, unfortunately, this breaks composability with torchrec internally :(  ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1806130993))
This commit is contained in:
PyTorch MergeBot 2023-11-10 17:24:40 +00:00
parent 6e714d7315
commit 2a271a3efa
6 changed files with 35 additions and 129 deletions

View File

@ -622,23 +622,16 @@ class TestExport(TestCase):
roundtrip_spec = treespec_loads(treespec_dumps(spec))
self.assertEqual(roundtrip_spec, spec)
@dataclass
class MyOtherDataClass: # the pytree registration don't allow registering the same class twice
x: int
y: int
z: int = None
# Override the registration with keep none fields
register_dataclass_as_pytree_node(MyOtherDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyOtherDataClass")
register_dataclass_as_pytree_node(MyDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyDataClass")
dt = MyOtherDataClass(x=3, y=4)
flat, spec = tree_flatten(dt)
self.assertEqual(
spec,
TreeSpec(
MyOtherDataClass,
MyDataClass,
(
MyOtherDataClass,
MyDataClass,
['x', 'y', 'z'],
[],
),
@ -648,7 +641,7 @@ class TestExport(TestCase):
self.assertEqual(flat, [3, 4, None])
orig_dt = tree_unflatten(flat, spec)
self.assertTrue(isinstance(orig_dt, MyOtherDataClass))
self.assertTrue(isinstance(orig_dt, MyDataClass))
self.assertEqual(orig_dt.x, 3)
self.assertEqual(orig_dt.y, 4)
self.assertEqual(orig_dt.z, None)

View File

@ -25,16 +25,16 @@ def register_dataclass_as_pytree_node(
flatten_fn: Optional[FlattenFunc] = None,
unflatten_fn: Optional[UnflattenFunc] = None,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
serialized_type_name: Optional[str] = None,
return_none_fields: bool = False,
) -> None:
assert dataclasses.is_dataclass(
cls
), f"Only dataclasses can be registered with this function: {cls}"
serialized_type = f"{cls.__module__}.{cls.__qualname__}"
serialized_type = f"{cls.__module__}.{cls.__name__}"
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:

View File

@ -29,7 +29,7 @@ from torch._logging import getArtifactLogger
from torch._subclasses import FakeTensor, FakeTensorMode
from torch._subclasses.fake_tensor import is_fake
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
from torch.fx import Interpreter
from torch.fx import immutable_collections, Interpreter
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
from torch.fx.experimental.symbolic_shapes import (
ShapeEnv, is_concrete_int, fx_placeholder_vals, definitely_true, definitely_false, sym_eq
@ -95,6 +95,19 @@ OutputType = Enum(
)
)
pytree._register_pytree_node(
immutable_collections.immutable_list,
lambda x: (list(x), None),
lambda x, c: immutable_collections.immutable_list(x),
)
pytree._register_pytree_node(
immutable_collections.immutable_dict,
lambda x: (list(x.values()), list(x.keys())),
lambda x, c: immutable_collections.immutable_dict(
dict(zip(c, x))
),
)
def partial_asdict(obj: Any) -> Any:
if dataclasses.is_dataclass(obj):
return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)}

View File

@ -93,11 +93,8 @@ class _PyTreeExtensionContext:
# All 'ModelOutput' subclasses are defined under module 'modeling_outputs'.
named_model_output_classes = inspect.getmembers(
modeling_outputs,
lambda x: (
inspect.isclass(x)
and issubclass(x, modeling_outputs.ModelOutput)
and x is not modeling_outputs.ModelOutput
),
lambda x: inspect.isclass(x)
and issubclass(x, modeling_outputs.ModelOutput),
)
for _, class_type in named_model_output_classes:

View File

@ -26,11 +26,6 @@ from typing import (
Union,
)
import torch
if torch._running_with_deploy():
raise ImportError("C++ pytree utilities do not work with torch::deploy.")
import optree
from optree import PyTreeSpec # direct import for type annotations
@ -40,9 +35,6 @@ __all__ = [
"Context",
"FlattenFunc",
"UnflattenFunc",
"DumpableContext",
"ToDumpableContextFn",
"FromDumpableContextFn",
"TreeSpec",
"LeafSpec",
"register_pytree_node",
@ -76,9 +68,6 @@ TreeSpec = PyTreeSpec
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
UnflattenFunc = Callable[[Iterable, Context], PyTree]
OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree]
DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
@ -95,8 +84,6 @@ def register_pytree_node(
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
namespace: str = "torch",
) -> None:
"""Extend the set of types that are considered internal nodes in pytrees.
@ -122,13 +109,6 @@ def register_pytree_node(
The function should return an instance of ``cls``.
serialized_type_name (str, optional): A keyword argument used to specify the fully
qualified name used when serializing the tree spec.
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
to convert the context of the pytree to a custom json dumpable representation. This is
used for json serialization, which is being used in :mod:`torch.export` right now.
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
how to convert the custom json dumpable representation of the context back to the
original context. This is used for json deserialization, which is being used in
:mod:`torch.export` right now.
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
type registry. This is used to isolate the registry from other modules that might
register a different custom behavior for the same type. (default: :const:`"torch"`)
@ -213,56 +193,26 @@ def register_pytree_node(
)
)
"""
_private_register_pytree_node(
from ._pytree import _register_pytree_node
_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
namespace=namespace,
)
from . import _pytree as python
python._private_register_pytree_node(
optree.register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
_reverse_args(unflatten_fn),
namespace=namespace,
)
_register_pytree_node = register_pytree_node
def _private_register_pytree_node(
cls: Type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
namespace: str = "torch",
) -> None:
"""This is an internal function that is used to register a pytree node type
for the C++ pytree only. End-users should use :func:`register_pytree_node`
instead.
"""
# TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
# PyStructSequence types
if not optree.is_structseq_class(cls):
optree.register_pytree_node(
cls,
flatten_fn,
_reverse_args(unflatten_fn),
namespace=namespace,
)
def tree_flatten(
tree: PyTree,
*,

View File

@ -150,59 +150,12 @@ def _register_pytree_node(
back to the original context. This is used for json deserialization,
which is being used in torch.export right now.
"""
_private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
to_str_fn=to_str_fn, # deprecated
maybe_from_str_fn=maybe_from_str_fn, # deprecated
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
)
try:
from . import _cxx_pytree as cxx
except ImportError:
pass
else:
cxx._private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
)
register_pytree_node = _register_pytree_node
def _private_register_pytree_node(
cls: Any,
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
to_str_fn: Optional[ToStrFunc] = None, # deprecated
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
) -> None:
"""This is an internal function that is used to register a pytree node type
for the Python pytree only. End-users should use :func:`register_pytree_node`
instead.
"""
if to_str_fn is not None or maybe_from_str_fn is not None:
warnings.warn(
"to_str_fn and maybe_from_str_fn is deprecated. "
"Please use to_dumpable_context and from_dumpable_context instead."
)
if cls in SUPPORTED_NODES:
raise ValueError(f"{cls} is already registered as pytree node.")
node_def = NodeDef(
cls,
flatten_fn,
@ -217,7 +170,7 @@ def _private_register_pytree_node(
)
if serialized_type_name is None:
serialized_type_name = f"{cls.__module__}.{cls.__qualname__}"
serialized_type_name = f"{cls.__module__}.{cls.__name__}"
serialize_node_def = _SerializeNodeDef(
cls,
@ -290,25 +243,25 @@ def _odict_unflatten(
return OrderedDict((key, value) for key, value in zip(context, values))
_private_register_pytree_node(
_register_pytree_node(
dict,
_dict_flatten,
_dict_unflatten,
serialized_type_name="builtins.dict",
)
_private_register_pytree_node(
_register_pytree_node(
list,
_list_flatten,
_list_unflatten,
serialized_type_name="builtins.list",
)
_private_register_pytree_node(
_register_pytree_node(
tuple,
_tuple_flatten,
_tuple_unflatten,
serialized_type_name="builtins.tuple",
)
_private_register_pytree_node(
_register_pytree_node(
namedtuple,
_namedtuple_flatten,
_namedtuple_unflatten,
@ -316,7 +269,7 @@ _private_register_pytree_node(
from_dumpable_context=_namedtuple_deserialize,
serialized_type_name="collections.namedtuple",
)
_private_register_pytree_node(
_register_pytree_node(
OrderedDict,
_odict_flatten,
_odict_unflatten,
@ -776,7 +729,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
if treespec.type not in SUPPORTED_SERIALIZED_TYPES:
raise NotImplementedError(
f"Serializing {treespec.type} in pytree is not registered.",
f"Serializing {treespec.type} in pytree is not registered."
)
serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]