[Inductor] Lower fallback nodes annotated with "should_fallback" (#166339)

Summary:
This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as #164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules.

Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, pred: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if self.predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Differential Revision: D85347587

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166339
Approved by: https://github.com/blaine-rister, https://github.com/eellison
This commit is contained in:
Millie Chen 2025-10-29 16:33:55 +00:00 committed by PyTorch MergeBot
parent 5fd1d41e62
commit 398fdd32bb
2 changed files with 97 additions and 1 deletions

View File

@ -0,0 +1,91 @@
# Owner(s): ["module: inductor"]
"""
Test selective lowering control via node metadata annotations.
"""
from collections.abc import Callable
import torch
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal.common_utils import instantiate_parametrized_tests
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
@instantiate_parametrized_tests
class SelectiveLoweringTest(InductorTestCase):
"""
Tests for user-controllable selective lowering using node.meta annotations.
"""
device = GPU_TYPE
def _mark_nodes_for_fallback(
self, gm: torch.fx.GraphModule, predicate: Callable[[torch.fx.Node], bool]
) -> torch.fx.GraphModule:
"""
Helper method to mark nodes with should_fallback metadata based on a predicate.
"""
for node in gm.graph.nodes:
if node.op == "call_function" and predicate(node):
node.meta["should_fallback"] = True
return gm
def test_basic_selective_lowering(self):
"""
Test that nodes marked for fallback use fallback handlers instead of lowerings.
"""
def foo(x, y):
a = x + y # This will be marked for fallback
b = a * 2 # This will use normal lowering
return b
x = torch.randn(10, device=self.device)
y = torch.randn(10, device=self.device)
def custom_backend(gm: torch.fx.GraphModule, example_inputs):
# Mark all add operations for fallback
def should_fallback_add(node: torch.fx.Node) -> bool:
return node.target == torch.ops.aten.add.Tensor
self._mark_nodes_for_fallback(gm, should_fallback_add)
from torch._inductor.compile_fx import compile_fx
return compile_fx(gm, example_inputs)
compiled_fn = torch.compile(foo, backend=custom_backend)
result = compiled_fn(x, y)
expected = foo(x, y)
self.assertTrue(torch.allclose(result, expected))
def test_no_fallback_when_unmarked(self):
"""
Test that operations without fallback annotation use normal lowering.
"""
def foo(x, y):
return x + y
x = torch.randn(10, device=self.device)
y = torch.randn(10, device=self.device)
def custom_backend(gm: torch.fx.GraphModule, example_inputs):
# Don't mark anything - all operations should use normal lowering
from torch._inductor.compile_fx import compile_fx
return compile_fx(gm, example_inputs)
compiled_fn = torch.compile(foo, backend=custom_backend)
result = compiled_fn(x, y)
expected = foo(x, y)
self.assertTrue(torch.allclose(result, expected))
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU:
run_tests(needs="filelock")

View File

@ -1322,7 +1322,12 @@ class GraphLowering(torch.fx.Interpreter):
else:
args, kwargs = layout_constraints(n, *args, **kwargs)
out = lowerings[target](*args, **kwargs) # type: ignore[index]
if "should_fallback" in n.meta:
out = fallback_handler(target, add_to_fallback_set=False)(
*args, **kwargs
)
else:
out = lowerings[target](*args, **kwargs) # type: ignore[index]
if layout_constraints:
# layout_constraints are allowed to make new copies of the inputs.