Inductor cpp wrapper: support Reduction (#88561)

For reductions, the code string in the codegen stage and the execution stage are different due to `\`.

- The code string gotten from `code.getvalue()` (`code` is an `IndentedBuffer`) in codegen stage:
  ```
  #pragma omp declare reduction(argmax : struct IndexValue_1 :\
                  omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value,\
                  omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index)\
                  initializer(omp_priv = {0, -std::numeric_limits<float>::infinity()})
  ```

- The code string loaded during the execution (`\` will be escaped):
  ```
  #pragma omp declare reduction(argmax : struct IndexValue_1 :                omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value,                omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index)                  initializer(omp_priv = {0, -std::numeric_limits<float>::infinity()})
  ```

Thus we can't get the same hash value for these two pieces of code.
This PR adds a function to make the transformation escape the backslash in the codegen stage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88561
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
This commit is contained in:
chunyuan 2022-12-14 15:43:28 +00:00 committed by PyTorch MergeBot
parent 7963dbf3db
commit d35aa2f65a
4 changed files with 27 additions and 4 deletions

View File

@ -4922,6 +4922,7 @@ class CommonTemplate:
"test_cat", # alias
"test_lowmem_dropout1", # None as output
"test_profiler_mark_wrapper_call", # TODO: fallback to default wrapper for now
"test_reduction1", # Reduction
"test_relu", # multiple inputs
"test_silu", # single input, single output
"test_transpose", # multiple outputs, buffer clear

View File

@ -146,6 +146,22 @@ class IndentedBuffer:
buf.write("\n")
return buf.getvalue()
def getrawvalue(self):
buf = StringIO()
for line in self._lines:
if isinstance(line, DeferredLine):
line = line()
if line is None:
continue
assert isinstance(line, str)
# backslash implies line continuation
if line.endswith("\\"):
buf.write(line[:-1])
else:
buf.write(line)
buf.write("\n")
return buf.getvalue()
def clear(self):
self._lines.clear()

View File

@ -651,7 +651,16 @@ class CppWrapperCodeGen(WrapperCodeGen):
ext = "so"
extra = cpp_compile_command("i", "o", vec_isa=picked_vec_isa)
# \n is required to match with the CodeCache behavior
source_code = "\n" + code.getvalue()
# For reductions, the code string gotten from code.getvalue() will use backslash '\'
# at the end of lines for readability purpose:
# #pragma omp declare reduction(xxx :\
# omp_out.value = xxx,\
# While the code string loaded during the execution will escape the backslash '\':
# #pragma omp declare reduction(xxx : omp_out.value = xxx,
# Use code.getrawvalue() here to escape the backslash to
# make sure the same code string is used during compilation and execution,
# so that the hash value is the same.
source_code = "\n" + code.getrawvalue()
_, _, kernel_path = get_code_path(source_code, ext, extra)
return kernel_path

View File

@ -158,9 +158,6 @@ class GraphLowering(torch.fx.Interpreter):
def check_buffer_for_cpp_wrapper(self, buffer: ir.ComputedBuffer):
if isinstance(buffer, ir.ExternKernel):
self.disable_cpp_wrapper("ExternKernel")
if isinstance(buffer, ir.ComputedBuffer):
if buffer.data.get_reduction_type():
self.disable_cpp_wrapper("Reduction")
def register_buffer(self, buffer: ir.ComputedBuffer):
if config.cpp_wrapper: