Commit Graph

292 Commits

Author SHA1 Message Date
Edward Z. Yang
6f70d22277 Extend torch.utils._sympy.symbol for more Inductor symbols (#125419)
I'm still missing a few, cdzq at least

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125419
Approved by: https://github.com/lezcano
ghstack dependencies: #125395
2024-05-04 09:05:00 +00:00
haozhe.zhu
57790fd088 [inductor] share cse cache during vectorized indirect load (#124597)
Fix https://github.com/pytorch/pytorch/issues/123502

`swap_buffer` in not needed in vectorized indirect load, remove it to share cse buffer.
```
auto tmp8 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
//
// other codes
//
// also store tmp7 here (redundant tmp16)
auto tmp16 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124597
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-04-28 01:02:48 +00:00
leslie-fang-intel
2d7f709752 [Inductor] Force the parallel depth as outer loop fusion depth (#123899)
**Summary**
Fix issue: https://github.com/pytorch/pytorch/issues/123801 which brings performance regression of `pyhpc_turbulent_kinetic_energy` after outer loop fusion.

**Root Cause**

- [Generated Kernel before Outer Loop Fusion](https://gist.github.com/leslie-fang-intel/54fe21ac8871fc63b9bf20fdb6edf209)
  - Taking below 2 kernels as example:
    - [Kernel 0](https://gist.github.com/leslie-fang-intel/54fe21ac8871fc63b9bf20fdb6edf209#file-pyhpc_turbulent_kinetic_energy-before-outer-loop-fusion-py-L255-L305) has 2 loop levels with size [200, 200]. Parallelization is not feasible due to the inefficient number of elements determined by [`decide_parallel_depth`](aaec97a403/torch/_inductor/codegen/cpp.py (L2145-L2164)). Therefore, the loop code will be generated with the `#pragma omp single` directive.
    - [Kernel 1](https://gist.github.com/leslie-fang-intel/54fe21ac8871fc63b9bf20fdb6edf209#file-pyhpc_turbulent_kinetic_energy-before-outer-loop-fusion-py-L306-L316) has 3 loop levels with size [200, 200, 26] which has enough number of elements to be parallelized.
- [Generated Kernel after Outer Loop Fusion](https://gist.github.com/leslie-fang-intel/57a497b9d9c6aa82b1c6a686292fc887)
  - After outer loop fusion, `Kernel0` and `Kernel1` has been fused into one [OuterLoopFusedKernel](https://gist.github.com/leslie-fang-intel/57a497b9d9c6aa82b1c6a686292fc887#file-pyhpc_turbulent_kinetic_energy-after-outer-loop-fusion-py-L261-L497), the outer loop size is [200, 200] which does not contain enough number of elements to do parallelization.

In this PR, we propose a fix for `loop_nest` involving `OuterLoopFusedKernel`. The fix entails adding a specific heuristic for `OuterLoopFusedKernel` to determine the parallel depth by combining `outer_loop_fusion_depth` with the internal kernels' parallel depth.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123899
Approved by: https://github.com/jgong5, https://github.com/lezcano
2024-04-25 09:50:46 +00:00
leslie-fang-intel
bffecb5aff [Inductor] Enable VecMask store (#123710)
**Summary**
Enable the vectorization of store with `bool` dtype.

**Test Plan**
```
python -u -m pytest -s -v inductor/test_cpu_repro.py -k test_decomposed_fake_quant_per_channel
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123710
Approved by: https://github.com/jgong5, https://github.com/lezcano
ghstack dependencies: #123512
2024-04-23 00:29:47 +00:00
Aaron Gokaslan
29cc293725 [BE]: FURB142 - Remove set mutations. Use set update (#124551)
Uses set mutation methods instead of manually reimplementing (update, set_difference etc).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124551
Approved by: https://github.com/ezyang
2024-04-21 14:12:33 +00:00
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
Edward Z. Yang
efa36ef092 Natively support int truncation, don't guard on positive/negative (#122827)
This doesn't entirely fix the original problem that prompted this, but
it seems to just be getting stuck in export constraint formatting now
which seems like progress to me.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122827
Approved by: https://github.com/avikchaudhuri
2024-04-11 15:22:32 +00:00
vfdev-5
6b7741546b Fixed arange decomp for float dtype (#121013)
## Description:

- [x] Fixed arange decomp for float dtype
- [x] Added a test

## Current state

Arange graph and C++ generated code are not optimal when arange is created directly using float32 dtype:
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s, dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on `main`:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:8 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f64[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float64);  iota = None
        mul: "f64[10]" = torch.ops.aten.mul.Tensor(convert_element_type, 1);  convert_element_type = None
        add: "f64[10]" = torch.ops.aten.add.Tensor(mul, 0);  mul = None
        convert_element_type_1: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None

        # File: check_arange_decomp.py:9 in func, code: return s + a
        add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type_1, 10);  convert_element_type_1 = None
        return (add_1,)

 ===== AFTER POST GRAD =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:15 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f64[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float64);  iota = None
        mul: "f64[10]" = torch.ops.aten.mul.Tensor(convert_element_type, 1);  convert_element_type = None
        add: "f64[10]" = torch.ops.aten.add.Tensor(mul, 0);  mul = None
        convert_element_type_1: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None

        # File: check_arange_decomp.py:16 in func, code: return s + a
        add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type_1, 10);  convert_element_type_1 = None
        return (add_1,)
```
and C++
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<double>(tmp0);   // <---- useless ops
            auto tmp2 = static_cast<double>(1.0);     // <----
            auto tmp3 = decltype(tmp1)(tmp1 * tmp2);  // <----
            auto tmp4 = static_cast<double>(0.0);     // <----
            auto tmp5 = decltype(tmp3)(tmp3 + tmp4);  // <----
            auto tmp6 = c10::convert<float>(tmp5);
            auto tmp7 = static_cast<float>(10.0);
            auto tmp8 = decltype(tmp6)(tmp6 + tmp7);
            out_ptr0[static_cast<long>(x0)] = tmp8;
        }
    }
}
```

However, if we manually create arange on i64 and then put to float32, generated graph and C++ code are more natural and benefit of a speed-up.
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s).to(dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on `main`:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:14 in func, code: a = torch.arange(s).to(dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32);  iota = None

        # File: check_arange_decomp.py:15 in func, code: return s + a
        add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add,)

 ===== AFTER POST GRAD =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:21 in func, code: a = torch.arange(s).to(dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(iota, torch.float32);  iota = None

        # File: check_arange_decomp.py:22 in func, code: return s + a
        add: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add,)
```

C++ on `main`
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<float>(tmp0);
            auto tmp2 = static_cast<float>(10.0);
            auto tmp3 = decltype(tmp1)(tmp1 + tmp2);
            out_ptr0[static_cast<long>(x0)] = tmp3;
        }
    }
}
```

For example, the speed-up seen on upsample_nearest2d on cpu:
```
[----------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu ----------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                |  Eager (2.3.0a0+gitb4324ed) PR  |  Compiled (2.3.0a0+gitb4324ed) PR  |  Compiled (2.3.0a0+git0d1e705) Nightly  |  speed-up PR vs Nightly  |  Eager (2.3.0a0+git0d1e705) Nightly
1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)      |        287.988 (+-10.399)       |         200.034 (+-8.630)          |            285.143 (+-8.412)            |     1.425 (+-0.000)      |          287.991 (+-11.302)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)          |        697.206 (+-27.033)       |         171.650 (+-7.381)          |            193.280 (+-5.840)            |     1.126 (+-0.000)      |          701.642 (+-26.461)
      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)    |        149.149 (+-6.045)        |         222.780 (+-6.852)          |            299.968 (+-12.354)           |     1.346 (+-0.000)      |          145.055 (+-7.232)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)        |        596.741 (+-27.970)       |         205.923 (+-8.648)          |            233.912 (+-7.742)            |     1.136 (+-0.000)      |          598.000 (+-25.630)
      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)      |       1095.734 (+-51.658)       |         700.850 (+-24.852)         |           1044.255 (+-38.216)           |     1.490 (+-0.000)      |         1097.977 (+-35.521)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)          |       2741.813 (+-122.917)      |         583.073 (+-16.998)         |            665.029 (+-36.331)           |     1.141 (+-0.000)      |         2722.388 (+-116.263)
      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (256, 256)    |        578.183 (+-37.266)       |         833.295 (+-42.264)         |           1131.341 (+-54.710)           |     1.358 (+-0.000)      |          584.953 (+-45.549)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (256, 256)        |       2332.508 (+-103.556)      |         840.194 (+-47.664)         |            935.625 (+-47.467)           |     1.114 (+-0.000)      |         2334.314 (+-91.644)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)    |        272.631 (+-11.348)       |         195.988 (+-5.748)          |            274.021 (+-9.475)            |     1.398 (+-0.000)      |          272.752 (+-12.716)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)        |        640.409 (+-25.465)       |         164.773 (+-7.372)          |            185.018 (+-8.349)            |     1.123 (+-0.000)      |          639.390 (+-30.761)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)  |        158.602 (+-6.593)        |         220.478 (+-6.809)          |            286.376 (+-8.981)            |     1.299 (+-0.000)      |          158.557 (+-6.143)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)      |        548.903 (+-22.889)       |         202.788 (+-9.158)          |            227.404 (+-8.995)            |     1.121 (+-0.000)      |          554.096 (+-21.330)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)    |       1036.061 (+-35.285)       |         680.728 (+-30.925)         |            986.254 (+-42.732)           |     1.449 (+-0.000)      |         1038.718 (+-43.070)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)        |       2504.520 (+-125.805)      |         550.067 (+-21.383)         |            628.000 (+-27.589)           |     1.142 (+-0.000)      |         2523.134 (+-113.336)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (200, 300)  |       1058.188 (+-57.853)       |        1216.427 (+-76.160)         |           1380.231 (+-98.939)           |     1.135 (+-0.000)      |         1057.031 (+-66.075)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (200, 300)      |       2305.911 (+-116.864)      |        1080.189 (+-79.934)         |           1141.561 (+-67.959)           |     1.057 (+-0.000)      |         2306.606 (+-121.544)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)      |       1689.489 (+-60.579)       |        1077.401 (+-44.948)         |           1634.264 (+-64.340)           |     1.517 (+-0.000)      |         1693.945 (+-67.998)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)          |       4198.368 (+-179.096)      |         886.656 (+-30.355)         |           1028.568 (+-46.310)           |     1.160 (+-0.000)      |         4174.351 (+-141.020)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)    |        716.572 (+-51.954)       |        1175.864 (+-52.191)         |           1674.373 (+-51.815)           |     1.424 (+-0.000)      |          715.724 (+-41.104)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)        |       3604.989 (+-132.489)      |        1096.933 (+-54.290)         |           1270.347 (+-60.932)           |     1.158 (+-0.000)      |         3601.864 (+-140.218)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)      |       6721.610 (+-355.997)      |        4203.213 (+-134.362)        |           6423.763 (+-225.311)          |     1.528 (+-0.000)      |         6715.626 (+-288.233)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)          |      16695.467 (+-709.620)      |        3460.013 (+-149.456)        |           4001.810 (+-218.093)          |     1.157 (+-0.000)      |        16621.138 (+-713.320)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: nearest, align_corners: None, osize: (600, 700)    |       3020.017 (+-147.314)      |        4743.164 (+-135.850)        |           6709.494 (+-281.025)          |     1.415 (+-0.000)      |         3015.602 (+-105.852)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: nearest, align_corners: None, osize: (600, 700)        |      14456.688 (+-752.839)      |        5150.893 (+-201.571)        |           5737.315 (+-138.011)          |     1.114 (+-0.000)      |        14464.472 (+-720.027)

