mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Prevent pattern matches across mutation ops in inductor pre-grad FX passes (#101144)
Per https://github.com/pytorch/pytorch/issues/101124 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101144 Approved by: https://github.com/jansel
This commit is contained in:
parent
13640bf925
commit
dde6d56101
|
|
@ -1,4 +1,7 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.utils import count_calls, counters
|
||||
|
|
@ -280,6 +283,48 @@ class TestPaternMatcher(TestCase):
|
|||
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
|
||||
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
|
||||
|
||||
def test_match_with_mutation(self):
|
||||
from torch._inductor.pattern_matcher import (
|
||||
CallFunction,
|
||||
KeywordArg,
|
||||
PatternMatcherPass,
|
||||
register_graph_pattern,
|
||||
)
|
||||
|
||||
counter = 0
|
||||
test_pass = PatternMatcherPass(prevent_match_across_mutations=True)
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunction(
|
||||
torch.add, KeywordArg("x"), CallFunction(torch.sin, KeywordArg("x"))
|
||||
),
|
||||
pass_dict=test_pass,
|
||||
)
|
||||
def _test(match, x):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
|
||||
def fn(x, y):
|
||||
a = torch.sin(x)
|
||||
x.copy_(y)
|
||||
b = torch.add(x, a)
|
||||
return b
|
||||
|
||||
args1 = [
|
||||
torch.randn(5, 5, device="cuda"),
|
||||
torch.randn(5, 5, device="cuda"),
|
||||
]
|
||||
args2 = copy.deepcopy(args1)
|
||||
|
||||
with unittest.mock.patch(
|
||||
"torch._inductor.fx_passes.pre_grad.pattern_matcher_passes", [test_pass]
|
||||
):
|
||||
expected = fn(*args1)
|
||||
actual = torch.compile(fn)(*args2)
|
||||
# should not match
|
||||
self.assertEqual(counter, 0)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_CUDA and not TEST_WITH_ROCM:
|
||||
|
|
|
|||
|
|
@ -545,6 +545,25 @@ class TestSplitCatFxPasses(TestCase):
|
|||
0,
|
||||
)
|
||||
|
||||
@torch._inductor.config.patch(split_cat_fx_passes=True)
|
||||
def test_split_cat_merge_mutation(self):
|
||||
args = [
|
||||
torch.randn(2, 32, 32, 16),
|
||||
]
|
||||
|
||||
def split_cat_mutation(x):
|
||||
splits = torch.split(x, 4, dim=1)
|
||||
splits[1].copy_(splits[0])
|
||||
return torch.cat(splits, dim=1)
|
||||
|
||||
expected = split_cat_mutation(*args)
|
||||
actual = torch.compile(split_cat_mutation, dynamic=True)(*args)
|
||||
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
self.assertEqual(counters["inductor"]["scmerge_split_removed"], 0)
|
||||
self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_CUDA and not TEST_WITH_ROCM:
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ from ..utils import is_cpu_device
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
normalize_split_pass = PatternMatcherPass()
|
||||
merge_splits_pass = PatternMatcherPass()
|
||||
merge_split_cat_pass = PatternMatcherPass()
|
||||
normalize_split_pass = PatternMatcherPass(prevent_match_across_mutations=True)
|
||||
merge_splits_pass = PatternMatcherPass(prevent_match_across_mutations=True)
|
||||
merge_split_cat_pass = PatternMatcherPass(prevent_match_across_mutations=True)
|
||||
|
||||
pattern_matcher_passes: List[PatternMatcherPass] = [
|
||||
normalize_split_pass,
|
||||
|
|
|
|||
|
|
@ -721,10 +721,54 @@ def register_graph_pattern(
|
|||
return decorator
|
||||
|
||||
|
||||
def is_start_of_fx_graph(graph, node):
|
||||
# first node in the graph
|
||||
return node is next(iter(graph.nodes))
|
||||
|
||||
|
||||
def is_mutation_op(node):
|
||||
if node.op == "call_function":
|
||||
if node.target.__name__.endswith("_"):
|
||||
return True
|
||||
elif node.op == "call_method":
|
||||
if node.target.endswith("_"):
|
||||
return True
|
||||
if "out" in node.kwargs:
|
||||
if node.kwargs["out"] in node.all_input_nodes:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_mutation_region_id(graph, node):
|
||||
n = node
|
||||
while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):
|
||||
n = n.prev
|
||||
mutation_region_id = n.meta.get("mutation_region_id", 0)
|
||||
while n is not node:
|
||||
n = n.next
|
||||
if is_mutation_op(n):
|
||||
mutation_region_id += 1
|
||||
n.meta["mutation_region_id"] = mutation_region_id
|
||||
return mutation_region_id
|
||||
|
||||
|
||||
def should_compute_mutation_region_ids(graph):
|
||||
return "mutation_region_id" not in next(iter(graph.nodes)).meta
|
||||
|
||||
|
||||
def compute_mutation_region_ids(graph):
|
||||
mutation_region_id = 0
|
||||
for nd in graph.nodes:
|
||||
if is_mutation_op(nd):
|
||||
mutation_region_id += 1
|
||||
nd.meta["mutation_region_id"] = mutation_region_id
|
||||
|
||||
|
||||
class PatternMatcherPass:
|
||||
def __init__(self):
|
||||
def __init__(self, prevent_match_across_mutations=False):
|
||||
super().__init__()
|
||||
self.patterns = defaultdict(list)
|
||||
self.prevent_match_across_mutations = prevent_match_across_mutations
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.patterns[item]
|
||||
|
|
@ -734,6 +778,12 @@ class PatternMatcherPass:
|
|||
return 0
|
||||
if isinstance(graph, torch.fx.GraphModule):
|
||||
graph = graph.graph
|
||||
if self.prevent_match_across_mutations:
|
||||
if should_compute_mutation_region_ids(graph):
|
||||
compute_mutation_region_ids(graph)
|
||||
get_mutation_region_id_partial = functools.partial(
|
||||
get_mutation_region_id, graph
|
||||
)
|
||||
count = 0
|
||||
for node in reversed(graph.nodes):
|
||||
if (
|
||||
|
|
@ -750,6 +800,13 @@ class PatternMatcherPass:
|
|||
if node._erased:
|
||||
break
|
||||
m = entry.pattern.match(node)
|
||||
# pattern match crosses mutation barrier - discard
|
||||
if (
|
||||
self.prevent_match_across_mutations
|
||||
and m
|
||||
and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1
|
||||
):
|
||||
continue
|
||||
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
|
||||
log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
|
||||
if m and entry.extra_check(m):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user