mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR removes SymbolicVariable from all tests as well as the specialize_autogradzero and canonicalize_ops passes. These passes used SymbolicVariable in a relatively simple way compared to its few remaining uses. Removing SymbolicVariable means graphs must be constructed by other methods. IRParser was preferred for tests, but tests requiring pointers to graph internals or differentiation use direct construction instead. See https://github.com/pytorch/pytorch/issues/23989, which was discovered during this process, for why IRParser cannot be used when differentiation is required. Direct construction was also used in the updated passes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/24007 Test Plan: Only refactors existing tests and preserves current checks; no additional testing needed. Differential Revision: D16906045 Pulled By: mruberry fbshipit-source-id: b67df4611562cd7618f969890e2b6840750c7266
514 lines
11 KiB
C++
514 lines
11 KiB
C++
#include "test/cpp/jit/test_base.h"
|
|
#include "test/cpp/jit/test_utils.h"
|
|
#include "torch/csrc/jit/subgraph_matcher.h"
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void testTrivial1() {
|
|
Graph graph, pattern;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
return (%a))IR",
|
|
&graph);
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%x = a::aaa(%0)
|
|
return (%x))IR",
|
|
&pattern);
|
|
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
|
}
|
|
|
|
void testTrivial2() {
|
|
Graph graph;
|
|
auto* g_in = graph.addInput();
|
|
auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/ 1));
|
|
g_tanh->addInput(g_in);
|
|
graph.registerOutput(g_tanh->output());
|
|
|
|
Graph pattern;
|
|
auto* p_in = pattern.addInput();
|
|
auto* p_tanh = pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/ 1));
|
|
p_tanh->addInput(p_in);
|
|
pattern.registerOutput(p_tanh->output());
|
|
|
|
auto matches = findPatternMatches(pattern, graph);
|
|
AT_ASSERT(matches.size() == 1);
|
|
for (const Match& m : matches) {
|
|
AT_ASSERT(m.values_map.at(p_in) == g_in);
|
|
AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output());
|
|
AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh);
|
|
}
|
|
}
|
|
|
|
void testTrivial3() {
|
|
Graph graph, pattern;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::a(%0)
|
|
%b = a::b(%0)
|
|
%c = a::c(%a, %b)
|
|
return (%c))IR",
|
|
&graph);
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%a, %b):
|
|
%c = a::c(%a, %b)
|
|
return (%c))IR",
|
|
&pattern);
|
|
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
|
}
|
|
|
|
void testTrivial4() {
|
|
Graph graph;
|
|
auto* g_in0 = graph.addInput();
|
|
auto* g_in1 = graph.addInput();
|
|
auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/ 1));
|
|
g_mul->addInput(g_in0);
|
|
g_mul->addInput(g_in1);
|
|
graph.registerOutput(g_mul->output());
|
|
|
|
Graph pattern;
|
|
auto* p_in0 = pattern.addInput();
|
|
auto* p_in1 = pattern.addInput();
|
|
auto* p_mul = pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/ 1));
|
|
p_mul->addInput(p_in0);
|
|
p_mul->addInput(p_in1);
|
|
pattern.registerOutput(p_mul->output());
|
|
|
|
auto matches = findPatternMatches(pattern, graph);
|
|
AT_ASSERT(matches.size() == 1);
|
|
for (const Match& m : matches) {
|
|
AT_ASSERT(m.values_map.at(p_in0) == g_in0);
|
|
AT_ASSERT(m.values_map.at(p_in1) == g_in1);
|
|
AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output());
|
|
AT_ASSERT(m.nodes_map.at(p_mul) == g_mul);
|
|
}
|
|
}
|
|
|
|
void testLinear1() {
|
|
Graph graph, pattern;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
%b = b::bbb(%a)
|
|
%c = c::ccc(%b)
|
|
%d = d::ddd(%c)
|
|
%a = a::aaa(%0)
|
|
return (%d))IR",
|
|
&graph);
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%x = b::bbb(%0)
|
|
%y = c::ccc(%x)
|
|
return (%y))IR",
|
|
&pattern);
|
|
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
|
}
|
|
|
|
void testLinear2() {
|
|
Graph graph;
|
|
auto* g_in = graph.addInput();
|
|
|
|
auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/ 1));
|
|
g_tanh->addInput(g_in);
|
|
|
|
auto* g_tanh2 = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/ 1));
|
|
g_tanh2->addInput(g_tanh->output());
|
|
|
|
graph.registerOutput(g_tanh2->output());
|
|
|
|
Graph pattern;
|
|
auto* p_in = pattern.addInput();
|
|
|
|
auto* p_tanh = pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/ 1));
|
|
p_tanh->addInput(p_in);
|
|
|
|
auto* p_tanh2 = pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/ 1));
|
|
p_tanh2->addInput(p_tanh->output());
|
|
|
|
pattern.registerOutput(p_tanh2->output());
|
|
|
|
auto matches = findPatternMatches(pattern, graph);
|
|
AT_ASSERT(matches.size() == 1);
|
|
for (const Match& m : matches) {
|
|
AT_ASSERT(m.values_map.at(p_in) == g_in);
|
|
AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output());
|
|
AT_ASSERT(m.values_map.at(p_tanh2->output()) == g_tanh2->output());
|
|
AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh);
|
|
AT_ASSERT(m.nodes_map.at(p_tanh2) == g_tanh2);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Test diamond pattern:
|
|
*
|
|
* ooo
|
|
* |
|
|
* aaa
|
|
* / \
|
|
* bbb ccc
|
|
* \ /
|
|
* ddd
|
|
* |
|
|
* eee
|
|
*/
|
|
void testDiamond1() {
|
|
Graph graph, pattern1, pattern2;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%o = o::ooo(%0)
|
|
%a = a::aaa(%o)
|
|
%b = b::bbb(%a)
|
|
%c = c::ccc(%a)
|
|
%d = d::ddd(%b, %c)
|
|
%e = e::eee(%d)
|
|
return (%e))IR",
|
|
&graph);
|
|
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
%b = b::bbb(%a)
|
|
%c = c::ccc(%a)
|
|
%d = d::ddd(%b, %c)
|
|
return (%d))IR",
|
|
&pattern1);
|
|
AT_ASSERT(!findPatternMatches(pattern1, graph).empty());
|
|
|
|
// Check that order of nodes inside the diamond does not affect the result
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
%c = c::ccc(%a)
|
|
%b = b::bbb(%a)
|
|
%d = d::ddd(%b, %c)
|
|
return (%d))IR",
|
|
&pattern2);
|
|
AT_ASSERT(!findPatternMatches(pattern2, graph).empty());
|
|
}
|
|
|
|
/**
|
|
* Test diamond pattern:
|
|
*
|
|
* i0
|
|
* |
|
|
* chunk
|
|
* / \
|
|
* os[0] os[1]
|
|
* \ /
|
|
* *
|
|
* |
|
|
* o1
|
|
*/
|
|
void testDiamond2() {
|
|
Graph graph;
|
|
auto* g_in = graph.addInput();
|
|
|
|
auto* g_chunk = graph.insertNode(graph.create(prim::ConstantChunk, /*num_outputs =*/ 2));
|
|
g_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0);
|
|
g_chunk->addInput(g_in);
|
|
|
|
auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/ 1));
|
|
g_mul->addInput(g_chunk->outputs()[0]);
|
|
g_mul->addInput(g_chunk->outputs()[1]);
|
|
graph.registerOutput(g_mul->output());
|
|
|
|
Graph pattern;
|
|
auto* p_in = pattern.addInput();
|
|
auto* p_chunk = pattern.insertNode(pattern.create(prim::ConstantChunk, /*num_outputs =*/ 2));
|
|
p_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0);
|
|
p_chunk->addInput(p_in);
|
|
|
|
auto* p_mul = pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/ 1));
|
|
p_mul->addInput(p_chunk->outputs()[0]);
|
|
p_mul->addInput(p_chunk->outputs()[1]);
|
|
pattern.registerOutput(p_mul->output());
|
|
|
|
auto matches = findPatternMatches(pattern, graph);
|
|
AT_ASSERT(matches.size() == 1);
|
|
for (const Match& m : matches) {
|
|
AT_ASSERT(m.values_map.at(p_in) == g_in);
|
|
AT_ASSERT(m.values_map.at(p_chunk->outputs()[0]) == g_chunk->outputs()[0]);
|
|
AT_ASSERT(m.values_map.at(p_chunk->outputs()[1]) == g_chunk->outputs()[1]);
|
|
AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output());
|
|
AT_ASSERT(m.nodes_map.at(p_mul) == g_mul);
|
|
}
|
|
}
|
|
|
|
void testXPattern() {
|
|
Graph graph, pattern;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0, %1):
|
|
%b = b::bbb(%0)
|
|
%c = c::ccc(%1)
|
|
%x = x::xxx(%b, %c)
|
|
%e = e::eee(%x)
|
|
%f = f::fff(%x)
|
|
%g = g::ggg(%e, %f)
|
|
return (%g))IR",
|
|
&graph);
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0, %1):
|
|
%b = b::bbb(%0)
|
|
%c = c::ccc(%1)
|
|
%x = x::xxx(%b, %c)
|
|
%e = e::eee(%x)
|
|
%f = f::fff(%x)
|
|
%g = g::ggg(%e, %f)
|
|
return (%g))IR",
|
|
&pattern);
|
|
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
|
}
|
|
|
|
void testMultipleMatches() {
|
|
Graph graph, pattern;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%t0):
|
|
%t1 = a::aaa(%t0)
|
|
%t2 = a::aaa(%t1)
|
|
%t3 = a::aaa(%t2)
|
|
%t4 = a::aaa(%t3)
|
|
return (%t4))IR",
|
|
&graph);
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%t0):
|
|
%t1 = a::aaa(%t0)
|
|
return (%t1))IR",
|
|
&pattern);
|
|
auto matches = findPatternMatches(pattern, graph);
|
|
AT_ASSERT(matches.size() == 4);
|
|
}
|
|
|
|
void testOverlappingMatches() {
|
|
Graph graph, pattern;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%t0):
|
|
%t1 = a::aaa(%t0)
|
|
%t2 = a::aaa(%t1)
|
|
%t3 = a::aaa(%t2)
|
|
%t4 = a::aaa(%t3)
|
|
return (%t4))IR",
|
|
&graph);
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%t0):
|
|
%t1 = a::aaa(%t0)
|
|
%t2 = a::aaa(%t1)
|
|
return (%t2))IR",
|
|
&pattern);
|
|
auto matches = findPatternMatches(pattern, graph);
|
|
AT_ASSERT(matches.size() == 3);
|
|
}
|
|
|
|
void testMatchInBasicBlocks1() {
|
|
Graph graph;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%a, %b, %c):
|
|
%d = aten::mul(%a, %b)
|
|
%x = prim::If(%c)
|
|
block0():
|
|
%x1 = aten::mul(%a, %d)
|
|
-> (%x1)
|
|
block1():
|
|
%x2 = aten::mul(%b, %d)
|
|
-> (%x2)
|
|
return (%x))IR",
|
|
&graph);
|
|
|
|
// Ensure the matches don't cross basic block boundaries
|
|
Graph pattern0;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%x, %y):
|
|
%z = aten::mul(%x, %y)
|
|
return (%z))IR",
|
|
&pattern0);
|
|
AT_ASSERT(findPatternMatches(pattern0, graph).size() == 3);
|
|
|
|
Graph pattern1;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%x, %y):
|
|
%z1 = aten::mul(%x, %y)
|
|
%z2 = aten::mul(%y, %z1)
|
|
return (%z2))IR",
|
|
&pattern1);
|
|
AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0);
|
|
}
|
|
|
|
void testMatchInBasicBlocks2() {
|
|
Graph graph;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%a, %b):
|
|
%x = my::mul(%a, %b)
|
|
%y = my::node_with_subblock()
|
|
block0():
|
|
%z = my::mul(%b, %x)
|
|
-> (%z)
|
|
return (%y))IR",
|
|
&graph);
|
|
|
|
// Check that we can match both mul ops
|
|
Graph pattern0;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%x, %y):
|
|
%z = my::mul(%x, %y)
|
|
return (%z))IR",
|
|
&pattern0);
|
|
AT_ASSERT(findPatternMatches(pattern0, graph).size() == 2);
|
|
|
|
// Ensure the matches don't cross basic block boundaries
|
|
Graph pattern1;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%x, %y):
|
|
%u = my::mul(%x, %y)
|
|
%v = my::mul(%y, %u)
|
|
return (%v))IR",
|
|
&pattern1);
|
|
AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0);
|
|
}
|
|
|
|
void testMatchesAttributes() {
|
|
Graph graph;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::a[isattr=[1,2]](%0)
|
|
%b = a::b[intattr=10, floatattr=3.14](%0)
|
|
%c = a::c[myattr="qqq"](%a, %b)
|
|
return (%c))IR",
|
|
&graph);
|
|
|
|
{
|
|
Graph pattern;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%a, %b):
|
|
%c = a::c[myattr="qqq"](%a, %b)
|
|
return (%c))IR",
|
|
&pattern);
|
|
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
|
}
|
|
{
|
|
Graph pattern;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%a, %b):
|
|
%c = a::c[myattr="zzz"](%a, %b)
|
|
return (%c))IR",
|
|
&pattern);
|
|
AT_ASSERT(findPatternMatches(pattern, graph).empty());
|
|
}
|
|
{
|
|
Graph pattern;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%b = a::b[extraattr=10](%0)
|
|
return (%b))IR",
|
|
&pattern);
|
|
AT_ASSERT(findPatternMatches(pattern, graph).empty());
|
|
}
|
|
{
|
|
Graph pattern;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%b = a::b[intattr=10, floatattr=3.14](%0)
|
|
return (%b))IR",
|
|
&pattern);
|
|
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
|
}
|
|
{
|
|
Graph pattern;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%b = a::b[intattr=10, floatattr=3.14, strattr="rrr"](%0)
|
|
return (%b))IR",
|
|
&pattern);
|
|
AT_ASSERT(findPatternMatches(pattern, graph).empty());
|
|
}
|
|
{
|
|
Graph pattern;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::a[isattr=[1,2]](%0)
|
|
return (%a))IR",
|
|
&pattern);
|
|
// Lists are not supported yet, thus we shouldn't match for now.
|
|
AT_ASSERT(findPatternMatches(pattern, graph).empty());
|
|
}
|
|
}
|
|
|
|
void testBadPattern() {
|
|
Graph graph, pattern1, pattern2;
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
return (%a))IR",
|
|
&graph);
|
|
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%x):
|
|
%y = my::node_with_subblock()
|
|
block0():
|
|
%z = my::op(%x)
|
|
-> (%z)
|
|
return (%y))IR",
|
|
&pattern1);
|
|
ASSERT_ANY_THROW(findPatternMatches(pattern1, graph));
|
|
|
|
script::parseIR(
|
|
R"IR(
|
|
graph(%x):
|
|
%y = my::op1(%x)
|
|
%z = my::op2(%x)
|
|
return (%y, %z))IR",
|
|
&pattern2);
|
|
ASSERT_ANY_THROW(findPatternMatches(pattern2, graph));
|
|
}
|
|
|
|
void testSubgraphMatching() {
|
|
testTrivial1();
|
|
testTrivial2();
|
|
testTrivial3();
|
|
testTrivial4();
|
|
testLinear1();
|
|
testLinear2();
|
|
testDiamond1();
|
|
testDiamond2();
|
|
testXPattern();
|
|
testMultipleMatches();
|
|
testOverlappingMatches();
|
|
testMatchInBasicBlocks1();
|
|
testMatchInBasicBlocks2();
|
|
testMatchesAttributes();
|
|
testBadPattern();
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|