Commit Graph

76 Commits

Author SHA1 Message Date
Brian Hirsh
440a3f2398 fix set_() with functionalization (#90722)
This should fix https://github.com/pytorch/pytorch/issues/90573

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90722
Approved by: https://github.com/ezyang
2022-12-19 16:11:06 +00:00
Richard Zou
4068c5467d [Reland] Move functorch/_src to torch/_functorch (#88756) (#90091)
This will be the last disruptive functorch internals change.

Why are we moving these files?
- As a part of rationalizing functorch we are moving the code in
functorch/_src to torch/_functorch
- This is so that we can offer the functorch APIs as native PyTorch APIs
(coming soon) and resolve some internal build issues.

Why are we moving all of these files at once?
- It's better to break developers all at once rather than many times

Test Plan:
- wait for tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90091
Approved by: https://github.com/anijain2305, https://github.com/ezyang
2022-12-03 14:17:15 +00:00
Jane Xu
76e869c911 [BE] Beef up test_functionalization to test functionalizing multi-parameter functions (#89798)
Previously, `assert_functionalization` only took in uni-Tensor-parameter functions. This PR beefs up the check to allow for functions that take multiple parameters.

This PR also changes the test_instance_norm test to check that the multiparam change works.

## Test plan
Locally tested, CI should also pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89798
Approved by: https://github.com/samdow
2022-11-30 20:46:16 +00:00
Jane Xu
fcb5d6e771 Enable instance norm running mean test (#89793)
Followup action to https://github.com/pytorch/pytorch/pull/88697
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89793
Approved by: https://github.com/bdhirsh
2022-11-29 23:45:56 +00:00
PyTorch MergeBot
218d9c6e09 Revert "Move functorch/_src to torch/_functorch (#88756)"
This reverts commit 52bc5c1cfe.

Reverted https://github.com/pytorch/pytorch/pull/88756 on behalf of https://github.com/clee2000 due to broke imports in tests 52bc5c1cfe https://github.com/pytorch/pytorch/actions/runs/3574742513/jobs/6010814968 probably a landrace
2022-11-29 17:17:11 +00:00
Richard Zou
52bc5c1cfe Move functorch/_src to torch/_functorch (#88756)
This will be the last disruptive functorch internals change.

Why are we moving these files?
- As a part of rationalizing functorch we are moving the code in
functorch/_src to torch/_functorch
- This is so that we can offer the functorch APIs as native PyTorch APIs
(coming soon) and resolve some internal build issues.

Why are we moving all of these files at once?
- It's better to break developers all at once rather than many times

Test Plan:
- wait for tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88756
Approved by: https://github.com/ezyang
2022-11-29 13:55:42 +00:00
Jane Xu
8695f0cced Rectify native_batch_norm schema by splitting it into two legit schemas (#88697)
Using the same repro from the issue (but with BatchNorm2D)

Rectifies native_batch_norm schema by splitting the schema into 2:
1. one will have NON-optional alias-able running_mean and running_var inputs
2. the other will just not have those parameters at all (no_stats variation)

**Calling for name suggestions!**

## test plan
I've added tests in test_functionalization.py as well as an entry in common_method_invocations.py for `native_batch_norm_legit`
CI should pass.

## next steps
Because of bc/fc reasons, we reroute native_batch_norm to call our new schemas ONLY through the python dispatcher, but in 2 weeks or so, we should make `native_batch_norm_legit` the official batch_norm.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88697
Approved by: https://github.com/albanD
2022-11-23 23:23:17 +00:00
Edward Z. Yang
5266953443 Add crossref debug mode for functionalization, catches stride errors (#89498)
The idea is to add a custom handler to Functionalize key in Python
dispatcher that runs the functionalized version along side a non
functionalized version, and checks that their outputs agree in the
end.  (Technically, for metadata mutation we should also check the
inputs, but for now we're relying on those functions returning self.)
I turned this on for test_functionalize.py (new TestCrossRefFunctionalize)
and found a bunch of failures that look legit.

This probably doesn't interact that nicely if you're also tracing at
the same time, probably need more special logic for that (directly,
just disabling tracing for when we create the nested fake tensor mode,
but IDK if there's a more principled way to organize this.)

There are some misc fixups which I can split if people really want.

- xfail_inherited_tests moved to test common_utils
- Bindings for _dispatch_tls_set_dispatch_key_included,
  _dispatch_tls_is_dispatch_key_included and _functionalization_reapply_views_tls
- Type stubs for _enable_functionalization, _disable_functionalization
- all_known_overloads utility to let you iterate over all OpOverloads
  in all namespaces.  Iterator support on all torch._ops objects to let
  you iterate over their members.
- suspend_functionalization lets you temporarily disable functionalization mode
  in a context
- check_metadata_matches for easily comparing outputs of functions and see
  if they match (TODO: there are a few copies of this logic, consolidate!)
- _fmt for easily printing the metadata of a tensor without its data
- _uncache_dispatch for removing a particular dispatch key from the cache,
  so that we force it to regenerate
- check_significant_strides new kwarg only_cuda to let you also do stride
  test even when inputs are not CUDA
- Functionalize in torch._C.DispatchKey

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89498
Approved by: https://github.com/malfet
2022-11-23 04:18:25 +00:00
Edward Z. Yang
d9cbe7764e Make aten.copy preserve strides (hf_Longformer) (#89464)
Fixes https://github.com/pytorch/torchdynamo/issues/1888

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D41460986](https://our.internmc.facebook.com/intern/diff/D41460986)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89464
Approved by: https://github.com/bdhirsh
2022-11-22 13:06:43 +00:00
Brian Hirsh
ec4eadac5b reland "Do not use unsafe restriding for subclasses (#87610)" (#88343)
This reverts commit 5b75b19f51.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88343
Approved by: https://github.com/ezyang
2022-11-14 13:42:51 +00:00
Edward Z. Yang
0e3031f7e7 Functionalize and compute joint simultaneously. (#88063)
This also comes with some bug fixes that were uncovered from doing
this:

- Forward device calls to inner tensor in FunctionalTensorWrapper

- Make legacyExtractDispatchKey exclude Functionalize, so that
  it can get at the real device type key.  This is noncontroversial.

- Stop stripping dense from key set.  The reason for this is
  FunctionalWrapperTensor may be used in contexts where people
  query if it is dense or not.  If it doesn't report this correctly
  (from the dispatch key), it will cause errors.  This caused some
  torchbench models to fail when I did one-pass tracing.

- Save and restore reapply views TLS correctly

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88063
Approved by: https://github.com/bdhirsh
2022-11-05 03:52:40 +00:00
PyTorch MergeBot
5b75b19f51 Revert "Do not use unsafe restriding for subclasses (#87610)"
This reverts commit 73379acaf3.

Reverted https://github.com/pytorch/pytorch/pull/87610 on behalf of https://github.com/mehtanirav due to [Internal breakages](https://www.internalfb.com/intern/sandcastle/job/36028797828925790/insights)
2022-11-02 16:59:02 +00:00
Edward Z. Yang
bb7e6254e4 Add ability to freeze storages inside functionalization (#88141)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88141
Approved by: https://github.com/albanD, https://github.com/bdhirsh
2022-11-01 16:00:33 +00:00
Brian Hirsh
73379acaf3 Do not use unsafe restriding for subclasses (#87610)
This helps convert some accuracy errors into runtime errors,
which makes it easier to debug.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87610
Approved by: https://github.com/albanD
2022-10-31 20:49:15 +00:00
Brian Hirsh
23ff47ccc5 functionalization: fix detach() (#87750)
`.detach()` worked in basic cases previously, but didn't properly preserve view relationships between the base and the output. This wasn't heavily tested, because autograd doesn't normally encounter `FunctionalTensorWrapper` directly, but could become more common if we fuse functionalization and autograd into a single tracing pass.

This will also be a bug fix for LTC (and XLA when they use functionalization)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87750
Approved by: https://github.com/ezyang
2022-10-27 15:47:56 +00:00
Brian Hirsh
9ad1659b17 functionalization: make view_copy outputs always contiguous (#85747)
This fixes an issue with mobile: The output of view_copy ops should always be contiguous.

Later, we can consider adding optional arguments to the `view_copy()` functions to let you explicitly say what the contiguity of the output can be (e.g. channels_last)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85747
Approved by: https://github.com/ezyang
2022-10-21 17:42:02 +00:00
Horace He
a27a4a02fe Refactored proxytensor to clean up separate branches (#84325)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84325
Approved by: https://github.com/ezyang
2022-08-31 09:37:55 +00:00
Brian Hirsh
e9e7363854 reinplacing pass fixes for torchbench + huggingface (#83626)
I'm testing out turning on re-inplacing + functionalization by default with the AOTAutograd + eager backend on torchbench + huggingface models. This PR contains a few bug fixes from turning re-inplacing on:

(1) Handle more gracefully when FakeTensorMode is already turned on when you call reinplace

(2) More robust detection for when an inplace variant of an op exists (the dumb bug was that `pow.Scalar` doesn't have an inplace variant, even though there are several overloads of `pow_`. None of them are eligible though

(3) Avoid re-inplacing when it would require resizing the input buffer. This isn't allowed, because inplace ops aren't allowed to resize their inputs.

For the last one, I gave the two main examples in more detail in the comments. Important cases are:
```
# This should not be re-inplaced at all; the op broadcasts, so this would require resizing the self tensor
torch.add(tensor[1, 4], tensor[4, 4])

# This should not be re-inplaced, because the inplace and out-of-place variants of the op return different dtypes
torch.ge(a, b)
# However, this means that today when functionalization functionalists a `torch.ge_(a, b)` call, reinplacing won't properly de-functionalize it. I mentioned that optimization is worth adding later in the comments
```

(4) There's some logic around keeping `storage_to_nodes` up to date when we see a view op: if we re-inplace `out = a.add(...)`, and later in the program we encounter a "later_node",`out.view(..)`, and need to replace it with `a.view(...)`, then we need to update some metadata structures. I had to fix that logic: specifically, if "later_node" isn't a dispatcher op, (e.g. if it's an FX output node), I wasn't properly handling the case where the node's fake_meta info was not a tensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83626
Approved by: https://github.com/ezyang
2022-08-19 23:30:45 +00:00
Edward Z. Yang
988bd0173c Add OpOverload.decompose API (#83075)
This allows you to directly call into the CompositeImplicitAutograd
implementation of an operator, *without* changing any aspects of the
dispatcher state.  In particular, you can use this to recursively call
into a decomposition, dispatching back to your tensor subclass/mode
as desired.

Hypothetically, we should also make these available in the
decompositions dictionary, but I'm leaving this as future work as
enumerating these decompositions is annoying (as operators are lazily
registered.)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83075
Approved by: https://github.com/albanD
2022-08-09 18:53:19 +00:00
Peter Bell
4f255dbfb3 Remove manual bindings for arange (#81380)
The functional variant of one of the `arange` overloads has a schema mismatch with the out variant. The functional one has `Scalar step`, but the corresponding out variant has `Scalar step=1`. This isn't allowed, so it had to be special-cased in the python codegen and manually bound. This adds the default `step` value to the functional overload and removes the special-casing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81380
Approved by: https://github.com/ngimel
2022-08-07 00:10:27 +00:00
Peter Bell
2c2278a960 Make python TensorOption signatures consistent with JIT schemas (#82241)
Fixes #81774

`TensorOptions` arguments in the JIT schema are optional, but in the Python API these were being translated to non-optional but with a default value. This change makes the arguments accept `None` for consistency with the JIT schema. However, it also means that `dtype=c10::nullopt` was previously completely untested so this also fixes several related bugs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82241
Approved by: https://github.com/ngimel
2022-08-07 00:10:27 +00:00
Brian Hirsh
d362b8e9e6 reland "add a reinplacing FX pass (#80897)" (#82407)
fixes #81457
fixes #81216
fixes #81212
fixes #81207
fixes #81206
fixes #81218
fixes #81203
fixes #81202
fixes #81214
fixes #81220
fixes #81205
fixes #81200
fixes #81204
fixes #81221
fixes #81209
fixes #81210
fixes #81215
fixes #81217
fixes #81222
fixes #81211
fixes #81201
fixes #81208

As part of this PR I'm also re-enabling all of the functionalization tests that got marked as flaky in CI (they're not actually flaky - I think they got marked because a PR that should have changed their expect-test output made it to master without the changes. I'll let CI run on this PR to confirm though).

reland of https://github.com/pytorch/pytorch/pull/80897
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82407
Approved by: https://github.com/ezyang
2022-08-02 18:03:29 +00:00
Brian Hirsh
7eed83e016 fix functionalization handling for mixed functional/nonfunctional tensorlists (#82326)
There's an existing assert in functionalization that's probably too restrictive - when you pass a list of tensors to an op that has a mix of functional and nonfunctional tensors, we should just selectively unwrap the functional tensors and call the op rather than erroring.

I added a test for it in `test_functionalization.py` - it looks like this behavior can also show up when tracing with `make_fx()`, when constants get baked in as module properties, which don't get wrapped up when you try to functionalize the module's forward function.

Should fix the last of https://github.com/pytorch/torchdynamo/issues/88#issuecomment-1193059940

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82326
Approved by: https://github.com/ezyang
2022-07-29 17:28:19 +00:00
Edward Z. Yang
2f95b61cea Revert "Revert "Make factory functions CompositeExplicitAutograd (#82251)"" (#82470)
This reverts commit 1df307f334.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82470
Approved by: https://github.com/zou3519
2022-07-29 17:06:07 +00:00
PyTorch MergeBot
e3243203b0 Revert "Add Python to CompositeImplicitAutograd (#82333)"
This reverts commit 1a20c69385.

Reverted https://github.com/pytorch/pytorch/pull/82333 on behalf of https://github.com/osalpekar due to Failing executorch tests internally D38252636 due to changes in graph tracing
2022-07-29 00:46:27 +00:00
Edward Z. Yang
1a20c69385 Add Python to CompositeImplicitAutograd (#82333)
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82333
Approved by: https://github.com/zou3519
2022-07-28 18:18:51 +00:00
PyTorch MergeBot
df36ccbd81 Revert "add a reinplacing FX pass (#80897)"
This reverts commit 3ef7a6921d.

Reverted https://github.com/pytorch/pytorch/pull/80897 on behalf of https://github.com/malfet due to broke windows trunk tests, see 3ef7a6921d
2022-07-27 22:32:03 +00:00
Brian Hirsh
3ef7a6921d add a reinplacing FX pass (#80897)
Adds a "reinplacing" FX transform, that goes through an FX graph and tries to convert out-of-place op calls into inplace calls whenever possible.

Followups from this PR include:
- Set up torch bench, and run the whole torchbench suite using AOTAutograd + functionalize + rein placing transforms to surface any issues (this is what I'm currently working on). Right now, I have some basic unit tests just to sanity check that the general logic makes sense.
- Add any missing inplace ops. This is mostly the `*_scatter*` ops, e.g. `diagonal_scatter_`, because these ops will commonly show up an FX graph after running functionalization.

The criteria for when you can swap an op `b = a.add(...)` with `a.add_(...)` is:
(1) An inplace variant of the operator with the same schema needs to exist (`aten.add` -> `aten.add_`)
(2) `a` (**or any of its aliases**) can't be used as an input to any other operators later on in the graph
(3) `a` can't be one of the inputs to the entire graph. It also can't be an **alias** of any of the inputs ***

*** One thing to note: (3) means that we can't technically guarantee that we'll get back **all** memory usage that we lost from functionalization. Functionalization converts input mutations into out-of-place calls, and then adds a `copy_()` to the end of the graph to preserve semantics.

I added logic to handle `copy_()` in this PR because it it's a pretty important optimizations in the context of `functionalization()`: any program that performs input mutations will have a `copy_()` in it after running functionalization.

There are some examples in the test file, but I think staring at an example of where re-inplacing is/isn't allowed to run is helpful:
```
// Before functionalization
def foo(a):
    tmp1 = a.add_(1)
    tmp2 = a.add(2)

// After functionalization
def foo(a)
    tmp1 = a.add(1)
    tmp2 = a.add(2)
    ....
    a.copy_(tmp1)

// After re-inplacing
def foo(a)
    // first add() is safe to re-inplace even though a is a program input,
    // because a's data is overwritten later by a copy_()
    tmp1 = a.add_(1)
    // second add() is NOT safe to re-inplace, because:
    // (1) a and tmp1 are aliased. Note that they weren't aliased in the original program,
             but they are now that we've done some re-inplacing.
    // (2) tmp1 is used as an input later in the program
    tmp2 = a.add(2)
    ....
    a.copy_(tmp1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80897
Approved by: https://github.com/ezyang
2022-07-27 19:11:15 +00:00
Nikolay Korovaiko
d2c47d559c Revert "Revert "Enabling SymInt in autograd; take 3 (#81145)"" ; make sure is_intlist checks for symintnodes (#82189)
### Description
<!-- What did you change and why was it needed? -->

### Issue
<!-- Link to Issue ticket or RFP -->

### Testing
<!-- How did you test your change? -->

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82189
Approved by: https://github.com/ezyang
2022-07-26 20:47:11 +00:00
PyTorch MergeBot
c078476eb0 Revert "Enabling SymInt in autograd; take 3 (#81145)"
This reverts commit 032facd6e6.

Reverted https://github.com/pytorch/pytorch/pull/81145 on behalf of https://github.com/jeanschmidt due to breaking internal builds
2022-07-22 11:15:20 +00:00
Nikolay Korovaiko
032facd6e6 Enabling SymInt in autograd; take 3 (#81145)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81145
Approved by: https://github.com/ezyang
2022-07-22 00:14:50 +00:00
Brian Hirsh
ec77d35bda remove backend keys from FunctionalTensorWrapper, update TensorImpl::is_device methods (#81471)
It's kinda annoying to have wrapper subclass tensors (like `FunctionalTensorWrapper` include backend dispatch keys in their keyset, because when we occasionally write something buggy, we'll send the wrapper tensor the the backend kernel (which usually segfaults). By ensuring that wrapper tensors don't get backend keys, we'll get a nicer error when that happens.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81471
Approved by: https://github.com/ezyang
2022-07-21 21:47:29 +00:00
Brian Hirsh
5bd7abf281 functionalization: fix for mutable ops with different type promotion rules (#81702)
fixes https://github.com/pytorch/pytorch/issues/81618

At some point it looks like this became broken (you can see the updated expect test looks better now, and the original was just returning a constant).

I also got a repro that was failing with an assert, that I confirmed now passes:

```
def foo(t, y):
    out_1 = torch.ones(1)
    return torch.add(t, y, out=out_1)

g = make_fx(functionalize(foo))(torch.tensor([1]), torch.tensor([1]))
print(g.code)
out1 = functionalize(foo)(torch.tensor([1]), torch.tensor([1]))
out2 = foo(torch.tensor([1]), torch.tensor([1]))
print(out1 == out2)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81702
Approved by: https://github.com/ezyang
2022-07-19 18:52:35 +00:00
Edward Z. Yang
c09617f98f Revert "Revert "Remove python key when setting functional tensor metadata (#81401)"" (#81456)
This reverts commit f2bb25a758.

For the gory story see https://github.com/pytorch/pytorch/issues/73537
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81456
Approved by: https://github.com/Chillee
2022-07-15 03:53:40 +00:00
Edward Z. Yang
cce2f0d0e4 Disable test_functionalization.py under torchdynamo (#81458)
Tracked at https://github.com/pytorch/pytorch/issues/81457

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81458
Approved by: https://github.com/anijain2305
2022-07-14 16:56:56 +00:00
PyTorch MergeBot
f2bb25a758 Revert "Remove python key when setting functional tensor metadata (#81401)"
This reverts commit b0199c06f6.

Reverted https://github.com/pytorch/pytorch/pull/81401 on behalf of https://github.com/clee2000 due to broke trunk win force_on_cpu tests https://github.com/pytorch/pytorch/runs/7329017706?check_suite_focus=true
2022-07-13 21:55:47 +00:00
Edward Z. Yang
b0199c06f6 Remove python key when setting functional tensor metadata (#81401)
Fixes https://github.com/pytorch/pytorch/issues/81365

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81401
Approved by: https://github.com/bdhirsh
2022-07-13 19:57:40 +00:00
Brian Hirsh
f2dcb11bac basic SymInt test for functionalization (#80418)
`expand` is one of a handful of ops with SymInt support today, so this PR gives a basic test that shows functionalization properly mapping `expand.SymInt` -> `expand_copy.SymInt`. I added the logic to handle this properly in https://github.com/pytorch/pytorch/pull/80251, but didn't add a test for it. (see the [code](https://github.com/pytorch/pytorch/pull/80251/files#diff-da7d91d9e59774e3ee8d120a0f97e52058b73125fd7edd55b5c2e71d4ce5629dR330))

I want to add a more comprehensive test that also shows something more E2E (using `PySymInt`'s to avoid baking in shapes, running functionalization, and fx-tracing the output to show that functionalization ran properly), but I think it's currently blocked on some other work.

At least today, `FakeSymbolicTensor` doesn't play well with `make_fx` (but @Chillee mentioned - should it?)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80418
Approved by: https://github.com/ezyang, https://github.com/albanD
2022-07-12 01:46:16 +00:00
Brian Hirsh
f84b30f790 fix functionalization regression introduced by ProxyTorchDispatchMode, migrate testing to make_fx (#80416)
`ProxyTorchDispatchMode` was added recently as part of `make_fx`, which was secretly causing the meta tensor calls used inside of functionalization to get baked into the graph. It also wasn't caught because the functionalization tests in core don't use `make_fx`, and the tests in functorch aren't as comprehensive.

Now that `make_fx` is in core, I also ported the functionalization test infra over to use it, which would have caught the regression. This also makes the tests cleaner, since mode-based tracing lets us pick up factory functions in the trace output.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80416
Approved by: https://github.com/ezyang, https://github.com/albanD
2022-07-12 01:46:16 +00:00
Animesh Jain
1d90d6ee60 Setup for running PyTorch tests with TorchDynamo and skips for known failing tests (#80106)
@ezyang I am going to keep adding more skips in this PR for now. And once we have the CI running, I will replace with the appropriate decorators.

cc @mlazos , we should add those tests in test_ops.py in this PR as well

cc @jansel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80106
Approved by: https://github.com/ezyang, https://github.com/jansel
2022-07-07 18:57:33 +00:00
Brian Hirsh
960758b0b7 fix overload ambiguity with functional ops; fix _foreach op grouping (#80556)
This should fix the last issue that @anijain2305 hit when running ResNet with TorchDynamo <> functionalization.

Today if you try to call an `OpOverloadPacket` from python with some arguments, we will use the types of those arguments to perform overload resolution. With some functional variants of ops, this can be ambiguous.

Today this affects just one op: `_fused_moving_avg_obs_fq_helper`, although it would potentially affect e.g. `native_batch_norm` in the future.

Example:
```
# There are technically two overloads:
# torch.ops.aten._fused_moving_avg_obs_fq_helper.default (returns 2 argument, mutates 4 of its inputs inplace)
# torch.ops.aten._fused_moving_avg_obs_fq_helper.functional (returns 6 argument, mutates none of its inputs)

# We pick the wrong one - no way to know that we should pick the functional one, just from the call site.
outs = torch.ops.aten._fused_moving_avg_obs_fq_helper(a, a, a, a, a, a, a, 1.0, 0, 1, 0)
# raises an error - tries to call the overload with only 2 returns
return _fused_moving_avg_obs_fq_helper_functional[5]
```

Specifically, functionalization will bake `_fused_moving_avg_obs_fq_helper.functional` into the graph, but when AOTAutograd tries to compile with TorchScript, it needs to remove the overload name (TS doesn't know how to parse overload names directly, so we need to remove the overload name and let it infer the right overload at runtime later- so it picks the wrong one).

The situation is pretty similar to inplace; `ops.aten.add` and `ops.aten.add_` represent two different `OverloadPacket` objects; they can't be overloads of the same op, because their schemas would be ambiguous - the alias annotations are different, but that isn't enough to disambiguate).

In this PR, I try to fix the situation in a pretty similar way to how we handle `inplace` in the data model: `inplace` ops get their own base operator name, but they are represented as a flag inside of `BaseOperatorName` in the data model.

Two other important changes that I made as part of this PR:

(1) Originally, there were ~100 different `*_functional` operators: e.g. we had operators named `resize.functional` and `zero.functional`. The `_functional` bit isn't actually necessary in most cases: it's only necessary for operators that **also** have a `SchemaKind.mutable` variant, where `_fused_moving_avg_obs_fq_helper` is the only op that fits that description today. So I removed the unnecessary notion of "functional" from those other ops. I also added a bunch of assertions to force this restriction.

I think that makes more sense in the long run, because it eliminates an unnecessary difference in the model. E.g. we don't have `add_.Tensor` and `add.Tensor_functional`. We just have `add_.Tensor` and `add.Tensor`.

(2) I noticed that we actually still weren't pairing up a bunch of `_foreach` operators correctly, because their input arguments were different (`self` vs. `tensors`). Since they're private API's, I went ahead and changed the argument names directly so they get matched up. Before this PR, we were generating a separate `_foreach_add` and `_foreach_add.functional` variant in a bunch of cases, that really did the same thing (but happened to have a different name for the first argument).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80556
Approved by: https://github.com/ezyang, https://github.com/albanD
2022-07-06 12:45:11 +00:00
Brian Hirsh
adf8060600 add a new alias key for functional to view op decompositions
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79615

Approved by: https://github.com/zou3519
2022-06-15 23:18:09 +00:00
PyTorch MergeBot
d2200e38f7 Revert "fix _unsafe_view schema to work with functionalization"
This reverts commit 46234df5f1.

Reverted https://github.com/pytorch/pytorch/pull/79148 on behalf of https://github.com/janeyx99 due to Broke 11.3 tests on trunk and on PR, see 46234df5f1
2022-06-10 13:09:00 +00:00
Brian Hirsh
46234df5f1 fix _unsafe_view schema to work with functionalization
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79148

Approved by: https://github.com/albanD
2022-06-10 01:45:04 +00:00
Brian Hirsh
92229adf0c add special handling for resize_() in functionalization pass
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77714

Approved by: https://github.com/ezyang
2022-05-26 16:15:44 +00:00
Brian Hirsh
e9c54ae1c2 functionalization: remove some unnecessary view_copies in inplace views
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77713

Approved by: https://github.com/ezyang
2022-05-26 16:15:44 +00:00
Brian Hirsh
7ff091fc4e move Functionalize dispatch key closer to backends
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77132

Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-05-26 16:15:43 +00:00
Brian Hirsh
5cc258ec9e make block_diag composite compliant
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77716

Approved by: https://github.com/zou3519
2022-05-26 16:15:42 +00:00
Brian Hirsh
07e4533403 reland of as_strided support for functionalization; introduce as_strided_scatter
This reverts commit a95f1edd85.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78199

Approved by: https://github.com/ezyang
2022-05-24 22:40:44 +00:00
PyTorch MergeBot
a95f1edd85 Revert "as_strided support for functionalization; introduce as_strided_scatter"
This reverts commit 3a921f2d26.

Reverted https://github.com/pytorch/pytorch/pull/77128 on behalf of https://github.com/suo due to This broke rocm tests on master 3a921f2d26. rocm tests are no longer run on PRs, you should add a `ciflow/trunk` label if you want to run them
2022-05-24 20:19:12 +00:00