mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
tree_flatten_spec is bad; it isn't synced up with `register_pytree_node` so it will not handle arbitrary custom pytrees. It's also not really maintained. We only use it for two purposes: - To retain kwarg ordering stability, so that if the user passes in kwargs in a different order things will still work. - To do "structural" checks that ignore types. In both cases, tree_flatten_spec is probably *not* the ideal way to implement the desired behavior. ## kwargs ordering - tree_flatten_spec overwrites the behavior of ALL dictionaries, not just kwargs. This is not correct, dictionary ordering is meaningful in Python, and it's pretty trivial to write a program that relies on dict ordering. - For kwargs, we do sort of expect that the order in which arguments are passed shouldn't matter. BUT there is one exception: `**kwargs`. In fact, [PEP 468](https://peps.python.org/pep-0468/) was introduced specifically to clarify that ordering does matter when the function being called uses `**kwargs`. In this diff I introduce a utility function that *only* reorders kwargs. This gets us most of the way to correct—dicts are no longer reordered, but kwargs can be passed in any order. A "fully correct" solution would need fix the corner case from PEP468. We could detect whether the top-level fn being traced uses `**kwargs` (via `inspect`), then serialize a flag for it. In ExportedProgram, we would check that flag and only re-order if `**kwargs` was unused; otherwise error if the key order doesn't match. This is a super corner case though, so I'll file it as a followup task. ## structural equivalence checking This is another use case, where again `tree_flatten_spec` is too broad. Generally we want to treat a precise two types as the same, not override the behavior of comparison generally. So I introduce an `is_equivalent` util for this purpose. Differential Revision: [D53168420](https://our.internmc.facebook.com/intern/diff/D53168420/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/118608 Approved by: https://github.com/zhxchen17 ghstack dependencies: #118607
51 lines
1.8 KiB
Python
51 lines
1.8 KiB
Python
# Owner(s): ["oncall: export"]
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
from torch._dynamo.test_case import TestCase
|
|
|
|
from torch.export._tree_utils import is_equivalent, reorder_kwargs
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.utils._pytree import tree_structure
|
|
|
|
|
|
class TestTreeUtils(TestCase):
|
|
def test_reorder_kwargs(self):
|
|
original_kwargs = {"a": torch.tensor(0), "b": torch.tensor(1)}
|
|
user_kwargs = {"b": torch.tensor(2), "a": torch.tensor(3)}
|
|
orig_spec = tree_structure(((), original_kwargs))
|
|
|
|
reordered_kwargs = reorder_kwargs(user_kwargs, orig_spec)
|
|
|
|
# Key ordering should be the same
|
|
self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]),
|
|
self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]),
|
|
|
|
def test_equivalence_check(self):
|
|
tree1 = {"a": torch.tensor(0), "b": torch.tensor(1), "c": None}
|
|
tree2 = OrderedDict(a=torch.tensor(0), b=torch.tensor(1), c=None)
|
|
spec1 = tree_structure(tree1)
|
|
spec2 = tree_structure(tree2)
|
|
|
|
def dict_ordered_dict_eq(type1, context1, type2, context2):
|
|
if type1 is None or type2 is None:
|
|
return type1 is type2 and context1 == context2
|
|
|
|
if issubclass(type1, (dict, OrderedDict)) and issubclass(
|
|
type2, (dict, OrderedDict)
|
|
):
|
|
return context1 == context2
|
|
|
|
return type1 is type2 and context1 == context2
|
|
|
|
self.assertTrue(is_equivalent(spec1, spec2, dict_ordered_dict_eq))
|
|
|
|
# Wrong ordering should still fail
|
|
tree3 = OrderedDict(b=torch.tensor(1), a=torch.tensor(0))
|
|
spec3 = tree_structure(tree3)
|
|
self.assertFalse(is_equivalent(spec1, spec3, dict_ordered_dict_eq))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|