Commit Graph

130 Commits

Author SHA1 Message Date
Elias Ellison
6a2b12dd65 Turn on aliasing tests for fake backwards, Fix Batch norm running mean/var decomp aliasing (#85471)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85471
Approved by: https://github.com/ezyang
2022-09-28 23:06:59 +00:00
Animesh Jain
796da4df4d Return contiguous tensor from softmax decomposition (#85788)
Fixes https://github.com/pytorch/torchdynamo/issues/1135

Softmax decomp's output stride does not match with aten softmax output stride. Not sure if its desirable. Opening a PR for now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85788
Approved by: https://github.com/ngimel, https://github.com/ezyang
2022-09-28 20:52:45 +00:00
Nikita Karetnikov
8dd45424ea [primTorch] Add ref for huber_loss and error inputs (#85041)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85041
Approved by: https://github.com/lezcano, https://github.com/mruberry
2022-09-28 19:56:17 +00:00
Edward Z. Yang
793488cda2 Revert "Revert "Symintifying slice ops (#85196)"" (#85746)
This reverts commit 3a171dfb0c.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85746
Approved by: https://github.com/albanD
2022-09-28 04:37:35 +00:00
PyTorch MergeBot
3a171dfb0c Revert "Symintifying slice ops (#85196)"
This reverts commit 4c01c51266.

Reverted https://github.com/pytorch/pytorch/pull/85196 on behalf of https://github.com/atalman due to Break internal build Exutorch
2022-09-27 18:01:27 +00:00
Fabio Rocha
d5ce2bbed2 [primTorch] decompositions for upsample_bicubic2d (#85403)
FYI, this decomposition seems to be significantly slower than the lowering in torchinductor:

```
------------------------------------- upsample_bicubic2d -------------------------------------]
                                                              |  lowering  |  Inductor  |  Eager
32 threads: ------------------------------------------------------------------------------------
      (torch.Size([16, 4, 128, 256]),), ((512, 1024), True)   |    1.8     |   3.880    |   1.4
      (torch.Size([16, 4, 128, 256]),), ((512, 1024), False)  |    1.9     |   3.887    |   1.4
```

This seems related to the fact that in the lowering we can use int32s as the indices and in the decomp we can only use int64s (see https://github.com/pytorch/torchdynamo/issues/1293).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85403
Approved by: https://github.com/ngimel
2022-09-26 20:11:23 +00:00
Elias Ellison
bcc544e9d7 Add FakeCrossRef tests for backwards, Fix Layer Norm Backward Decomp (#85417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85417
Approved by: https://github.com/ezyang
2022-09-26 17:08:14 +00:00
Fabio Rocha
ffaff8896a Removed None arg check in test/test_decomp.py (#85402)
Not sure why this check was necessary? Tests seem to run fine without
it.
There were definitely tests this was skipping before that it shouldn't,
e.g., pretty much all of the tests for `torch.nn.functional.interpolate`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85402
Approved by: https://github.com/ezyang
2022-09-24 11:37:27 +00:00
Edward Z. Yang
4c01c51266 Symintifying slice ops (#85196)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85196
Approved by: https://github.com/ezyang
2022-09-23 22:01:32 +00:00
PyTorch MergeBot
d10de31cc8 Revert "Add FakeCrossRef tests for backwards, Fix Layer Norm Backward Decomp (#85417)"
This reverts commit 78afa0cf0c.

Reverted https://github.com/pytorch/pytorch/pull/85417 on behalf of https://github.com/clee2000 due to broke tests on trunk 78afa0cf0c
2022-09-23 17:21:43 +00:00
PyTorch MergeBot
3b195fd33e Revert "Turn on aliasing tests for fake backwards, Fix Batch norm running mean/var decomp aliasing (#85471)"
This reverts commit 1e92eb8068.

Reverted https://github.com/pytorch/pytorch/pull/85471 on behalf of https://github.com/clee2000 due to stacked prs https://github.com/pytorch/pytorch/pull/85417 and https://github.com/pytorch/pytorch/pull/85434 broke trunk, reverting this so i can revert the others
2022-09-23 17:13:35 +00:00
Elias Ellison
1e92eb8068 Turn on aliasing tests for fake backwards, Fix Batch norm running mean/var decomp aliasing (#85471)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85471
Approved by: https://github.com/ezyang
2022-09-23 16:02:15 +00:00
Elias Ellison
78afa0cf0c Add FakeCrossRef tests for backwards, Fix Layer Norm Backward Decomp (#85417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85417
Approved by: https://github.com/ezyang
2022-09-23 15:50:03 +00:00
Ryan Spring
71dddec6ea Cast grad_input to half when input_dtype is half in _softmax_backward_data aten decomposition (#85497)
Fixes #85504

`_softmax_backward_data` and `_log_softmax_backward_data` cast `grad_input` to half when the `input_dtype` is half.
When running with amp without the cast, consumer ops can trigger `RuntimeError: expected scalar type Float but found Half`.

https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/SoftMax.cpp#L70-L83
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/SoftMax.cpp#L102-L113

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85497
Approved by: https://github.com/ngimel
2022-09-23 06:52:38 +00:00
PyTorch MergeBot
5043457a8e Revert "Add FakeCrossRef tests for backwards, Fix Layer Norm Backward Decomp (#85417)"
This reverts commit 9c77083965.

Reverted https://github.com/pytorch/pytorch/pull/85417 on behalf of https://github.com/clee2000 due to broke tests on trunk (and pull somehow) 9c77083965
2022-09-22 15:44:38 +00:00
Elias Ellison
9c77083965 Add FakeCrossRef tests for backwards, Fix Layer Norm Backward Decomp (#85417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85417
Approved by: https://github.com/ezyang
2022-09-22 13:03:57 +00:00
Horace He
2f4a517d67 Ported matmul compositeimplicitautograd impl into core (#85239)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85239
Approved by: https://github.com/ezyang, https://github.com/lezcano
2022-09-21 09:25:24 +00:00
lezcano
d17b144e65 Adding multigammaln ref and fix arange (#85153)
Partially based on https://github.com/pytorch/pytorch/pull/83662.

I'll help land this one, as Rob does not work in the PyTorch project
anymore

I removed the data-dependent check for the args, as data dependencies
are bad for many reasons (and it was failing when the input has NaNs).

It also registers arange as a decomposition, and fixes the naming of its
args.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85153
Approved by: https://github.com/mruberry, https://github.com/ngimel
2022-09-20 17:52:56 +00:00
lezcano
5dd9610e9d Refs and decompositions for index_{add,copy,select,fill} (#85002)
As per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85002
Approved by: https://github.com/ngimel
2022-09-17 19:57:34 +00:00
PyTorch MergeBot
e33b464ffc Revert "Refs and decompositions for index_{add,copy,select,fill} (#85002)"
This reverts commit 2f0b3de443.

Reverted https://github.com/pytorch/pytorch/pull/85002 on behalf of https://github.com/huydhn due to Broke trunk slow tests
2022-09-17 04:26:04 +00:00
lezcano
2f0b3de443 Refs and decompositions for index_{add,copy,select,fill} (#85002)
As per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85002
Approved by: https://github.com/ngimel
2022-09-16 23:59:35 +00:00
Sherlock Huang
29eba319b4 Use alias for nop decomp (#84727)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84727
Approved by: https://github.com/Chillee
2022-09-16 18:50:56 +00:00
Natalia Gimelshein
6162a04364 fix half_to_float arg in *softmax decomp (#85120)
Fixes https://github.com/pytorch/torchdynamo/issues/1239

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85120
Approved by: https://github.com/Chillee
2022-09-16 15:54:50 +00:00
soulitzer
7f88934a8f [reland 2] Call jit decomp in VariableType to improve forward AD coverage (#84976)
Reland of https://github.com/pytorch/pytorch/pull/84675
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84976
Approved by: https://github.com/zou3519
2022-09-15 22:46:19 +00:00
PyTorch MergeBot
36d79143ce Revert "[reland] Call jit decomposition in VariableType to increase forward AD coverage (#84151) (#84675)"
This reverts commit bb4e96c964.

Reverted https://github.com/pytorch/pytorch/pull/84675 on behalf of https://github.com/osalpekar due to causing asan xplat link-time errors like ld.lld: error: undefined symbol: torch::jit::has_jit_decomposition(c10::FunctionSchema const&)
2022-09-13 22:54:54 +00:00
soulitzer
bb4e96c964 [reland] Call jit decomposition in VariableType to increase forward AD coverage (#84151) (#84675)
This reverts commit acb4a09628.

In addition, we also fix a memory leak in layer norm.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84675
Approved by: https://github.com/zou3519
2022-09-12 20:33:14 +00:00
Horace He
1459a909b4 Added mv, mm, and binary_cross_entropy_with_logits decomps (#84451)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84451
Approved by: https://github.com/ngimel
2022-09-08 17:56:18 +00:00
soulitzer
e31ad1c2d3 [reland] Move decompositions and helpers for jvp from functorch into core (#84581)
Reland of https://github.com/pytorch/pytorch/pull/84358
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84581
Approved by: https://github.com/samdow
2022-09-07 15:31:46 +00:00
Ivan Yashchuk
6363b1b358 Add nvFuser support for aten.native_batch_norm_backward (#84546)
Replacing `tensor.reshape(broadcast_mask)` with unsqueezes makes the implementation of `batch_norm_backward` more friendly for PrimTorch+nvFuser.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84546
Approved by: https://github.com/Chillee
2022-09-06 19:56:17 +00:00
Fabio Rocha
91a5f52f51 Decomp for nn.functional.grid_sampler_2d (#84350)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84350
Approved by: https://github.com/jansel, https://github.com/Lezcano
2022-09-05 21:33:26 +00:00
lezcano
3dfbf09afe Optimise the decomposition for adaptive_avg_pool2d wrt. TorchInductor (#84483)
This fixes some part of the implementation that did not work with
TorchInductor (e.g. the indices in TorchInductor need to be `int64`s,
while in PyTorch we can have `int32`s).

It also brings up the performance of the kernel to similar numbers than
those of the lowering (benchmarks below).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84483
Approved by: https://github.com/jansel
2022-09-02 22:25:09 +00:00
PyTorch MergeBot
375d6cd5b7 Revert "Move decompositions and helpers for jvp from functorch into core (#84358)"
This reverts commit a3c60a4db4.

Reverted https://github.com/pytorch/pytorch/pull/84358 on behalf of https://github.com/malfet due to Broke lint
2022-09-01 23:42:48 +00:00
soulitzer
a3c60a4db4 Move decompositions and helpers for jvp from functorch into core (#84358)
This refactor shouldn't change any behavior. At this point functorch still relies on the mechanism in DynamicLayerFront; we just moved some parts of it into core.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84358
Approved by: https://github.com/samdow
2022-09-01 22:39:15 +00:00
Sherlock Huang
ef3ab31f1c Decomp for aten.im2col (#84303)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84303
Approved by: https://github.com/jansel, https://github.com/ngimel
2022-09-01 00:06:35 +00:00
Nikita Karetnikov
71ce9cd072 [primTorch] Add decomp for soft_margin_loss (#83804)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83804
Approved by: https://github.com/Lezcano, https://github.com/ngimel
2022-08-31 17:39:34 +00:00
Nikita Shulga
b8e1c54f53 [Prim] Implement group_norm_backward (#84037)
Test plan: CI, i.e. `python3 test_decomp.py -v -k test_comprehensive_nn_functional_group_norm` plus:
```
#!/usr/bin/env python3.8
import torch

func = torch.ops.aten.native_group_norm_backward.default
decomp =  torch._decomp.decomposition_table[func]
for args in (
        (torch.rand(1, 6, 3), torch.rand(1, 6, 3), torch.rand(1, 2), torch.rand(1, 2), torch.rand(6), 1, 6, 3, 2, [True, True, True]),
        (torch.rand(64, 768, 7, 7), torch.rand(64, 768, 7, 7), torch.rand(64, 1), torch.rand(64, 1), torch.rand(768), 64, 768, 49, 1, [True, True, True])):
    nrc=func(*args)
    drc=decomp(*args)
    for i in range(len(nrc)):
       print(i, torch.max(nrc[i]-drc[i]))
    print(all(torch.allclose(x, y) for (x, y) in zip(nrc, drc)))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84037
Approved by: https://github.com/Chillee, https://github.com/ngimel
2022-08-29 09:29:30 +00:00
Natalia Gimelshein
533203f5aa _to_copy decomp (#84108)
Per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84108
Approved by: https://github.com/Chillee
2022-08-29 02:25:02 +00:00
lezcano
9fc02f6bc5 Decomposition for adaptive_avg_pool2d (#84062)
This was already implemented as a lowering in https://github.com/pytorch/torchdynamo/pull/962. I'm putting the idea up here ~(I haven't even run this code, so it surely has *many* issues, but I reckon the general idea should hopefully be alright).~ The tests now pass and I corrected the issues that the first implementation had.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84062
Approved by: https://github.com/jansel
2022-08-29 01:38:51 +00:00
PyTorch MergeBot
33db5da4c1 Revert "[Prim] Implement group_norm_backward (#84037)"
This reverts commit bed85cce8b.

Reverted https://github.com/pytorch/pytorch/pull/84037 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-08-28 17:30:50 +00:00
PyTorch MergeBot
ff23f3ac1c Revert "_to_copy decomp (#84108)"
This reverts commit e33897cb99.

Reverted https://github.com/pytorch/pytorch/pull/84108 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-08-28 13:27:49 +00:00
Natalia Gimelshein
e33897cb99 _to_copy decomp (#84108)
Per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84108
Approved by: https://github.com/Chillee
2022-08-27 03:51:03 +00:00
Nikita Shulga
bed85cce8b [Prim] Implement group_norm_backward (#84037)
Test plan: CI, i.e. `python3 test_decomp.py -v -k test_comprehensive_nn_functional_group_norm` plus:
```
#!/usr/bin/env python3.8
import torch

func = torch.ops.aten.native_group_norm_backward.default
decomp =  torch._decomp.decomposition_table[func]
for args in (
        (torch.rand(1, 6, 3), torch.rand(1, 6, 3), torch.rand(1, 2), torch.rand(1, 2), torch.rand(6), 1, 6, 3, 2, [True, True, True]),
        (torch.rand(64, 768, 7, 7), torch.rand(64, 768, 7, 7), torch.rand(64, 1), torch.rand(64, 1), torch.rand(768), 64, 768, 49, 1, [True, True, True])):
    nrc=func(*args)
    drc=decomp(*args)
    for i in range(len(nrc)):
       print(i, torch.max(nrc[i]-drc[i]))
    print(all(torch.allclose(x, y) for (x, y) in zip(nrc, drc)))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84037
Approved by: https://github.com/Chillee, https://github.com/ngimel
2022-08-27 01:10:27 +00:00
Horace He
9a236c7ab4 Made some minor cleanups to decompositions (#83814)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83814
Approved by: https://github.com/ngimel
2022-08-26 10:55:31 +00:00
Animesh Jain
e2f75d63d4 Decomposition - batch_norm, save_mean and save_variance always float32 (#84013)
AMP error shown here - https://github.com/pytorch/torchdynamo/issues/835

Test missing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84013
Approved by: https://github.com/ezyang
2022-08-25 16:09:52 +00:00
Ivan Yashchuk
473b733bae Replace .new_zeros(()) with 0.0 in torch/_decomp/decompositions (#83734)
`new_zeros` is decomposed into `prims.empty_strided`+`prims.fill`+`prims.copy_to` and none of these are supported by prims+nvFuser executor currently.
Replacing it with 0.0 makes these backward decompositions nvFuser friendly.

Example with `torch.ops.aten.hardsigmoid_backward.default`:
```py
# Before this PR
opcode         name                      target                            args                                                          kwargs
-------------  ------------------------  --------------------------------  ------------------------------------------------------------  ----------------------------------------------------------------------------------------
placeholder    a_1                       a_1                               ()                                                            {}
placeholder    g_1                       g_1                               ()                                                            {}
call_function  gt_default                nvprims.gt.default                (a_1, -3.0)                                                   {}
call_function  lt_default                nvprims.lt.default                (a_1, 3.0)                                                    {}
call_function  bitwise_and_default       nvprims.bitwise_and.default       (gt_default, lt_default)                                      {}
call_function  mul_default               nvprims.mul.default               (g_1, 0.16666666666666666)                                    {}
call_function  empty_strided             prims.empty_strided.default       ([], [])                                                      {'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}
call_function  fill_default              prims.fill.default                (empty_strided, 0)                                            {}
call_function  copy_to_default           prims.copy_to.default             (empty_strided, fill_default)                                 {}
call_function  broadcast_in_dim_default  nvprims.broadcast_in_dim.default  (copy_to_default, [3, 2], [])                                 {}
call_function  where_default             nvprims.where.default             (bitwise_and_default, mul_default, broadcast_in_dim_default)  {}
output         output                    output                            (where_default,)                                              {}

# After this PR
opcode         name                 target                       args                                     kwargs
-------------  -------------------  ---------------------------  ---------------------------------------  --------
placeholder    a_1                  a_1                          ()                                       {}
placeholder    g_1                  g_1                          ()                                       {}
call_function  gt_default           nvprims.gt.default           (a_1, -3.0)                              {}
call_function  lt_default           nvprims.lt.default           (a_1, 3.0)                               {}
call_function  bitwise_and_default  nvprims.bitwise_and.default  (gt_default, lt_default)                 {}
call_function  mul_default          nvprims.mul.default          (g_1, 0.16666666666666666)               {}
call_function  where_default        nvprims.where.default        (bitwise_and_default, mul_default, 0.0)  {}
output         output               output                       (where_default,)                         {}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83734
Approved by: https://github.com/Chillee
2022-08-22 09:12:13 +00:00
Edward Z. Yang
02581f053b Address CR comments for "Delete ProxyTensor wrapper subclass" (#83646)
CR is on https://github.com/pytorch/pytorch/pull/83330

- Factor proxy slot getters/setters into helper functions
- Use a weak map for storing proxies, so they go away when
  tracing is done
- More documentation on SymDispatchMode

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83646
Approved by: https://github.com/Chillee
2022-08-18 22:18:09 +00:00
Edward Z. Yang
817a82704f Delete ProxyTensor wrapper subclass (#83330)
I was working on https://github.com/pytorch/torchdynamo/issues/80 and my
working hypothesis for what was causing the error was that proxy tensor
was not advertising correct dispatch keys, causing AMP to operate
differently when you traced.  I could have fixed this directly by
replicating fake tensor's fix for setting dispatch keys to also apply to
proxy tensor, but I was like, "Why must I repeat myself."

This PR is the result.  It completely deletes the ProxyTensor wrapper
subclass, so that when we are tracing, the tensors flowing through the
program are the *original* real or fake tensors, depending on what the
user requested in the top-level API.  There is no more wrapping.  To
store the Proxy objects necessary for actually doing tracing, I store
the property directly on the tensors.  (Note: I never
clean up old entries from the map at the moment, this is easily fixed
by using a weak map)

Benefits of doing this:

* No more tip-toeing around no_dispatch() creation of new ProxyTensors;
  we never create new tensors (except when we call the underlying func),
  so you don't have to worry about accidentally tracing them.

* No more syncing up metadata from in place operators.  In particular
  https://github.com/pytorch/pytorch/issues/81526 is mooted

* This fixes https://github.com/pytorch/torchdynamo/issues/519 as we no longer need to teach proxy tensor to support sparse tensor.

* No more schlepping symbolic integers from the inner fake tensor to the
  outer proxy tensor.  If you can make a fake tensor with symbolic ints,
  you're done, nothing else to do.

To avoid having to rewrite all of the guts, when I get to the actual
proxy tensor handler, I first "fetch" the stored ProxyTensor data from
the weakmap via a tree_map, and then operate on the consequent data as
before.  A more optimized implementation is possible.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83330
Approved by: https://github.com/Chillee
2022-08-18 01:56:07 +00:00
Nikita Karetnikov
cd86d25515 [primTorch] Move addcdiv from decompositions -> refs (#80842)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80842
Approved by: https://github.com/Lezcano, https://github.com/ngimel
2022-08-16 17:23:00 +00:00
Horace He
f02f304657 Added nll_loss_forward decomposition + some other minor decomps (#83235)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83235
Approved by: https://github.com/ngimel
2022-08-13 10:24:58 +00:00
Brian Hirsh
1a51efd8bb dispatch API for checking computed table, use it in prim decomps (#82358)
Fixes https://github.com/pytorch/pytorch/issues/82331

Expose a `torch._C._dispatch_has_computed_kernel_for_dispatch_key` to check if an operator has a kernel registered to the given dispatch key in the **computed table**.

Use it in the prim registration logic, making it more accurate and robust (so that it e.g. picks up `CompositeExplicitAutograd` kernels.

It looks like before this change we'd register 134 prim ops to the meta key, and after we only register 62. So that's 72 ops that now use an existing C++ decomp to get meta working, instead of going directly through the prim decomp.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82358
Approved by: https://github.com/ezyang
2022-08-10 23:42:02 +00:00