mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[Inductor-FX] Support ScatterFallback (#162686)
# Problem Inductor has a `ScatterFallback` op with custom Python and C++ wrapper codegen macros. This is used in certain situations where the default Triton codegen doesn't apply, and especially for reductions which need to be deterministic. Since this op used direct Python/C++ codegen, it wasn't compatible with the FX backend. # Feature This PR refactors the associated wrapper codegen to support `ScatterFallback`. This follows the same basic steps that were used for other fallback ops including `MultiOutput` and `ExternKernel`: 1. Create a new wrapper IR op called `ScatterFallbackLine`. Move the logic in `ScatterFallback.cogeden` to `ScatterFallbackLine.codegen`, to prevent it from affecting the FX backend. This logic is unsafe for FX because it may generate Python or C++ strings with methods like `codegen_reference()`. 2. To eleminate the dependence on `V.graph`, move language-specific logic to the respective wrapper codegen subclasses. In this case, C++ codegen has some special logic, which is moved to `CppWrapperCpu`. 3. Create a new method in `FXWrapperCodegen` to handle `ScatterFallbackLine`. # Test plan Added a couple of CI tests for the FX backend with scatter fallbacks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162686 Approved by: https://github.com/jansel
This commit is contained in:
parent
98e9440f30
commit
a7bbc5fea7
|
|
@ -25,6 +25,7 @@ from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen
|
|||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch.export import Dim
|
||||
from torch.testing._internal.common_utils import (
|
||||
DeterministicGuard,
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
|
|
@ -567,6 +568,50 @@ class FxirTestCase(InductorTestCase):
|
|||
|
||||
self.assertTrue(same(ref, result))
|
||||
|
||||
def test_scatter_fallback_scalar_src(self):
|
||||
"""
|
||||
Test a special case where ScatterFallback takes a scalar 'src' argument.
|
||||
"""
|
||||
|
||||
def foo(input_):
|
||||
dim = 0
|
||||
src = 1.5
|
||||
return torch.ops.aten.scatter(input_, dim, index, src)
|
||||
|
||||
length = 8
|
||||
index = torch.randint(length, (length,), device=self.device)
|
||||
input_ = torch.randn(length, device=self.device)
|
||||
with DeterministicGuard(True):
|
||||
(gm,) = self._compile_and_check(
|
||||
foo,
|
||||
(input_,),
|
||||
)
|
||||
|
||||
# Check for the fallback op.
|
||||
num_fallback = self._count_ops(gm, torch.ops.aten.scatter_.value)
|
||||
self.assertEqual(num_fallback, 1)
|
||||
|
||||
def test_scatter_reduce_fallback(self):
|
||||
"""
|
||||
Test the customized wrapper codegen for ScatterFallback ops.
|
||||
"""
|
||||
fallback_op = torch.ops.aten.scatter_reduce_.two
|
||||
|
||||
def foo(out, index, src):
|
||||
dim = 0
|
||||
out = fallback_op(out, dim, index, src, reduce="amax", include_self=False)
|
||||
return out + 1
|
||||
|
||||
length = 8
|
||||
out, src = [torch.randn(length, device=self.device) for _ in range(2)]
|
||||
index = torch.randint(length, (length,), device=self.device)
|
||||
(gm,) = self._compile_and_check(
|
||||
foo, (out, index, src), expected_num_triton_kernels=2
|
||||
)
|
||||
|
||||
# Check for the fallback.
|
||||
self.assertEqual(self._count_ops(gm, fallback_op), 1)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_subgraph_raises(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1398,7 +1398,15 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
kernel, args, device, debug_handle=debug_handle
|
||||
)
|
||||
|
||||
def generate_scatter_fallback(
|
||||
def _get_scatter_reduce_enum(self, reduce):
|
||||
# Follow aten/src/ATen/native/ReductionType.h:get_operator_enum
|
||||
get_operator_enum = {"add": "sum", "multiply": "prod"}
|
||||
if reduce in get_operator_enum:
|
||||
reduce = get_operator_enum[reduce]
|
||||
|
||||
return reduce
|
||||
|
||||
def _generate_scatter_fallback(
|
||||
self,
|
||||
output,
|
||||
inputs,
|
||||
|
|
@ -1408,6 +1416,8 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
reduce,
|
||||
kwargs,
|
||||
):
|
||||
reduce = self._get_scatter_reduce_enum(reduce)
|
||||
|
||||
# call the ABI shim function instead of the ATen one
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
|
||||
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
|
||||
|
|
|
|||
|
|
@ -694,7 +694,12 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
|||
kernel, wrapped_args, device, debug_args=args
|
||||
)
|
||||
|
||||
def generate_scatter_fallback(
|
||||
def generate_scatter_fallback(self, node: ir.ScatterFallback):
|
||||
# No stack allocation when there is a fallback op
|
||||
self.allow_stack_allocation = False
|
||||
super().generate_scatter_fallback(node)
|
||||
|
||||
def _generate_scatter_fallback(
|
||||
self,
|
||||
output,
|
||||
inputs,
|
||||
|
|
@ -704,8 +709,7 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
|||
reduce,
|
||||
kwargs,
|
||||
):
|
||||
# No stack allocation when there is a fallback op
|
||||
self.allow_stack_allocation = False
|
||||
reduce = self._get_scatter_reduce_enum(reduce)
|
||||
|
||||
# call the ABI shim function instead of the ATen one
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
|
||||
|
|
|
|||
|
|
@ -908,6 +908,33 @@ class MultiOutputLine(WrapperLine):
|
|||
return converter._generate_multi_output
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ScatterFallbackLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
node: ir.ScatterFallback
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
node = self.node
|
||||
assert ir.is_node_sequence(node.inputs)
|
||||
if node.src_is_tensor:
|
||||
(x, index, src) = (t.codegen_reference() for t in node.inputs)
|
||||
else:
|
||||
(x, index) = (t.codegen_reference() for t in node.inputs)
|
||||
src = node.constant_args[1]
|
||||
self.wrapper._generate_scatter_fallback(
|
||||
x,
|
||||
[x, node.constant_args[0], index, src],
|
||||
node.cpp_kernel_name,
|
||||
node.python_kernel_name,
|
||||
node.src_is_tensor,
|
||||
node.kwargs["reduce"],
|
||||
node.codegen_kwargs(),
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_scatter_fallback
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SymbolicCallArgLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
|
|
@ -1511,7 +1538,10 @@ class PythonWrapperCodegen(CodeGen):
|
|||
line = f"{desc.name} = {call}{self.ending}"
|
||||
self.writeline(line)
|
||||
|
||||
def generate_scatter_fallback(
|
||||
def generate_scatter_fallback(self, node: ir.ScatterFallback):
|
||||
self.writeline(ScatterFallbackLine(self, node))
|
||||
|
||||
def _generate_scatter_fallback(
|
||||
self,
|
||||
output,
|
||||
inputs,
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ from .wrapper import (
|
|||
PythonWrapperCodegen,
|
||||
ReinterpretLine,
|
||||
ReuseLine,
|
||||
ScatterFallbackLine,
|
||||
SymbolicCallArg,
|
||||
SymbolicCallArgLine,
|
||||
WrapperLine,
|
||||
|
|
@ -653,6 +654,26 @@ class FxConverter:
|
|||
node.name = line.result_name
|
||||
self.buffer_to_node[line.result_name] = node
|
||||
|
||||
def _generate_scatter_fallback(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ScatterFallbackLine)
|
||||
ir_node = line.node
|
||||
assert ir.is_node_sequence(ir_node.inputs)
|
||||
(x, index, src) = [self._generate_buffer(t) for t in ir_node.inputs] + (
|
||||
[] if ir_node.src_is_tensor else [ir_node.constant_args[1]]
|
||||
)
|
||||
args = (x, ir_node.constant_args[0], index, src)
|
||||
kwargs = {}
|
||||
if reduce := ir_node.kwargs.get("reduce"):
|
||||
kwargs["reduce"] = reduce
|
||||
|
||||
fx_node = self.gm.graph.call_function(
|
||||
ir_node.op_overload, # type: ignore[arg-type]
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
result_buffer = ir_node.codegen_reference()
|
||||
self.buffer_to_node[result_buffer] = fx_node
|
||||
|
||||
def _generate_null(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, NullLine)
|
||||
# Does nothing.
|
||||
|
|
|
|||
|
|
@ -7060,28 +7060,7 @@ class ScatterFallback(ExternKernel):
|
|||
"""
|
||||
|
||||
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
|
||||
reduce = self.kwargs["reduce"]
|
||||
if V.graph.cpp_wrapper:
|
||||
# Follow aten/src/ATen/native/ReductionType.h:get_operator_enum
|
||||
get_operator_enum = {"add": "sum", "multiply": "prod"}
|
||||
if reduce in get_operator_enum:
|
||||
reduce = get_operator_enum[reduce]
|
||||
|
||||
assert is_node_sequence(self.inputs)
|
||||
if self.src_is_tensor:
|
||||
(x, index, src) = (t.codegen_reference() for t in self.inputs)
|
||||
else:
|
||||
(x, index) = (t.codegen_reference() for t in self.inputs)
|
||||
src = self.constant_args[1]
|
||||
wrapper.generate_scatter_fallback(
|
||||
x,
|
||||
[x, self.constant_args[0], index, src],
|
||||
self.cpp_kernel_name,
|
||||
self.python_kernel_name,
|
||||
self.src_is_tensor,
|
||||
reduce,
|
||||
self.codegen_kwargs(),
|
||||
)
|
||||
wrapper.generate_scatter_fallback(self)
|
||||
|
||||
def should_allocate(self) -> bool:
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user