Commit Graph

148 Commits

Author SHA1 Message Date
Brian Hirsh
2efe4d809f [hotfix inductor test] disable cpp vectorization codegen in fbcode for inductor (#104560)
Summary:
After D46364355 landed, a few inductor internal tests started failing. When I ran this locally:
```
buck2 test fbcode//mode/dev-nosan fbcode//caffe2/test/inductor:config
```

The test appeared to hang with this output, until it would fail with a timeout after 10 minutes passed:
```
Test caffe2/test/inductor:config -- discovering tests [local_execute]
```

Eventually, I realized that inductor has a value `HAS_CPU` (https://www.internalfb.com/code/fbsource/[6cc47fa5eb77a93d91a519d3eb3df67ceddb8faa]/fbcode/caffe2/torch/testing/_internal/inductor_utils.py?lines=23) that is implemented lazily. Part of that implementation involves inspecting `/proc/cpuinfo` to figure out what vectorized intructions are available, and that call appeared to hang (https://www.internalfb.com/code/fbsource/[6cc47fa5eb77a93d91a519d3eb3df67ceddb8faa]/fbcode/caffe2/torch/_inductor/codecache.py?lines=568).

Since vectorized codegen for inductor cpu internally already isn't working, I hardcoded that test to fail for now in fbcode.

Test Plan:
Confirmed that this passes:
`buck2 test fbcode//mode/dev-nosan fbcode//caffe2/test/inductor:config`

Differential Revision: D47199912

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104560
Approved by: https://github.com/desertfire, https://github.com/bertmaher
2023-07-06 19:00:13 +00:00
XiaobingSuper
c4cf90aad1 inductor: fix assert error when load a bfloat16 inf constant (#104614)
Fix ```nanogpt_generate``` bfloat16 path error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104614
Approved by: https://github.com/jgong5, https://github.com/desertfire
2023-07-06 17:01:04 +00:00
Peter Bell
59b8d5be74 [inductor] Split ops.reduction into reduction and store_reduction (#102737)
This is intended as a first step towards reductions with multiple outputs. This
also incidentally improves CSE of reductions under C++ codegen. For example,
```python
def fn(x):
    return torch.argmin(x, dim=-1), torch.argmin(x, dim=-1)
```

Currently this generates two reductions, where the common load is CSEd
```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10); i1+=static_cast<long>(1L))
{
    auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
    if (tmp_acc0.value > tmp0) {
        tmp_acc0.index = i1; tmp_acc0.value = tmp0;
    }
    if (tmp_acc1.value > tmp0) {
        tmp_acc1.index = i1; tmp_acc1.value = tmp0;
    }
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
auto tmp2 = tmp_acc1.index;
out_ptr1[static_cast<long>(i0)] = tmp2;
```

but with this change it gets CSEd to a single accumulator

```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10L); i1+=static_cast<long>(1L))
{
    auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
    if (tmp_acc0.value > tmp0) {
        tmp_acc0.index = i1; tmp_acc0.value = tmp0;
    }
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
out_ptr1[static_cast<long>(i0)] = tmp1;
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102737
Approved by: https://github.com/jgong5, https://github.com/lezcano
2023-07-06 16:22:19 +00:00
Peter Bell
7e098f9559 [inductor] Add single pass "var_unnormalized" reduction_type (#102486)
This is a bit inefficient because it computes the mean and throws it
away since ir.Reduction nodes only have 1 output. However, the mean
can at least be scheduled into the same loop as the variance now since
there is no data dependency. Thus we can take fewer passes over the
data.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102486
Approved by: https://github.com/lezcano, https://github.com/jansel
2023-07-06 00:00:59 +00:00
leslie-fang-intel
ea4d5c4538 [Quant][PT2E] Enable vec code gen for pair of quant/dequant (#104503)
**Summary**
We have supported the vectorization code gen with pattern of `dequant-relu-quant`, for which `to_uint8` is the last node of quant pattern before store into memory. However, there is another case that `dequant1-relu-quant2-dequant2-relu-quant3`. In this case, `quant2` is at the middle of fusion pattern, we enable vectorization code gen of `quant2-dequant2` in this PR.

**Test Plan**
```
python -u -m pytest -s -v test_cpu_repro.py  -k test_dequant_relu_quant_dequant_relu_quant_lowering
```

**Next Step**
* For better performance, we can add another pass to eliminate pair nodes of `float_to_uint8` and `uint8_to_float`.
* For better performance, we should annotate `dequant1` and `quant2` as share observer in quantization recipe. Then we can lower `dequant1-relu-quant2` into a QReLU node to fully eliminate the calculation of `dequant1` and `quant2`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104503
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-07-05 01:59:00 +00:00
lezcano
7ae100628e Move most SymPy functions to their own file (#104556)
All these are standalone implementations of some functions and they
don't depend on anything else, so we better have them under the
`_sympy/` folder on their own

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104556
Approved by: https://github.com/ezyang
2023-07-04 03:53:48 +00:00
leslie-fang-intel
707d265db2 [Inductor][Quant]Refactor load and store vectorization code generation with uint8 data type (#104075)
**Summary**
Refactor the vectorization code generation of uint8 input data type. Previously, we combine the uint8 data load and uint8 to float data convert into one step as `load_uint8_as_float` and `store_float_as_uint8`. After refactor, we split them into 2 steps of load/store and data type convert to make the behavior same as BFloat16 data type .

The previous generated code is:
```
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(432L); i0+=static_cast<long>(16L))
{
    auto tmp0 = at::vec::load_uint8_as_float(in_ptr0 + static_cast<long>(i0));
    auto tmp1 = (tmp0);
    auto tmp2 = at::vec::Vectorized<float>(static_cast<float>(100.0));
    auto tmp3 = tmp1 - tmp2;
    auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(0.01));
    auto tmp5 = tmp3 * tmp4;
    auto tmp6 = at::vec::clamp_min(tmp5, decltype(tmp5)(0));
    auto tmp7 = tmp6 * tmp2;
    auto tmp8 = tmp7.round();
    auto tmp9 = tmp8 + tmp2;
    auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(0.0));
    auto tmp11 = at::vec::maximum(tmp9, tmp10);
    auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(255.0));
    auto tmp13 = at::vec::minimum(tmp11, tmp12);
    auto tmp14 = (tmp13);
    at::vec::store_float_as_uint8(tmp14, out_ptr0 + static_cast<long>(i0));
}
```

After this PR, the generated code is:
```
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(432L); i0+=static_cast<long>(16L))
{
    auto tmp0 = at::vec::Vectorized<uint8_t>::loadu(in_ptr0 + static_cast<long>(i0), 16);
    auto tmp1 = cvt_uint8_to_fp32_with_same_elem_num(tmp0);
    auto tmp2 = at::vec::Vectorized<float>(static_cast<float>(100.0));
    auto tmp3 = tmp1 - tmp2;
    auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(0.01));
    auto tmp5 = tmp3 * tmp4;
    auto tmp6 = at::vec::clamp_min(tmp5, decltype(tmp5)(0));
    auto tmp7 = tmp6 * tmp2;
    auto tmp8 = tmp7.round();
    auto tmp9 = tmp8 + tmp2;
    auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(0.0));
    auto tmp11 = at::vec::maximum(tmp9, tmp10);
    auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(255.0));
    auto tmp13 = at::vec::minimum(tmp11, tmp12);
    auto tmp14 = cvt_fp32_to_uint8(tmp13);
    tmp14.store(out_ptr0 + static_cast<long>(i0), 16);
}
```

**Test Plan**
```
python -m pytest test_cpu_repro.py -k test_decomposed_dequant_relu_quant
python -m pytest test_cpu_repro.py -k test_tile2d_load_decomposed_dequant_add_relu_quant
python -m pytest test_cpu_repro.py -k test_tile2d_store_channel_shuffle_cl_quant_output
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104075
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-07-01 23:12:43 +00:00
Brian Hirsh
624d20c3de kill inductor.config.disable_cpp_codegen in internal (#104351)
Summary:
This diff adds a path in inductor to invoke gcc through Remote Execution, when run from within fbcode.

This should (hopefully) let us kill the `inductor.disable_cpp_codegen` flag, since we should now be able to invoke clang at runtime from within fbcode to compile c++ code. This was preventing https://github.com/pytorch/pytorch/pull/100115 from landing, which fixed one of the last remaining models in torchbench that was failing with `torch.compile` (hf_Longformer).

Enumeration of changes:

- updated inductor to invoke `_run_build_command()` when in fbcode, which hooks into Remote Execution
- When inductor invokes g++ normally, it includes a bunch of absolute paths, to stuff like the pytorch header paths, and the input and output path. I changed these all to relative paths when in fbcode, and copied everything we needed into a temp dir that we send to Remote Execution.
- updated `triton/fb/make_build_paths.py` to let us grab paths to openmp, sleef, and ld from within the Remote Execution environment. I'm not sure if there's a better way to do this (but this way appeared to work, thanks to Bert's suggestion from https://www.internalfb.com/diff/D46482550?dst_version_fbid=231706286239076&transaction_fbid=229345569847706)
- factored `triton/fb/build.py` (it had a function to create a triton build command and run it all in one go, I separated the bit that takes in an arbitrary command (our clang command), and runs it with RE)
- a few tweaks to the include paths that inductor uses: it adds those two extra paths (sleef and openmp), and it also does not manually include the `-ltorch`,`-lc10`,`-ltorch_python`,`-ltorch_cpu` libs - the linker was complaining that it couldn't find those libs, and not including those flags ends up working
- I added a few more missing headers. Maybe with D46527002 this won't be necessary?
- I had a basic manual test in `scripts/hirsheybar/tmp2.py`. We probably want to try running an actual job in MAST to make sure this works.

Test Plan: `scripts/hirsheybar/pt2/tmp2.py` has a basic test, but I'm also planning on testing by kicking off a MAST job with cmf_10x (thanks to a bunch of help from Bert)

Reviewed By: bertmaher

Differential Revision: D46364355

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104351
Approved by: https://github.com/bertmaher
2023-06-30 13:32:16 +00:00
XiaobingSuper
a704251628 inductor: fix compile error of bfloat16 broadcast operation (#104319)
For the bfloat16 broadcast, there is always has compile error:
```
error: could not convert ‘tmp2’ from ‘Vectorized<float>’ to ‘Vectorized<c10::BFloat16>
```

This PR will fix this issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104319
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-06-30 04:14:38 +00:00
leslie-fang-intel
f8ac569365 [Inductor][Quant]Fix tile2d code generation issue with uint8 data type (#104074)
**Summary**
The previous vectorized code generation of tile2d doesn't support input data type of uint8, which still takes it as float and generate wrong result. This PR fixes this issue. Take UT `test_tile2d_load_decomposed_dequant_add_relu_quant` in this PR as example:
The previous generated code is:
```
#pragma GCC ivdep
for(long i1=static_cast<long>(0L); i1<static_cast<long>(192L); i1+=static_cast<long>(16L))
{
    unsigned char tmp0[16*16] __attribute__ ((aligned (16)));
    at::vec::transpose_mxn<unsigned char,16,16>(in_ptr0 + static_cast<long>(i0 + (1024L*i1)), static_cast<long>(1024L), tmp0, 16);
    unsigned char tmp7[16*16] __attribute__ ((aligned (16)));
    at::vec::transpose_mxn<unsigned char,16,16>(in_ptr1 + static_cast<long>(i0 + (1024L*i1)), static_cast<long>(1024L), tmp7, 16);
    for (long i0_inner = 0; i0_inner < 16; i0_inner++)
    {
        auto tmp1 = at::vec::Vectorized<float>::loadu(tmp0 + static_cast<long>(16L*i0_inner));
        auto tmp8 = at::vec::Vectorized<float>::loadu(tmp7 + static_cast<long>(16L*i0_inner));
        auto tmp2 = (tmp1);
        auto tmp3 = at::vec::Vectorized<float>(static_cast<float>(1.0));
        auto tmp4 = tmp2 - tmp3;
        auto tmp5 = at::vec::Vectorized<float>(static_cast<float>(0.01));
        auto tmp6 = tmp4 * tmp5;
        auto tmp9 = (tmp8);
        auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(2.0));
        auto tmp11 = tmp9 - tmp10;
        auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(0.02));
        auto tmp13 = tmp11 * tmp12;
        auto tmp14 = tmp6 + tmp13;
        auto tmp15 = at::vec::clamp_min(tmp14, decltype(tmp14)(0));
        auto tmp16 = at::vec::Vectorized<float>(static_cast<float>(33.333333333333336));
        auto tmp17 = tmp15 * tmp16;
        auto tmp18 = tmp17.round();
        auto tmp19 = at::vec::Vectorized<float>(static_cast<float>(3.0));
        auto tmp20 = tmp18 + tmp19;
        auto tmp21 = at::vec::Vectorized<float>(static_cast<float>(0.0));
        auto tmp22 = at::vec::maximum(tmp20, tmp21);
        auto tmp23 = at::vec::Vectorized<float>(static_cast<float>(255.0));
        auto tmp24 = at::vec::minimum(tmp22, tmp23);
        auto tmp25 = (tmp24);
        at::vec::store_float_as_uint8(tmp25, out_ptr0 + static_cast<long>(i1 + (196L*i0) + (196L*i0_inner)));
    }
}
```

After this PR, the generated code is:
```
#pragma GCC ivdep
for(long i1=static_cast<long>(0L); i1<static_cast<long>(192L); i1+=static_cast<long>(16L))
{
    unsigned char tmp0[16*16] __attribute__ ((aligned (16)));
    at::vec::transpose_mxn<unsigned char,16,16>(in_ptr0 + static_cast<long>(i0 + (1024L*i1)), static_cast<long>(1024L), tmp0, 16);
    unsigned char tmp7[16*16] __attribute__ ((aligned (16)));
    at::vec::transpose_mxn<unsigned char,16,16>(in_ptr1 + static_cast<long>(i0 + (1024L*i1)), static_cast<long>(1024L), tmp7, 16);
    for (long i0_inner = 0; i0_inner < 16; i0_inner++)
    {
        auto tmp1 = at::vec::load_uint8_as_float(tmp0 + static_cast<long>(16L*i0_inner));
        auto tmp8 = at::vec::load_uint8_as_float(tmp7 + static_cast<long>(16L*i0_inner));
        auto tmp2 = (tmp1);
        auto tmp3 = at::vec::Vectorized<float>(static_cast<float>(1.0));
        auto tmp4 = tmp2 - tmp3;
        auto tmp5 = at::vec::Vectorized<float>(static_cast<float>(0.01));
        auto tmp6 = tmp4 * tmp5;
        auto tmp9 = (tmp8);
        auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(2.0));
        auto tmp11 = tmp9 - tmp10;
        auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(0.02));
        auto tmp13 = tmp11 * tmp12;
        auto tmp14 = tmp6 + tmp13;
        auto tmp15 = at::vec::clamp_min(tmp14, decltype(tmp14)(0));
        auto tmp16 = at::vec::Vectorized<float>(static_cast<float>(33.333333333333336));
        auto tmp17 = tmp15 * tmp16;
        auto tmp18 = tmp17.round();
        auto tmp19 = at::vec::Vectorized<float>(static_cast<float>(3.0));
        auto tmp20 = tmp18 + tmp19;
        auto tmp21 = at::vec::Vectorized<float>(static_cast<float>(0.0));
        auto tmp22 = at::vec::maximum(tmp20, tmp21);
        auto tmp23 = at::vec::Vectorized<float>(static_cast<float>(255.0));
        auto tmp24 = at::vec::minimum(tmp22, tmp23);
        auto tmp25 = (tmp24);
        at::vec::store_float_as_uint8(tmp25, out_ptr0 + static_cast<long>(i1 + (196L*i0) + (196L*i0_inner)));
    }
}
```

**Test Plan**
```
python -m pytest test_cpu_repro.py -k test_tile2d_load_decomposed_dequant_add_relu_quant
python -m pytest test_cpu_repro.py -k test_tile2d_store_channel_shuffle_cl_quant_output
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104074
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-06-27 00:59:05 +00:00
Antoni Viros i Martin
0d653730ce Refactory bits for the codegen cache (#103452)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103452
Approved by: https://github.com/ezyang
2023-06-22 13:04:22 +00:00
XiaobingSuper
01abccf63f inductor: fix CppTile2D bf16 store complier error for cpp backend (#103659)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103659
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-06-19 00:46:30 +00:00
XiaobingSuper
b287cb816c inductor: make the vec_transpose's tiling stride doesn't depend on out_idx and tiling_idex (#103651)
For TIMM swin_base_patch4_window7_224 dynamic shape path, there has an accuracy issue with horizontal reduction with vec_transpose:
```
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(ks0); i0+=static_cast<long>(1L))
{
    #pragma GCC ivdep
    for(long i1=static_cast<long>(0L); i1<static_cast<long>(3136L); i1+=static_cast<long>(16L))
    {
        {
            #pragma omp declare reduction(+:at::vec::Vectorized<float>:omp_out = omp_out + omp_in) initializer(omp_priv={{0}})
            float tmp_acc0 = 0;
            auto tmp_acc0_vec = at::vec::Vectorized<float>(tmp_acc0);
            for(long i2=static_cast<long>(0L); i2<static_cast<long>(128L); i2+=static_cast<long>(16L))
            {
                float tmp1[16*16] __attribute__ ((aligned (16)));
                at::vec::transpose_mxn<float,16,16>(in_ptr1 + static_cast<long>(i2 + (128L*(static_cast<long>((static_cast<long>(i1) % static_cast<long>(56L))) % static_cast<long>(7L))) + (896L*(static_cast<long>(at::native::div_floor_integer(i1, 56L)) % static_cast<long>(7L))) + (6272L*(at::native::div_floor_integer((static_cast<long>(i1) % static_cast<long>(56L)), 7L))) + (50176L*(at::native::div_floor_integer(i1, 392L))) + (401408L*i0)), static_cast<long>(((-50176L)*(at::native::div_floor_integer(i1, 392L))) + ((-6272L)*(at::native::div_floor_integer((static_cast<long>(i1) % static_cast<long>(56L)), 7L))) + ((-896L)*(static_cast<long>(at::native::div_floor_integer(i1, 56L)) % static_cast<long>(7L))) + ((-128L)*(static_cast<long>((static_cast<long>(i1) % static_cast<long>(56L))) % static_cast<long>(7L))) + (128L*(static_cast<long>((static_cast<long>((1L + i1)) % static_cast<long>(56L))) % static_cast<long>(7L))) + (896L*(static_cast<long>(at::native::div_floor_integer((1L + i1), 56L)) % static_cast<long>(7L))) + (6272L*(at::native::div_floor_integer((static_cast<long>((1L + i1)) % static_cast<long>(56L)), 7L))) + (50176L*(at::native::div_floor_integer((1L + i1), 392L)))), tmp1, 16);
                for (long i2_inner = 0; i2_inner < 16; i2_inner++)
                {
                    auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(i1 + (3136L*i2) + (3136L*i2_inner) + (401408L*i0)));
                    auto tmp2 = at::vec::Vectorized<float>::loadu(tmp1 + static_cast<long>(16L*i2_inner));
                    auto tmp3 = tmp0 + tmp2;
                    tmp_acc0_vec = tmp_acc0_vec + tmp3;
                }
            }
            tmp_acc0_vec.store(out_ptr0 + static_cast<long>(i1 + (3136L*i0)));
        }
    }
}
```

The ```transpose_mxn```'s ```ld_src``` depends on ```i1``` which is not expected. This PR will  add a check to make sure the tiling stride doesn't depend on out_idx(```i2```) and tiling_idex(```i1```)

After this PR, the generated code will be like this:
```
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(ks0); i0+=static_cast<long>(1L))
{
    #pragma GCC ivdep
    for(long i1=static_cast<long>(0L); i1<static_cast<long>(3136L); i1+=static_cast<long>(16L))
    {
        {
            #pragma omp declare reduction(+:at::vec::Vectorized<float>:omp_out = omp_out + omp_in) initializer(omp_priv={{0}})
            float tmp_acc0 = 0;
            auto tmp_acc0_vec = at::vec::Vectorized<float>(tmp_acc0);
            for(long i2=static_cast<long>(0L); i2<static_cast<long>(128L); i2+=static_cast<long>(16L))
            {
                for (long i2_inner = 0; i2_inner < 16; i2_inner++)
                {
                    auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(i1 + (3136L*i2) + (3136L*i2_inner) + (401408L*i0)));
                    auto tmp1 = ([&]() { __at_align__ float tmpbuf[16]; for (long i1_inner = 0; i1_inner < 16; i1_inner++) tmpbuf[i1_inner] = in_ptr1[static_cast<long>(i2 + i2_inner + (128L*(static_cast<long>((static_cast<long>((i1 + i1_inner)) % static_cast<long>(56L))) % static_cast<long>(7L))) + (896L*(static_cast<long>(at::native::div_floor_integer((i1 + i1_inner), 56L)) % static_cast<long>(7L))) + (6272L*(at::native::div_floor_integer((static_cast<long>((i1 + i1_inner)) % static_cast<long>(56L)), 7L))) + (50176L*(at::native::div_floor_integer((i1 + i1_inner), 392L))) + (401408L*i0))]; return at::vec::Vectorized<float>::loadu(tmpbuf); })();
                    auto tmp2 = tmp0 + tmp1;
                    tmp_acc0_vec = tmp_acc0_vec + tmp2;
                }
            }
            tmp_acc0_vec.store(out_ptr0 + static_cast<long>(i1 + (3136L*i0)));
        }
    }
}
```

How to reproduce this issue:
```
python -m torch.backends.xeon.run_cpu --node_id 0 benchmarks/dynamo/timm_models.py --accuracy --float32 -dcpu --inference -n5 --inductor --dynamic-shapes --only swin_base_patch4_window7_224
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103651
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-06-16 03:56:39 +00:00
XiaobingSuper
da21273ad5 inductor: support rsqrt for dynamic shape (#103579)
Fix compiler error for HF hf_BigBird dynamic shape path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103579
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-06-15 07:02:18 +00:00
Nikita Shulga
5c252f2c7c [Inductor/cpp] Fix reduction on pre clang-10 (#103347)
`#pragma omp declare reduction` is not supported before clang-10 and results in a misleading compiler error in the following example:
```c++

template<typename T>
T max_propagate_nan(T, T);

extern "C" void cpp_fused_argmax_max_sum_0(const float* in_ptr0,
                       float* out_ptr0,
                       float* out_ptr1,
                       long* out_ptr2)
{
    float tmp_acc0 = 0;
    float tmp_acc1 = -std::numeric_limits<float>::infinity();
    float tmp_acc2 = std::numeric_limits<float>::infinity();
    struct IndexValue_7 {size_t index; float value;};
    IndexValue_7 tmp_acc3{0, -std::numeric_limits<float>::infinity()};
    #pragma omp declare reduction(argmax : IndexValue_7 :                omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value,                omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index)               initializer(omp_priv = {0, -std::numeric_limits<float>::infinity()})
    for(long i0=static_cast<long>(0L); i0<static_cast<long>(3L); i0+=static_cast<long>(1L))
    {
        auto tmp0 = in_ptr0[static_cast<long>(i0)];
        tmp_acc0 = tmp_acc0 + tmp0;
        tmp_acc1 = max_propagate_nan(tmp_acc1, tmp0);
        if (tmp_acc3.value < tmp0) {
            tmp_acc3.index = i0; tmp_acc3.value = tmp0;
        }
    }
    out_ptr0[static_cast<long>(0L)] = tmp_acc0;
    out_ptr1[static_cast<long>(0L)] = tmp_acc1;
    out_ptr2[static_cast<long>(0L)] = tmp_acc3.index;
}
```

```
% clang++-10 -std=c++17 -fopenmp bar.cpp  -c -O3
% clang++-9 -std=c++17 -fopenmp bar.cpp  -c -O3
bar.cpp:17:149: error: expected ')'
    #pragma omp declare reduction(argmax : IndexValue_7 :                omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value,                omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index)               initializer(omp_priv = {0, -std::numeric_limits<float>::infinity()})
                                                                                                                                                    ^
bar.cpp:17:34: note: to match this '('
    #pragma omp declare reduction(argmax : IndexValue_7 :                omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value,                omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index)               initializer(omp_priv = {0, -std::numeric_limits<float>::infinity()})
                                 ^
1 error generated.
```

Also, remove unnecessary `struct` keyword in front of type, as C++ compiler already assumes that (and again, it causes problem with clang++-10 implementation)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103347
Approved by: https://github.com/voznesenskym
2023-06-10 02:53:37 +00:00
David Berard
cde4657284 [inductor] Support complex fallback for convert_element_type, _fft_c2c, view_as_real to support GoogleFnet with cpp wrapper (#103183)
Fixes #102752

These 3 fallback kernels appear in GoogleFnet because they take complex arguments - i.e., usually they aren't fallback kernels. To support this model, we added support for these 3 ops.

Details:
1. Add these 3 ops to the allowlist. I assume that we eventually want to support all fallback kernels, but for now we just add these 3 ops to the allowlist.
2. Support complex64 in cpp codegen
3. Support List[] arguments and ScalarType arguments in cpp codegen
4. Allow alias_info in schema arguments. In the original PR supporting fallback kernels for cpp wrapper, ops with schemas with non-null alias_info for any of the arguments were disallowed; but I don't think there's any reason we need to disallow these in cpp wrapper code.

Caveats:
* This has not added support for complex32 or complex128
* It only works with static shapes, not dynamic shapes. It seems like the dynamic shapes issue is unrelated to cpp wrapper, since it fails in the test_torchinductor_dynamic_shapes.py test. I checked these `test_fft_.*` tests, which I added in this PR, and verified that they were broken with dynamic shapes before any of the code changes from this PR.

**Test**:

```
benchmarks/dynamo/huggingface.py --inductor --amp --accuracy --inference --device cuda   --cpp-wrapper --only GoogleFnet
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103183
Approved by: https://github.com/desertfire, https://github.com/jgong5, https://github.com/chunyuan-w
2023-06-09 21:12:41 +00:00
XiaobingSuper
8e5b7ce5db inductor: fix bf16 legalization issue for fp32 load with to bf16 case (#103080)
Giving following ir:

```
    def body(self, ops):
        get_index = self.get_index('index0')
        index_expr = ops.index_expr(get_index, torch.int32)
        constant = ops.constant(4, torch.int32)
        lt = ops.lt(index_expr, constant)
        masked_subblock1 = self.masked_subblock1(lt, 0.0)
        get_index_1 = self.get_index('index3')
        load = ops.load('arg2_1', get_index_1)
        to_dtype = ops.to_dtype(load, torch.bfloat16)
        where = ops.where(lt, masked_subblock1, to_dtype)
        get_index_2 = self.get_index('index3')
        store = ops.store('buf0', get_index_2, where, None)
        return store
    def masked_subblock2(self, ops):
        get_index = self.get_index('index2')
        load = ops.load('arg1_1', get_index)
        return load
    def masked_subblock1(self, ops):
        get_index = self.get_index('index1')
        index_expr = ops.index_expr(get_index, torch.int32)
        constant = ops.constant(1, torch.int32)
        ge = ops.ge(index_expr, constant)
        get_index_1 = self.get_index('index1')
        index_expr_1 = ops.index_expr(get_index_1, torch.int32)
        constant_1 = ops.constant(3, torch.int32)
        lt = ops.lt(index_expr_1, constant_1)
        and_ = ops.and_(ge, lt)
        masked_subblock2 = self.masked_subblock2(and_, 0.0)
        get_index_2 = self.get_index('index3')
        load = ops.load('arg2_1', get_index_2)
        to_dtype = ops.to_dtype(load, torch.bfloat16)
        where = ops.where(and_, masked_subblock2, to_dtype)
        return where
```

before this PR, the ```masked_subblock2``` will legalize as ```load_bf16+to_fp32```, and the ```masked_subblock2```'s output type is ```fp32```, but for ```load = ops.load('arg2_1', get_index_2), to_dtype = ops.to_dtype(load, torch.bfloat16)```, we didn't convert ```to_bf16``` as ```to_fp32```, which the ```op.where``` has mixed type computation, and will has compiler error: ```error: operands to ?: have different types ‘float’ and ‘c10::BFloat16’```.

This PR will always convert ```to_bf16``` as ```to_fp32``` to fix such an issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103080
Approved by: https://github.com/jgong5, https://github.com/desertfire
2023-06-09 00:33:10 +00:00
Yanbo Liang
686d7e4c48 [Inductor] Fix x.view(dtype) decomp and make inductor support it (#102920)
Fixes #99804

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102920
Approved by: https://github.com/jansel, https://github.com/ngimel
2023-06-07 17:10:54 +00:00
haozhe.zhu
adcefcb378 insert to dtype for fused mem copy scheduler node (#101042)
Fix https://github.com/pytorch/pytorch/issues/100830.

For the inplace node, there will be a `copy_` generated and the `copy_` will be `realized` as a `scheduler buffer` since it is a mutation. This `scheduler buffer` is a memory copy but after fusing with the previous buffer, it will not be a memory copy only buffers.
This PR solves the issue by removing `load_bf16_as_fp32` and `store_bf16_from_fp32`. Instead, enable fp32/bf16 vec conversion in `to_dtype`. Then we always store bf16.

```python
import torch
import torch.nn as nn
torch.manual_seed(420)
from torch._inductor import config

x = torch.randn(1, 18, dtype=torch.bfloat16)

class ExampleModel(nn.Module):

    def __init__(self):
        super(ExampleModel, self).__init__()
        self.relu = nn.ReLU(inplace=True) # nn.ReLU(inplace=False)

    def forward(self, input1):
        out = self.relu(input1)
        # input1.copy_(out)
        return out

func = ExampleModel()

with torch.no_grad():
    func.train(False)
    res1 = func(x) # without jit
    print(res1)

    jit_func = torch.compile(func)
    res2 = jit_func(x)
    print(res2)
```

Generated code without this PR: (`tm3` store is wrong, `tmp3` is `float` while `out_ptr1` is `bf16`)
```
            auto tmp0 = load_bf16_as_float(out_ptr1 + static_cast<long>(i0));
            auto tmp1 = (tmp0);
            auto tmp2 = at::vec::clamp_min(tmp1, decltype(tmp1)(0));
            auto tmp3 = (tmp2);
            store_float_as_bf16(out_ptr0 + static_cast<long>(i0), tmp3);
            tmp3.store(out_ptr1 + static_cast<long>(i0), 16);
```

Generated code with this PR:
```
            auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(out_ptr1 + static_cast<long>(i0), 16);
            auto tmp1 = cvt_bf16_to_fp32(tmp0);
            auto tmp2 = at::vec::clamp_min(tmp1, decltype(tmp1)(0));
            auto tmp3 = cvt_fp32_to_bf16(tmp2);
            tmp3.store(out_ptr0 + static_cast<long>(i0), 16);
            tmp3.store(out_ptr1 + static_cast<long>(i0), 16);
```

This PR also fixed the data type propagation for `masked_subblock`.
Before the masked_subblock's dtype is propagated by its input which is wrong.
```
opcode       name       target     args                        kwargs
-----------  ---------  ---------  --------------------------  --------
call_module  masked_subblock1  masked_subblock1  (and__2, -inf)
```
Now we propagated it by subblock with the same name:

```
# graph for body.subblocks['masked_subblock1']
opcode       name       target     args                        kwargs
-----------  ---------  ---------  --------------------------  --------
placeholder  ops        ops        ()                          {}
call_module  get_index  get_index  ('index2',)                 {}
call_method  load       load       (ops, 'arg0_1', get_index)  {}
call_method  to_dtype   to_dtype   (ops, load, torch.float32)  {}
output       output     output     (to_dtype,)                 {}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101042
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-06-07 15:55:25 +00:00
Aleksandar Samardžić
51e0f9e858 Add missing decompositons/lowerings for logical/bitwise operators (#102566)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102566
Approved by: https://github.com/lezcano, https://github.com/alexsio27444, https://github.com/jgong5
2023-06-02 14:27:17 +00:00
XiaobingSuper
1204463bd0 inductor: fix bfloat16 reduction crash issue which store float value to bfloat16 (#102719)
For bfloat16 reduction, there has an wrong store issue which store float value as bfloat16:

Before:

```

extern "C" void kernel(const bfloat16* in_ptr0,
                       bfloat16* out_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(40)
    {
        {
            #pragma omp for
            for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L); i0+=static_cast<long>(16L))
            {
                {
                    #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={{-std::numeric_limits<float>::infinity()}})
                    float tmp_acc0 = -std::numeric_limits<float>::infinity();
                    auto tmp_acc0_vec = at::vec::Vectorized<float>(tmp_acc0);
                    for(long i1=static_cast<long>(0L); i1<static_cast<long>(32L); i1+=static_cast<long>(1L))
                    {
                        auto tmp0 = load_bf16_as_float(in_ptr0 + static_cast<long>(i0 + (16L*i1)));
                        auto tmp1 = (tmp0);
                        tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp1);
                    }
                    tmp_acc0_vec.store(out_ptr0 + static_cast<long>(i0));
                }
            }
        }
        #pragma omp single
        {
            {
                for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L); i0+=static_cast<long>(16L))
                {
                    auto tmp0 = load_bf16_as_float(out_ptr0 + static_cast<long>(i0));
                    auto tmp1 = (tmp0);
                    tmp1.store(out_ptr1 + static_cast<long>(i0));
                }
            }
        }
    }
}
''')

```

after:

```
extern "C" void kernel(const bfloat16* in_ptr0,
                       bfloat16* out_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(40)
    {
        {
            #pragma omp for
            for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L); i0+=static_cast<long>(16L))
            {
                {
                    #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={{-std::numeric_limits<float>::infinity()}})
                    float tmp_acc0 = -std::numeric_limits<float>::infinity();
                    auto tmp_acc0_vec = at::vec::Vectorized<float>(tmp_acc0);
                    for(long i1=static_cast<long>(0L); i1<static_cast<long>(32L); i1+=static_cast<long>(1L))
                    {
                        auto tmp0 = load_bf16_as_float(in_ptr0 + static_cast<long>(i0 + (16L*i1)));
                        auto tmp1 = (tmp0);
                        tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp1);
                    }
                    store_float_as_bf16(out_ptr0 + static_cast<long>(i0), tmp_acc0_vec);
                }
            }
        }
        #pragma omp single
        {
            {
                for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L); i0+=static_cast<long>(16L))
                {
                    auto tmp0 = load_bf16_as_float(out_ptr0 + static_cast<long>(i0));
                    auto tmp1 = (tmp0);
                    tmp1.store(out_ptr1 + static_cast<long>(i0));
                }
            }
        }
    }
}
''')

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102719
Approved by: https://github.com/jansel, https://github.com/jgong5
2023-06-02 08:34:29 +00:00
Peter Bell
2f96981e5a [inductor] Reduce duplication of reduction combine functions (#99661)
Currently reduction bodies are duplicated in several different places.
This reduces duplication by `combine_fn` definition used in
`_unroll_reduction_fn` and using it in the triton codegen. For cpp
this also makes better use of `reduction_combine{,_vec}` by using them
to generate the `omp declare reduction` line and the `vec_reduce_all`
call.

For triton the only change is that that the combine step gets spread
over two lines, e.g. instead of:
```python
_tmp1 = tl.where(rmask & xmask, triton_helpers.maximum(_tmp1, tmp0), _tmp1)
```
we get
```python
tmp2 = triton_helpers.maximum(_tmp1, tmp0)
_tmp1 = tl.where(rmask & xmask, tmp2, _tmp1)
```

For cpp the only change is that inplace reduction operations are now written as
an out-of-place operation and an assignment, e.g. instead if
```cpp
omp_out += omp_in
```
we generate
```cpp
omp_out = omp_out + omp_in
```

Which is a purely cosmetic change

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99661
Approved by: https://github.com/lezcano, https://github.com/ngimel
2023-06-01 18:02:17 +00:00
XiaobingSuper
49cd184f89 inductor: improve the index range check for index_expr vec check (#102263)
Fix https://github.com/pytorch/pytorch/issues/102065.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102263
Approved by: https://github.com/lezcano, https://github.com/peterbell10, https://github.com/jgong5
2023-06-01 03:07:14 +00:00
kshitij12345
b1bc8aecf5 [inductor] erfinv: CPU/CUDA lowering (#101863)
Add `erfinv` lowering for CUDA. On CPU, we just fallback to the aten operator.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101863
Approved by: https://github.com/lezcano, https://github.com/peterbell10
2023-05-29 15:31:54 +00:00
Wang, Eikan
ce41faa2ae Add cpp.max_horizontal_fusion_size to control the granularity of horizontal fusion (#99828)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99828
Approved by: https://github.com/jansel, https://github.com/jgong5
2023-05-26 05:20:49 +00:00
Wang, Eikan
6f464e0cf8 Invoke the bf16 load w/o #elements to bypass the temporary buffer allocation from the performance perspective. (#99822)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99822
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-05-26 02:10:41 +00:00
Wang, Eikan
c3550d8376 Add fast path for BF16 kernel if all the operations within the kernel support bf16 (#99814)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99814
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-05-26 02:08:53 +00:00
XiaobingSuper
4882cd0801 inductor: align cpp floordiv with python floordiv for dyanmic shape path (#102068)
This PR does the following things:

- Align the C++ behavior with Python for FloorDiv.
- Always return expr dtype for some ops which not use expr's dtype to do the computation.

After this PR, TIMM ```levit_128``` and ```volo_d1_224``` accuracy tests can be passed for dynamic shape path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102068
Approved by: https://github.com/jgong5, https://github.com/ngimel
2023-05-25 10:18:45 +00:00
Bin Bao
431344f2d0 [inductor] Refactor generate_kernel_call (#102018)
Summary: Refactor generate_kernel_call to support codegen call to Triton
kernel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102018
Approved by: https://github.com/jansel, https://github.com/jgong5
2023-05-23 15:54:49 +00:00
Jason Ansel
0c6f409cda [inductor] Refactor RNG operators (#100064)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100064
Approved by: https://github.com/ngimel
2023-05-20 03:43:33 +00:00
XiaobingSuper
350f0cd78c inductor: fix bfloat16 store complier issue (#101856)
Fix the bfloat16 compiler error:
```
/tmp/torchinductor_xiaobing/ez/cezrraw7rtu5vkxcfd544i53crqaobycprf5twyvf7b62jrgi75p.cpp: In function ‘void kernel(const bfloat16*, bfloat16*)’:
/tmp/torchinductor_xiaobing/ez/cezrraw7rtu5vkxcfd544i53crqaobycprf5twyvf7b62jrgi75p.cpp:20:79: error: expected ‘;’ before ‘}’ token
   20 |                         tmp0.store(tmp1 + static_cast<long>(16L*i1_inner), 16)
      |                                                                               ^
      |                                                                               ;
   21 |                     }

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101856
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/desertfire
2023-05-20 01:41:41 +00:00
XiaobingSuper
61b6b038b0 inductor: fix FloorDiv issue for dynamic shape path (#101793)
For TIMM ```tf_mixnet_l``` cpu dynamic shape path, we always get a wrong result compared with eager mode, the root cause is that we compute a wrong index when doing vectorization:

```
or(long i2=static_cast<long>(0L); i2<static_cast<long>(16L*(((std::ceil((1.0/2.0)*(std::ceil((1.0/2.0)*(std::ceil((1.0/2.0)*(std::ceil((1.0/2.0)*ks1))))))))*(std::ceil((1.0/2.0)*(std::ceil((1.0/2.0)*(std::ceil((1.0/2.0)*(std::ceil((1.0/2.0)*ks1))))))))) / 16L)); i2+=static_cast<long>(16L))
```
the main loop's index using ```/``` rather than ```//```. After this PR, the ```tf_mixnet_l``` accuracy test can be passed.

How to reproduce this issue?

```
python -m torch.backends.xeon.run_cpu --node_id 0 benchmarks/dynamo/timm_models.py --accuracy --float32 -dcpu --inference -n5 --inductor --dynamic-shapes --only tf_mixnet_l
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101793
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/ezyang
2023-05-19 12:39:27 +00:00
Peter Bell
66e398951a [inductor/decomp] Add aten._unsafe_index to disable range checks (#101602)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101602
Approved by: https://github.com/lezcano, https://github.com/ngimel
2023-05-17 23:36:24 +00:00
Peter Bell
b256091c7b [inductor] Generate indirect_indexing checks even if optimized out (#100895)
Fixes #100831, fixes #100878

Previously `gen_assert_indirect_indexing` was only called on the index
expressions passed to `ops.load` and `ops.store` which means if the
variable is optimized out during lowering, we never generate the
assert. This instead makes `ops.indirect_indexing` eagerly generate
the assert statement, whether or not it will be used.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100895
Approved by: https://github.com/lezcano, https://github.com/ngimel
2023-05-17 23:36:24 +00:00
PyTorch MergeBot
5f07c589b0 Revert "[inductor] Refactor RNG operators (#100064)"
This reverts commit 3bbf0683a1.

Reverted https://github.com/pytorch/pytorch/pull/100064 on behalf of https://github.com/izaitsevfb due to breaks inductor tests, see D45936056 ([comment](https://github.com/pytorch/pytorch/pull/100064#issuecomment-1552093728))
2023-05-17 21:16:41 +00:00
Liao, Xuan
6261aa5c8d [inductor][cpp] support non contiguous vectorization codegen (#99966)
Currently, cpp vectorization is supported only when the node has at least one contiguous index. The PR enables cpp vectorization when all indices in the node are non-contiguous. Specifically, the most inner index is selected as the tiling index.

### Validation
For the E2E performance and functionality, both inference and training model suites for data type float32 and bfloat16 are validated. All the results show that there is no performance regression and no new failures compared with baseline.

### Code
The modification could help certain kernels in GPT-J do vectorization. Here is a snippet of output code change.

**Before**
```
{
        #pragma GCC ivdep
        for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L*ks0); i0+=static_cast<long>(1L))
        {
            #pragma GCC ivdep
            for(long i1=static_cast<long>(0L); i1<static_cast<long>(32L); i1+=static_cast<long>(1L))
            {
                auto tmp0 = in_ptr0[static_cast<long>(1L + (2L*i1) + (256L*i0))];
                auto tmp1 = static_cast<float>(tmp0);
                auto tmp2 = decltype(tmp1)(-tmp1);
                auto tmp3 = static_cast<bfloat16>(tmp2);
                out_ptr0[static_cast<long>((2L*i1) + (64L*i0))] = tmp3;
            }
        }
    }
    {
        #pragma GCC ivdep
        for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L*ks0); i0+=static_cast<long>(1L))
        {
            #pragma GCC ivdep
            for(long i1=static_cast<long>(0L); i1<static_cast<long>(32L); i1+=static_cast<long>(1L))
            {
                auto tmp0 = in_ptr0[static_cast<long>((2L*i1) + (256L*i0))];
                out_ptr1[static_cast<long>((2L*i1) + (64L*i0))] = tmp0;
            }
        }
    }
