pytorch/torch/utils/_pytree.py
Richard Zou 52d1ffb789 Teach pytrees about namedtuple (#62292)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62292

This PR adds pytree support for namedtuples. The challenge about namedtuple
is that each namedtuple class is actually different. This PR does the
following:
- it adds a namedtuple flatten/unflatten. The flatten function returns
a context that is the actual type of the namedtuple subclass. The
unflatten function uses that type to reconstruct the namedtuple
- Special cases all pytree logic to consider all namedtuples the same.
This is done by creating a `_get_node_type(pytree)` helper function that
returns `namedtuple` if `pytree` is any namedtuple subclass. The effect
of this is that all namedtuple subclasses will go through the namedtuple
flatten/unflatten functions
- Adds a `_namedtuple_flatten_spec` function for FX pytrees. This function
flattens the namedtuple based on the spec and is equivalent to the
`_tuple_flatten_spec`.

Test Plan
- new tests in test/test_pytree.py and test/test_fx.py

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D29947302

Pulled By: zou3519

fbshipit-source-id: 19c00665b13546642c315df0f243ad99b8e7ff7c
2021-07-28 06:27:44 -07:00

217 lines
8.0 KiB
Python

from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, Optional
from collections import namedtuple
"""
Contains utility functions for working with nested python data structures.
A *pytree* is Python nested data structure. It is a tree in the sense that
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
Python values. Furthermore, a pytree should not contain reference cycles.
pytrees are useful for working with nested collections of Tensors. For example,
one can use `tree_map` to map a function over all Tensors inside some nested
collection of Tensors and `tree_unflatten` to get a flat list of all Tensors
inside some nested collection. pytrees are helpful for implementing nested
collection support for PyTorch APIs.
This pytree implementation is not very performant due to Python overhead
To improve the performance we can move parts of the implementation to C++.
"""
# A NodeDef holds two callables:
# - flatten_fn should take the collection and return a flat list of values.
# It can also return some context that is used in reconstructing the
# collection.
# - unflatten_fn should take a flat list of values and some context
# (returned by flatten_fn). It returns the collection by reconstructing
# it from the list and the context.
Context = Any
PyTree = Any
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
UnflattenFunc = Callable[[List, Context], PyTree]
class NodeDef(NamedTuple):
flatten_fn: FlattenFunc
unflatten_fn: UnflattenFunc
SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
def _register_pytree_node(typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc) -> None:
SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn)
def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
return list(d.values()), list(d.keys())
def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
return {key: value for key, value in zip(context, values)}
def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
return d, None
def _list_unflatten(values: List[Any], context: Context) -> List[Any]:
return list(values)
def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]:
return list(d), None
def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]:
return tuple(values)
def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]:
return list(d), type(d)
def _namedtuple_unflatten(values: List[Any], context: Context) -> NamedTuple:
return cast(NamedTuple, context(*values))
_register_pytree_node(dict, _dict_flatten, _dict_unflatten)
_register_pytree_node(list, _list_flatten, _list_unflatten)
_register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten)
_register_pytree_node(namedtuple, _namedtuple_flatten, _namedtuple_unflatten)
# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
def _is_namedtuple_instance(pytree: Any) -> bool:
typ = type(pytree)
bases = typ.__bases__
if len(bases) != 1 or bases[0] != tuple:
return False
fields = getattr(typ, '_fields', None)
if not isinstance(fields, tuple):
return False
return all(type(entry) == str for entry in fields)
def _get_node_type(pytree: Any) -> Any:
if _is_namedtuple_instance(pytree):
return namedtuple
return type(pytree)
# A leaf is defined as anything that is not a Node.
def _is_leaf(pytree: PyTree) -> bool:
return _get_node_type(pytree) not in SUPPORTED_NODES.keys()
# A TreeSpec represents the structure of a pytree. It holds:
# "type": the type of root Node of the pytree
# context: some context that is useful in unflattening the pytree
# children_specs: specs for each child of the root Node
# num_leaves: the number of leaves
class TreeSpec:
def __init__(self, typ: Any, context: Context, children_specs: List['TreeSpec']) -> None:
self.type = typ
self.context = context
self.children_specs = children_specs
self.num_leaves: int = sum([spec.num_leaves for spec in children_specs])
def __repr__(self) -> str:
return f'TreeSpec({self.type.__name__}, {self.context}, {self.children_specs})'
def __eq__(self, other: Any) -> bool:
result = self.type == other.type and self.context == other.context \
and self.children_specs == other.children_specs \
and self.num_leaves == other.num_leaves
# This should really not be necessary, but mypy errors out without it.
return cast(bool, result)
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
class LeafSpec(TreeSpec):
def __init__(self) -> None:
super().__init__(None, None, [])
self.num_leaves = 1
def __repr__(self) -> str:
return '*'
def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
"""Flattens a pytree into a list of values and a TreeSpec that can be used
to reconstruct the pytree.
"""
if _is_leaf(pytree):
return [pytree], LeafSpec()
node_type = _get_node_type(pytree)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(pytree)
# Recursively flatten the children
result : List[Any] = []
children_specs : List['TreeSpec'] = []
for child in child_pytrees:
flat, child_spec = tree_flatten(child)
result += flat
children_specs.append(child_spec)
return result, TreeSpec(node_type, context, children_specs)
def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree:
"""Given a list of values and a TreeSpec, builds a pytree.
This is the inverse operation of `tree_flatten`.
"""
if not isinstance(spec, TreeSpec):
raise ValueError(
f'tree_unflatten(values, spec): Expected `spec` to be instance of '
f'TreeSpec but got item of type {type(spec)}.')
if len(values) != spec.num_leaves:
raise ValueError(
f'tree_unflatten(values, spec): `values` has length {len(values)} '
f'but the spec refers to a pytree that holds {spec.num_leaves} '
f'items ({spec}).')
if isinstance(spec, LeafSpec):
return values[0]
unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn
# Recursively unflatten the children
start = 0
end = 0
child_pytrees = []
for child_spec in spec.children_specs:
end += child_spec.num_leaves
child_pytrees.append(tree_unflatten(values[start:end], child_spec))
start = end
return unflatten_fn(child_pytrees, spec.context)
def tree_map(fn: Any, pytree: PyTree) -> PyTree:
flat_args, spec = tree_flatten(pytree)
return tree_unflatten([fn(i) for i in flat_args], spec)
# Broadcasts a pytree to the provided TreeSpec and returns the flattened
# values. If this is not possible, then this function returns None.
#
# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
# would return [0, 0]. This is useful for part of the vmap implementation:
# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
# broadcastable to the tree structure of `inputs` and we use
# _broadcast_to_and_flatten to check this.
def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[Any]]:
assert isinstance(spec, TreeSpec)
if _is_leaf(pytree):
return [pytree] * spec.num_leaves
if isinstance(spec, LeafSpec):
return None
node_type = _get_node_type(pytree)
if node_type != spec.type:
return None
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, ctx = flatten_fn(pytree)
# Check if the Node is different from the spec
if len(child_pytrees) != len(spec.children_specs) or ctx != spec.context:
return None
# Recursively flatten the children
result : List[Any] = []
for child, child_spec in zip(child_pytrees, spec.children_specs):
flat = _broadcast_to_and_flatten(child, child_spec)
if flat is not None:
result += flat
else:
return None
return result