pytorch/test/test_pytree.py
angelayi ff35e1e45b [pytree] Add custom treespec fqn field (#112428)
Custom classes that are serialized with pytree are serialized by default with `f”{class.__module__}.{class.__name__}”`. This is a dependency from our serialized program directly into the outer Python environment. If a user moves the class to a different directory, the serialized program will be unable to be loaded. So, we will require users to pass in an FQN if they want to serialize their custom treespec type.

Differential Revision: [D50886366](https://our.internmc.facebook.com/intern/diff/D50886366)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112428
Approved by: https://github.com/suo
2023-11-02 00:26:41 +00:00

767 lines
26 KiB
Python

# Owner(s): ["module: pytree"]
import unittest
from collections import namedtuple, OrderedDict
import torch
import torch.utils._cxx_pytree as cxx_pytree
import torch.utils._pytree as py_pytree
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
subtest,
TEST_WITH_TORCHDYNAMO,
TestCase,
)
GlobalPoint = namedtuple("GlobalPoint", ["x", "y"])
class GlobalDummyType:
def __init__(self, x, y):
self.x = x
self.y = y
class TestGenericPytree(TestCase):
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_flatten_unflatten_leaf(self, pytree_impl):
def run_test_with_leaf(leaf):
values, treespec = pytree_impl.tree_flatten(leaf)
self.assertEqual(values, [leaf])
self.assertEqual(treespec, pytree_impl.LeafSpec())
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, leaf)
run_test_with_leaf(1)
run_test_with_leaf(1.0)
run_test_with_leaf(None)
run_test_with_leaf(bool)
run_test_with_leaf(torch.randn(3, 3))
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda lst: py_pytree.TreeSpec(
list, None, [py_pytree.LeafSpec() for _ in lst]
),
),
name="py",
),
subtest(
(cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))),
name="cxx",
),
],
)
def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn):
def run_test(lst):
expected_spec = gen_expected_fn(lst)
values, treespec = pytree_impl.tree_flatten(lst)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, lst)
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, lst)
self.assertTrue(isinstance(unflattened, list))
run_test([])
run_test([1.0, 2])
run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda tup: py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec() for _ in tup]
),
),
name="py",
),
subtest(
(cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))),
name="cxx",
),
],
)
def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn):
def run_test(tup):
expected_spec = gen_expected_fn(tup)
values, treespec = pytree_impl.tree_flatten(tup)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(tup))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
self.assertTrue(isinstance(unflattened, tuple))
run_test(())
run_test((1.0,))
run_test((1.0, 2))
run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda dct: py_pytree.TreeSpec(
dict,
list(dct.keys()),
[py_pytree.LeafSpec() for _ in dct.values()],
),
),
name="py",
),
subtest(
(
cxx_pytree,
lambda dct: cxx_pytree.tree_structure(dict.fromkeys(dct, 0)),
),
name="cxx",
),
],
)
def test_flatten_unflatten_dict(self, pytree_impl, gen_expected_fn):
def run_test(dct):
expected_spec = gen_expected_fn(dct)
values, treespec = pytree_impl.tree_flatten(dct)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(dct.values()))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, dct)
self.assertTrue(isinstance(unflattened, dict))
run_test({})
run_test({"a": 1})
run_test({"abcdefg": torch.randn(2, 3)})
run_test({1: torch.randn(2, 3)})
run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda odict: py_pytree.TreeSpec(
OrderedDict,
list(odict.keys()),
[py_pytree.LeafSpec() for _ in odict.values()],
),
),
name="py",
),
subtest(
(
cxx_pytree,
lambda odict: cxx_pytree.tree_structure(
OrderedDict.fromkeys(odict, 0)
),
),
name="cxx",
),
],
)
def test_flatten_unflatten_odict(self, pytree_impl, gen_expected_fn):
def run_test(odict):
expected_spec = gen_expected_fn(odict)
values, treespec = pytree_impl.tree_flatten(odict)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(odict.values()))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, odict)
self.assertTrue(isinstance(unflattened, OrderedDict))
od = OrderedDict()
run_test(od)
od["b"] = 1
od["a"] = torch.tensor(3.14)
run_test(od)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_flatten_unflatten_namedtuple(self, pytree_impl):
Point = namedtuple("Point", ["x", "y"])
def run_test(tup):
if pytree_impl is py_pytree:
expected_spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec() for _ in tup]
)
else:
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
values, treespec = pytree_impl.tree_flatten(tup)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(tup))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
self.assertTrue(isinstance(unflattened, Point))
run_test(Point(1.0, 2))
run_test(Point(torch.tensor(1.0), 2))
@parametrize(
"op",
[
subtest(torch.max, name="max"),
subtest(torch.min, name="min"),
],
)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_flatten_unflatten_return_type(self, pytree_impl, op):
x = torch.randn(3, 3)
expected = op(x, dim=0)
values, spec = pytree_impl.tree_flatten(expected)
# Check that values is actually List[Tensor] and not (ReturnType(...),)
for value in values:
self.assertTrue(isinstance(value, torch.Tensor))
result = pytree_impl.tree_unflatten(values, spec)
self.assertEqual(type(result), type(expected))
self.assertEqual(result, expected)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_flatten_unflatten_nested(self, pytree_impl):
def run_test(pytree):
values, treespec = pytree_impl.tree_flatten(pytree)
self.assertTrue(isinstance(values, list))
self.assertEqual(len(values), treespec.num_leaves)
# NB: python basic data structures (dict list tuple) all have
# contents equality defined on them, so the following works for them.
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, pytree)
cases = [
[()],
([],),
{"a": ()},
{"a": 0, "b": [{"c": 1}]},
{"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)},
]
for case in cases:
run_test(case)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_treemap(self, pytree_impl):
def run_test(pytree):
def f(x):
return x * 3
sm1 = sum(map(f, pytree_impl.tree_leaves(pytree)))
sm2 = sum(pytree_impl.tree_leaves(pytree_impl.tree_map(f, pytree)))
self.assertEqual(sm1, sm2)
def invf(x):
return x // 3
self.assertEqual(
pytree_impl.tree_map(invf, pytree_impl.tree_map(f, pytree)),
pytree,
)
cases = [
[()],
([],),
{"a": ()},
{"a": 1, "b": [{"c": 2}]},
{"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
]
for case in cases:
run_test(case)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_tree_only(self, pytree_impl):
self.assertEqual(
pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_tree_all_any(self, pytree_impl):
self.assertTrue(pytree_impl.tree_all(lambda x: x % 2, [1, 3]))
self.assertFalse(pytree_impl.tree_all(lambda x: x % 2, [0, 1]))
self.assertTrue(pytree_impl.tree_any(lambda x: x % 2, [0, 1]))
self.assertFalse(pytree_impl.tree_any(lambda x: x % 2, [0, 2]))
self.assertTrue(pytree_impl.tree_all_only(int, lambda x: x % 2, [1, 3, "a"]))
self.assertFalse(pytree_impl.tree_all_only(int, lambda x: x % 2, [0, 1, "a"]))
self.assertTrue(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 1, "a"]))
self.assertFalse(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 2, "a"]))
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_broadcast_to_and_flatten(self, pytree_impl):
cases = [
(1, (), []),
# Same (flat) structures
((1,), (0,), [1]),
([1], [0], [1]),
((1, 2, 3), (0, 0, 0), [1, 2, 3]),
({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]),
# Mismatched (flat) structures
([1], (0,), None),
([1], (0,), None),
((1,), [0], None),
((1, 2, 3), (0, 0), None),
({"a": 1, "b": 2}, {"a": 0}, None),
({"a": 1, "b": 2}, {"a": 0, "c": 0}, None),
({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None),
# Same (nested) structures
((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),
# Mismatched (nested) structures
((1, [2, 3]), (0, (0, 0)), None),
((1, [2, 3]), (0, [0, 0, 0]), None),
# Broadcasting single value
(1, (0, 0, 0), [1, 1, 1]),
(1, [0, 0, 0], [1, 1, 1]),
(1, {"a": 0, "b": 0}, [1, 1]),
(1, (0, [0, [0]], 0), [1, 1, 1, 1]),
(1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),
# Broadcast multiple things
((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
(([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
]
for pytree, to_pytree, expected in cases:
_, to_spec = pytree_impl.tree_flatten(to_pytree)
result = pytree_impl._broadcast_to_and_flatten(pytree, to_spec)
self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_pytree_serialize_bad_input(self, pytree_impl):
with self.assertRaises(TypeError):
pytree_impl.treespec_dumps("random_blurb")
class TestPythonPytree(TestCase):
def test_treespec_equality(self):
self.assertTrue(
py_pytree.LeafSpec() == py_pytree.LeafSpec(),
)
self.assertTrue(
py_pytree.TreeSpec(list, None, []) == py_pytree.TreeSpec(list, None, []),
)
self.assertTrue(
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()])
== py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
)
self.assertFalse(
py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []),
)
self.assertTrue(
py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []),
)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
def test_treespec_repr(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = py_pytree.tree_flatten(pytree)
self.assertEqual(
repr(spec),
(
"TreeSpec(tuple, None, [*,\n"
" TreeSpec(list, None, [*,\n"
" *,\n"
" TreeSpec(list, None, [*])])])"
),
)
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
def test_treespec_repr_dynamo(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = py_pytree.tree_flatten(pytree)
self.assertExpectedInline(
repr(spec),
"""\
TreeSpec(tuple, None, [*,
TreeSpec(list, None, [*,
*,
TreeSpec(list, None, [*])])])""",
)
@parametrize(
"spec",
[
py_pytree.TreeSpec(list, None, []),
py_pytree.TreeSpec(tuple, None, []),
py_pytree.TreeSpec(dict, [], []),
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
py_pytree.TreeSpec(
list, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
),
py_pytree.TreeSpec(
tuple,
None,
[py_pytree.LeafSpec(), py_pytree.LeafSpec(), py_pytree.LeafSpec()],
),
py_pytree.TreeSpec(
dict,
["a", "b", "c"],
[py_pytree.LeafSpec(), py_pytree.LeafSpec(), py_pytree.LeafSpec()],
),
py_pytree.TreeSpec(
OrderedDict,
["a", "b", "c"],
[
py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
),
py_pytree.LeafSpec(),
py_pytree.TreeSpec(
dict,
["a", "b", "c"],
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
],
),
py_pytree.TreeSpec(
list,
None,
[
py_pytree.TreeSpec(
tuple,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.TreeSpec(
list,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
],
),
],
),
],
)
def test_pytree_serialize(self, spec):
serialized_spec = py_pytree.treespec_dumps(spec)
self.assertTrue(isinstance(serialized_spec, str))
self.assertTrue(spec == py_pytree.treespec_loads(serialized_spec))
def test_pytree_serialize_namedtuple(self):
Point = namedtuple("Point", ["x", "y"])
spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
# The context in the namedtuple is different now because we recreated
# the namedtuple type.
self.assertEqual(spec.context._fields, roundtrip_spec.context._fields)
def test_pytree_custom_type_serialize_bad(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
with self.assertRaisesRegex(
NotImplementedError, "No registered serialization name"
):
roundtrip_spec = py_pytree.treespec_dumps(spec)
def test_pytree_custom_type_serialize(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
serialized_type_name="test_pytree_custom_type_serialize.DummyType",
to_dumpable_context=lambda context: "moo",
from_dumpable_context=lambda dumpable_context: None,
)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
serialized_spec = py_pytree.treespec_dumps(spec, 1)
self.assertTrue("moo" in serialized_spec)
roundtrip_spec = py_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
def test_pytree_serialize_register_bad(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
with self.assertRaisesRegex(
ValueError, "Both to_dumpable_context and from_dumpable_context"
):
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
serialized_type_name="test_pytree_serialize_register_bad.DummyType",
to_dumpable_context=lambda context: "moo",
)
def test_pytree_context_serialize_bad(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
serialized_type_name="test_pytree_serialize_serialize_bad.DummyType",
to_dumpable_context=lambda context: DummyType,
from_dumpable_context=lambda dumpable_context: None,
)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
with self.assertRaisesRegex(
TypeError, "Object of type type is not JSON serializable"
):
py_pytree.treespec_dumps(spec)
def test_pytree_serialize_bad_protocol(self):
import json
Point = namedtuple("Point", ["x", "y"])
spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
with self.assertRaisesRegex(ValueError, "Unknown protocol"):
py_pytree.treespec_dumps(spec, -1)
serialized_spec = py_pytree.treespec_dumps(spec)
protocol, data = json.loads(serialized_spec)
bad_protocol_serialized_spec = json.dumps((-1, data))
with self.assertRaisesRegex(ValueError, "Unknown protocol"):
py_pytree.treespec_loads(bad_protocol_serialized_spec)
def test_saved_serialized(self):
complicated_spec = py_pytree.TreeSpec(
OrderedDict,
[1, 2, 3],
[
py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
),
py_pytree.LeafSpec(),
py_pytree.TreeSpec(
dict,
[4, 5, 6],
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
],
)
serialized_spec = py_pytree.treespec_dumps(complicated_spec)
saved_spec = (
'[1, {"type": "collections.OrderedDict", "context": "[1, 2, 3]", '
'"children_spec": [{"type": "builtins.tuple", "context": "null", '
'"children_spec": [{"type": null, "context": null, '
'"children_spec": []}, {"type": null, "context": null, '
'"children_spec": []}]}, {"type": null, "context": null, '
'"children_spec": []}, {"type": "builtins.dict", "context": '
'"[4, 5, 6]", "children_spec": [{"type": null, "context": null, '
'"children_spec": []}, {"type": null, "context": null, "children_spec": '
'[]}, {"type": null, "context": null, "children_spec": []}]}]}]'
)
self.assertEqual(serialized_spec, saved_spec)
self.assertEqual(complicated_spec, py_pytree.treespec_loads(saved_spec))
class TestCxxPytree(TestCase):
def test_treespec_equality(self):
self.assertTrue(cxx_pytree.LeafSpec() == cxx_pytree.LeafSpec())
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
def test_treespec_repr(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = cxx_pytree.tree_flatten(pytree)
self.assertEqual(
repr(spec),
("PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)"),
)
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
def test_treespec_repr_dynamo(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = cxx_pytree.tree_flatten(pytree)
self.assertExpectedInline(
repr(spec),
"PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)",
)
@parametrize(
"spec",
[
cxx_pytree.tree_structure([]),
cxx_pytree.tree_structure(()),
cxx_pytree.tree_structure({}),
cxx_pytree.tree_structure([0]),
cxx_pytree.tree_structure([0, 1]),
cxx_pytree.tree_structure((0, 1, 2)),
cxx_pytree.tree_structure({"a": 0, "b": 1, "c": 2}),
cxx_pytree.tree_structure(
OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
),
cxx_pytree.tree_structure([(0, 1, [2, 3])]),
],
)
def test_pytree_serialize(self, spec):
serialized_spec = cxx_pytree.treespec_dumps(spec)
self.assertTrue(isinstance(serialized_spec, str))
self.assertTrue(spec == cxx_pytree.treespec_loads(serialized_spec))
def test_pytree_serialize_namedtuple(self):
spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
LocalPoint = namedtuple("LocalPoint", ["x", "y"])
spec = cxx_pytree.tree_structure(LocalPoint(0, 1))
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
def test_pytree_custom_type_serialize(self):
cxx_pytree.register_pytree_node(
GlobalDummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: GlobalDummyType(*xs),
serialized_type_name="GlobalDummyType",
)
spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1))
serialized_spec = cxx_pytree.treespec_dumps(spec)
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
class LocalDummyType:
def __init__(self, x, y):
self.x = x
self.y = y
cxx_pytree.register_pytree_node(
LocalDummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: LocalDummyType(*xs),
serialized_type_name="LocalDummyType",
)
spec = cxx_pytree.tree_structure(LocalDummyType(0, 1))
serialized_spec = cxx_pytree.treespec_dumps(spec)
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
instantiate_parametrized_tests(TestGenericPytree)
instantiate_parametrized_tests(TestPythonPytree)
instantiate_parametrized_tests(TestCxxPytree)
if __name__ == "__main__":
run_tests()