Commit Graph

261 Commits

Author SHA1 Message Date
xinan.lin
16b37b309f [Inductor] Rename cpp_wrapper_cuda.py as cpp_wrapper_gpu.py (#135313)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135313
Approved by: https://github.com/jansel, https://github.com/desertfire
ghstack dependencies: #135312
2024-09-11 23:59:54 +00:00
xinan.lin
13ee85ca5e [Inductor] Generalize cuda cpp wrapper as common triton based GPU cpp wrapper, will be reused by xpu in next PR. (#135312)
[Inductor] Generalize cuda cpp wrapper as common triton based GPU cpp wrapper, will be reused by xpu in next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135312
Approved by: https://github.com/jansel, https://github.com/desertfire, https://github.com/eellison
2024-09-11 23:59:54 +00:00
xinan.lin
ca16956b20 [Inductor] Generalize device guard codegen for cpp_wrapper mode. (#134761)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134761
Approved by: https://github.com/jansel, https://github.com/EikanWang
ghstack dependencies: #134693
2024-09-10 10:11:52 +00:00
Jason Ansel
eac5e12548 [inductor] Move LoopBody to its own file (#135257)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135257
Approved by: https://github.com/oulgen
2024-09-07 16:29:15 +00:00
leslie-fang-intel
2c7e314803 [Inductor][CPP] Fix the issue of view dtype (#135301)
**Summary**
Fix issue: https://github.com/pytorch/pytorch/issues/135160, it's a regression introduced by https://github.com/pytorch/pytorch/pull/134569, where the dtype of `to_dtype_bitcast` was incorrectly handled when using the scalarize implementation.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_view_dtype
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135301
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-09-06 23:36:44 +00:00
haozhe.zhu
f4641ca481 [Inductor] Remove VecChecker and fallback non-supported Vec op to Scalar impl with a for loop (#134569)
Fall back non-vectorized op by scalar impl + for loop.

Example code:
```
cpp_fused_igammac_0 = async_compile.cpp_pybinding(['const double*', 'const double*', 'double*'], '''
#include "/tmp/torchinductor_root/z4/cz4j2mmotlx3z2b7u4fbjtdt4x6plhd67ljwzg5bk7ekv4xz6y7q.h"
extern "C"  void kernel(const double* in_ptr0,
                       const double* in_ptr1,
                       double* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(48L); x0+=static_cast<int64_t>(8L))
        {
            auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), 8);
            auto tmp1 = in_ptr1[static_cast<int64_t>(0L)];
            auto tmp2 = at::vec::VectorizedN<double,2>(tmp1);
            auto tmp3 =
            [&]()
            {
                __at_align__ std::array<double, 8> tmpbuf0;
                tmp0.store(tmpbuf0.data(), 8);
                __at_align__ std::array<double, 8> tmpbuf1;
                tmp2.store(tmpbuf1.data(), 8);
                __at_align__ std::array<double, 8> tmpbuf_out;
                for (int i = 0; i < 8; i++)
                {
                    tmpbuf_out[i] = calc_igammac(tmpbuf0[i], tmpbuf1[i]);
                }
                return at::vec::VectorizedN<double, 2>::loadu(tmpbuf_out.data(), 8);
            }
            ()
            ;
            tmp3.store(out_ptr0 + static_cast<int64_t>(x0), 8);
        }
        #pragma omp simd simdlen(4)
        for(int64_t x0=static_cast<int64_t>(48L); x0<static_cast<int64_t>(50L); x0+=static_cast<int64_t>(1L))
        {
            auto tmp0 = in_ptr0[static_cast<int64_t>(x0)];
            auto tmp1 = in_ptr1[static_cast<int64_t>(0L)];
            auto tmp2 = calc_igammac(tmp0, tmp1);
            out_ptr0[static_cast<int64_t>(x0)] = tmp2;
        }
    }
}
''')

```

`frexp` are difficult to be handled by common `fallback` since it returns two `cse_var` 2ba60a1618/torch/_inductor/codegen/cpp.py (L752-L766)
So we added a special function to do that.
```
cpp_fused_frexp_0 = async_compile.cpp_pybinding(['const double*', 'double*', 'int32_t*'], '''
#include "/tmp/torchinductor_root/z4/cz4j2mmotlx3z2b7u4fbjtdt4x6plhd67ljwzg5bk7ekv4xz6y7q.h"
extern "C"  void kernel(const double* in_ptr0,
                       double* out_ptr0,
                       int32_t* out_ptr1)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(16L); x0+=static_cast<int64_t>(8L))
        {
            auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), 8);
            at::vec::Vectorized<int32_t> tmp1;
            at::vec::VectorizedN<double, 2> tmp2;
            [&]()
            {
                __at_align__ std::array<double, 8> tmpbuf;
                tmp0.store(tmpbuf.data(), 8);
                __at_align__ std::array<int32_t, 8> tmpbuf_exponent;
                __at_align__ std::array<double, 8> tmpbuf_mantissa;
                for (int i = 0; i < 8; i++)
                {
                    tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]);
                }
                tmp1 = at::vec::Vectorized<int32_t>::loadu(tmpbuf_exponent.data(), 8);
                tmp2 = at::vec::VectorizedN<double, 2>::loadu(tmpbuf_mantissa.data(), 8);
            }
            ();
            tmp2.store(out_ptr0 + static_cast<int64_t>(x0), 8);
            tmp1.store(out_ptr1 + static_cast<int64_t>(x0), 8);
        }
        #pragma omp simd simdlen(4)
        for(int64_t x0=static_cast<int64_t>(16L); x0<static_cast<int64_t>(20L); x0+=static_cast<int64_t>(1L))
        {
            auto tmp0 = in_ptr0[static_cast<int64_t>(x0)];
            int32_t tmp1;
            auto tmp2 = std::frexp(tmp0, &tmp1);
            out_ptr0[static_cast<int64_t>(x0)] = tmp2;
            out_ptr1[static_cast<int64_t>(x0)] = tmp1;
        }
    }
}
''')
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134569
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-08-31 11:19:57 +00:00
Rachel Guo
3965f11837 Minor type annotation updates following up D60954888 (#133382)
Summary: As title.

Test Plan:
CI

Ran lintrunner locally but might have to continue to keep an eye on more oss linting issue if comes up.

Differential Revision: D61240900

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133382
Approved by: https://github.com/ColinPeppler
2024-08-14 21:36:42 +00:00
Oguz Ulgen
72d2dba992 Add None return type to init (#132335)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335
Approved by: https://github.com/albanD
2024-08-01 15:26:45 +00:00
eellison
f32ab3b9e3 Migrate Inductor scheduler, dependencies, ir, and codegen/common to use OrderedSet (#130004)
Python's set is non deterministic. There is an internal failure which we recently ran into which did not consistently fail.

See, repro here: P1453035092.

Now, with these changes, it does consistently fail. In follow ups we could also consider adding a lintrule for uses of either set() or set literals.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130004
Approved by: https://github.com/oulgen
2024-08-01 04:37:15 +00:00
PyTorch MergeBot
784a6ec5a3 Revert "Migrate Inductor scheduler, dependencies, ir, and codegen/common to use OrderedSet (#130004)"
This reverts commit 13d744464f.

Reverted https://github.com/pytorch/pytorch/pull/130004 on behalf of https://github.com/clee2000 due to broke lint [GH job link](https://github.com/pytorch/pytorch/actions/runs/10183945999/job/28170099930) [HUD commit link](13d744464f) probably a landrace, the base is 21 hours old ([comment](https://github.com/pytorch/pytorch/pull/130004#issuecomment-2260946562))
2024-07-31 16:49:21 +00:00
eellison
13d744464f Migrate Inductor scheduler, dependencies, ir, and codegen/common to use OrderedSet (#130004)
Python's set is non deterministic. There is an internal failure which we recently ran into which did not consistently fail.

See, repro here: P1453035092.

Now, with these changes, it does consistently fail. In follow ups we could also consider adding a lintrule for uses of either set() or set literals.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130004
Approved by: https://github.com/oulgen
2024-07-31 16:22:11 +00:00
leslie-fang-intel
f8e4060484 [Inductor][CPP] Enhance cppcsevar data type deduce (#130827)
**Summary**
Previously, we used `data_type_propagation` at the start of `codegen` to deduce the data type of each node and save this information in `node.meta[OptimizationContext.key]`. Then, we used this node metadata to update the cppcsevar data type in `update_on_args`. However, this method is not always correct. For example, in the codegen of `indirect_indexing` (see [here](096dc444ce/torch/_inductor/codegen/common.py (L1844))), we insert nodes on the fly and reuse the node of `indirect_indexing` to set the `cppcsevar` data type. In this PR, we plan to enhance the `cppcsevar` data type deduction:

- We will deduce the `cppcsevar` data type in `update_on_args` by reusing the code in `data_type_propagation`.

- To align the data type of scalar and vector variables, we previously always cast the scalar to the vector's data type. This caused a data type misalignment between `codegen` and `data_type_propagation`. We should use the same data type promotion logic to align the data types of scalar and vector variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130827
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-07-30 02:51:31 +00:00
eellison
5772c13f56 Dont wrap negative indexing in scatter reduce (#131503)
Fix for https://github.com/pytorch/pytorch/issues/131321

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131503
Approved by: https://github.com/shunting314
2024-07-24 04:01:32 +00:00
eellison
16a2a1aad3 Annotate graph.py (#131400)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131400
Approved by: https://github.com/shunting314
2024-07-23 07:04:12 +00:00
Peter Bell
27c2a0d63b [inductor] Separate Buffer and Operation into two concepts (#130831)
Resubmit of #128893

Currently a buffer represents both a tensor with physical storage and a
computation that produces the tensor as a result.

This PR attempts to split these into two different concepts in the scheduler.
This should allow us to have multiple outputs from a single operation.

Differential Revision: [D59876059](https://our.internmc.facebook.com/intern/diff/D59876059)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130831
Approved by: https://github.com/lezcano
2024-07-20 02:05:07 +00:00
Xuehai Pan
f0075c179b Pin sympy >= 1.13.0 (#130895)
------

The opposite of #130836. Pin `sympy >= 1.13.0` for Python >= 3.9 and `sympy == 1.12.1` for Python 3.8.

- #130836

See the PR description of #130836 for more details.

`sympy` 1.13.0 introduces some breaking changes which break our tests. More specifically:

- Ref [Backwards compatibility breaks and deprecations](https://github.com/sympy/sympy/wiki/release-notes-for-1.13.0#backwards-compatibility-breaks-and-deprecations)

> BREAKING CHANGE: Float and Integer/Rational no longer compare equal with a == b. From now on Float(2.0) != Integer(2). Previously expressions involving Float would compare unequal e.g. x*2.0 != x*2 but an individual Float would compare equal to an Integer. In SymPy 1.7 a Float will always compare unequal to an Integer even if they have the same "value". Use sympy.numbers.int_valued(number) to test if a number is a concrete number with no decimal part. ([#25614](https://github.com/sympy/sympy/pull/25614) by [@smichr](https://github.com/smichr))

`sympy >= 1.13.0` is required to enable Python 3.13 support. This should be part of #130689.

- #130689

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130895
Approved by: https://github.com/ezyang
2024-07-20 00:59:24 +00:00
Isuru Fernando
b7d2abd766 Fix vectorized ops.masked (#130130)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130130
Approved by: https://github.com/jgong5, https://github.com/lezcano
2024-07-17 14:55:11 +00:00
chilli
f9f85bfc0b [Inductor] FlexAttention supports partial masking (#130415) (#130626)
This is the new version of https://github.com/pytorch/pytorch/pull/130415

Updated test script: https://gist.github.com/yanboliang/7c34a82df611d4ea8869cb9e041bfbfc
Updated perf numbers:
```
(pt) [ybliang@devgpu002.ash8 ~/local/debug]$ CUDA_VISIBLE_DEVICES=4 python debug7.py
fwd speedup: 0.7166695598192317
bwd speedup: 0.7142133867805904
(pt) [ybliang@devgpu002.ash8 ~/local/debug]$ CUDA_VISIBLE_DEVICES=4 python debug7.py --partial-mask
fwd speedup: 0.8428246087169973
bwd speedup: 0.8486261278030254
```
Approved by: https://github.com/Chillee

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130626
Approved by: https://github.com/drisspg, https://github.com/yanboliang
2024-07-14 00:37:26 +00:00
Xuehai Pan
973037be6a [BE][Easy] apply autofix for ruff rules unnecessary-collection-call (C408): list() / tuple() / dict() (#130199)
This PR changes the empty collection factory call to Python literals:

- `list()` -> `[]`
- `tuple()` -> `()`
- `dict()` -> `{}`

The Python literals are more performant and safer. For example, the bytecode for building an empty dictionary:

```bash
$ python3 -m dis - <<EOS
import collections

d1 = {}
d2 = dict()

dict = collections.OrderedDict
d3 = dict()
EOS
```

```text
  0           0 RESUME                   0

  1           2 LOAD_CONST               0 (0)
              4 LOAD_CONST               1 (None)
              6 IMPORT_NAME              0 (collections)
              8 STORE_NAME               0 (collections)

  3          10 BUILD_MAP                0
             12 STORE_NAME               1 (d1)

  4          14 PUSH_NULL
             16 LOAD_NAME                2 (dict)
             18 CALL                     0
             26 STORE_NAME               3 (d2)

  6          28 LOAD_NAME                0 (collections)
             30 LOAD_ATTR                8 (OrderedDict)
             50 STORE_NAME               2 (dict)

  7          52 PUSH_NULL
             54 LOAD_NAME                2 (dict)
             56 CALL                     0
             64 STORE_NAME               5 (d3)
             66 RETURN_CONST             1 (None)
```

The dict literal `{}` only has one bytecode `BUILD_MAP`, while the factory call `dict()` has three `PUSH_NULL + LOAD_NAME + CALL`. Also, the factory call is not safe if users override the `dict` name in `locals` or `globals` (see the example of replacing with `OrderedDict` above).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130199
Approved by: https://github.com/malfet
2024-07-11 17:30:28 +00:00
Richard Zou
edf273edf4 Revert some PRs (#130303)
Summary:
Revert https://github.com/pytorch/pytorch/pull/129346 thru
https://github.com/pytorch/pytorch/pull/128893

For S430832

Test Plan: Tests

Differential Revision: D59503843

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130303
Approved by: https://github.com/bdhirsh
2024-07-09 14:46:00 +00:00
chilli
cd683212a2 Fix indexing twice with score_mod (#130224)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130224
Approved by: https://github.com/yanboliang
ghstack dependencies: #130160, #130106
2024-07-08 18:15:35 +00:00
peaceorwell
9983242c8e [inductor] support adding a new inductor backend using PrivateUse1 (#129953)
Add handling custom device registered by PrivateUse1 in init_backend_registration() func

Fixes #129952

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129953
Approved by: https://github.com/jansel
2024-07-06 21:15:40 +00:00
Jason Ansel
4fc9157e90 [halide-backend] Disable split reductions for Halide (#129320)
In theory Halide doesn't need the split reduction stuff we do for Triton since it can generate multiple kernels.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129320
Approved by: https://github.com/shunting314, https://github.com/eellison
ghstack dependencies: #129321
2024-07-03 05:56:40 +00:00
Peter Bell
fb078c20c1 [inductor] Separate Buffer and Operation into two concepts (#128893)
Currently a buffer represents both a tensor with physical storage and a
computation that produces the tensor as a result.

This PR attempts to split these into two different concepts in the scheduler.
This should allow us to have multiple outputs from a single operation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128893
Approved by: https://github.com/lezcano
2024-07-02 23:49:57 +00:00
Aaron Gokaslan
6c2a8b6b38 [Ez][BE]: Enable new stable ruff rules (#129825)
Applies a bunch of new ruff lint rules that are now stable. Some of these improve efficiency or readability. Since I already did passes on the codebase for these when they were in preview, there should be relatively few changes to the codebase. This is just more for future hardening of it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129825
Approved by: https://github.com/XuehaiPan, https://github.com/jansel, https://github.com/malfet
2024-07-02 14:47:10 +00:00
PyTorch MergeBot
e385bf8ef8 Revert "[halide-backend] Disable split reductions for Halide (#129320)"
This reverts commit a18eb651d3.

Reverted https://github.com/pytorch/pytorch/pull/129320 on behalf of https://github.com/jeanschmidt due to This PR is breaking internal builds, please check comments on it D59204360 ([comment](https://github.com/pytorch/pytorch/pull/129320#issuecomment-2200351678))
2024-07-01 14:44:35 +00:00
leslie-fang-intel
3fec0efd34 [Inductor][CPP] Support vectorization of bitwise fn (#129733)
**Summary**
When check the vectorization status among 3 test suit, we found some operators disabled vectorization with message `Disabled vectorization: op: bitwise_and`. In this PR, we add vectorization support of 6 bitwise functions.

In this PR, we also remove `bitwise_xor` from `ops_to_bool` list which sets output data type as bool in data type propagation. It seems wrong since according to this doc
https://pytorch.org/docs/stable/generated/torch.bitwise_xor.html, it should return the same integral data type with input and the testcase `test_bitwise3` failed due to this issue.

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_vec_bitwise
python -u -m pytest -s -v test/inductor/test_torchinductor.py -k test_bitwise3
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129733
Approved by: https://github.com/jgong5, https://github.com/Skylion007
2024-06-29 17:25:27 +00:00
Jason Ansel
a18eb651d3 [halide-backend] Disable split reductions for Halide (#129320)
In theory Halide doesn't need the split reduction stuff we do for Triton since it can generate multiple kernels.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129320
Approved by: https://github.com/shunting314, https://github.com/eellison
ghstack dependencies: #126417, #129025, #129026, #127506, #129036
2024-06-29 14:06:28 +00:00
Jason Ansel
b93bf55b6a [halide-backend] Add GPU support (#127506)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127506
Approved by: https://github.com/shunting314, https://github.com/eellison
ghstack dependencies: #126417, #129025, #129026
2024-06-29 14:06:21 +00:00
Jason Ansel
da5f37515e [halide-backend] Generate standalone runtime (#129025)
This puts the halide runtime in a global shared object, rather than copying it to each kernel.  Having many copies of the runtime causes many issues with cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129025
Approved by: https://github.com/shunting314, https://github.com/eellison
ghstack dependencies: #126417
2024-06-29 14:06:12 +00:00
Jason Ansel
e34b7e6af3 [halide-backend] Initial implementation of HalideKernel and HalideScheduling (#126417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126417
Approved by: https://github.com/shunting314, https://github.com/eellison
2024-06-29 14:06:08 +00:00
Peter Bell
90d5a6f001 [inductor] Add lowering and codegen for aten.sort (#128458)
Closes #125633

Benchmarks:
| Shape       | dim | stable | compiled | eager   | speedup |
|-------------|-----|--------|----------|---------|---------|
| (256, 4096) | 0   | False  | 0.73 ms  | 1.26 ms | 1.7     |
| (256, 4096) | 0   | True   | 0.75 ms  | 1.27 ms | 1.7     |
| (4096, 256) | 1   | False  | 0.20 ms  | 0.73 ms | 3.7     |
| (4096, 256) | 1   | True   | 0.21 ms  | 0.73 ms | 3.5     |
| (255, 4096) | 0   | False  | 1.05 ms  | 1.48 ms | 1.4     |
| (255, 4096) | 0   | True   | 1.03 ms  | 1.47 ms | 1.4     |
| (4096, 255) | 1   | False  | 0.52 ms  | 0.98 ms | 1.9     |
| (4096, 255) | 1   | True   | 0.54 ms  | 1.00 ms | 1.9     |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128458
Approved by: https://github.com/lezcano, https://github.com/eellison
2024-06-26 01:36:39 +00:00
Jiong Gong
533c4190f9 [inductor][cpp] support nested kernel with indirect indexing (#129223)
This PR makes sure the current kernel is used for generating CSE variables when nested kernel codegen is involved, e.g., nested CppKernel is used to generate epilogue of CppTemplateKernel. Without the fix, the epilogue with indirect indexing would fail to run.

pytest -k test_linear_with_embedding_bias_False_cpu test_cpu_select_algorithm.py

Epilogue code Before:
```c++
                {
                    #pragma GCC ivdep
                    for(long x0=static_cast<long>(0L); x0<static_cast<long>(m_end + ((-1L)*m_start)); x0+=static_cast<long>(1L))
                    {
                        for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1+=static_cast<long>(16L))
                        {
                            auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)];
                            auto tmp11 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<long>(x1 + (N0*x0)), 16);
                            auto tmp1 = 64L;
                            auto tmp2 = c10::convert<int64_t>(tmp1);
                            auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
                            auto tmp4 = tmp0 ? tmp3 : tmp0;
                            auto tmp5 = decltype(tmp4)(tmp4 + tmp2);
                            auto tmp6 = tmp1 ? tmp5 : tmp4;
                            auto tmp7 = tmp6;
                            auto tmp8 = c10::convert<int64_t>(tmp7);
                            TORCH_CHECK((0 <= tmp8) & (tmp8 < 64L), "index out of bounds: 0 <= tmp8 < 64L");
                            auto tmp10 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<long>(n_start + x1 + (384L*tmp6)), 16);
                            auto tmp12 = (tmp11);
                            auto tmp13 = tmp10 + tmp12;
                            tmp13.store(Y + static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x1=static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1<static_cast<long>(N0); x1+=static_cast<long>(1L))
                        {
                            auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)];
                            auto tmp11 = local_acc_buf[static_cast<long>(x1 + (N0*x0))];
                            auto tmp1 = 64L;
                            auto tmp2 = c10::convert<int64_t>(tmp1);
                            auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
                            auto tmp4 = tmp0 ? tmp3 : tmp0;
                            auto tmp5 = decltype(tmp4)(tmp4 + tmp2);
                            auto tmp6 = tmp1 ? tmp5 : tmp4;
                            auto tmp7 = tmp6;
                            auto tmp8 = c10::convert<int64_t>(tmp7);
                            TORCH_CHECK((0 <= tmp8) & (tmp8 < 64L), "index out of bounds: 0 <= tmp8 < 64L");
                            TORCH_CHECK((0 <= tmp8) & (tmp8 < 64L), "index out of bounds: 0 <= tmp8 < 64L");
                            auto tmp10 = in_ptr3[static_cast<long>(n_start + x1 + (384L*tmp6))];
                            auto tmp12 = c10::convert<float>(tmp11);
                            auto tmp13 = decltype(tmp10)(tmp10 + tmp12);
                            Y[static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0))] = tmp13;
                        }
                    }
                }
