mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytree][reland] Require pytree serialized_type_name (#120636)
Relanding https://github.com/pytorch/pytorch/pull/119718 as the diff which prevents breakages of torchrec [D53857843](https://www.internalfb.com/diff/D53857843) has landed Pull Request resolved: https://github.com/pytorch/pytorch/pull/120636 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
parent
12f724c779
commit
cbbc309cae
|
|
@ -761,7 +761,10 @@ class TestExport(TestCase):
|
|||
a: Tensor
|
||||
b: Tensor
|
||||
|
||||
register_dataclass_as_pytree_node(DataClass)
|
||||
register_dataclass_as_pytree_node(
|
||||
DataClass,
|
||||
serialized_type_name="test_export_api_with_dynamic_shapes.DataClass",
|
||||
)
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, inputs):
|
||||
|
|
@ -947,7 +950,7 @@ class TestExport(TestCase):
|
|||
self.assertEqual(
|
||||
spec,
|
||||
TreeSpec(
|
||||
MyDataClass, (MyDataClass, ["x", "y"], ["z"]), [LeafSpec(), LeafSpec()]
|
||||
MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]
|
||||
),
|
||||
)
|
||||
self.assertEqual(flat, [3, 4])
|
||||
|
|
@ -980,11 +983,7 @@ class TestExport(TestCase):
|
|||
spec,
|
||||
TreeSpec(
|
||||
MyOtherDataClass,
|
||||
(
|
||||
MyOtherDataClass,
|
||||
["x", "y", "z"],
|
||||
[],
|
||||
),
|
||||
[["x", "y", "z"], []],
|
||||
[LeafSpec(), LeafSpec(), LeafSpec()],
|
||||
),
|
||||
)
|
||||
|
|
@ -1987,7 +1986,10 @@ def forward(self, arg_0):
|
|||
f: torch.Tensor
|
||||
p: torch.Tensor
|
||||
|
||||
torch._export.utils.register_dataclass_as_pytree_node(Input)
|
||||
torch._export.utils.register_dataclass_as_pytree_node(
|
||||
Input,
|
||||
serialized_type_name="test_preserve_shape_dynamism_for_unused_inputs.Input"
|
||||
)
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: Input):
|
||||
|
|
|
|||
|
|
@ -925,7 +925,6 @@ TreeSpec(tuple, None, [*,
|
|||
# the namedtuple type.
|
||||
self.assertEqual(spec.context._fields, roundtrip_spec.context._fields)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_pytree_custom_type_serialize_bad(self):
|
||||
class DummyType:
|
||||
def __init__(self, x, y):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from torch.export import ExportedProgram
|
|||
from torch.utils._pytree import (
|
||||
_register_pytree_node,
|
||||
Context,
|
||||
DumpableContext,
|
||||
FlattenFunc,
|
||||
FromDumpableContextFn,
|
||||
KeyPath,
|
||||
|
|
@ -22,9 +21,6 @@ from torch.utils._pytree import (
|
|||
)
|
||||
|
||||
|
||||
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS: Dict[str, Type[Any]] = {}
|
||||
|
||||
|
||||
def _check_input_constraints_for_graph(
|
||||
input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints
|
||||
):
|
||||
|
|
@ -130,9 +126,6 @@ def register_dataclass_as_pytree_node(
|
|||
cls
|
||||
), f"Only dataclasses can be registered with this function: {cls}"
|
||||
|
||||
serialized_type = f"{cls.__module__}.{cls.__qualname__}"
|
||||
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
|
||||
|
||||
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
|
||||
flattened = []
|
||||
flat_names = []
|
||||
|
|
@ -144,21 +137,11 @@ def register_dataclass_as_pytree_node(
|
|||
flat_names.append(name)
|
||||
else:
|
||||
none_names.append(name)
|
||||
return flattened, (cls, flat_names, none_names)
|
||||
return flattened, [flat_names, none_names]
|
||||
|
||||
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
|
||||
typ, flat_names, none_names = context
|
||||
return typ(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
|
||||
|
||||
def default_to_dumpable_context(context: Context) -> DumpableContext:
|
||||
return (serialized_type, context[1], context[2])
|
||||
|
||||
def default_from_dumpable_context(dumpable_context: DumpableContext) -> Context:
|
||||
return (
|
||||
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[dumpable_context[0]],
|
||||
dumpable_context[1],
|
||||
dumpable_context[2],
|
||||
)
|
||||
flat_names, none_names = context
|
||||
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
|
||||
|
||||
flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
|
||||
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn
|
||||
|
|
@ -169,17 +152,6 @@ def register_dataclass_as_pytree_node(
|
|||
"be None or registered."
|
||||
)
|
||||
|
||||
to_dumpable_context = (
|
||||
to_dumpable_context
|
||||
if to_dumpable_context is not None
|
||||
else default_to_dumpable_context
|
||||
)
|
||||
from_dumpable_context = (
|
||||
from_dumpable_context
|
||||
if from_dumpable_context is not None
|
||||
else default_from_dumpable_context
|
||||
)
|
||||
|
||||
_register_pytree_node(
|
||||
cls,
|
||||
flatten_fn,
|
||||
|
|
|
|||
|
|
@ -314,12 +314,19 @@ def load(
|
|||
)
|
||||
|
||||
|
||||
def register_dataclass(cls: Type[Any]) -> None:
|
||||
def register_dataclass(
|
||||
cls: Type[Any],
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Registers a dataclass as a valid input/output type for :func:`torch.export.export`.
|
||||
|
||||
Args:
|
||||
cls: the dataclass type to register
|
||||
serialized_type_name: The serialized name for the dataclass. This is
|
||||
required if you want to serialize the pytree TreeSpec containing this
|
||||
dataclass.
|
||||
|
||||
Example::
|
||||
|
||||
|
|
@ -345,4 +352,6 @@ def register_dataclass(cls: Type[Any]) -> None:
|
|||
|
||||
from torch._export.utils import register_dataclass_as_pytree_node
|
||||
|
||||
return register_dataclass_as_pytree_node(cls)
|
||||
return register_dataclass_as_pytree_node(
|
||||
cls, serialized_type_name=serialized_type_name
|
||||
)
|
||||
|
|
|
|||
|
|
@ -312,7 +312,7 @@ def _private_register_pytree_node(
|
|||
)
|
||||
|
||||
if serialized_type_name is None:
|
||||
serialized_type_name = f"{cls.__module__}.{cls.__qualname__}"
|
||||
serialized_type_name = NO_SERIALIZED_TYPE_NAME_FOUND
|
||||
|
||||
serialize_node_def = _SerializeNodeDef(
|
||||
cls,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user