pytorch/torch/export/_tree_utils.py
suo ca090b2c77 [export] do not use tree_flatten_spec (#118608)
tree_flatten_spec is bad; it isn't synced up with `register_pytree_node` so it will not handle arbitrary custom pytrees. It's also not really maintained.

We only use it for two purposes:
- To retain kwarg ordering stability, so that if the user passes in kwargs in a different order things will still work.
- To do "structural" checks that ignore types.

In both cases, tree_flatten_spec is probably *not* the ideal way to implement the desired behavior.

## kwargs ordering
- tree_flatten_spec overwrites the behavior of ALL dictionaries, not just kwargs. This is not correct, dictionary ordering is meaningful in Python, and it's pretty trivial to write a program that relies on dict ordering.
- For kwargs, we do sort of expect that the order in which arguments are passed shouldn't matter. BUT there is one exception: `**kwargs`. In fact, [PEP 468](https://peps.python.org/pep-0468/) was introduced specifically to clarify that ordering does matter when the function being called uses `**kwargs`.

In this diff I introduce a utility function that *only* reorders kwargs. This gets us most of the way to correct—dicts are no longer reordered, but kwargs can be passed in any order.

A "fully correct" solution would need fix the corner case from PEP468. We could detect whether the top-level fn being traced uses `**kwargs` (via `inspect`), then serialize a flag for it. In ExportedProgram, we would check that flag and only re-order if `**kwargs` was unused; otherwise error if the key order doesn't match. This is a super corner case though, so I'll file it as a followup task.

## structural equivalence checking

This is another use case, where again `tree_flatten_spec` is too broad. Generally we want to treat a precise two types as the same, not override the behavior of comparison generally. So I introduce an `is_equivalent` util for this purpose.

Differential Revision: [D53168420](https://our.internmc.facebook.com/intern/diff/D53168420/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118608
Approved by: https://github.com/zhxchen17
ghstack dependencies: #118607
2024-01-30 19:14:04 +00:00

65 lines
2.2 KiB
Python

from typing import Any, Callable, Dict, Optional
from torch.utils._pytree import Context, TreeSpec
def reorder_kwargs(user_kwargs: Dict[str, Any], spec: TreeSpec) -> Dict[str, Any]:
"""Reorder user-provided kwargs to match the order in `spec`. `spec` is
expected to be the in_spec of an exported program, i.e. the spec that
results from flattening `(args, kwargs)`.
We need this to provide consistent input ordering, such so that users can
pass in foo(a=a, b=b) OR foo(b=b, a=a) and receive the same result.
"""
# Make sure that the spec is actually shaped like (args, kwargs)
assert spec.type is tuple
assert spec.num_children == 2
kwargs_spec = spec.children_specs[1]
assert kwargs_spec.type is dict
if set(user_kwargs) != set(kwargs_spec.context):
raise ValueError(
f"kwarg key mismatch: "
f"Got {list(user_kwargs)} but expected {kwargs_spec.context}"
)
reordered_kwargs = {}
for kw in kwargs_spec.context:
reordered_kwargs[kw] = user_kwargs[kw]
return reordered_kwargs
def is_equivalent(
spec1: TreeSpec,
spec2: TreeSpec,
equivalence_fn: Callable[[Optional[type], Context, Optional[type], Context], bool],
) -> bool:
"""Customizable equivalence check for two TreeSpecs.
Arguments:
spec1: The first TreeSpec to compare
spec2: The second TreeSpec to compare
equivalence_fn: A function to determine the equivalence of two
TreeSpecs by examining their types and contexts. It will be called like:
equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context)
This function will be applied recursively to all children.
Returns:
True if the two TreeSpecs are equivalent, False otherwise.
"""
if not equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context):
return False
# Recurse on children
if len(spec1.children_specs) != len(spec2.children_specs):
return False
for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs):
if not is_equivalent(child_spec1, child_spec2, equivalence_fn):
return False
return True