Commit Graph

580 Commits

Author SHA1 Message Date
Horace He
a5fb41e3d3 Revert "Revert "Refactored prim utils into _prims_utils folder (#81746)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81746
Approved by: https://github.com/anijain2305, https://github.com/Krovatkin
2022-07-20 23:43:57 +00:00
PyTorch MergeBot
e43a02c314 Revert "Refactored prim utils into _prims_utils folder (#81088)"
This reverts commit 80231d0a72.

Reverted https://github.com/pytorch/pytorch/pull/81088 on behalf of https://github.com/jeanschmidt due to breaking internal tests
2022-07-19 19:56:41 +00:00
Horace He
80231d0a72 Refactored prim utils into _prims_utils folder (#81088)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81088
Approved by: https://github.com/ngimel
2022-07-19 03:55:51 +00:00
Edward Z. Yang
fca03eeec1 Make proxy tensor support item() calls on torch.tensor constants (#81192)
This PR is doing a few interrelated things, all of which are necessary to get correctness. Read the comment in torch/fx/experimental/proxy_tensor.py for the high level overview.

Let's break down the parts of this PR:

* Bug fix where `enable_torch_dispatch_mode` with `None` doesn't work. This make `enable_torch_dispatch_mode(current_mode.inner)` work which is the basis for how we temporarily disable fake tensor mode.
* Bug fix for when fake tensor mode is combined with a non-mode tensor subclass. This actually could be ablated from this PR but it affects where the logic for allowing non fake tensor inputs with lift goes, so it's all in here in one go. There are some relevant tests for the fix in fake tensor, but it turns out I didn't need this because I'm always using proxy tensors as a mode (which ensures the ordering is right.)
* New `lift_fresh` view operator.  Note that like lift, we have to manually write the functionalize kernel for these functions.
* The actual change, which is to save constants when we see them in the proxy tensor mode, and then propagate them as we go (because otherwise you'll handle mutations on constants incorrectly--see test.)

This is mildly BC-breaking if anyone was previously interposing on
at::lift, but this operator was relatively new and I checked
functorch which has no explicit reference to lift.  So I think it
should not be too disruptive.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81192
Approved by: https://github.com/samdow, https://github.com/bdhirsh
2022-07-15 03:53:40 +00:00
Brian Hirsh
960758b0b7 fix overload ambiguity with functional ops; fix _foreach op grouping (#80556)
This should fix the last issue that @anijain2305 hit when running ResNet with TorchDynamo <> functionalization.

Today if you try to call an `OpOverloadPacket` from python with some arguments, we will use the types of those arguments to perform overload resolution. With some functional variants of ops, this can be ambiguous.

Today this affects just one op: `_fused_moving_avg_obs_fq_helper`, although it would potentially affect e.g. `native_batch_norm` in the future.

Example:
```
# There are technically two overloads:
# torch.ops.aten._fused_moving_avg_obs_fq_helper.default (returns 2 argument, mutates 4 of its inputs inplace)
# torch.ops.aten._fused_moving_avg_obs_fq_helper.functional (returns 6 argument, mutates none of its inputs)

# We pick the wrong one - no way to know that we should pick the functional one, just from the call site.
outs = torch.ops.aten._fused_moving_avg_obs_fq_helper(a, a, a, a, a, a, a, 1.0, 0, 1, 0)
# raises an error - tries to call the overload with only 2 returns
return _fused_moving_avg_obs_fq_helper_functional[5]
```

Specifically, functionalization will bake `_fused_moving_avg_obs_fq_helper.functional` into the graph, but when AOTAutograd tries to compile with TorchScript, it needs to remove the overload name (TS doesn't know how to parse overload names directly, so we need to remove the overload name and let it infer the right overload at runtime later- so it picks the wrong one).

The situation is pretty similar to inplace; `ops.aten.add` and `ops.aten.add_` represent two different `OverloadPacket` objects; they can't be overloads of the same op, because their schemas would be ambiguous - the alias annotations are different, but that isn't enough to disambiguate).

In this PR, I try to fix the situation in a pretty similar way to how we handle `inplace` in the data model: `inplace` ops get their own base operator name, but they are represented as a flag inside of `BaseOperatorName` in the data model.

Two other important changes that I made as part of this PR:

(1) Originally, there were ~100 different `*_functional` operators: e.g. we had operators named `resize.functional` and `zero.functional`. The `_functional` bit isn't actually necessary in most cases: it's only necessary for operators that **also** have a `SchemaKind.mutable` variant, where `_fused_moving_avg_obs_fq_helper` is the only op that fits that description today. So I removed the unnecessary notion of "functional" from those other ops. I also added a bunch of assertions to force this restriction.

I think that makes more sense in the long run, because it eliminates an unnecessary difference in the model. E.g. we don't have `add_.Tensor` and `add.Tensor_functional`. We just have `add_.Tensor` and `add.Tensor`.

(2) I noticed that we actually still weren't pairing up a bunch of `_foreach` operators correctly, because their input arguments were different (`self` vs. `tensors`). Since they're private API's, I went ahead and changed the argument names directly so they get matched up. Before this PR, we were generating a separate `_foreach_add` and `_foreach_add.functional` variant in a bunch of cases, that really did the same thing (but happened to have a different name for the first argument).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80556
Approved by: https://github.com/ezyang, https://github.com/albanD
2022-07-06 12:45:11 +00:00
PyTorch MergeBot
d8cd15b00c Revert "Allow returned tensor impls in fallback (#80545)"
This reverts commit d7e4520d1d.

Reverted https://github.com/pytorch/pytorch/pull/80545 on behalf of https://github.com/malfet due to New test broke rocm, see d7e4520d1d
2022-07-01 01:30:40 +00:00
Elias Ellison
d7e4520d1d Allow returned tensor impls in fallback (#80545)
Previously, we had disallowed any reused storages in fallback kernels because the correct aliasing relationships weren't set up correctly. However, we can support the case of a reused TensorImpl by just returning the input FakeTensor it corresponds to. Additionally, we have `inplace_view` as a tag which will prevent us from running fallback kernels for kernels that mutate metadata. (thanks @ezyang who originally suggested I do this).

Fix for https://github.com/pytorch/torchdynamo/issues/467
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80545
Approved by: https://github.com/ezyang
2022-06-30 21:57:54 +00:00
Elias Ellison
1058b47562 Weak-ref-ify MetaConverter and FakeTensorConverter (#80544)
Make `MetaConverter` and `FakeTensorConverter` hold weak references to their memoized tensors, and also have `MetaConverter` hold weak reference to Tensor storage. Otherwise it can be tricky for users to make sure all existing FakeTensors or FakeTensorModes are deleted which otherwise will leak memory.

I ran into https://github.com/pytorch/pytorch/issues/7733 which I was able to get around with the following (see comment from code):

```
# torch.Tensors cannot be used as a key in a dictionary
# because they define a custom __eq__ function which when used
# to resolve hash collisions will throw when comparing tensors:
# "RuntimeError: bool value of Tensor with more than one value is ambiguous."
# To avoid that, we use an object which will hold a Tensor and use
# its id for both hashing and equality.
# In order to use this as a weak key reference, we cannot
# simply use weakref.WeakKeyDictionary because the newly constructed
# WeakTensorRefKey only use would be a dictionary so it would have no strong
# references.
# To get around this issue, we can use it as a normal key, and then set
# `weakref.finalize` to delete the key when its contained tensor dies.
```

While for the tensor memo we can set a `weakref.finalize` callback that will remove the corresponding `WeakTensorRefKey` from the tensor memo, at the point that this callback is invoked the tensor storage is not yet deallocated.. See comment from code:

```
# [expired-storages]
# NB: even though the tensor has died,
# the deallocation of its storage can take longer,
# even when the storage has no other uses/views.
# In this case, the StorageWeakRef object will be kept alive
# longer than it needs to be, however the storage itself
# will be deallocated. We retain the possibly dead storages
# and periodically check if any of them are expired and
# can be freed.
```

partial fix for https://github.com/pytorch/torchdynamo/issues/468
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80544
Approved by: https://github.com/ezyang
2022-06-29 23:36:35 +00:00
Elias Ellison
c1d11f40c9 [FakeTensor] Use the device of the meta tensor for fallback kernel (#80193)
Otherwise, I would run into issues such as fp16 not supported for Conv Cpu, or cudnn_rnn not implemented for cpu.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80193
Approved by: https://github.com/Chillee
2022-06-24 20:00:07 +00:00
PyTorch MergeBot
35268bdc2a Revert "[FakeTensor] Use the device of the meta tensor for fallback kernel (#80193)"
This reverts commit 93e70c5973.

Reverted https://github.com/pytorch/pytorch/pull/80193 on behalf of https://github.com/b0noI due to broken test: https://github.com/pytorch/pytorch/runs/7035945243?check_suite_focus=true
2022-06-24 14:53:37 +00:00
Elias Ellison
93e70c5973 [FakeTensor] Use the device of the meta tensor for fallback kernel (#80193)
Otherwise, I would run into issues such as fp16 not supported for Conv Cpu, or cudnn_rnn not implemented for cpu.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80193
Approved by: https://github.com/Chillee
2022-06-24 03:53:21 +00:00
Elias Ellison
7e54959a8a Add support for indexing cuda tensors with cpu (#80115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80115
Approved by: https://github.com/Chillee
2022-06-23 22:46:40 +00:00
Elias Ellison
a6b783e714 Refine conditions under which index.Tensor has a dynamic shape
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79809

Approved by: https://github.com/davidberard98
2022-06-20 22:50:03 +00:00
Elias Ellison
9705fb03b3 Add support for a couple ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79581

Approved by: https://github.com/Chillee
2022-06-20 22:25:39 +00:00
Elias Ellison
268bbecf1c Add option for allowing non-fake inputs, add deepcopy impl
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79580

Approved by: https://github.com/samdow
2022-06-17 19:36:26 +00:00
Elias Ellison
13a8867c01 Add Dynamic Output Shape Tagdfor ata-dependent ops, handle in FakeTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79170

Approved by: https://github.com/ezyang
2022-06-09 22:16:16 +00:00
Elias Ellison
3c5a3ca9e8 Make FakeTensors return meta within kerenl invocation, add FakeTensor op tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78972

Approved by: https://github.com/ezyang
2022-06-09 01:39:27 +00:00
Elias Ellison
d6ecdf1605 refactor op handling to use register pattern
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78523

Approved by: https://github.com/ezyang
2022-06-08 22:37:50 +00:00
Elias Ellison
fe7a13496e Add CPU Fallback
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78522

Approved by: https://github.com/ezyang
2022-06-08 22:35:13 +00:00
Elias Ellison
290d0979f1 Migrate FakeTensors to always call into FakeTensorMode and have them hold a reference
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78677

Approved by: https://github.com/ezyang
2022-06-08 22:30:34 +00:00
Elias Ellison
a711ce4a2a add non-kwarg device and _like constructors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78536

Approved by: https://github.com/ezyang
2022-06-07 20:36:41 +00:00
PyTorch MergeBot
0df77f320f Revert "add non-kwarg device and _like constructors"
This reverts commit 44937da6db.

Reverted https://github.com/pytorch/pytorch/pull/78536 on behalf of https://github.com/janeyx99 due to Broke meta tests on trunk and on PR https://github.com/pytorch/pytorch/runs/6765692797?check_suite_focus=true
2022-06-07 04:11:05 +00:00
Elias Ellison
44937da6db add non-kwarg device and _like constructors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78536

Approved by: https://github.com/ezyang
2022-06-07 00:28:59 +00:00
Edward Z. Yang
587efdb5fa Replace TensorMeta with FakeTensor
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78836

Approved by: https://github.com/albanD, https://github.com/mruberry
2022-06-05 11:51:27 +00:00
Elias Ellison
26d273959c Add Caching of Conversion to Fake/Meta tensors in FakeTensorMode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78090

Approved by: https://github.com/ezyang
2022-06-03 13:56:00 +00:00
Elias Ellison
6671b504f7 Modernize FakeTensorMode, throw on non-fake inputs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78516

Approved by: https://github.com/samdow
2022-06-01 21:43:59 +00:00
Elias Ellison
cea7dd1646 Add FakeTensorMode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77972

Approved by: https://github.com/ezyang
2022-05-31 16:27:06 +00:00
Elias Ellison
4c18f362a9 add support for type_as/_to_copy
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77971

Approved by: https://github.com/ezyang
2022-05-31 16:25:11 +00:00
Elias Ellison
98e0816986 Extend __new__ on subclasses to set custom_device and custom_strides
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77970

Approved by: https://github.com/Chillee
2022-05-31 16:23:18 +00:00
Elias Ellison
678213ead2 Fake Tensor Part 1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77969

Approved by: https://github.com/ezyang
2022-05-31 16:20:35 +00:00