[Inductor-FX] Support ScatterFallback (#162686)

# Problem
Inductor has a `ScatterFallback` op with custom Python and C++ wrapper codegen macros. This is used in certain situations where the default Triton codegen doesn't apply, and especially for reductions which need to be deterministic. Since this op used direct Python/C++ codegen, it wasn't compatible with the FX backend.

# Feature
This PR refactors the associated wrapper codegen to support `ScatterFallback`. This follows the same basic steps that were used for other fallback ops including `MultiOutput` and `ExternKernel`:

1. Create a new wrapper IR op called `ScatterFallbackLine`. Move the logic in `ScatterFallback.cogeden` to `ScatterFallbackLine.codegen`, to prevent it from affecting the FX backend. This logic is unsafe for FX because it may generate Python or C++ strings with methods like `codegen_reference()`.
2. To eleminate the dependence on `V.graph`, move language-specific logic to the respective wrapper codegen subclasses. In this case, C++ codegen has some special logic, which is moved to `CppWrapperCpu`.
3. Create a new method in `FXWrapperCodegen` to handle `ScatterFallbackLine`.

# Test plan
Added a couple of CI tests for the FX backend with scatter fallbacks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162686
Approved by: https://github.com/jansel
This commit is contained in:
Blaine Burton Rister 2025-09-12 08:41:47 +00:00 committed by PyTorch MergeBot
parent 98e9440f30
commit a7bbc5fea7
6 changed files with 116 additions and 27 deletions

View File

@ -25,6 +25,7 @@ from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.export import Dim
from torch.testing._internal.common_utils import (
DeterministicGuard,
instantiate_parametrized_tests,
parametrize,
)
@ -567,6 +568,50 @@ class FxirTestCase(InductorTestCase):
self.assertTrue(same(ref, result))
def test_scatter_fallback_scalar_src(self):
"""
Test a special case where ScatterFallback takes a scalar 'src' argument.
"""
def foo(input_):
dim = 0
src = 1.5
return torch.ops.aten.scatter(input_, dim, index, src)
length = 8
index = torch.randint(length, (length,), device=self.device)
input_ = torch.randn(length, device=self.device)
with DeterministicGuard(True):
(gm,) = self._compile_and_check(
foo,
(input_,),
)
# Check for the fallback op.
num_fallback = self._count_ops(gm, torch.ops.aten.scatter_.value)
self.assertEqual(num_fallback, 1)
def test_scatter_reduce_fallback(self):
"""
Test the customized wrapper codegen for ScatterFallback ops.
"""
fallback_op = torch.ops.aten.scatter_reduce_.two
def foo(out, index, src):
dim = 0
out = fallback_op(out, dim, index, src, reduce="amax", include_self=False)
return out + 1
length = 8
out, src = [torch.randn(length, device=self.device) for _ in range(2)]
index = torch.randint(length, (length,), device=self.device)
(gm,) = self._compile_and_check(
foo, (out, index, src), expected_num_triton_kernels=2
)
# Check for the fallback.
self.assertEqual(self._count_ops(gm, fallback_op), 1)
@torch._inductor.config.patch("graph_partition", True)
def test_subgraph_raises(self):
"""

View File

@ -1398,7 +1398,15 @@ class CppWrapperCpu(PythonWrapperCodegen):
kernel, args, device, debug_handle=debug_handle
)
def generate_scatter_fallback(
def _get_scatter_reduce_enum(self, reduce):
# Follow aten/src/ATen/native/ReductionType.h:get_operator_enum
get_operator_enum = {"add": "sum", "multiply": "prod"}
if reduce in get_operator_enum:
reduce = get_operator_enum[reduce]
return reduce
def _generate_scatter_fallback(
self,
output,
inputs,
@ -1408,6 +1416,8 @@ class CppWrapperCpu(PythonWrapperCodegen):
reduce,
kwargs,
):
reduce = self._get_scatter_reduce_enum(reduce)
# call the ABI shim function instead of the ATen one
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py

View File

@ -694,7 +694,12 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
kernel, wrapped_args, device, debug_args=args
)
def generate_scatter_fallback(
def generate_scatter_fallback(self, node: ir.ScatterFallback):
# No stack allocation when there is a fallback op
self.allow_stack_allocation = False
super().generate_scatter_fallback(node)
def _generate_scatter_fallback(
self,
output,
inputs,
@ -704,8 +709,7 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
reduce,
kwargs,
):
# No stack allocation when there is a fallback op
self.allow_stack_allocation = False
reduce = self._get_scatter_reduce_enum(reduce)
# call the ABI shim function instead of the ATen one
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)

View File

@ -908,6 +908,33 @@ class MultiOutputLine(WrapperLine):
return converter._generate_multi_output
@dataclasses.dataclass
class ScatterFallbackLine(WrapperLine):
wrapper: PythonWrapperCodegen
node: ir.ScatterFallback
def codegen(self, code: IndentedBuffer) -> None:
node = self.node
assert ir.is_node_sequence(node.inputs)
if node.src_is_tensor:
(x, index, src) = (t.codegen_reference() for t in node.inputs)
else:
(x, index) = (t.codegen_reference() for t in node.inputs)
src = node.constant_args[1]
self.wrapper._generate_scatter_fallback(
x,
[x, node.constant_args[0], index, src],
node.cpp_kernel_name,
node.python_kernel_name,
node.src_is_tensor,
node.kwargs["reduce"],
node.codegen_kwargs(),
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_scatter_fallback
@dataclasses.dataclass
class SymbolicCallArgLine(WrapperLine):
wrapper: PythonWrapperCodegen
@ -1511,7 +1538,10 @@ class PythonWrapperCodegen(CodeGen):
line = f"{desc.name} = {call}{self.ending}"
self.writeline(line)
def generate_scatter_fallback(
def generate_scatter_fallback(self, node: ir.ScatterFallback):
self.writeline(ScatterFallbackLine(self, node))
def _generate_scatter_fallback(
self,
output,
inputs,

View File

@ -63,6 +63,7 @@ from .wrapper import (
PythonWrapperCodegen,
ReinterpretLine,
ReuseLine,
ScatterFallbackLine,
SymbolicCallArg,
SymbolicCallArgLine,
WrapperLine,
@ -653,6 +654,26 @@ class FxConverter:
node.name = line.result_name
self.buffer_to_node[line.result_name] = node
def _generate_scatter_fallback(self, line: WrapperLine) -> None:
assert isinstance(line, ScatterFallbackLine)
ir_node = line.node
assert ir.is_node_sequence(ir_node.inputs)
(x, index, src) = [self._generate_buffer(t) for t in ir_node.inputs] + (
[] if ir_node.src_is_tensor else [ir_node.constant_args[1]]
)
args = (x, ir_node.constant_args[0], index, src)
kwargs = {}
if reduce := ir_node.kwargs.get("reduce"):
kwargs["reduce"] = reduce
fx_node = self.gm.graph.call_function(
ir_node.op_overload, # type: ignore[arg-type]
args=args,
kwargs=kwargs,
)
result_buffer = ir_node.codegen_reference()
self.buffer_to_node[result_buffer] = fx_node
def _generate_null(self, line: WrapperLine) -> None:
assert isinstance(line, NullLine)
# Does nothing.

View File

@ -7060,28 +7060,7 @@ class ScatterFallback(ExternKernel):
"""
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
reduce = self.kwargs["reduce"]
if V.graph.cpp_wrapper:
# Follow aten/src/ATen/native/ReductionType.h:get_operator_enum
get_operator_enum = {"add": "sum", "multiply": "prod"}
if reduce in get_operator_enum:
reduce = get_operator_enum[reduce]
assert is_node_sequence(self.inputs)
if self.src_is_tensor:
(x, index, src) = (t.codegen_reference() for t in self.inputs)
else:
(x, index) = (t.codegen_reference() for t in self.inputs)
src = self.constant_args[1]
wrapper.generate_scatter_fallback(
x,
[x, self.constant_args[0], index, src],
self.cpp_kernel_name,
self.python_kernel_name,
self.src_is_tensor,
reduce,
self.codegen_kwargs(),
)
wrapper.generate_scatter_fallback(self)
def should_allocate(self) -> bool:
return False