Commit Graph

552 Commits

Author SHA1 Message Date
Karthick Panner Selvam
cddcaa1903 [Inductor] Add DeviceAssert op to enable device-side assertion in torch.compile (#160677)
This PR introduces a device_assert op to trigger device-side assertions within torch.compile. This implementation is based on the suggestion in [this comment](https://github.com/pytorch/pytorch/issues/147282#issuecomment-2756056084).

Changes Included

- Implemented device_assert op and overrides has_side_effect to return True to avoid removal by dead code elimination.
- Commented out the assert_async_msg_decomp and functional_assert_async_msg_decomp decompositions to disable the default assert decomposition inside Inductor.
- Added lowering for torch.ops.aten._assert_async.msg to convert assert calls into the ops_handler.
- Implemented the codegen method for the device_assert op. This supports generating C++ and Triton code.
- Added test cases to verify both "should throw" and "should not throw" scenarios.

Fixes #147282

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160677
Approved by: https://github.com/mlazos
2025-08-26 22:33:23 +00:00
CaoE
23b033452f [Inductor][CPP] Fix layout for local buf in outer loop fusion (#160857)
Fixes #159154

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160857
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
2025-08-21 06:00:04 +00:00
Isuru Fernando
f305019377 [inductor] propagate shapes in CSEVariable (#152198)
Fixes #149905

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152198
Approved by: https://github.com/eellison
2025-08-19 16:46:38 +00:00
Mu-Chu Lee
40311e2ec1 [AOTInductor] ABI-Compatibility for RecordFunction. (#159842)
Summary:
Previous our implementation for RecordFunction injects Aten into
codegen, which is breaking the ABI contract for AOTInductor.

C10::IValue is aded to call the full record function. The extension of
more profiling info will come in later PRs.

Test Plan:
Included in commit.

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D79622071](https://our.internmc.facebook.com/intern/diff/D79622071)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159842
Approved by: https://github.com/desertfire
2025-08-15 21:45:47 +00:00
Shangdi Yu
aa99e0958f Separate provenance tracking to different levels (#160383)
Summary: as title. We've got request from various parties who are interested in turning on the provenance tracking by default. In this PR, we prepare to turn on part of the provenance tracking that doesn't have too much overhead by default.

- Change `provenance_tracking` config to `provenance_tracking_level`
- turn on the following provenance tracking by default when `basic_provenance_tracking`=True
    - `set_kernel_post_grad_provenance_tracing` for kernels, this add mapping between triton kernels and post_grad nodes
    - `dump_inductor_provenance_info` if we're dumping tlparse log
    - `get_graph_provenance_json` and dump `reate_mapping_pre_post_grad_nodes`. This creates mapping between pre_grad and post_grad nodes. Since we're not turning on the provenance tracking in GraphTransformObserver by default, the mapping here maybe incomplete/limited.
    - add stack trace from post grad nodes to inductor IR nodes
    - add exception swallowing for all functions above

Test Plan:
CI

Rollback Plan:

Differential Revision: D80031559

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160383
Approved by: https://github.com/angelayi
2025-08-15 04:59:35 +00:00
thenumberouscode
24f43d0da7 [inductor] [cpu] fix the dype hardcoded to int64 in store_reduction (#157904)
## Fixes https://github.com/pytorch/pytorch/issues/157683

## mini repro
* Just copy the code from the issue to reproduce it.
```python
import torch

device = "cpu"

# Input tensors
v2_0 = torch.randn(16, 24, 59, dtype=torch.complex64, device=device)
v3_0 = torch.randn(16, 24, 59, dtype=torch.complex64, device=device)

def my_model(v2_0, v3_0):
    v6_0 = -v3_0
    v4_0 = v2_0 * v3_0
    v1_0 = v4_0.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
    v0_0 = v2_0.to(torch.int32)
    v5_0 = v0_0.amax(dim=0)

    return v6_0, v4_0, v1_0, v0_0, v5_0

v6_0, v4_0, v1_0, v0_0, v5_0 = my_model(v2_0, v3_0)
print("v6_0", v6_0.shape)
print("v4_0", v4_0.shape)

compiled_model = torch.compile(my_model, backend="inductor")

v6_0, v4_0, v1_0, v0_0, v5_0 = compiled_model(v2_0, v3_0)

print("v6_0", v6_0.shape)
print("v4_0", v4_0.shape)
print("v1_0", v1_0.shape)
print("v0_0", v0_0.shape)
print("v5_0", v5_0.shape)

```
error_stack
```
/home/admin/pytorch/pytorch/torch/include/ATen/cpu/vec/vec_convert.h:41:1: 附注:candidate: ‘template<class dst_t, class src_t> std::enable_if_t<(! is_same_v<dst_t, src_t>), at::vec::CPU_CAPABILITY::Vectorized<T> > at::vec::CPU_CAPABILITY::convert(const at::vec::CPU_CAPABILITY::Vectorized<T>&)’
   41 | convert(const Vectorized<src_t>& src) {
      | ^~~~~~~
/home/admin/pytorch/pytorch/torch/include/ATen/cpu/vec/vec_convert.h:41:1: 附注:  template argument deduction/substitution failed:
/tmp/torchinductor_admin/6k/c6kr65o43rlmp2cmkpn5ezewhe5bla4w72hpcrg5biyelrs4skyw.main.cpp:37:99: 错误:模板参数数目不对(不应是 4 个而应是 2 个)
   37 |                     auto int32_t_tmp_acc0_vec = at::vec::convert<int32_t,1,int64_t,2>(tmp_acc0_vec);
```
## summary
**The C++ kernel generated by the Inductor had the wrong data type for the output variable; it should be int32_t instead of int64_t. This incorrect data type led to an incompatible data type conversion, which caused the g++ compilation to fail.**
The original code that caused the problem.
```
def my_model(v2_0, v3_0):
    v6_0 = -v3_0
    v4_0 = v2_0 * v3_0
    v1_0 = v4_0.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
    v0_0 = v2_0.to(torch.int32)
    // The original code that caused the problem.
    v5_0 = v0_0.amax(dim=0)
```

## proof procedure
The c++ kernel generated by inductor:
```c++
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void kernel(const int32_t* in_ptr0,
                       int32_t* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1416L); x0+=static_cast<int64_t>(16L))
        {
            {
                int32_t tmp_acc0_arr[16];
                for (int i = 0; i < 16; i++)
                {
                    tmp_acc0_arr[i] = std::numeric_limits<int32_t>::min();
                }
                int32_t tmp_acc0 = std::numeric_limits<int32_t>::min();
                at::vec::Vectorized<int32_t> tmp_acc0_vec = at::vec::Vectorized<int32_t>(std::numeric_limits<int32_t>::min());
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L); x1+=static_cast<int64_t>(1L))
                {
                    {
                        if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(1408L)))
                        {
                            auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr0 + static_cast<int64_t>(x0 + 1416L*x1), static_cast<int64_t>(16));
                            tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp0);
                        }
                        if(C10_UNLIKELY(x0 >= static_cast<int64_t>(1408L) && x0 < static_cast<int64_t>(1416L)))
                        {
                            for (int64_t x0_tail = static_cast<int64_t>(1408L);x0_tail < static_cast<int64_t>(1416L); x0_tail++)
                            {
                                auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail + 1416L*x1)];
                                tmp_acc0_arr[x0_tail - static_cast<int64_t>(1408L)] = max_propagate_nan(tmp_acc0_arr[x0_tail - static_cast<int64_t>(1408L)], tmp0);
                            }
                        }
                    }
                }
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(1408L)))
                {
                   // impossible data type conversion which would caused the g++ compilation to fail.
                    auto int32_t_tmp_acc0_vec = at::vec::convert<int32_t,1,int64_t,2>(tmp_acc0_vec);
                    int32_t_tmp_acc0_vec.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(1408L) && x0 < static_cast<int64_t>(1416L)))
                {
                    for (int64_t x0_tail = static_cast<int64_t>(1408L);x0_tail < static_cast<int64_t>(1416L); x0_tail++)
                    {
                        out_ptr0[static_cast<int64_t>(x0_tail)] = tmp_acc0_arr[x0_tail - static_cast<int64_t>(1408L)];
                    }
                }
            }
        }
    }
}
```
the compilers complains
```text
/home/admin/pytorch/pytorch/torch/include/ATen/cpu/vec/vec_convert.h:41:1: 附注:candidate: ‘template<class dst_t, class src_t> std::enable_if_t<(! is_same_v<dst_t, src_t>), at::vec::CPU_CAPABILITY::Vectorized<T> > at::vec::CPU_CAPABILITY::convert(const at::vec::CPU_CAPABILITY::Vectorized<T>&)’
   41 | convert(const Vectorized<src_t>& src) {
      | ^~~~~~~
/home/admin/pytorch/pytorch/torch/include/ATen/cpu/vec/vec_convert.h:41:1: 附注:  template argument deduction/substitution failed:
/tmp/torchinductor_admin/6k/c6kr65o43rlmp2cmkpn5ezewhe5bla4w72hpcrg5biyelrs4skyw.main.cpp:37:99: 错误:模板参数数目不对(不应是 4 个而应是 2 个)
   37 |                     auto int32_t_tmp_acc0_vec = at::vec::convert<int32_t,1,int64_t,2>(tmp_acc0_vec);
```
so the following line have problem
```c++
    // this line means that tmp_acc0_vec should be Vectorized<int64_t>, and it will convert it to Vectorized<int32_t>.
    auto int32_t_tmp_acc0_vec = at::vec::convert<int32_t,1,int64_t,2>(tmp_acc0_vec);
```
The issue is that tmp_acc0_vec is of type Vectorized<int32_t>, but the template parameters expect it to be Vectorized<int64_t>.  and it will convert it to a Vectorized<int32_t>. this is conflict. the conversion should not be exist for tmp_acc0_vec is already Vectorized<int32_t>.The following line hardcodes the output variable type to int64, which causes unnecessary and incorrect type conversions.
d89f30ad45/torch/_inductor/codegen/cpp.py (L2985-L2993)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157904
Approved by: https://github.com/jgong5
2025-08-07 08:03:05 +00:00
CaoE
efc4b460b3 Add cascade sum support for Inductor CPP backend (#156296)
Fixes #154703

Add cascade summation support for Inductor CPP backend to improve precision for large size summation.

Currently, Inductor CPP directly do reduction for sum. As shown in #154703, when the size of the sum is large and the number of parallel is small, direct reduction will cause an intolerable precision loss:
```
extern "C"  void kernel(float* in_out_ptr0,
                       const float* in_ptr0)
{
    auto out_ptr0 = in_out_ptr0;
    {
        {
            float tmp_acc0 = 0;
            at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(3000000000L); x0+=static_cast<int64_t>(16L))
            {
                {
                    if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(3000000000L)))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                        tmp_acc0_vec = tmp_acc0_vec + tmp0;
                    }
                }
            }
            tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float, 1>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
            out_ptr0[static_cast<int64_t>(0L)] = static_cast<float>(tmp_acc0);
        }
    }
    {
        {
            {
                auto tmp0 = out_ptr0[static_cast<int64_t>(0L)];
                auto tmp1 = static_cast<float>(3000000000.0);
                auto tmp2 = tmp0 / tmp1;
                in_out_ptr0[static_cast<int64_t>(0L)] = tmp2;
            }
        }
    }
}
```

After adding cascade sum support:

```
extern "C"  void kernel(float* in_out_ptr0,
                       const float* in_ptr0)
{
    auto out_ptr0 = in_out_ptr0;
    {
        {
            float tmp_acc0 = 0;
            at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
            at::vec::Vectorized<float> masked_tmp_acc0_vec = at::vec::Vectorized<float>(0);
            CascadeSumHelper<float, 65536> scalar_cascade_helper0(static_cast<int64_t>(3000000000L));
            CascadeSumHelper<at::vec::Vectorized<float>, 65536> cascade_helper0(static_cast<int64_t>(187500000L));
            CascadeSumHelper<at::vec::Vectorized<float>, 65536> masked_cascade_helper0(static_cast<int64_t>(0L));
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(3000000000L); x0+=static_cast<int64_t>(16L))
            {
                {
                    if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(3000000000L)))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                        tmp_acc0_vec = cascade_sum_combine(tmp0, &cascade_helper0);
                    }
                }
            }
            tmp_acc0 = cascade_sum_final(&scalar_cascade_helper0);
            tmp_acc0_vec = cascade_sum_final(&cascade_helper0);
            masked_tmp_acc0_vec = cascade_sum_final(&masked_cascade_helper0);
            tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float, 1>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec + masked_tmp_acc0_vec);
            out_ptr0[static_cast<int64_t>(0L)] = static_cast<float>(tmp_acc0);
        }
    }
    {
        {
            {
                auto tmp0 = out_ptr0[static_cast<int64_t>(0L)];
                auto tmp1 = static_cast<float>(3000000000.0);
                auto tmp2 = tmp0 / tmp1;
                in_out_ptr0[static_cast<int64_t>(0L)] = tmp2;
            }
        }
    }
}
```
This will inevitably reduce performance when cascade sum is turned on.
For the case shown in #154703: performance reduced by ~3%.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156296
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
2025-08-05 02:54:32 +00:00
Shangdi Yu
82a1ee1135 Refactor Provenance Tracking (#158399)
Summary:
As inductor provenance tracking is getting more use cases, we want to separate the inductor provenance tracking guarding flag from the general `trace.enabled`, so we can enable provenance tracking without all the overhead of `trace.enabled`

- change the guard flag from `trace.enabled` to `trace.provenance_tracking`.  It is turned on by either `TORCH_COMPILE_DEBUG=1` or `INDUCTOR_PROVENANCE=1`.
- Move the provenance tracking logic and variables out of DebugContext, because DebugContext is only enabled with `trace.enabled`. Since the variables are now global variables, added `reset_provenance_globals()` context manager to reset them for each `compile_fx()` call.
- Move `set_kernel_post_grad_provenance_tracing` from `util.py` to `debug.py` so now all provenance related logic is in `debug.py`.

In the future, if we want to enable it further, we can change the provenance tracking flag to be enabled when `TORCH_TRACE` is set. I think we should do that in a separate PR, so it's easier to revert if this flag change creates any problem.

See more motivation in internal Diff

Test Plan:
```
buck2 run mode/dev-nosan fbcode//caffe2/test:fx -- -r test_graph_transform_observer
buck run mode/dev-nosan  fbcode//caffe2/test:fx -- -r graph_provenance
buck2 run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing
```

Differential Revision: D78287976

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158399
Approved by: https://github.com/angelayi
2025-07-17 00:23:00 +00:00
Xiangyang (Mark) Guo
156a377f4c [AOTI][CPP] add flag TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL (#157949)
Summary: Add flag TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL to force inline the kernel function when TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL=1. It's disabled by default because force inlining may increase the build time.

Differential Revision: D77915987

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157949
Approved by: https://github.com/desertfire
2025-07-15 10:51:43 +00:00
Tom Ritchford
e3afbb0362 [inductor] Add typing to _inductor/ir.py (#149958)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149958
Approved by: https://github.com/Skylion007
2025-06-30 15:56:35 +00:00
leslie-fang-intel
1eea2c4fe3 [Inductor][CPP] Fix perf regression of functorch_maml_omniglot (#156526)
**Summary**
Fix the performance regression of `functorch_maml_omniglot` in TorchBench. The issue reported in [#151523](https://github.com/pytorch/pytorch/issues/151523) occurs only when a parallel reduction is performed under the vectorized loop and a scalar kernel is used for the tail loop. Previously, we addressed this regression in [#151887](https://github.com/pytorch/pytorch/pull/151887) by disabling all cases where a parallel reduction occurs under the vectorized  loop. However, for `functorch_maml_omniglot`, we found that a masked vector kernel is used in the tail loop instead of the scalar kernel in the job of `inductor_torchbench_cpu_smoketest_perf`. In this PR, we refine the fix by excluding the cases where a masked vector kernel is used in the tail loop, rather than disabling all such scenarios.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156526
Approved by: https://github.com/CaoE
2025-06-27 03:09:24 +00:00
Laith Sakka
26f7ca3972 Unify dynamic shapes APIs naming 2 (expect_true and check) attempt2 (#156518)
Summary:
The functions guard_lt, guard_equals, and guard_leq work similarly to torch.check and expect_true, but they operate on SymPy expressions. Notably, guard_equals applies local replacements before comparison, which might be better extracted into a separate function.

This pull request standardizes naming conventions to match symbolic_shapes.py. Specifically,
-  it introduces size_vars.expect_true and size_vars.check.
- guard_lt becomes check_lt
- guard_leq becomes check_leq
- guard_equals becomes check_equals

I am also seeing a couple of wrong usages !! that i will fix  in the next PR

Test Plan:
OSS and cont

Rollback Plan:

Differential Revision: D77054177

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156518
Approved by: https://github.com/bobrenjc93
2025-06-24 21:01:38 +00:00
Pandya, Vivek Vasudevbhai
e0ae4ecca8 Refactor cpp codegen to support overridable class attributes. (#155553)
- Refactored CppKernelProxy and CppScheduling to use class-level attributes (kernel_cls, kernel_proxy_cls) for backend-specific kernel customization.
 - Avoids method duplication (e.g., codegen_functions, codegen_node) for backend-specific overrides thus reduces downstream maintenance when upgrading Torch.
 - Ensures type safety with annotations while keeping core logic centralized and extensible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155553
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
2025-06-23 07:36:30 +00:00
Xuehai Pan
6ff6630375 [BE][3/16] fix typos in torch/ (torch/_inductor/) (#156313)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156313
Approved by: https://github.com/jingsh
2025-06-23 02:57:12 +00:00
PyTorch MergeBot
f1331f3f1b Revert "[BE][3/16] fix typos in torch/ (torch/_inductor/) (#156313)"
This reverts commit 3627270bdf.

Reverted https://github.com/pytorch/pytorch/pull/156313 on behalf of https://github.com/atalman due to export/test_torchbind.py::TestCompileTorchbind::test_compile_error_on_input_aliasing_contents_backend_aot_eager [GH job link](https://github.com/pytorch/pytorch/actions/runs/15804799771/job/44548489912) [HUD commit link](c95f7fa874) ([comment](https://github.com/pytorch/pytorch/pull/156313#issuecomment-2994171213))
2025-06-22 12:31:57 +00:00
Xuehai Pan
3627270bdf [BE][3/16] fix typos in torch/ (torch/_inductor/) (#156313)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156313
Approved by: https://github.com/jingsh
2025-06-22 08:43:09 +00:00
PyTorch MergeBot
503362d019 Revert "Unify dynamic shapes APIs naming 2 (expect_true and check) (#155776)"
This reverts commit 603a54a9b3.

Reverted https://github.com/pytorch/pytorch/pull/155776 on behalf of https://github.com/atalman due to failing internal build ([comment](https://github.com/pytorch/pytorch/pull/155776#issuecomment-2977041192))
2025-06-16 15:13:53 +00:00
Aaron Orenstein
e95e8eed0a mypy 1.16.0 (#155821)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155821
Approved by: https://github.com/ezyang, https://github.com/zou3519
2025-06-14 18:18:43 +00:00
Laith Sakka
603a54a9b3 Unify dynamic shapes APIs naming 2 (expect_true and check) (#155776)
The functions guard_lt, guard_equals, and guard_leq work similarly to torch.check and expect_true, but they operate on SymPy expressions. Notably, guard_equals applies local replacements before comparison, which might be better extracted into a separate function.

This pull request standardizes naming conventions to match symbolic_shapes.py. Specifically,
-  it introduces size_vars.expect_true and size_vars.check.
- guard_lt becomes check_lt
- guard_leq becomes check_leq
- guard_equals becomes check_equals

I am also seeing a couple of wrong usages !! that i will fix  in the next PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155776
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #154774
2025-06-14 17:13:53 +00:00
Oguz Ulgen
d1947a8707 Migrate from lru_cache to cache (#155613)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155613
Approved by: https://github.com/ezyang
ghstack dependencies: #155612
2025-06-11 19:44:18 +00:00
PyTorch MergeBot
7e4c097b07 Revert "[inductor] Add typing to _inductor/ir.py (#149958)"
This reverts commit 529e0357c6.

Reverted https://github.com/pytorch/pytorch/pull/149958 on behalf of https://github.com/malfet due to Looks like it broke inductor_torchbind tests, due to more graphbreaks, see b0fbbef136/1 ([comment](https://github.com/pytorch/pytorch/pull/149958#issuecomment-2949583209))
2025-06-06 15:19:16 +00:00
Tom Ritchford
529e0357c6 [inductor] Add typing to _inductor/ir.py (#149958)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149958
Approved by: https://github.com/Skylion007
2025-06-06 14:15:01 +00:00
eellison
0827464002 Replace runtime type parameterization (#155221)
See:

```
>>> import timeit; print(f"OrderedSet[str](): {timeit.timeit('OrderedSet[str]()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s, OrderedSet(): {timeit.timeit('OrderedSet()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s")
```
> `OrderedSet[str]()`: 0.354622s, OrderedSet(): 0.095376s

Type parameterization should be on type hint, not in runtime.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155221
Approved by: https://github.com/Skylion007, https://github.com/jansel
2025-06-05 21:43:54 +00:00
Aaron Gokaslan
bbda22e648 [BE][Ez]: Optimize unnecessary lambda with operator (#154722)
Automated edits performed by FURB118. Operator is implemented in C and way faster when passed to another C method like sorted, max etc as a `key=`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154722
Approved by: https://github.com/jansel
2025-05-30 23:47:10 +00:00
leslie-fang-intel
7ba6fb69e6 [Inductor][CPP] Enable vectorized fp8 E5M2 quant dequant (#153365)
**Summary**
This PR enables the vectorization codegen with Inductor CPP backend for `FP8_E5M2` `quant` from `float32` and `dequant` to `float32`.

**Test Plan**
```
python test/inductor/test_cpu_repro.py -k test_dequant_quant_lowering_fp8_e5m2
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153365
Approved by: https://github.com/jansel, https://github.com/jgong5
ghstack dependencies: #152417, #152418, #153364
2025-05-23 23:20:02 +00:00
leslie-fang-intel
b77a6504fa [Inductor][CPP] Enable vectorized fp8 quant dequant (#152418)
**Summary**
This PR enables the vectorization codegen with Inductor CPP backend for `FP8_E4M3` `quant` from `float32` and `dequant` to `float32`.

**Test Plan**
```
python test/inductor/test_cpu_repro.py -k test_dequant_quant_lowering_fp8_e4m3
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152418
Approved by: https://github.com/jansel, https://github.com/jgong5, https://github.com/CaoE
ghstack dependencies: #152417
2025-05-23 23:05:17 +00:00
Benjamin Glass
cda572b053 codecache: Remove cpp_prefix.h duplication per build, then precompile it (#144293)
Prior to this PR, `_inductor/codegen/cpp_prefix.h` was copied into a new temporary directory on every inductor run utilizing the CPP backend (i.e. CPU-only), then included in the output source code. Instead, this PR puts it in an appropriate place in the torch includes, and includes it from there. This allows us to precompile it in cpp_wrapper and AOT inductor mode, saving significant compilation time.

Due to difficulties getting this to work in FBCode, the precompilation itself is only enabled in OSS PyTorch.

Differential Revision: [D69420620](https://our.internmc.facebook.com/intern/diff/D69420620)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144293
Approved by: https://github.com/desertfire
2025-05-16 17:41:36 +00:00
Bin Bao
33a5179269 [AOTI][reland2] Remove typedef for half and bfloat16 (#153467)
Summary:
Reland https://github.com/pytorch/pytorch/pull/151109 after fixing cutlass AOTI build issues.

typedef is prone to name collision. Explicitly spell out the actual aten types, needed for the standalone AOTI codegen.

Differential Revision: D74398762

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153467
Approved by: https://github.com/jingsh, https://github.com/henrylhtsang, https://github.com/cyyever
2025-05-14 02:37:18 +00:00
yuchengliu1
13bdfe6577 get right function declaration on windows inductor (#152939)
Fixes #152251

`get_export_declaration` introduced one more ')' in Windows platform, which cause this pattern of function declaration different with Linux.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152939
Approved by: https://github.com/xuhancn, https://github.com/jansel
2025-05-08 14:28:33 +00:00
Huamin Li
6f6acb4128 [AOTI][CPU] Introduce config.cpp.use_decompose_tanh (#152542)
Summary: Previously D70489427 changed tanh impl to `.tanh()`, and this is causing some meta internal workload perf regression. This diff will introduce a config so we can set it based on need.

Differential Revision: D73909371

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152542
Approved by: https://github.com/desertfire
2025-05-01 10:25:31 +00:00
PyTorch MergeBot
471025c489 Revert "[AOTI][reland] Remove typedef for half and bfloat16 (#151109)"
This reverts commit a0d440a26a.

Reverted https://github.com/pytorch/pytorch/pull/151109 on behalf of https://github.com/wdvr due to causing AOTI test failures - discussed with author ([comment](https://github.com/pytorch/pytorch/pull/151109#issuecomment-2840386483))
2025-04-29 22:37:16 +00:00
Anthony Shoumikhin
e2f9759bd0 Fix broken URLs (#152237)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152237
Approved by: https://github.com/huydhn, https://github.com/malfet
2025-04-27 09:56:42 +00:00
Bin Bao
a0d440a26a [AOTI][reland] Remove typedef for half and bfloat16 (#151109)
Summary: Reland https://github.com/pytorch/pytorch/pull/150657

typedef is prone to name collision. Explicitly spell out the actual aten types, needed for the libtorch-free codegen.

Differential Revision: [D72878456](https://our.internmc.facebook.com/intern/diff/D72878456)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151109
Approved by: https://github.com/angelayi
2025-04-26 23:17:35 +00:00
leslie-fang-intel
68a7501dab [Inductor][CPP] Fix Codegen Issue when Parallel Reduction under the vectorization (#151887)
**Summary**
Fixes [#151290](https://github.com/pytorch/pytorch/issues/151290) and [#151523](https://github.com/pytorch/pytorch/issues/151523), which are regressions introduced by [#144020](https://github.com/pytorch/pytorch/pull/144020). That PR enabled parallelization at the inner loop level.

However, a currently unsupported case arises when parallel reduction occurs under the vectorization loop level, specifically in patterns like:
```
for vec_loop_level:
    do_parallel_reduction
```
In such cases, a temporary buffer `tmp_acc_array` is allocated for tail scalar kernels, and another temporary buffer `tmp_acc_array` is also defined for parallel reduction. This results in a conflict due to overlapping temporary buffers. This PR disables the problematic case to avoid the conflict until proper support is implemented.

**Test Plan**
```
python test/inductor/test_flex_attention.py -k test_make_block_mask_cpu
python test/inductor/test_cpu_repro.py -k test_parallel_reduction_vectorization
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151887
Approved by: https://github.com/jansel
2025-04-23 00:41:14 +00:00
PyTorch MergeBot
31162214d8 Revert "[AOTI] Remove typedef for half and bfloat16 (#150657)"
This reverts commit 357814c85c.

Reverted https://github.com/pytorch/pytorch/pull/150657 on behalf of https://github.com/atalman due to failing internally ([comment](https://github.com/pytorch/pytorch/pull/150657#issuecomment-2795042772))
2025-04-10 20:08:03 +00:00
Bin Bao
357814c85c [AOTI] Remove typedef for half and bfloat16 (#150657)
Summary: typedef is prone to name collision. Explicitly spell out the actual aten types, needed for the libtorch-free codegen.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150657
Approved by: https://github.com/malfet
2025-04-09 21:21:17 +00:00
Sun, Jiayi
5cb5675f13 [Inductor] optimize the heuristics of parallel reduction (#149614)
Fix https://github.com/pytorch/pytorch/issues/148639.

Summary:
Optimize the heuristics of parallel reduction: When the number of steps of the first inner loop beyond the maximum parallel depth is much larger than the number of steps of all outer loops within the maximum parallel depth, change the starting depth of parallelism to the first inner loop and recalculate the maximum parallel depth. I ran the Inductor benchmark with this PR on CPU. A timm model poolformer_m36 BF16 has about 25% performance improvement, and no performance regression is seen.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149614
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
2025-04-01 01:31:00 +00:00
Ding, Yi1
f7d1b966c2 [Inductor] Unify the data type propagation between Triton and CPP Backend (#146970)
Fixes #144246

Use `DtypePropagationOpsHandler` for CSE variables of CPP backend. In addition, add static type checking for the generated CPP code similar to the `config.test_configs.runtime_triton_dtype_assert`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146970
Approved by: https://github.com/jgong5, https://github.com/eellison, https://github.com/leslie-fang-intel
2025-03-21 17:52:51 +00:00
Rachel Guo
b8f91bcb14 [pt2_provenance_tracking] add support for cpp kernel (#149185)
Summary:
As title.

Add inductor cpp kernel to post grad graph node mapping
& UT.

Context:
Raised as a feature request for AOTI CPU case.

https://fb.workplace.com/groups/1028545332188949/permalink/1169020841474730/

Differential Revision: D71181284

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149185
Approved by: https://github.com/jingsh
2025-03-18 04:43:07 +00:00
Sun, Jiayi
c36ac16da1 [Inductor] optimize welford reduction (#145061)
Fix https://github.com/pytorch/pytorch/issues/141541.
Fix https://github.com/pytorch/pytorch/issues/142839.
Fix https://github.com/pytorch/pytorch/issues/143182.

**Summary:**
In order to fix the issue that the accuracy of welford reduction is not good enough, we refer to the eager implementation, combine Welford algorithm with cascade sum to improve numerical stability. Specifically:
1. Use Welford algorithm to compute mean and variance.
2. Use cascade summation when computing sum over input for both mean and variance.

I tested Inductor benchmark with this PR on CPU, no performance gains or regressions were seen.

**Example:**
Take https://github.com/pytorch/pytorch/issues/141541 as an example:
```
import torch
import torch.nn as nn
torch.manual_seed(0)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.gn = nn.GroupNorm(num_groups=32, num_channels=32)

    def forward(self, x):
        return self.gn(x)

model = Model().eval()
c_model = torch.compile(model)
x = torch.randn(1, 32, 128, 128, 128)

with torch.no_grad():
    output = model(x)
    c_output = c_model(x)

print(torch.max(torch.abs(output - c_output)))
print(torch.allclose(output, c_output, 1.3e-6, 1e-5))
```
**logs**

- before
```
tensor(7.0095e-05)
False
```
- After
```
tensor(9.5367e-07)
True
```

- on CUDA
```
tensor(1.4305e-06, device='cuda:0', grad_fn=<MaxBackward1>)
True
```

**Generated code:**
- before
```
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C"  void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2)
{
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                Welford<float> tmp_acc0 = Welford<float>();
                Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<int64_t>(131072L));
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(16L))
                {
                    {
                        if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(2097152L)))
                        {
                            auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1 + 2097152L*x0), static_cast<int64_t>(16));
                            tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0);
                        }
                    }
                }
                tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec));
                tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                out_ptr0[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0.mean);
                out_ptr1[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0.m2);
            }
        }
    }
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(16L))
            {
                {
                    if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(2097152L)))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1 + 2097152L*x0), static_cast<int64_t>(16));
                        auto tmp1 = out_ptr0[static_cast<int64_t>(x0)];
                        auto tmp4 = out_ptr1[static_cast<int64_t>(x0)];
                        auto tmp12 = in_ptr1[static_cast<int64_t>(x0)];
                        auto tmp15 = in_ptr2[static_cast<int64_t>(x0)];
                        auto tmp2 = at::vec::Vectorized<float>(tmp1);
                        auto tmp3 = tmp0 - tmp2;
                        auto tmp5 = static_cast<float>(2097152.0);
                        auto tmp6 = tmp4 / tmp5;
                        auto tmp7 = static_cast<float>(1e-05);
                        auto tmp8 = decltype(tmp6)(tmp6 + tmp7);
                        auto tmp9 = 1 / std::sqrt(tmp8);
                        auto tmp10 = at::vec::Vectorized<float>(tmp9);
                        auto tmp11 = tmp3 * tmp10;
                        auto tmp13 = at::vec::Vectorized<float>(tmp12);
                        auto tmp14 = tmp11 * tmp13;
                        auto tmp16 = at::vec::Vectorized<float>(tmp15);
                        auto tmp17 = tmp14 + tmp16;
                        tmp17.store(out_ptr2 + static_cast<int64_t>(x1 + 2097152L*x0));
                    }
                }
            }
        }
    }
}
''')
```
- After
```
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/ln/clnlak27xpvmq3klpqyj6xzyq2thf4ecrezve5ddy4f4xaz4sb7w.h"
extern "C"  void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2)
{
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                Welford<float> tmp_acc0 = Welford<float>();
                Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                WelfordHelper<at::vec::Vectorized<float>> welford_helper0(static_cast<int64_t>(131072L));
                static WelfordHelper<at::vec::Vectorized<float>> masked_welford_helper0(static_cast<int64_t>(0L));
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(16L))
                {
                    {
                        if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(2097152L)))
                        {
                            auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1 + 2097152L*x0), static_cast<int64_t>(16));
                            tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &welford_helper0);
                        }
                    }
                }
                tmp_acc0_vec = welford_combine(tmp_acc0_vec, &welford_helper0);
                masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, &masked_welford_helper0);
                tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec));
                tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                out_ptr0[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0.mean);
                out_ptr1[static_cast<int64_t>(x0)] = static_cast<float>(tmp_acc0.m2);
            }
        }
    }
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(16L))
            {
                {
                    if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(2097152L)))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1 + 2097152L*x0), static_cast<int64_t>(16));
                        auto tmp1 = out_ptr0[static_cast<int64_t>(x0)];
                        auto tmp4 = out_ptr1[static_cast<int64_t>(x0)];
                        auto tmp12 = in_ptr1[static_cast<int64_t>(x0)];
                        auto tmp15 = in_ptr2[static_cast<int64_t>(x0)];
                        auto tmp2 = at::vec::Vectorized<float>(tmp1);
                        auto tmp3 = tmp0 - tmp2;
                        auto tmp5 = static_cast<float>(2097152.0);
                        auto tmp6 = tmp4 / tmp5;
                        auto tmp7 = static_cast<float>(1e-05);
                        auto tmp8 = decltype(tmp6)(tmp6 + tmp7);
                        auto tmp9 = 1 / std::sqrt(tmp8);
                        auto tmp10 = at::vec::Vectorized<float>(tmp9);
                        auto tmp11 = tmp3 * tmp10;
                        auto tmp13 = at::vec::Vectorized<float>(tmp12);
                        auto tmp14 = tmp11 * tmp13;
                        auto tmp16 = at::vec::Vectorized<float>(tmp15);
                        auto tmp17 = tmp14 + tmp16;
                        tmp17.store(out_ptr2 + static_cast<int64_t>(x1 + 2097152L*x0));
                    }
                }
            }
        }
    }
}
''')
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145061
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jansel
2025-03-18 02:05:35 +00:00
Jason Ansel
b040dc3a53 Reland: [inductor] Simplify grid handling (#148305)
Summary:
Relands D69965761 / https://github.com/pytorch/pytorch/pull/147583

Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg.  This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
    grid_0 = ((xnumel + 1023) >> 10)
    grid_1 = 1
    grid_2 = 1
    runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```

This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.

It also allows us to unify the handling of grids between the Python and C++ wrapper code.  Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.

This unification allows this PR to be a net deletion of code.

Differential [disconnected] Revision: D70471332

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148305
Approved by: https://github.com/shunting314, https://github.com/eellison
2025-03-12 15:52:16 +00:00
PyTorch MergeBot
5ada4e6a53 Revert "Reland: [inductor] Simplify grid handling (#148305)"
This reverts commit 8d08b49015.

Reverted https://github.com/pytorch/pytorch/pull/148305 on behalf of https://github.com/jithunnair-amd due to Broke ROCm CI ([comment](https://github.com/pytorch/pytorch/pull/148305#issuecomment-2718177044))
2025-03-12 14:58:43 +00:00
leslie-fang-intel
f349304c08 [Inductor][CPP] Fix expr issue in loop split (#148882)
**Summary**
Fix issue: https://github.com/pytorch/pytorch/issues/148058. In this case, there is an `indexing_expr` as an integer which doesn't have the method of `find`.

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_issue_148058
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148882
Approved by: https://github.com/jgong5
2025-03-12 11:08:07 +00:00
Jason Ansel
8d08b49015 Reland: [inductor] Simplify grid handling (#148305)
Summary:
Relands D69965761 / https://github.com/pytorch/pytorch/pull/147583

Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg.  This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
    grid_0 = ((xnumel + 1023) >> 10)
    grid_1 = 1
    grid_2 = 1
    runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```

This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.

It also allows us to unify the handling of grids between the Python and C++ wrapper code.  Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.

This unification allows this PR to be a net deletion of code.

Differential Revision: D70471332

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148305
Approved by: https://github.com/shunting314, https://github.com/eellison
2025-03-11 18:51:06 +00:00
Aaron Gokaslan
edd640a95a [BE][Ez]: Use itertools.chain.from_iterable when possible (#148190)
Often makes the code more readable, more efficient, and adds support for infinite iterables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148190
Approved by: https://github.com/jansel, https://github.com/malfet
2025-03-06 20:37:06 +00:00
Benjamin Glass
d6d670ab4d [AOTI] build CPU CPP kernels at O3, and all other code at O1 (#148587)
In the future, we may also want to add LTO linking to further optimize the results (while still hopefully netting compile time benefits).

Differential Revision: [D70641543](https://our.internmc.facebook.com/intern/diff/D70641543)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148587
Approved by: https://github.com/desertfire
2025-03-05 22:47:46 +00:00
leslie-fang-intel
165e33531c [Inductor][CPP] Fix the vec codegen for tanh (#148254)
**Summary**
Fix https://github.com/pytorch/pytorch/issues/148241, The previous vectorized code generation for `tanh` used a decomposed implementation, leading to numerical differences that were further amplified by `atan2`. For example, in the given test case after `tanh`, the eager output at `[0,0,11,47]` was `-5.820766091346741e-10`, while the compiled output was `1.4319084584712982e-08`, resulting in different `atan2` outputs of `-2.3561` and `0.7853`. This issue is fixed by switching to the Sleef implementation.

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_tanh_atan2
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148254
Approved by: https://github.com/malfet, https://github.com/jgong5
2025-03-03 11:46:57 +00:00
PyTorch MergeBot
608377d341 Revert "[import][inductor] Simplify grid handling (#147583)"
This reverts commit b59776d857.

Reverted https://github.com/pytorch/pytorch/pull/147583 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/147583#issuecomment-2693016036))
2025-03-03 00:49:32 +00:00
Jason Ansel
b59776d857 [import][inductor] Simplify grid handling (#147583)
Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg.  This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
    grid_0 = ((xnumel + 1023) >> 10)
    grid_1 = 1
    grid_2 = 1
    runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```

This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.

It also allows us to unify the handling of grids between the Python and C++ wrapper code.  Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.

This unification allows this PR to be a net deletion of code.

Note the attached diff contains some minor fbcode-only changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147583
Approved by: https://github.com/eellison, https://github.com/shunting314
2025-03-02 07:31:07 +00:00
Sun, Jiayi
d23051f29b [Inductor] Support parallel reduction for GroupNorm (#144020)
Summary:
Support parallel reduction for GroupNorm by optimizing the parallelization heuristics: When the range of the first inner loop is much larger than the range of all outer loops, change the starting depth of parallelization to the first inner loop.
I tested the Inductor benchmark with this PR on CPU. One torchbench model(pytorch_CycleGAN_and_pix2pix) achieved ~45% performance improvement, and two diffusion models(Stable Diffusion and Latent Consistency Model(LCM)) achieved ~2% performance improvement.

Example:
```
import torch
import torch.nn as nn

class GN(nn.Module):
    def __init__(self, num_groups, num_channels):
        super(GN, self).__init__()
        self.gn = nn.GroupNorm(num_groups, num_channels)

    def forward(self, x):
        return self.gn(x)

x = torch.randn(2, 64, 168, 168).to(memory_format=torch.channels_last)
m = GN(2, 64).eval()
compiled_m = torch.compile(m)

with torch.no_grad():
    out = compiled_m(x)
```

Generated code:

- Before:
```
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C"  void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2,
                       float* out_ptr3,
                       float* out_ptr4)
{
    #pragma omp parallel num_threads(56)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
            {
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2L); x1+=static_cast<int64_t>(1L))
                {
                    {
                        Welford<float> tmp_acc0 = Welford<float>();
                        Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<int64_t>(56448L));
                        for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(28224L); x2+=static_cast<int64_t>(1L))
                        {
                            for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(32L); x3+=static_cast<int64_t>(16L))
                            {
                                {
                                    if(C10_LIKELY(x3 >= static_cast<int64_t>(0) && x3 < static_cast<int64_t>(32L)))
                                    {
                                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + 32L*x1 + 64L*x2 + 1806336L*x0), static_cast<int64_t>(16));
                                        tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0);
                                    }
                                }
                            }
                        }
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec));
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                        out_ptr0[static_cast<int64_t>(x1 + 2L*x0)] = static_cast<float>(tmp_acc0.mean);
                        out_ptr1[static_cast<int64_t>(x1 + 2L*x0)] = static_cast<float>(tmp_acc0.m2);
                    }
                }
            }
        }
        #pragma omp single
        {
            {
                #pragma GCC ivdep
                for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
                {
                    #pragma GCC ivdep
                    for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2L); x1+=static_cast<int64_t>(1L))
                    {
                        for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(32L); x2+=static_cast<int64_t>(16L))
                        {
                            {
                                if(C10_LIKELY(x2 >= static_cast<int64_t>(0) && x2 < static_cast<int64_t>(32L)))
                                {
                                    auto tmp0 = out_ptr1[static_cast<int64_t>(x1 + 2L*x0)];
                                    auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x2 + 32L*x1), static_cast<int64_t>(16));
                                    auto tmp9 = out_ptr0[static_cast<int64_t>(x1 + 2L*x0)];
                                    auto tmp13 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x2 + 32L*x1), static_cast<int64_t>(16));
                                    auto tmp1 = static_cast<float>(903168.0);
                                    auto tmp2 = tmp0 / tmp1;
                                    auto tmp3 = static_cast<float>(1e-05);
                                    auto tmp4 = decltype(tmp2)(tmp2 + tmp3);
                                    auto tmp5 = 1 / std::sqrt(tmp4);
                                    auto tmp7 = at::vec::Vectorized<float>(tmp5);
                                    auto tmp8 = tmp7 * tmp6;
                                    auto tmp10 = decltype(tmp9)(-tmp9);
                                    auto tmp11 = at::vec::Vectorized<float>(tmp10);
                                    auto tmp12 = tmp11 * tmp8;
                                    auto tmp14 = tmp12 + tmp13;
                                    tmp8.store(out_ptr2 + static_cast<int64_t>(x2 + 32L*x1 + 64L*x0));
                                    tmp14.store(out_ptr3 + static_cast<int64_t>(x2 + 32L*x1 + 64L*x0));
                                }
                            }
                        }
                    }
                }
            }
        }
        {
            #pragma omp for collapse(2)
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
            {
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(28224L); x1+=static_cast<int64_t>(1L))
                {
                    for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(64L); x2+=static_cast<int64_t>(16L))
                    {
                        {
                            if(C10_LIKELY(x2 >= static_cast<int64_t>(0) && x2 < static_cast<int64_t>(64L)))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + 64L*x1 + 1806336L*x0), static_cast<int64_t>(16));
                                auto tmp1 = at::vec::Vectorized<float>::loadu(out_ptr2 + static_cast<int64_t>(x2 + 64L*x0), static_cast<int64_t>(16));
                                auto tmp3 = at::vec::Vectorized<float>::loadu(out_ptr3 + static_cast<int64_t>(x2 + 64L*x0), static_cast<int64_t>(16));
                                auto tmp2 = tmp0 * tmp1;
                                auto tmp4 = tmp2 + tmp3;
                                tmp4.store(out_ptr4 + static_cast<int64_t>(x2 + 64L*x1 + 1806336L*x0));
                            }
                        }
                    }
                }
            }
        }
    }
}
''')
```

- After:
```
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C"  void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2,
                       float* out_ptr3,
                       float* out_ptr4)
{
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
        {
            #pragma GCC ivdep
            for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2L); x1+=static_cast<int64_t>(1L))
            {
                {
                    Welford<float> tmp_acc0 = Welford<float>();
                    Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                    Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                    Welford<at::vec::Vectorized<float>> tmp_acc0_vec_arr[56];
                    for (int i = 0; i < 56; i++)
                    {
                        tmp_acc0_vec_arr[i] = Welford<at::vec::Vectorized<float>>();
                    }
                    Welford<float> tmp_acc0_arr[56];
                    for (int i = 0; i < 56; i++)
                    {
                        tmp_acc0_arr[i] = Welford<float>();
                    }
                    Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec_arr[56];
                    for (int i = 0; i < 56; i++)
                    {
                        masked_tmp_acc0_vec_arr[i] = Welford<at::vec::Vectorized<float>>();
                    }
                    #pragma omp parallel num_threads(56)
                    {
                        int tid = omp_get_thread_num();
                        static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<int64_t>(1008L));
                        Welford<at::vec::Vectorized<float>> tmp_acc0_vec_local = Welford<at::vec::Vectorized<float>>();
                        Welford<float> tmp_acc0_local = Welford<float>();
                        Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec_local = Welford<at::vec::Vectorized<float>>();
                        #pragma omp for
                        for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(28224L); x2+=static_cast<int64_t>(1L))
                        {
                            for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(32L); x3+=static_cast<int64_t>(16L))
                            {
                                {
                                    if(C10_LIKELY(x3 >= static_cast<int64_t>(0) && x3 < static_cast<int64_t>(32L)))
                                    {
                                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + 32L*x1 + 64L*x2 + 1806336L*x0), static_cast<int64_t>(16));
                                        tmp_acc0_vec_local = welford_combine(tmp_acc0_vec_local, tmp0, &wrecps0);
                                    }
                                }
                            }
                        }
                        tmp_acc0_vec_arr[tid] = tmp_acc0_vec_local;
                        tmp_acc0_arr[tid] = tmp_acc0_local;
                        masked_tmp_acc0_vec_arr[tid] = masked_tmp_acc0_vec_local;
                    }
                    for (int tid = 0; tid < 56; tid++)
                    {
                        tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp_acc0_vec_arr[tid]);
                    }
                    for (int tid = 0; tid < 56; tid++)
                    {
                        tmp_acc0 = welford_combine(tmp_acc0, tmp_acc0_arr[tid]);
                    }
                    for (int tid = 0; tid < 56; tid++)
                    {
                        masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, masked_tmp_acc0_vec_arr[tid]);
                    }
                    tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec));
                    tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                    out_ptr0[static_cast<int64_t>(x1 + 2L*x0)] = static_cast<float>(tmp_acc0.mean);
                    out_ptr1[static_cast<int64_t>(x1 + 2L*x0)] = static_cast<float>(tmp_acc0.m2);
                }
            }
        }
    }
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
        {
            #pragma GCC ivdep
            for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2L); x1+=static_cast<int64_t>(1L))
            {
                for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(32L); x2+=static_cast<int64_t>(16L))
                {
                    {
                        if(C10_LIKELY(x2 >= static_cast<int64_t>(0) && x2 < static_cast<int64_t>(32L)))
                        {
                            auto tmp0 = out_ptr1[static_cast<int64_t>(x1 + 2L*x0)];
                            auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x2 + 32L*x1), static_cast<int64_t>(16));
                            auto tmp9 = out_ptr0[static_cast<int64_t>(x1 + 2L*x0)];
                            auto tmp13 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x2 + 32L*x1), static_cast<int64_t>(16));
                            auto tmp1 = static_cast<float>(903168.0);
                            auto tmp2 = tmp0 / tmp1;
                            auto tmp3 = static_cast<float>(1e-05);
                            auto tmp4 = decltype(tmp2)(tmp2 + tmp3);
                            auto tmp5 = 1 / std::sqrt(tmp4);
                            auto tmp7 = at::vec::Vectorized<float>(tmp5);
                            auto tmp8 = tmp7 * tmp6;
                            auto tmp10 = decltype(tmp9)(-tmp9);
                            auto tmp11 = at::vec::Vectorized<float>(tmp10);
                            auto tmp12 = tmp11 * tmp8;
                            auto tmp14 = tmp12 + tmp13;
                            tmp8.store(out_ptr2 + static_cast<int64_t>(x2 + 32L*x1 + 64L*x0));
                            tmp14.store(out_ptr3 + static_cast<int64_t>(x2 + 32L*x1 + 64L*x0));
                        }
                    }
                }
            }
        }
    }
    #pragma omp parallel num_threads(56)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
            {
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(28224L); x1+=static_cast<int64_t>(1L))
                {
                    for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(64L); x2+=static_cast<int64_t>(16L))
                    {
                        {
                            if(C10_LIKELY(x2 >= static_cast<int64_t>(0) && x2 < static_cast<int64_t>(64L)))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + 64L*x1 + 1806336L*x0), static_cast<int64_t>(16));
                                auto tmp1 = at::vec::Vectorized<float>::loadu(out_ptr2 + static_cast<int64_t>(x2 + 64L*x0), static_cast<int64_t>(16));
                                auto tmp3 = at::vec::Vectorized<float>::loadu(out_ptr3 + static_cast<int64_t>(x2 + 64L*x0), static_cast<int64_t>(16));
                                auto tmp2 = tmp0 * tmp1;
                                auto tmp4 = tmp2 + tmp3;
                                tmp4.store(out_ptr4 + static_cast<int64_t>(x2 + 64L*x1 + 1806336L*x0));
                            }
                        }
                    }
                }
            }
        }
    }
}
''')
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144020
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel, https://github.com/jgong5
2025-03-01 17:11:50 +00:00