mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] [cpu] fix the dype hardcoded to int64 in store_reduction (#157904)
## Fixes https://github.com/pytorch/pytorch/issues/157683
## mini repro
* Just copy the code from the issue to reproduce it.
```python
import torch
device = "cpu"
# Input tensors
v2_0 = torch.randn(16, 24, 59, dtype=torch.complex64, device=device)
v3_0 = torch.randn(16, 24, 59, dtype=torch.complex64, device=device)
def my_model(v2_0, v3_0):
v6_0 = -v3_0
v4_0 = v2_0 * v3_0
v1_0 = v4_0.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
v0_0 = v2_0.to(torch.int32)
v5_0 = v0_0.amax(dim=0)
return v6_0, v4_0, v1_0, v0_0, v5_0
v6_0, v4_0, v1_0, v0_0, v5_0 = my_model(v2_0, v3_0)
print("v6_0", v6_0.shape)
print("v4_0", v4_0.shape)
compiled_model = torch.compile(my_model, backend="inductor")
v6_0, v4_0, v1_0, v0_0, v5_0 = compiled_model(v2_0, v3_0)
print("v6_0", v6_0.shape)
print("v4_0", v4_0.shape)
print("v1_0", v1_0.shape)
print("v0_0", v0_0.shape)
print("v5_0", v5_0.shape)
```
error_stack
```
/home/admin/pytorch/pytorch/torch/include/ATen/cpu/vec/vec_convert.h:41:1: 附注:candidate: ‘template<class dst_t, class src_t> std::enable_if_t<(! is_same_v<dst_t, src_t>), at::vec::CPU_CAPABILITY::Vectorized<T> > at::vec::CPU_CAPABILITY::convert(const at::vec::CPU_CAPABILITY::Vectorized<T>&)’
41 | convert(const Vectorized<src_t>& src) {
| ^~~~~~~
/home/admin/pytorch/pytorch/torch/include/ATen/cpu/vec/vec_convert.h:41:1: 附注: template argument deduction/substitution failed:
/tmp/torchinductor_admin/6k/c6kr65o43rlmp2cmkpn5ezewhe5bla4w72hpcrg5biyelrs4skyw.main.cpp:37:99: 错误:模板参数数目不对(不应是 4 个而应是 2 个)
37 | auto int32_t_tmp_acc0_vec = at::vec::convert<int32_t,1,int64_t,2>(tmp_acc0_vec);
```
## summary
**The C++ kernel generated by the Inductor had the wrong data type for the output variable; it should be int32_t instead of int64_t. This incorrect data type led to an incompatible data type conversion, which caused the g++ compilation to fail.**
The original code that caused the problem.
```
def my_model(v2_0, v3_0):
v6_0 = -v3_0
v4_0 = v2_0 * v3_0
v1_0 = v4_0.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
v0_0 = v2_0.to(torch.int32)
// The original code that caused the problem.
v5_0 = v0_0.amax(dim=0)
```
## proof procedure
The c++ kernel generated by inductor:
```c++
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C" void kernel(const int32_t* in_ptr0,
int32_t* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(1416L); x0+=static_cast<int64_t>(16L))
{
{
int32_t tmp_acc0_arr[16];
for (int i = 0; i < 16; i++)
{
tmp_acc0_arr[i] = std::numeric_limits<int32_t>::min();
}
int32_t tmp_acc0 = std::numeric_limits<int32_t>::min();
at::vec::Vectorized<int32_t> tmp_acc0_vec = at::vec::Vectorized<int32_t>(std::numeric_limits<int32_t>::min());
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L); x1+=static_cast<int64_t>(1L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(1408L)))
{
auto tmp0 = at::vec::Vectorized<int32_t>::loadu(in_ptr0 + static_cast<int64_t>(x0 + 1416L*x1), static_cast<int64_t>(16));
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp0);
}
if(C10_UNLIKELY(x0 >= static_cast<int64_t>(1408L) && x0 < static_cast<int64_t>(1416L)))
{
for (int64_t x0_tail = static_cast<int64_t>(1408L);x0_tail < static_cast<int64_t>(1416L); x0_tail++)
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail + 1416L*x1)];
tmp_acc0_arr[x0_tail - static_cast<int64_t>(1408L)] = max_propagate_nan(tmp_acc0_arr[x0_tail - static_cast<int64_t>(1408L)], tmp0);
}
}
}
}
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(1408L)))
{
// impossible data type conversion which would caused the g++ compilation to fail.
auto int32_t_tmp_acc0_vec = at::vec::convert<int32_t,1,int64_t,2>(tmp_acc0_vec);
int32_t_tmp_acc0_vec.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x0 >= static_cast<int64_t>(1408L) && x0 < static_cast<int64_t>(1416L)))
{
for (int64_t x0_tail = static_cast<int64_t>(1408L);x0_tail < static_cast<int64_t>(1416L); x0_tail++)
{
out_ptr0[static_cast<int64_t>(x0_tail)] = tmp_acc0_arr[x0_tail - static_cast<int64_t>(1408L)];
}
}
}
}
}
}
```
the compilers complains
```text
/home/admin/pytorch/pytorch/torch/include/ATen/cpu/vec/vec_convert.h:41:1: 附注:candidate: ‘template<class dst_t, class src_t> std::enable_if_t<(! is_same_v<dst_t, src_t>), at::vec::CPU_CAPABILITY::Vectorized<T> > at::vec::CPU_CAPABILITY::convert(const at::vec::CPU_CAPABILITY::Vectorized<T>&)’
41 | convert(const Vectorized<src_t>& src) {
| ^~~~~~~
/home/admin/pytorch/pytorch/torch/include/ATen/cpu/vec/vec_convert.h:41:1: 附注: template argument deduction/substitution failed:
/tmp/torchinductor_admin/6k/c6kr65o43rlmp2cmkpn5ezewhe5bla4w72hpcrg5biyelrs4skyw.main.cpp:37:99: 错误:模板参数数目不对(不应是 4 个而应是 2 个)
37 | auto int32_t_tmp_acc0_vec = at::vec::convert<int32_t,1,int64_t,2>(tmp_acc0_vec);
```
so the following line have problem
```c++
// this line means that tmp_acc0_vec should be Vectorized<int64_t>, and it will convert it to Vectorized<int32_t>.
auto int32_t_tmp_acc0_vec = at::vec::convert<int32_t,1,int64_t,2>(tmp_acc0_vec);
```
The issue is that tmp_acc0_vec is of type Vectorized<int32_t>, but the template parameters expect it to be Vectorized<int64_t>. and it will convert it to a Vectorized<int32_t>. this is conflict. the conversion should not be exist for tmp_acc0_vec is already Vectorized<int32_t>.The following line hardcodes the output variable type to int64, which causes unnecessary and incorrect type conversions.
d89f30ad45/torch/_inductor/codegen/cpp.py (L2985-L2993)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157904
Approved by: https://github.com/jgong5
This commit is contained in:
parent
aa75e917bd
commit
24f43d0da7
|
|
@ -3117,6 +3117,30 @@ class CPUReproTests(TestCase):
|
|||
lengths = torch.zeros(11, dtype=torch.long)
|
||||
get_traj_idx(lengths, num_slices=4)
|
||||
|
||||
def test_store_reduction(self):
|
||||
# fix https://github.com/pytorch/pytorch/issues/157683
|
||||
def fn(x, y):
|
||||
r1 = x.amax(dim=0)
|
||||
r2 = y.amax(dim=0)
|
||||
return r1, r2
|
||||
|
||||
device = "cpu"
|
||||
for int_dypte, float_dtype in zip(
|
||||
[torch.int64, torch.int32, torch.int16, torch.int8],
|
||||
[torch.float64, torch.float32, torch.float16, torch.bfloat16],
|
||||
):
|
||||
x = torch.randint(
|
||||
low=0, high=100, size=(16, 24, 59), dtype=int_dypte, device=device
|
||||
)
|
||||
y = torch.randn(16, 24, 59, dtype=float_dtype, device=device)
|
||||
self.common(
|
||||
fn,
|
||||
(
|
||||
x,
|
||||
y,
|
||||
),
|
||||
)
|
||||
|
||||
@requires_vectorization
|
||||
@patch("torch.cuda.is_available", lambda: False)
|
||||
def test_sign_cpu_only(self):
|
||||
|
|
|
|||
|
|
@ -3218,11 +3218,10 @@ class CppVecKernel(CppKernel):
|
|||
index = self.rename_indexing(index)
|
||||
var = self.args.output(name)
|
||||
out_dtype = V.graph.get_dtype(name)
|
||||
dtype = (
|
||||
(out_dtype if out_dtype == torch.double else torch.float)
|
||||
if out_dtype.is_floating_point
|
||||
else torch.int64
|
||||
)
|
||||
if out_dtype.is_floating_point and out_dtype != torch.double:
|
||||
dtype = torch.float
|
||||
else:
|
||||
dtype = out_dtype
|
||||
out_num_vectors = V.kernel._get_num_vectors(out_dtype)
|
||||
src_num_vectors = V.kernel._get_num_vectors(dtype)
|
||||
code = IndentedBuffer()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user