[Inductor] support transpose vertical reduction in cpp (#97781)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97781
Approved by: https://github.com/jansel
This commit is contained in:
Jiong Gong 2023-04-02 10:06:47 +00:00 committed by PyTorch MergeBot
parent 76074dc0a3
commit 5d62d12557
2 changed files with 91 additions and 48 deletions

View File

@ -1211,6 +1211,30 @@ class CPUReproTests(TestCase):
self.assertTrue(same(fn2(x), opt_fn2(x))) self.assertTrue(same(fn2(x), opt_fn2(x)))
assert metrics.generated_cpp_vec_kernel_count == 1 assert metrics.generated_cpp_vec_kernel_count == 1
def test_transpose_vertical_sum_cpu_only(self):
def fn(a, b):
c = a * b
return c.sum(dim=1)
metrics.reset()
x = torch.randn(100, 50, 50)
y = torch.randn(100, 50, 50).transpose(1, 2)
opt_fn = torch._dynamo.optimize("inductor")(fn)
self.assertTrue(same(fn(x, y), opt_fn(x, y)))
assert metrics.generated_cpp_vec_kernel_count == 2
def test_transpose_sum2d_cpu_only(self):
def fn(a, b):
c = a * b
return c.sum()
metrics.reset()
x = torch.randn(50, 50)
y = torch.randn(50, 50).transpose(0, 1)
opt_fn = torch._dynamo.optimize("inductor")(fn)
self.assertTrue(same(fn(x, y), opt_fn(x, y)))
assert metrics.generated_cpp_vec_kernel_count == 2
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -26,6 +26,7 @@ from ..virtualized import ops, V
from .common import ( from .common import (
BracesBuffer, BracesBuffer,
CppWrapperKernelArgs, CppWrapperKernelArgs,
CSE,
CSEVariable, CSEVariable,
DeferredLine, DeferredLine,
ExprPrinter, ExprPrinter,
@ -904,6 +905,7 @@ class CppKernel(Kernel):
self.reduction_prefix = IndentedBuffer() self.reduction_prefix = IndentedBuffer()
self.reduction_suffix = IndentedBuffer() self.reduction_suffix = IndentedBuffer()
self.reduction_var_map = {} self.reduction_var_map = {}
self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc")
self.preloads = IndentedBuffer() self.preloads = IndentedBuffer()
self.poststores = IndentedBuffer() self.poststores = IndentedBuffer()
self.num_threads = num_threads # num_threads the kernel specialized for self.num_threads = num_threads # num_threads the kernel specialized for
@ -955,7 +957,7 @@ class CppKernel(Kernel):
def reduction(self, name, dtype, src_dtype, reduction_type, index, value): def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
argmax_or_argmin = reduction_type in {"argmax", "argmin"} argmax_or_argmin = reduction_type in {"argmax", "argmin"}
tmpvar = self.cse.generate( tmpvar = self.reduction_cse.generate(
self.loads, f"reduction {name} {cexpr_index(index)}", write=False self.loads, f"reduction {name} {cexpr_index(index)}", write=False
) )
index = self.rename_indexing(index) index = self.rename_indexing(index)
@ -1044,19 +1046,26 @@ class CppKernel(Kernel):
if hasattr(kernel, "codegen_inner_loops"): if hasattr(kernel, "codegen_inner_loops"):
code.splice(kernel.poststores) code.splice(kernel.poststores)
def get_reduction_code_buffer(loops, is_suffix=True):
for loop in loops:
for kernel in loop.get_kernels():
if is_suffix:
return kernel.reduction_suffix
else:
return kernel.reduction_prefix
return None
def gen_loops(loops: List[LoopLevel], in_reduction=False): def gen_loops(loops: List[LoopLevel], in_reduction=False):
with contextlib.ExitStack() as stack_outer: with contextlib.ExitStack() as stack_outer:
if loops: if loops:
loop = loops[0] loop = loops[0]
if loop.is_reduction() and not in_reduction: if loop.is_reduction() and not in_reduction:
kernels = loop.get_kernels() reduction_prefix = get_reduction_code_buffer(
assert kernels loops, is_suffix=False
# TODO(jgong5): should gen prefix for all kernels. )
# currently, Vec kernel generates prefix for both if reduction_prefix:
# vector and scalar kernels.
if kernels[0].reduction_prefix:
stack_outer.enter_context(code.indent()) stack_outer.enter_context(code.indent())
code.splice(kernels[0].reduction_prefix) code.splice(reduction_prefix)
if loop_nest.is_reduction_only() and loop.parallel: if loop_nest.is_reduction_only() and loop.parallel:
worksharing.parallel(threads) worksharing.parallel(threads)
@ -1064,13 +1073,13 @@ class CppKernel(Kernel):
gen_loop(loop, in_reduction) gen_loop(loop, in_reduction)
if loops: if loops:
loop = loops[0]
if loop_nest.is_reduction_only() and loop.parallel: if loop_nest.is_reduction_only() and loop.parallel:
worksharing.close() worksharing.close()
for loop in loops: if loop.is_reduction() and not in_reduction:
if loop.is_reduction() and not in_reduction: code.splice(
kernels = loop.get_kernels() get_reduction_code_buffer(loops, is_suffix=True)
for kernel in kernels: )
code.splice(kernel.reduction_suffix)
def gen_loop(loop: LoopLevel, in_reduction=False): def gen_loop(loop: LoopLevel, in_reduction=False):
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
@ -1264,7 +1273,7 @@ class CppVecKernel(CppKernel):
self.reduction_omp_dec[reduction_type] = RTYPE_TO_CPP[reduction_type] self.reduction_omp_dec[reduction_type] = RTYPE_TO_CPP[reduction_type]
self.reduction_prefix.writeline(vec_reduc_prefix) self.reduction_prefix.writeline(vec_reduc_prefix)
tmpvar = self.cse.generate( tmpvar = self.reduction_cse.generate(
self.loads, f"reduction {name} {cexpr_index(index)}", write=False self.loads, f"reduction {name} {cexpr_index(index)}", write=False
) )
tmpvar_vec = f"{tmpvar}_vec" tmpvar_vec = f"{tmpvar}_vec"
@ -1283,9 +1292,6 @@ class CppVecKernel(CppKernel):
if self.tiling_idx >= self.reduction_depth: if self.tiling_idx >= self.reduction_depth:
# Horizontal reduction # Horizontal reduction
# NOTE(jgong5): we do not generate the real stores here with the assumption that
# the scalar kernel that handles the loop tail would be generated and generates
# the stores there.
reduce_all_body = "{" reduce_all_body = "{"
if reduction_type == "sum": if reduction_type == "sum":
reduce_all_body += "return x + y;" reduce_all_body += "return x + y;"
@ -1298,21 +1304,28 @@ class CppVecKernel(CppKernel):
next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}&y) {reduce_all_body}, {tmpvar_vec})" next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}&y) {reduce_all_body}, {tmpvar_vec})"
self.reduction_suffix.writeline( self.reduction_suffix.writeline(
DeferredLine( DeferredLine(
name, name, f"{reduction_combine(reduction_type, tmpvar, next_value)};"
f"{reduction_combine(reduction_type, tmpvar, next_value)};",
) )
) )
elif name not in V.graph.removed_buffers:
# Vertical reduction if name not in V.graph.removed_buffers:
var = self.args.output(name) var = self.args.output(name)
new_index = self.scale_index_with_offset( if self.tiling_idx >= self.reduction_depth:
index, self.tiling_factor, itervar_idx=self.tiling_idx # Horizontal reduction
) self.reduction_suffix.writeline(
self.reduction_suffix.writeline( DeferredLine(name, f"{var}[{cexpr_index(index)}] = {tmpvar};")
DeferredLine(
name, f"{tmpvar_vec}.store({var} + {cexpr_index(new_index)});"
) )
) else:
# Vertical reduction
new_index = self.scale_index_with_offset(
index, self.tiling_factor, itervar_idx=self.tiling_idx
)
self.reduction_suffix.writeline(
DeferredLine(
name, f"{tmpvar_vec}.store({var} + {cexpr_index(new_index)});"
)
)
self.cse.store_cache[name] = tmpvar self.cse.store_cache[name] = tmpvar
@ -1347,30 +1360,32 @@ class CppTile2DKernel(CppVecKernel):
... ...
""" """
def __init__(self, args, num_threads, tiling_factor, outer_tiling_idx): def __init__(self, args, num_threads, tiling_factor, tiling_indices):
super().__init__(args, num_threads, tiling_factor) super().__init__(args, num_threads, tiling_factor, tiling_indices[1])
self.outer_tiling_idx = outer_tiling_idx self.tiling_indices = tiling_indices
def inner_itervar(self): def inner_itervar(self):
return sympy.symbols(f"{self.itervars[self.outer_tiling_idx]}_inner") return sympy.symbols(f"{self.itervars[self.outer_idx]}_inner")
def need_vec_transpose(self, index): def need_vec_transpose(self, index):
return self.is_stride1_at( return self.is_stride1_at(
self.itervars[self.outer_tiling_idx], index self.itervars[self.outer_idx], index
) and not self.is_invariant_under(self.itervars[-1], index) ) and not self.is_invariant_under(self.itervars[self.tiling_idx], index)
def gen_transposed_tile_load_store(self, name, var, index, is_store): def gen_transposed_tile_load_store(self, name, var, index, is_store):
# transposed tile load/store outside the kernel inner loop # transposed tile load/store outside the kernel inner loop
dtype = V.graph.get_dtype(name) dtype = V.graph.get_dtype(name)
factor = self.tiling_factor factor = self.tiling_factor
new_index = self.scale_index_with_offset(index, factor, itervar_idx=-1)
new_index = self.scale_index_with_offset( new_index = self.scale_index_with_offset(
new_index, factor, itervar_idx=self.outer_tiling_idx index, factor, itervar_idx=self.tiling_idx
)
new_index = self.scale_index_with_offset(
new_index, factor, itervar_idx=self.outer_idx
) )
src = f"{var} + {cexpr_index(new_index)}" src = f"{var} + {cexpr_index(new_index)}"
dst = "__place_holder__" dst = "__place_holder__"
ld_src = f"{cexpr_index(self.stride_at(self.itervars[-1], index))}" ld_src = f"{cexpr_index(self.stride_at(self.itervars[self.tiling_idx], index))}"
ld_dst = f"{factor}" ld_dst = f"{factor}"
if is_store: if is_store:
src, dst = dst, src src, dst = dst, src
@ -1424,7 +1439,7 @@ class CppTile2DKernel(CppVecKernel):
new_index = self.scale_index_with_offset( new_index = self.scale_index_with_offset(
expanded_index, expanded_index,
self.tiling_factor, self.tiling_factor,
itervar_idx=self.outer_tiling_idx, itervar_idx=self.outer_idx,
offset=inner, offset=inner,
) )
return super().load(name, new_index) return super().load(name, new_index)
@ -1457,7 +1472,7 @@ class CppTile2DKernel(CppVecKernel):
new_index = self.scale_index_with_offset( new_index = self.scale_index_with_offset(
expanded_index, expanded_index,
self.tiling_factor, self.tiling_factor,
itervar_idx=self.outer_tiling_idx, itervar_idx=self.outer_idx,
offset=inner, offset=inner,
) )
super().store(name, new_index, value, mode) super().store(name, new_index, value, mode)
@ -1468,6 +1483,16 @@ class CppTile2DKernel(CppVecKernel):
f"for (long {inner} = 0; {inner} < {self.tiling_factor}; {inner}++)" f"for (long {inner} = 0; {inner} < {self.tiling_factor}; {inner}++)"
) )
def set_ranges(self, group, reduction_group):
vars = super().set_ranges(group, reduction_group)
# do vertical reduction as the tail loop
self.outer_idx, self.tiling_idx = (
self.tiling_indices
if self.tiling_indices[1] < self.reduction_depth
else reversed(self.tiling_indices)
)
return vars
class CppVecKernelChecker(CppVecKernel): class CppVecKernelChecker(CppVecKernel):
def __init__(self, args, num_threads, tiling_factor, tiling_idx=-1): def __init__(self, args, num_threads, tiling_factor, tiling_idx=-1):
@ -2245,11 +2270,7 @@ class CppKernelProxy(CppKernel):
if vec_checker.simd_vec: if vec_checker.simd_vec:
if len(tiling_indices) == 1: if len(tiling_indices) == 1:
return [tiling_factor], tiling_indices return [tiling_factor], tiling_indices
if len(tiling_indices) == 2 and self.reduction_depth == len( if len(tiling_indices) == 2:
self.itervars
):
# TODO(jgong5): support tile2d with reduction
assert tiling_indices[1] == len(self.itervars) - 1
return [tiling_factor, tiling_factor], tiling_indices return [tiling_factor, tiling_factor], tiling_indices
return [], [] return [], []
@ -2286,12 +2307,10 @@ class CppKernelProxy(CppKernel):
) )
outer_tail_loop.set_kernel(scalar_kernel) outer_tail_loop.set_kernel(scalar_kernel)
inner_main_loop, inner_tail_loop = outer_main_loop.split_with_tiling( inner_main_loop, inner_tail_loop = outer_main_loop.split_with_tiling(
inner_most_idx - tiling_indices[0], factor=tiling_factors[0] tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0]
) )
inner_main_loop.set_kernel( inner_main_loop.set_kernel(
codegen_kernel( codegen_kernel(CppTile2DKernel, tiling_factors[0], tiling_indices)
CppTile2DKernel, tiling_factors[0], tiling_indices[0]
)
) )
inner_tail_loop.set_kernel( inner_tail_loop.set_kernel(
codegen_kernel(CppVecKernel, tiling_factors[0], tiling_indices[0]) codegen_kernel(CppVecKernel, tiling_factors[0], tiling_indices[0])