Commit Graph

985 Commits

Author SHA1 Message Date
Jason Ansel
637074983e [inductor] Make load_mask() codegen determinstic (#126017)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126017
Approved by: https://github.com/shunting314
2024-05-13 17:36:52 +00:00
Jiong Gong
037615b989 [inductor][cpp] GEMM template (infra and fp32) (#124021)
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC https://github.com/pytorch/pytorch/issues/125683 for more background info.
1. Cpp template infrastructure
Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates.
2. Initial FP32 gemm template
This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction.
3. Correctness and performance
The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details.

Static shapes
| Benchmark | torchbench | huggingface | timm_models |
|------------|-------------|--------------|--------------|
| Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x |
| Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x |
| Single-threaded (baseline) | 1.56x | 1.19x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x |

Key models being sped up:
drq: 1.14x
soft_act: 1.12
cait_m36_384: 1.18x

Dynamic shapes
| Benchmark | torchbench | huggingface | timm_models |
| --- | --- | --- | --- |
| Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x |
| Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x |
| Single-threaded (baseline) | 1.55x | 1.20x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x |

Key models being sped up:
BERT_pytorch: 1.22x
pyhpc_turbulent: 1.13x
soft_actor_critic: 1.77x
BlenderbotForCausalLM: 1.09x
cait_m36_384: 1.17x

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124021
Approved by: https://github.com/jansel
2024-05-12 07:46:44 +00:00
Bin Bao
2df114e6be [AOTI] Fix 'int' object is not subscriptable (#125731)
Summary: for https://github.com/pytorch/pytorch/issues/117369

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125731
Approved by: https://github.com/chenyang78
ghstack dependencies: #125291, #125730
2024-05-11 18:12:46 +00:00
Bin Bao
538877d204 [AOTI] Fix convolution_backward (#125730)
Summary: for https://github.com/pytorch/pytorch/issues/125922

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125730
Approved by: https://github.com/chenyang78
ghstack dependencies: #125291
2024-05-10 20:13:34 +00:00
Jiong Gong
3267814d53 [inductor] refactor: device dispatch inside do_bench (#125736)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125736
Approved by: https://github.com/shunting314
2024-05-09 23:50:02 +00:00
Jez Ng
affd7a9789 Get PT2 Cutlass backend working under fbcode [take 2] (#125688)
Differential Revision: D57051232

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125688
Approved by: https://github.com/chenyang78
2024-05-08 16:44:49 +00:00
lezcano
320af5eaa6 Compute bounds for the variables created during codegen (#123100)
Before we would just bail out on these bounds for all variables that did
not come from the FX graph. Now we propagate the bounds whenever we have
a rule for that op.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123100
Approved by: https://github.com/jgong5, https://github.com/peterbell10
2024-05-08 08:14:06 +00:00
Jez Ng
7863e04615 Back out "Get cutlass_library import working under fbcode" (#125606)
Summary: Original commit changeset: de79f6bfe348

Differential Revision: D57002294

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125606
Approved by: https://github.com/chenyang78
2024-05-07 16:55:11 +00:00
Jiong Gong
058e28108f [inductor][cpp] support int64 vertical vec reduction (fix #124821) (#125563)
Fix #124821

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125563
Approved by: https://github.com/desertfire
2024-05-07 03:56:22 +00:00
PyTorch MergeBot
2a42c40791 Revert "Compute bounds for the variables created during codegen (#123100)"
This reverts commit bb668c6468.

Reverted https://github.com/pytorch/pytorch/pull/123100 on behalf of https://github.com/huydhn due to Sorry for reverting you change but it is failing inductor tests bb668c6468 ([comment](https://github.com/pytorch/pytorch/pull/123100#issuecomment-2096837821))
2024-05-06 20:23:39 +00:00
PyTorch MergeBot
7ffa5558ee Revert "[FX] Update type hints in torch.fx._compatibility.py (#125469)"
This reverts commit 235b4d6ec2.

Reverted https://github.com/pytorch/pytorch/pull/125469 on behalf of https://github.com/izaitsevfb due to breaks pyre in dependent projects (internal: see D56986361) ([comment](https://github.com/pytorch/pytorch/pull/125469#issuecomment-2096665396))
2024-05-06 18:36:43 +00:00
lezcano
bb668c6468 Compute bounds for the variables created during codegen (#123100)
Before we would just bail out on these bounds for all variables that did
not come from the FX graph. Now we propagate the bounds whenever we have
a rule for that op.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123100
Approved by: https://github.com/jgong5, https://github.com/peterbell10
2024-05-06 18:12:15 +00:00
Jiong Gong
68a1f787c8 [inductor][cpp] move some common cpp utils to cpp_utils.py (#125152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125152
Approved by: https://github.com/desertfire, https://github.com/jansel
2024-05-06 04:30:30 +00:00
Aaron Gokaslan
1dd42e42c4 [BE]: Try TCH autofixes on torch/ (#125536)
Tries TCH autofixes and see what breaks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125536
Approved by: https://github.com/ezyang
2024-05-05 23:13:59 +00:00
Xuehai Pan
235b4d6ec2 [FX] Update type hints in torch.fx._compatibility.py (#125469)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125469
Approved by: https://github.com/Skylion007
ghstack dependencies: #125468
2024-05-05 19:30:22 +00:00
Kai Londenberg
10f673541e [Inductor cutlass backend] Enabled nonzero workspace and Cutlass StreamK (#125406)
Enable nonzero workspace and Cutlass StreamK for Inductor Cutlass GEMM ops.

This is a simpler rewrite of my original version of #119005 using @peterbell10 's workspace allocation mechanism from #117992

Test Plan:
 - Additional unit test in test_cutlass_backend.py which specifically tests StreamK GEMM with workspace requirement
 - CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125406
Approved by: https://github.com/jansel
2024-05-05 15:28:45 +00:00
Edward Z. Yang
6f70d22277 Extend torch.utils._sympy.symbol for more Inductor symbols (#125419)
I'm still missing a few, cdzq at least

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125419
Approved by: https://github.com/lezcano
ghstack dependencies: #125395
2024-05-04 09:05:00 +00:00
Edward Z. Yang
5503c29357 Introduce torch.utils._sympy.symbol (#125395)
This provides utilities for creating and querying properties on
sympy.Symbol.  I want to use this refactor to get a better handle on how
the 's' prefix is being used in Inductor.  To start, I only do
symbolic_shapes code because that's what I'm familiar with.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125395
Approved by: https://github.com/Skylion007
2024-05-03 21:24:23 +00:00
Kai Londenberg
b1a7455b99 [Inductor cutlass backend] Fix cutlass_utils.get_max_alignment() for strided layouts. (#124930)
Fixes cutlass_utils.get_max_alignment() which was so far not checking the alignment properly. Basically
the method so far assumed that the passed layout is contiguous and row-major, which does not have to be true.

Test Plan:
CI - test_cutlass_backend.py to prevent regressions
Added unit test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124930
Approved by: https://github.com/int3
ghstack dependencies: #124929
2024-05-03 20:50:26 +00:00
Bin Bao
e84a5b6cc0 [AOTI] Add missing std::move for constant args (#125329)
Summary: fix https://github.com/pytorch/pytorch/issues/123187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125329
Approved by: https://github.com/angelayi, https://github.com/chenyang78
2024-05-03 18:35:42 +00:00
Edward Z. Yang
dae574c713 Don't make replacements for i variables (#125398)
This was introduced in https://github.com/pytorch/pytorch/pull/110262
but actually it looks like they were trying to hit unbacked SymInt.
Now that unbacked SymInt is renamed to u, this code is no longer
necessary

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125398
Approved by: https://github.com/lezcano, https://github.com/Skylion007
2024-05-02 20:38:09 +00:00
Jez Ng
fb1bfe1156 Get cutlass_library import working under fbcode (#125257)
Differential Revision: D56764089

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125257
Approved by: https://github.com/chenyang78
2024-05-02 15:17:10 +00:00
Kai Londenberg
8046de3512 [Inductor cutlass backend] Remove epilogue nodes from Kernel call (#124929)
Minor refactoring:
Remove unused "fused epilogue node" arguments from some method  Kernel call signatures.

Test Plan:
Covered by current tests in test_cutlass_backend.py - no functional change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124929
Approved by: https://github.com/eellison
2024-05-02 13:02:31 +00:00
Kazuaki Ishizaki
9fec26e231 Fix typo under torch/_inductor directory (#119658)
This PR fixes typo in comments and msgs under `torch/_inductor` directory, and also changes the corresponding test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119658
Approved by: https://github.com/colesbury
2024-04-30 22:28:56 +00:00
Sam Larsen
254128c16e [inductor] Remove usage of device_interface from _inductor.runtime (#124592)
Differential Revision: [D56723770](https://our.internmc.facebook.com/intern/diff/D56723770)
Co-authored-by: Sam Larsen <slarsen@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124592
Approved by: https://github.com/masnesral
2024-04-30 16:54:16 +00:00
Shunting Zhang
dc514df2af [inductor] add triton code to SchedulerNode.debug_str (#125091)
Here is an example print: https://gist.github.com/shunting314/75c161368a833a535bd0d240b8099d7e

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125091
Approved by: https://github.com/jansel
ghstack dependencies: #125090
2024-04-30 08:27:53 +00:00
Edward Z. Yang
e5e623af4b Codegen runtime asserts in Inductor (#124874)
This completely subsumes https://github.com/pytorch/pytorch/pull/120816

This makes use of the unbacked binding machinery to teach Inductor how to generate deferred runtime asserts directly. There is some back story about why I did it this way, let me explain.

Previously, our strategy for generating runtime asserts was that Dynamo would insert them into the FX graph after finishing tracing, and we would attempt to code generate them based on the FX graph. This is a good strategy for export, where we immediately export the graph. However, this strategy was afflicted by problems in eager, where we reuse the same ShapeEnv as before. In particular, on subsequent graph passes, we would immediately turn all of these assertions into noops, because when we evaluated their expressions, we would see that because we had a deferred runtime assert in the ShapeEnv, we know "oh, of course this expression is True" already. Oops!

So, with this PR, we take the attitude that as long as the ShapeEnv sticks around, the ShapeEnv's list of deferred runtime asserts is the source of truth, and we don't put anything in the graph. So we just need to decide when to actually generate asserts, and the place I picked was Inductor lowering, since we already have an AssertScalar buffer concept, and so I just need to insert them at this point. AssertScalar also uses raw sympy.Expr rather than SymInt/Bool, so it is easier to prevent unrestricted simplification at this point.

There are a few things jumbled together in this PR. I can split them if you want, but some of the changes are before I changed my strategy, but they're useful changes anyway.

**torch/_dynamo/output_graph.py** and **torch/_inductor/lowering.py** - Here, we stop putting deferred runtime asserts in the graph. I also have to make sure we don't DCE unused symbol arguments; we're going to get some goofy graph arguments this way, will be good to restore that optimization eventually. We also just disable codegen for `_assert_scalar`  entirely; we assume that ShapeEnv will be good enough to capture all of these.

**torch/_inductor/codegen/wrapper.py** and **torch/_inductor/ir.py** - Add a way to codegen sizevars without forcing simplification

**torch/_inductor/graph.py** - The main logic. Our strategy is to interpose in the same place we are testing that unbacked SymInts are properly showing up in lowered code. The logic is directly analogous to the logic in the existing insert deferred runtime asserts FX pass, but it's simpler because sympy expressions can be directly stored on inductor IR nodes.

**torch/fx/experimental/symbolic_shapes.py** - For extra safety, we have a way of freezing runtime asserts, so that if you try to add more we error. This prevents us from adding runtime asserts after we've done lowering. There's a funny interaction with backwards which there's a comment for in graph.py

**torch/fx/passes/runtime_assert.py** - This is not really needed in this PR, but I rewrote the runtime assert logic to use unbacked_bindings rather than inferring it by looking for unbacked SymInts. Now, keypaths are translated into FX node acessors. Unfortunately, I couldn't delete the old inference code, because you still need it to find backed SymInts from arguments (as this pass may be used on graphs which don't explicitly bind all their shape variables as argments). There are some new tests exercising this.

TODO: I think we need to generate asserts for replacements too. This is a preexisting problem that the old FX pass had too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124874
Approved by: https://github.com/jansel
ghstack dependencies: #124864
2024-04-29 10:19:29 +00:00
Aaron Gokaslan
49ca2b3429 [BE]: Apply RUF025 perf fixups (#125104)
Uses `dict.fromkeys()` for more efficient dict construction. Automatically generated by RUF025 (prev).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125104
Approved by: https://github.com/ezyang
2024-04-28 15:09:21 +00:00
haozhe.zhu
c5b1a4c269 [inductor] share more cse cache during swap buffer (#124921)
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope.
For example,

```
auto tmp8 = -std::numeric_limits<float>::infinity();
auto tmp9 = [&]
{
    auto tmp12 = -std::numeric_limits<float>::infinity();
    return tmp12;
}
```
`tmp12` should not be created since it is same with `tmp8`.

We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.)

**Test Plan**
```
python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu
```
the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope.

Before this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = c10::convert<int>(x3);
                                    auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1);
                                    auto tmp21 = static_cast<int>(256);
                                    auto tmp22 = at::vec::Vectorized<int>(tmp21);
                                    auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22);
                                    auto tmp25 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp26 = tmp23 & tmp25;
                                    auto tmp24 = [&]
                                    {
                                        auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp27;
                                    }
                                    ;
                                    auto tmp28 =
                                    [&]
                                    {
                                        if (tmp26.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp29 = static_cast<float>(0.0);
                                    auto tmp30 = at::vec::Vectorized<float>(tmp29);
                                    auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>());
                                    return tmp31;
                                }
                                ;
                                auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp33 = static_cast<float>(0.0);
                                auto tmp34 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp35 = at::vec::Vectorized<float>(tmp33);
                                auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>());
                                auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>());
                                return tmp37;
                            }
                            ;
                            auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp40 = static_cast<int>(3);
                            auto tmp41 = tmp39 < tmp40;
                            auto tmp42 = [&]
                            {
                                auto tmp43 = c10::convert<int>(x3);
                                auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1);
                                auto tmp45 = static_cast<int>(256);
                                auto tmp46 = at::vec::Vectorized<int>(tmp45);
                                auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46);
                                auto tmp49 = at::vec::VecMask<float,1>::from(tmp41);
                                auto tmp50 = tmp47 & tmp49;
                                auto tmp48 = [&]
                                {
                                    auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp51;
                                }
                                ;
                                auto tmp52 =
                                [&]
                                {
                                    if (tmp50.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp53 = static_cast<float>(0.0);
                                auto tmp54 = at::vec::Vectorized<float>(tmp53);
                                auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>());
                                return tmp55;
                            }
                            ;
                            auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp57 = static_cast<float>(0.0);
                            auto tmp58 = at::vec::VecMask<float,1>::from(tmp41);
                            auto tmp59 = at::vec::Vectorized<float>(tmp57);
                            auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>());
                            auto tmp61 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>());
                            tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = c10::convert<int64_t>(x3);
                                    auto tmp15 = static_cast<int64_t>(256);
                                    auto tmp16 = tmp14 >= tmp15;
                                    auto tmp17 = [&]
                                    {
                                        auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp18;
                                    }
                                    ;
                                    auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0);
                                    auto tmp20 = static_cast<float>(0.0);
                                    auto tmp21 = tmp16 ? tmp19 : tmp20;
                                    return tmp21;
                                }
                                ;
                                auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp23 = static_cast<float>(0.0);
                                auto tmp24 = tmp12 ? tmp22 : tmp23;
                                auto tmp25 = tmp6 ? tmp9 : tmp24;
                                return tmp25;
                            }
                            ;
                            auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp28 = static_cast<int64_t>(3);
                            auto tmp29 = tmp27 < tmp28;
                            auto tmp30 = [&]
                            {
                                auto tmp31 = c10::convert<int64_t>(x3);
                                auto tmp32 = static_cast<int64_t>(256);
                                auto tmp33 = tmp31 >= tmp32;
                                auto tmp34 = [&]
                                {
                                    auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp35;
                                }
                                ;
                                auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0);
                                auto tmp37 = static_cast<float>(0.0);
                                auto tmp38 = tmp33 ? tmp36 : tmp37;
                                return tmp38;
                            }
                            ;
                            auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0);
                            auto tmp40 = static_cast<float>(0.0);
                            auto tmp41 = tmp29 ? tmp39 : tmp40;
                            auto tmp42 = tmp2 ? tmp26 : tmp41;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42;
                        }
                    }
                }
            }
        }
    }
}
''')
```
After this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = at::vec::Vectorized<int>(tmp1);
                                    auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19);
                                    auto tmp22 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp23 = tmp20 & tmp22;
                                    auto tmp21 = [&]
                                    {
                                        auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp24;
                                    }
                                    ;
                                    auto tmp25 =
                                    [&]
                                    {
                                        if (tmp23.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp26 = static_cast<float>(0.0);
                                    auto tmp27 = at::vec::Vectorized<float>(tmp26);
                                    auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>());
                                    return tmp28;
                                }
                                ;
                                auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp30 = static_cast<float>(0.0);
                                auto tmp31 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp32 = at::vec::Vectorized<float>(tmp30);
                                auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>());
                                auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>());
                                return tmp34;
                            }
                            ;
                            auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp37 = static_cast<int>(3);
                            auto tmp38 = tmp36 < tmp37;
                            auto tmp39 = [&]
                            {
                                auto tmp40 = c10::convert<int>(x3);
                                auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1);
                                auto tmp42 = at::vec::Vectorized<int>(tmp1);
                                auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42);
                                auto tmp45 = at::vec::VecMask<float,1>::from(tmp38);
                                auto tmp46 = tmp43 & tmp45;
                                auto tmp44 = [&]
                                {
                                    auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp47;
                                }
                                ;
                                auto tmp48 =
                                [&]
                                {
                                    if (tmp46.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp49 = static_cast<float>(0.0);
                                auto tmp50 = at::vec::Vectorized<float>(tmp49);
                                auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>());
                                return tmp51;
                            }
                            ;
                            auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp53 = static_cast<float>(0.0);
                            auto tmp54 = at::vec::VecMask<float,1>::from(tmp38);
                            auto tmp55 = at::vec::Vectorized<float>(tmp53);
                            auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>());
                            auto tmp57 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>());
                            tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = tmp4 >= tmp1;
                                    auto tmp15 = [&]
                                    {
                                        auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp16;
                                    }
                                    ;
                                    auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0);
                                    auto tmp18 = static_cast<float>(0.0);
                                    auto tmp19 = tmp14 ? tmp17 : tmp18;
                                    return tmp19;
                                }
                                ;
                                auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp21 = static_cast<float>(0.0);
                                auto tmp22 = tmp12 ? tmp20 : tmp21;
                                auto tmp23 = tmp6 ? tmp9 : tmp22;
                                return tmp23;
                            }
                            ;
                            auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp26 = static_cast<int64_t>(3);
                            auto tmp27 = tmp25 < tmp26;
                            auto tmp28 = [&]
                            {
                                auto tmp29 = c10::convert<int64_t>(x3);
                                auto tmp30 = tmp29 >= tmp1;
                                auto tmp31 = [&]
                                {
                                    auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp32;
                                }
                                ;
                                auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0);
                                auto tmp34 = static_cast<float>(0.0);
                                auto tmp35 = tmp30 ? tmp33 : tmp34;
                                return tmp35;
                            }
                            ;
                            auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0);
                            auto tmp37 = static_cast<float>(0.0);
                            auto tmp38 = tmp27 ? tmp36 : tmp37;
                            auto tmp39 = tmp2 ? tmp24 : tmp38;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39;
                        }
                    }
                }
            }
        }
    }
}
''')
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124921
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: #124597
2024-04-28 04:33:25 +00:00
haozhe.zhu
57790fd088 [inductor] share cse cache during vectorized indirect load (#124597)
Fix https://github.com/pytorch/pytorch/issues/123502

`swap_buffer` in not needed in vectorized indirect load, remove it to share cse buffer.
```
auto tmp8 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
//
// other codes
//
// also store tmp7 here (redundant tmp16)
auto tmp16 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124597
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-04-28 01:02:48 +00:00
Kai Londenberg
cd06c73cbd [Inductor Cutlass backend] Improved GEMM template (#124577)
Improves the Cutlass backend GEMM template:

 * Adds code which allows to create stand-alone test runners for Cutlass GEMM Kernels, which allows (manual) debugging of, for example, CUDA IMA errors or similar problems which occur in practice. Includes some utility code and tests to actually compile and run these standalone tests.
 * Cleans up the GEMM template code through various refactorings
 * Eliminates code sections and options that are unneccessary now that epilogue fusions are being removed.
 * Limits the scope of a workaround for (flaky) Cutlass issues with bias broadcasting to neccessary cases.
 * Puts some CPU runtime checks into #if / #endif blocks, such that it's possible to compile CUTLASS Kernels with lower CPU overhead.
 * Add documentation comments

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124577
Approved by: https://github.com/jansel
ghstack dependencies: #124576
2024-04-26 20:03:20 +00:00
Aaron Orenstein
a8574a9719 Fix global flake8 issues (#124771)
Prior to this `lintrunner --all-files --take FLAKE8` failed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124771
Approved by: https://github.com/Skylion007
ghstack dependencies: #124428
2024-04-26 15:35:53 +00:00
PyTorch MergeBot
6b54f9d3e1 Revert "fix Invalid call to aoti_torch_tensor_copy_ #123039 (#124037)"
This reverts commit f9379ebbbf.

Reverted https://github.com/pytorch/pytorch/pull/124037 on behalf of https://github.com/jeanschmidt due to introducing regressions in benchmark, see D56623194 for more details ([comment](https://github.com/pytorch/pytorch/pull/124037#issuecomment-2079574308))
2024-04-26 15:07:09 +00:00
PyTorch MergeBot
1ac60484c1 Revert "Fix global flake8 issues (#124771)"
This reverts commit f01275934b.

Reverted https://github.com/pytorch/pytorch/pull/124771 on behalf of https://github.com/jeanschmidt due to Unfortunately, I needed to revert #123735 and this one depends on it. So please check if there are no merge conflicts or breakages and feel free to merge this PR again ([comment](https://github.com/pytorch/pytorch/pull/124428#issuecomment-2078699836))
2024-04-26 06:15:17 +00:00
chilli
7321005dd8 Add support for capturing tensors with score_mod (#124444)
```
import torch
from torch import nn
import torch.nn.functional as F
import torch._inductor.config as config
# torch.set_default_device('cuda')

import torch
from torch.nn.attention._templated_attention import _templated_attention as templated_attention
from triton.testing import do_bench
from torch.nn.attention import SDPBackend, sdpa_kernel

index = torch.ops.aten
torch.manual_seed(0)

B = 16
H = 16
S = 2048
D = 64

head_scale = torch.randn(H, device='cuda')
def alibi(score, batch, head, token_q, token_kv):
    return score + torch.ops.aten.index(head_scale, [head]) * (token_q - token_kv)
bias = torch.randn(H, S, S, dtype=torch.float16, device='cuda')

query = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
key = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
value = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)

compiled = torch.compile(templated_attention)
out = compiled(query, key, value, score_mod=alibi)
out2 = templated_attention(query, key, value,score_mod=alibi)
print((out - out2).abs().mean())
assert (out - out2).abs().mean() < 1e-3
print("Flash (no mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value)))
print("Flash (mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value, attn_mask=bias)))
print("flexattention: ", do_bench(lambda: compiled(query, key, value, score_mod=alibi)))
```
<img width="324" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/18c175d0-2720-4dfd-8747-85b8a8f609f5">

Differential Revision: [D56583900](https://our.internmc.facebook.com/intern/diff/D56583900)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124444
Approved by: https://github.com/jansel, https://github.com/drisspg
2024-04-26 01:02:28 +00:00
Tuan Trieu
f9379ebbbf fix Invalid call to aoti_torch_tensor_copy_ #123039 (#124037)
fixes #123039

In abi mode, ExternKernelSchedulerNode generates code using `aoti_torch_tensor_copy_` which requires `AtenTensorHandle`, but the allocation generates ArrayRefTensor to allocate mem in stack. To fix this issue, this PR prevents ExternKernelSchedulerNode from using stack-mem-allocation in abi, and creates AtenTensorHandle instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124037
Approved by: https://github.com/desertfire
2024-04-26 00:16:16 +00:00
David Berard
4259e5d0e0 [inductor] Specialize on unguarded alignment of example inputs (#123319)
When inductor generates triton code, the triton code can either assume that the inputs given to it are aligned or unaligned. If they are aligned, triton can use more efficient instructions (like vectorized loads or tensor cores). However, if we generate "aligned" code and pass in unaligned inputs, the triton code will error out; to fix this, we clone unaligned inputs that are passed to triton kernels that expect aligned inputs. This can lead to excessive clones if we have inputs that are not expected to be aligned.

In this PR, we use the example input to decide whether the generated triton code should assume alignment or not. If the example input is aligned, then we will generate triton code that assumes alignment; if at runtime we receive an unaligned input, we'll make a clone. Meanwhile, if the example input is not aligned, the generated triton code will not assume inputs are aligned and we won't ever need to clone.

Note that the alignment of the inputs is not guarded on; we found that adding guards on tensor offsets (a) was slow in cases where we do a lot of comparisons on tensor offsets, and (b) led to a lot of recompilations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123319
Approved by: https://github.com/eellison
2024-04-25 22:28:15 +00:00
Kai Londenberg
9aeeb8e925 [Inductor Cutlass backend] Improve GEMM op filtering (#124576)
Add configurable allowlist / denylist regular expressions to make it possible to exclude certain
CUTLASS GEMM implementations ( for example "pingpong" Kernels due to undesired numerical behavior ).

Remove usage of old 2.x Cutlass Kernels entirely.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124576
Approved by: https://github.com/jansel, https://github.com/eellison
2024-04-25 16:33:54 +00:00
Aaron Orenstein
f01275934b Fix global flake8 issues (#124771)
Prior to this `lintrunner --all-files --take FLAKE8` failed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124771
Approved by: https://github.com/Skylion007
ghstack dependencies: #124428
2024-04-25 14:25:00 +00:00
Edward Z. Yang
4c44e2b236 Improved unbacked SymInt input support in Inductor (#124739)
This is a subset of changes extracted from https://github.com/pytorch/pytorch/pull/124683/

This PR contains modifications to make Inductor work with unbacked symbol inputs, which can occur when a data-dependent sized tensor is saved for backwards. The problems to be fixed:

* When binding initial symbols, we unconditionally bind unbacked symbols (instead of computing if they are needed, which only looks at backed symbols)
* Benchmark generation code doesn't work with unbacked symints as we have no hints to actually feed in real values. So I pick a random number and you are expected to fix it if it doesn't work
* Need to make sure we don't install dependencies on unbacked SymInt inputs, that puts us down the "promptly deallocate the input" path, but that's pointless for unbacked SymInt

Fixes https://github.com/pytorch/pytorch/issues/124652

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124739
Approved by: https://github.com/jansel
ghstack dependencies: #124310, #124314, #124316, #124394
2024-04-25 13:29:53 +00:00
PyTorch MergeBot
f6ce94dca5 Revert "[inductor] Remove usage of device_interface from _inductor.runtime (#124592)"
This reverts commit 5d45eb77f1.

Reverted https://github.com/pytorch/pytorch/pull/124592 on behalf of https://github.com/jeanschmidt due to breaking internal tests, check D56522594 ([comment](https://github.com/pytorch/pytorch/pull/124592#issuecomment-2076957668))
2024-04-25 11:28:23 +00:00
PyTorch MergeBot
0ca1ff3dce Revert "Add support for capturing tensors with score_mod (#124444)"
This reverts commit 7c253a7776.

Reverted https://github.com/pytorch/pytorch/pull/124444 on behalf of https://github.com/jeanschmidt due to Breaking internal tests, check D56522566 ([comment](https://github.com/pytorch/pytorch/pull/124444#issuecomment-2076908582))
2024-04-25 10:56:38 +00:00
leslie-fang-intel
2d7f709752 [Inductor] Force the parallel depth as outer loop fusion depth (#123899)
**Summary**
Fix issue: https://github.com/pytorch/pytorch/issues/123801 which brings performance regression of `pyhpc_turbulent_kinetic_energy` after outer loop fusion.

**Root Cause**

- [Generated Kernel before Outer Loop Fusion](https://gist.github.com/leslie-fang-intel/54fe21ac8871fc63b9bf20fdb6edf209)
  - Taking below 2 kernels as example:
    - [Kernel 0](https://gist.github.com/leslie-fang-intel/54fe21ac8871fc63b9bf20fdb6edf209#file-pyhpc_turbulent_kinetic_energy-before-outer-loop-fusion-py-L255-L305) has 2 loop levels with size [200, 200]. Parallelization is not feasible due to the inefficient number of elements determined by [`decide_parallel_depth`](aaec97a403/torch/_inductor/codegen/cpp.py (L2145-L2164)). Therefore, the loop code will be generated with the `#pragma omp single` directive.
    - [Kernel 1](https://gist.github.com/leslie-fang-intel/54fe21ac8871fc63b9bf20fdb6edf209#file-pyhpc_turbulent_kinetic_energy-before-outer-loop-fusion-py-L306-L316) has 3 loop levels with size [200, 200, 26] which has enough number of elements to be parallelized.
- [Generated Kernel after Outer Loop Fusion](https://gist.github.com/leslie-fang-intel/57a497b9d9c6aa82b1c6a686292fc887)
  - After outer loop fusion, `Kernel0` and `Kernel1` has been fused into one [OuterLoopFusedKernel](https://gist.github.com/leslie-fang-intel/57a497b9d9c6aa82b1c6a686292fc887#file-pyhpc_turbulent_kinetic_energy-after-outer-loop-fusion-py-L261-L497), the outer loop size is [200, 200] which does not contain enough number of elements to do parallelization.

In this PR, we propose a fix for `loop_nest` involving `OuterLoopFusedKernel`. The fix entails adding a specific heuristic for `OuterLoopFusedKernel` to determine the parallel depth by combining `outer_loop_fusion_depth` with the internal kernels' parallel depth.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123899
Approved by: https://github.com/jgong5, https://github.com/lezcano
2024-04-25 09:50:46 +00:00
Edward Z. Yang
13ab24f192 Reimplement unbacked symbol bindings in Inductor (#124394)
This PR has a lot of "draw the rest of the fucking owl" energy. Here's how to break it down.

1. **torch/_inductor/graph.py** - We start by tightening unbacked symbol invariants. Specifically, as we lower FX nodes, we check whether or not every unbacked_binding recorded on the FX node meta, actually ends up getting bound (according to get_unbacked_symbol_defs) in all the buffers generated by the lowering. Hopefully this invariant is self evident. This leads to a lot of failures.
2. **torch/_inductor/ir.py** - Problem 1: There is softness in how Inductor computes defs of unbacked symbols in IR node. Previously, we tried to infer it by looking at the output sizes/strides/etc and see if new unbacked symbols popped up that we hadn't seen in the inputs. I don't know exactly what was buggy about the old code, but sometimes we would fail to notice an unbacked symbol had been bound, or rebind an unbacked symbol multiple times. Fortunately, thanks to the earlier PRs in our stack, we now have a nice list of unbacked symbol bindings from FX, so we now just store it directly on ExternKernel and use it directly to report defs. This has to be done twice: once for FallbackKernel (e.g., nonzero) and once for DynamicScalar (e.g., item) (see also **torch/_inductor/lowering.py**, **torch/_inductor/codegen/wrapper.py** and  **torch/_inductor/codegen/cpp_wrapper_cpu.py** for the lowering and codegen changes for item)
   * **process_kernel** - Sidequest! It turns out that Inductor lowering can reallocate unbacked symbols. This happens specifically when we repropagate fake tensors through the operator in `process_kernel`. This repropagation process is necessary because Inductor may have changed the strides of input tensors, and it must now recompute the strides so that it can continue to appropriately plan the rest of the lowering process. This is fine: we just make sure we do the rebind unbacked + compute_unbacked_bindings dance we've been doing previously in the PR stack. But instead of putting unbacked_bindings on a new FX node, they go straight into our unbacked_bindings on the Inductor IR node.
    * **codegen_unbacked_symbol_defs** - Sidequest! FallbackKernel lowering is done in two steps. First, you emit the FallbackKernel buffer. Then, you emit MultiOutput buffers which actually give access to the individual outputs of FallbackKernel, which may have been multi-output. There is a design decision here: does the FallbackKernel bind the unbacked symbols, or the MultiOutput buffer? Historically, we put the binding on MultiOutput buffer, because it's more convenient: the FallbackKernel buffer is fake, in fact, it doesn't even get a name in C++ codegen. But it's kind of inconsistent with the keypath model that we've been tracking unbacked bindings with: if you have a multi-output node, you'd expect a keypath like `[0].size()[0]` representing the first output's first dimension size. That suggests that it's the FallbackKernel that should define the things. So that was my first implementation. Unfortunately, the C++ codegen is too cursed and I could not understand how to make it work in that case. So now we just unsoundly assume you cannot have multi-output data dependent output, and do the codegen in MultiOutput. There are some comments explaining exactly what we are improperly assuming.
3. **_rename_unbacked_to** in **torch/fx/experimental/symbolic_shapes.py** - Previously, when we renamed unbacked symbols, we clobbered any facts we previously knew about them. So for example, if we had a replacement `u0 -> s0` but then we renamed u0 to u1, we would now setup the replacement `u0 -> u1`, clobbering the old replacement. This apparently didn't matter in earlier PRs in the stack, but with Inductor now on the ball, there were some tests that indicated this was a problem. The solution is easy: if u0 had a preexisting replacement, reapply it to u1. However...
    * **torch/_functorch/_aot_autograd/collect_metadata_analysis.py** - When we run forward analysis, this triggers fake tensor repropagation and fresh allocations. Previously, we just cleared out the pending symbols when finished the analysis. But with the change above, this would also migrate replacements to the new symbols... which are now dead. So now we explicitly suppress generation of these symbols with `ignore_fresh_unbacked_symbols` so that no rebinding happens at all.
    * **torch/_dynamo/eval_frame.py** - same deal; I just searched for all sites we called clear() on pending
4. The last step is fixing the long tail of extra problems that show up, now that unbacked_bindings are load bearing into Inductor
    * **torch/_dynamo/eval_frame.py** - Some of the exports are making copies of nodes without repropagating fake tensors, so in this case, it is important to also copy the `unbacked_bindings` (apparently this didn't matter before without the Inductor changes)
    * **torch/_export/pass_base.py** - I discover that this is doing fake tensor repropagation via a test suite failure. Do the same playbook as AOTAutograd: PropagateUnbackedSymInts too!  Actually, they also have implemented their own tracer as well, so do the same playbook as proxy_tensor: record unbacked_bindings on the newly traced nodes. UGH code duplication.
    * **torch/_subclasses/fake_tensor.py**, **torch/_subclasses/fake_impls.py** (with call site updates at  **torch/_functorch/_aot_autograd/traced_function_transforms.py** and **torch/fx/passes/fake_tensor_prop.py**) - What's this new epoch thing? I noticed that sometimes I would be retracing, call nonzero() on a fake tensor, and not allocate a new unbacked symbol. This is actually bad, because if I don't get a new unbacked symbol, I don't know there's a binding site, and `unbacked_bindings` is now missing a binding. The reason for this is memoization: if I reuse the exact same fake tensor on my retrace, it will already have an unbacked symint memoized on it and we will short circuit allocation. Well, that's no good. So I associate the memos with a fake tensor epoch, and every time you start a new fake tensor propagation from scratch, you bump the epoch so that I clear all the memos.
    * **torch/_inductor/scheduler.py** - I notice in unit tests that V.current_node is not always set when we call process_kernel. So I save it into the IR node and restore it when we are running `get_estimated_runtime`.
    * **torch/fx/experimental/symbolic_shapes.py** - A few things
      * **rebind_unbacked** (re **_tensor_version**). Ordinarily, when you have an unbacked SymInt, you persistently hvae it all the way to the end of the program. `_tensor_version` violates this: this generates an unbacked SymInt (for reasons I don't quite understand?) and then gets rid of it later. This triggered an assert violation. I think this op is kind of misusing unbacked SymInt, but I didn't know how to refactor it, so it gets a special case.
      * **rebind_unbacked** (re **Simplify SymBool binding**). Ugh, SymBool, what a pain in the butt. I have an assert that you can only rebind unbacked symbol to another unbacked symbol. This assert fails when a boolean is involved, because the result of running keypath on the result is not `u1`, it's `sympy.Piecewise(... sympy.Eq(u1, 1) ...)`. This is actually just `u1`, but Sympy doesn't know it because it doesn't know that `u1` value range is `[0, 1]`. So we manually implement the simplification needed to get the assert to pass.
      * **compute_unbacked_bindings** (re **This is pretty fragile**). There is a really funny disaster involving memoization and Inductor process kernel. Ordinarily when I retrace, if there was a memo hit in the old trace, there will be a memo hit in the new trace. However, Inductor process kernel breaks this, because it recreates fake tensor inputs to the operator call from scratch (since they might have different strides), and obviously these tensor inputs don't have the memo from the old one. I tried a little bit to try to manually transplant the memo to the new fake tensor but it seemed hopeless, so I just let the fresh symbol ride, allocating a new unbacked symbol. However, in one of our tests, we rely on knowing that the first nonzero call is equal to the second (memoized) nonzero call. The equality test looked pretty easy to discharge, so I just went ahead and added a deferred runtime assert to this effect and it worked.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124394
Approved by: https://github.com/jansel
ghstack dependencies: #124310, #124314, #124316
2024-04-25 02:08:59 +00:00
Kai Londenberg
7d94f52a8a [Inductor Cutlass backend] clean up CUTLASSGemmTemplate.add_cutlass_gemm_choices (#124575)
Clean up CUTLASSGemmTemplate.add_cutlass_gemm_choices, removing code that became unneccessary by removing EVT-based epilogue fusion.

Test Plan:
Already covered by test_cutlass_backend.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124575
Approved by: https://github.com/jansel
ghstack dependencies: #121497, #123930, #123932, #121734, #124107, #124574
2024-04-24 14:00:12 +00:00
Kai Londenberg
e76b5e3cc8 [Inductor Cutlass backend] Disable epilogue fusions (#124107)
This diff disables Cutlass backend EVT epilogue fusions. It does not yet contain the removal of most of the underlying implementation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124107
Approved by: https://github.com/jansel
ghstack dependencies: #121497, #123930, #123932, #121734
2024-04-24 13:56:44 +00:00
Sergii Dymchenko
f0f7452e31 Do not propogate (#124769)
Fix the propogate typos.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124769
Approved by: https://github.com/Skylion007
2024-04-24 02:18:18 +00:00
chilli
7c253a7776 Add support for capturing tensors with score_mod (#124444)
```
import torch
from torch import nn
import torch.nn.functional as F
import torch._inductor.config as config
# torch.set_default_device('cuda')

import torch
from torch.nn.attention._templated_attention import _templated_attention as templated_attention
from triton.testing import do_bench
from torch.nn.attention import SDPBackend, sdpa_kernel

index = torch.ops.aten
torch.manual_seed(0)

B = 16
H = 16
S = 2048
D = 64

head_scale = torch.randn(H, device='cuda')
def alibi(score, batch, head, token_q, token_kv):
    return score + torch.ops.aten.index(head_scale, [head]) * (token_q - token_kv)
bias = torch.randn(H, S, S, dtype=torch.float16, device='cuda')

query = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
key = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
value = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)

compiled = torch.compile(templated_attention)
out = compiled(query, key, value, score_mod=alibi)
out2 = templated_attention(query, key, value,score_mod=alibi)
print((out - out2).abs().mean())
assert (out - out2).abs().mean() < 1e-3
print("Flash (no mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value)))
print("Flash (mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value, attn_mask=bias)))
print("flexattention: ", do_bench(lambda: compiled(query, key, value, score_mod=alibi)))
```
<img width="324" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/18c175d0-2720-4dfd-8747-85b8a8f609f5">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124444
Approved by: https://github.com/jansel, https://github.com/drisspg
2024-04-23 17:54:08 +00:00
Jason Ansel
5d45eb77f1 [inductor] Remove usage of device_interface from _inductor.runtime (#124592)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124592
Approved by: https://github.com/masnesral
2024-04-23 17:51:25 +00:00
Bin Bao
b2fd224f27 [AOTI] Add more ABI-compatiblity unit test (#123900)
Summary: Follow https://github.com/pytorch/pytorch/pull/123848, and test more c10 util functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123900
Approved by: https://github.com/chenyang78
2024-04-23 16:06:40 +00:00