Times are in microseconds (us).
```

## PR

This PR improves arange decomp such that `arange(s, dtype=torch.float32)` removing extra dtype conversion to double:

Code:
```python
import torch

def func(x):
    s = x.shape[-1]
    a = torch.arange(s, dtype=torch.float32)
    return s + a

c_func = torch.compile(func)
out = c_func(torch.rand(10))
```

Graph on this PR:
```
 ===== Forward graph 0 =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:15 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        mul: "i64[10]" = torch.ops.aten.mul.Tensor(iota, 1);  iota = None
        add: "i64[10]" = torch.ops.aten.add.Tensor(mul, 0);  mul = None
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None

        # File: check_arange_decomp.py:16 in func, code: return s + a
        add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add_1,)

 ===== AFTER POST GRAD =====
 /pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self):
        # File: check_arange_decomp.py:16 in func, code: a = torch.arange(s, dtype=torch.float32)
        iota: "i64[10]" = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64, device = device(type='cpu'), requires_grad = False)
        mul: "i64[10]" = torch.ops.aten.mul.Tensor(iota, 1);  iota = None
        add: "i64[10]" = torch.ops.aten.add.Tensor(mul, 0);  mul = None
        convert_element_type: "f32[10]" = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None

        # File: check_arange_decomp.py:17 in func, code: return s + a
        add_1: "f32[10]" = torch.ops.aten.add.Tensor(convert_element_type, 10);  convert_element_type = None
        return (add_1,)
