[Inductor] support transpose vertical reduction in cpp (#97781)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97781
Approved by: https://github.com/jansel
This commit is contained in:
Jiong Gong 2023-04-02 10:06:47 +00:00 committed by PyTorch MergeBot
parent 76074dc0a3
commit 5d62d12557
2 changed files with 91 additions and 48 deletions

View File

@ -1211,6 +1211,30 @@ class CPUReproTests(TestCase):
self.assertTrue(same(fn2(x), opt_fn2(x)))
assert metrics.generated_cpp_vec_kernel_count == 1
def test_transpose_vertical_sum_cpu_only(self):
def fn(a, b):
c = a * b
return c.sum(dim=1)
metrics.reset()
x = torch.randn(100, 50, 50)
y = torch.randn(100, 50, 50).transpose(1, 2)
opt_fn = torch._dynamo.optimize("inductor")(fn)
self.assertTrue(same(fn(x, y), opt_fn(x, y)))
assert metrics.generated_cpp_vec_kernel_count == 2
def test_transpose_sum2d_cpu_only(self):
def fn(a, b):
c = a * b
return c.sum()
metrics.reset()
x = torch.randn(50, 50)
y = torch.randn(50, 50).transpose(0, 1)
opt_fn = torch._dynamo.optimize("inductor")(fn)
self.assertTrue(same(fn(x, y), opt_fn(x, y)))
assert metrics.generated_cpp_vec_kernel_count == 2
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -26,6 +26,7 @@ from ..virtualized import ops, V
from .common import (
BracesBuffer,
CppWrapperKernelArgs,
CSE,
CSEVariable,
DeferredLine,
ExprPrinter,
@ -904,6 +905,7 @@ class CppKernel(Kernel):
self.reduction_prefix = IndentedBuffer()
self.reduction_suffix = IndentedBuffer()
self.reduction_var_map = {}
self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc")
self.preloads = IndentedBuffer()
self.poststores = IndentedBuffer()
self.num_threads = num_threads # num_threads the kernel specialized for
@ -955,7 +957,7 @@ class CppKernel(Kernel):
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
argmax_or_argmin = reduction_type in {"argmax", "argmin"}
tmpvar = self.cse.generate(
tmpvar = self.reduction_cse.generate(
self.loads, f"reduction {name} {cexpr_index(index)}", write=False
)
index = self.rename_indexing(index)
@ -1044,19 +1046,26 @@ class CppKernel(Kernel):
if hasattr(kernel, "codegen_inner_loops"):
code.splice(kernel.poststores)
def get_reduction_code_buffer(loops, is_suffix=True):
for loop in loops:
for kernel in loop.get_kernels():
if is_suffix:
return kernel.reduction_suffix
else:
return kernel.reduction_prefix
return None
def gen_loops(loops: List[LoopLevel], in_reduction=False):
with contextlib.ExitStack() as stack_outer:
if loops:
loop = loops[0]
if loop.is_reduction() and not in_reduction:
kernels = loop.get_kernels()
assert kernels
# TODO(jgong5): should gen prefix for all kernels.
# currently, Vec kernel generates prefix for both
# vector and scalar kernels.
if kernels[0].reduction_prefix:
reduction_prefix = get_reduction_code_buffer(
loops, is_suffix=False
)
if reduction_prefix:
stack_outer.enter_context(code.indent())
code.splice(kernels[0].reduction_prefix)
code.splice(reduction_prefix)
if loop_nest.is_reduction_only() and loop.parallel:
worksharing.parallel(threads)
@ -1064,13 +1073,13 @@ class CppKernel(Kernel):
gen_loop(loop, in_reduction)
if loops:
loop = loops[0]
if loop_nest.is_reduction_only() and loop.parallel:
worksharing.close()
for loop in loops:
if loop.is_reduction() and not in_reduction:
kernels = loop.get_kernels()
for kernel in kernels:
code.splice(kernel.reduction_suffix)
if loop.is_reduction() and not in_reduction:
code.splice(
get_reduction_code_buffer(loops, is_suffix=True)
)
def gen_loop(loop: LoopLevel, in_reduction=False):
with contextlib.ExitStack() as stack:
@ -1264,7 +1273,7 @@ class CppVecKernel(CppKernel):
self.reduction_omp_dec[reduction_type] = RTYPE_TO_CPP[reduction_type]
self.reduction_prefix.writeline(vec_reduc_prefix)
tmpvar = self.cse.generate(
tmpvar = self.reduction_cse.generate(
self.loads, f"reduction {name} {cexpr_index(index)}", write=False
)
tmpvar_vec = f"{tmpvar}_vec"
@ -1283,9 +1292,6 @@ class CppVecKernel(CppKernel):
if self.tiling_idx >= self.reduction_depth:
# Horizontal reduction
# NOTE(jgong5): we do not generate the real stores here with the assumption that
# the scalar kernel that handles the loop tail would be generated and generates
# the stores there.
reduce_all_body = "{"
if reduction_type == "sum":
reduce_all_body += "return x + y;"
@ -1298,21 +1304,28 @@ class CppVecKernel(CppKernel):
next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}&y) {reduce_all_body}, {tmpvar_vec})"
self.reduction_suffix.writeline(
DeferredLine(
name,
f"{reduction_combine(reduction_type, tmpvar, next_value)};",
name, f"{reduction_combine(reduction_type, tmpvar, next_value)};"
)
)
elif name not in V.graph.removed_buffers:
# Vertical reduction
if name not in V.graph.removed_buffers:
var = self.args.output(name)
new_index = self.scale_index_with_offset(
index, self.tiling_factor, itervar_idx=self.tiling_idx
)
self.reduction_suffix.writeline(
DeferredLine(
name, f"{tmpvar_vec}.store({var} + {cexpr_index(new_index)});"
if self.tiling_idx >= self.reduction_depth:
# Horizontal reduction
self.reduction_suffix.writeline(
DeferredLine(name, f"{var}[{cexpr_index(index)}] = {tmpvar};")
)
)
else:
# Vertical reduction
new_index = self.scale_index_with_offset(
index, self.tiling_factor, itervar_idx=self.tiling_idx
)
self.reduction_suffix.writeline(
DeferredLine(
name, f"{tmpvar_vec}.store({var} + {cexpr_index(new_index)});"
)
)
self.cse.store_cache[name] = tmpvar
@ -1347,30 +1360,32 @@ class CppTile2DKernel(CppVecKernel):
...
"""
def __init__(self, args, num_threads, tiling_factor, outer_tiling_idx):
super().__init__(args, num_threads, tiling_factor)
self.outer_tiling_idx = outer_tiling_idx
def __init__(self, args, num_threads, tiling_factor, tiling_indices):
super().__init__(args, num_threads, tiling_factor, tiling_indices[1])
self.tiling_indices = tiling_indices
def inner_itervar(self):
return sympy.symbols(f"{self.itervars[self.outer_tiling_idx]}_inner")
return sympy.symbols(f"{self.itervars[self.outer_idx]}_inner")
def need_vec_transpose(self, index):
return self.is_stride1_at(
self.itervars[self.outer_tiling_idx], index
) and not self.is_invariant_under(self.itervars[-1], index)
self.itervars[self.outer_idx], index
) and not self.is_invariant_under(self.itervars[self.tiling_idx], index)
def gen_transposed_tile_load_store(self, name, var, index, is_store):
# transposed tile load/store outside the kernel inner loop
dtype = V.graph.get_dtype(name)
factor = self.tiling_factor
new_index = self.scale_index_with_offset(index, factor, itervar_idx=-1)
new_index = self.scale_index_with_offset(
new_index, factor, itervar_idx=self.outer_tiling_idx
index, factor, itervar_idx=self.tiling_idx
)
new_index = self.scale_index_with_offset(
new_index, factor, itervar_idx=self.outer_idx
)
src = f"{var} + {cexpr_index(new_index)}"
dst = "__place_holder__"
ld_src = f"{cexpr_index(self.stride_at(self.itervars[-1], index))}"
ld_src = f"{cexpr_index(self.stride_at(self.itervars[self.tiling_idx], index))}"
ld_dst = f"{factor}"
if is_store:
src, dst = dst, src
@ -1424,7 +1439,7 @@ class CppTile2DKernel(CppVecKernel):
new_index = self.scale_index_with_offset(
expanded_index,
self.tiling_factor,
itervar_idx=self.outer_tiling_idx,
itervar_idx=self.outer_idx,
offset=inner,
)
return super().load(name, new_index)
@ -1457,7 +1472,7 @@ class CppTile2DKernel(CppVecKernel):
new_index = self.scale_index_with_offset(
expanded_index,
self.tiling_factor,
itervar_idx=self.outer_tiling_idx,
itervar_idx=self.outer_idx,
offset=inner,
)
super().store(name, new_index, value, mode)
@ -1468,6 +1483,16 @@ class CppTile2DKernel(CppVecKernel):
f"for (long {inner} = 0; {inner} < {self.tiling_factor}; {inner}++)"
)
def set_ranges(self, group, reduction_group):
vars = super().set_ranges(group, reduction_group)
# do vertical reduction as the tail loop
self.outer_idx, self.tiling_idx = (
self.tiling_indices
if self.tiling_indices[1] < self.reduction_depth
else reversed(self.tiling_indices)
)
return vars
class CppVecKernelChecker(CppVecKernel):
def __init__(self, args, num_threads, tiling_factor, tiling_idx=-1):
@ -2245,11 +2270,7 @@ class CppKernelProxy(CppKernel):
if vec_checker.simd_vec:
if len(tiling_indices) == 1:
return [tiling_factor], tiling_indices
if len(tiling_indices) == 2 and self.reduction_depth == len(
self.itervars
):
# TODO(jgong5): support tile2d with reduction
assert tiling_indices[1] == len(self.itervars) - 1
if len(tiling_indices) == 2:
return [tiling_factor, tiling_factor], tiling_indices
return [], []
@ -2286,12 +2307,10 @@ class CppKernelProxy(CppKernel):
)
outer_tail_loop.set_kernel(scalar_kernel)
inner_main_loop, inner_tail_loop = outer_main_loop.split_with_tiling(
inner_most_idx - tiling_indices[0], factor=tiling_factors[0]
tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0]
)
inner_main_loop.set_kernel(
codegen_kernel(
CppTile2DKernel, tiling_factors[0], tiling_indices[0]
)
codegen_kernel(CppTile2DKernel, tiling_factors[0], tiling_indices)
)
inner_tail_loop.set_kernel(
codegen_kernel(CppVecKernel, tiling_factors[0], tiling_indices[0])