mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
We have known for a while that we should in principle support SymBool as a separate concept from SymInt and SymFloat ( in particular, every distinct numeric type should get its own API). However, recent work with unbacked SymInts in, e.g., https://github.com/pytorch/pytorch/pull/90985 have made this a priority to implement. The essential problem is that our logic for computing the contiguity of tensors performs branches on the passed in input sizes, and this causes us to require guards when constructing tensors from unbacked SymInts. Morally, this should not be a big deal because, we only really care about the regular (non-channels-last) contiguity of the tensor, which should be guaranteed since most people aren't calling `empty_strided` on the tensor, however, because we store a bool (not a SymBool, prior to this PR it doesn't exist) on TensorImpl, we are forced to *immediately* compute these values, even if the value ends up not being used at all. In particular, even when a user allocates a contiguous tensor, we still must compute channels-last contiguity (as some contiguous tensors are also channels-last contiguous, but others are not.) This PR implements SymBool, and makes TensorImpl use SymBool to store the contiguity information in ExtraMeta. There are a number of knock on effects, which I now discuss below. * I introduce a new C++ type SymBool, analogous to SymInt and SymFloat. This type supports logical and, logical or and logical negation. I support the bitwise operations on this class (but not the conventional logic operators) to make it clear that logical operations on SymBool are NOT short-circuiting. I also, for now, do NOT support implicit conversion of SymBool to bool (creating a guard in this case). This does matter too much in practice, as in this PR I did not modify the equality operations (e.g., `==` on SymInt) to return SymBool, so all preexisting implicit guards did not need to be changed. I also introduced symbolic comparison functions `sym_eq`, etc. on SymInt to make it possible to create SymBool. The current implementation of comparison functions makes it unfortunately easy to accidentally introduce guards when you do not mean to (as both `s0 == s1` and `s0.sym_eq(s1)` are valid spellings of equality operation); in the short term, I intend to prevent excess guarding in this situation by unit testing; in the long term making the equality operators return SymBool is probably the correct fix. * ~~I modify TensorImpl to store SymBool for the `is_contiguous` fields and friends on `ExtraMeta`. In practice, this essentially meant reverting most of the changes from https://github.com/pytorch/pytorch/pull/85936 . In particular, the fields on ExtraMeta are no longer strongly typed; at the time I was particularly concerned about the giant lambda I was using as the setter getting a desynchronized argument order, but now that I have individual setters for each field the only "big list" of boolean arguments is in the constructor of ExtraMeta, which seems like an acceptable risk. The semantics of TensorImpl are now that we guard only when you actually attempt to access the contiguity of the tensor via, e.g., `is_contiguous`. By in large, the contiguity calculation in the implementations now needs to be duplicated (as the boolean version can short circuit, but the SymBool version cannot); you should carefully review the duplicate new implementations. I typically use the `identity` template to disambiguate which version of the function I need, and rely on overloading to allow for implementation sharing. The changes to the `compute_` functions are particularly interesting; for most of the functions, I preserved their original non-symbolic implementation, and then introduce a new symbolic implementation that is branch-less (making use of our new SymBool operations). However, `compute_non_overlapping_and_dense` is special, see next bullet.~~ This appears to cause performance problems, so I am leaving this to an update PR. * (Update: the Python side pieces for this are still in this PR, but they are not wired up until later PRs.) While the contiguity calculations are relatively easy to write in a branch-free way, `compute_non_overlapping_and_dense` is not: it involves a sort on the strides. While in principle we can still make it go through by using a data oblivious sorting network, this seems like too much complication for a field that is likely never used (because typically, it will be obvious that a tensor is non overlapping and dense, because the tensor is contiguous.) So we take a different approach: instead of trying to trace through the logic computation of non-overlapping and dense, we instead introduce a new opaque operator IsNonOverlappingAndDenseIndicator which represents all of the compute that would have been done here. This function returns an integer 0 if `is_non_overlapping_and_dense` would have returned `False`, and an integer 1 otherwise, for technical reasons (Sympy does not easily allow defining custom functions that return booleans). The function itself only knows how to evaluate itself if all of its arguments are integers; otherwise it is left unevaluated. This means we can always guard on it (as `size_hint` will always be able to evaluate through it), but otherwise its insides are left a black box. We typically do NOT expect this custom function to show up in actual boolean expressions, because we will typically shortcut it due to the tensor being contiguous. It's possible we should apply this treatment to all of the other `compute_` operations, more investigation necessary. As a technical note, because this operator takes a pair of a list of SymInts, we need to support converting `ArrayRef<SymNode>` to Python, and I also unpack the pair of lists into a single list because I don't know if Sympy operations can actually validly take lists of Sympy expressions as inputs. See for example `_make_node_sizes_strides` * On the Python side, we also introduce a SymBool class, and update SymNode to track bool as a valid pytype. There is some subtlety here: bool is a subclass of int, so one has to be careful about `isinstance` checks (in fact, in most cases I replaced `isinstance(x, int)` with `type(x) is int` for expressly this reason.) Additionally, unlike, C++, I do NOT define bitwise inverse on SymBool, because it does not do the correct thing when run on booleans, e.g., `~True` is `-2`. (For that matter, they don't do the right thing in C++ either, but at least in principle the compiler can warn you about it with `-Wbool-operation`, and so the rule is simple in C++; only use logical operations if the types are statically known to be SymBool). Alas, logical negation is not overrideable, so we have to introduce `sym_not` which must be used in place of `not` whenever a SymBool can turn up. To avoid confusion with `__not__` which may imply that `operators.__not__` might be acceptable to use (it isn't), our magic method is called `__sym_not__`. The other bitwise operators `&` and `|` do the right thing with booleans and are acceptable to use. * There is some annoyance working with booleans in Sympy. Unlike int and float, booleans live in their own algebra and they support less operations than regular numbers. In particular, `sympy.expand` does not work on them. To get around this, I introduce `safe_expand` which only calls expand on operations which are known to be expandable. TODO: this PR appears to greatly regress performance of symbolic reasoning. In particular, `python test/functorch/test_aotdispatch.py -k max_pool2d` performs really poorly with these changes. Need to investigate. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/92149 Approved by: https://github.com/albanD, https://github.com/Skylion007
340 lines
11 KiB
Python
340 lines
11 KiB
Python
from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, Optional, TypeVar, overload, Union
|
|
import functools
|
|
from collections import namedtuple, OrderedDict
|
|
from dataclasses import dataclass
|
|
|
|
|
|
T = TypeVar('T')
|
|
S = TypeVar('S')
|
|
U = TypeVar('U')
|
|
R = TypeVar('R')
|
|
|
|
"""
|
|
Contains utility functions for working with nested python data structures.
|
|
|
|
A *pytree* is Python nested data structure. It is a tree in the sense that
|
|
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
|
|
Python values. Furthermore, a pytree should not contain reference cycles.
|
|
|
|
pytrees are useful for working with nested collections of Tensors. For example,
|
|
one can use `tree_map` to map a function over all Tensors inside some nested
|
|
collection of Tensors and `tree_unflatten` to get a flat list of all Tensors
|
|
inside some nested collection. pytrees are helpful for implementing nested
|
|
collection support for PyTorch APIs.
|
|
|
|
This pytree implementation is not very performant due to Python overhead
|
|
To improve the performance we can move parts of the implementation to C++.
|
|
"""
|
|
|
|
# A NodeDef holds two callables:
|
|
# - flatten_fn should take the collection and return a flat list of values.
|
|
# It can also return some context that is used in reconstructing the
|
|
# collection.
|
|
# - unflatten_fn should take a flat list of values and some context
|
|
# (returned by flatten_fn). It returns the collection by reconstructing
|
|
# it from the list and the context.
|
|
Context = Any
|
|
PyTree = Any
|
|
FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
|
|
UnflattenFunc = Callable[[List, Context], PyTree]
|
|
|
|
class NodeDef(NamedTuple):
|
|
flatten_fn: FlattenFunc
|
|
unflatten_fn: UnflattenFunc
|
|
|
|
SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
|
|
|
|
def _register_pytree_node(typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc) -> None:
|
|
SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn)
|
|
|
|
def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
|
|
return list(d.values()), list(d.keys())
|
|
|
|
def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
|
|
return {key: value for key, value in zip(context, values)}
|
|
|
|
def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
|
|
return d, None
|
|
|
|
def _list_unflatten(values: List[Any], context: Context) -> List[Any]:
|
|
return list(values)
|
|
|
|
def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]:
|
|
return list(d), None
|
|
|
|
def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]:
|
|
return tuple(values)
|
|
|
|
def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]:
|
|
return list(d), type(d)
|
|
|
|
def _namedtuple_unflatten(values: List[Any], context: Context) -> NamedTuple:
|
|
return cast(NamedTuple, context(*values))
|
|
|
|
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Context]:
|
|
return list(d.values()), list(d.keys())
|
|
|
|
def _odict_unflatten(values: List[Any], context: Context) -> 'OrderedDict[Any, Any]':
|
|
return OrderedDict((key, value) for key, value in zip(context, values))
|
|
|
|
|
|
_register_pytree_node(dict, _dict_flatten, _dict_unflatten)
|
|
_register_pytree_node(list, _list_flatten, _list_unflatten)
|
|
_register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten)
|
|
_register_pytree_node(namedtuple, _namedtuple_flatten, _namedtuple_unflatten)
|
|
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
|
|
|
|
|
|
# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
|
|
def _is_namedtuple_instance(pytree: Any) -> bool:
|
|
typ = type(pytree)
|
|
bases = typ.__bases__
|
|
if len(bases) != 1 or bases[0] != tuple:
|
|
return False
|
|
fields = getattr(typ, '_fields', None)
|
|
if not isinstance(fields, tuple):
|
|
return False
|
|
return all(type(entry) == str for entry in fields)
|
|
|
|
def _get_node_type(pytree: Any) -> Any:
|
|
if _is_namedtuple_instance(pytree):
|
|
return namedtuple
|
|
return type(pytree)
|
|
|
|
# A leaf is defined as anything that is not a Node.
|
|
def _is_leaf(pytree: PyTree) -> bool:
|
|
return _get_node_type(pytree) not in SUPPORTED_NODES.keys()
|
|
|
|
|
|
# A TreeSpec represents the structure of a pytree. It holds:
|
|
# "type": the type of root Node of the pytree
|
|
# context: some context that is useful in unflattening the pytree
|
|
# children_specs: specs for each child of the root Node
|
|
# num_leaves: the number of leaves
|
|
@dataclass
|
|
class TreeSpec:
|
|
type: Any
|
|
context: Context
|
|
children_specs: List['TreeSpec']
|
|
|
|
def __post_init__(self) -> None:
|
|
self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs])
|
|
|
|
def __repr__(self, indent: int = 0) -> str:
|
|
repr_prefix: str = f'TreeSpec({self.type.__name__}, {self.context}, ['
|
|
children_specs_str: str = ''
|
|
if len(self.children_specs):
|
|
indent += len(repr_prefix)
|
|
children_specs_str += self.children_specs[0].__repr__(indent)
|
|
children_specs_str += ',' if len(self.children_specs) > 1 else ''
|
|
children_specs_str += ','.join(['\n' + ' ' * indent + child.__repr__(indent) for child in self.children_specs[1:]])
|
|
repr_suffix: str = f'{children_specs_str}])'
|
|
return repr_prefix + repr_suffix
|
|
|
|
|
|
class LeafSpec(TreeSpec):
|
|
def __init__(self) -> None:
|
|
super().__init__(None, None, [])
|
|
self.num_leaves = 1
|
|
|
|
def __repr__(self, indent: int = 0) -> str:
|
|
return '*'
|
|
|
|
def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
|
|
"""Flattens a pytree into a list of values and a TreeSpec that can be used
|
|
to reconstruct the pytree.
|
|
"""
|
|
if _is_leaf(pytree):
|
|
return [pytree], LeafSpec()
|
|
|
|
node_type = _get_node_type(pytree)
|
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
|
child_pytrees, context = flatten_fn(pytree)
|
|
|
|
# Recursively flatten the children
|
|
result : List[Any] = []
|
|
children_specs : List['TreeSpec'] = []
|
|
for child in child_pytrees:
|
|
flat, child_spec = tree_flatten(child)
|
|
result += flat
|
|
children_specs.append(child_spec)
|
|
|
|
return result, TreeSpec(node_type, context, children_specs)
|
|
|
|
|
|
def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree:
|
|
"""Given a list of values and a TreeSpec, builds a pytree.
|
|
This is the inverse operation of `tree_flatten`.
|
|
"""
|
|
if not isinstance(spec, TreeSpec):
|
|
raise ValueError(
|
|
f'tree_unflatten(values, spec): Expected `spec` to be instance of '
|
|
f'TreeSpec but got item of type {type(spec)}.')
|
|
if len(values) != spec.num_leaves:
|
|
raise ValueError(
|
|
f'tree_unflatten(values, spec): `values` has length {len(values)} '
|
|
f'but the spec refers to a pytree that holds {spec.num_leaves} '
|
|
f'items ({spec}).')
|
|
if isinstance(spec, LeafSpec):
|
|
return values[0]
|
|
|
|
unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn
|
|
|
|
# Recursively unflatten the children
|
|
start = 0
|
|
end = 0
|
|
child_pytrees = []
|
|
for child_spec in spec.children_specs:
|
|
end += child_spec.num_leaves
|
|
child_pytrees.append(tree_unflatten(values[start:end], child_spec))
|
|
start = end
|
|
|
|
return unflatten_fn(child_pytrees, spec.context)
|
|
|
|
def tree_map(fn: Any, pytree: PyTree) -> PyTree:
|
|
flat_args, spec = tree_flatten(pytree)
|
|
return tree_unflatten([fn(i) for i in flat_args], spec)
|
|
|
|
Type2 = Tuple[Type[T], Type[S]]
|
|
Type3 = Tuple[Type[T], Type[S], Type[U]]
|
|
TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
|
|
|
|
Fn3 = Callable[[Union[T, S, U]], R]
|
|
Fn2 = Callable[[Union[T, S]], R]
|
|
Fn = Callable[[T], R]
|
|
FnAny = Callable[[Any], R]
|
|
|
|
MapOnlyFn = Callable[[T], Callable[[Any], Any]]
|
|
|
|
# These specializations help with type inference on the lambda passed to this
|
|
# function
|
|
@overload
|
|
def map_only(ty: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
|
|
...
|
|
|
|
@overload
|
|
def map_only(ty: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
|
|
...
|
|
|
|
# This specialization is needed for the implementations below that call
|
|
@overload
|
|
def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
|
...
|
|
|
|
def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
|
"""
|
|
Suppose you are writing a tree_map over tensors, leaving everything
|
|
else unchanged. Ordinarily you would have to write:
|
|
|
|
def go(t):
|
|
if isinstance(t, Tensor):
|
|
return ...
|
|
else:
|
|
return t
|
|
|
|
With this function, you only need to write:
|
|
|
|
@map_only(Tensor)
|
|
def go(t):
|
|
return ...
|
|
|
|
You can also directly use 'tree_map_only'
|
|
"""
|
|
def deco(f: Callable[[T], Any]) -> Callable[[Any], Any]:
|
|
@functools.wraps(f)
|
|
def inner(x: T) -> Any:
|
|
if isinstance(x, ty):
|
|
return f(x)
|
|
else:
|
|
return x
|
|
return inner
|
|
return deco
|
|
|
|
@overload
|
|
def tree_map_only(ty: Type[T], fn: Fn[T, Any], pytree: PyTree) -> PyTree:
|
|
...
|
|
|
|
@overload
|
|
def tree_map_only(ty: Type2[T, S], fn: Fn2[T, S, Any], pytree: PyTree) -> PyTree:
|
|
...
|
|
|
|
@overload
|
|
def tree_map_only(ty: Type3[T, S, U], fn: Fn3[T, S, U, Any], pytree: PyTree) -> PyTree:
|
|
...
|
|
|
|
def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree:
|
|
return tree_map(map_only(ty)(fn), pytree)
|
|
|
|
def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool:
|
|
flat_args, _ = tree_flatten(pytree)
|
|
return all(map(pred, flat_args))
|
|
|
|
def tree_any(pred: Callable[[Any], bool], pytree: PyTree) -> bool:
|
|
flat_args, _ = tree_flatten(pytree)
|
|
return any(map(pred, flat_args))
|
|
|
|
@overload
|
|
def tree_all_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool:
|
|
...
|
|
|
|
@overload
|
|
def tree_all_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool:
|
|
...
|
|
|
|
@overload
|
|
def tree_all_only(ty: Type3[T, S, U], pred: Fn3[T, S, U, bool], pytree: PyTree) -> bool:
|
|
...
|
|
|
|
def tree_all_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool:
|
|
flat_args, _ = tree_flatten(pytree)
|
|
return all(pred(x) for x in flat_args if isinstance(x, ty))
|
|
|
|
@overload
|
|
def tree_any_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool:
|
|
...
|
|
|
|
@overload
|
|
def tree_any_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool:
|
|
...
|
|
|
|
def tree_any_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool:
|
|
flat_args, _ = tree_flatten(pytree)
|
|
return any(pred(x) for x in flat_args if isinstance(x, ty))
|
|
|
|
# Broadcasts a pytree to the provided TreeSpec and returns the flattened
|
|
# values. If this is not possible, then this function returns None.
|
|
#
|
|
# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
|
|
# would return [0, 0]. This is useful for part of the vmap implementation:
|
|
# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
|
|
# broadcastable to the tree structure of `inputs` and we use
|
|
# _broadcast_to_and_flatten to check this.
|
|
def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[Any]]:
|
|
assert isinstance(spec, TreeSpec)
|
|
|
|
if _is_leaf(pytree):
|
|
return [pytree] * spec.num_leaves
|
|
if isinstance(spec, LeafSpec):
|
|
return None
|
|
node_type = _get_node_type(pytree)
|
|
if node_type != spec.type:
|
|
return None
|
|
|
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
|
child_pytrees, ctx = flatten_fn(pytree)
|
|
|
|
# Check if the Node is different from the spec
|
|
if len(child_pytrees) != len(spec.children_specs) or ctx != spec.context:
|
|
return None
|
|
|
|
# Recursively flatten the children
|
|
result : List[Any] = []
|
|
for child, child_spec in zip(child_pytrees, spec.children_specs):
|
|
flat = _broadcast_to_and_flatten(child, child_spec)
|
|
if flat is not None:
|
|
result += flat
|
|
else:
|
|
return None
|
|
|
|
return result
|