```
and C++ on this PR:
```c++
extern "C" void kernel(float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = c10::convert<long>(x0);
            auto tmp1 = c10::convert<float>(tmp0);
            auto tmp2 = static_cast<float>(10.0);
            auto tmp3 = decltype(tmp1)(tmp1 + tmp2);
            out_ptr0[static_cast<long>(x0)] = tmp3;
        }
    }
}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121013
Approved by: https://github.com/peterbell10
2024-04-11 09:02:31 +00:00
Jiong Gong
cacc8e27a5 [inductor][cpp] refactor code to use define_kernel and call_kernel similar to CUDA (#123704)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123704
Approved by: https://github.com/jansel, https://github.com/desertfire
2024-04-11 06:34:44 +00:00
leslie-fang-intel
9078191666 [Inductor] Add the possible fusions group by priority (#123067)
**Summary**

Refactor the `Scheduler.fuse_nodes` changes in https://github.com/pytorch/pytorch/pull/121625. In the previous implementation of `Scheduler.fuse_nodes` in https://github.com/pytorch/pytorch/pull/121625, we use the `enable_outer_loop_fusion` context to ensure `OuterLoopFusion` happens after all the norm fusions.

And there is a discussion in https://github.com/pytorch/pytorch/pull/121625/files#r1527177141 to reuse current `score_fusion` mechanism. However, given that [fuse_nodes](f4ff063c33/torch/_inductor/scheduler.py (L1679-L1698)) will invoke `fuse_nodes_once` 10 times. We are concerned that the score approach may potentially disrupt pairs of regular fusion nodes in the 2rd invocation of `fuse_nodes_once` if they have been pick up by the outer loop fusion in the 1st invocation of `fuse_nodes_once`.

In this PR, we propose adding an abstract of `filter_possible_fusions_by_priority`. In each invoking of `fuse_nodes_once`, the possible fusions will be grouped by their priority from the backend. And only the group of possible fusions with highest priority will be fused in this invocation. In this way, we can ensure `OuterLoopFusion` happens after all the norm fusions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123067
Approved by: https://github.com/lezcano, https://github.com/jgong5
ghstack dependencies: #121625
2024-04-05 06:30:41 +00:00
leslie-fang-intel
bac2a39aee [Inductor] [ReImplement] Outer Loop Fusion for CPP Backend (#121625)
**Summary**
Re-implement of https://github.com/pytorch/pytorch/pull/121064

**Test Plan**
```
python -u -m pytest -s -v test_cpu_repro.py -k test_outer_loop_fusion
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121625
Approved by: https://github.com/lezcano, https://github.com/jgong5
2024-04-05 06:24:57 +00:00
Gao Tianlin
aaef246c74 remove log2 decomposition; add log2 lowering (#123112)
Same reason as `log10`. `log2` is a core aten op, we should not decompose it. As https://github.com/pytorch/pytorch/pull/110882 suggested, it often maps to a hardware intrinsic; Furthermore, decomposing it will negatively impact the numerical precision of the output.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123112
Approved by: https://github.com/peterbell10
2024-04-02 16:16:26 +00:00
Jiong Gong
6f4ed57b8a [inductor][cpp] unified the vectorized conversion with at::vec::convert for all data types (#119979)
This PR unified the vectorized conversion with `at::vec::convert` for all vectorized data types. The intrinsics implementations are implemented as a specialization and moved to their own arch-specific files. The vectorized conversion logic in cpp Inductor is simplified.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119979
Approved by: https://github.com/jansel, https://github.com/malfet
2024-03-29 21:48:29 +00:00
Edward Z. Yang
3178ba0dc9 Don't use sympy Float functions, use an opaque one with no reasoning (#122823)
Sympy simplifications don't obey floating point semantics, so don't
use Sympy for this.  Keep them as is, only evaluate with the reference
implementations when all arguments are known.

This may end up getting subsumed by some other changes later, but I
wanted to understand if this was easy and it seems to be easy.

This doesn't actually depend on the earlier diffs on the stack and I can detach it.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122823
Approved by: https://github.com/lezcano
2024-03-29 19:13:55 +00:00
Jiong Gong
105381ea11 [inductor][cpp] simplify CppVecKernelChecker (remove bool/int8 load as mask and load as float flags) (#119734)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119734
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
ghstack dependencies: #119654, #119655
2024-03-27 11:20:35 +00:00
Jiong Gong
49121603ab [inductor][cpp] support vectorized indirect indexing (#119655)
This PR adds the vectorized indirect indexing so that we can further simplify the `CppVecKernelChecker` (done in the later PR #119734) and remove the check that throws `CppVecUnsupportedError`. A boundary assertion check is added on vectorized indices and via the new `indirect_assert` method on `Kernel` - the base implementation is for scalar indices, overridden in `CppVecKernel` for vectorized indices.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119655
Approved by: https://github.com/jansel
ghstack dependencies: #119654
2024-03-27 10:25:45 +00:00
Jiong Gong
367ec62ae3 [inductor][cpp] generalize vector mask for dtypes (#119654)
Vectorized boolean values in CPU Inductor were modeled with `Vectorized<float>` which cannot work for operations with other data types. This PR generalizes it with the new `VecMask` template class that can work for masks on any vectorized data types. The intrinsics implementation in `cpp_prefix.h` for mask conversion, cast and masked load are now implemented as the specialization for `VecMask` and moved to corresponding header files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119654
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
2024-03-27 05:33:53 +00:00
Nikita Shulga
dd3f2cb53a [Inductor] Add NEON ISA support on arm64 Macs (#122217)
This started as a re-land of https://github.com/pytorch/pytorch/pull/105590 but focusing on enabling it on MacOS, but quickly turned into landing very limited platform-specific acceleration at this time (I.e. this PR does not add any NEON accelerated code at all, just enables vectorized compilation for the existing abstractions)

Enabling the test harness, uncovered number of latent issues in CPU inductor that were fixed in the following PRS:
- https://github.com/pytorch/pytorch/pull/122511
- https://github.com/pytorch/pytorch/pull/122513
- https://github.com/pytorch/pytorch/pull/122580
- https://github.com/pytorch/pytorch/pull/122608

Following was added/changed to enable vectorization code to work on MacOS
 - Added VecNEON class to `_inductor/codecache.py`  that is supported on all AppleSilicon Macs
 - Added `Vectorized::loadu_one_fourth` to `vec_base.h`, and limit it to 8-bit types
 - Change 64-bit integral types mapping to `int64_t`/`uint64_t` to align with the rest of the code, as on MacOS, `int64_t` is a `long long` rather than `long` (see https://github.com/pytorch/pytorch/pull/118149 for more details)

See table below for perf changes with and without torch.compile using [gpt-fast](https://github.com/pytorch-labs/gpt-fast) running `stories15M` on M2 Pro:
| dtype  | Eager | Compile (before) | Compile (after) |
| ------ | ------ | --------- | --------- |
| bfloat16  | 120 tokens/sec  | 130 tokens/sec | 156 tokens/sec |
| float32  | 158 tokens/sec  | 140 tokens/sec | 236 tokens/sec |
| float16  | 235 tokens/sec  | 81 tokens/sec | 58 tokens/sec |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122217
Approved by: https://github.com/jansel
2024-03-26 05:07:30 +00:00
Nikita Shulga
cf06189a2d [CPPInductor] Fix another out-of-bounds access (#122580)
Not sure what was the idea behind `{self.tiling_factor}*sizeof(float)/sizeof({DTYPE_TO_CPP[dtype]})` size calculation (perhaps copy-n-paste error during the refactor made by https://github.com/pytorch/pytorch/pull/97626  ) , but `Vectorized::store(ptr, tiling_factor)` needs at least `tiling_factor` elements, not `tiling_factor/2` (which would be the case with the original calculation if data type is 64-bit value such as int64)
Discovered while trying to enable arch64 vectorized inductor.
Minimal reproducer (reproducible on ARMv8 or any  x86_64 machine that does not support AVX512):
```python
import torch
def do_ds(x, y):
    return torch.diagonal_scatter(x, y)

x=torch.ones(10, 10, dtype=torch.int64)
y=torch.tensor([ 1,  2, -8,  8,  5,  5, -7, -8,  7,  0])
dsc = torch.compile(do_ds)
assert torch.allclose(torch.diagonal_scatter(x, y), dsc(x, y))
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122580
Approved by: https://github.com/Skylion007, https://github.com/jansel
2024-03-25 04:49:20 +00:00
vfdev-5
cdc7f0fd3b Fixed failing pyhpc_equation_of_state due to cpp nodes fusion with compatible ranges (#122420)
Fixes #122283

Description:

PR https://github.com/pytorch/pytorch/pull/120077 introduced cpp nodes fusion with compatible ranges with an assumption that all scheduler nodes inside the fused nodes are the same, however, it appeared that snodes can have different indexing expressions. This PR fixes the incorrect assumption.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122420
Approved by: https://github.com/lezcano
2024-03-24 00:40:31 +00:00
Adnan Akhundov
456b112dca [inductor] Support non-Tensor predicate in torch.cond (#122378)
Summary: Previously, we only supported torch.Tensor boolean scalar predicate in `torch.cond` in Inductor. This PR adds support for SymBool and Python bool predicate, to match the `torch.cond` [sematics](https://pytorch.org/docs/stable/generated/torch.cond.html) in Dynamo / Export.

Test Plan:

```
$ python test/inductor/test_control_flow.py
...
----------------------------------------------------------------------
Ran 34 tests in 56.980s

OK

$ python test/inductor/test_aot_inductor.py -k test_cond
...
----------------------------------------------------------------------
Ran 54 tests in 460.093s

OK (skipped=4)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122378
Approved by: https://github.com/jansel, https://github.com/chenyang78
2024-03-21 14:35:01 +00:00
haozhe.zhu
3bc2bb6781 use two pass reduction for deterministic reduction order (#115620)
## Motivation
Address the [non-deterministic reduction order](https://github.com/pytorch/pytorch/issues/93542#issuecomment-1411294181) issue for `omp parallel reduction`.

## Latest update on 1.15:
55d81901bc.
Do not reduce to arr in loops. Instead, reduce to a local scaler and write it to arr after local reduction is done. This will allow the compiler to optimize the reduction variable in register instead read/write from memory. If the `working set` of `loop body` is quite large, `read/write from register/memory` will have a large gap.
```
vaddss (%xmm0, %xmm11, %xmm11) -> accumulate in register %xmm0
vaddssl ((%rdx, %rdi, 4), %xmm0, %xmm0) -> accumulate in memory address (%rdx, %rdi, 4)
```
Examples code:
```
tmp0_acc_arr[64];
#pragma omp parallel num_threads(64)
{
    auto tid = omp_get_thread_num();
    #pragma omp for
    for(...){
        ....
        tmp0_acc_arr[tid] = tmp0_acc_arr[tid] + tmp_x;  // access array will always from memory
    }
}
```
will be changed to
```
tmp0_acc_arr[64];
#pragma omp parallel num_threads(64)
{
    auto tid = omp_get_thread_num();
    **auto tmp0_acc_local = 0;**
    #pragma omp for
    for(...){
        ....
        **tmp0_acc_local**  = tmp0_acc_local + tmp_x;
    }
    **tmp0_acc_arr[tid] = tmp0_acc_local;**
}
```

## Descriptions
Following aten to use `two pass reduction` with `omp parallel` for deterministic reduction order.
9c3ae37fc4/aten/src/ATen/Parallel-inl.h (L39)
9c3ae37fc4/aten/src/ATen/native/TensorIteratorReduce.cpp (L24)
```
            float tmp_acc0 = 0;
            at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
            // init reduction buffer per thread
            float tmp_acc0_arr[64];
            at::vec::Vectorized<float> tmp_acc0_vec_arr[64];
            for (int tid = 0; tid < 64; tid++)
            {
                tmp_acc0_arr[tid] = 0;
                tmp_acc0_vec_arr[tid] = at::vec::Vectorized<float>(0);
            }
            #pragma omp parallel num_threads(64)
            {
                int tid = omp_get_thread_num();
                #pragma omp for
                for(long x0=static_cast<long>(0L); x0<static_cast<long>(3964928L); x0+=static_cast<long>(16L))
                {
                    auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x0));
                    auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x0));
                    auto tmp2 = tmp0 - tmp1;
                    auto tmp3 = tmp2 * tmp2;
                    // reduce to per thread buffers
                    tmp_acc0_vec_arr[tid] = tmp_acc0_vec_arr[tid] + tmp3;
                }
            }
            // second pass reduce
            for (int tid = 0; tid < 64; tid++)
            {
                tmp_acc0 = tmp_acc0 + tmp_acc0_arr[tid];
                tmp_acc0_vec = tmp_acc0_vec + tmp_acc0_vec_arr[tid];
            }
            tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
            out_ptr0[static_cast<long>(0L)] = static_cast<float>(tmp_acc0);
```

## Test results
I test this PR with dynamo benchmark on 32-core ICX system,
Result (avg speed up):
| |  before this PR   | after this PR  |
| ---- |  ----  | ----  |
| torchbench | 1.303  | 1.301 |
| hugginface | 1.346  | 1.343 |
| timms | 1.971 | 1.970 |

```
export LD_PRELOAD=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}/lib/libiomp5.so:${CONDA_PREFIX:-"$(dirname $(which conda))/../"}/lib/libjemalloc.so
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
export KMP_AFFINITY=granularity=fine,compact,1,0
export KMP_BLOCKTIME=1

multi_threads_test() {
    CORES=$(lscpu | grep Core | awk '{print $4}')
    export OMP_NUM_THREADS=$CORES
    end_core=$(expr $CORES - 1)
    numactl -C 0-${end_core} --membind=0 python benchmarks/dynamo/${SUITE}.py --${SCENARIO} --${DT} -dcpu -n50 --no-skip --dashboard --only "${MODEL}" ${Channels_extra} ${BS_extra} ${Shape_extra} ${Mode_extra} ${Wrapper_extra} ${Flag_extra} --timeout 9000 --backend=inductor --output=${LOG_BASE}/${SUITE}.csv
}

SCENARIO=performance
DT=float32
export TORCHINDUCTOR_FREEZING=1
Flag_extra="--freezing"
Mode_extra="--inference"

for suite in timm_models huggingface torchbench
do
  export SUITE=$suite
  echo $SUITE
  export LOG_BASE=`date +%m%d%H%M%S`
  mkdir $LOG_BASE
  multi_threads_test
done
```
System info
```
ubuntu@ip-172-31-18-205:~/hz/pytorch$ lscpu
Architecture:            x86_64
  CPU op-mode(s):        32-bit, 64-bit
  Address sizes:         46 bits physical, 48 bits virtual
  Byte Order:            Little Endian
CPU(s):                  64
  On-line CPU(s) list:   0-63
Vendor ID:               GenuineIntel
  Model name:            Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
    CPU family:          6
    Model:               106
    Thread(s) per core:  2
    Core(s) per socket:  32
    Socket(s):           1
    Stepping:            6
    BogoMIPS:            5800.00
    Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic mo
                         vbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xs
                         aveopt xsavec xgetbv1 xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities
Virtualization features:
  Hypervisor vendor:     KVM
  Virtualization type:   full
Caches (sum of all):
  L1d:                   1.5 MiB (32 instances)
  L1i:                   1 MiB (32 instances)
  L2:                    40 MiB (32 instances)
  L3:                    54 MiB (1 instance)
NUMA:
  NUMA node(s):          1
  NUMA node0 CPU(s):     0-63
Vulnerabilities:
  Gather data sampling:  Unknown: Dependent on hypervisor status
  Itlb multihit:         Not affected
  L1tf:                  Not affected
  Mds:                   Not affected
  Meltdown:              Not affected
  Mmio stale data:       Mitigation; Clear CPU buffers; SMT Host state unknown
  Retbleed:              Not affected
  Spec rstack overflow:  Not affected
  Spec store bypass:     Mitigation; Speculative Store Bypass disabled via prctl
  Spectre v1:            Mitigation; usercopy/swapgs barriers and __user pointer sanitization
  Spectre v2:            Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
  Srbds:                 Not affected
  Tsx async abort:       Not affected
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115620
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-03-15 02:03:10 +00:00
Xia, Weiwen
f848e9c646 [Quant][Inductor] Fix q/dq per channel lowering with 64-bit qparams (#120984)
Fixes #120869

Fix lowering of `quantize_per_channel` and `dequantize_per_channel` with float64 scale and int64 zero point.
Generated codes are incorrect without explicit type conversion. Add type conversion to the lowering pass, i.e., float64 (double) -> float32 and int64 -> int32.

**Test plan**
python test/inductor/test_cpu_repro.py -k test_per_channel_fake_quant_module_uint8

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120984
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jerryzh168
2024-03-07 06:23:52 +00:00
vfdev-5
49d1fd31cf Fuse nodes with sizes (s0*s1*...,) and (s0, s1, s2, ...) (#120077)
Description:
- PR tries to fuse nodes with compatible sizes, for example `node1: (s0, s1, s2)` and `node2: (s0 * s1 * s2)`. On `main` these two nodes can be fused due to different sizes. With this PR we can recompute node2 size, body etc using node1 indexing constraint and thus be able to fuse two nodes.
- this should influence only cpu device

Example:
```python
from unittest.mock import patch
import torch
from torch._inductor.graph import GraphLowering
from torch._inductor import config

# Force multple scheduler nodes creation to fuse them
config.realize_opcount_threshold = 1

@torch.compile(fullgraph=True, dynamic=True)
def fn(x: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor) -> torch.Tensor:
    o1 = x * w1.view(1, 1, 1, -1)
    o2 = x * w2.view(1, 1, 1, -1)
    output = o1 + o2
    return output

in_nodes = []
outputs = []
run_node = GraphLowering.run_node

graph_lowering_obj = None

def run_node_alt(self, n):
    global graph_lowering_obj

    graph_lowering_obj = self
    in_nodes.append(n)
    output = run_node(self, n)
    outputs.append(output)

    return output

x = torch.rand(1, 3, 32, 32)
w1 = torch.randn(32)
w2 = torch.randn(32)

with patch.object(GraphLowering, "run_node", run_node_alt):
    fn(x, w1, w2)

print("graph_lowering_obj.buffers:", graph_lowering_obj.buffers)
print("graph_lowering_obj.scheduler:", graph_lowering_obj.scheduler.nodes)
```

Output on `main`:
```
graph_lowering_obj.buffers: [ComputedBuffer(name='buf0', layout=FixedLayout('cpu', torch.float32, size=[1, s1, s0, s0], stride=[s0**2*s1, s0**2, s0, 1]), data=Pointwise(
  'cpu',
  torch.float32,
  def inner_fn(index):
      _, i1, i2, i3 = index
      tmp0 = ops.load(arg3_1, i3 + i1 * s0**2 + i2 * s0)
      tmp1 = ops.load(arg1_1, i3)
      tmp2 = tmp0 * tmp1
      return tmp2
  ,
  ranges=[1, s1, s0, s0],
  origin_node=mul,
  origins={mul}
)), ComputedBuffer(name='buf1', layout=FixedLayout('cpu', torch.float32, size=[1, s1, s0, s0], stride=[s0**2*s1, s0**2, s0, 1]), data=Pointwise(
  'cpu',
  torch.float32,
  def inner_fn(index):
      _, i1, i2, i3 = index
      tmp0 = ops.load(arg3_1, i3 + i1 * s0**2 + i2 * s0)
      tmp1 = ops.load(arg4_1, i3)
      tmp2 = tmp0 * tmp1
      return tmp2
  ,
  ranges=[1, s1, s0, s0],
  origin_node=mul_1,
  origins={mul_1}
)), ComputedBuffer(name='buf2', layout=FixedLayout('cpu', torch.float32, size=[1, s1, s0, s0], stride=[s0**2*s1, s0**2, s0, 1]), data=Pointwise(
  'cpu',
  torch.float32,
  def inner_fn(index):
      _, i1, i2, i3 = index
      tmp0 = ops.load(buf0, i3 + i1 * s0**2 + i2 * s0)
      tmp1 = ops.load(buf1, i3 + i1 * s0**2 + i2 * s0)
      tmp2 = tmp0 + tmp1
      return tmp2
  ,
  ranges=[1, s1, s0, s0],
  origin_node=add,
  origins={add}
))]
graph_lowering_obj.scheduler: [FusedSchedulerNode(nodes=buf0_buf1), SchedulerNode(name='buf2')]
```
Output on this PR:
```
graph_lowering_obj.buffers: [ComputedBuffer(name='buf0', layout=FixedLayout('cpu', torch.float32, size=[1, s1, s0, s0], stride=[s0**2*s1, s0**2, s0, 1]), data=Pointwise(
  'cpu',
  torch.float32,
  def inner_fn(index):
      _, i1, i2, i3 = index
      tmp0 = ops.load(arg3_1, i3 + i1 * s0**2 + i2 * s0)
      tmp1 = ops.load(arg1_1, i3)
      tmp2 = tmp0 * tmp1
      return tmp2
  ,
  ranges=[1, s1, s0, s0],
  origin_node=mul,
  origins={mul}
)), ComputedBuffer(name='buf1', layout=FixedLayout('cpu', torch.float32, size=[1, s1, s0, s0], stride=[s0**2*s1, s0**2, s0, 1]), data=Pointwise(
  'cpu',
  torch.float32,
  def inner_fn(index):
      _, i1, i2, i3 = index
      tmp0 = ops.load(arg3_1, i3 + i1 * s0**2 + i2 * s0)
      tmp1 = ops.load(arg4_1, i3)
      tmp2 = tmp0 * tmp1
      return tmp2
  ,
  ranges=[1, s1, s0, s0],
  origin_node=mul_1,
  origins={mul_1}
)), ComputedBuffer(name='buf2', layout=FixedLayout('cpu', torch.float32, size=[1, s1, s0, s0], stride=[s0**2*s1, s0**2, s0, 1]), data=Pointwise(
  'cpu',
  torch.float32,
  def inner_fn(index):
      _, i1, i2, i3 = index
      tmp0 = ops.load(buf0, i3 + i1 * s0**2 + i2 * s0)
      tmp1 = ops.load(buf1, i3 + i1 * s0**2 + i2 * s0)
      tmp2 = tmp0 + tmp1
      return tmp2
  ,
  ranges=[1, s1, s0, s0],
  origin_node=add,
  origins={add}
))]
graph_lowering_obj.scheduler: [FusedSchedulerNode(nodes=buf0_buf1_buf2)]
```

Context:
While working on https://github.com/pytorch/pytorch/pull/120411, upsampling bicubic decomposition, I saw an extra for-loop in C++ generated code summing up two buffers. Exploring the cause, it happend due to buffer number of ops goes beyond `config.realize_opcount_threshold`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120077
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/peterbell10
2024-03-06 12:19:45 +00:00
Pearu Peterson
c06499981d Add a decomposition for torch.put, 2. (#120179)
As in the title. It is an updated copy of https://github.com/pytorch/pytorch/pull/115306 .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120179
Approved by: https://github.com/lezcano, https://github.com/peterbell10, https://github.com/jgong5
2024-03-04 14:37:30 +00:00
Jiong Gong
1c7b0e7cd1 [inductor][cpp] disable masked load for non-fp data types (#120558)
Fix https://github.com/pytorch/pytorch/issues/120377. We disable the masked load for non-fp data types for now. The complete support of masks will be added in https://github.com/pytorch/pytorch/pull/119654.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120558
Approved by: https://github.com/lezcano, https://github.com/jansel
2024-02-26 04:12:22 +00:00
Isuru Fernando
b7df3bba62 add decomposition for frexp (#119217)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119217
Approved by: https://github.com/peterbell10
ghstack dependencies: #119284, #120027
2024-02-23 21:52:42 +00:00
Yang Chen
b96ea097ee [aotinductor] rename CppWrapperCodeGen and CudaWrapperCodeGen (#120391)
make WrapperCodeGen subclass names consistent with the
file names:

CppWrapperCodeGen -> CppWrapperCpu
CudaWrapperCodeGen -> CppWrapperCuda

Differential Revision: [D54074938](https://our.internmc.facebook.com/intern/diff/D54074938)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120391
Approved by: https://github.com/aakhundov
2024-02-23 10:41:50 +00:00
atalman
be8ba5ef2d Revert "use two pass reduction for deterministic reduction order (#11… (#120243)
This reverts commit cc7ef43423.

Manual revert because of the conflict in: test/inductor/test_cpu_repro.py , conflict with this PR: https://github.com/pytorch/pytorch/pull/118365

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120243
Approved by: https://github.com/malfet, https://github.com/huydhn
2024-02-20 20:50:29 +00:00
haozhe.zhu
b4b1480b06 remove redundant to_dtype in Fused Schedular Nodes (#118365)
Fix https://github.com/pytorch/pytorch/issues/115260.
This issue is triggered by `FusedSchedularNodes` cases.
We always store `lowp buffer` to `store_cache` then load `lowp buffer` from `store_cache` and `convert it to float` before `compute ops`.
Now we will generate a `{key: to(float32)_expr, value: the float32 cse var before to_dtype and store}` in `cse.cache`.
Then the `to_dtype(float32)` after `load` will hit this cache and not generate a new var with cast codes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118365
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-02-20 13:35:03 +00:00
haozhe.zhu
cc7ef43423 use two pass reduction for deterministic reduction order (#115620)
## Motivation
Address the [non-deterministic reduction order](https://github.com/pytorch/pytorch/issues/93542#issuecomment-1411294181) issue for `omp parallel reduction`.

## Latest update on 1.15:
55d81901bc.
Do not reduce to arr in loops. Instead, reduce to a local scaler and write it to arr after local reduction is done. This will allow the compiler to optimize the reduction variable in register instead read/write from memory. If the `working set` of `loop body` is quite large, `read/write from register/memory` will have a large gap.
```
vaddss (%xmm0, %xmm11, %xmm11) -> accumulate in register %xmm0
vaddssl ((%rdx, %rdi, 4), %xmm0, %xmm0) -> accumulate in memory address (%rdx, %rdi, 4)
```
Examples code:
```
tmp0_acc_arr[64];
#pragma omp parallel num_threads(64)
{
    auto tid = omp_get_thread_num();
    #pragma omp for
    for(...){
        ....
        tmp0_acc_arr[tid] = tmp0_acc_arr[tid] + tmp_x;  // access array will always from memory
    }
}
```
will be changed to
```
tmp0_acc_arr[64];
#pragma omp parallel num_threads(64)
{
    auto tid = omp_get_thread_num();
    **auto tmp0_acc_local = 0;**
    #pragma omp for
    for(...){
        ....
        **tmp0_acc_local**  = tmp0_acc_local + tmp_x;
    }
    **tmp0_acc_arr[tid] = tmp0_acc_local;**
}
```

## Descriptions
Following aten to use `two pass reduction` with `omp parallel` for deterministic reduction order.
9c3ae37fc4/aten/src/ATen/Parallel-inl.h (L39)
9c3ae37fc4/aten/src/ATen/native/TensorIteratorReduce.cpp (L24)
```
            float tmp_acc0 = 0;
            at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
            // init reduction buffer per thread
            float tmp_acc0_arr[64];
            at::vec::Vectorized<float> tmp_acc0_vec_arr[64];
            for (int tid = 0; tid < 64; tid++)
            {
                tmp_acc0_arr[tid] = 0;
                tmp_acc0_vec_arr[tid] = at::vec::Vectorized<float>(0);
            }
            #pragma omp parallel num_threads(64)
            {
                int tid = omp_get_thread_num();
                #pragma omp for
                for(long x0=static_cast<long>(0L); x0<static_cast<long>(3964928L); x0+=static_cast<long>(16L))
                {
                    auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x0));
                    auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x0));
                    auto tmp2 = tmp0 - tmp1;
                    auto tmp3 = tmp2 * tmp2;
                    // reduce to per thread buffers
                    tmp_acc0_vec_arr[tid] = tmp_acc0_vec_arr[tid] + tmp3;
                }
            }
            // second pass reduce
            for (int tid = 0; tid < 64; tid++)
            {
                tmp_acc0 = tmp_acc0 + tmp_acc0_arr[tid];
                tmp_acc0_vec = tmp_acc0_vec + tmp_acc0_vec_arr[tid];
            }
            tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
            out_ptr0[static_cast<long>(0L)] = static_cast<float>(tmp_acc0);
```

## Test results
I test this PR with dynamo benchmark on 32-core ICX system,
Result (avg speed up):
| |  before this PR   | after this PR  |
| ---- |  ----  | ----  |
| torchbench | 1.303  | 1.301 |
| hugginface | 1.346  | 1.343 |
| timms | 1.971 | 1.970 |

```
export LD_PRELOAD=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}/lib/libiomp5.so:${CONDA_PREFIX:-"$(dirname $(which conda))/../"}/lib/libjemalloc.so
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
export KMP_AFFINITY=granularity=fine,compact,1,0
export KMP_BLOCKTIME=1

multi_threads_test() {
    CORES=$(lscpu | grep Core | awk '{print $4}')
    export OMP_NUM_THREADS=$CORES
    end_core=$(expr $CORES - 1)
    numactl -C 0-${end_core} --membind=0 python benchmarks/dynamo/${SUITE}.py --${SCENARIO} --${DT} -dcpu -n50 --no-skip --dashboard --only "${MODEL}" ${Channels_extra} ${BS_extra} ${Shape_extra} ${Mode_extra} ${Wrapper_extra} ${Flag_extra} --timeout 9000 --backend=inductor --output=${LOG_BASE}/${SUITE}.csv
}

SCENARIO=performance
DT=float32
export TORCHINDUCTOR_FREEZING=1
Flag_extra="--freezing"
Mode_extra="--inference"

for suite in timm_models huggingface torchbench
do
  export SUITE=$suite
  echo $SUITE
  export LOG_BASE=`date +%m%d%H%M%S`
  mkdir $LOG_BASE
  multi_threads_test
done
```
System info
```
ubuntu@ip-172-31-18-205:~/hz/pytorch$ lscpu
Architecture:            x86_64
  CPU op-mode(s):        32-bit, 64-bit
  Address sizes:         46 bits physical, 48 bits virtual
  Byte Order:            Little Endian
CPU(s):                  64
  On-line CPU(s) list:   0-63
Vendor ID:               GenuineIntel
  Model name:            Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
    CPU family:          6
    Model:               106
    Thread(s) per core:  2
    Core(s) per socket:  32
    Socket(s):           1
    Stepping:            6
    BogoMIPS:            5800.00
    Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic mo
                         vbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xs
                         aveopt xsavec xgetbv1 xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities
Virtualization features:
  Hypervisor vendor:     KVM
  Virtualization type:   full
Caches (sum of all):
  L1d:                   1.5 MiB (32 instances)
  L1i:                   1 MiB (32 instances)
  L2:                    40 MiB (32 instances)
  L3:                    54 MiB (1 instance)
NUMA:
  NUMA node(s):          1
  NUMA node0 CPU(s):     0-63
Vulnerabilities:
  Gather data sampling:  Unknown: Dependent on hypervisor status
  Itlb multihit:         Not affected
  L1tf:                  Not affected
  Mds:                   Not affected
  Meltdown:              Not affected
  Mmio stale data:       Mitigation; Clear CPU buffers; SMT Host state unknown
  Retbleed:              Not affected
  Spec rstack overflow:  Not affected
  Spec store bypass:     Mitigation; Speculative Store Bypass disabled via prctl
  Spectre v1:            Mitigation; usercopy/swapgs barriers and __user pointer sanitization
  Spectre v2:            Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
  Srbds:                 Not affected
  Tsx async abort:       Not affected
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115620
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-02-20 00:46:59 +00:00
Yang Chen
bc7f3efb09 [aot_inductor] move CppWrapperCodeGen into a separate file (#119871)
This reverts commit d8e319a961.

Differential Revision: [D53817853](https://our.internmc.facebook.com/intern/diff/D53817853)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119871
Approved by: https://github.com/albanD, https://github.com/khabinov
ghstack dependencies: #119870
2024-02-16 08:14:20 +00:00
Pearu Peterson
2c91e13afc Add lowerings to special functions (#119187)
As in the title.

In addition, the PR introduces infrastructure for lowerings of pointwise functions that have both cpp and triton implementations available.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119187
Approved by: https://github.com/peterbell10
2024-02-11 16:35:40 +00:00
PyTorch MergeBot
d8e319a961 Revert "[aot_inductor] move CppWrapperCodeGen into a separate file (#119491)"
This reverts commit 760056bbdc.

Reverted https://github.com/pytorch/pytorch/pull/119491 on behalf of https://github.com/DanilBaibak due to Reverted as a dependency for #119448 ([comment](https://github.com/pytorch/pytorch/pull/119491#issuecomment-1937344548))
2024-02-10 23:02:05 +00:00
Yang Chen
760056bbdc [aot_inductor] move CppWrapperCodeGen into a separate file (#119491)
This PR moved CppWrapperCodeGen class into a seperate file,
cpp_wrapper.py, to simplify wrapper.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119491
Approved by: https://github.com/desertfire, https://github.com/albanD
2024-02-10 02:15:56 +00:00
Jiong Gong
a050d146b7 [Inductor] Add Int8 data type into Inductor CPP backend vectorized code generation (#119179)
**Summary**
Part 1 of fixing https://github.com/pytorch/pytorch/issues/119141 which needs vectorized code generation of per channel quant and int8 data type.
In the current implementation for quantization, the vectorized code generation only supports the `uint8` data type. In this PR, we introduce support for the `int8` data type within the vectorized code generation.

**TestPlan**
```
python -u -m pytest -s -v test_cpu_repro.py -k test_decomposed_dequant_relu_quant_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_dequant_quant_lowering_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_dequant_maxpool2d_lowering_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_tile2d_load_decomposed_dequant_add_relu_quant_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_per_tensor_fake_quant_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_non_contiguous_load_buf_quant_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_tile2d_store_channel_shuffle_cl_quant_output_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_dequant_relu_quant_dequant_relu_quant_lowering_int8
```

Co-authored-by: Jiong Gong <jiong.gong@intel.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119179
Approved by: https://github.com/peterbell10, https://github.com/jgong5, https://github.com/jansel
2024-02-09 07:33:12 +00:00
Yang Chen
9f8ade04cc [aot_inductor] replace TORCH_CHECK with AOTI_CHECK in the generate cpp code (#119220)
In some cases where we have TORCH_CHECK in loops, it may cause the host
compiler to spend hours optimizing the run_impl function. This PR
mitigated the issue by replacing TORCH_CHECK with a custom AOTI_CHECK,
where we force the underneath assert function to be noinline.

If forcing noinline caused any serious perf regression, we could
either add an option to turn on/off enable noinline. Or, we could
another an option to just turn AOTI_CHECK into a no-op, similar
to the ```assert``` macro from cassert.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119220
Approved by: https://github.com/hl475, https://github.com/desertfire
2024-02-08 21:57:27 +00:00
Pearu Peterson
7ec6ac89e8 Add lowering to special.modified_bessel_i0 (#118993)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118993
Approved by: https://github.com/peterbell10
2024-02-08 18:42:40 +00:00
Jiong Gong
896cf9d1ce [inductor][cpp] vectorization support for int32/int64 (#119001)
This pull request aims to complete most of the support for vectorizing int32 and int64 data types except for indirect indexing and masks. The basic data type support for uint32 and uint64 is also added but without vectorization. More vectorized conversion functions are added between integer and float. In order to support int64 vectors, a new VectorizedN class to handle vectors of arbitrary length. Below are the details:
1. Complete most of the int32 and int64 vectorization support including load, store, reduction, constant and conversion. The indirect indexing and masks will be addressed in follow-up PRs, after which, the legality checking logic in `CppVecKernelChecker` can be further simplified.
2. Util functions for conversion between integer and float vectors (in cpp_prefix.h and ATen vec). Ideally, we'd better move them from cpp_prefix.h to ATen vec to simplify cpp_prefix.h, will be addressed in follow-up PRs.
3. Introduced a new template class VectorizedN, designed to handle vectors of arbitrary length by encapsulating multiple Vectorized<T> instances. This class supports most of the operations of `Vectorized<T>`. It makes the support of int64 vectorization simpler. I will also apply it to bf16/fp16/int8 in the follow-up PRs for better efficiency. For example, bf16 currently only uses half of the vector lanes. With `VectorizedN`, we can use full of the lanes and map bf16 vector to `VectorizedN<float,2>` on conversion.
4. Basic data type support is added for uint32 and uint64 (in graph.py). Vectorization support will be added later but not of high priority due to fewer usages.

Next steps:

- [ ] Refactor the vector mask handling to support data types other than float. Currently vector masks are implemented with float vectors.
- [ ] Fully utilize vector lanes for bfloat16/float16/int8.
- [ ] Support indirect indexing with vectorized index via scalarization.
- [ ] Clean up `CppVecKernelChecker`.
- [ ] Simplify `cpp_prefix.h` including refactoring vector conversion logic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119001
Approved by: https://github.com/peterbell10, https://github.com/jansel
2024-02-08 17:38:49 +00:00
Yang Chen
b2e0f8d82d [mypy] added type annotations to codegen_nodes methods (#119080)
added correct type annotations to scheduler and backends'
codegen_nodes methods

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119080
Approved by: https://github.com/eellison
2024-02-05 18:33:52 +00:00
Edward Z. Yang
abc09b27b9 Some minor type stub improvements (#118529)
I was just playing around with improving the typing of symbolic_shapes. The PR is not "complete" but I in particular wanted to get feedback on whether or not people liked making ValueRanges Generic; it seems that distinguishing if you have an Expr ValueRange or a SympyBoolean ValueRange is a lot of trouble for downstream. Using TypeGuard, we can perform refinements on the generic parameter inside methods, although we still have to cast back to ValueRange[T] due to https://github.com/python/mypy/issues/14425#issuecomment-1914852707

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118529
Approved by: https://github.com/Skylion007
2024-02-04 00:19:00 +00:00
Pearu Peterson
a69016a741 Add lowering to special.bessel_j1 (#118992)
As in the title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118992
Approved by: https://github.com/peterbell10
2024-02-02 20:16:08 +00:00
Bin Bao
c7ba5f6c6f [AOTI] Fix a cpp kernel missing arg type issue (#119021)
Summary: The current way of fetching the kernel arg types only works for tensors, not symbols.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119021
Approved by: https://github.com/aakhundov, https://github.com/hl475, https://github.com/khabinov
2024-02-02 20:11:58 +00:00
PyTorch MergeBot
dbba1d4bf5 Revert "Some minor type stub improvements (#118529)"
This reverts commit c978f38bd4.

Reverted https://github.com/pytorch/pytorch/pull/118529 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/118529#issuecomment-1922362331))
2024-02-01 22:18:36 +00:00
Edward Z. Yang
c978f38bd4 Some minor type stub improvements (#118529)
I was just playing around with improving the typing of symbolic_shapes. The PR is not "complete" but I in particular wanted to get feedback on whether or not people liked making ValueRanges Generic; it seems that distinguishing if you have an Expr ValueRange or a SympyBoolean ValueRange is a lot of trouble for downstream. Using TypeGuard, we can perform refinements on the generic parameter inside methods, although we still have to cast back to ValueRange[T] due to https://github.com/python/mypy/issues/14425#issuecomment-1914852707

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118529
Approved by: https://github.com/Skylion007
2024-01-31 20:56:56 +00:00
hodavand
8026534a2f Add torch.complex128 and torch.complex32 to DTYPE_TO_ATEN dictionary. (#117929)
Fixes #117370

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117929
Approved by: https://github.com/Skylion007, https://github.com/desertfire
2024-01-31 19:34:58 +00:00
Pearu Peterson
2327879fb6 Add lowering to special.bessel_j0 (2nd try) (#118565)
This PR is a copy of https://github.com/pytorch/pytorch/pull/118464 that was merged without using pytorchbot. Sorry for the noise!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118565
Approved by: https://github.com/peterbell10
2024-01-30 15:26:59 +00:00
Jiong Gong
e5bb527d3e [inductor][cpp] support scalar value in vec reduction (#118511)
Fix https://github.com/pytorch/pytorch/issues/118379

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118511
Approved by: https://github.com/leslie-fang-intel, https://github.com/lezcano, https://github.com/jansel
2024-01-30 13:07:43 +00:00
Jiong Gong
04c1df651a [inductor][cpp] enable vectorization with constant bool (#118380)
Related model DebertaForQuestionAnswering etc. For DebertaForQuestionAnswering, single thread, measured on ICX:
Before: 0.990x, After: 1.043x

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118380
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
2024-01-29 13:31:22 +00:00
leslie-fang-intel
ee3dfbbe47 [Inductor] Fix Argmax codegen with Nan input (#118358)
**Summary**
Fix issue https://github.com/pytorch/pytorch/issues/118266, current `torch.argmax` and `torch.argmin` has different return values with eager and Inductor cpp backend when inputs has `Nan` value. Align cpp backend results to eager by reusing the compare function.

**Test Plan**
```
python -u -m pytest -s -v test_cpu_repro.py -k test_argmin_cpu_only
python -u -m pytest -s -v test_cpu_repro.py -k test_argmax_argmin_with_nan_value
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118358
Approved by: https://github.com/lezcano, https://github.com/jgong5, https://github.com/jansel
2024-01-29 09:09:46 +00:00