pytorch/test/cpp/jit/test_subgraph_rewriter.cpp
Michael Suo dbe850af5b [jit] do the code reorg (#33851)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33851

Rationale and context described in #33828.

Script to reproduce the move:
https://gist.github.com/suo/16cbefaaeb67ca5a7c6caffd49b7f6e9
ghstack-source-id: 99079645

Test Plan: Make sure CI passes

Reviewed By: jamesr66a

Differential Revision: D20133869

fbshipit-source-id: 390e9241a9c85366d9005c492ac31f10aa96488e
2020-02-27 13:02:51 -08:00

109 lines
2.4 KiB
C++

#include <test/cpp/jit/test_base.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/testing/file_check.h>
namespace torch {
namespace jit {
using namespace testing;
void testFilterMatch() {
auto graph = std::make_shared<Graph>();
script::parseIR(
R"IR(
graph(%0):
%a = a::aaa(%0)
%b = prim::Constant[value=1]()
%c = c::ccc(%a, %b)
return (%c))IR",
graph.get());
std::string pattern = R"IR(
graph(%a, %b):
%c = c::ccc(%a, %b)
return (%c))IR";
Graph pattern_graph;
std::unordered_map<std::string, Value*> vmap;
script::parseIR(
pattern,
&pattern_graph,
vmap);
auto filter = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto b_node = match_vmap.at(vmap.at("b"))->node();
return b_node->kind() == prim::Constant;
};
std::string replacement = R"IR(
graph(%a, %b):
%d = d::ddd(%a, %b)
return (%d))IR";
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(pattern, replacement);
rewriter.runOnGraph(graph, filter);
FileCheck().check("d::ddd")
->check_not("c::ccc")
->run(*graph);
}
void testFilterNoMatch() {
auto graph = std::make_shared<Graph>();
script::parseIR(
R"IR(
graph(%0):
%a = a::aaa(%0)
%b = prim::Constant[value=1]()
%c = c::ccc(%a, %b)
return (%c))IR",
graph.get());
std::string pattern = R"IR(
graph(%a, %b):
%c = c::ccc(%a, %b)
return (%c))IR";
Graph pattern_graph;
std::unordered_map<std::string, Value*> vmap;
script::parseIR(
pattern,
&pattern_graph,
vmap);
auto filter = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto b_node = match_vmap.at(vmap.at("b"))->node();
// b_node is not Constant, so this won't match and we'll skip the rewrite
return b_node->kind() == prim::Assign;
};
std::string replacement = R"IR(
graph(%a, %b):
%d = d::ddd(%a, %b)
return (%d))IR";
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(pattern, replacement);
rewriter.runOnGraph(graph, filter);
FileCheck().check("c::ccc")
->check_not("d::ddd")
->run(*graph);
}
void testSubgraphRewriter() {
testFilterMatch();
testFilterNoMatch();
}
}}