```

Epilogue code After:
```c++
                {
                    #pragma GCC ivdep
                    for(long x0=static_cast<long>(0L); x0<static_cast<long>(m_end + ((-1L)*m_start)); x0+=static_cast<long>(1L))
                    {
                        for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1+=static_cast<long>(16L))
                        {
                            auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)];
                            auto tmp13 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<long>(x1 + (N0*x0)), 16);
                            auto tmp1 = 64L;
                            auto tmp2 = c10::convert<int64_t>(tmp1);
                            auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
                            auto tmp4 = tmp0 < 0;
                            auto tmp5 = tmp4 ? tmp3 : tmp0;
                            auto tmp6 = decltype(tmp5)(tmp5 + tmp2);
                            auto tmp7 = tmp5 < 0;
                            auto tmp8 = tmp7 ? tmp6 : tmp5;
                            auto tmp9 = tmp8;
                            auto tmp10 = c10::convert<int64_t>(tmp9);
                            TORCH_CHECK((0 <= tmp10) & (tmp10 < 64L), "index out of bounds: 0 <= tmp10 < 64L");
                            auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<long>(n_start + x1 + (384L*tmp8)), 16);
                            auto tmp14 = (tmp13);
                            auto tmp15 = tmp12 + tmp14;
                            tmp15.store(Y + static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x1=static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1<static_cast<long>(N0); x1+=static_cast<long>(1L))
                        {
                            auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)];
                            auto tmp13 = local_acc_buf[static_cast<long>(x1 + (N0*x0))];
                            auto tmp1 = 64L;
                            auto tmp2 = c10::convert<int64_t>(tmp1);
                            auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
                            auto tmp4 = tmp0 < 0;
                            auto tmp5 = tmp4 ? tmp3 : tmp0;
                            auto tmp6 = decltype(tmp5)(tmp5 + tmp2);
                            auto tmp7 = tmp5 < 0;
                            auto tmp8 = tmp7 ? tmp6 : tmp5;
                            auto tmp9 = tmp8;
                            auto tmp10 = c10::convert<int64_t>(tmp9);
                            TORCH_CHECK((0 <= tmp10) & (tmp10 < 64L), "index out of bounds: 0 <= tmp10 < 64L");
                            TORCH_CHECK((0 <= tmp10) & (tmp10 < 64L), "index out of bounds: 0 <= tmp10 < 64L");
                            auto tmp12 = in_ptr3[static_cast<long>(n_start + x1 + (384L*tmp8))];
                            auto tmp14 = c10::convert<float>(tmp13);
                            auto tmp15 = decltype(tmp12)(tmp12 + tmp14);
                            Y[static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0))] = tmp15;
                        }
                    }
                }
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129223
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
2024-06-25 05:21:00 +00:00
PyTorch MergeBot
1a54bb0f96 Revert "[halide-backend] Initial implementation of HalideKernel and HalideScheduling (#126417)"
This reverts commit 4f9399bd0d.

