SubgraphMatcher: matching modules support. (#25075)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25075

This change adds a special behavior to subgraph matcher to allow it to
match calls to modules. Namely, when a node in the pattern graph has a
'match::module' type, it is considered 'match' only when the
corresponding node in the target graph is a 'prim::GetAttr' obtaining a
submodule which type matches the type specified in 'name' attribute of
the 'match::module' node.

Currently when comparing the expected module type we check if the string
specified in 'name' prefixes qualified name of the module GetAttr
returns. In future when qualified name format is better defined  we will
probably change it for the exact comparison.

Why do we want this? In some cases we would like to perform fusion on a
module level rather than on a graph-level. A popular example of such
fusion would be Conv-BN. It is inpractical to match batchnorm on
graph-evel because that would mean we woudl need to specify its full
and exact implementation in the pattern graph. If we match on the
CallMethod level, however, the problem becomes trivial.

The feature added in this PR allows to detect patterns with 'CallMethod'
nodes, which in its turn allows us to use subgraph rewriter to replace
such patterns with some node (or nodes). I expect that a usual approach
would be to use subgraph-rewriter to replace all matches with some
artificial node and then in additional pass replace such nodes with
calls to another module or something else. It is not possible at the
moment to use subgraph-rewriter to add a call to a method of a new
module, because it can not add a new submodule, but we probably would
add a higher level API to do that.

Test Plan: Imported from OSS

Differential Revision: D16978652

Pulled By: ZolotukhinM

fbshipit-source-id: 37307a5ec65cf4618ad8eb595ef5f8ae656e2713
This commit is contained in:
Mikhail Zolotukhin 2019-08-23 21:13:34 -07:00 committed by Facebook Github Bot
parent 16289c2fdc
commit 85bca16a61
2 changed files with 66 additions and 11 deletions

View File

@ -1269,6 +1269,33 @@ graph(%Ra, %Rb):
return (%Ra)""", graph)
FileCheck().run(input_str, graph)
@_tmp_donotuse_dont_inline_everything
def test_pattern_based_module_rewrite(self):
# Check match::module behavior
class Test(torch.nn.Module):
def __init__(self):
super(Test, self).__init__()
self.conv = torch.nn.Conv2d(1, 20, 5, 1)
self.bn = torch.nn.BatchNorm2d(num_features=20)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
m = torch.jit.script(Test())
torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
graph(%self, %x):
%conv = match::module[name="Conv2d"](%self)
%y = prim::CallMethod[name="forward"](%conv, %x)
%bn = match::module[name="BatchNorm2d"](%self)
%z = prim::CallMethod[name="forward"](%bn, %y)
return (%z)""", """
graph(%self, %x):
%z = my::matched_conv_bn(%self, %x)
return (%z)""", m._c._get_method("forward").graph)
FileCheck().check("my::matched_conv_bn").run(m._c._get_method("forward").graph)
def test_expand_quantlint(self):
pass

View File

@ -179,6 +179,11 @@ bool SubgraphMatcher::matchAttributes(const Node* n1, Node* n2) {
return true;
}
static bool endsWith(const std::string& str, const std::string& suffix) {
return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}
/**
* Compare two Nodes. N1 is from pattern, N2 is from the actual graph.
*
@ -212,17 +217,40 @@ bool SubgraphMatcher::matchNodes(const Node* n1, Node* n2) {
return false;
}
if (n1->kind() != n2->kind() ||
n1->outputs().size() != n2->outputs().size() ||
n1->inputs().size() != n2->inputs().size()) {
GRAPH_DEBUG(
"Nodes did not match in their kind or number of inputs/outputs:\n",
*n1,
*n2);
return false;
}
if (!matchAttributes(n1, n2)) {
return false;
// Special handling for matching modules
if (n1->kind() == Symbol::fromQualString("match::module")) {
if (n2->kind() == prim::GetAttr) {
if (!n1->hasAttributeS("name")) {
GRAPH_DEBUG(
"Nodes did not match because special node match::module does not have 'name' attribute:\n",
*n1,
*n2);
return false;
}
auto t = n2->output()->type()->expect<ClassType>();
auto real_typename = t->name()->qualifiedName();
auto pattern_typename = n1->s(attr::name);
if (!endsWith(real_typename, pattern_typename)) {
GRAPH_DEBUG("Nodes did not match because expected module type is different:\n");
GRAPH_DEBUG(" actualtype: ", real_typename, "\n");
GRAPH_DEBUG(" expected type: ", pattern_typename, "\n");
GRAPH_DEBUG("Nodes:", *n1, *n2);
return false;
}
}
} else {
if (n1->kind() != n2->kind() ||
n1->outputs().size() != n2->outputs().size() ||
n1->inputs().size() != n2->inputs().size()) {
GRAPH_DEBUG(
"Nodes did not match in their kind or number of inputs/outputs:\n",
*n1,
*n2);
return false;
}
if (!matchAttributes(n1, n2)) {
return false;
}
}
// Add nodes to the map before calling matchValues to avoid infinite