Prevent pattern matches across mutation ops in inductor pre-grad FX passes (#101144)

Per https://github.com/pytorch/pytorch/issues/101124

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101144
Approved by: https://github.com/jansel
This commit is contained in:
William Wen 2023-05-18 22:52:38 +00:00 committed by PyTorch MergeBot
parent 13640bf925
commit dde6d56101
4 changed files with 125 additions and 4 deletions

View File

@ -1,4 +1,7 @@
# Owner(s): ["module: inductor"]
import copy
import unittest
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import count_calls, counters
@ -280,6 +283,48 @@ class TestPaternMatcher(TestCase):
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
def test_match_with_mutation(self):
from torch._inductor.pattern_matcher import (
CallFunction,
KeywordArg,
PatternMatcherPass,
register_graph_pattern,
)
counter = 0
test_pass = PatternMatcherPass(prevent_match_across_mutations=True)
@register_graph_pattern(
CallFunction(
torch.add, KeywordArg("x"), CallFunction(torch.sin, KeywordArg("x"))
),
pass_dict=test_pass,
)
def _test(match, x):
nonlocal counter
counter += 1
def fn(x, y):
a = torch.sin(x)
x.copy_(y)
b = torch.add(x, a)
return b
args1 = [
torch.randn(5, 5, device="cuda"),
torch.randn(5, 5, device="cuda"),
]
args2 = copy.deepcopy(args1)
with unittest.mock.patch(
"torch._inductor.fx_passes.pre_grad.pattern_matcher_passes", [test_pass]
):
expected = fn(*args1)
actual = torch.compile(fn)(*args2)
# should not match
self.assertEqual(counter, 0)
torch.testing.assert_close(actual, expected)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA and not TEST_WITH_ROCM:

View File

@ -545,6 +545,25 @@ class TestSplitCatFxPasses(TestCase):
0,
)
@torch._inductor.config.patch(split_cat_fx_passes=True)
def test_split_cat_merge_mutation(self):
args = [
torch.randn(2, 32, 32, 16),
]
def split_cat_mutation(x):
splits = torch.split(x, 4, dim=1)
splits[1].copy_(splits[0])
return torch.cat(splits, dim=1)
expected = split_cat_mutation(*args)
actual = torch.compile(split_cat_mutation, dynamic=True)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["scmerge_split_removed"], 0)
self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 0)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA and not TEST_WITH_ROCM:

View File

@ -22,9 +22,9 @@ from ..utils import is_cpu_device
log = logging.getLogger(__name__)
normalize_split_pass = PatternMatcherPass()
merge_splits_pass = PatternMatcherPass()
merge_split_cat_pass = PatternMatcherPass()
normalize_split_pass = PatternMatcherPass(prevent_match_across_mutations=True)
merge_splits_pass = PatternMatcherPass(prevent_match_across_mutations=True)
merge_split_cat_pass = PatternMatcherPass(prevent_match_across_mutations=True)
pattern_matcher_passes: List[PatternMatcherPass] = [
normalize_split_pass,

View File

@ -721,10 +721,54 @@ def register_graph_pattern(
return decorator
def is_start_of_fx_graph(graph, node):
# first node in the graph
return node is next(iter(graph.nodes))
def is_mutation_op(node):
if node.op == "call_function":
if node.target.__name__.endswith("_"):
return True
elif node.op == "call_method":
if node.target.endswith("_"):
return True
if "out" in node.kwargs:
if node.kwargs["out"] in node.all_input_nodes:
return True
return False
def get_mutation_region_id(graph, node):
n = node
while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):
n = n.prev
mutation_region_id = n.meta.get("mutation_region_id", 0)
while n is not node:
n = n.next
if is_mutation_op(n):
mutation_region_id += 1
n.meta["mutation_region_id"] = mutation_region_id
return mutation_region_id
def should_compute_mutation_region_ids(graph):
return "mutation_region_id" not in next(iter(graph.nodes)).meta
def compute_mutation_region_ids(graph):
mutation_region_id = 0
for nd in graph.nodes:
if is_mutation_op(nd):
mutation_region_id += 1
nd.meta["mutation_region_id"] = mutation_region_id
class PatternMatcherPass:
def __init__(self):
def __init__(self, prevent_match_across_mutations=False):
super().__init__()
self.patterns = defaultdict(list)
self.prevent_match_across_mutations = prevent_match_across_mutations
def __getitem__(self, item):
return self.patterns[item]
@ -734,6 +778,12 @@ class PatternMatcherPass:
return 0
if isinstance(graph, torch.fx.GraphModule):
graph = graph.graph
if self.prevent_match_across_mutations:
if should_compute_mutation_region_ids(graph):
compute_mutation_region_ids(graph)
get_mutation_region_id_partial = functools.partial(
get_mutation_region_id, graph
)
count = 0
for node in reversed(graph.nodes):
if (
@ -750,6 +800,13 @@ class PatternMatcherPass:
if node._erased:
break
m = entry.pattern.match(node)
# pattern match crosses mutation barrier - discard
if (
self.prevent_match_across_mutations
and m
and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1
):
continue
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
if m and entry.extra_check(m):