mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Inductor cpp wrapper: support Reduction (#88561)
For reductions, the code string in the codegen stage and the execution stage are different due to `\`.
- The code string gotten from `code.getvalue()` (`code` is an `IndentedBuffer`) in codegen stage:
```
#pragma omp declare reduction(argmax : struct IndexValue_1 :\
omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value,\
omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index)\
initializer(omp_priv = {0, -std::numeric_limits<float>::infinity()})
```
- The code string loaded during the execution (`\` will be escaped):
```
#pragma omp declare reduction(argmax : struct IndexValue_1 : omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value, omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index) initializer(omp_priv = {0, -std::numeric_limits<float>::infinity()})
```
Thus we can't get the same hash value for these two pieces of code.
This PR adds a function to make the transformation escape the backslash in the codegen stage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88561
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
This commit is contained in:
parent
7963dbf3db
commit
d35aa2f65a
|
|
@ -4922,6 +4922,7 @@ class CommonTemplate:
|
|||
"test_cat", # alias
|
||||
"test_lowmem_dropout1", # None as output
|
||||
"test_profiler_mark_wrapper_call", # TODO: fallback to default wrapper for now
|
||||
"test_reduction1", # Reduction
|
||||
"test_relu", # multiple inputs
|
||||
"test_silu", # single input, single output
|
||||
"test_transpose", # multiple outputs, buffer clear
|
||||
|
|
|
|||
|
|
@ -146,6 +146,22 @@ class IndentedBuffer:
|
|||
buf.write("\n")
|
||||
return buf.getvalue()
|
||||
|
||||
def getrawvalue(self):
|
||||
buf = StringIO()
|
||||
for line in self._lines:
|
||||
if isinstance(line, DeferredLine):
|
||||
line = line()
|
||||
if line is None:
|
||||
continue
|
||||
assert isinstance(line, str)
|
||||
# backslash implies line continuation
|
||||
if line.endswith("\\"):
|
||||
buf.write(line[:-1])
|
||||
else:
|
||||
buf.write(line)
|
||||
buf.write("\n")
|
||||
return buf.getvalue()
|
||||
|
||||
def clear(self):
|
||||
self._lines.clear()
|
||||
|
||||
|
|
|
|||
|
|
@ -651,7 +651,16 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
ext = "so"
|
||||
extra = cpp_compile_command("i", "o", vec_isa=picked_vec_isa)
|
||||
# \n is required to match with the CodeCache behavior
|
||||
source_code = "\n" + code.getvalue()
|
||||
# For reductions, the code string gotten from code.getvalue() will use backslash '\'
|
||||
# at the end of lines for readability purpose:
|
||||
# #pragma omp declare reduction(xxx :\
|
||||
# omp_out.value = xxx,\
|
||||
# While the code string loaded during the execution will escape the backslash '\':
|
||||
# #pragma omp declare reduction(xxx : omp_out.value = xxx,
|
||||
# Use code.getrawvalue() here to escape the backslash to
|
||||
# make sure the same code string is used during compilation and execution,
|
||||
# so that the hash value is the same.
|
||||
source_code = "\n" + code.getrawvalue()
|
||||
_, _, kernel_path = get_code_path(source_code, ext, extra)
|
||||
return kernel_path
|
||||
|
||||
|
|
|
|||
|
|
@ -158,9 +158,6 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
def check_buffer_for_cpp_wrapper(self, buffer: ir.ComputedBuffer):
|
||||
if isinstance(buffer, ir.ExternKernel):
|
||||
self.disable_cpp_wrapper("ExternKernel")
|
||||
if isinstance(buffer, ir.ComputedBuffer):
|
||||
if buffer.data.get_reduction_type():
|
||||
self.disable_cpp_wrapper("Reduction")
|
||||
|
||||
def register_buffer(self, buffer: ir.ComputedBuffer):
|
||||
if config.cpp_wrapper:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user