mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytree] Fix namedtuple serialization (#123388)
Summary:
Previously we were serializing namedtuple treespecs incorrectly:
```python
Point = namedtuple("Point", ["x", "y"])
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)
print(flat) # [1, 2]
print(spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
We only serialize the name of the class and the fields of the namedtuple:
TreeSpec {
type='collections.namedtuple',
context={class_name='Point', class_fields={'x', 'y'}},
children=[Leaf, Leaf]
}
"""
reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)
"""
When we load, we create a new namedtuple class containing the same fields as before,
but the is class is now a completely different class than the original one:
TreeSpec(type=namedtuple, context=torch.utils._pytree.Point, children=[*, *])
"""
spec == reconstructed_spec # False
```
So, we introduce a new API called `pytree._register_namedtuple` where users can pass in the serialized name for each namedtuple class:
```python
Point = namedtuple("Point", ["x", "y"])
pytree._register_namedtuple(Point, "Point")
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)
print(flat) # [1, 2]
print(spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
TreeSpec {
type='collections.namedtuple',
context='Point',
children=[Leaf, Leaf]
}
"""
reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
spec == reconstructed_spec # True
```
Test Plan: `python test/test_pytree.py`
Differential Revision: D55771058
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123388
Approved by: https://github.com/zou3519
This commit is contained in:
parent
c797fbc4e1
commit
1be2126ff6
|
|
@ -915,15 +915,44 @@ TreeSpec(tuple, None, [*,
|
|||
self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec))
|
||||
|
||||
def test_pytree_serialize_namedtuple(self):
|
||||
Point = namedtuple("Point", ["x", "y"])
|
||||
spec = py_pytree.TreeSpec(
|
||||
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
|
||||
Point1 = namedtuple("Point1", ["x", "y"])
|
||||
py_pytree._register_namedtuple(
|
||||
Point1,
|
||||
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1",
|
||||
)
|
||||
|
||||
spec = py_pytree.TreeSpec(
|
||||
namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
|
||||
)
|
||||
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
|
||||
# The context in the namedtuple is different now because we recreated
|
||||
# the namedtuple type.
|
||||
self.assertEqual(spec.context._fields, roundtrip_spec.context._fields)
|
||||
self.assertEqual(spec, roundtrip_spec)
|
||||
|
||||
class Point2(NamedTuple):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
py_pytree._register_namedtuple(
|
||||
Point2,
|
||||
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2",
|
||||
)
|
||||
|
||||
spec = py_pytree.TreeSpec(
|
||||
namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
|
||||
)
|
||||
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
|
||||
self.assertEqual(spec, roundtrip_spec)
|
||||
|
||||
def test_pytree_serialize_namedtuple_bad(self):
|
||||
DummyType = namedtuple("DummyType", ["x", "y"])
|
||||
|
||||
spec = py_pytree.TreeSpec(
|
||||
namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError, "Please register using `_register_namedtuple`"
|
||||
):
|
||||
py_pytree.treespec_dumps(spec)
|
||||
|
||||
def test_pytree_custom_type_serialize_bad(self):
|
||||
class DummyType:
|
||||
|
|
@ -1015,6 +1044,10 @@ TreeSpec(tuple, None, [*,
|
|||
spec = py_pytree.TreeSpec(
|
||||
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
|
||||
)
|
||||
py_pytree._register_namedtuple(
|
||||
Point,
|
||||
serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point",
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Unknown protocol"):
|
||||
py_pytree.treespec_dumps(spec, -1)
|
||||
|
|
@ -1296,12 +1329,20 @@ class TestCxxPytree(TestCase):
|
|||
self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec))
|
||||
|
||||
def test_pytree_serialize_namedtuple(self):
|
||||
py_pytree._register_namedtuple(
|
||||
GlobalPoint,
|
||||
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.GlobalPoint",
|
||||
)
|
||||
spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))
|
||||
|
||||
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
|
||||
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
|
||||
|
||||
LocalPoint = namedtuple("LocalPoint", ["x", "y"])
|
||||
py_pytree._register_namedtuple(
|
||||
LocalPoint,
|
||||
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.LocalPoint",
|
||||
)
|
||||
spec = cxx_pytree.tree_structure(LocalPoint(0, 1))
|
||||
|
||||
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
|
||||
|
|
|
|||
|
|
@ -221,6 +221,34 @@ def register_pytree_node(
|
|||
)
|
||||
|
||||
|
||||
def _register_namedtuple(
|
||||
cls: Type[Any],
|
||||
*,
|
||||
serialized_type_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Registers a namedtuple as a valid pytree node. By default namedtuples are
|
||||
valid pytree nodes, but they are not serializable. This API provides the
|
||||
argument `serialized_type_name` which allows these namedtuples to be
|
||||
serialized.
|
||||
|
||||
Args:
|
||||
cls: the dataclass type to register
|
||||
serialized_type_name: The serialized name for the dataclass. This is
|
||||
required if you want to serialize the pytree TreeSpec containing this
|
||||
namedtuple.
|
||||
"""
|
||||
_private_register_pytree_node(
|
||||
cls,
|
||||
_namedtuple_flatten,
|
||||
_namedtuple_unflatten,
|
||||
serialized_type_name=serialized_type_name,
|
||||
to_dumpable_context=_namedtuple_serialize,
|
||||
from_dumpable_context=_namedtuple_deserialize,
|
||||
flatten_with_keys_fn=_namedtuple_flatten_with_keys,
|
||||
)
|
||||
|
||||
|
||||
def _register_pytree_node(
|
||||
cls: Type[Any],
|
||||
flatten_fn: FlattenFunc,
|
||||
|
|
@ -422,18 +450,34 @@ def _namedtuple_unflatten(values: Iterable[Any], context: Context) -> NamedTuple
|
|||
|
||||
|
||||
def _namedtuple_serialize(context: Context) -> DumpableContext:
|
||||
json_namedtuple = {
|
||||
"class_name": context.__name__,
|
||||
"fields": context._fields,
|
||||
}
|
||||
return json_namedtuple
|
||||
if context not in SUPPORTED_SERIALIZED_TYPES:
|
||||
raise NotImplementedError(
|
||||
f"Can't serialize TreeSpec of namedtuple class {context} because we "
|
||||
"didn't register a serializated_type_name. Please register using "
|
||||
"`_register_namedtuple`."
|
||||
)
|
||||
|
||||
serialize_node_def = SUPPORTED_SERIALIZED_TYPES[context]
|
||||
serialized_type_name = serialize_node_def.serialized_type_name
|
||||
|
||||
if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
|
||||
raise NotImplementedError(
|
||||
f"Can't serialize TreeSpec of namedtuple class {context} because we "
|
||||
"couldn't find a serializated_type_name. Please register using "
|
||||
"`_register_namedtuple`."
|
||||
)
|
||||
return serialized_type_name
|
||||
|
||||
|
||||
def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context:
|
||||
class_name = dumpable_context["class_name"]
|
||||
assert isinstance(class_name, str)
|
||||
context = namedtuple(class_name, dumpable_context["fields"]) # type: ignore[misc]
|
||||
return context
|
||||
if dumpable_context not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
|
||||
raise NotImplementedError(
|
||||
f"Can't deserialize TreeSpec of namedtuple class {dumpable_context} "
|
||||
"because we couldn't find a serializated name."
|
||||
)
|
||||
|
||||
typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[dumpable_context]
|
||||
return typ
|
||||
|
||||
|
||||
def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user