mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111 Approved by: https://github.com/zou3519
This commit is contained in:
parent
5e2adc8650
commit
a0d00349ed
|
|
@ -622,16 +622,23 @@ class TestExport(TestCase):
|
|||
roundtrip_spec = treespec_loads(treespec_dumps(spec))
|
||||
self.assertEqual(roundtrip_spec, spec)
|
||||
|
||||
# Override the registration with keep none fields
|
||||
register_dataclass_as_pytree_node(MyDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyDataClass")
|
||||
@dataclass
|
||||
class MyOtherDataClass: # the pytree registration don't allow registering the same class twice
|
||||
x: int
|
||||
y: int
|
||||
z: int = None
|
||||
|
||||
# Override the registration with keep none fields
|
||||
register_dataclass_as_pytree_node(MyOtherDataClass, return_none_fields=True, serialized_type_name="test_pytree_regster_data_class.MyOtherDataClass")
|
||||
|
||||
dt = MyOtherDataClass(x=3, y=4)
|
||||
flat, spec = tree_flatten(dt)
|
||||
self.assertEqual(
|
||||
spec,
|
||||
TreeSpec(
|
||||
MyDataClass,
|
||||
MyOtherDataClass,
|
||||
(
|
||||
MyDataClass,
|
||||
MyOtherDataClass,
|
||||
['x', 'y', 'z'],
|
||||
[],
|
||||
),
|
||||
|
|
@ -641,7 +648,7 @@ class TestExport(TestCase):
|
|||
self.assertEqual(flat, [3, 4, None])
|
||||
|
||||
orig_dt = tree_unflatten(flat, spec)
|
||||
self.assertTrue(isinstance(orig_dt, MyDataClass))
|
||||
self.assertTrue(isinstance(orig_dt, MyOtherDataClass))
|
||||
self.assertEqual(orig_dt.x, 3)
|
||||
self.assertEqual(orig_dt.y, 4)
|
||||
self.assertEqual(orig_dt.z, None)
|
||||
|
|
|
|||
|
|
@ -25,16 +25,16 @@ def register_dataclass_as_pytree_node(
|
|||
flatten_fn: Optional[FlattenFunc] = None,
|
||||
unflatten_fn: Optional[UnflattenFunc] = None,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
return_none_fields: bool = False,
|
||||
) -> None:
|
||||
assert dataclasses.is_dataclass(
|
||||
cls
|
||||
), f"Only dataclasses can be registered with this function: {cls}"
|
||||
|
||||
serialized_type = f"{cls.__module__}.{cls.__name__}"
|
||||
serialized_type = f"{cls.__module__}.{cls.__qualname__}"
|
||||
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
|
||||
|
||||
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from torch._logging import getArtifactLogger
|
|||
from torch._subclasses import FakeTensor, FakeTensorMode
|
||||
from torch._subclasses.fake_tensor import is_fake
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
|
||||
from torch.fx import immutable_collections, Interpreter
|
||||
from torch.fx import Interpreter
|
||||
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ShapeEnv, is_concrete_int, fx_placeholder_vals, definitely_true, definitely_false, sym_eq
|
||||
|
|
@ -95,19 +95,6 @@ OutputType = Enum(
|
|||
)
|
||||
)
|
||||
|
||||
pytree._register_pytree_node(
|
||||
immutable_collections.immutable_list,
|
||||
lambda x: (list(x), None),
|
||||
lambda x, c: immutable_collections.immutable_list(x),
|
||||
)
|
||||
pytree._register_pytree_node(
|
||||
immutable_collections.immutable_dict,
|
||||
lambda x: (list(x.values()), list(x.keys())),
|
||||
lambda x, c: immutable_collections.immutable_dict(
|
||||
dict(zip(c, x))
|
||||
),
|
||||
)
|
||||
|
||||
def partial_asdict(obj: Any) -> Any:
|
||||
if dataclasses.is_dataclass(obj):
|
||||
return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)}
|
||||
|
|
|
|||
|
|
@ -93,8 +93,11 @@ class _PyTreeExtensionContext:
|
|||
# All 'ModelOutput' subclasses are defined under module 'modeling_outputs'.
|
||||
named_model_output_classes = inspect.getmembers(
|
||||
modeling_outputs,
|
||||
lambda x: inspect.isclass(x)
|
||||
and issubclass(x, modeling_outputs.ModelOutput),
|
||||
lambda x: (
|
||||
inspect.isclass(x)
|
||||
and issubclass(x, modeling_outputs.ModelOutput)
|
||||
and x is not modeling_outputs.ModelOutput
|
||||
),
|
||||
)
|
||||
|
||||
for _, class_type in named_model_output_classes:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,11 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
if torch._running_with_deploy():
|
||||
raise ImportError("C++ pytree utilities do not work with torch::deploy.")
|
||||
|
||||
import optree
|
||||
from optree import PyTreeSpec # direct import for type annotations
|
||||
|
||||
|
|
@ -35,6 +40,9 @@ __all__ = [
|
|||
"Context",
|
||||
"FlattenFunc",
|
||||
"UnflattenFunc",
|
||||
"DumpableContext",
|
||||
"ToDumpableContextFn",
|
||||
"FromDumpableContextFn",
|
||||
"TreeSpec",
|
||||
"LeafSpec",
|
||||
"register_pytree_node",
|
||||
|
|
@ -68,6 +76,9 @@ TreeSpec = PyTreeSpec
|
|||
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
|
||||
UnflattenFunc = Callable[[Iterable, Context], PyTree]
|
||||
OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree]
|
||||
DumpableContext = Any # Any json dumpable text
|
||||
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
||||
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
||||
|
||||
|
||||
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
|
||||
|
|
@ -84,6 +95,8 @@ def register_pytree_node(
|
|||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
namespace: str = "torch",
|
||||
) -> None:
|
||||
"""Extend the set of types that are considered internal nodes in pytrees.
|
||||
|
|
@ -109,6 +122,13 @@ def register_pytree_node(
|
|||
The function should return an instance of ``cls``.
|
||||
serialized_type_name (str, optional): A keyword argument used to specify the fully
|
||||
qualified name used when serializing the tree spec.
|
||||
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
|
||||
to convert the context of the pytree to a custom json dumpable representation. This is
|
||||
used for json serialization, which is being used in :mod:`torch.export` right now.
|
||||
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
|
||||
how to convert the custom json dumpable representation of the context back to the
|
||||
original context. This is used for json deserialization, which is being used in
|
||||
:mod:`torch.export` right now.
|
||||
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
|
||||
type registry. This is used to isolate the registry from other modules that might
|
||||
register a different custom behavior for the same type. (default: :const:`"torch"`)
|
||||
|
|
@ -193,26 +213,56 @@ def register_pytree_node(
|
|||
)
|
||||
)
|
||||
"""
|
||||
from ._pytree import _register_pytree_node
|
||||
|
||||
_register_pytree_node(
|
||||
_private_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
optree.register_pytree_node(
|
||||
from . import _pytree as python
|
||||
|
||||
python._private_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
_reverse_args(unflatten_fn),
|
||||
namespace=namespace,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
)
|
||||
|
||||
|
||||
_register_pytree_node = register_pytree_node
|
||||
|
||||
|
||||
def _private_register_pytree_node(
|
||||
cls: Type[Any],
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
namespace: str = "torch",
|
||||
) -> None:
|
||||
"""This is an internal function that is used to register a pytree node type
|
||||
for the C++ pytree only. End-users should use :func:`register_pytree_node`
|
||||
instead.
|
||||
"""
|
||||
# TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
|
||||
# PyStructSequence types
|
||||
if not optree.is_structseq_class(cls):
|
||||
optree.register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
_reverse_args(unflatten_fn),
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
def tree_flatten(
|
||||
tree: PyTree,
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -150,12 +150,59 @@ def _register_pytree_node(
|
|||
back to the original context. This is used for json deserialization,
|
||||
which is being used in torch.export right now.
|
||||
"""
|
||||
_private_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
to_str_fn=to_str_fn, # deprecated
|
||||
maybe_from_str_fn=maybe_from_str_fn, # deprecated
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
)
|
||||
|
||||
try:
|
||||
from . import _cxx_pytree as cxx
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
cxx._private_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=to_dumpable_context,
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
)
|
||||
|
||||
|
||||
register_pytree_node = _register_pytree_node
|
||||
|
||||
|
||||
def _private_register_pytree_node(
|
||||
cls: Any,
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
to_str_fn: Optional[ToStrFunc] = None, # deprecated
|
||||
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
) -> None:
|
||||
"""This is an internal function that is used to register a pytree node type
|
||||
for the Python pytree only. End-users should use :func:`register_pytree_node`
|
||||
instead.
|
||||
"""
|
||||
if to_str_fn is not None or maybe_from_str_fn is not None:
|
||||
warnings.warn(
|
||||
"to_str_fn and maybe_from_str_fn is deprecated. "
|
||||
"Please use to_dumpable_context and from_dumpable_context instead."
|
||||
)
|
||||
|
||||
if cls in SUPPORTED_NODES:
|
||||
raise ValueError(f"{cls} is already registered as pytree node.")
|
||||
|
||||
node_def = NodeDef(
|
||||
cls,
|
||||
flatten_fn,
|
||||
|
|
@ -170,7 +217,7 @@ def _register_pytree_node(
|
|||
)
|
||||
|
||||
if serialized_type_name is None:
|
||||
serialized_type_name = f"{cls.__module__}.{cls.__name__}"
|
||||
serialized_type_name = f"{cls.__module__}.{cls.__qualname__}"
|
||||
|
||||
serialize_node_def = _SerializeNodeDef(
|
||||
cls,
|
||||
|
|
@ -243,25 +290,25 @@ def _odict_unflatten(
|
|||
return OrderedDict((key, value) for key, value in zip(context, values))
|
||||
|
||||
|
||||
_register_pytree_node(
|
||||
_private_register_pytree_node(
|
||||
dict,
|
||||
_dict_flatten,
|
||||
_dict_unflatten,
|
||||
serialized_type_name="builtins.dict",
|
||||
)
|
||||
_register_pytree_node(
|
||||
_private_register_pytree_node(
|
||||
list,
|
||||
_list_flatten,
|
||||
_list_unflatten,
|
||||
serialized_type_name="builtins.list",
|
||||
)
|
||||
_register_pytree_node(
|
||||
_private_register_pytree_node(
|
||||
tuple,
|
||||
_tuple_flatten,
|
||||
_tuple_unflatten,
|
||||
serialized_type_name="builtins.tuple",
|
||||
)
|
||||
_register_pytree_node(
|
||||
_private_register_pytree_node(
|
||||
namedtuple,
|
||||
_namedtuple_flatten,
|
||||
_namedtuple_unflatten,
|
||||
|
|
@ -269,7 +316,7 @@ _register_pytree_node(
|
|||
from_dumpable_context=_namedtuple_deserialize,
|
||||
serialized_type_name="collections.namedtuple",
|
||||
)
|
||||
_register_pytree_node(
|
||||
_private_register_pytree_node(
|
||||
OrderedDict,
|
||||
_odict_flatten,
|
||||
_odict_unflatten,
|
||||
|
|
@ -729,7 +776,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
|
|||
|
||||
if treespec.type not in SUPPORTED_SERIALIZED_TYPES:
|
||||
raise NotImplementedError(
|
||||
f"Serializing {treespec.type} in pytree is not registered."
|
||||
f"Serializing {treespec.type} in pytree is not registered.",
|
||||
)
|
||||
|
||||
serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user