mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improve placeholder matching in subgraph rewriter (#54958)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54958 Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D27431889 Pulled By: ansley fbshipit-source-id: 8b1b4f2f0202305530b9648b6b770f9e2ecacfe2
This commit is contained in:
parent
f5d6b90c35
commit
18e61d1ce9
|
|
@ -306,3 +306,72 @@ class TestSubgraphRewriter(JitTestCase):
|
|||
traced.graph.lint()
|
||||
|
||||
self.assertEqual(res, [])
|
||||
|
||||
def test_subgraph_rewriter_placeholder_matching(self):
|
||||
"""
|
||||
This tests that a placeholder Node can be matched to a Node with
|
||||
a different number of input Nodes. In the example below, the
|
||||
original traced Module looks like this:
|
||||
|
||||
opcode target args kwargs
|
||||
------------- ---------------------------------------------------------- ------------------------ --------
|
||||
placeholder x () {}
|
||||
call_function <built-in function add> (x, 3) {}
|
||||
call_method dequantize (add,) {}
|
||||
call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {}
|
||||
call_method to (sigmoid, torch.float16) {}
|
||||
output output (to,) {}
|
||||
|
||||
while the pattern we want to match looks like this:
|
||||
|
||||
opcode target args kwargs
|
||||
------------- ---------------------------------------------------------- ------------------------ --------
|
||||
placeholder x () {}
|
||||
call_method dequantize (x,) {}
|
||||
call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {}
|
||||
call_method to (sigmoid, torch.float16) {}
|
||||
output output (to,) {}
|
||||
|
||||
Here, we want to be able to match the original graph's
|
||||
`call_function.add` Node with the pattern graph's
|
||||
`plaeholder.x` Node.
|
||||
|
||||
Credit to Jerry Zhang (GitHub: jerryzh168) for this test case
|
||||
"""
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dtype = torch.float16
|
||||
|
||||
def forward(self, x):
|
||||
x += 3
|
||||
x = x.dequantize()
|
||||
x = torch.sigmoid(x)
|
||||
dtype = self.dtype
|
||||
x = x.to(dtype)
|
||||
return x
|
||||
|
||||
def pattern(x):
|
||||
x = x.dequantize()
|
||||
x = torch.sigmoid(x)
|
||||
x = x.to(torch.float16)
|
||||
return x
|
||||
|
||||
def replacement(x):
|
||||
return x
|
||||
|
||||
def comparison(x):
|
||||
return x + 3
|
||||
|
||||
traced = symbolic_trace(M())
|
||||
comparison_fn = symbolic_trace(comparison)
|
||||
|
||||
x = torch.randn(3, 4)
|
||||
|
||||
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
|
||||
|
||||
traced.graph.lint()
|
||||
|
||||
ref_outs = comparison_fn(x)
|
||||
test_outs = traced.forward(x)
|
||||
self.assertEqual(ref_outs, test_outs)
|
||||
|
|
|
|||
|
|
@ -64,6 +64,8 @@ class SubgraphMatcher:
|
|||
|
||||
# Traverse the use-def relationships to ensure that `pn` is a true
|
||||
# match for `gn`
|
||||
if pn.op == "placeholder":
|
||||
return True
|
||||
if (pn.op != "output"
|
||||
and len(pn.all_input_nodes) != len(gn.all_input_nodes)):
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user