[pytree] Fix namedtuple serialization (#123388)

Summary:
Previously we were serializing namedtuple treespecs incorrectly:
```python
Point = namedtuple("Point", ["x", "y"])
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)

print(flat)  # [1, 2]
print(spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
We only serialize the name of the class and the fields of the namedtuple:

TreeSpec {
  type='collections.namedtuple',
  context={class_name='Point', class_fields={'x', 'y'}},
  children=[Leaf, Leaf]
}
"""

reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)
"""
When we load, we create a new namedtuple class containing the same fields as before,
but the is class is now a completely different class than the original one:

TreeSpec(type=namedtuple, context=torch.utils._pytree.Point, children=[*, *])
"""

spec == reconstructed_spec  # False
```

So, we introduce a new API called `pytree._register_namedtuple` where users can pass in the serialized name for each namedtuple class:
```python
Point = namedtuple("Point", ["x", "y"])
pytree._register_namedtuple(Point, "Point")

p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)

print(flat)  # [1, 2]
print(spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
TreeSpec {
  type='collections.namedtuple',
  context='Point',
  children=[Leaf, Leaf]
}
"""

reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

spec == reconstructed_spec  # True
```

Test Plan: `python test/test_pytree.py`

Differential Revision: D55771058

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123388
Approved by: https://github.com/zou3519
This commit is contained in:
Angela Yi 2024-04-08 20:55:19 +00:00 committed by PyTorch MergeBot
parent c797fbc4e1
commit 1be2126ff6
2 changed files with 100 additions and 15 deletions

View File

@ -915,15 +915,44 @@ TreeSpec(tuple, None, [*,
self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec))
def test_pytree_serialize_namedtuple(self):
Point = namedtuple("Point", ["x", "y"])
spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
Point1 = namedtuple("Point1", ["x", "y"])
py_pytree._register_namedtuple(
Point1,
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1",
)
spec = py_pytree.TreeSpec(
namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
# The context in the namedtuple is different now because we recreated
# the namedtuple type.
self.assertEqual(spec.context._fields, roundtrip_spec.context._fields)
self.assertEqual(spec, roundtrip_spec)
class Point2(NamedTuple):
x: int
y: int
py_pytree._register_namedtuple(
Point2,
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2",
)
spec = py_pytree.TreeSpec(
namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
self.assertEqual(spec, roundtrip_spec)
def test_pytree_serialize_namedtuple_bad(self):
DummyType = namedtuple("DummyType", ["x", "y"])
spec = py_pytree.TreeSpec(
namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
with self.assertRaisesRegex(
NotImplementedError, "Please register using `_register_namedtuple`"
):
py_pytree.treespec_dumps(spec)
def test_pytree_custom_type_serialize_bad(self):
class DummyType:
@ -1015,6 +1044,10 @@ TreeSpec(tuple, None, [*,
spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
py_pytree._register_namedtuple(
Point,
serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point",
)
with self.assertRaisesRegex(ValueError, "Unknown protocol"):
py_pytree.treespec_dumps(spec, -1)
@ -1296,12 +1329,20 @@ class TestCxxPytree(TestCase):
self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec))
def test_pytree_serialize_namedtuple(self):
py_pytree._register_namedtuple(
GlobalPoint,
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.GlobalPoint",
)
spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
LocalPoint = namedtuple("LocalPoint", ["x", "y"])
py_pytree._register_namedtuple(
LocalPoint,
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.LocalPoint",
)
spec = cxx_pytree.tree_structure(LocalPoint(0, 1))
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))

View File

@ -221,6 +221,34 @@ def register_pytree_node(
)
def _register_namedtuple(
cls: Type[Any],
*,
serialized_type_name: str,
) -> None:
"""
Registers a namedtuple as a valid pytree node. By default namedtuples are
valid pytree nodes, but they are not serializable. This API provides the
argument `serialized_type_name` which allows these namedtuples to be
serialized.
Args:
cls: the dataclass type to register
serialized_type_name: The serialized name for the dataclass. This is
required if you want to serialize the pytree TreeSpec containing this
namedtuple.
"""
_private_register_pytree_node(
cls,
_namedtuple_flatten,
_namedtuple_unflatten,
serialized_type_name=serialized_type_name,
to_dumpable_context=_namedtuple_serialize,
from_dumpable_context=_namedtuple_deserialize,
flatten_with_keys_fn=_namedtuple_flatten_with_keys,
)
def _register_pytree_node(
cls: Type[Any],
flatten_fn: FlattenFunc,
@ -422,18 +450,34 @@ def _namedtuple_unflatten(values: Iterable[Any], context: Context) -> NamedTuple
def _namedtuple_serialize(context: Context) -> DumpableContext:
json_namedtuple = {
"class_name": context.__name__,
"fields": context._fields,
}
return json_namedtuple
if context not in SUPPORTED_SERIALIZED_TYPES:
raise NotImplementedError(
f"Can't serialize TreeSpec of namedtuple class {context} because we "
"didn't register a serializated_type_name. Please register using "
"`_register_namedtuple`."
)
serialize_node_def = SUPPORTED_SERIALIZED_TYPES[context]
serialized_type_name = serialize_node_def.serialized_type_name
if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
raise NotImplementedError(
f"Can't serialize TreeSpec of namedtuple class {context} because we "
"couldn't find a serializated_type_name. Please register using "
"`_register_namedtuple`."
)
return serialized_type_name
def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context:
class_name = dumpable_context["class_name"]
assert isinstance(class_name, str)
context = namedtuple(class_name, dumpable_context["fields"]) # type: ignore[misc]
return context
if dumpable_context not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
raise NotImplementedError(
f"Can't deserialize TreeSpec of namedtuple class {dumpable_context} "
"because we couldn't find a serializated name."
)
typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[dumpable_context]
return typ
def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]: