In preparation for the next PR up in the stack, which is going to update
"can_auto_functionalize" to support more operators than just ones that
return nothing. We are unable to auto-generate FakeTensor kernels for
operators that do not return nothing, but we are able to generate
functionalization kernels for operators that return something.
Test Plan:
Existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115134
Approved by: https://github.com/bdhirsh
ghstack dependencies: #114955, #114956
Users may wish to torch.compile custom ops that mutate their inputs
and return nothing (this is a common class of operators).
torch.compile will automatically support this op without anyone needing
to provide a functionalization kernel for it. Here's how.
Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> ()
op. First, when FakeTensor sees this op, it can just return None.
This is the case because custom ops are not allowed to mutate input
metadata, so the FakeTensor rule for one that returns nothing is trivial.
Next, when Python FunctionalTensor sees the op, it will functionalize
it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...})
HOP and replacing the mutated inputs with the outputs of this HOP.
This HOP effectively runs the functional version of the op when
called: it clones inputs that will be mutated, runs the op, and
then returns Tensors with the new values.
In the future we can teach Inductor how to do re-inplacing when it sees
this HOP (like how triton kernels do it) but this isn't urgent (and is
more of a performance problem).
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114955
Approved by: https://github.com/bdhirsh
Summary:
The primary problem we are setting out to solve here is fake tensor freshness. Before this PR, fake tensors after dynamo represented fake tensors *at the end* of trace, so subsequent retraces like aot_autograd would start off with fake tensors in the wrong (end result) state, rather than their expected fresh state. The solution here is to start a fresh fake mode, and re-fakify the tensors. The nuance comes from ensuring that symbols are uniformly created for the symbolic sizes and strides of the tensor.
This PR is the result of *a lot* of back and forth with ezyang and eellison. Initially, the first pass at this was not super different from what we have in the PR - the broad strokes were the same:
1) We cache source->symbol in shape_env
2) We pass policy objects around, stored at dynamo fakificaiton time, and reused for later fakification
3) We create a new fake mode for backends
(from https://github.com/pytorch/pytorch/pull/113605/files)
This is ugly, and has some layering violations. We detoured our decision making through a few other alternatives. Immutable/mutable fake tensor mode was the most interesting alternative, https://github.com/pytorch/pytorch/pull/113653, and was struck down on concerns of complexity in fake mode combined with it not covering all edge cases. We also detoured on what to do about tensor memoization returning back potentially different tensors than requested, and if that was an anti pattern (it is) we want to hack in with the symbol cache (we don't).
We went back to the drawing board here, but with a few concessions:
1) the cache for source->symbol must live outside of shape_env, for both lifecycle, and layering reasons
2) A good amount of work needs to be done to pipe policy around fake_mode and meta_utils correctly, to cover all the cases (ezyang did this)
cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng
imported-using-ghimport
Test Plan: Imported from OSS
Reviewed By: huydhn, Chillee
Differential Revision: D51566250
Pulled By: voznesenskym
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114526
Approved by: https://github.com/Chillee, https://github.com/huydhn
The primary problem we are setting out to solve here is fake tensor freshness. Before this PR, fake tensors after dynamo represented fake tensors *at the end* of trace, so subsequent retraces like aot_autograd would start off with fake tensors in the wrong (end result) state, rather than their expected fresh state. The solution here is to start a fresh fake mode, and re-fakify the tensors. The nuance comes from ensuring that symbols are uniformly created for the symbolic sizes and strides of the tensor.
This PR is the result of *a lot* of back and forth with @ezyang and @eellison. Initially, the first pass at this was not super different from what we have in the PR - the broad strokes were the same:
1) We cache source->symbol in shape_env
2) We pass policy objects around, stored at dynamo fakificaiton time, and reused for later fakification
3) We create a new fake mode for backends
(from https://github.com/pytorch/pytorch/pull/113605/files)
This is ugly, and has some layering violations. We detoured our decision making through a few other alternatives. Immutable/mutable fake tensor mode was the most interesting alternative, https://github.com/pytorch/pytorch/pull/113653, and was struck down on concerns of complexity in fake mode combined with it not covering all edge cases. We also detoured on what to do about tensor memoization returning back potentially different tensors than requested, and if that was an anti pattern (it is) we want to hack in with the symbol cache (we don't).
We went back to the drawing board here, but with a few concessions:
1) the cache for source->symbol must live outside of shape_env, for both lifecycle, and layering reasons
2) A good amount of work needs to be done to pipe policy around fake_mode and meta_utils correctly, to cover all the cases (@ezyang did this)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113926
Approved by: https://github.com/ezyang, https://github.com/eellison
We spend somewhere on the order 1% in `sympy.Expr.free_symbols` as it is called millions of times.
Most of the time we actually just want to know "is this a constant", however `e.is_constant()` is
horribly slow. It turns out though that there is another propery `is_number` that does what we want.
> property is_number:
>
> Returns True if self has no free symbols and no undefined functions (AppliedUndef, to be precise). It will be faster
> than if not self.free_symbols, however, since is_number will fail as soon as it hits a free symbol or undefined
> function.
Even further, we also avoid the overhead of building the unnecessary set object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112688
Approved by: https://github.com/lezcano
To do this, there is a little detour to remove hint caching for unbacked
SymInts; now, we just always attempt to update the hint (using
maybe_evaluate_static; this is much better than the replace we were
doing before) if we don't think we know it.
With this change, we now can generally infer that i0 == 1 is false for
a size-like unbacked SymInt. So if we write the size match /
broadcasting test very carefully (see comment), we will eventually
end up expect_true(sizeA == sizeB), which is good enough to cause
refinement. Phew!
I think I still want to setup a replacement if you do i0 == s0, but I'm
going to do that in a follow up.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112155
Approved by: https://github.com/aakhundov, https://github.com/voznesenskym
This function repeatedly flattens and unflattens the `args, kwargs` pair so we
get a quite significant perf improvement from saving the `flat_args` and
operating directly on those. I see a 15% improvement in dispatch for
`empty_strided`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112418
Approved by: https://github.com/lezcano
This should be the last of the "it used to work with static shapes but
it doesn't work with dynamic shapes" hard errors. Now we will just
specialize if you hit it from C++.
The strategy here is a bit clever. We shunt the size() call to Python
binding if an error would have occurred. Importantly, we already have
logic to make sure the newly allocated ints stay live for the duration
of the ArrayRef access.
storage_offset is intentionally omitted because there are some problems
with it. I will fix them next.
This should let us get rid of the aotautograd_static test configuration.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111935
Approved by: https://github.com/zou3519
Changelog:
- torch.library.impl_abstract optionally accepts a torch.library.Library
object. If passed in, then the lifetime of the registration is tied to
the Library object.
- we've also changed torch.library.impl_abstract to work on all
operators, including overloads.
- we refactored the `torch._custom_ops.*` and `torch._custom_op.*`
impl_abstract APIs and put them under torch._library. This is the
final resting place for them. I will follow-up with deleting
all the `torch._custom_ops.*` stuff later.
- There is a new "SimpleOperatorRegistry" where we actually collect the
abstract_impl. We will expand this to also hold the other
torch._custom_ops.* APIs when we move those to torch.library
NB: Previously we had designed
`impl_abstract` assuming a very high-level Python-only custom op API.
We've revisited that since; now, impl_abstract works for all custom ops,
no matter python or C++, no matter the schema. The new refactored design
reflects this better.
Test Plan:
- existing and new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109912
Approved by: https://github.com/ezyang
We want users to be able to define custom ops in C++ but put the
abstract impl in Python (since it is easier to write them in Python and
the abstract impl better models device semantics and data-dependent
operators).
`m.impl_abstract_pystub(opname, python_module, context)` declares the
abstract_impl of the operator to exist in the given python module.
When the abstract_impl needs to be accessed (either via FakeTensor or
Meta), and it does not exist, the PyTorch Dispatcher will yell
with a descriptive error message.
Some details:
- We construct a new global AbstractImplPyStub mapping in
Dispatcher.cpp. Read/write to this map is protected by the Dispatcher
lock.
- We add a new Meta Tensor fallback kernel. The fallback errors out if there is
no meta kernel, but also offers a nicer error message if we see that there is
a pystub.
- We create a `torch._utils_internal.throw_abstract_impl_not_imported_error`
helper function to throw errors. This way, we can throw different error
messages in OSS PyTorch vs internal PyTorch. To invoke this from C++, we
added a PyInterpreter::throw_abstract_impl_not_imported_error.
Differential Revision: [D49464753](https://our.internmc.facebook.com/intern/diff/D49464753/)
Differential Revision: [D49464753](https://our.internmc.facebook.com/intern/diff/D49464753)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109529
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
- Update cross-ref FakeMode test to use ShapeEnv. Dynamic ops can now
return an unbacked SymInt. We always accept this as equal to whatever
the real value was.
- Relax test so it works on all classes, not just unittest.TestCase
- Properly wrap the original method, so things like
pytree.mark.parametrize are carried over
- Support dynamic shapes by default for make_fx `tracing_mode="fake"` without symbolifying everything else
Fixes https://github.com/pytorch/pytorch/issues/108927
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108929
Approved by: https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/101939
Several fixes bundled together:
1. When we valueToTensor, we only handled non-symbolic inputs and not symbolic inputs. We support symbolic Scalar, so also handle symbolic values.
2. In the symbolic case, we MUST NOT lift_fresh, as you're not going to inline a constant into the graph, it's going to be from a `scalar_tensor` call (so no need to clone it to avoid mutations)
3. In indexing scalarToTensor, must not do the static, directly read out the scalar contents logic with the scalar is symbolic
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108873
Approved by: https://github.com/jansel
Summary:
Enables dynamo eager mode tracing for the following situation:
1. we have a torch.autograd.Function
2. the input to that function is a tensor subclass which is an intermediary
This is useful for float8 training UX.
Test Plan:
```
python test/dynamo/test_autograd_function.py -k intermediary_input
```
Reviewers:
Subscribers:
Tasks:
Tags:
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108093
Approved by: https://github.com/bdhirsh, https://github.com/wanchaol
This PR adds a `return_and_correct_aliasing()` utility, that wrapper subclasses can use to get correct aliasing. I updated `TwoTensor` to use it, and added some testing that the aliasing of my `TwoTensor` subclass now matches the aliasing behavior of normal tensors.
Right now my test just uses a few hand-picked opinfos (that have varying aliasing behavior). I thought all op infos might be overkill (does that take a while to run?), but I'm happy to add them all if people prefer.
One more general question about this PR: eventually, proper aliasing will be a **requirement** in order for AOTAutograd to handle aliasing/mutations on subclasses properly during compilation. How can we make sure that wrapper subclasses use this API? A few options (from talking to Richard):
(1) Yolo require subclasses to use the API and hope users do as well (what this PR does)
(2) Yolo require subclasses to use the API, but add a kwarg to `_make_wrapper_subclass`, e.g. `manual_aliasing=True`, that torch.compile checks for before allowing the subclass to be used in compilation
(3) Automatically run this API in our python fallback, for **every** tensor subclass that currently implements `__tensor_flatten__` (aka only the "traceable" subclasses)
(4) Automatically run this API in our python fallback, for **every** tensor subclass. This would be a bit higher blast radius, since it would change the existing aliasing behavior of wrapper subclasses. Maybe.. this is the right thing to do though?
Either way, my tentative plan is to do (1) to unblock, and revisit this later once we want to come up with public docs + a more general "tensor subclass in PT2 requirements" plan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107915
Approved by: https://github.com/ezyang
**Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates:
(1) TLS changes in `TorchDispatchModeTLS.h/cpp`.
I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack).
`TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes.
`TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes".
`TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes).
There were also some mild codegen changes to support the new enum
(2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py`
The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`.
I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly.
(2) dispatching logic in `python_arg_parser.cpp`
This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now:
```
(a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last)
(b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses)
(c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses)
```
Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass.
How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key.
(3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends).
----- original description below -----
The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit
Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that.
**(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode**
This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes.
I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set.
This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler.
**(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()**
This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs.
**(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's**
Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active.
**(4) Allow traceable tensor subclasses to access storages from python**
This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable"
**(5) Fixed subclass fakeification**
@wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct.
**(6) make tensor subclasses resizeable**
Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`.
**(7) Added a basic DoubleTensor subclass for testing**
I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482
Approved by: https://github.com/ezyang
ghstack dependencies: #107415
There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular:
(1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests
(2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards.
(3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107415
Approved by: https://github.com/ezyang
This PR stops `SymNode` from mutating (i.e. simplifying) its expression. Instead, the
simplification (without mutation) is deferred to the `SymNode.maybe_as_int` method.
```python
- FakeTensor(size=(s0,), ...)
- FakeTensor(size=(s1, s2, s3), ...)
- Eq(s0, s1 + s2 + s3)
- FakeTensor(size=(s0,), ...)
- FakeTensor(size=(s1, s2, s3), ...)
```
In summary, this PR:
- Replaces `SymNode._expr` by `SymNode.expr`, removing the old property function
- This makes it so `SymNode` instances never update their expression
- Creates `SymNode.simplified_expr()` method for actually calling `ShapeEnv.replace` on
its expression. Note that this doesn't updates `SymNode.expr`
- Changes how `tensor.size()` gets converted to its Python `torch.Size` type
- Instead of calling `SymInt::maybe_as_int()` method, we create a new
`SymInt::is_symbolic()` method for checking whether it is actually a symbolic value
- This is needed so that when we call `tensor.size()` in the Python side, the returned
sequence is faithful to the actual data, instead of possibly simplifying it and
returning an integer
- 2 files needs this modification:
- _torch/csrc/Size.cpp_: for handling `torch.Tensor.size` Python calls
- _torch/csrc/utils/pybind.cpp_: for handling `symint.cast()` C++ calls
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107492
Approved by: https://github.com/ezyang
ghstack dependencies: #107523
This PR allows dynamo to fakify FunctionalTensorWrapper by unwrapping, replacing and wrapping again for FunctionalTensorWrapper so that FunctionalTensorWrapper can be passed in as input for dynamo.optimize and we can support something like this
```python
ff = torch.func.functionalize(f)
torch.compile(ff)(x)
```
This PR didn't follow the \_\_tensor_flatten\_\_ and \_\_tensor_unflatten\_\_ protocol right now because we're not sure the plan of doing that for FunctionalTensorWrapper (it's implemented in C++).
**Test Plan:**
Add a new test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107062
Approved by: https://github.com/zou3519
ghstack dependencies: #107042
This PR allows dynamo to fakify FunctionalTensorWrapper by unwrapping, replacing and wrapping again for FunctionalTensorWrapper so that FunctionalTensorWrapper can be passed in as input for dynamo.optimize and we can support something like this
```python
ff = torch.func.functionalize(f)
torch.compile(ff)(x)
```
This PR didn't follow the \_\_tensor_flatten\_\_ and \_\_tensor_unflatten\_\_ protocol right now because we're not sure the plan of doing that for FunctionalTensorWrapper (it's implemented in C++).
**Test Plan:**
Add a new test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107062
Approved by: https://github.com/zou3519
ghstack dependencies: #107042
Currently there are FFT operators which raise `UnsupportedOperatorException`
because their meta implementations sometimes give incorrect strides. This works
around the problem for static shapes by falling back to eager. Though we still
don't support calls with dynamic shapes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106319
Approved by: https://github.com/ezyang