Improve placeholder matching in subgraph rewriter (#54958)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54958

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D27431889

Pulled By: ansley

fbshipit-source-id: 8b1b4f2f0202305530b9648b6b770f9e2ecacfe2
This commit is contained in:
Ansley Ussery 2021-03-30 11:39:17 -07:00 committed by Facebook GitHub Bot
parent f5d6b90c35
commit 18e61d1ce9
2 changed files with 71 additions and 0 deletions

View File

@ -306,3 +306,72 @@ class TestSubgraphRewriter(JitTestCase):
traced.graph.lint()
self.assertEqual(res, [])
def test_subgraph_rewriter_placeholder_matching(self):
"""
This tests that a placeholder Node can be matched to a Node with
a different number of input Nodes. In the example below, the
original traced Module looks like this:
opcode target args kwargs
------------- ---------------------------------------------------------- ------------------------ --------
placeholder x () {}
call_function <built-in function add> (x, 3) {}
call_method dequantize (add,) {}
call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {}
call_method to (sigmoid, torch.float16) {}
output output (to,) {}
while the pattern we want to match looks like this:
opcode target args kwargs
------------- ---------------------------------------------------------- ------------------------ --------
placeholder x () {}
call_method dequantize (x,) {}
call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {}
call_method to (sigmoid, torch.float16) {}
output output (to,) {}
Here, we want to be able to match the original graph's
`call_function.add` Node with the pattern graph's
`plaeholder.x` Node.
Credit to Jerry Zhang (GitHub: jerryzh168) for this test case
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.dtype = torch.float16
def forward(self, x):
x += 3
x = x.dequantize()
x = torch.sigmoid(x)
dtype = self.dtype
x = x.to(dtype)
return x
def pattern(x):
x = x.dequantize()
x = torch.sigmoid(x)
x = x.to(torch.float16)
return x
def replacement(x):
return x
def comparison(x):
return x + 3
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)
x = torch.randn(3, 4)
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()
ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)

View File

@ -64,6 +64,8 @@ class SubgraphMatcher:
# Traverse the use-def relationships to ensure that `pn` is a true
# match for `gn`
if pn.op == "placeholder":
return True
if (pn.op != "output"
and len(pn.all_input_nodes) != len(gn.all_input_nodes)):
return False