Reverted https://github.com/pytorch/pytorch/pull/126417 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/126417#issuecomment-2186999121))
2024-06-24 16:50:15 +00:00
PyTorch MergeBot
063facf352 Revert "[halide-backend] Generate standalone runtime (#129025)"
This reverts commit 10c64c3b49.

Reverted https://github.com/pytorch/pytorch/pull/129025 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/129025#issuecomment-2186995467))
2024-06-24 16:47:25 +00:00
Jason Ansel
10c64c3b49 [halide-backend] Generate standalone runtime (#129025)
This puts the halide runtime in a global shared object, rather than copying it to each kernel.  Having many copies of the runtime causes many issues with cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129025
Approved by: https://github.com/shunting314, https://github.com/eellison
ghstack dependencies: #126417
2024-06-22 17:39:52 +00:00
Jason Ansel
4f9399bd0d [halide-backend] Initial implementation of HalideKernel and HalideScheduling (#126417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126417
Approved by: https://github.com/shunting314, https://github.com/eellison
2024-06-22 17:39:52 +00:00
Jason Ansel
feb3f3ad77 [inductor] Refactors for Halide backend (#129024)
Pulling these inductor-related refactors out of the larger Halide
backend PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129024
Approved by: https://github.com/shunting314, https://github.com/eellison
2024-06-21 16:53:35 +00:00
eellison
c187593418 Prevent expansion of cat indexing to avoid int64 intermediate (#127815)
Fix for https://github.com/pytorch/pytorch/issues/127652

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127815
Approved by: https://github.com/shunting314, https://github.com/peterbell10
2024-06-14 15:42:08 +00:00
Isuru Fernando
e397ad6883 Improve codegen for ops.masked in triton (#128054)
Fixes https://github.com/pytorch/pytorch/issues/127930
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128054
Approved by: https://github.com/peterbell10, https://github.com/lezcano
2024-06-14 11:52:56 +00:00
Jason Ansel
c897651392 [inductor] Add BackendFeature gating (#128266)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128266
Approved by: https://github.com/shunting314
2024-06-13 07:31:51 +00:00
PyTorch MergeBot
f2dcbe89d6 Revert "Prevent expansion of cat indexing to avoid int64 intermediate (#127815)"
This reverts commit 793df7b7cb.

Reverted https://github.com/pytorch/pytorch/pull/127815 on behalf of https://github.com/clee2000 due to the newly added test is failing internally D58444153.  Test exists in opensource and passed in OSS CI, maybe env difference? ([comment](https://github.com/pytorch/pytorch/pull/127815#issuecomment-2163421968))
2024-06-12 16:09:22 +00:00
eellison
793df7b7cb Prevent expansion of cat indexing to avoid int64 intermediate (#127815)
Fix for https://github.com/pytorch/pytorch/issues/127652

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127815
Approved by: https://github.com/shunting314, https://github.com/peterbell10
2024-06-11 02:41:07 +00:00
Edward Z. Yang
3964a3ec73 Complete revamp of float/promotion sympy handling (#126905)
At a high level, the idea behind this PR is:

* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.

The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:

* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)

In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations.  Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.

We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:

* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`

These changes have consequences. First, we need to make some administrative changes:

* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
  * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
  * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.

In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:

* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type

The new asserts uncovered necessary bug fixes:

* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1

Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**

**Reland notes.** This requires this internal fbcode diff https://www.internalfb.com/phabricator/paste/view/P1403322587 but I cannot prepare the diff codev due to https://fb.workplace.com/groups/osssupport/posts/26343544518600814/

It also requires this Executorch PR https://github.com/pytorch/executorch/pull/3911 but the ET PR can be landed prior to this landing.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
2024-06-09 06:20:25 +00:00
Aaron Orenstein
ea614fb2b1 Flip default value for mypy disallow_untyped_defs [2/11] (#127839)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127839
Approved by: https://github.com/oulgen
2024-06-08 18:23:08 +00:00
PyTorch MergeBot
ac51f782fe Revert "Complete revamp of float/promotion sympy handling (#126905)"
This reverts commit 2f7cfecd86.

Reverted https://github.com/pytorch/pytorch/pull/126905 on behalf of https://github.com/atalman due to Sorry need to revert - failing internally ([comment](https://github.com/pytorch/pytorch/pull/126905#issuecomment-2155118778))
2024-06-07 16:01:46 +00:00
Edward Z. Yang
2f7cfecd86 Complete revamp of float/promotion sympy handling (#126905)
At a high level, the idea behind this PR is:

* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.

The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:

* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)

In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations.  Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.

We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:

* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`

These changes have consequences. First, we need to make some administrative changes:

* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
  * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
  * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.

In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:

* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type

The new asserts uncovered necessary bug fixes:

* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1

Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
2024-06-06 02:29:45 +00:00
PyTorch MergeBot
d5cb5d623a Revert "Complete revamp of float/promotion sympy handling (#126905)"
This reverts commit fb696ef3aa.

Reverted https://github.com/pytorch/pytorch/pull/126905 on behalf of https://github.com/ezyang due to internal user reported ceiling equality simplification problem, I have a plan ([comment](https://github.com/pytorch/pytorch/pull/126905#issuecomment-2148805840))
2024-06-05 03:57:58 +00:00
Edward Z. Yang
fb696ef3aa Complete revamp of float/promotion sympy handling (#126905)
At a high level, the idea behind this PR is:

* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.

The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:

* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)

In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations.  Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.

We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:

* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`

These changes have consequences. First, we need to make some administrative changes:

* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
  * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
  * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.

In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:

* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type

The new asserts uncovered necessary bug fixes:

* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1

Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
2024-06-04 11:47:32 +00:00
Edward Z. Yang
a4064da8ca Always simplify sympy expressions before printing. (#127543)
This is important because if a replacement has happened during inductor lowering, we may have stale symbols in sympy expressions that we need to replace away.  Do this at the very end.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127543
Approved by: https://github.com/lezcano
2024-06-03 20:36:14 +00:00