mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33851 Rationale and context described in #33828. Script to reproduce the move: https://gist.github.com/suo/16cbefaaeb67ca5a7c6caffd49b7f6e9 ghstack-source-id: 99079645 Test Plan: Make sure CI passes Reviewed By: jamesr66a Differential Revision: D20133869 fbshipit-source-id: 390e9241a9c85366d9005c492ac31f10aa96488e
109 lines
2.4 KiB
C++
109 lines
2.4 KiB
C++
#include <test/cpp/jit/test_base.h>
|
|
#include <test/cpp/jit/test_utils.h>
|
|
#include <torch/csrc/jit/ir/subgraph_matcher.h>
|
|
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
using namespace testing;
|
|
|
|
void testFilterMatch() {
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
%b = prim::Constant[value=1]()
|
|
%c = c::ccc(%a, %b)
|
|
return (%c))IR",
|
|
graph.get());
|
|
|
|
std::string pattern = R"IR(
|
|
graph(%a, %b):
|
|
%c = c::ccc(%a, %b)
|
|
return (%c))IR";
|
|
Graph pattern_graph;
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
|
|
script::parseIR(
|
|
pattern,
|
|
&pattern_graph,
|
|
vmap);
|
|
|
|
auto filter = [](const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
const auto& match_vmap = match.values_map;
|
|
auto b_node = match_vmap.at(vmap.at("b"))->node();
|
|
return b_node->kind() == prim::Constant;
|
|
};
|
|
|
|
std::string replacement = R"IR(
|
|
graph(%a, %b):
|
|
%d = d::ddd(%a, %b)
|
|
return (%d))IR";
|
|
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(pattern, replacement);
|
|
rewriter.runOnGraph(graph, filter);
|
|
|
|
FileCheck().check("d::ddd")
|
|
->check_not("c::ccc")
|
|
->run(*graph);
|
|
}
|
|
|
|
void testFilterNoMatch() {
|
|
auto graph = std::make_shared<Graph>();
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
%b = prim::Constant[value=1]()
|
|
%c = c::ccc(%a, %b)
|
|
return (%c))IR",
|
|
graph.get());
|
|
|
|
std::string pattern = R"IR(
|
|
graph(%a, %b):
|
|
%c = c::ccc(%a, %b)
|
|
return (%c))IR";
|
|
Graph pattern_graph;
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
|
|
script::parseIR(
|
|
pattern,
|
|
&pattern_graph,
|
|
vmap);
|
|
|
|
auto filter = [](const Match& match,
|
|
const std::unordered_map<std::string, Value*>& vmap) {
|
|
const auto& match_vmap = match.values_map;
|
|
auto b_node = match_vmap.at(vmap.at("b"))->node();
|
|
// b_node is not Constant, so this won't match and we'll skip the rewrite
|
|
return b_node->kind() == prim::Assign;
|
|
};
|
|
|
|
std::string replacement = R"IR(
|
|
graph(%a, %b):
|
|
%d = d::ddd(%a, %b)
|
|
return (%d))IR";
|
|
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(pattern, replacement);
|
|
rewriter.runOnGraph(graph, filter);
|
|
|
|
FileCheck().check("c::ccc")
|
|
->check_not("d::ddd")
|
|
->run(*graph);
|
|
|
|
}
|
|
|
|
|
|
void testSubgraphRewriter() {
|
|
testFilterMatch();
|
|
testFilterNoMatch();
|
|
}
|
|
|
|
}}
|