[pytree][reland] Require pytree serialized_type_name (#120636)

Relanding https://github.com/pytorch/pytorch/pull/119718 as the diff which prevents breakages of torchrec [D53857843](https://www.internalfb.com/diff/D53857843) has landed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120636
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
angelayi 2024-02-27 06:53:33 +00:00 committed by PyTorch MergeBot
parent 12f724c779
commit cbbc309cae
5 changed files with 25 additions and 43 deletions

View File

@ -761,7 +761,10 @@ class TestExport(TestCase):
a: Tensor
b: Tensor
register_dataclass_as_pytree_node(DataClass)
register_dataclass_as_pytree_node(
DataClass,
serialized_type_name="test_export_api_with_dynamic_shapes.DataClass",
)
class Foo(torch.nn.Module):
def forward(self, inputs):
@ -947,7 +950,7 @@ class TestExport(TestCase):
self.assertEqual(
spec,
TreeSpec(
MyDataClass, (MyDataClass, ["x", "y"], ["z"]), [LeafSpec(), LeafSpec()]
MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]
),
)
self.assertEqual(flat, [3, 4])
@ -980,11 +983,7 @@ class TestExport(TestCase):
spec,
TreeSpec(
MyOtherDataClass,
(
MyOtherDataClass,
["x", "y", "z"],
[],
),
[["x", "y", "z"], []],
[LeafSpec(), LeafSpec(), LeafSpec()],
),
)
@ -1987,7 +1986,10 @@ def forward(self, arg_0):
f: torch.Tensor
p: torch.Tensor
torch._export.utils.register_dataclass_as_pytree_node(Input)
torch._export.utils.register_dataclass_as_pytree_node(
Input,
serialized_type_name="test_preserve_shape_dynamism_for_unused_inputs.Input"
)
class Module(torch.nn.Module):
def forward(self, x: Input):

View File

@ -925,7 +925,6 @@ TreeSpec(tuple, None, [*,
# the namedtuple type.
self.assertEqual(spec.context._fields, roundtrip_spec.context._fields)
@unittest.expectedFailure
def test_pytree_custom_type_serialize_bad(self):
class DummyType:
def __init__(self, x, y):

View File

@ -10,7 +10,6 @@ from torch.export import ExportedProgram
from torch.utils._pytree import (
_register_pytree_node,
Context,
DumpableContext,
FlattenFunc,
FromDumpableContextFn,
KeyPath,
@ -22,9 +21,6 @@ from torch.utils._pytree import (
)
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS: Dict[str, Type[Any]] = {}
def _check_input_constraints_for_graph(
input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints
):
@ -130,9 +126,6 @@ def register_dataclass_as_pytree_node(
cls
), f"Only dataclasses can be registered with this function: {cls}"
serialized_type = f"{cls.__module__}.{cls.__qualname__}"
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
flattened = []
flat_names = []
@ -144,21 +137,11 @@ def register_dataclass_as_pytree_node(
flat_names.append(name)
else:
none_names.append(name)
return flattened, (cls, flat_names, none_names)
return flattened, [flat_names, none_names]
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
typ, flat_names, none_names = context
return typ(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
def default_to_dumpable_context(context: Context) -> DumpableContext:
return (serialized_type, context[1], context[2])
def default_from_dumpable_context(dumpable_context: DumpableContext) -> Context:
return (
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[dumpable_context[0]],
dumpable_context[1],
dumpable_context[2],
)
flat_names, none_names = context
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn
@ -169,17 +152,6 @@ def register_dataclass_as_pytree_node(
"be None or registered."
)
to_dumpable_context = (
to_dumpable_context
if to_dumpable_context is not None
else default_to_dumpable_context
)
from_dumpable_context = (
from_dumpable_context
if from_dumpable_context is not None
else default_from_dumpable_context
)
_register_pytree_node(
cls,
flatten_fn,

View File

@ -314,12 +314,19 @@ def load(
)
def register_dataclass(cls: Type[Any]) -> None:
def register_dataclass(
cls: Type[Any],
*,
serialized_type_name: Optional[str] = None,
) -> None:
"""
Registers a dataclass as a valid input/output type for :func:`torch.export.export`.
Args:
cls: the dataclass type to register
serialized_type_name: The serialized name for the dataclass. This is
required if you want to serialize the pytree TreeSpec containing this
dataclass.
Example::
@ -345,4 +352,6 @@ def register_dataclass(cls: Type[Any]) -> None:
from torch._export.utils import register_dataclass_as_pytree_node
return register_dataclass_as_pytree_node(cls)
return register_dataclass_as_pytree_node(
cls, serialized_type_name=serialized_type_name
)

View File

@ -312,7 +312,7 @@ def _private_register_pytree_node(
)
if serialized_type_name is None:
serialized_type_name = f"{cls.__module__}.{cls.__qualname__}"
serialized_type_name = NO_SERIALIZED_TYPE_NAME_FOUND
serialize_node_def = _SerializeNodeDef(
cls,