diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index 2c644f79271..1578f0e0e7f 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -306,3 +306,72 @@ class TestSubgraphRewriter(JitTestCase): traced.graph.lint() self.assertEqual(res, []) + + def test_subgraph_rewriter_placeholder_matching(self): + """ + This tests that a placeholder Node can be matched to a Node with + a different number of input Nodes. In the example below, the + original traced Module looks like this: + + opcode target args kwargs + ------------- ---------------------------------------------------------- ------------------------ -------- + placeholder x () {} + call_function (x, 3) {} + call_method dequantize (add,) {} + call_function (dequantize,) {} + call_method to (sigmoid, torch.float16) {} + output output (to,) {} + + while the pattern we want to match looks like this: + + opcode target args kwargs + ------------- ---------------------------------------------------------- ------------------------ -------- + placeholder x () {} + call_method dequantize (x,) {} + call_function (dequantize,) {} + call_method to (sigmoid, torch.float16) {} + output output (to,) {} + + Here, we want to be able to match the original graph's + `call_function.add` Node with the pattern graph's + `plaeholder.x` Node. + + Credit to Jerry Zhang (GitHub: jerryzh168) for this test case + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.dtype = torch.float16 + + def forward(self, x): + x += 3 + x = x.dequantize() + x = torch.sigmoid(x) + dtype = self.dtype + x = x.to(dtype) + return x + + def pattern(x): + x = x.dequantize() + x = torch.sigmoid(x) + x = x.to(torch.float16) + return x + + def replacement(x): + return x + + def comparison(x): + return x + 3 + + traced = symbolic_trace(M()) + comparison_fn = symbolic_trace(comparison) + + x = torch.randn(3, 4) + + subgraph_rewriter.replace_pattern(traced, pattern, replacement) + + traced.graph.lint() + + ref_outs = comparison_fn(x) + test_outs = traced.forward(x) + self.assertEqual(ref_outs, test_outs) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 1edf1afe8f7..fac3bf036b8 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -64,6 +64,8 @@ class SubgraphMatcher: # Traverse the use-def relationships to ensure that `pn` is a true # match for `gn` + if pn.op == "placeholder": + return True if (pn.op != "output" and len(pn.all_input_nodes) != len(gn.all_input_nodes)): return False