In this PR, we are implementing Functionalization on pre-dispatch graph. Today, every dispatch key except for Dispatchkey.Python has a dedicated mode stack in python. PreDispatch tracing relies on this behaviour by pushing ProxyTorchDispatchMode to Dispatchkey.PreDispatch mode stack and handle the dispatching logic in python. To make pre-dispatch functionalization work, we now need to push FunctionalTensorMode on DispatchKey.PreDispatch mode stack and make sure it runs before ProxyTorchDispatchMode. (this is very similar to how post-dispatch tracing work). Here are some design decisions we made for this flow to work:
1. FunctionalTensorMode internally calls C++ functionalize key. Since C++ functionalization goes after PreDispatch, if we are not careful, we will keep re-entering into PreDispatch key. We solve this by directly dispatching to C++ Functionalize key.
2. We delete mode_stack_per_key logic because the only realistic time it is exercised is for PreDispatch and it is in general not safe to have a plain list because FunctionalTensorMode and ProxyTorchDispatchMode ordering matter and it is hard to enforce it on plain list. Instead, now we have a private class that tracks PreDispatch mode stack.
3. We will still run CompositeImplicitAutograd decomps in this PR, and disable this logic later as a followup.
Some missing bits after this PR:
1. Preserving autograd ops in a functional form. Right now they still show up in the graph but in a "non-functional" way.
2. Turn off CompositeImplicitAutograd decomps
3. Functionalizing HOO
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113728
Approved by: https://github.com/bdhirsh
We can auto-functionalize operators that mutate their inputs as long as
the outputs of the operator do not alias their inputs. The user needs
to provide an abstract impl for the operator if it has non-trivial
returns.
- We update can_auto_functionalize(op) to include ops that return (but
do not alias) Tensors
- We update auto_functionalized(op, mutated_args_names, kwargs) to
return (out, mutated_args), where `out = op(**kwargs)` and
`mutated_args` are the new values of the inputs that would have been
mutated.
Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115135
Approved by: https://github.com/bdhirsh
ghstack dependencies: #114955, #114956, #115134
In preparation for the next PR up in the stack, which is going to update
"can_auto_functionalize" to support more operators than just ones that
return nothing. We are unable to auto-generate FakeTensor kernels for
operators that do not return nothing, but we are able to generate
functionalization kernels for operators that return something.
Test Plan:
Existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115134
Approved by: https://github.com/bdhirsh
ghstack dependencies: #114955, #114956
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).
Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
* Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
* Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
* Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
* Signatures now:
```python
# attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
# ctx is anything useful for rebuilding the class we want to guard on
attrs, ctx = x.__tensor_flatten__()
...
# inner_tensors is a dict of {attr -> tensor}
# ctx is taken unmodified from flattening and (eventually) guarded on
# outer_size is the expected size of the output; possibly symbolic
# outer_stride is the expected strides of the output; possibly symbolic
y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
# at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
# the assert simplifies symbols when there are relationships between outer and inner symbols
```
* Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
* Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
* Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
Users may wish to torch.compile custom ops that mutate their inputs
and return nothing (this is a common class of operators).
torch.compile will automatically support this op without anyone needing
to provide a functionalization kernel for it. Here's how.
Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> ()
op. First, when FakeTensor sees this op, it can just return None.
This is the case because custom ops are not allowed to mutate input
metadata, so the FakeTensor rule for one that returns nothing is trivial.
Next, when Python FunctionalTensor sees the op, it will functionalize
it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...})
HOP and replacing the mutated inputs with the outputs of this HOP.
This HOP effectively runs the functional version of the op when
called: it clones inputs that will be mutated, runs the op, and
then returns Tensors with the new values.
In the future we can teach Inductor how to do re-inplacing when it sees
this HOP (like how triton kernels do it) but this isn't urgent (and is
more of a performance problem).
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114955
Approved by: https://github.com/bdhirsh
Summary:
The primary problem we are setting out to solve here is fake tensor freshness. Before this PR, fake tensors after dynamo represented fake tensors *at the end* of trace, so subsequent retraces like aot_autograd would start off with fake tensors in the wrong (end result) state, rather than their expected fresh state. The solution here is to start a fresh fake mode, and re-fakify the tensors. The nuance comes from ensuring that symbols are uniformly created for the symbolic sizes and strides of the tensor.
This PR is the result of *a lot* of back and forth with ezyang and eellison. Initially, the first pass at this was not super different from what we have in the PR - the broad strokes were the same:
1) We cache source->symbol in shape_env
2) We pass policy objects around, stored at dynamo fakificaiton time, and reused for later fakification
3) We create a new fake mode for backends
(from https://github.com/pytorch/pytorch/pull/113605/files)
This is ugly, and has some layering violations. We detoured our decision making through a few other alternatives. Immutable/mutable fake tensor mode was the most interesting alternative, https://github.com/pytorch/pytorch/pull/113653, and was struck down on concerns of complexity in fake mode combined with it not covering all edge cases. We also detoured on what to do about tensor memoization returning back potentially different tensors than requested, and if that was an anti pattern (it is) we want to hack in with the symbol cache (we don't).
We went back to the drawing board here, but with a few concessions:
1) the cache for source->symbol must live outside of shape_env, for both lifecycle, and layering reasons
2) A good amount of work needs to be done to pipe policy around fake_mode and meta_utils correctly, to cover all the cases (ezyang did this)
cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng
imported-using-ghimport
Test Plan: Imported from OSS
Reviewed By: huydhn, Chillee
Differential Revision: D51566250
Pulled By: voznesenskym
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114526
Approved by: https://github.com/Chillee, https://github.com/huydhn
The primary problem we are setting out to solve here is fake tensor freshness. Before this PR, fake tensors after dynamo represented fake tensors *at the end* of trace, so subsequent retraces like aot_autograd would start off with fake tensors in the wrong (end result) state, rather than their expected fresh state. The solution here is to start a fresh fake mode, and re-fakify the tensors. The nuance comes from ensuring that symbols are uniformly created for the symbolic sizes and strides of the tensor.
This PR is the result of *a lot* of back and forth with @ezyang and @eellison. Initially, the first pass at this was not super different from what we have in the PR - the broad strokes were the same:
1) We cache source->symbol in shape_env
2) We pass policy objects around, stored at dynamo fakificaiton time, and reused for later fakification
3) We create a new fake mode for backends
(from https://github.com/pytorch/pytorch/pull/113605/files)
This is ugly, and has some layering violations. We detoured our decision making through a few other alternatives. Immutable/mutable fake tensor mode was the most interesting alternative, https://github.com/pytorch/pytorch/pull/113653, and was struck down on concerns of complexity in fake mode combined with it not covering all edge cases. We also detoured on what to do about tensor memoization returning back potentially different tensors than requested, and if that was an anti pattern (it is) we want to hack in with the symbol cache (we don't).
We went back to the drawing board here, but with a few concessions:
1) the cache for source->symbol must live outside of shape_env, for both lifecycle, and layering reasons
2) A good amount of work needs to be done to pipe policy around fake_mode and meta_utils correctly, to cover all the cases (@ezyang did this)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113926
Approved by: https://github.com/ezyang, https://github.com/eellison
Subsumes half of https://github.com/pytorch/pytorch/pull/113605
We support fakeifying an already fake tensor, which will give you a new fake tensor mirroring the same structure as the original fake tensor, which is what is needed by https://github.com/pytorch/pytorch/issues/113643 . However, when this refakeification happens, we will naively reallocate all new sizes for all of the fake tensor. This is the right thing to do if you are re-fakeifying on a fresh ShapeEnv (because you're reparametrizing the sizes or something), but if you have two fake tensor modes which are sharing a shape environment, you would actually rather just reuse the original sizes/strides/offset from the original fake tensor. This ends up being pretty simple. I recommend viewing with whitespace diff turned off.
There's some fuzz around jagged tensor handling; that code is probably not quite right, but I fixed it for this particular case in the most straightforward way.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113651
Approved by: https://github.com/albanD, https://github.com/eellison, https://github.com/bdhirsh
We spend somewhere on the order 1% in `sympy.Expr.free_symbols` as it is called millions of times.
Most of the time we actually just want to know "is this a constant", however `e.is_constant()` is
horribly slow. It turns out though that there is another propery `is_number` that does what we want.
> property is_number:
>
> Returns True if self has no free symbols and no undefined functions (AppliedUndef, to be precise). It will be faster
> than if not self.free_symbols, however, since is_number will fail as soon as it hits a free symbol or undefined
> function.
Even further, we also avoid the overhead of building the unnecessary set object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112688
Approved by: https://github.com/lezcano
Fix: #111506
This PR skips aliasing correction on `lift_fresh` calls. Reasoning is: although unlifted and lifted tensors are technically aliases, they are from different levels of abstraction (`FunctionalTensorWrapper` and `XLATensor`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112202
Approved by: https://github.com/bdhirsh
To do this, there is a little detour to remove hint caching for unbacked
SymInts; now, we just always attempt to update the hint (using
maybe_evaluate_static; this is much better than the replace we were
doing before) if we don't think we know it.
With this change, we now can generally infer that i0 == 1 is false for
a size-like unbacked SymInt. So if we write the size match /
broadcasting test very carefully (see comment), we will eventually
end up expect_true(sizeA == sizeB), which is good enough to cause
refinement. Phew!
I think I still want to setup a replacement if you do i0 == s0, but I'm
going to do that in a follow up.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112155
Approved by: https://github.com/aakhundov, https://github.com/voznesenskym
This function repeatedly flattens and unflattens the `args, kwargs` pair so we
get a quite significant perf improvement from saving the `flat_args` and
operating directly on those. I see a 15% improvement in dispatch for
`empty_strided`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112418
Approved by: https://github.com/lezcano
Currently meta_utils relies on as_strided when handling the view case (recursively meta-ify the base, and then do as_strided to simulate the view), but NestedTensor does not support as_strided today (though maybe it could?), so what we want to do instead is call Tensor. _view_func. Conveniently, _view_func IS always available for nested tensors.
A detail to note is that _view_func actually incurs a guard because it needs to perform some metadata checks to make sure the view is still valid. This PR adds Tensor._unsafe_view_func which can avoid that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112205
Approved by: https://github.com/jbschlosser
This should be the last of the "it used to work with static shapes but
it doesn't work with dynamic shapes" hard errors. Now we will just
specialize if you hit it from C++.
The strategy here is a bit clever. We shunt the size() call to Python
binding if an error would have occurred. Importantly, we already have
logic to make sure the newly allocated ints stay live for the duration
of the ArrayRef access.
storage_offset is intentionally omitted because there are some problems
with it. I will fix them next.
This should let us get rid of the aotautograd_static test configuration.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111935
Approved by: https://github.com/zou3519
This is kind of hard to test, but I can try to add a test case if requested.
I noticed locally that we now end up logging to the ProxyTensorMode and FakeTensorMode `not_implemented` logs in very simple compile examples: https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py#L269
It was because `_mirror_autograd_meta_to()` indirectly queries sizes, and since modes have higher priority than subclasses, `aten::sym_sizes()` was getting dispatched to our modes before going to `FunctionalTensor.__torch_dispatch__`.
This works out fine (they return NotImplemented and we eventually get to `FunctionalTensor`) but I figured we want to avoid cluttering up the logs. So I wrapped the calls with `FunctionalTensorMode`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111040
Approved by: https://github.com/ezyang
In this PR:
- Adds support for strides for jagged tensor (design doc for this coming soon)
- NestedTensor skips automatic dynamic
- Make use of @bdhirsh's subclass fakification logic by adding the __tensor_{un,}flatten__ functions.
- Additional logic for fakification: since existing subclass fakification logic does not handle the case where the outer tensor has an additional dimension. We insert one-off logic to (1) insert an extra SingletonSymInt onto the fakified NestedTensor. (2) make sure we call track_symint on both the sizes on the inner and outer tensor during guard creation.
Remaining things that are weird:
- Still need to skip some logic in meta utils for some reason (I was going to write this up more, but decided not to since we're not able to do this anyway for a immediate reason: we cannot arbitrarily compare singleton ints. For now I'm just following Brian's advise from [here](https://github.com/pytorch/pytorch/pull/109171#discussion_r1328137070) )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109171
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
The first reland broke internal (failing diff: D49617462).
The major error looks like it's because there's an internal-only higher order op that needs a new functionalization rule. I'm going to land an internal diff for that and confirm tests pass before relanding this PR.
Also confirmed that the issue from https://github.com/pytorch/pytorch/issues/110121 is fixed, and added a test.
This reverts commit 1b90f07f5a.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110079
Approved by: https://github.com/ezyang
Changelog:
- torch.library.impl_abstract optionally accepts a torch.library.Library
object. If passed in, then the lifetime of the registration is tied to
the Library object.
- we've also changed torch.library.impl_abstract to work on all
operators, including overloads.
- we refactored the `torch._custom_ops.*` and `torch._custom_op.*`
impl_abstract APIs and put them under torch._library. This is the
final resting place for them. I will follow-up with deleting
all the `torch._custom_ops.*` stuff later.
- There is a new "SimpleOperatorRegistry" where we actually collect the
abstract_impl. We will expand this to also hold the other
torch._custom_ops.* APIs when we move those to torch.library
NB: Previously we had designed
`impl_abstract` assuming a very high-level Python-only custom op API.
We've revisited that since; now, impl_abstract works for all custom ops,
no matter python or C++, no matter the schema. The new refactored design
reflects this better.
Test Plan:
- existing and new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109912
Approved by: https://github.com/ezyang
I added some tests for Conj, Neg and ZeroTensor for both python and C++ functionalization. This also fixes a nasty segfult when running a functorch `jacfwd` test with `torch.compile`, once AOTAutograd is using `FunctionalTensor`.
Changes:
(1) I use Jeffrey's `make_wrapper_subclass(extra_dispatch_keys)` kwarg to plumb extra dispatch keys ontoto the wrapper, mirroring what C++ functionalization does (C++ functionalization will mirror all dispatch keys from the inner tensor to the wrapper, except for python and functorch keys).
(2) FunctionalTensorMode will decompose CompositeImplicitAutograd ops, since (for example) ZeroTensor kernels can send ops like `.to()` directly to the Python key. We'll need a way to toggle this later for pre-dispatch functionalization
(3) Bound `_ForceDispatchKeyGuard` and BatchedTensorImpl's dispatch keyset to python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109023
Approved by: https://github.com/zou3519
ghstack dependencies: #108654, #109662, #109632
We want users to be able to define custom ops in C++ but put the
abstract impl in Python (since it is easier to write them in Python and
the abstract impl better models device semantics and data-dependent
operators).
`m.impl_abstract_pystub(opname, python_module, context)` declares the
abstract_impl of the operator to exist in the given python module.
When the abstract_impl needs to be accessed (either via FakeTensor or
Meta), and it does not exist, the PyTorch Dispatcher will yell
with a descriptive error message.
Some details:
- We construct a new global AbstractImplPyStub mapping in
Dispatcher.cpp. Read/write to this map is protected by the Dispatcher
lock.
- We add a new Meta Tensor fallback kernel. The fallback errors out if there is
no meta kernel, but also offers a nicer error message if we see that there is
a pystub.
- We create a `torch._utils_internal.throw_abstract_impl_not_imported_error`
helper function to throw errors. This way, we can throw different error
messages in OSS PyTorch vs internal PyTorch. To invoke this from C++, we
added a PyInterpreter::throw_abstract_impl_not_imported_error.
Differential Revision: [D49464753](https://our.internmc.facebook.com/intern/diff/D49464753/)
Differential Revision: [D49464753](https://our.internmc.facebook.com/intern/diff/D49464753)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109529
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
We now have two types of functionalization, C++ Functionalization (through the `Functionalize` dispatch key), and python functionalization (through the `FunctionalTensorMode` torch_dispatch mode).
This means that all higher order ops need custom functionalization rules for the python variant too. I added them here, as well as a helper function `dispatch_functionalize()` - equivalent to `torch.func.functionalize()`, except that it uses `FunctionalTensorMode`.
In theory we could have secretly switched `torch.func.functionalize` to use `FunctionalTensorMode`. This would be BC-breaking, though, since `FunctionalTensorMode` isn't composable with the other functorch transforms (the functorch layer-mode stack doesn't know how to re-order torch_dispatch modes arbitrarily).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108656
Approved by: https://github.com/zou3519
ghstack dependencies: #109024, #109248