```
**After**
```
{
        #pragma GCC ivdep
        for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L*ks0); i0+=static_cast<long>(1L))
        {
            for(long i1=static_cast<long>(0L); i1<static_cast<long>(32L); i1+=static_cast<long>(16L))
            {
                auto tmp0 = ([&]() { __at_align__ bfloat16 tmpbuf[16 * 2]; for (long i1_inner = 0; i1_inner < 16; i1_inner++) tmpbuf[i1_inner] = in_ptr0[static_cast<long>(1L + (2L*i1_inner) + (2L*i1) + (256L*i0))]; return load_bf16_as_float(tmpbuf); })();
                auto tmp1 = (tmp0);
                auto tmp2 = tmp1.neg();
                auto tmp3 = (tmp2);
                { __at_align__ bfloat16 tmpbuf[16*sizeof(float)/sizeof(bfloat16)]; store_float_as_bf16(tmpbuf, tmp3); for (long i1_inner = 0; i1_inner < 16; i1_inner++) out_ptr0[static_cast<long>((2L*i1_inner) + (2L*i1) + (64L*i0))] = tmpbuf[i1_inner]; }
            }
        }
    }
    {
        #pragma GCC ivdep
        for(long i0=static_cast<long>(0L); i0<static_cast<long>(16L*ks0); i0+=static_cast<long>(1L))
        {
            for(long i1=static_cast<long>(0L); i1<static_cast<long>(32L); i1+=static_cast<long>(16L))
            {
                auto tmp0 = ([&]() { __at_align__ bfloat16 tmpbuf[16 * 2]; for (long i1_inner = 0; i1_inner < 16; i1_inner++) tmpbuf[i1_inner] = in_ptr0[static_cast<long>((2L*i1_inner) + (2L*i1) + (256L*i0))]; return at::vec::Vectorized<bfloat16>::loadu(tmpbuf, 16); })();
                { __at_align__ bfloat16 tmpbuf[16*sizeof(float)/sizeof(bfloat16)]; tmp0.store(tmpbuf, 16); for (long i1_inner = 0; i1_inner < 16; i1_inner++) out_ptr1[static_cast<long>((2L*i1_inner) + (2L*i1) + (64L*i0))] = tmpbuf[i1_inner]; }
            }
        }
    }
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99966
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-05-17 05:19:22 +00:00
Jason Ansel
3bbf0683a1 [inductor] Refactor RNG operators (#100064)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100064
Approved by: https://github.com/ngimel
2023-05-17 01:29:31 +00:00
kshitij12345
2b2a717f19 [inductor] erfc: lowering (#101416)
Codegen support was already present. This PR just removes the fallback.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101416
Approved by: https://github.com/lezcano
2023-05-16 14:31:13 +00:00
XiaobingSuper
88b6a4577b inductor: fix sign gets wrong result dtype issue (#101377)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101377
Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/jansel
2023-05-16 08:01:06 +00:00
Jiong Gong
6f7ebcdcd8 [inductor] enable descriptive name for cpp kernels (#101330)
This PR enables the descriptive name for cpp kernels similar to the triton kernel name. A new configuration `config.cpp.descriptive_names` is added similar to that of triton. The kernel name follows the format: `cpp_<fused_name>_<id>`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101330
Approved by: https://github.com/XiaobingSuper, https://github.com/jansel
2023-05-16 06:48:11 +00:00
chunyuan
cc54da4877 Inductor cpp wrapper: fix FallbackKernel support (#100788)
Fixes cpp wrapper support for kernels that are not exposed in `torch.ops.aten`. The current PR limits the support scope to `repeat_interleave.Tensor` and will submit follow-up PRs for more OPs.

The PR maps the python schema of the kernel to the cpp schema and uses `c10::Dispatcher::singleton().findSchemaOrThrow` to find the corresponding cpp OP.

The current support is limited and will raise `AssertionError` for unsupported cases.
The limitation includes:
- only support kernel that is not alias
- only support kernel the args and returns of which don't have `alias_info`
- only support output args to be a `Tensor`
- only support input args to be `Tensor`, `Optional[int]`, `Optional[float]` and `Optional[bool]`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100788
Approved by: https://github.com/jgong5, https://github.com/desertfire
2023-05-15 00:45:44 +00:00
XiaobingSuper
b1a8a10a73 inductor(CPU): fix masked_fill issue when filled value is nan (#101058)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101058
Approved by: https://github.com/jgong5, https://github.com/desertfire
2023-05-11 00:57:04 +00:00
Jiong Gong
b33c9c7c9f [inductor] support vec type conversion between float and bool (#100950)
Fix https://github.com/pytorch/pytorch/issues/100466
Fix https://github.com/pytorch/pytorch/issues/100800

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100950
Approved by: https://github.com/EikanWang, https://github.com/jansel
2023-05-10 02:16:06 +00:00
Peter Bell
4918940184 [inductor] Fix nan-handling of max and min reductions (#100572)
This adds helpers that replace tritons `minimum`, `maximum`, `min` and
`max` with the correct NaN prrpagation. I also removed
`ops.int_minimum` in favor of `ops.minimum` because we can just omit
the nan-checks by checking the dtype.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100572
Approved by: https://github.com/ngimel
2023-05-04 13:07:27 +00:00
Edward Z. Yang
c7e9f40653 Misc accuracy improvements on minifier (#100447)
The changes:

* Add config knob `same_two_models_use_fp64` for toggling whether or not to use fp64
* Add a test showing that RMSE is superior to atol/rtol
* Add `--strict-accuracy` options, which allows for testing against integral/boolean accuracy.  Regular accuracy by default now ONLY. There's a test which exercises this, it's a little delicate but I had trouble thinking of a good test otherwise.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100447
Approved by: https://github.com/voznesenskym
2023-05-04 02:51:26 +00:00
Edward Z. Yang
db4572dbf1
Revert tl.reduce usage (#100521)
Test Plan: sandcastle

Reviewed By: bertmaher

Differential Revision: D45513572

fbshipit-source-id: a03df851503f72313dfb50238e7d6db9239bf42e
2023-05-03 12:20:33 -04:00
Edward Z. Yang
0a479d9b9c Simplify minifier testing by incorporating fault injection in prod code (#100357)
Previously, minifier testing injected faults by injecting extra code
into the repro scripts, and then ensuring this code got propagated to
all subsequent subprocess calls.  This was not only quite complicated,
but also induced a big slowdown on the minifier, because to inject the
faults, you had to import torch._inductor, which would cause the
compilation threads to immediately get initialized before you even got
to do anything else in the repro script.

This new approach fixes this problem by incorporating the fault
injection into "prod" code.  Essentially, for inductor fault injection
we introduce some new config flags that let you "configure" Inductor to
be buggy; for Dynamo fault injection we just permanently keep the buggy
testing backends registered.  This is MUCH simpler: we only have to
propagate the buggy config (which is something we're already doing),
and it saves the minifier scripts from having to immediately initialize
inductor on entry.

Also, I enable the test for Triton runtime errors, now that tl.assert_device is here.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100357
Approved by: https://github.com/voznesenskym
2023-05-02 11:44:06 +00:00
XiaobingSuper
76bcc87277 fix TIMM mobilevit_s complier issue for dynamic CPU path (#100230)
For TIMM ```mobilevit``` dynamic path, there has a compiler issue(```
python -m torch.backends.xeon.run_cpu --node_id 0 benchmarks/dynamo/timm_models.py --performance --float32 -dcpu -n2 --inductor --no-skip --dashboard --only mobilevit_s --inference --dynamic-shapes```
):

```
/tmp/torchinductor_xiaobing/xy/cxyslqzcsxkco4ieph7t63kn5q74ka35ak75lwfon32nlalxmru5.cpp:29:130: error: invalid operands of types ‘long int’ and ‘double’ to binary ‘operator%’
   29 |                             auto tmp0 = in_ptr0[static_cast<long>((((((-1L) + ks1) / 8L)*(((-1L) + ks1) / 8L))*((((2L*((i2 / 1L) % (std::ceil((1.0/2.0) + ((1.0/2.0)*(((-1L) + ks1)
```

There has a modulo for ```long % double```, this PR will convert inputs to long before do this operation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100230
Approved by: https://github.com/jansel
2023-04-29 12:05:47 +00:00
Peter Bell
f9c3fcd1df [inductor] Fix nan-handling of max and min reductions (#99881)
This adds helpers that replace tritons `minimum`, `maximum`, `min` and
`max` with the correct NaN prrpagation. I also removed
`ops.int_minimum` in favor of `ops.minimum` because we can just omit
the nan-checks by checking the dtype.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99881
Approved by: https://github.com/ngimel
2023-04-27 15:10:50 +00:00
Edward Z. Yang
3a5427baf4 Add torch.utils._content_store (#99809)
Implements a simple content-addressable store for storages (with tensors implemented as cheap references on top), enabling incremental serialization of tensors to disk, which I intend to use in the accuracy repro extractor.  Check the comment at the top of torch/utils/_content_store.py for more details on the intended use case.

One major piece of this PR is implementing the content hash for tensors.  For our prospective use case, we may need to repeatedly hash up to 80 GB of tensor data every time we snapshot (and we may snapshot multiple times).  Using a conventional cryptographic hash and hashing each snapshot would likely take on order of minutes, which seemed too slow to me.  So instead, I implemented a crappy hash function that can be run on GPU.  It is at least somewhat theoretically grounded: using random parameters generated by Philox, we use the standard shift-multiply and xor sum universal hash family.  The hash function is a bit dorky though; instead of properly doing 160-bit math, it just runs 32-bit hash five times and cats them together.  By the way, this sets the first precedent for kernel in PyTorch library which MUST be torch.compile'd to be run (in fact, this kernel does not run in eager mode because of the use of xor_sum, which doesn't actually exist in ATen.)

I had to add a few more primitives to inductor, namely randint (over the entire int range) and xor_sum.  Fortunately, these primitives are natively supported by Triton/C++, and so they were very easy to plumb through.  xor_sum is exposed as a prim, while randint special cases on when low/high span the entire 32-bit signed integer range.

Thanks to Jeff Johnson for letting me bounce ideas of him on a Saturday morning lol.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99809
Approved by: https://github.com/voznesenskym
2023-04-26 18:02:59 +00:00