mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Lower fallback nodes annotated with "should_fallback" (#166339)
Summary: This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as #164776 in providing flexibility to selectively disable Inductor lowering for specific nodes. The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. Simple example marking nodes for fallback based on custom predicates: ``` def should_fallback_predicate(node: torch.fx.Node, pred: Callable[torch.fx.Node, bool]): # Apply predicate and mark for fallback if needed if self.predicate(node): node.meta["should_fallback"] = True ``` Test Plan: added a CI test Differential Revision: D85347587 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166339 Approved by: https://github.com/blaine-rister, https://github.com/eellison
This commit is contained in:
parent
5fd1d41e62
commit
398fdd32bb
91
test/inductor/test_selective_lowering.py
Normal file
91
test/inductor/test_selective_lowering.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
"""
|
||||
Test selective lowering control via node metadata annotations.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch.testing._internal.common_utils import instantiate_parametrized_tests
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class SelectiveLoweringTest(InductorTestCase):
|
||||
"""
|
||||
Tests for user-controllable selective lowering using node.meta annotations.
|
||||
"""
|
||||
|
||||
device = GPU_TYPE
|
||||
|
||||
def _mark_nodes_for_fallback(
|
||||
self, gm: torch.fx.GraphModule, predicate: Callable[[torch.fx.Node], bool]
|
||||
) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Helper method to mark nodes with should_fallback metadata based on a predicate.
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and predicate(node):
|
||||
node.meta["should_fallback"] = True
|
||||
return gm
|
||||
|
||||
def test_basic_selective_lowering(self):
|
||||
"""
|
||||
Test that nodes marked for fallback use fallback handlers instead of lowerings.
|
||||
"""
|
||||
|
||||
def foo(x, y):
|
||||
a = x + y # This will be marked for fallback
|
||||
b = a * 2 # This will use normal lowering
|
||||
return b
|
||||
|
||||
x = torch.randn(10, device=self.device)
|
||||
y = torch.randn(10, device=self.device)
|
||||
|
||||
def custom_backend(gm: torch.fx.GraphModule, example_inputs):
|
||||
# Mark all add operations for fallback
|
||||
def should_fallback_add(node: torch.fx.Node) -> bool:
|
||||
return node.target == torch.ops.aten.add.Tensor
|
||||
|
||||
self._mark_nodes_for_fallback(gm, should_fallback_add)
|
||||
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
return compile_fx(gm, example_inputs)
|
||||
|
||||
compiled_fn = torch.compile(foo, backend=custom_backend)
|
||||
result = compiled_fn(x, y)
|
||||
expected = foo(x, y)
|
||||
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
def test_no_fallback_when_unmarked(self):
|
||||
"""
|
||||
Test that operations without fallback annotation use normal lowering.
|
||||
"""
|
||||
|
||||
def foo(x, y):
|
||||
return x + y
|
||||
|
||||
x = torch.randn(10, device=self.device)
|
||||
y = torch.randn(10, device=self.device)
|
||||
|
||||
def custom_backend(gm: torch.fx.GraphModule, example_inputs):
|
||||
# Don't mark anything - all operations should use normal lowering
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
return compile_fx(gm, example_inputs)
|
||||
|
||||
compiled_fn = torch.compile(foo, backend=custom_backend)
|
||||
result = compiled_fn(x, y)
|
||||
expected = foo(x, y)
|
||||
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if HAS_GPU:
|
||||
run_tests(needs="filelock")
|
||||
|
|
@ -1322,6 +1322,11 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
else:
|
||||
args, kwargs = layout_constraints(n, *args, **kwargs)
|
||||
|
||||
if "should_fallback" in n.meta:
|
||||
out = fallback_handler(target, add_to_fallback_set=False)(
|
||||
*args, **kwargs
|
||||
)
|
||||
else:
|
||||
out = lowerings[target](*args, **kwargs) # type: ignore[index]
|
||||
|
||||
if layout_constraints:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user