pytorch/torch/fx/_pytree.py
Richard Zou 52d1ffb789 Teach pytrees about namedtuple (#62292)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62292

This PR adds pytree support for namedtuples. The challenge about namedtuple
is that each namedtuple class is actually different. This PR does the
following:
- it adds a namedtuple flatten/unflatten. The flatten function returns
a context that is the actual type of the namedtuple subclass. The
unflatten function uses that type to reconstruct the namedtuple
- Special cases all pytree logic to consider all namedtuples the same.
This is done by creating a `_get_node_type(pytree)` helper function that
returns `namedtuple` if `pytree` is any namedtuple subclass. The effect
of this is that all namedtuple subclasses will go through the namedtuple
flatten/unflatten functions
- Adds a `_namedtuple_flatten_spec` function for FX pytrees. This function
flattens the namedtuple based on the spec and is equivalent to the
`_tuple_flatten_spec`.

Test Plan
- new tests in test/test_pytree.py and test/test_fx.py

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D29947302

Pulled By: zou3519

fbshipit-source-id: 19c00665b13546642c315df0f243ad99b8e7ff7c
2021-07-28 06:27:44 -07:00

43 lines
1.9 KiB
Python

from typing import Callable, Any, Tuple, List, Dict, Type, NamedTuple
from torch.utils._pytree import PyTree, TreeSpec, LeafSpec
from collections import namedtuple
FlattenFuncSpec = Callable[[PyTree, TreeSpec], List]
SUPPORTED_NODES: Dict[Type[Any], Any] = {}
def register_pytree_flatten_spec(typ: Any, flatten_fn_spec: FlattenFuncSpec) -> None:
SUPPORTED_NODES[typ] = flatten_fn_spec
def tree_flatten_spec(pytree: PyTree, spec: TreeSpec) -> List[Any]:
if isinstance(spec, LeafSpec):
return [pytree]
if spec.type not in SUPPORTED_NODES:
raise RuntimeError(
f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with"
"torch.fx._pytree.register_pytree_flatten_spec. If you have serialized your model, make"
"sure that any custom pytrees have been registered before loading it.")
flatten_fn_spec = SUPPORTED_NODES[spec.type]
child_pytrees = flatten_fn_spec(pytree, spec)
result = []
for child, child_spec in zip(child_pytrees, spec.children_specs):
flat = tree_flatten_spec(child, child_spec)
result += flat
return result
def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]:
return list([d[k] for k in spec.context])
def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]:
return [d[i] for i in range(len(spec.children_specs))]
def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]:
return [d[i] for i in range(len(spec.children_specs))]
def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]:
return [d[i] for i in range(len(spec.children_specs))]
register_pytree_flatten_spec(dict, _dict_flatten_spec)
register_pytree_flatten_spec(list, _list_flatten_spec)
register_pytree_flatten_spec(tuple, _tuple_flatten_spec)
register_pytree_flatten_spec(namedtuple, _tuple_flatten_spec)