Commit Graph

135 Commits

Author SHA1 Message Date
drisspg
1434e0b121 Add a private _safe_softmax (#131060)
# Summary
Changes the stance of SDPA on what to do for fully masked out rows

## Current Behavior
Several PyTorch users have expressed frustration over this issue:
- https://github.com/pytorch/pytorch/issues/41508
- https://github.com/pytorch/pytorch/issues/103749
- https://github.com/pytorch/pytorch/issues/103963

These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here:
https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617

Can be paraphrased as follows:

When passing in fully masked out rows, attention becomes ambiguous. We have two main options:

1. Uniformly attend to all values:
   ```python
   scores[masked_out_rows] = 1 / len(row)
   out[masked_out_rows] = 1 / len(row) * value
   ```

2. Decide that attention between no queries (masked) and no keys (masked) is meaningless:
   ```python
   output[fully_masked_rows] = NaN
   ```

We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs:
``` Python
>fill_value = -float("inf")
>row0 = torch.randn(4)
>row1 = torch.tensor([(fill_value for _ in range(4)])
>matrix = torch.stack([row0, row1]).requires_grad_(True)
>out = torch.softmax(matrix, 1)
>out = out[0]
>print(out)
tensor([0.5377, 0.2729, 0.0692, 0.1201])
```
Cool, problem solved. But what happends when you call backwards..
```Python
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08],
        [       nan,        nan,        nan,        nan]])
```
Those pesky NaNs are back!

## Why do we see NaNs today?

The core of the problem revolves around using softmax function in sdpa:

```python
> row = torch.tensor([(-float("inf")) for _ in range(4)])
> torch.softmax(row, 0)
tensor([nan, nan, nan, nan])
```

## Quick Aside: Masking in Attention

Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs.

We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values.

## Alternative Approaches

If we use a very large negative number instead of -inf:

```python
> row = torch.tensor([(-1e6) for _ in range(4)])
> torch.softmax(row, 0)
tensor([0.2500, 0.2500, 0.2500, 0.2500])
```
However if users always remembered to "slice" out their outputs i.e.:
```Python
>fill_value = -1e6
>...
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[-0.0563, -0.0564,  0.1613, -0.0486],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])
```
This would bring us back into a better state.

## A Third Option

We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation.

This PR implements the new semantic for masking w/ attention in fully masked-out rows:
```python
out[masked_out_rows] = 0
```

**Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption.

## Details
This PR stack does 3 things:
1. Adds a PRIVATE _safe_softmax op
2. Updates semantic for flash_cpu fused kernel
3. Updates semantic for efficient_cuda fused kernel

_safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num.

Why I think this is okay? (please find a counter point if avail)
There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them?

The only case that this can happen is if the input itself had a NaN or an Inf
For example:
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = torch.finfo(torch.float16).max
print(a.softmax(-1))
```
Will return
`tensor([0., 1., 0., 0.], dtype=torch.float16)`

Where
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = float("inf")
a.softmax(-1)
```
returns:
`tensor([nan, nan, nan, nan], dtype=torch.float16)`

If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this

```Python
max = torch.max(a, dim=-1, keepdim=True)
exp = torch.exp(a - max.values)
denom = torch.sum(exp, dim=-1, keepdim=True)
softmax = exp / denom
softmax = torch.where(max.values == float('-inf'), 0.0, softmax)
```
however we would be paying for this in math performance.

## Why Now
I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131060
Approved by: https://github.com/jbschlosser
2024-08-08 23:09:38 +00:00
Jianyu Huang
c7cfa51721 Always use high precision for SDPA math backend (#128922)
Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.

Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.

Differential Revision: D58710805

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922
Approved by: https://github.com/xw285cornell, https://github.com/drisspg
2024-08-04 23:58:14 +00:00
Xuehai Pan
4226ed1585 [BE] Format uncategorized Python files with ruff format (#132576)
Remove patterns `**`, `test/**`, and `torch/**` in `tools/linter/adapters/pyfmt_linter.py` and run `lintrunner`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132576
Approved by: https://github.com/ezyang, https://github.com/Skylion007
ghstack dependencies: #132574
2024-08-04 17:13:31 +00:00
PyTorch MergeBot
59b73079a0 Revert "Always use high precision for SDPA math backend (#128922)"
This reverts commit fbf3bc0a60.

Reverted https://github.com/pytorch/pytorch/pull/128922 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR has a dependency on another PR (https://github.com/pytorch/pytorch/pull/128898) that has to be reverted ([comment](https://github.com/pytorch/pytorch/pull/128922#issuecomment-2265949958))
2024-08-02 18:46:50 +00:00
Jianyu Huang
fbf3bc0a60 Always use high precision for SDPA math backend (#128922)
Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.

Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.

Differential Revision: D58710805

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922
Approved by: https://github.com/xw285cornell, https://github.com/drisspg
2024-08-01 18:55:48 +00:00
Oguz Ulgen
221350e3a4 Add None return type to init -- tests (#132352)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352
Approved by: https://github.com/ezyang
ghstack dependencies: #132335, #132351
2024-08-01 15:44:51 +00:00
eellison
baa4c9ca46 Optimize aten.cat calls of a repeated element (#132081)
This was a particular problem for a model I saw which would have a large number of repeats, making compilation slow.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132081
Approved by: https://github.com/shunting314
2024-07-30 02:56:00 +00:00
Isuru Fernando
43a6d20883 Add decomposition for reflection_pad{1,2,3}d_backward (#130299)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130299
Approved by: https://github.com/lezcano
ghstack dependencies: #130130
2024-07-17 21:56:00 +00:00
Xuehai Pan
ba48cf6535 [BE][Easy][6/19] enforce style for empty lines in import segments in test/ (#129757)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129757
Approved by: https://github.com/ezyang
2024-07-17 06:42:37 +00:00
Joel Schlosser
c8ab2e8b63 Set seed per sample for OpInfo tests + support for restricting to a single sample input (#128238)
This PR:
* Sets a random seed before generating each sample for an OpInfo test. It does this by intercepting the sample input iterator via `TrackedInputIter`, optionally setting the seed to a test name specific seed before each iterator call (default is to set the seed).
    * Some quick and dirty benchmarking shows (hopefully) negligible overhead from setting the random seed before each sample input generation. For a trivial (single assert) test that uses `@ops`:
* Uncovered a bunch of test issues:
    * Test breakdown (>100 total)
        * A lot of tolerance issues (tweaked tolerance values to fix)
        * 1 broken OpInfo (`sample_inputs_masked_fill` was generating a sample of the wrong dtype)
        * 3 actually broken semantics (for masked tensor; added xfails)
        * 4 Jacobian mismatches (added xfails)
        * 2 nan results (skip for now, need fixing)
        * 3 results too far from reference result (add xfails)
* Skips MPS tests for now (there are so many failures!). Those will default to the old behavior.

**before (no seed setting):**
```
real	0m21.306s
user	0m19.053s
sys	0m5.192s
```

**after (with seed setting):**
```
real	0m21.905s
user	0m19.578s
sys	0m5.390s
```

* Utilizing the above for reproducible sample input generation, adds support for restricting the iterator to a single sample input. This is done via an env var `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX` and its usage is included in the repro command.

```
======================================================================
ERROR: test_bar_add_cuda_uint8 (__main__.TestFooCUDA.test_bar_add_cuda_uint8)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_device_type.py", line 971, in test_wrapper
    return test(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jbschlosser/branches/testing_updates/test/test_ops.py", line 2671, in test_bar
    self.assertFalse(True)
AssertionError: True is not false

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_utils.py", line 2816, in wrapper
    method(*args, **kwargs)
  File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_utils.py", line 2816, in wrapper
    method(*args, **kwargs)
  File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_device_type.py", line 419, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_utils.py", line 1426, in wrapper
    fn(*args, **kwargs)
  File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_device_type.py", line 982, in test_wrapper
    raise new_e from e
Exception: Caused by sample input at index 3: SampleInput(input=Tensor[size=(10, 5), device="cuda:0", dtype=torch.uint8], args=TensorList[Tensor[size=(), device="cuda:0", dtype=torch.uint8]], kwargs={}, broadcasts_input=False, name='')

To execute this test, run the following from the base repo dir:
    PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=3 python test/test_ops.py -k TestFooCUDA.test_bar_add_cuda_uint8

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 1 test in 0.037s

FAILED (errors=1)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128238
Approved by: https://github.com/janeyx99, https://github.com/justinchuby
2024-07-08 16:06:38 +00:00
eellison
8cd9b10456 Fix exp decomp numerics (#129154)
Our previous implementation would sometimes generate `inf` because we did not do the same numerics tricks as in eager:

See comment / [link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/TransformationHelper.h#L123-L144) :
```
    # curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
    # we need log to be not 0, and not underflow when converted to half
    # fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args
```

Fix for https://github.com/pytorch/pytorch/issues/127749.

Added a test for non-inf, but it would be great to have more robust decomp distribution tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129154
Approved by: https://github.com/bdhirsh, https://github.com/zou3519
2024-06-21 03:21:30 +00:00
Peter Bell
39de62845a [decomp] Fix default values missing from inplace rrelu decomposition (#126978)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126978
Approved by: https://github.com/lezcano
2024-05-26 23:49:40 +00:00
Yuanhao Ji
c165a8e71d Enable UFMT on test_decomp.py, test_expanded_weights.py and some files (#125117)
Part of: #123062

Ran lintrunner on:

- test/test_decomp.py
- test/test_deploy.py
- test/test_determination.py
- test/test_dlpack.py
- test/test_dynamic_shapes.py
- test/test_expanded_weights.py

Detail:

```bash
$ lintrunner -a --take UFMT --all-files
ok No lint issues.
Successfully applied all patches.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125117
Approved by: https://github.com/jansel
2024-05-07 02:36:40 +00:00
mashaobin
af67704dcc [privateuse1] _refs.masked_fill support privateuse1 when value.device.type is cpu (#124835)
_refs.masked_fill support privateuse1 when value.device.type is cpu.

1. maybe I should consider whether this modification meets the expectations of other privateuse1 devices,
2. add TestCase

Fixes #124693

Co-authored-by: albanD <desmaison.alban@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124835
Approved by: https://github.com/albanD
2024-05-01 18:57:14 +00:00
Isuru Fernando
97ccfad915 Fix test_decomp test for ops with py_impl(CompositeImplicitAutograd) (#116832)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116832
Approved by: https://github.com/lezcano
2024-04-20 11:10:38 +00:00
vfdev-5
6b7741546b Fixed arange decomp for float dtype (#121013)
## Description:

- [x] Fixed arange decomp for float dtype
- [x] Added a test

## Current state

Arange graph and C++ generated code are not optimal when arange is created directly using float32 dtype:
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s, dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on `main`:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f64[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float64);  iota = None
        mul: "f64[10]" = torch.ops.aten.mul.Tensor(convert_element_type, 1);  convert_element_type = None
        add: "f64[10]" = torch.ops.aten.add.Tensor(mul, 0);  mul = None
        convert_element_type_1: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None

        # File: check_arange_decomp.py:9 in func, code: return s + a
        add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type_1, 10);  convert_element_type_1 = None
        return (add_1,)

 ===== AFTER POST GRAD =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:15 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f64[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float64);  iota = None
        mul: "f64[10]" = torch.ops.aten.mul.Tensor(convert_element_type, 1);  convert_element_type = None
        add: "f64[10]" = torch.ops.aten.add.Tensor(mul, 0);  mul = None
        convert_element_type_1: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None

        # File: check_arange_decomp.py:16 in func, code: return s + a
        add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type_1, 10);  convert_element_type_1 = None
        return (add_1,)
```
and C++
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<double>(tmp0);   // <---- useless ops
            auto tmp2 = static_cast<double>(1.0);     // <----
            auto tmp3 = decltype(tmp1)(tmp1 * tmp2);  // <----
            auto tmp4 = static_cast<double>(0.0);     // <----
            auto tmp5 = decltype(tmp3)(tmp3 + tmp4);  // <----
            auto tmp6 = c10::convert<float>(tmp5);
            auto tmp7 = static_cast<float>(10.0);
            auto tmp8 = decltype(tmp6)(tmp6 + tmp7);
            out_ptr0[static_cast<long>(x0)] = tmp8;
        }
    }
}
```

However, if we manually create arange on i64 and then put to float32, generated graph and C++ code are more natural and benefit of a speed-up.
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s).to(dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on `main`:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:14 in func, code: a = torch.arange(s).to(dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32);  iota = None

        # File: check_arange_decomp.py:15 in func, code: return s + a
        add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add,)

 ===== AFTER POST GRAD =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:21 in func, code: a = torch.arange(s).to(dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32);  iota = None

        # File: check_arange_decomp.py:22 in func, code: return s + a
        add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add,)
```

C++ on `main`
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<float>(tmp0);
            auto tmp2 = static_cast<float>(10.0);
            auto tmp3 = decltype(tmp1)(tmp1 + tmp2);
            out_ptr0[static_cast<long>(x0)] = tmp3;
        }
    }
}
```

For example, the speed-up seen on upsample_nearest2d on cpu:
```
[----------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu ----------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                |  Eager (2.3.0a0+gitb4324ed) PR  |  Compiled (2.3.0a0+gitb4324ed) PR  |  Compiled (2.3.0a0+git0d1e705) Nightly  |  speed-up PR vs Nightly  |  Eager (2.3.0a0+git0d1e705) Nightly
1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)      |        287.988 (+-10.399)       |         200.034 (+-8.630)          |            285.143 (+-8.412)            |     1.425 (+-0.000)      |          287.991 (+-11.302)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)          |        697.206 (+-27.033)       |         171.650 (+-7.381)          |            193.280 (+-5.840)            |     1.126 (+-0.000)      |          701.642 (+-26.461)
      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)    |        149.149 (+-6.045)        |         222.780 (+-6.852)          |            299.968 (+-12.354)           |     1.346 (+-0.000)      |          145.055 (+-7.232)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)        |        596.741 (+-27.970)       |         205.923 (+-8.648)          |            233.912 (+-7.742)            |     1.136 (+-0.000)      |          598.000 (+-25.630)
      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)      |       1095.734 (+-51.658)       |         700.850 (+-24.852)         |           1044.255 (+-38.216)           |     1.490 (+-0.000)      |         1097.977 (+-35.521)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)          |       2741.813 (+-122.917)      |         583.073 (+-16.998)         |            665.029 (+-36.331)           |     1.141 (+-0.000)      |         2722.388 (+-116.263)
      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)    |        578.183 (+-37.266)       |         833.295 (+-42.264)         |           1131.341 (+-54.710)           |     1.358 (+-0.000)      |          584.953 (+-45.549)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)        |       2332.508 (+-103.556)      |         840.194 (+-47.664)         |            935.625 (+-47.467)           |     1.114 (+-0.000)      |         2334.314 (+-91.644)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)    |        272.631 (+-11.348)       |         195.988 (+-5.748)          |            274.021 (+-9.475)            |     1.398 (+-0.000)      |          272.752 (+-12.716)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)        |        640.409 (+-25.465)       |         164.773 (+-7.372)          |            185.018 (+-8.349)            |     1.123 (+-0.000)      |          639.390 (+-30.761)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)  |        158.602 (+-6.593)        |         220.478 (+-6.809)          |            286.376 (+-8.981)            |     1.299 (+-0.000)      |          158.557 (+-6.143)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)      |        548.903 (+-22.889)       |         202.788 (+-9.158)          |            227.404 (+-8.995)            |     1.121 (+-0.000)      |          554.096 (+-21.330)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)    |       1036.061 (+-35.285)       |         680.728 (+-30.925)         |            986.254 (+-42.732)           |     1.449 (+-0.000)      |         1038.718 (+-43.070)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)        |       2504.520 (+-125.805)      |         550.067 (+-21.383)         |            628.000 (+-27.589)           |     1.142 (+-0.000)      |         2523.134 (+-113.336)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)  |       1058.188 (+-57.853)       |        1216.427 (+-76.160)         |           1380.231 (+-98.939)           |     1.135 (+-0.000)      |         1057.031 (+-66.075)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)      |       2305.911 (+-116.864)      |        1080.189 (+-79.934)         |           1141.561 (+-67.959)           |     1.057 (+-0.000)      |         2306.606 (+-121.544)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)      |       1689.489 (+-60.579)       |        1077.401 (+-44.948)         |           1634.264 (+-64.340)           |     1.517 (+-0.000)      |         1693.945 (+-67.998)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)          |       4198.368 (+-179.096)      |         886.656 (+-30.355)         |           1028.568 (+-46.310)           |     1.160 (+-0.000)      |         4174.351 (+-141.020)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)    |        716.572 (+-51.954)       |        1175.864 (+-52.191)         |           1674.373 (+-51.815)           |     1.424 (+-0.000)      |          715.724 (+-41.104)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)        |       3604.989 (+-132.489)      |        1096.933 (+-54.290)         |           1270.347 (+-60.932)           |     1.158 (+-0.000)      |         3601.864 (+-140.218)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)      |       6721.610 (+-355.997)      |        4203.213 (+-134.362)        |           6423.763 (+-225.311)          |     1.528 (+-0.000)      |         6715.626 (+-288.233)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)          |      16695.467 (+-709.620)      |        3460.013 (+-149.456)        |           4001.810 (+-218.093)          |     1.157 (+-0.000)      |        16621.138 (+-713.320)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)    |       3020.017 (+-147.314)      |        4743.164 (+-135.850)        |           6709.494 (+-281.025)          |     1.415 (+-0.000)      |         3015.602 (+-105.852)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)        |      14456.688 (+-752.839)      |        5150.893 (+-201.571)        |           5737.315 (+-138.011)          |     1.114 (+-0.000)      |        14464.472 (+-720.027)

Times are in microseconds (us).
```

## PR

This PR improves arange decomp such that `arange(s, dtype=torch.float32)` removing extra dtype conversion to double:

Code:
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s, dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on this PR:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:15 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        mul: "i64[10]" = torch.ops.aten.mul.Tensor(iota, 1);  iota = None
        add: "i64[10]" = torch.ops.aten.add.Tensor(mul, 0);  mul = None
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None

        # File: check_arange_decomp.py:16 in func, code: return s + a
        add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add_1,)

 ===== AFTER POST GRAD =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:16 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        mul: "i64[10]" = torch.ops.aten.mul.Tensor(iota, 1);  iota = None
        add: "i64[10]" = torch.ops.aten.add.Tensor(mul, 0);  mul = None
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None

        # File: check_arange_decomp.py:17 in func, code: return s + a
        add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add_1,)
```
and C++ on this PR:
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<float>(tmp0);
            auto tmp2 = static_cast<float>(10.0);
            auto tmp3 = decltype(tmp1)(tmp1 + tmp2);
            out_ptr0[static_cast<long>(x0)] = tmp3;
        }
    }
}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121013
Approved by: https://github.com/peterbell10
2024-04-11 09:02:31 +00:00
William Wen
cbde0f048b [dynamo, 3.12] enable tests disabled due to missing dynamo 3.12 support (#123300)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123300
Approved by: https://github.com/jansel, https://github.com/malfet, https://github.com/zou3519
2024-04-05 20:13:17 +00:00
atalman
244b124bb8 Add linux cpu test for 3.12 (#117853)
This is continuation of work: https://github.com/pytorch/pytorch/pull/113987

Co-authored-by: albanD <desmaison.alban@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117853
Approved by: https://github.com/albanD
2024-02-14 20:52:23 +00:00
Mengwei Liu
1e4b408b02 [decomp] Add tests for different dtypes to SDPA decomposition (#119239)
Summary: As titled. Skipping torch.bfloat16 because for some reason the
difference is 0.01.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119239
Approved by: https://github.com/drisspg
2024-02-06 11:17:07 +00:00
Elias Ellison
e87ac82c98 Fix missing default dim param in weight norm interface decomp (#118762)
Fix for https://github.com/pytorch/pytorch/issues/118742

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118762
Approved by: https://github.com/ezyang, https://github.com/shunting314
2024-01-31 22:10:10 +00:00
Digant Desai
e2830e6328 [PyTorch] SDPA decomp: actually use attn_mask (#117579)
Summary: Need to pass this along

Test Plan:
```
cd ~/fbsource/fbcode/executorch/backends/xnnpack/test
buck test fbcode//mode/dev-nosan :test_xnnpack_ops -- test_fp32_sdpa
buck run fbcode//mode/dev-nosan :test_xnnpack_models -- executorch.backends.xnnpack.test.models.llama2_et_example.TestLlama2ETExample.test_fp32
```

Reviewed By: larryliu0820

Differential Revision: D52812369

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117579
Approved by: https://github.com/larryliu0820
2024-01-17 10:26:43 +00:00
Aaron Orenstein
638f85fd67 Add default parameters to rrelu_with_noise() (#117141)
Summary:
rrelu_with_noise() was listed as having default parameters in the schema but the
actual code definition didn't have them.

The failing example was calling rrelu() which DOES have default parameters and
it passes those defaulted values to C++. Under the covers the C code was calling
the python version of rrelu_with_noise().

Although the C++ code was passing all the values to the python version of
rrelu_with_noise() the pytorch C++ -> Python dispatch code looks at the schema
and strips any parameters which match the schema's listed defaults so if the
schema shows defaults that aren't in the code it will be a problem.

Test Plan:
I added a unit test for this specific case. It would probably be better to write
a more general one to validate all the ops against their schemas - but I haven't
learned enough about the test harness to do that yet.

Fixes #115811

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117141
Approved by: https://github.com/yanboliang, https://github.com/oulgen
2024-01-12 05:32:13 +00:00
Mengwei Liu
8783fe9cf3 [export] Modify SDPA decomposition to decompose _scaled_dot_product_flash_attention_for_cpu (#117097)
Summary: As titled. #115913 added
`_scaled_dot_product_flash_attention_for_cpu` and the export result of
`scaled_dot_product_attention` includes this op. Adding this
decomposition so that it's being decomposed the same way as
`_scaled_dot_product_attention_math`.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117097
Approved by: https://github.com/lezcano
2024-01-10 23:46:14 +00:00
Elias Ellison
d6540038c0 Fix 0-dim Index in Index Copy decomp (#117065)
Fix for https://github.com/pytorch/pytorch/issues/115931

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117065
Approved by: https://github.com/jansel, https://github.com/shunting314
2024-01-10 22:13:43 +00:00
rzou
3477a2ee03 unMarkDynamoStrictTest on OpInfo-based tests (#115856)
These take too long to run under strict mode. We'll worry about them
later. Note that these decorators don't do anything yet (unless we flip
the default from non-strict to strict).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115856
Approved by: https://github.com/voznesenskym
ghstack dependencies: #115845, #115855
2023-12-15 01:22:31 +00:00
atalman
ba4285bd9e Deprecate primTorch module, replace it with decompositions in module Owners (#114754)
Context: pt2 oncall is revamping its labeling system. One of the guidelines is to remove duplicate labeling in our system. Both primTorch and decomposition labels are referring to the same thing. primTorch was the legacy name (and we no longer have a primTorch project), so using decomposition as the label name makes more sense.

Right now, the only open issues that use "module: primTorch" are the ones generated by the DISABLED bots. Once we replace the label in the bot, we can safely remove the primTorch label.

Here an example of the issue that has primTorch label :
https://github.com/pytorch/pytorch/issues/112719

Torchbot uses following logic to auto extract module owners:
https://github.com/pytorch/test-infra/blob/main/torchci/pages/api/flaky-tests/disable.ts#L391

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114754
Approved by: https://github.com/huydhn
2023-11-29 18:27:20 +00:00
Mengwei Liu
5506b9db43 [decomp] Fix _scaled_dot_product_flash_attention decomposition bug (#113102)
For `_scaled_dot_product_flash_attention` we don't have

`Tensor? attn_mask=None`

but `scaled_dot_product_attention` has. In the original decomp there's a
mixup where I added this argument to
`_scaled_dot_product_flash_attention`.

Fix it so that `_scaled_dot_product_flash_attention` is being decomposed correctly.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113102
Approved by: https://github.com/ezyang
2023-11-08 21:47:37 +00:00
Han Qi
5a6f8014c4 Add a decomposition for _weight_norm_interface. (#112193)
Fixes #112086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112193
Approved by: https://github.com/ezyang
2023-11-01 19:51:11 +00:00
Peter Bell
66c32d099a Use pytree.arg_tree_leaves everywhere (#112394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112394
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393
2023-10-31 15:57:06 +00:00
Peter Bell
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
Nikita Shulga
4f0cf1e1ff Mark more decomp tests as slow (#111524)
Something is broken with automatic slow detection, so let's do it manually

Those tests were previously classified as slow, see:
```
test_decomp.py::TestDecompCUDA::test_quick_core_backward_baddbmm_cuda_float64 SKIPPED [0.0003s] (test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test) [ 53%]
test_decomp.py::TestDecompCUDA::test_quick_core_backward_clamp_max_cuda_float64 SKIPPED [0.0002s] (test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test) [ 53%]
test_decomp.py::TestDecompCUDA::test_quick_core_backward_clamp_min_cuda_float64 SKIPPED [0.0002s] (test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test) [ 53%]
```
from https://ossci-raw-job-status.s3.amazonaws.com/log/17792633247

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111524
Approved by: https://github.com/kit1980, https://github.com/izaitsevfb, https://github.com/huydhn
2023-10-19 02:29:59 +00:00
Nikita Shulga
16cb3bdd57 Skip test_quick_core_backward_baddbmm_cuda_float64 (#111493)
As its painfully slow (10+ min on A100):
```shell
$ time python3 test_decomp.py -v -k test_quick_core_backward_baddbmm_cuda_float64
Fail to import hypothesis in common_utils, tests are not derandomized
test_quick_core_backward_baddbmm_cuda_float64 (__main__.TestDecompCUDA) ... ok

----------------------------------------------------------------------
Ran 1 test in 897.523s

OK

real	15m4.773s
user	15m0.207s
sys	0m6.492s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111493
Approved by: https://github.com/clee2000, https://github.com/huydhn
2023-10-18 20:09:14 +00:00
PyTorch MergeBot
98c329b19e Revert "[core ATen IR] Add decompositions for max, min, var_mean (#110906)"
This reverts commit 9606cda64e.

Reverted https://github.com/pytorch/pytorch/pull/110906 on behalf of https://github.com/SS-JIA due to Breaks internal CI ([comment](https://github.com/pytorch/pytorch/pull/110906#issuecomment-1757490740))
2023-10-11 11:41:21 +00:00
SS-JIA
9606cda64e [core ATen IR] Add decompositions for max, min, var_mean (#110906)
## Context

Add decompositions for `aten.max`, `aten.min`, and `aten.var_mean`. These operators follow a pattern of returning a tuple of outputs from two component operators:

```
aten.max(x) -> return aten.amax(x), aten.argmax(x)
aten.min(x) -> return aten.amin(x), aten.argmin(x)
aten.var_mean(x) -> return aten.var(x), aten.mean(x)
```

For `var_mean`, the `refs` implementation was doing something similar, so I changed it to call `torch.` ops instead like was done for other `refs` implementations previously. cc: @peterbell10 @lezcano

Note that Inductor lowers all these directly, so they are excluded from the Inductor decomp table.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110906
Approved by: https://github.com/manuelcandales
2023-10-11 00:06:24 +00:00
cdzhan
7cc0020a80 [decomp] Fix different return type in threshold_backward vs. eager (#110689)
due to type promotion with floating point scalar in decompositions.py

Fixes part of #100838

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110689
Approved by: https://github.com/ezyang
2023-10-06 20:59:58 +00:00
SS-JIA
9928c10e71 [core IR] Add glu as a core decomposition (#110043)
## Context

Add the decomposition for `aten.glu` as a decomposition in the core ATen decomposition table. Don't use it in the Inductor decomposition table since Inductor has a lowering for it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110043
Approved by: https://github.com/peterbell10, https://github.com/lezcano
ghstack dependencies: #110046
2023-09-27 00:23:05 +00:00
Li-Huai (Allan) Lin
b2cba439b4 Introduce Tensor overload to linspace and logspace (#104889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104889
Approved by: https://github.com/zou3519
ghstack dependencies: #107958
2023-09-11 23:30:40 +00:00
PyTorch MergeBot
a7f5abeade Revert "Introduce Tensor overload to linspace and logspace (#104889)"
This reverts commit 57e5239321.

Reverted https://github.com/pytorch/pytorch/pull/104889 on behalf of https://github.com/clee2000 due to sorry have to revert this to revert https://github.com/pytorch/pytorch/pull/107958 ([comment](https://github.com/pytorch/pytorch/pull/104889#issuecomment-1714305768))
2023-09-11 17:33:48 +00:00
Li-Huai (Allan) Lin
57e5239321 Introduce Tensor overload to linspace and logspace (#104889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104889
Approved by: https://github.com/zou3519
ghstack dependencies: #107958
2023-09-11 15:29:39 +00:00
rzou
0e4752bafc Allow registering decomps for HigherOrderOp; add decomp for out_dtype (#108080)
We allow registering decomps for HigherOrderOp via the existing decomp
mechanisms:
- I refactored those APIs to accept torch._ops.OperatorBase, which is the base
  class for torch.ops.HigherOrderOperator and torch.ops.OpOverload
- HigherOrderOps must directly call maybe_handle_decomp in their
  ProxyTorchDispatchMode handling in order to resolve decompositions. We
  can change this in the future so that they do not need to do this.

Next, we add an inductor decomp for out_dtype. This decomp shouldn't be
generally available because we want to preserve out_dtype to the backend
for other use cases (i.e. executorch).

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108080
Approved by: https://github.com/HDCharles
2023-08-31 03:15:38 +00:00
Nikita Karetnikov
77f080ee29 [pt2] test if core decomps are differentiable (#107241)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107241
Approved by: https://github.com/ezyang
2023-08-18 20:47:58 +00:00
lezcano
2c5f96deac [Inductor] Make softshrink composite implicit (#107052)
The backward is pretty much equivalent to the one we had written

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107052
Approved by: https://github.com/peterbell10
ghstack dependencies: #107038, #107039, #107051
2023-08-14 21:01:50 +00:00
lezcano
3b1254e800 Make hardshrink's decomp composite implicit (#107039)
The generated code is the same
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107039
Approved by: https://github.com/peterbell10
ghstack dependencies: #107038
2023-08-14 21:01:50 +00:00
Sam Larsen
e165938853 Implement decomposition for aten.rrelu_with_noise (#106812)
Test Plan:
* Primarily, added new test in test/test_decomp.py
* Updated existing tests, e.g., to NOT expect failure

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106812
Approved by: https://github.com/eellison
2023-08-11 19:18:29 +00:00
Kshiteej K
a899333ffc fix: nll_loss batch rule with negative ignore_idx (#106118)
We use python decompositions instead of writing our own for batching rules.

Fixes https://github.com/pytorch/pytorch/issues/105736

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106118
Approved by: https://github.com/lezcano, https://github.com/zou3519
2023-08-04 07:43:02 +00:00
Peter Bell
5c580a9846 [decomp] Add test tracking core ATen operators (#104262)
This adds an expect-test that finds the set of core ATen operators by
subtracting the operators with decomposition in core_aten_decompositions from the
set of all operators that have decompositions and could be decomposed.

This is useful because if you add a new decomposition but forget to add it to
the list of core decompositions, it will appear in the PR diff.

Also, by going through this list I have identified some operators where the
functional variant is decomposed, but not the inplace variant which must be an
oversight.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104262
Approved by: https://github.com/lezcano
2023-07-04 16:41:44 +00:00
Fuzzkatt
d805a53f1f disable tf32 for rnn tests and norm tests (#102005)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102005
Approved by: https://github.com/ngimel
2023-05-24 02:22:58 +00:00
Khushi
1aaf0396eb [reland][opinfo] empty_strided (#101782)
Follows #100223

Previous PR: #100890

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101782
Approved by: https://github.com/ezyang
2023-05-19 03:06:29 +00:00
PyTorch MergeBot
dfac4364c4 Revert "[opinfo] empty_strided (#100890)"
This reverts commit 01c7106580.

Reverted https://github.com/pytorch/pytorch/pull/100890 on behalf of https://github.com/PaliC due to broke test_ops.py slow test ([comment](https://github.com/pytorch/pytorch/pull/100890#issuecomment-1551903975))
2023-05-17 19:00:15 +00:00
Jiong Gong
788ff0623b [decomp] fix decomp of batch_norm when weight/bias is not flattened (#101059)
Fix https://github.com/pytorch/pytorch/issues/100970
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101059
Approved by: https://github.com/ezyang
2023-05-16 00:00:34 +00:00