Commit Graph

113 Commits

Author SHA1 Message Date
Aaron Gokaslan
5471621497 [BE] Remove unnecessary dict comprehensions (#97116)
Removes unnecessary dict comprehensions that optimize creation of dicts from iterables

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97116
Approved by: https://github.com/kit1980
2023-03-20 00:56:57 +00:00
Edward Z. Yang
5c6f5439b7 Implement SymBool (#92149)
We have known for a while that we should in principle support SymBool as a separate concept from SymInt and SymFloat ( in particular, every distinct numeric type should get its own API). However, recent work with unbacked SymInts in, e.g., https://github.com/pytorch/pytorch/pull/90985 have made this a priority to implement. The essential problem is that our logic for computing the contiguity of tensors performs branches on the passed in input sizes, and this causes us to require guards when constructing tensors from unbacked SymInts. Morally, this should not be a big deal because, we only really care about the regular (non-channels-last) contiguity of the tensor, which should be guaranteed since most people aren't calling `empty_strided` on the tensor, however, because we store a bool (not a SymBool, prior to this PR it doesn't exist) on TensorImpl, we are forced to *immediately* compute these values, even if the value ends up not being used at all. In particular, even when a user allocates a contiguous tensor, we still must compute channels-last contiguity (as some contiguous tensors are also channels-last contiguous, but others are not.)

This PR implements SymBool, and makes TensorImpl use SymBool to store the contiguity information in ExtraMeta. There are a number of knock on effects, which I now discuss below.

* I introduce a new C++ type SymBool, analogous to SymInt and SymFloat. This type supports logical and, logical or and logical negation. I support the bitwise operations on this class (but not the conventional logic operators) to make it clear that logical operations on SymBool are NOT short-circuiting. I also, for now, do NOT support implicit conversion of SymBool to bool (creating a guard in this case). This does matter too much in practice, as in this PR I did not modify the equality operations (e.g., `==` on SymInt) to return SymBool, so all preexisting implicit guards did not need to be changed. I also introduced symbolic comparison functions `sym_eq`, etc. on SymInt to make it possible to create SymBool. The current implementation of comparison functions makes it unfortunately easy to accidentally introduce guards when you do not mean to (as both `s0 == s1` and `s0.sym_eq(s1)` are valid spellings of equality operation); in the short term, I intend to prevent excess guarding in this situation by unit testing; in the long term making the equality operators return SymBool is probably the correct fix.
* ~~I modify TensorImpl to store SymBool for the `is_contiguous` fields and friends on `ExtraMeta`. In practice, this essentially meant reverting most of the changes from https://github.com/pytorch/pytorch/pull/85936 . In particular, the fields on ExtraMeta are no longer strongly typed; at the time I was particularly concerned about the giant lambda I was using as the setter getting a desynchronized argument order, but now that I have individual setters for each field the only "big list" of boolean arguments is in the constructor of ExtraMeta, which seems like an acceptable risk. The semantics of TensorImpl are now that we guard only when you actually attempt to access the contiguity of the tensor via, e.g., `is_contiguous`. By in large, the contiguity calculation in the implementations now needs to be duplicated (as the boolean version can short circuit, but the SymBool version cannot); you should carefully review the duplicate new implementations. I typically use the `identity` template to disambiguate which version of the function I need, and rely on overloading to allow for implementation sharing. The changes to the `compute_` functions are particularly interesting; for most of the functions, I preserved their original non-symbolic implementation, and then introduce a new symbolic implementation that is branch-less (making use of our new SymBool operations). However, `compute_non_overlapping_and_dense` is special, see next bullet.~~ This appears to cause performance problems, so I am leaving this to an update PR.
* (Update: the Python side pieces for this are still in this PR, but they are not wired up until later PRs.) While the contiguity calculations are relatively easy to write in a branch-free way, `compute_non_overlapping_and_dense` is not: it involves a sort on the strides. While in principle we can still make it go through by using a data oblivious sorting network, this seems like too much complication for a field that is likely never used (because typically, it will be obvious that a tensor is non overlapping and dense, because the tensor is contiguous.) So we take a different approach: instead of trying to trace through the logic computation of non-overlapping and dense, we instead introduce a new opaque operator IsNonOverlappingAndDenseIndicator which represents all of the compute that would have been done here. This function returns an integer 0 if `is_non_overlapping_and_dense` would have returned `False`, and an integer 1 otherwise, for technical reasons (Sympy does not easily allow defining custom functions that return booleans). The function itself only knows how to evaluate itself if all of its arguments are integers; otherwise it is left unevaluated. This means we can always guard on it (as `size_hint` will always be able to evaluate through it), but otherwise its insides are left a black box. We typically do NOT expect this custom function to show up in actual boolean expressions, because we will typically shortcut it due to the tensor being contiguous. It's possible we should apply this treatment to all of the other `compute_` operations, more investigation necessary. As a technical note, because this operator takes a pair of a list of SymInts, we need to support converting `ArrayRef<SymNode>` to Python, and I also unpack the pair of lists into a single list because I don't know if Sympy operations can actually validly take lists of Sympy expressions as inputs. See for example `_make_node_sizes_strides`
* On the Python side, we also introduce a SymBool class, and update SymNode to track bool as a valid pytype. There is some subtlety here: bool is a subclass of int, so one has to be careful about `isinstance` checks (in fact, in most cases I replaced `isinstance(x, int)` with `type(x) is int` for expressly this reason.) Additionally, unlike, C++, I do NOT define bitwise inverse on SymBool, because it does not do the correct thing when run on booleans, e.g., `~True` is `-2`. (For that matter, they don't do the right thing in C++ either, but at least in principle the compiler can warn you about it with `-Wbool-operation`, and so the rule is simple in C++; only use logical operations if the types are statically known to be SymBool). Alas, logical negation is not overrideable, so we have to introduce `sym_not` which must be used in place of `not` whenever a SymBool can turn up. To avoid confusion with `__not__` which may imply that `operators.__not__` might be acceptable to use (it isn't), our magic method is called `__sym_not__`. The other bitwise operators `&` and `|` do the right thing with booleans and are acceptable to use.
* There is some annoyance working with booleans in Sympy. Unlike int and float, booleans live in their own algebra and they support less operations than regular numbers. In particular, `sympy.expand` does not work on them. To get around this, I introduce `safe_expand` which only calls expand on operations which are known to be expandable.

TODO: this PR appears to greatly regress performance of symbolic reasoning. In particular, `python test/functorch/test_aotdispatch.py -k max_pool2d` performs really poorly with these changes. Need to investigate.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92149
Approved by: https://github.com/albanD, https://github.com/Skylion007
2023-01-21 02:21:56 +00:00
hxu296
09a967d6c9 Make nested TreeSpec printing nicer (#46538) (#86546)
1. Made TreeSpec into a dataclass.
2. In `__repr__`, recursively transformed TreeSpec into dictionaries and then pretty-printed it.

Fixes #46538. Hi, @ezyang. this PR is for the TreeSpec `__repr__` refactor we discussed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86546
Approved by: https://github.com/ezyang
2022-10-18 16:50:39 +00:00
Edward Z. Yang
2a332afbf4 Add SymFloat, support SymInt to SymFloat conversion (#84284)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84284
Approved by: https://github.com/albanD
2022-09-03 01:30:32 +00:00
Edward Z. Yang
b8b54eccd2 Add *_only and all/any pytree utilities (#83316)
With a sample usage in proxy tensor to show how they can shorten
your code dramatically.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83316
Approved by: https://github.com/zou3519, https://github.com/albanD, https://github.com/bdhirsh
2022-08-12 17:31:55 +00:00
Richard Zou
6700a78504 Move vmap's OrderedDict pytree support to torch.utils._pytree (#83073)
There's no reason why it should just apply when you import vmap

Test Plan:
- added a new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83073
Approved by: https://github.com/Chillee
2022-08-11 03:00:55 +00:00
Richard Zou
52d1ffb789 Teach pytrees about namedtuple (#62292)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62292

This PR adds pytree support for namedtuples. The challenge about namedtuple
is that each namedtuple class is actually different. This PR does the
following:
- it adds a namedtuple flatten/unflatten. The flatten function returns
a context that is the actual type of the namedtuple subclass. The
unflatten function uses that type to reconstruct the namedtuple
- Special cases all pytree logic to consider all namedtuples the same.
This is done by creating a `_get_node_type(pytree)` helper function that
returns `namedtuple` if `pytree` is any namedtuple subclass. The effect
of this is that all namedtuple subclasses will go through the namedtuple
flatten/unflatten functions
- Adds a `_namedtuple_flatten_spec` function for FX pytrees. This function
flattens the namedtuple based on the spec and is equivalent to the
`_tuple_flatten_spec`.

Test Plan
- new tests in test/test_pytree.py and test/test_fx.py

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D29947302

Pulled By: zou3519

fbshipit-source-id: 19c00665b13546642c315df0f243ad99b8e7ff7c
2021-07-28 06:27:44 -07:00
Horace He
8d363d37da [FX] Adds PyTree support to FX through concrete_args (#55888)
Summary:
```
class Foo(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y, x):
        for k in x:
            for v in x[k]:
                v += y
        return x

example_dict = {'x': {'a': [fx.HOLE], 'z': [fx.HOLE, fx.HOLE]}}
new_f = fx.symbolic_trace(Foo(), concrete_args=example_dict)
print(new_f.code)
new_f(torch.randn(5), {'x': {'a': [torch.randn(5)], 'z': [torch.randn(5), torch.randn(5)]}})

fx.symbolic_trace(new_f, concrete_args=example_dict)
```

prints out
```
def forward(self, y, x):
    y, tree_2, tree_3, tree_4 = pytree.tree_flatten([y, x])[0]
    add = tree_2 + y
    add_1 = tree_3 + y
    add_2 = tree_4 + y;  y = None
    return {'a': [tree_2], 'z': [tree_3, tree_4]}
```

Currently, I store `in_spec` as an extra attribute on `fx.Graph`, and then include it when we do the codegen. I'm not sure if this is the right approach - it introduces a divergence between what's in `fx.Graph` and what's in the python code.

Perhaps the best API is something explicit like `fx.Graph.flatten_args`, but that does make calling things a bit ... more verbose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55888

Reviewed By: jamesr66a

Differential Revision: D27884694

Pulled By: Chillee

fbshipit-source-id: f9e8a70c63a8df63c9f9bd0a6459255daa5a8df8
2021-05-07 04:48:35 -07:00
Jerry Zhang
b8d98f05e7 [reland][quant][docs] Add fx graph mode quantization to quantization docs (#49211) (#49515)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49515

Test Plan:
Imported from OSS

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D25601061

fbshipit-source-id: 74e917d57895e9b4131a01fdcea8df3e94322bec
2020-12-17 10:30:10 -08:00
Mike Ruberry
676bfa6dbd Revert D25507480: [quant][docs] Add fx graph mode quantization to quantization docs
Test Plan: revert-hammer

Differential Revision:
D25507480 (7729581414)

Original commit changeset: 9e9e4b5fef97

fbshipit-source-id: fdb08d824209b97defaba2e207d1a914575a6ae7
2020-12-16 14:26:18 -08:00
Jerry Zhang
7729581414 [quant][docs] Add fx graph mode quantization to quantization docs (#49211)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49211

Test Plan: Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D25507480

fbshipit-source-id: 9e9e4b5fef979f5621c1bbd1b49e9cc6830da617
2020-12-16 12:40:02 -08:00
Richard Zou
6025f8148a Implement _broadcast_to_and_flatten(pytree, spec) (#46288)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46288

This "broadcasts" `pytree` to have the same structure as `spec`
and then flattens it.
I find it hard to describe what that does in words, so here's an example:

- Broadcasting 1 to have the same structure as [0, [0, 0]] would
return [1, [1, 1]]. Further flattening it gives us [1, 1, 1].
- Broadcasting [1, 2] to have the same structure as [0, [0, 0]] would
return [1, [2, 2]]. Further flattening it gives us [1, 2, 2].

What is this used for?
----------------------
The next PR up in the stack uses this helper function to allow vmap to
accept nested data structures. `vmap(fn, in_dims)(*inputs)` allows the
user to specify in_dims with a tree structure that is a sub-graph of
that of `inputs` (where both contain the root of the tree).

For example, one can do `vmap(fn, in_dims=0)(x, y, z)`. `in_dims` is 0
and inputs is (x, y, z). We would like to broadcast in_dims up to the
structure of inputs to get (0, 0, 0).

Another example, is `vmap(fn, in_dims=(0, 1))(x, [y, z])`. `in_dims` is
(0, 1) and inputs is (x, [y, z]). We would like to broadcast in_dims up
to the structure of inputs to get (0, [1, 1]); this value of in_dims is
used to say "let's vmap over dim 0 for x and dim 1 for y and z".

Test Plan
---------
New tests.

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D24392891

Pulled By: zou3519

fbshipit-source-id: 6f494d8b6359582f1b4ab6b8dd6a956d8bfe8ed4
2020-10-20 07:52:14 -07:00
Richard Zou
0285618a11 Add utilities to support handling of nested python data structures (#46287)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46287

This adds a lightweight `pytree` implementation that is similar to and
inspired by JAX pytrees, tensorflow.nest, deepmind/tree,
TorchBeast's TensorNest, etc.

A *pytree* is Python nested data structure. It is a tree in the sense
that nodes are Python collections (e.g., list, tuple, dict) and the leaves
are Python values. Furthermore, a pytree should not contain reference
cycles.

This PR:
- adds support for flattening and unflattening nested Python list/dict/tuples

Context: nested Tensor inputs for vmap
--------------------------------------
Right now, vmap is restricted to taking in flat lists of tensors. This
is because vmap needs to be able to convert every tensor in the input
that is being vmapped over into a BatchedTensor.

With a pytree library, we can simply flatten the input data structure
(returning the leaves), map all of the Tensors in the flat input to
BatchedTensors, and unflatten the flat list of BatchedTensors into a new
input. Or equivalently, with a `tree_map` function, we can map a nested
python data structure containing Tensors into one containing
BatchedTensors.

Future work
-----------
In some future PRs, we'll add nested input support for vmap. The
prerequisites for that are:
- a `broadcast_to(small, big)` that broadcasts `small` up to `big`.
  This is for handling the in_dims to vmap: the in_dims structure must
  be compatible with the structure of the inputs.

Test Plan
---------
- New tests in test/test_pytree.py

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D24392890

Pulled By: zou3519

fbshipit-source-id: 7daf7430c5a38354e7d203a72882bd7a9b24cfb1
2020-10-20 07:45:45 -07:00