Commit Graph

191 Commits

Author SHA1 Message Date
Edward Z. Yang
6f70d22277 Extend torch.utils._sympy.symbol for more Inductor symbols (#125419)
I'm still missing a few, cdzq at least

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125419
Approved by: https://github.com/lezcano
ghstack dependencies: #125395
2024-05-04 09:05:00 +00:00
Edward Z. Yang
5503c29357 Introduce torch.utils._sympy.symbol (#125395)
This provides utilities for creating and querying properties on
sympy.Symbol.  I want to use this refactor to get a better handle on how
the 's' prefix is being used in Inductor.  To start, I only do
symbolic_shapes code because that's what I'm familiar with.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125395
Approved by: https://github.com/Skylion007
2024-05-03 21:24:23 +00:00
Edward Z. Yang
dae574c713 Don't make replacements for i variables (#125398)
This was introduced in https://github.com/pytorch/pytorch/pull/110262
but actually it looks like they were trying to hit unbacked SymInt.
Now that unbacked SymInt is renamed to u, this code is no longer
necessary

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125398
Approved by: https://github.com/lezcano, https://github.com/Skylion007
2024-05-02 20:38:09 +00:00
haozhe.zhu
c5b1a4c269 [inductor] share more cse cache during swap buffer (#124921)
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope.
For example,

```
auto tmp8 = -std::numeric_limits<float>::infinity();
auto tmp9 = [&]
{
    auto tmp12 = -std::numeric_limits<float>::infinity();
    return tmp12;
}
```
`tmp12` should not be created since it is same with `tmp8`.

We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.)

**Test Plan**
```
python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu
```
the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope.

Before this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = c10::convert<int>(x3);
                                    auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1);
                                    auto tmp21 = static_cast<int>(256);
                                    auto tmp22 = at::vec::Vectorized<int>(tmp21);
                                    auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22);
                                    auto tmp25 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp26 = tmp23 & tmp25;
                                    auto tmp24 = [&]
                                    {
                                        auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp27;
                                    }
                                    ;
                                    auto tmp28 =
                                    [&]
                                    {
                                        if (tmp26.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp29 = static_cast<float>(0.0);
                                    auto tmp30 = at::vec::Vectorized<float>(tmp29);
                                    auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>());
                                    return tmp31;
                                }
                                ;
                                auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp33 = static_cast<float>(0.0);
                                auto tmp34 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp35 = at::vec::Vectorized<float>(tmp33);
                                auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>());
                                auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>());
                                return tmp37;
                            }
                            ;
                            auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp40 = static_cast<int>(3);
                            auto tmp41 = tmp39 < tmp40;
                            auto tmp42 = [&]
                            {
                                auto tmp43 = c10::convert<int>(x3);
                                auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1);
                                auto tmp45 = static_cast<int>(256);
                                auto tmp46 = at::vec::Vectorized<int>(tmp45);
                                auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46);
                                auto tmp49 = at::vec::VecMask<float,1>::from(tmp41);
                                auto tmp50 = tmp47 & tmp49;
                                auto tmp48 = [&]
                                {
                                    auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp51;
                                }
                                ;
                                auto tmp52 =
                                [&]
                                {
                                    if (tmp50.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp53 = static_cast<float>(0.0);
                                auto tmp54 = at::vec::Vectorized<float>(tmp53);
                                auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>());
                                return tmp55;
                            }
                            ;
                            auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp57 = static_cast<float>(0.0);
                            auto tmp58 = at::vec::VecMask<float,1>::from(tmp41);
                            auto tmp59 = at::vec::Vectorized<float>(tmp57);
                            auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>());
                            auto tmp61 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>());
                            tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = c10::convert<int64_t>(x3);
                                    auto tmp15 = static_cast<int64_t>(256);
                                    auto tmp16 = tmp14 >= tmp15;
                                    auto tmp17 = [&]
                                    {
                                        auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp18;
                                    }
                                    ;
                                    auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0);
                                    auto tmp20 = static_cast<float>(0.0);
                                    auto tmp21 = tmp16 ? tmp19 : tmp20;
                                    return tmp21;
                                }
                                ;
                                auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp23 = static_cast<float>(0.0);
                                auto tmp24 = tmp12 ? tmp22 : tmp23;
                                auto tmp25 = tmp6 ? tmp9 : tmp24;
                                return tmp25;
                            }
                            ;
                            auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp28 = static_cast<int64_t>(3);
                            auto tmp29 = tmp27 < tmp28;
                            auto tmp30 = [&]
                            {
                                auto tmp31 = c10::convert<int64_t>(x3);
                                auto tmp32 = static_cast<int64_t>(256);
                                auto tmp33 = tmp31 >= tmp32;
                                auto tmp34 = [&]
                                {
                                    auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp35;
                                }
                                ;
                                auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0);
                                auto tmp37 = static_cast<float>(0.0);
                                auto tmp38 = tmp33 ? tmp36 : tmp37;
                                return tmp38;
                            }
                            ;
                            auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0);
                            auto tmp40 = static_cast<float>(0.0);
                            auto tmp41 = tmp29 ? tmp39 : tmp40;
                            auto tmp42 = tmp2 ? tmp26 : tmp41;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42;
                        }
                    }
                }
            }
        }
    }
}
''')
```
After this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = at::vec::Vectorized<int>(tmp1);
                                    auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19);
                                    auto tmp22 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp23 = tmp20 & tmp22;
                                    auto tmp21 = [&]
                                    {
                                        auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp24;
                                    }
                                    ;
                                    auto tmp25 =
                                    [&]
                                    {
                                        if (tmp23.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp26 = static_cast<float>(0.0);
                                    auto tmp27 = at::vec::Vectorized<float>(tmp26);
                                    auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>());
                                    return tmp28;
                                }
                                ;
                                auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp30 = static_cast<float>(0.0);
                                auto tmp31 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp32 = at::vec::Vectorized<float>(tmp30);
                                auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>());
                                auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>());
                                return tmp34;
                            }
                            ;
                            auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp37 = static_cast<int>(3);
                            auto tmp38 = tmp36 < tmp37;
                            auto tmp39 = [&]
                            {
                                auto tmp40 = c10::convert<int>(x3);
                                auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1);
                                auto tmp42 = at::vec::Vectorized<int>(tmp1);
                                auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42);
                                auto tmp45 = at::vec::VecMask<float,1>::from(tmp38);
                                auto tmp46 = tmp43 & tmp45;
                                auto tmp44 = [&]
                                {
                                    auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp47;
                                }
                                ;
                                auto tmp48 =
                                [&]
                                {
                                    if (tmp46.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp49 = static_cast<float>(0.0);
                                auto tmp50 = at::vec::Vectorized<float>(tmp49);
                                auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>());
                                return tmp51;
                            }
                            ;
                            auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp53 = static_cast<float>(0.0);
                            auto tmp54 = at::vec::VecMask<float,1>::from(tmp38);
                            auto tmp55 = at::vec::Vectorized<float>(tmp53);
                            auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>());
                            auto tmp57 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>());
                            tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = tmp4 >= tmp1;
                                    auto tmp15 = [&]
                                    {
                                        auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp16;
                                    }
                                    ;
                                    auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0);
                                    auto tmp18 = static_cast<float>(0.0);
                                    auto tmp19 = tmp14 ? tmp17 : tmp18;
                                    return tmp19;
                                }
                                ;
                                auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp21 = static_cast<float>(0.0);
                                auto tmp22 = tmp12 ? tmp20 : tmp21;
                                auto tmp23 = tmp6 ? tmp9 : tmp22;
                                return tmp23;
                            }
                            ;
                            auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp26 = static_cast<int64_t>(3);
                            auto tmp27 = tmp25 < tmp26;
                            auto tmp28 = [&]
                            {
                                auto tmp29 = c10::convert<int64_t>(x3);
                                auto tmp30 = tmp29 >= tmp1;
                                auto tmp31 = [&]
                                {
                                    auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp32;
                                }
                                ;
                                auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0);
                                auto tmp34 = static_cast<float>(0.0);
                                auto tmp35 = tmp30 ? tmp33 : tmp34;
                                return tmp35;
                            }
                            ;
                            auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0);
                            auto tmp37 = static_cast<float>(0.0);
                            auto tmp38 = tmp27 ? tmp36 : tmp37;
                            auto tmp39 = tmp2 ? tmp24 : tmp38;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39;
                        }
                    }
                }
            }
        }
    }
}
''')
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124921
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: #124597
2024-04-28 04:33:25 +00:00
Sergii Dymchenko
f0f7452e31 Do not propogate (#124769)
Fix the propogate typos.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124769
Approved by: https://github.com/Skylion007
2024-04-24 02:18:18 +00:00
lezcano
9a5b4d2403 Do not forward parent's value range to CSE variable for variables created within codegen. (#123099)
Consider we are generating code for `ops.gt`, and within it we call
`ops.to_dtype`. Before, we would forward the bounds from `gt` to the
to the result of `to_dtype`, which is wrong.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123099
Approved by: https://github.com/jgong5, https://github.com/peterbell10
2024-04-23 06:26:39 +00:00
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
Peter Bell
bd225189f1 [inductor] Change OverridesData to take callables instead of strings (#123397)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123397
Approved by: https://github.com/lezcano
2024-04-11 22:22:54 +00:00
Edward Z. Yang
efa36ef092 Natively support int truncation, don't guard on positive/negative (#122827)
This doesn't entirely fix the original problem that prompted this, but
it seems to just be getting stuck in export constraint formatting now
which seems like progress to me.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122827
Approved by: https://github.com/avikchaudhuri
2024-04-11 15:22:32 +00:00
Peter Bell
9189d04cb1 [inductor] Add explicit ops.fma and use it in softmax_backward (#122518)
This allows us to generate an fma even when fp-fusion is disabled
in the compiler.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122518
Approved by: https://github.com/lezcano, https://github.com/Chillee
2024-04-06 02:15:16 +00:00
drisspg
f4e2a226aa ScoreMod API (#121845)
# Summary

This PR adds a new higher-order_op: `templated_attention`.  This op is designed to extend the functionality of torch.nn.fucntional.scaled_dot_product_attention.  PyTorch has efficient pre-written fused-attention kernels. However, users want to modify how scores are computed (a substep inside attention) -- this traditionally requires the user to write their own attention kernel. One such modification to attention scores that is not currently supported by the top level SDPA op is:[ Attention with Linear Biases (ALiBi](https://arxiv.org/abs/2108.12409)).

This higher-order op will instead accept a callable( 'score_mod') function that is through torch.compile will be used to create an efficient attention kernel instantiation.

### Details

This HOP utilizes the existing fx and HOP infra to capture and convert the User `score-mod` function and convert to an FX graph module. Inductor then consumes this HOP that has a `ir.Subgraph` input. It will inline this lowered subgraph into a triton kernel which performs fused attention with the modification to the scores matrix inlined.

### API

The API for a score_mod function should be as follows:

```Python
def score_mod(score: torch.Tensor, batch: torch.Tensor, head: torch.Tensor, token_1: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```

This function receives five parameters:

- `score`: A scalar tensor representing the attention score, with the same data type and device as the query, key, and value tensors.
- `batch`, `head`, `seq_len_q`, `seq_len_kv`: Scalar tensors indicating the batch index, head index, query index, and key/value index, respectively, with torch.int data type and located on the same device as the score tensor.

Consider inputs query, key, value of shapes (2, 4, 16, 8), leading to an intermediate attention score matrix of shape (2, 4, 16, 16)

The score_mod function will be vectorized over each element of this matrix. For instance, modifying the score at the position corresponding to the 0th batch, 2nd head, between the 8th query and the 9th key element, would be invoked as:

```Python
score_mod(score[0,2,8,9], torch.tensor(0), torch.tensor(2), torch.tensor(8), torch.tensor(9))
```

### Examples
```Python
import torch
from torch.nn.attention.templated_attention import templated_attention

torch.manual_seed(0)

# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim)
query = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
key = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
value = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)

# Lets create a fun new score_modification! I will call this
# Checkerboard. It will reduce the score for neighboring tokens (1 step apart)
# in the sequence. And increase the score for tokens 2 steps apart. For everything
# else, the score will remain the same.

def checkerboard(score, batch, head, token_q, token_kv):
    score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
    score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
    return score

# Lets call templated_attention with this new score modification
output = templated_attention(query, key, value, score_mod=checkerboard)

compiled_templated_attention = torch.compile(templated_attention)
out_compiled = compiled_templated_attention(query, key, value, score_mod=checkerboard)

torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```

### Future Work
- This PR is currently only forward only. However the triton kernel for backwards where score_modifications to not rely on external buffers has been explored here: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/flash/flash_attention.py
- Kernel Improvements; There are has been some larger updates to the fused attention implementation that Triton uses in its tutorials. The implementation of this kernel is based on a prior version and should be updated.
- We may want to unify this API under the top level SDPA API and leave that as a follow up once this is more stable
- Should we error on CPU?
- There are some issues with dynamic shapes
- Capturing of free variables and lifting to inputs to the subgraph is not working correctly today

### Performance
Comparisons generated by this benchmark:

| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     5.412 |              |             |             |             |            |               |                |
| Max     |     8.882 |           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |
| Min     |     3.645 |            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |
| Min     |     0.345 |            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |

For reference

| Configuration                                 | Forward Time (µ seconds) | Backend          | Speedup |
|-----------------------------------------------|--------------------------|------------------|---------|
| Fastest Config in Sweep (`8 16 4096 4096 64 relative_bias torch.bfloat16`) | 3608                   | Templated Attention                | 1.0  |
| Compiled SDPA (No Mask)                       | 9928                   | Math             | 2.75x   |
| Compiled SDPA (With Mask)                     | 11898                    | Math             | 3.29x   |
| Compiled SDPA (With Mask) | 8704                      | Memory Efficient Attention | 2.42x   |
| Compiled SDPA (No Mask) | 2548                     | FlashAttention2 | 0.706x   |

The speedups are measuring compiled templated attention speed versus different calls to torch.nn.functional.sdpa

<details>

<summary> FULL PERFORMANCE SWEEP NUMBERS </summary>

|   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |   eager_time |   compiled_time |   speedup |
|--------------|-------------|-------------|-------------|------------|---------------|----------------|--------------|-----------------|-----------|
|            1 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      331.444 |          67.221 |     4.931 |
|            1 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      335.300 |          64.187 |     5.224 |
|            1 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      352.039 |          63.806 |     5.517 |
|            1 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      371.699 |         711.349 |     0.523 |
|            1 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |      333.488 |          86.455 |     3.857 |
|            1 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |      322.363 |          82.469 |     3.909 |
|            1 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |      349.967 |          82.233 |     4.256 |
|            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |      486.359 |        1412.453 |     0.344 |
|            1 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |     2794.597 |         551.188 |     5.070 |
|            1 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |     3965.150 |         513.101 |     7.728 |
|            1 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |     2408.013 |         504.759 |     4.771 |
|            1 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |     6850.531 |       16733.675 |     0.409 |
|            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      441.939 |         123.576 |     3.576 |
|            8 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      560.379 |         116.710 |     4.801 |
|            8 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      421.172 |         115.825 |     3.636 |
|            8 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      994.492 |        2132.806 |     0.466 |
|            8 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     1436.430 |         309.495 |     4.641 |
|            8 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     1892.216 |         290.186 |     6.521 |
|            8 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     1360.665 |         282.956 |     4.809 |
|            8 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     3525.532 |        8359.702 |     0.422 |
|            8 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    22026.839 |        3864.604 |     5.700 |
|            8 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    31262.746 |        3609.551 |     8.661 |
|            8 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    20219.079 |        3480.402 |     5.809 |
|            8 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |    54654.647 |      116652.357 |     0.469 |
|           16 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      820.606 |         188.683 |     4.349 |
|           16 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |     1058.362 |         179.295 |     5.903 |
|           16 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      784.372 |         175.714 |     4.464 |
|           16 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |     1890.792 |        4212.877 |     0.449 |
|           16 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     2781.830 |         557.017 |     4.994 |
|           16 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     3694.050 |         525.249 |     7.033 |
|           16 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     2634.164 |         507.613 |     5.189 |
|           16 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     6959.917 |       15331.116 |     0.454 |
|           16 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    43889.096 |        7582.018 |     5.789 |
|           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    62784.293 |        7075.846 |     8.873 |
|           16 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    40308.606 |        6829.587 |     5.902 |
|           16 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |   108892.137 |      233090.953 |     0.467 |
</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121845
Approved by: https://github.com/Chillee, https://github.com/zou3519
2024-04-06 01:10:44 +00:00
xinan.lin
9743e3a19c [Inductor Intel GPU backend Upstream] Add Inductor Intel GPU backend. (#121895)
As the design in RFC https://github.com/pytorch/pytorch/issues/114856, this PR implemented Intel GPU Inductor backend by:
- Reuse WrapperCodegen and TritonScheduling for python wrapper and kernel code generation. And implenented device-specific code generation in XPUDeviceOpOverrides
- Reuse fx_pass, lowering, codecache, triton kernel auto-tuning, and compilation.

For the test case, this PR provided test/inductor/test_xpu_basic.py for basic inductor backend functionality testing.
We'll reuse all the existing Inductor test case in the next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121895
Approved by: https://github.com/EikanWang, https://github.com/jansel, https://github.com/desertfire
2024-04-05 09:05:11 +00:00
PyTorch MergeBot
16cb5d48dd Revert "[inductor] Add explicit ops.fma and use it in softmax_backward (#122518)"
This reverts commit 05984e642b.

Reverted https://github.com/pytorch/pytorch/pull/122518 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it starts failing in trunk 05984e642b ([comment](https://github.com/pytorch/pytorch/pull/122518#issuecomment-2038631010))
2024-04-05 02:09:32 +00:00
Peter Bell
05984e642b [inductor] Add explicit ops.fma and use it in softmax_backward (#122518)
This allows us to generate an fma even when fp-fusion is disabled
in the compiler.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122518
Approved by: https://github.com/lezcano, https://github.com/Chillee
ghstack dependencies: #121924
2024-04-04 20:53:14 +00:00
Peter Bell
09c72eaa3f [inductor] Remove identity from ops.scan (#119727)
Currently scan has an `init` argument which must be the identity of the
combine function. This isn't strictly necessary if we are more careful about
keeping track of the first element and avoid combining it with anything.

This does additionally require that there are no active load masks, since we can't
do the `where_cond` any more. However, this shouldn't be possible anyway since
scans are always realized and only fused via the scheduler.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119727
Approved by: https://github.com/lezcano
2024-04-01 22:47:26 +00:00
Edward Z. Yang
3178ba0dc9 Don't use sympy Float functions, use an opaque one with no reasoning (#122823)
Sympy simplifications don't obey floating point semantics, so don't
use Sympy for this.  Keep them as is, only evaluate with the reference
implementations when all arguments are known.

This may end up getting subsumed by some other changes later, but I
wanted to understand if this was easy and it seems to be easy.

This doesn't actually depend on the earlier diffs on the stack and I can detach it.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122823
Approved by: https://github.com/lezcano
2024-03-29 19:13:55 +00:00
Jiong Gong
105381ea11 [inductor][cpp] simplify CppVecKernelChecker (remove bool/int8 load as mask and load as float flags) (#119734)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119734
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
ghstack dependencies: #119654, #119655
2024-03-27 11:20:35 +00:00
Jiong Gong
49121603ab [inductor][cpp] support vectorized indirect indexing (#119655)
This PR adds the vectorized indirect indexing so that we can further simplify the `CppVecKernelChecker` (done in the later PR #119734) and remove the check that throws `CppVecUnsupportedError`. A boundary assertion check is added on vectorized indices and via the new `indirect_assert` method on `Kernel` - the base implementation is for scalar indices, overridden in `CppVecKernel` for vectorized indices.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119655
Approved by: https://github.com/jansel
ghstack dependencies: #119654
2024-03-27 10:25:45 +00:00
Jiong Gong
367ec62ae3 [inductor][cpp] generalize vector mask for dtypes (#119654)
Vectorized boolean values in CPU Inductor were modeled with `Vectorized<float>` which cannot work for operations with other data types. This PR generalizes it with the new `VecMask` template class that can work for masks on any vectorized data types. The intrinsics implementation in `cpp_prefix.h` for mask conversion, cast and masked load are now implemented as the specialization for `VecMask` and moved to corresponding header files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119654
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
2024-03-27 05:33:53 +00:00
Wang, Eikan
f8eeae7aaa Enable CPP wrapper codegen registration (#121296)
Extend code gen registration for `CppWrapper`. W/ this PR, an new backend can register its specific `CppWrapper` at runtime.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121296
Approved by: https://github.com/jansel, https://github.com/desertfire
2024-03-26 06:51:03 +00:00
eellison
cbbed46377 Defer selection of triton template (#120275)
Our prior approach to epilogue fusion was to select from a choice from a set of triton templates and extern calls based on benchmarking inputs, then unconditionally fuse epilogues. This can be sub-optimal in following ways:

- We select an extern kernel, however an epilogue like relu() exists such that choosing a triton template + relu would have been faster
- We select a triton template, epilogue fuse, and register spilling occurs causing it to be slower than not epilogue fusing.

In this PR we wait to select either the Triton Template or Extern Kernel based on benchmarking results from the kernel itself and its epilogue. As soon as a successful fusion occurs where a fused Triton Template + epilogue is faster than the unfused choice we finalize the MultiTemplateBuffer as a specific template. If no fusion occurs we'll finalize the MultiTemplateBuffer after fusion.

Note: if there are multiple epilogue fusions (not super likely), even though we select a template after the first fusion, we will still benchmark to see if subsequent epilogue are worth fusing. We could potentially defer choosing template in this case in a follow up at expense of compile time.

Gives 4% HF training win, 10% TIMM inference win. Increases compilation time which I will be trying to address more in follow up prs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120275
Approved by: https://github.com/jansel
ghstack dependencies: #121996
2024-03-20 01:40:33 +00:00
Isuru Fernando
409b1a6081 Add lowering for cummax, cummin (#120429)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120429
Approved by: https://github.com/peterbell10
2024-03-15 19:04:38 +00:00
Peter Bell
168a04e752 [inductor] Changes to support newer triton pin (#121267)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121267
Approved by: https://github.com/lezcano
ghstack dependencies: #121438
2024-03-09 18:17:36 +00:00
Kai Londenberg
96eff4ef70 [inductor max autotune] Detailed autotuning result logs ( machine-readable ) (#119004)
This diff introduces a new separate logging of autotuning results,
with the intention of making the results analyzable, specifically
those for the new experimental Cutlass backend.

Results are logged as text files with one JSON document corresponding to a single benchmark result per line.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119004
Approved by: https://github.com/jansel
ghstack dependencies: #120620
2024-02-29 18:24:13 +00:00
Isuru Fernando
b7df3bba62 add decomposition for frexp (#119217)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119217
Approved by: https://github.com/peterbell10
ghstack dependencies: #119284, #120027
2024-02-23 21:52:42 +00:00
wangjiangben-hw
b4cef25a1e add register_device_op_overrides (#119268)
Fixes #119267

Currently https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/common.py#L106 only supports built-in device function, I'm going to add a register function to get overrides class.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119268
Approved by: https://github.com/jansel
2024-02-21 04:53:07 +00:00
PyTorch MergeBot
0bdeaad936 Revert "add register_device_op_overrides (#119268)"
This reverts commit 2864a7e161.

Reverted https://github.com/pytorch/pytorch/pull/119268 on behalf of https://github.com/malfet due to Broke lint ([comment](https://github.com/pytorch/pytorch/pull/119268#issuecomment-1953231324))
2024-02-19 22:31:32 +00:00
PyTorch MergeBot
f1fbba8f35 Revert "Fix lint after #119268 (#120207)"
This reverts commit d9d0f1dccc.

Reverted https://github.com/pytorch/pytorch/pull/120207 on behalf of https://github.com/atalman due to Broke inductor tests ([comment](https://github.com/pytorch/pytorch/pull/120207#issuecomment-1953170249))
2024-02-19 21:21:12 +00:00
atalman
d9d0f1dccc Fix lint after #119268 (#120207)
Fixes lint after: https://github.com/pytorch/pytorch/issues/119268

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120207
Approved by: https://github.com/davidberard98
2024-02-19 20:01:45 +00:00
wangjiangben-hw
2864a7e161 add register_device_op_overrides (#119268)
Fixes #119267

Currently https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/common.py#L106 only supports built-in device function, I'm going to add a register function to get overrides class.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119268
Approved by: https://github.com/jansel
2024-02-18 06:11:54 +00:00
Adnan Akhundov
e5f46a1d35 Check alignment of ReinterpretView args of custom Triton kernels (#119649)
Summary: Currently, when a custom (user-written) Triton kernel has a ReinterpretView argument in IR, we're always skipping the alignment checking for this argument when preparing the `signature_of` for the AOT compilation of the Triton kernel (via setting `TensorArg.check_alignment` to `False`). This is problematic for user-written kernels where, albeit reinterpreted, the argument of the Triton kernel (the data pointer) can still be aligned to 16. When we skip alignment checking, the performance of the AOT-compiled internal Triton kernels can degrade 2x--3x.

In this PR, we replace `TensorArg.check_alignment` by `TensorArg.offset`, in which we specify the offset of the `ReinterpretView.layout` relative to the underlying `ir.Buffer` (corresponding to the data pointer before reinterpretation). As the size and stride of the layout don't change the alignment properties, those can be skipped. Importantly, for `ReinterpretView` arguments of custom Triton kernels, we use `arg.data.get_name()` as the buffer name. That, together with the offset, is used to check the alignment.

Bonus: the namedtuples in `codegen/common.py` are refactored as `dataclass`es, with nicer type hints and default values (for the newly added `TensorArg.offset`).

Test Plan:

```
$ python test/inductor/test_aot_inductor.py -k test_triton_kernel_reinterpret_view
...
----------------------------------------------------------------------
Ran 6 tests in 27.952s

OK (skipped=4)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119649
Approved by: https://github.com/oulgen
2024-02-11 20:21:17 +00:00
Pearu Peterson
2c91e13afc Add lowerings to special functions (#119187)
As in the title.

In addition, the PR introduces infrastructure for lowerings of pointwise functions that have both cpp and triton implementations available.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119187
Approved by: https://github.com/peterbell10
2024-02-11 16:35:40 +00:00
Jiong Gong
a050d146b7 [Inductor] Add Int8 data type into Inductor CPP backend vectorized code generation (#119179)
**Summary**
Part 1 of fixing https://github.com/pytorch/pytorch/issues/119141 which needs vectorized code generation of per channel quant and int8 data type.
In the current implementation for quantization, the vectorized code generation only supports the `uint8` data type. In this PR, we introduce support for the `int8` data type within the vectorized code generation.

**TestPlan**
```
python -u -m pytest -s -v test_cpu_repro.py -k test_decomposed_dequant_relu_quant_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_dequant_quant_lowering_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_dequant_maxpool2d_lowering_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_tile2d_load_decomposed_dequant_add_relu_quant_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_per_tensor_fake_quant_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_non_contiguous_load_buf_quant_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_tile2d_store_channel_shuffle_cl_quant_output_int8
python -u -m pytest -s -v test_cpu_repro.py -k test_dequant_relu_quant_dequant_relu_quant_lowering_int8
```

Co-authored-by: Jiong Gong <jiong.gong@intel.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119179
Approved by: https://github.com/peterbell10, https://github.com/jgong5, https://github.com/jansel
2024-02-09 07:33:12 +00:00
Peter Bell
88429a8084 [inductor] Add split scan kernel (#117992)
This PR adds a new type of triton kernel in which data is persistent but the
reduction dimension is split over multiple blocks (up to the entire kernel).
though this is called a reduction dimension, in actuality we only support scans.
because of this limitation, i have to be able to block fusions of split scan
operations with reductions so chose to add a new `ir.SplitScan` node which
is identical but allows for differentiation in the scheduler.

The split scan kernel is also the first to require an additional workspace buffer
which is used to communicate between cuda blocks. this is slightly tricky as we
the exact scratch space requirement isn't known until the grid size is calculated.
here i workaround the issue by setting a minimum rblock size and always allocating
to the maximum possible grid size for a given input tensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117992
Approved by: https://github.com/jansel
ghstack dependencies: #117991
2024-02-09 01:56:00 +00:00
PyTorch MergeBot
088d538a8d Revert "[Inductor] GEMM shape padding improvements (#118522)"
This reverts commit cc46829f96.

Reverted https://github.com/pytorch/pytorch/pull/118522 on behalf of https://github.com/eellison due to regresses HF ~4/5% ([comment](https://github.com/pytorch/pytorch/pull/118522#issuecomment-1932557670))
2024-02-07 17:42:14 +00:00
Bin Bao
e868a7fedd [AOTI] Rename config.aot_inductor.abi_compatible (#119065)
Summary: Rename config.aot_inductor.abi_compatible to config.abi_compatible, since the cpp_wrapper mode in JIT Inductor will share the same flag.

Differential Revision: [D53478752](https://our.internmc.facebook.com/intern/diff/D53478752)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119065
Approved by: https://github.com/khabinov
2024-02-07 00:14:33 +00:00
Edward Z. Yang
abc09b27b9 Some minor type stub improvements (#118529)
I was just playing around with improving the typing of symbolic_shapes. The PR is not "complete" but I in particular wanted to get feedback on whether or not people liked making ValueRanges Generic; it seems that distinguishing if you have an Expr ValueRange or a SympyBoolean ValueRange is a lot of trouble for downstream. Using TypeGuard, we can perform refinements on the generic parameter inside methods, although we still have to cast back to ValueRange[T] due to https://github.com/python/mypy/issues/14425#issuecomment-1914852707

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118529
Approved by: https://github.com/Skylion007
2024-02-04 00:19:00 +00:00
Kai Londenberg
cc46829f96 [Inductor] GEMM shape padding improvements (#118522)
Improvements to shape padding logic in torch/_inductor/pad_mm.py

These changes could lead up to 14% perf improvement for certain Meta internal models in experiments.

Most notably:
  * 1.) Use aten.const_pad_nd operation to pad Tensors in a single op instead of using multiple steps involving intermediate buffers. This appears to be more performant than the previous logic, confirmed by Profiling & Benchmarking results ( Meta internal )
 * 2.) Make many paddings unneccessary using explicitly transposed GEMM when either M or N dimension is properly aligned but the other is not, configurable via config.shape_pad_use_transpose (default: True).
  * 3.) Enable shape padding for the Inductor CUDA  /  Cutlass backend for all GEMM ops where Cutlass would be enabled, without benchmarking in that case.
  * Add config flag to always pad shapes (without benchmarking first), configurable via config.force_shape_pad (default: False )
  * Added several new unit tests to ensure tensors are padded such that they meet all alignment requirements after padding.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118522
Approved by: https://github.com/jansel, https://github.com/eellison
2024-02-02 08:50:06 +00:00
PyTorch MergeBot
dbba1d4bf5 Revert "Some minor type stub improvements (#118529)"
This reverts commit c978f38bd4.

Reverted https://github.com/pytorch/pytorch/pull/118529 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/118529#issuecomment-1922362331))
2024-02-01 22:18:36 +00:00
Edward Z. Yang
c978f38bd4 Some minor type stub improvements (#118529)
I was just playing around with improving the typing of symbolic_shapes. The PR is not "complete" but I in particular wanted to get feedback on whether or not people liked making ValueRanges Generic; it seems that distinguishing if you have an Expr ValueRange or a SympyBoolean ValueRange is a lot of trouble for downstream. Using TypeGuard, we can perform refinements on the generic parameter inside methods, although we still have to cast back to ValueRange[T] due to https://github.com/python/mypy/issues/14425#issuecomment-1914852707

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118529
Approved by: https://github.com/Skylion007
2024-01-31 20:56:56 +00:00
Edward Z. Yang
cad79bd0bb Remove follow_imports = skip from sympy (#118469)
dmypy silently ignores follow_imports = skip, so to get parity between
dmypy and mypy we have to suck it up and type: ignore all of the sympy
typing problems.

The suppressions were added automatically with the following script generated by GPT-4:

```
import re

# Read the error file
with open("error_file.txt", "r") as f:
    errors = f.readlines()

# Parse the lines with errors and error types
error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

# Insert ignore comments in the source files
for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118469
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432, #118467, #118468
2024-01-28 13:38:38 +00:00
Edward Z. Yang
46712b019d Enable local_partial_types (#118467)
When using dmypy, this setting is enabled and cannot be turned off. Force it for regular mypy too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118467
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432
2024-01-28 13:38:22 +00:00
Edward Z. Yang
d03173e88c Unify MYPYINDUCTOR and MYPY (#118432)
The original motivation for MYPYINDUCTOR was a faster type checking configuration that only checked a subset of files. With the removal of `follow_imports = ignore`, we are now able to use dmypy to do fast incremental typechecking, eliminating the need for this.

Perhaps erroneously, when I tee'ed up this PR I elected to delete the `follow_imports = skip` designations in the mypy-inductor.ini. This lead to a number of extra type error suppressions that I manually edited. You will need to review.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118432
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418
2024-01-27 17:23:20 +00:00
laith sakka
708e6241ed Fix sympy_subs to preserve integer and non-negative properties. (#118150)
This diff introduce the following changes:
1. Fix sympy_subs to preserve integer and non-negative properties of replaced symbol when replacement is string
why is this needed?
I was compiling an expression:
 x*abs(y)  where y =-2
  what happens is that this expression is passed as ``s1*abs(s0)`` then s0 is replaced to ks0 with a call to sympy_subs.
 but sympy_subs used to replace s0 (integer=false, nonegative=false) with ks0(inetegr=true, nonegative = true)
 resulting in ``x*abs(ks0) = x*ks0`` which is wrong

2. rename sympy_symbol to sympy_index_symbol to make it explicit.
3. add assertion that replaced expression is not passed as string but always a sympy expression.

Fixes https://github.com/pytorch/pytorch/issues/117757

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118150
Approved by: https://github.com/ezyang
2024-01-25 20:54:55 +00:00
Edward Z. Yang
903e1913ff Rename unbacked SymInt prefix to u (#117859)
Currently, it conflicts with Inductor's naming convention for index
variables

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117859
Approved by: https://github.com/lezcano, https://github.com/jansel, https://github.com/avikchaudhuri
2024-01-22 20:53:47 +00:00
Edward Z. Yang
df4e3d9d08 Document OpsHandler protocol (#117790)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117790
Approved by: https://github.com/jansel
2024-01-21 07:20:53 +00:00
Shunting Zhang
e432b2e607 [inductor] multi-kernel support (#103469)
For a persistent reduction, we generate 2 flavor of 'equivalant' kernels at the same time
- persistent reduction
- regular reduction

A MultiKernel wraps these 2 kernels and pick the one with better performance at runtime.

Here I talk more about implementation details:
- Inductor maintains states for generating kernels. E.g. the wrapper code.  After we generate code for one kernel, we need restore the inductor state before we can generate the counterpart.

***There is one thing I need some comments from others***:
There is one tricky thing about kernel arguments. In general, inductor removes a buffer from the argument list if it's only used inside the kernel.  But somehow a buffer removed by persistent reduction kernel may still be kept by the regular (non-persistent) reduction kernel because of some CSE invalidation rule. My current implementation avoid removing buffers if multi_kernel is enabled. This makes sure both flavors of reduction has consistent argument list.  Another idea I have is, we generate the multi-kernel definition with the union of arguments from both sub-kernels. Let each sub-kernel pick the subset of arguments it wants. But this will make the code-gen or multi-kernel much complex.

I'm not sure if there is some easy and clean way to resolve this.

Testing command:
```

TORCHINDUCTOR_MULTI_KERNEL=1 TORCH_LOGS=+torch._inductor.graph TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 python benchmarks/dynamo/huggingface.py --backend inductor --amp --performance --only BertForMaskedLM --training

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103469
Approved by: https://github.com/jansel
2024-01-18 23:16:31 +00:00
Jason Ansel
a669319450 [inductor] Faster C++ kernel python bindings (#117500)
Calling C++ from Python via ctypes is notoriously slow.  This switches to generating our own C++ bindings directly, which is a >5x speedup on this kernel-launch-bound microbenchmark:
```python
from ctypes import c_void_p
import torch
from torch import empty
from torch._inductor.codecache import AsyncCompile
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
from torch._inductor.wrapper_benchmark import compiled_module_main

async_compile = AsyncCompile()

src = '''
#include "/tmp/torchinductor_jansel/gb/cgbau5vlj6cetmcjbjbtw6x4rrivaln6f45s5d72gy2bfx5foz3k.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr0)
{
    {
        auto tmp0 = in_ptr0[static_cast<long>(0L)];
        auto tmp1 = static_cast<float>(1.0);
        auto tmp2 = decltype(tmp0)(tmp0 + tmp1);
        out_ptr0[static_cast<long>(0L)] = tmp2;
    }
}
'''

cpp_fused_add_ctypes = async_compile.cpp(src)
cpp_fused_add_cpython = async_compile.cpp_pybinding(["const float*", "float*"], src)

async_compile.wait(globals())
del async_compile

def call(arg0_1):
    buf0 = empty((1,), device='cpu', dtype=torch.float32)
    if use_ctypes:
        for _ in range(100):
            cpp_fused_add_ctypes(c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()))
    else:
        for _ in range(100):
            cpp_fused_add_cpython(arg0_1, buf0)
    del arg0_1
    return (buf0,)

def benchmark_compiled_module(times=1000, repeat=100):
    arg0_1 = rand_strided((1,), (1,), device='cpu', dtype=torch.float32)
    return print_performance(lambda: call(arg0_1), times=times, repeat=repeat)

print("old ctypes bindings: ", end='')
use_ctypes = True
compiled_module_main('None', benchmark_compiled_module)
print("new bindings:        ", end='')
use_ctypes = False
compiled_module_main('None', benchmark_compiled_module)
```
Output:
```
old ctypes bindings: 0.000073
new bindings:        0.000013
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117500
Approved by: https://github.com/desertfire
2024-01-18 16:20:12 +00:00
Nikita Shulga
a1afd1b195 Revert "[inductor] Faster C++ kernel python bindings (#117500)"
It should have never been landed, but was landed again, thanks to
ghstack grafting/ungrafting see discussion on https://github.com/pytorch/pytorch/pull/116910

This reverts commit e457b6fb18.
2024-01-17 17:06:32 -08:00
titaiwangms
e457b6fb18 [inductor] Faster C++ kernel python bindings (#117500)
Calling C++ from Python via ctypes is notoriously slow.  This switches to generating our own C++ bindings directly, which is a >5x speedup on this kernel-launch-bound microbenchmark:
```python
from ctypes import c_void_p
import torch
from torch import empty
from torch._inductor.codecache import AsyncCompile
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
from torch._inductor.wrapper_benchmark import compiled_module_main

async_compile = AsyncCompile()

src = '''
#include "/tmp/torchinductor_jansel/gb/cgbau5vlj6cetmcjbjbtw6x4rrivaln6f45s5d72gy2bfx5foz3k.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr0)
{
    {
        auto tmp0 = in_ptr0[static_cast<long>(0L)];
        auto tmp1 = static_cast<float>(1.0);
        auto tmp2 = decltype(tmp0)(tmp0 + tmp1);
        out_ptr0[static_cast<long>(0L)] = tmp2;
    }
}
'''

cpp_fused_add_ctypes = async_compile.cpp(src)
cpp_fused_add_cpython = async_compile.cpp_pybinding(["const float*", "float*"], src)

async_compile.wait(globals())
del async_compile

def call(arg0_1):
    buf0 = empty((1,), device='cpu', dtype=torch.float32)
    if use_ctypes:
        for _ in range(100):
            cpp_fused_add_ctypes(c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()))
    else:
        for _ in range(100):
            cpp_fused_add_cpython(arg0_1, buf0)
    del arg0_1
    return (buf0,)

def benchmark_compiled_module(times=1000, repeat=100):
    arg0_1 = rand_strided((1,), (1,), device='cpu', dtype=torch.float32)
    return print_performance(lambda: call(arg0_1), times=times, repeat=repeat)

print("old ctypes bindings: ", end='')
use_ctypes = True
compiled_module_main('None', benchmark_compiled_module)
print("new bindings:        ", end='')
use_ctypes = False
compiled_module_main('None', benchmark_compiled_module)
```
Output:
```
old ctypes bindings: 0.000073
new bindings:        0.000013
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117500
Approved by: https://github.com/desertfire
ghstack dependencies: #117409, #116667, #117591
2024-01-17 23:03:15 +00:00