[Inductor] support masked vectorization for the tail_loop for float64 datatype (#163316)

**Summary:**
Support masked vectorization for the tail_loop for float64 datatype.

**Example:**
```
import torch

def fn(x):
    return x * x

x = torch.randn((22, 22), dtype=torch.double)
with torch.no_grad():
    compiled_fn = torch.compile(fn)
    compiled_fn(x)
```

**Generated code:**

- Before
```
cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const double* in_ptr0,
                       double* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L)))
                {
                    auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = tmp0 * tmp0;
                    tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L)))
                {
                    for (int64_t x0_tail = static_cast<int64_t>(480L);x0_tail < static_cast<int64_t>(484L); x0_tail++)
                    {
                        auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)];
                        auto tmp1 = double(tmp0 * tmp0);
                        out_ptr0[static_cast<int64_t>(x0_tail)] = tmp1;
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (22, 22), (22, 1))
        buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64)
        # [Provenance debug handles] cpp_fused_mul_0:1
        cpp_fused_mul_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```
- After
```
cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const double* in_ptr0,
                       double* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L)))
                {
                    auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = tmp0 * tmp0;
                    tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L)))
                {
                    auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L));
                    auto tmp1 = tmp0 * tmp0;
                    tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L));
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (22, 22), (22, 1))
        buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64)
        # [Provenance debug handles] cpp_fused_mul_0:1
        cpp_fused_mul_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163316
Approved by: https://github.com/mingfeima, https://github.com/jansel
This commit is contained in:
Sun, Jiayi 2025-10-28 16:37:00 +00:00 committed by PyTorch MergeBot
parent 94eaeb9cb8
commit 20be077085
2 changed files with 69 additions and 0 deletions

View File

@ -4810,6 +4810,23 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
# Tail vectorization case
x = torch.randn((37, 37), dtype=torch.double)
torch._dynamo.reset()
metrics.reset()
with torch.no_grad():
expected = fn(x)
compiled_fn = torch.compile(fn)
actual, code = run_and_get_cpp_code(compiled_fn, x)
self.assertEqual(expected, actual)
# 1 generated vec kernel
check_metrics_vec_kernel_count(1)
# Check that both main and tail loops are vectorized
if _can_check_vec_metrics():
FileCheck().check_count(
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
).run(code)
def test_double_reduction_vec(self):
def fn(x):
return x.sum(dim=1)
@ -4819,6 +4836,23 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
# Tail vectorization case
x = torch.randn((37, 37), dtype=torch.double)
torch._dynamo.reset()
metrics.reset()
with torch.no_grad():
expected = fn(x)
compiled_fn = torch.compile(fn)
actual, code = run_and_get_cpp_code(compiled_fn, x)
self.assertEqual(expected, actual)
# 1 generated vec kernel
check_metrics_vec_kernel_count(1)
# Check that both main and tail loops are vectorized
if _can_check_vec_metrics():
FileCheck().check_count(
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
).run(code)
def test_convert_fp32_to_double_vec(self):
def fn(x):
return x.to(torch.double)
@ -4828,6 +4862,23 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
# Tail vectorization case
x = torch.randn(37, 37)
torch._dynamo.reset()
metrics.reset()
with torch.no_grad():
expected = fn(x)
compiled_fn = torch.compile(fn)
actual, code = run_and_get_cpp_code(compiled_fn, x)
self.assertEqual(expected, actual)
# 1 generated vec kernel
check_metrics_vec_kernel_count(1)
# Check that both main and tail loops are vectorized
if _can_check_vec_metrics():
FileCheck().check_count(
"at::vec::convert<double,2,float,1>", 2, exactly=True
).run(code)
def test_convert_double_to_fp32_vec(self):
def fn(x):
return x.to(torch.float32)
@ -4837,6 +4888,23 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
# Tail vectorization case
x = torch.randn((37, 37), dtype=torch.double)
torch._dynamo.reset()
metrics.reset()
with torch.no_grad():
expected = fn(x)
compiled_fn = torch.compile(fn)
actual, code = run_and_get_cpp_code(compiled_fn, x)
self.assertEqual(expected, actual)
# 1 generated vec kernel
check_metrics_vec_kernel_count(1)
# Check that both main and tail loops are vectorized
if _can_check_vec_metrics():
FileCheck().check_count(
"at::vec::convert<float,1,double,2>", 2, exactly=True
).run(code)
def test_no_redundant_to_dtypes_between_fused_scheduler_node(self):
# https://github.com/pytorch/pytorch/issues/115260
p0 = torch.tensor([1.0879], dtype=torch.float16)

View File

@ -159,6 +159,7 @@ VECTORIZABLE_DTYPES: list[torch.dtype] = [
]
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
torch.float64,
torch.float,
torch.bfloat16,
torch.float16,