This PR is doing a few interrelated things, all of which are necessary to get correctness. Read the comment in torch/fx/experimental/proxy_tensor.py for the high level overview.
Let's break down the parts of this PR:
* Bug fix where `enable_torch_dispatch_mode` with `None` doesn't work. This make `enable_torch_dispatch_mode(current_mode.inner)` work which is the basis for how we temporarily disable fake tensor mode.
* Bug fix for when fake tensor mode is combined with a non-mode tensor subclass. This actually could be ablated from this PR but it affects where the logic for allowing non fake tensor inputs with lift goes, so it's all in here in one go. There are some relevant tests for the fix in fake tensor, but it turns out I didn't need this because I'm always using proxy tensors as a mode (which ensures the ordering is right.)
* New `lift_fresh` view operator. Note that like lift, we have to manually write the functionalize kernel for these functions.
* The actual change, which is to save constants when we see them in the proxy tensor mode, and then propagate them as we go (because otherwise you'll handle mutations on constants incorrectly--see test.)
This is mildly BC-breaking if anyone was previously interposing on
at::lift, but this operator was relatively new and I checked
functorch which has no explicit reference to lift. So I think it
should not be too disruptive.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81192
Approved by: https://github.com/samdow, https://github.com/bdhirsh
This should fix the last issue that @anijain2305 hit when running ResNet with TorchDynamo <> functionalization.
Today if you try to call an `OpOverloadPacket` from python with some arguments, we will use the types of those arguments to perform overload resolution. With some functional variants of ops, this can be ambiguous.
Today this affects just one op: `_fused_moving_avg_obs_fq_helper`, although it would potentially affect e.g. `native_batch_norm` in the future.
Example:
```
# There are technically two overloads:
# torch.ops.aten._fused_moving_avg_obs_fq_helper.default (returns 2 argument, mutates 4 of its inputs inplace)
# torch.ops.aten._fused_moving_avg_obs_fq_helper.functional (returns 6 argument, mutates none of its inputs)
# We pick the wrong one - no way to know that we should pick the functional one, just from the call site.
outs = torch.ops.aten._fused_moving_avg_obs_fq_helper(a, a, a, a, a, a, a, 1.0, 0, 1, 0)
# raises an error - tries to call the overload with only 2 returns
return _fused_moving_avg_obs_fq_helper_functional[5]
```
Specifically, functionalization will bake `_fused_moving_avg_obs_fq_helper.functional` into the graph, but when AOTAutograd tries to compile with TorchScript, it needs to remove the overload name (TS doesn't know how to parse overload names directly, so we need to remove the overload name and let it infer the right overload at runtime later- so it picks the wrong one).
The situation is pretty similar to inplace; `ops.aten.add` and `ops.aten.add_` represent two different `OverloadPacket` objects; they can't be overloads of the same op, because their schemas would be ambiguous - the alias annotations are different, but that isn't enough to disambiguate).
In this PR, I try to fix the situation in a pretty similar way to how we handle `inplace` in the data model: `inplace` ops get their own base operator name, but they are represented as a flag inside of `BaseOperatorName` in the data model.
Two other important changes that I made as part of this PR:
(1) Originally, there were ~100 different `*_functional` operators: e.g. we had operators named `resize.functional` and `zero.functional`. The `_functional` bit isn't actually necessary in most cases: it's only necessary for operators that **also** have a `SchemaKind.mutable` variant, where `_fused_moving_avg_obs_fq_helper` is the only op that fits that description today. So I removed the unnecessary notion of "functional" from those other ops. I also added a bunch of assertions to force this restriction.
I think that makes more sense in the long run, because it eliminates an unnecessary difference in the model. E.g. we don't have `add_.Tensor` and `add.Tensor_functional`. We just have `add_.Tensor` and `add.Tensor`.
(2) I noticed that we actually still weren't pairing up a bunch of `_foreach` operators correctly, because their input arguments were different (`self` vs. `tensors`). Since they're private API's, I went ahead and changed the argument names directly so they get matched up. Before this PR, we were generating a separate `_foreach_add` and `_foreach_add.functional` variant in a bunch of cases, that really did the same thing (but happened to have a different name for the first argument).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80556
Approved by: https://github.com/ezyang, https://github.com/albanD
Previously, we had disallowed any reused storages in fallback kernels because the correct aliasing relationships weren't set up correctly. However, we can support the case of a reused TensorImpl by just returning the input FakeTensor it corresponds to. Additionally, we have `inplace_view` as a tag which will prevent us from running fallback kernels for kernels that mutate metadata. (thanks @ezyang who originally suggested I do this).
Fix for https://github.com/pytorch/torchdynamo/issues/467
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80545
Approved by: https://github.com/ezyang
Make `MetaConverter` and `FakeTensorConverter` hold weak references to their memoized tensors, and also have `MetaConverter` hold weak reference to Tensor storage. Otherwise it can be tricky for users to make sure all existing FakeTensors or FakeTensorModes are deleted which otherwise will leak memory.
I ran into https://github.com/pytorch/pytorch/issues/7733 which I was able to get around with the following (see comment from code):
```
# torch.Tensors cannot be used as a key in a dictionary
# because they define a custom __eq__ function which when used
# to resolve hash collisions will throw when comparing tensors:
# "RuntimeError: bool value of Tensor with more than one value is ambiguous."
# To avoid that, we use an object which will hold a Tensor and use
# its id for both hashing and equality.
# In order to use this as a weak key reference, we cannot
# simply use weakref.WeakKeyDictionary because the newly constructed
# WeakTensorRefKey only use would be a dictionary so it would have no strong
# references.
# To get around this issue, we can use it as a normal key, and then set
# `weakref.finalize` to delete the key when its contained tensor dies.
```
While for the tensor memo we can set a `weakref.finalize` callback that will remove the corresponding `WeakTensorRefKey` from the tensor memo, at the point that this callback is invoked the tensor storage is not yet deallocated.. See comment from code:
```
# [expired-storages]
# NB: even though the tensor has died,
# the deallocation of its storage can take longer,
# even when the storage has no other uses/views.
# In this case, the StorageWeakRef object will be kept alive
# longer than it needs to be, however the storage itself
# will be deallocated. We retain the possibly dead storages
# and periodically check if any of them are expired and
# can be freed.
```
partial fix for https://github.com/pytorch/torchdynamo/issues/468
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80544
Approved by: https://github.com/ezyang