pytorch/test/cpp/jit/test_subgraph_rewriter.cpp
Jerry Zhang 004aa089a6 [jit][subgraph_rewriter] Support list of filters (#39867)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39867

Support list of filters in subgraph rewriter, the rewrite will execute only
when the match passes all filter check. this is useful for different matches
to share the same filter.

Test Plan: Imported from OSS

Differential Revision: D22009855

fbshipit-source-id: 67aab8d6326b2011a9061397699dc62ee9ad4e2d
2020-06-12 08:24:49 -07:00

131 lines
3.4 KiB
C++

#include <test/cpp/jit/test_base.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/testing/file_check.h>
namespace torch {
namespace jit {
using namespace testing;
void testFilterMatch() {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%0):
%a = a::aaa(%0)
%b : int = prim::Constant[value=1]()
%c = c::ccc(%a, %b)
return (%c))IR",
graph.get());
std::string pattern = R"IR(
graph(%a, %b):
%c = c::ccc(%a, %b)
return (%c))IR";
Graph pattern_graph;
std::unordered_map<std::string, Value*> vmap;
parseIR(pattern, &pattern_graph, vmap);
auto b_is_constant = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto b_node = match_vmap.at(vmap.at("b"))->node();
return b_node->kind() == prim::Constant;
};
auto b_is_one = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto b_val = toIValue(match_vmap.at(vmap.at("b")));
return b_val && b_val->isInt() && b_val->toInt() == 1;
};
auto b_is_two = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto b_val = toIValue(match_vmap.at(vmap.at("b")));
return b_val && b_val->isInt() && b_val->toInt() == 2;
};
std::string replacement = R"IR(
graph(%a, %b):
%d = d::ddd(%a, %b)
return (%d))IR";
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(pattern, replacement);
// b is constant, so the match will succeed
{
auto g = graph->copy();
rewriter.runOnGraph(g, b_is_constant);
FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g);
}
// b is constant and the value is one, the match will succeed
{
auto g = graph->copy();
rewriter.runOnGraph(g, {b_is_constant, b_is_one});
FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g);
}
// b is constant but the value is not two, the match will fail
{
auto g = graph->copy();
rewriter.runOnGraph(g, {b_is_constant, b_is_two});
FileCheck().check("c::ccc")->check_not("d::ddd")->run(*g);
}
}
void testFilterNoMatch() {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%0):
%a = a::aaa(%0)
%b = prim::Constant[value=1]()
%c = c::ccc(%a, %b)
return (%c))IR",
graph.get());
std::string pattern = R"IR(
graph(%a, %b):
%c = c::ccc(%a, %b)
return (%c))IR";
Graph pattern_graph;
std::unordered_map<std::string, Value*> vmap;
parseIR(pattern, &pattern_graph, vmap);
auto filter = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto b_node = match_vmap.at(vmap.at("b"))->node();
// b_node is not prim::Assign, so this won't match and we'll skip the
// rewrite
return b_node->kind() == prim::Assign;
};
std::string replacement = R"IR(
graph(%a, %b):
%d = d::ddd(%a, %b)
return (%d))IR";
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(pattern, replacement);
rewriter.runOnGraph(graph, filter);
FileCheck().check("c::ccc")->check_not("d::ddd")->run(*graph);
}
void testSubgraphRewriter() {
testFilterMatch();
testFilterNoMatch();
}
} // namespace jit
} // namespace torch