pytorch/torch/utils/_pytree.py
Edward Z. Yang 5c6f5439b7 Implement SymBool (#92149)
We have known for a while that we should in principle support SymBool as a separate concept from SymInt and SymFloat ( in particular, every distinct numeric type should get its own API). However, recent work with unbacked SymInts in, e.g., https://github.com/pytorch/pytorch/pull/90985 have made this a priority to implement. The essential problem is that our logic for computing the contiguity of tensors performs branches on the passed in input sizes, and this causes us to require guards when constructing tensors from unbacked SymInts. Morally, this should not be a big deal because, we only really care about the regular (non-channels-last) contiguity of the tensor, which should be guaranteed since most people aren't calling `empty_strided` on the tensor, however, because we store a bool (not a SymBool, prior to this PR it doesn't exist) on TensorImpl, we are forced to *immediately* compute these values, even if the value ends up not being used at all. In particular, even when a user allocates a contiguous tensor, we still must compute channels-last contiguity (as some contiguous tensors are also channels-last contiguous, but others are not.)

This PR implements SymBool, and makes TensorImpl use SymBool to store the contiguity information in ExtraMeta. There are a number of knock on effects, which I now discuss below.

* I introduce a new C++ type SymBool, analogous to SymInt and SymFloat. This type supports logical and, logical or and logical negation. I support the bitwise operations on this class (but not the conventional logic operators) to make it clear that logical operations on SymBool are NOT short-circuiting. I also, for now, do NOT support implicit conversion of SymBool to bool (creating a guard in this case). This does matter too much in practice, as in this PR I did not modify the equality operations (e.g., `==` on SymInt) to return SymBool, so all preexisting implicit guards did not need to be changed. I also introduced symbolic comparison functions `sym_eq`, etc. on SymInt to make it possible to create SymBool. The current implementation of comparison functions makes it unfortunately easy to accidentally introduce guards when you do not mean to (as both `s0 == s1` and `s0.sym_eq(s1)` are valid spellings of equality operation); in the short term, I intend to prevent excess guarding in this situation by unit testing; in the long term making the equality operators return SymBool is probably the correct fix.
* ~~I modify TensorImpl to store SymBool for the `is_contiguous` fields and friends on `ExtraMeta`. In practice, this essentially meant reverting most of the changes from https://github.com/pytorch/pytorch/pull/85936 . In particular, the fields on ExtraMeta are no longer strongly typed; at the time I was particularly concerned about the giant lambda I was using as the setter getting a desynchronized argument order, but now that I have individual setters for each field the only "big list" of boolean arguments is in the constructor of ExtraMeta, which seems like an acceptable risk. The semantics of TensorImpl are now that we guard only when you actually attempt to access the contiguity of the tensor via, e.g., `is_contiguous`. By in large, the contiguity calculation in the implementations now needs to be duplicated (as the boolean version can short circuit, but the SymBool version cannot); you should carefully review the duplicate new implementations. I typically use the `identity` template to disambiguate which version of the function I need, and rely on overloading to allow for implementation sharing. The changes to the `compute_` functions are particularly interesting; for most of the functions, I preserved their original non-symbolic implementation, and then introduce a new symbolic implementation that is branch-less (making use of our new SymBool operations). However, `compute_non_overlapping_and_dense` is special, see next bullet.~~ This appears to cause performance problems, so I am leaving this to an update PR.
* (Update: the Python side pieces for this are still in this PR, but they are not wired up until later PRs.) While the contiguity calculations are relatively easy to write in a branch-free way, `compute_non_overlapping_and_dense` is not: it involves a sort on the strides. While in principle we can still make it go through by using a data oblivious sorting network, this seems like too much complication for a field that is likely never used (because typically, it will be obvious that a tensor is non overlapping and dense, because the tensor is contiguous.) So we take a different approach: instead of trying to trace through the logic computation of non-overlapping and dense, we instead introduce a new opaque operator IsNonOverlappingAndDenseIndicator which represents all of the compute that would have been done here. This function returns an integer 0 if `is_non_overlapping_and_dense` would have returned `False`, and an integer 1 otherwise, for technical reasons (Sympy does not easily allow defining custom functions that return booleans). The function itself only knows how to evaluate itself if all of its arguments are integers; otherwise it is left unevaluated. This means we can always guard on it (as `size_hint` will always be able to evaluate through it), but otherwise its insides are left a black box. We typically do NOT expect this custom function to show up in actual boolean expressions, because we will typically shortcut it due to the tensor being contiguous. It's possible we should apply this treatment to all of the other `compute_` operations, more investigation necessary. As a technical note, because this operator takes a pair of a list of SymInts, we need to support converting `ArrayRef<SymNode>` to Python, and I also unpack the pair of lists into a single list because I don't know if Sympy operations can actually validly take lists of Sympy expressions as inputs. See for example `_make_node_sizes_strides`
* On the Python side, we also introduce a SymBool class, and update SymNode to track bool as a valid pytype. There is some subtlety here: bool is a subclass of int, so one has to be careful about `isinstance` checks (in fact, in most cases I replaced `isinstance(x, int)` with `type(x) is int` for expressly this reason.) Additionally, unlike, C++, I do NOT define bitwise inverse on SymBool, because it does not do the correct thing when run on booleans, e.g., `~True` is `-2`. (For that matter, they don't do the right thing in C++ either, but at least in principle the compiler can warn you about it with `-Wbool-operation`, and so the rule is simple in C++; only use logical operations if the types are statically known to be SymBool). Alas, logical negation is not overrideable, so we have to introduce `sym_not` which must be used in place of `not` whenever a SymBool can turn up. To avoid confusion with `__not__` which may imply that `operators.__not__` might be acceptable to use (it isn't), our magic method is called `__sym_not__`. The other bitwise operators `&` and `|` do the right thing with booleans and are acceptable to use.
* There is some annoyance working with booleans in Sympy. Unlike int and float, booleans live in their own algebra and they support less operations than regular numbers. In particular, `sympy.expand` does not work on them. To get around this, I introduce `safe_expand` which only calls expand on operations which are known to be expandable.

TODO: this PR appears to greatly regress performance of symbolic reasoning. In particular, `python test/functorch/test_aotdispatch.py -k max_pool2d` performs really poorly with these changes. Need to investigate.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92149
Approved by: https://github.com/albanD, https://github.com/Skylion007
2023-01-21 02:21:56 +00:00

340 lines
11 KiB
Python

from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, Optional, TypeVar, overload, Union
import functools
from collections import namedtuple, OrderedDict
from dataclasses import dataclass
T = TypeVar('T')
S = TypeVar('S')
U = TypeVar('U')
R = TypeVar('R')
"""
Contains utility functions for working with nested python data structures.
A *pytree* is Python nested data structure. It is a tree in the sense that
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
Python values. Furthermore, a pytree should not contain reference cycles.
pytrees are useful for working with nested collections of Tensors. For example,
one can use `tree_map` to map a function over all Tensors inside some nested
collection of Tensors and `tree_unflatten` to get a flat list of all Tensors
inside some nested collection. pytrees are helpful for implementing nested
collection support for PyTorch APIs.
This pytree implementation is not very performant due to Python overhead
To improve the performance we can move parts of the implementation to C++.
"""
# A NodeDef holds two callables:
# - flatten_fn should take the collection and return a flat list of values.
# It can also return some context that is used in reconstructing the
# collection.
# - unflatten_fn should take a flat list of values and some context
# (returned by flatten_fn). It returns the collection by reconstructing
# it from the list and the context.
Context = Any
PyTree = Any
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
UnflattenFunc = Callable[[List, Context], PyTree]
class NodeDef(NamedTuple):
flatten_fn: FlattenFunc
unflatten_fn: UnflattenFunc
SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
def _register_pytree_node(typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc) -> None:
SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn)
def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
return list(d.values()), list(d.keys())
def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
return {key: value for key, value in zip(context, values)}
def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
return d, None
def _list_unflatten(values: List[Any], context: Context) -> List[Any]:
return list(values)
def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]:
return list(d), None
def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]:
return tuple(values)
def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]:
return list(d), type(d)
def _namedtuple_unflatten(values: List[Any], context: Context) -> NamedTuple:
return cast(NamedTuple, context(*values))
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Context]:
return list(d.values()), list(d.keys())
def _odict_unflatten(values: List[Any], context: Context) -> 'OrderedDict[Any, Any]':
return OrderedDict((key, value) for key, value in zip(context, values))
_register_pytree_node(dict, _dict_flatten, _dict_unflatten)
_register_pytree_node(list, _list_flatten, _list_unflatten)
_register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten)
_register_pytree_node(namedtuple, _namedtuple_flatten, _namedtuple_unflatten)
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
def _is_namedtuple_instance(pytree: Any) -> bool:
typ = type(pytree)
bases = typ.__bases__
if len(bases) != 1 or bases[0] != tuple:
return False
fields = getattr(typ, '_fields', None)
if not isinstance(fields, tuple):
return False
return all(type(entry) == str for entry in fields)
def _get_node_type(pytree: Any) -> Any:
if _is_namedtuple_instance(pytree):
return namedtuple
return type(pytree)
# A leaf is defined as anything that is not a Node.
def _is_leaf(pytree: PyTree) -> bool:
return _get_node_type(pytree) not in SUPPORTED_NODES.keys()
# A TreeSpec represents the structure of a pytree. It holds:
# "type": the type of root Node of the pytree
# context: some context that is useful in unflattening the pytree
# children_specs: specs for each child of the root Node
# num_leaves: the number of leaves
@dataclass
class TreeSpec:
type: Any
context: Context
children_specs: List['TreeSpec']
def __post_init__(self) -> None:
self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs])
def __repr__(self, indent: int = 0) -> str:
repr_prefix: str = f'TreeSpec({self.type.__name__}, {self.context}, ['
children_specs_str: str = ''
if len(self.children_specs):
indent += len(repr_prefix)
children_specs_str += self.children_specs[0].__repr__(indent)
children_specs_str += ',' if len(self.children_specs) > 1 else ''
children_specs_str += ','.join(['\n' + ' ' * indent + child.__repr__(indent) for child in self.children_specs[1:]])
repr_suffix: str = f'{children_specs_str}])'
return repr_prefix + repr_suffix
class LeafSpec(TreeSpec):
def __init__(self) -> None:
super().__init__(None, None, [])
self.num_leaves = 1
def __repr__(self, indent: int = 0) -> str:
return '*'
def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
"""Flattens a pytree into a list of values and a TreeSpec that can be used
to reconstruct the pytree.
"""
if _is_leaf(pytree):
return [pytree], LeafSpec()
node_type = _get_node_type(pytree)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(pytree)
# Recursively flatten the children
result : List[Any] = []
children_specs : List['TreeSpec'] = []
for child in child_pytrees:
flat, child_spec = tree_flatten(child)
result += flat
children_specs.append(child_spec)
return result, TreeSpec(node_type, context, children_specs)
def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree:
"""Given a list of values and a TreeSpec, builds a pytree.
This is the inverse operation of `tree_flatten`.
"""
if not isinstance(spec, TreeSpec):
raise ValueError(
f'tree_unflatten(values, spec): Expected `spec` to be instance of '
f'TreeSpec but got item of type {type(spec)}.')
if len(values) != spec.num_leaves:
raise ValueError(
f'tree_unflatten(values, spec): `values` has length {len(values)} '
f'but the spec refers to a pytree that holds {spec.num_leaves} '
f'items ({spec}).')
if isinstance(spec, LeafSpec):
return values[0]
unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn
# Recursively unflatten the children
start = 0
end = 0
child_pytrees = []
for child_spec in spec.children_specs:
end += child_spec.num_leaves
child_pytrees.append(tree_unflatten(values[start:end], child_spec))
start = end
return unflatten_fn(child_pytrees, spec.context)
def tree_map(fn: Any, pytree: PyTree) -> PyTree:
flat_args, spec = tree_flatten(pytree)
return tree_unflatten([fn(i) for i in flat_args], spec)
Type2 = Tuple[Type[T], Type[S]]
Type3 = Tuple[Type[T], Type[S], Type[U]]
TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
Fn3 = Callable[[Union[T, S, U]], R]
Fn2 = Callable[[Union[T, S]], R]
Fn = Callable[[T], R]
FnAny = Callable[[Any], R]
MapOnlyFn = Callable[[T], Callable[[Any], Any]]
# These specializations help with type inference on the lambda passed to this
# function
@overload
def map_only(ty: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
...
@overload
def map_only(ty: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
...
# This specialization is needed for the implementations below that call
@overload
def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]:
...
def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]:
"""
Suppose you are writing a tree_map over tensors, leaving everything
else unchanged. Ordinarily you would have to write:
def go(t):
if isinstance(t, Tensor):
return ...
else:
return t
With this function, you only need to write:
@map_only(Tensor)
def go(t):
return ...
You can also directly use 'tree_map_only'
"""
def deco(f: Callable[[T], Any]) -> Callable[[Any], Any]:
@functools.wraps(f)
def inner(x: T) -> Any:
if isinstance(x, ty):
return f(x)
else:
return x
return inner
return deco
@overload
def tree_map_only(ty: Type[T], fn: Fn[T, Any], pytree: PyTree) -> PyTree:
...
@overload
def tree_map_only(ty: Type2[T, S], fn: Fn2[T, S, Any], pytree: PyTree) -> PyTree:
...
@overload
def tree_map_only(ty: Type3[T, S, U], fn: Fn3[T, S, U, Any], pytree: PyTree) -> PyTree:
...
def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree:
return tree_map(map_only(ty)(fn), pytree)
def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool:
flat_args, _ = tree_flatten(pytree)
return all(map(pred, flat_args))
def tree_any(pred: Callable[[Any], bool], pytree: PyTree) -> bool:
flat_args, _ = tree_flatten(pytree)
return any(map(pred, flat_args))
@overload
def tree_all_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool:
...
@overload
def tree_all_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool:
...
@overload
def tree_all_only(ty: Type3[T, S, U], pred: Fn3[T, S, U, bool], pytree: PyTree) -> bool:
...
def tree_all_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool:
flat_args, _ = tree_flatten(pytree)
return all(pred(x) for x in flat_args if isinstance(x, ty))
@overload
def tree_any_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool:
...
@overload
def tree_any_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool:
...
def tree_any_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool:
flat_args, _ = tree_flatten(pytree)
return any(pred(x) for x in flat_args if isinstance(x, ty))
# Broadcasts a pytree to the provided TreeSpec and returns the flattened
# values. If this is not possible, then this function returns None.
#
# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
# would return [0, 0]. This is useful for part of the vmap implementation:
# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
# broadcastable to the tree structure of `inputs` and we use
# _broadcast_to_and_flatten to check this.
def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[Any]]:
assert isinstance(spec, TreeSpec)
if _is_leaf(pytree):
return [pytree] * spec.num_leaves
if isinstance(spec, LeafSpec):
return None
node_type = _get_node_type(pytree)
if node_type != spec.type:
return None
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, ctx = flatten_fn(pytree)
# Check if the Node is different from the spec
if len(child_pytrees) != len(spec.children_specs) or ctx != spec.context:
return None
# Recursively flatten the children
result : List[Any] = []
for child, child_spec in zip(child_pytrees, spec.children_specs):
flat = _broadcast_to_and_flatten(child, child_spec)
if flat is not None:
result += flat
else:
return None
return result