mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
SubgraphMatcher: matching modules support. (#25075)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25075 This change adds a special behavior to subgraph matcher to allow it to match calls to modules. Namely, when a node in the pattern graph has a 'match::module' type, it is considered 'match' only when the corresponding node in the target graph is a 'prim::GetAttr' obtaining a submodule which type matches the type specified in 'name' attribute of the 'match::module' node. Currently when comparing the expected module type we check if the string specified in 'name' prefixes qualified name of the module GetAttr returns. In future when qualified name format is better defined we will probably change it for the exact comparison. Why do we want this? In some cases we would like to perform fusion on a module level rather than on a graph-level. A popular example of such fusion would be Conv-BN. It is inpractical to match batchnorm on graph-evel because that would mean we woudl need to specify its full and exact implementation in the pattern graph. If we match on the CallMethod level, however, the problem becomes trivial. The feature added in this PR allows to detect patterns with 'CallMethod' nodes, which in its turn allows us to use subgraph rewriter to replace such patterns with some node (or nodes). I expect that a usual approach would be to use subgraph-rewriter to replace all matches with some artificial node and then in additional pass replace such nodes with calls to another module or something else. It is not possible at the moment to use subgraph-rewriter to add a call to a method of a new module, because it can not add a new submodule, but we probably would add a higher level API to do that. Test Plan: Imported from OSS Differential Revision: D16978652 Pulled By: ZolotukhinM fbshipit-source-id: 37307a5ec65cf4618ad8eb595ef5f8ae656e2713
This commit is contained in:
parent
16289c2fdc
commit
85bca16a61
|
|
@ -1269,6 +1269,33 @@ graph(%Ra, %Rb):
|
|||
return (%Ra)""", graph)
|
||||
FileCheck().run(input_str, graph)
|
||||
|
||||
@_tmp_donotuse_dont_inline_everything
|
||||
def test_pattern_based_module_rewrite(self):
|
||||
# Check match::module behavior
|
||||
class Test(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Test, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(1, 20, 5, 1)
|
||||
self.bn = torch.nn.BatchNorm2d(num_features=20)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
m = torch.jit.script(Test())
|
||||
torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
|
||||
graph(%self, %x):
|
||||
%conv = match::module[name="Conv2d"](%self)
|
||||
%y = prim::CallMethod[name="forward"](%conv, %x)
|
||||
%bn = match::module[name="BatchNorm2d"](%self)
|
||||
%z = prim::CallMethod[name="forward"](%bn, %y)
|
||||
return (%z)""", """
|
||||
graph(%self, %x):
|
||||
%z = my::matched_conv_bn(%self, %x)
|
||||
return (%z)""", m._c._get_method("forward").graph)
|
||||
|
||||
FileCheck().check("my::matched_conv_bn").run(m._c._get_method("forward").graph)
|
||||
|
||||
def test_expand_quantlint(self):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -179,6 +179,11 @@ bool SubgraphMatcher::matchAttributes(const Node* n1, Node* n2) {
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool endsWith(const std::string& str, const std::string& suffix) {
|
||||
return str.size() >= suffix.size() &&
|
||||
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare two Nodes. N1 is from pattern, N2 is from the actual graph.
|
||||
*
|
||||
|
|
@ -212,17 +217,40 @@ bool SubgraphMatcher::matchNodes(const Node* n1, Node* n2) {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (n1->kind() != n2->kind() ||
|
||||
n1->outputs().size() != n2->outputs().size() ||
|
||||
n1->inputs().size() != n2->inputs().size()) {
|
||||
GRAPH_DEBUG(
|
||||
"Nodes did not match in their kind or number of inputs/outputs:\n",
|
||||
*n1,
|
||||
*n2);
|
||||
return false;
|
||||
}
|
||||
if (!matchAttributes(n1, n2)) {
|
||||
return false;
|
||||
// Special handling for matching modules
|
||||
if (n1->kind() == Symbol::fromQualString("match::module")) {
|
||||
if (n2->kind() == prim::GetAttr) {
|
||||
if (!n1->hasAttributeS("name")) {
|
||||
GRAPH_DEBUG(
|
||||
"Nodes did not match because special node match::module does not have 'name' attribute:\n",
|
||||
*n1,
|
||||
*n2);
|
||||
return false;
|
||||
}
|
||||
auto t = n2->output()->type()->expect<ClassType>();
|
||||
auto real_typename = t->name()->qualifiedName();
|
||||
auto pattern_typename = n1->s(attr::name);
|
||||
if (!endsWith(real_typename, pattern_typename)) {
|
||||
GRAPH_DEBUG("Nodes did not match because expected module type is different:\n");
|
||||
GRAPH_DEBUG(" actualtype: ", real_typename, "\n");
|
||||
GRAPH_DEBUG(" expected type: ", pattern_typename, "\n");
|
||||
GRAPH_DEBUG("Nodes:", *n1, *n2);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (n1->kind() != n2->kind() ||
|
||||
n1->outputs().size() != n2->outputs().size() ||
|
||||
n1->inputs().size() != n2->inputs().size()) {
|
||||
GRAPH_DEBUG(
|
||||
"Nodes did not match in their kind or number of inputs/outputs:\n",
|
||||
*n1,
|
||||
*n2);
|
||||
return false;
|
||||
}
|
||||
if (!matchAttributes(n1, n2)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Add nodes to the map before calling matchValues to avoid infinite
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user