mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] support masked vectorization for the tail_loop for float64 datatype (#163316)
**Summary:**
Support masked vectorization for the tail_loop for float64 datatype.
**Example:**
```
import torch
def fn(x):
return x * x
x = torch.randn((22, 22), dtype=torch.double)
with torch.no_grad():
compiled_fn = torch.compile(fn)
compiled_fn(x)
```
**Generated code:**
- Before
```
cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C" void kernel(const double* in_ptr0,
double* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L)))
{
auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
auto tmp1 = tmp0 * tmp0;
tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L)))
{
for (int64_t x0_tail = static_cast<int64_t>(480L);x0_tail < static_cast<int64_t>(484L); x0_tail++)
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)];
auto tmp1 = double(tmp0 * tmp0);
out_ptr0[static_cast<int64_t>(x0_tail)] = tmp1;
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (22, 22), (22, 1))
buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64)
# [Provenance debug handles] cpp_fused_mul_0:1
cpp_fused_mul_0(arg0_1, buf0)
del arg0_1
return (buf0, )
```
- After
```
cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C" void kernel(const double* in_ptr0,
double* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L)))
{
auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
auto tmp1 = tmp0 * tmp0;
tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L)))
{
auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L));
auto tmp1 = tmp0 * tmp0;
tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L));
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (22, 22), (22, 1))
buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64)
# [Provenance debug handles] cpp_fused_mul_0:1
cpp_fused_mul_0(arg0_1, buf0)
del arg0_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163316
Approved by: https://github.com/mingfeima, https://github.com/jansel
This commit is contained in:
parent
94eaeb9cb8
commit
20be077085
|
|
@ -4810,6 +4810,23 @@ class CPUReproTests(TestCase):
|
||||||
self.common(fn, (x,))
|
self.common(fn, (x,))
|
||||||
check_metrics_vec_kernel_count(1)
|
check_metrics_vec_kernel_count(1)
|
||||||
|
|
||||||
|
# Tail vectorization case
|
||||||
|
x = torch.randn((37, 37), dtype=torch.double)
|
||||||
|
torch._dynamo.reset()
|
||||||
|
metrics.reset()
|
||||||
|
with torch.no_grad():
|
||||||
|
expected = fn(x)
|
||||||
|
compiled_fn = torch.compile(fn)
|
||||||
|
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
# 1 generated vec kernel
|
||||||
|
check_metrics_vec_kernel_count(1)
|
||||||
|
# Check that both main and tail loops are vectorized
|
||||||
|
if _can_check_vec_metrics():
|
||||||
|
FileCheck().check_count(
|
||||||
|
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
|
||||||
|
).run(code)
|
||||||
|
|
||||||
def test_double_reduction_vec(self):
|
def test_double_reduction_vec(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return x.sum(dim=1)
|
return x.sum(dim=1)
|
||||||
|
|
@ -4819,6 +4836,23 @@ class CPUReproTests(TestCase):
|
||||||
self.common(fn, (x,))
|
self.common(fn, (x,))
|
||||||
check_metrics_vec_kernel_count(1)
|
check_metrics_vec_kernel_count(1)
|
||||||
|
|
||||||
|
# Tail vectorization case
|
||||||
|
x = torch.randn((37, 37), dtype=torch.double)
|
||||||
|
torch._dynamo.reset()
|
||||||
|
metrics.reset()
|
||||||
|
with torch.no_grad():
|
||||||
|
expected = fn(x)
|
||||||
|
compiled_fn = torch.compile(fn)
|
||||||
|
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
# 1 generated vec kernel
|
||||||
|
check_metrics_vec_kernel_count(1)
|
||||||
|
# Check that both main and tail loops are vectorized
|
||||||
|
if _can_check_vec_metrics():
|
||||||
|
FileCheck().check_count(
|
||||||
|
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
|
||||||
|
).run(code)
|
||||||
|
|
||||||
def test_convert_fp32_to_double_vec(self):
|
def test_convert_fp32_to_double_vec(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return x.to(torch.double)
|
return x.to(torch.double)
|
||||||
|
|
@ -4828,6 +4862,23 @@ class CPUReproTests(TestCase):
|
||||||
self.common(fn, (x,))
|
self.common(fn, (x,))
|
||||||
check_metrics_vec_kernel_count(1)
|
check_metrics_vec_kernel_count(1)
|
||||||
|
|
||||||
|
# Tail vectorization case
|
||||||
|
x = torch.randn(37, 37)
|
||||||
|
torch._dynamo.reset()
|
||||||
|
metrics.reset()
|
||||||
|
with torch.no_grad():
|
||||||
|
expected = fn(x)
|
||||||
|
compiled_fn = torch.compile(fn)
|
||||||
|
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
# 1 generated vec kernel
|
||||||
|
check_metrics_vec_kernel_count(1)
|
||||||
|
# Check that both main and tail loops are vectorized
|
||||||
|
if _can_check_vec_metrics():
|
||||||
|
FileCheck().check_count(
|
||||||
|
"at::vec::convert<double,2,float,1>", 2, exactly=True
|
||||||
|
).run(code)
|
||||||
|
|
||||||
def test_convert_double_to_fp32_vec(self):
|
def test_convert_double_to_fp32_vec(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return x.to(torch.float32)
|
return x.to(torch.float32)
|
||||||
|
|
@ -4837,6 +4888,23 @@ class CPUReproTests(TestCase):
|
||||||
self.common(fn, (x,))
|
self.common(fn, (x,))
|
||||||
check_metrics_vec_kernel_count(1)
|
check_metrics_vec_kernel_count(1)
|
||||||
|
|
||||||
|
# Tail vectorization case
|
||||||
|
x = torch.randn((37, 37), dtype=torch.double)
|
||||||
|
torch._dynamo.reset()
|
||||||
|
metrics.reset()
|
||||||
|
with torch.no_grad():
|
||||||
|
expected = fn(x)
|
||||||
|
compiled_fn = torch.compile(fn)
|
||||||
|
actual, code = run_and_get_cpp_code(compiled_fn, x)
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
# 1 generated vec kernel
|
||||||
|
check_metrics_vec_kernel_count(1)
|
||||||
|
# Check that both main and tail loops are vectorized
|
||||||
|
if _can_check_vec_metrics():
|
||||||
|
FileCheck().check_count(
|
||||||
|
"at::vec::convert<float,1,double,2>", 2, exactly=True
|
||||||
|
).run(code)
|
||||||
|
|
||||||
def test_no_redundant_to_dtypes_between_fused_scheduler_node(self):
|
def test_no_redundant_to_dtypes_between_fused_scheduler_node(self):
|
||||||
# https://github.com/pytorch/pytorch/issues/115260
|
# https://github.com/pytorch/pytorch/issues/115260
|
||||||
p0 = torch.tensor([1.0879], dtype=torch.float16)
|
p0 = torch.tensor([1.0879], dtype=torch.float16)
|
||||||
|
|
|
||||||
|
|
@ -159,6 +159,7 @@ VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||||
]
|
]
|
||||||
|
|
||||||
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||||
|
torch.float64,
|
||||||
torch.float,
|
torch.float,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user