mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] support transpose vertical reduction in cpp (#97781)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97781 Approved by: https://github.com/jansel
This commit is contained in:
parent
76074dc0a3
commit
5d62d12557
|
|
@ -1211,6 +1211,30 @@ class CPUReproTests(TestCase):
|
|||
self.assertTrue(same(fn2(x), opt_fn2(x)))
|
||||
assert metrics.generated_cpp_vec_kernel_count == 1
|
||||
|
||||
def test_transpose_vertical_sum_cpu_only(self):
|
||||
def fn(a, b):
|
||||
c = a * b
|
||||
return c.sum(dim=1)
|
||||
|
||||
metrics.reset()
|
||||
x = torch.randn(100, 50, 50)
|
||||
y = torch.randn(100, 50, 50).transpose(1, 2)
|
||||
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
||||
self.assertTrue(same(fn(x, y), opt_fn(x, y)))
|
||||
assert metrics.generated_cpp_vec_kernel_count == 2
|
||||
|
||||
def test_transpose_sum2d_cpu_only(self):
|
||||
def fn(a, b):
|
||||
c = a * b
|
||||
return c.sum()
|
||||
|
||||
metrics.reset()
|
||||
x = torch.randn(50, 50)
|
||||
y = torch.randn(50, 50).transpose(0, 1)
|
||||
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
||||
self.assertTrue(same(fn(x, y), opt_fn(x, y)))
|
||||
assert metrics.generated_cpp_vec_kernel_count == 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from ..virtualized import ops, V
|
|||
from .common import (
|
||||
BracesBuffer,
|
||||
CppWrapperKernelArgs,
|
||||
CSE,
|
||||
CSEVariable,
|
||||
DeferredLine,
|
||||
ExprPrinter,
|
||||
|
|
@ -904,6 +905,7 @@ class CppKernel(Kernel):
|
|||
self.reduction_prefix = IndentedBuffer()
|
||||
self.reduction_suffix = IndentedBuffer()
|
||||
self.reduction_var_map = {}
|
||||
self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc")
|
||||
self.preloads = IndentedBuffer()
|
||||
self.poststores = IndentedBuffer()
|
||||
self.num_threads = num_threads # num_threads the kernel specialized for
|
||||
|
|
@ -955,7 +957,7 @@ class CppKernel(Kernel):
|
|||
|
||||
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
|
||||
argmax_or_argmin = reduction_type in {"argmax", "argmin"}
|
||||
tmpvar = self.cse.generate(
|
||||
tmpvar = self.reduction_cse.generate(
|
||||
self.loads, f"reduction {name} {cexpr_index(index)}", write=False
|
||||
)
|
||||
index = self.rename_indexing(index)
|
||||
|
|
@ -1044,19 +1046,26 @@ class CppKernel(Kernel):
|
|||
if hasattr(kernel, "codegen_inner_loops"):
|
||||
code.splice(kernel.poststores)
|
||||
|
||||
def get_reduction_code_buffer(loops, is_suffix=True):
|
||||
for loop in loops:
|
||||
for kernel in loop.get_kernels():
|
||||
if is_suffix:
|
||||
return kernel.reduction_suffix
|
||||
else:
|
||||
return kernel.reduction_prefix
|
||||
return None
|
||||
|
||||
def gen_loops(loops: List[LoopLevel], in_reduction=False):
|
||||
with contextlib.ExitStack() as stack_outer:
|
||||
if loops:
|
||||
loop = loops[0]
|
||||
if loop.is_reduction() and not in_reduction:
|
||||
kernels = loop.get_kernels()
|
||||
assert kernels
|
||||
# TODO(jgong5): should gen prefix for all kernels.
|
||||
# currently, Vec kernel generates prefix for both
|
||||
# vector and scalar kernels.
|
||||
if kernels[0].reduction_prefix:
|
||||
reduction_prefix = get_reduction_code_buffer(
|
||||
loops, is_suffix=False
|
||||
)
|
||||
if reduction_prefix:
|
||||
stack_outer.enter_context(code.indent())
|
||||
code.splice(kernels[0].reduction_prefix)
|
||||
code.splice(reduction_prefix)
|
||||
if loop_nest.is_reduction_only() and loop.parallel:
|
||||
worksharing.parallel(threads)
|
||||
|
||||
|
|
@ -1064,13 +1073,13 @@ class CppKernel(Kernel):
|
|||
gen_loop(loop, in_reduction)
|
||||
|
||||
if loops:
|
||||
loop = loops[0]
|
||||
if loop_nest.is_reduction_only() and loop.parallel:
|
||||
worksharing.close()
|
||||
for loop in loops:
|
||||
if loop.is_reduction() and not in_reduction:
|
||||
kernels = loop.get_kernels()
|
||||
for kernel in kernels:
|
||||
code.splice(kernel.reduction_suffix)
|
||||
if loop.is_reduction() and not in_reduction:
|
||||
code.splice(
|
||||
get_reduction_code_buffer(loops, is_suffix=True)
|
||||
)
|
||||
|
||||
def gen_loop(loop: LoopLevel, in_reduction=False):
|
||||
with contextlib.ExitStack() as stack:
|
||||
|
|
@ -1264,7 +1273,7 @@ class CppVecKernel(CppKernel):
|
|||
self.reduction_omp_dec[reduction_type] = RTYPE_TO_CPP[reduction_type]
|
||||
self.reduction_prefix.writeline(vec_reduc_prefix)
|
||||
|
||||
tmpvar = self.cse.generate(
|
||||
tmpvar = self.reduction_cse.generate(
|
||||
self.loads, f"reduction {name} {cexpr_index(index)}", write=False
|
||||
)
|
||||
tmpvar_vec = f"{tmpvar}_vec"
|
||||
|
|
@ -1283,9 +1292,6 @@ class CppVecKernel(CppKernel):
|
|||
|
||||
if self.tiling_idx >= self.reduction_depth:
|
||||
# Horizontal reduction
|
||||
# NOTE(jgong5): we do not generate the real stores here with the assumption that
|
||||
# the scalar kernel that handles the loop tail would be generated and generates
|
||||
# the stores there.
|
||||
reduce_all_body = "{"
|
||||
if reduction_type == "sum":
|
||||
reduce_all_body += "return x + y;"
|
||||
|
|
@ -1298,21 +1304,28 @@ class CppVecKernel(CppKernel):
|
|||
next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}&y) {reduce_all_body}, {tmpvar_vec})"
|
||||
self.reduction_suffix.writeline(
|
||||
DeferredLine(
|
||||
name,
|
||||
f"{reduction_combine(reduction_type, tmpvar, next_value)};",
|
||||
name, f"{reduction_combine(reduction_type, tmpvar, next_value)};"
|
||||
)
|
||||
)
|
||||
elif name not in V.graph.removed_buffers:
|
||||
# Vertical reduction
|
||||
|
||||
if name not in V.graph.removed_buffers:
|
||||
var = self.args.output(name)
|
||||
new_index = self.scale_index_with_offset(
|
||||
index, self.tiling_factor, itervar_idx=self.tiling_idx
|
||||
)
|
||||
self.reduction_suffix.writeline(
|
||||
DeferredLine(
|
||||
name, f"{tmpvar_vec}.store({var} + {cexpr_index(new_index)});"
|
||||
if self.tiling_idx >= self.reduction_depth:
|
||||
# Horizontal reduction
|
||||
self.reduction_suffix.writeline(
|
||||
DeferredLine(name, f"{var}[{cexpr_index(index)}] = {tmpvar};")
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Vertical reduction
|
||||
new_index = self.scale_index_with_offset(
|
||||
index, self.tiling_factor, itervar_idx=self.tiling_idx
|
||||
)
|
||||
self.reduction_suffix.writeline(
|
||||
DeferredLine(
|
||||
name, f"{tmpvar_vec}.store({var} + {cexpr_index(new_index)});"
|
||||
)
|
||||
)
|
||||
|
||||
self.cse.store_cache[name] = tmpvar
|
||||
|
||||
|
||||
|
|
@ -1347,30 +1360,32 @@ class CppTile2DKernel(CppVecKernel):
|
|||
...
|
||||
"""
|
||||
|
||||
def __init__(self, args, num_threads, tiling_factor, outer_tiling_idx):
|
||||
super().__init__(args, num_threads, tiling_factor)
|
||||
self.outer_tiling_idx = outer_tiling_idx
|
||||
def __init__(self, args, num_threads, tiling_factor, tiling_indices):
|
||||
super().__init__(args, num_threads, tiling_factor, tiling_indices[1])
|
||||
self.tiling_indices = tiling_indices
|
||||
|
||||
def inner_itervar(self):
|
||||
return sympy.symbols(f"{self.itervars[self.outer_tiling_idx]}_inner")
|
||||
return sympy.symbols(f"{self.itervars[self.outer_idx]}_inner")
|
||||
|
||||
def need_vec_transpose(self, index):
|
||||
return self.is_stride1_at(
|
||||
self.itervars[self.outer_tiling_idx], index
|
||||
) and not self.is_invariant_under(self.itervars[-1], index)
|
||||
self.itervars[self.outer_idx], index
|
||||
) and not self.is_invariant_under(self.itervars[self.tiling_idx], index)
|
||||
|
||||
def gen_transposed_tile_load_store(self, name, var, index, is_store):
|
||||
# transposed tile load/store outside the kernel inner loop
|
||||
dtype = V.graph.get_dtype(name)
|
||||
factor = self.tiling_factor
|
||||
new_index = self.scale_index_with_offset(index, factor, itervar_idx=-1)
|
||||
new_index = self.scale_index_with_offset(
|
||||
new_index, factor, itervar_idx=self.outer_tiling_idx
|
||||
index, factor, itervar_idx=self.tiling_idx
|
||||
)
|
||||
new_index = self.scale_index_with_offset(
|
||||
new_index, factor, itervar_idx=self.outer_idx
|
||||
)
|
||||
|
||||
src = f"{var} + {cexpr_index(new_index)}"
|
||||
dst = "__place_holder__"
|
||||
ld_src = f"{cexpr_index(self.stride_at(self.itervars[-1], index))}"
|
||||
ld_src = f"{cexpr_index(self.stride_at(self.itervars[self.tiling_idx], index))}"
|
||||
ld_dst = f"{factor}"
|
||||
if is_store:
|
||||
src, dst = dst, src
|
||||
|
|
@ -1424,7 +1439,7 @@ class CppTile2DKernel(CppVecKernel):
|
|||
new_index = self.scale_index_with_offset(
|
||||
expanded_index,
|
||||
self.tiling_factor,
|
||||
itervar_idx=self.outer_tiling_idx,
|
||||
itervar_idx=self.outer_idx,
|
||||
offset=inner,
|
||||
)
|
||||
return super().load(name, new_index)
|
||||
|
|
@ -1457,7 +1472,7 @@ class CppTile2DKernel(CppVecKernel):
|
|||
new_index = self.scale_index_with_offset(
|
||||
expanded_index,
|
||||
self.tiling_factor,
|
||||
itervar_idx=self.outer_tiling_idx,
|
||||
itervar_idx=self.outer_idx,
|
||||
offset=inner,
|
||||
)
|
||||
super().store(name, new_index, value, mode)
|
||||
|
|
@ -1468,6 +1483,16 @@ class CppTile2DKernel(CppVecKernel):
|
|||
f"for (long {inner} = 0; {inner} < {self.tiling_factor}; {inner}++)"
|
||||
)
|
||||
|
||||
def set_ranges(self, group, reduction_group):
|
||||
vars = super().set_ranges(group, reduction_group)
|
||||
# do vertical reduction as the tail loop
|
||||
self.outer_idx, self.tiling_idx = (
|
||||
self.tiling_indices
|
||||
if self.tiling_indices[1] < self.reduction_depth
|
||||
else reversed(self.tiling_indices)
|
||||
)
|
||||
return vars
|
||||
|
||||
|
||||
class CppVecKernelChecker(CppVecKernel):
|
||||
def __init__(self, args, num_threads, tiling_factor, tiling_idx=-1):
|
||||
|
|
@ -2245,11 +2270,7 @@ class CppKernelProxy(CppKernel):
|
|||
if vec_checker.simd_vec:
|
||||
if len(tiling_indices) == 1:
|
||||
return [tiling_factor], tiling_indices
|
||||
if len(tiling_indices) == 2 and self.reduction_depth == len(
|
||||
self.itervars
|
||||
):
|
||||
# TODO(jgong5): support tile2d with reduction
|
||||
assert tiling_indices[1] == len(self.itervars) - 1
|
||||
if len(tiling_indices) == 2:
|
||||
return [tiling_factor, tiling_factor], tiling_indices
|
||||
return [], []
|
||||
|
||||
|
|
@ -2286,12 +2307,10 @@ class CppKernelProxy(CppKernel):
|
|||
)
|
||||
outer_tail_loop.set_kernel(scalar_kernel)
|
||||
inner_main_loop, inner_tail_loop = outer_main_loop.split_with_tiling(
|
||||
inner_most_idx - tiling_indices[0], factor=tiling_factors[0]
|
||||
tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0]
|
||||
)
|
||||
inner_main_loop.set_kernel(
|
||||
codegen_kernel(
|
||||
CppTile2DKernel, tiling_factors[0], tiling_indices[0]
|
||||
)
|
||||
codegen_kernel(CppTile2DKernel, tiling_factors[0], tiling_indices)
|
||||
)
|
||||
inner_tail_loop.set_kernel(
|
||||
codegen_kernel(CppVecKernel, tiling_factors[0], tiling_indices[0])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user