mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Helper function to replace literals that show up in call_function nodes in the graph to become placeholders so that they can be represented as wildcards when matching with the SubgraphMatcher. This pass causes the resulting graph to not be runnable with the original inputs since adding placeholders to the graph will change the number of inputs needed for the graph. Test: `python test/test_fx.py TestMatcher` Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/97683 Approved by: https://github.com/kimishpatel, https://github.com/SherlockNoMad
96 lines
3.9 KiB
Python
96 lines
3.9 KiB
Python
# Owner(s): ["module: fx"]
|
|
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
from torch.fx import symbolic_trace
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
class TestMatcher(JitTestCase):
|
|
def test_subgraph_matcher_with_attributes(self):
|
|
class LargeModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._weight = torch.nn.Parameter(torch.ones(3, 3))
|
|
self._bias = torch.nn.Parameter(torch.ones(3, 3))
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten.addmm.default(self._bias, x, self._weight)
|
|
|
|
# Large Model graph:
|
|
# opcode name target args kwargs
|
|
# ------------- ------------- ------------------ ------------------- --------
|
|
# placeholder x x () {}
|
|
# get_attr _bias _bias () {}
|
|
# get_attr _weight _weight () {}
|
|
# call_function addmm_default aten.addmm.default (_bias, x, _weight) {}
|
|
# output output output (addmm_default,) {}
|
|
large_model_graph = symbolic_trace(LargeModel()).graph
|
|
|
|
class PatternModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
|
|
self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
|
|
|
|
pattern_graph = torch.fx.symbolic_trace(PatternModel()).graph
|
|
|
|
subgraph_matcher = SubgraphMatcher(pattern_graph)
|
|
match_result = subgraph_matcher.match(large_model_graph)
|
|
self.assertEqual(len(match_result), 1)
|
|
|
|
def test_subgraph_matcher_with_list(self):
|
|
def original(x, y):
|
|
return torch.ops.aten.view(x, [5, y.shape[0]])
|
|
original_graph = torch.fx.symbolic_trace(original).graph
|
|
|
|
def pattern(x, y, z):
|
|
return torch.ops.aten.view(x, [z, y.shape[0]])
|
|
pattern_graph = torch.fx.symbolic_trace(pattern).graph
|
|
|
|
subgraph_matcher = SubgraphMatcher(pattern_graph)
|
|
match_result = subgraph_matcher.match(original_graph)
|
|
self.assertEqual(len(match_result), 1)
|
|
|
|
def test_subgraph_matcher_with_list_bad(self):
|
|
def original(x, y):
|
|
return torch.ops.aten._reshape_alias_copy.default(x, [1, y.shape[0]], [y.shape[1], y.shape[1]])
|
|
original_graph = torch.fx.symbolic_trace(original).graph
|
|
|
|
def pattern(x, y, b):
|
|
return torch.ops.aten._reshape_alias_copy.default(x, [b, y.shape[0], y.shape[1]], [y.shape[1]])
|
|
pattern_graph = torch.fx.symbolic_trace(pattern).graph
|
|
|
|
subgraph_matcher = SubgraphMatcher(pattern_graph)
|
|
match_result = subgraph_matcher.match(original_graph)
|
|
self.assertEqual(len(match_result), 0)
|
|
|
|
def test_subgraph_matcher_ignore_literals(self):
|
|
def original(x):
|
|
return x + 1
|
|
|
|
original_graph = make_fx(original)(torch.ones(3, 3)).graph
|
|
original_graph.eliminate_dead_code()
|
|
|
|
def pattern(x):
|
|
return x + 2
|
|
pattern_graph = make_fx(pattern)(torch.ones(4, 4)).graph
|
|
pattern_graph.eliminate_dead_code()
|
|
|
|
subgraph_matcher = SubgraphMatcher(pattern_graph)
|
|
match_result = subgraph_matcher.match(original_graph)
|
|
self.assertEqual(len(match_result), 0)
|
|
|
|
subgraph_matcher = SubgraphMatcher(pattern_graph, ignore_literals=True)
|
|
match_result = subgraph_matcher.match(original_graph)
|
|
self.assertEqual(len(match_result), 1)
|