mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52881 **This PR adds:** 1. logic to parse complex constants (complex literals of the form `bj`) 2. logic to parse complex lists 3. support for complex constructors: `complex(tensor/int/float/bool, tensor/int/float/bool)` 4. Limited operator support - `add`, `sub`, `mul`, `torch.tensor`, `torch.as_tensor` **Follow-up work:** 1. Add complex support for unary and other registered ops. 2. support complex constructor with string as input (this is supported in Python eager mode). 3. Test all emitXYZ for all XYZ in `ir_emitter.cpp` (currently only emitConst, emitValueToTensor are tested). e.g., test loops etc. 4. onnx doesn't support complex tensors, so we should error out with a clear and descriptive error message. Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D27245059 Pulled By: anjali411 fbshipit-source-id: af043b5159ae99a9cc8691b5a8401503fa8d6f05
568 lines
12 KiB
C++
568 lines
12 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include "test/cpp/jit/test_utils.h"
|
|
#include "torch/csrc/jit/ir/subgraph_matcher.h"
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
TEST(SubgraphMatcherTest, Trivial1) {
|
|
Graph graph, pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
return (%a))IR",
|
|
&graph);
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%x = a::aaa(%0)
|
|
return (%x))IR",
|
|
&pattern);
|
|
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
|
}
|
|
|
|
TEST(SubgraphMatcherTest, Trivial2) {
|
|
Graph graph;
|
|
auto* g_in = graph.addInput();
|
|
auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1));
|
|
g_tanh->addInput(g_in);
|
|
graph.registerOutput(g_tanh->output());
|
|
|
|
Graph pattern;
|
|
auto* p_in = pattern.addInput();
|
|
auto* p_tanh =
|
|
pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1));
|
|
p_tanh->addInput(p_in);
|
|
pattern.registerOutput(p_tanh->output());
|
|
|
|
auto matches = findPatternMatches(pattern, graph);
|
|
AT_ASSERT(matches.size() == 1);
|
|
for (const Match& m : matches) {
|
|
AT_ASSERT(m.values_map.at(p_in) == g_in);
|
|
AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output());
|
|
AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh);
|
|
}
|
|
}
|
|
|
|
TEST(SubgraphMatcherTest, Trivial3) {
|
|
Graph graph, pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::a(%0)
|
|
%b = a::b(%0)
|
|
%c = a::c(%a, %b)
|
|
return (%c))IR",
|
|
&graph);
|
|
parseIR(
|
|
R"IR(
|
|
graph(%a, %b):
|
|
%c = a::c(%a, %b)
|
|
return (%c))IR",
|
|
&pattern);
|
|
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
|
}
|
|
|
|
TEST(SubgraphMatcherTest, Trivial4) {
|
|
Graph graph;
|
|
auto* g_in0 = graph.addInput();
|
|
auto* g_in1 = graph.addInput();
|
|
auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/1));
|
|
g_mul->addInput(g_in0);
|
|
g_mul->addInput(g_in1);
|
|
graph.registerOutput(g_mul->output());
|
|
|
|
Graph pattern;
|
|
auto* p_in0 = pattern.addInput();
|
|
auto* p_in1 = pattern.addInput();
|
|
auto* p_mul =
|
|
pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/1));
|
|
p_mul->addInput(p_in0);
|
|
p_mul->addInput(p_in1);
|
|
pattern.registerOutput(p_mul->output());
|
|
|
|
auto matches = findPatternMatches(pattern, graph);
|
|
AT_ASSERT(matches.size() == 1);
|
|
for (const Match& m : matches) {
|
|
AT_ASSERT(m.values_map.at(p_in0) == g_in0);
|
|
AT_ASSERT(m.values_map.at(p_in1) == g_in1);
|
|
AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output());
|
|
AT_ASSERT(m.nodes_map.at(p_mul) == g_mul);
|
|
}
|
|
}
|
|
|
|
TEST(SubgraphMatcherTest, Linear1) {
|
|
Graph graph, pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
%b = b::bbb(%a)
|
|
%c = c::ccc(%b)
|
|
%d = d::ddd(%c)
|
|
%a = a::aaa(%0)
|
|
return (%d))IR",
|
|
&graph);
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%x = b::bbb(%0)
|
|
%y = c::ccc(%x)
|
|
return (%y))IR",
|
|
&pattern);
|
|
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
|
}
|
|
|
|
TEST(SubgraphMatcherTest, Linear2) {
|
|
Graph graph;
|
|
auto* g_in = graph.addInput();
|
|
|
|
auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1));
|
|
g_tanh->addInput(g_in);
|
|
|
|
auto* g_tanh2 =
|
|
graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1));
|
|
g_tanh2->addInput(g_tanh->output());
|
|
|
|
graph.registerOutput(g_tanh2->output());
|
|
|
|
Graph pattern;
|
|
auto* p_in = pattern.addInput();
|
|
|
|
auto* p_tanh =
|
|
pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1));
|
|
p_tanh->addInput(p_in);
|
|
|
|
auto* p_tanh2 =
|
|
pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1));
|
|
p_tanh2->addInput(p_tanh->output());
|
|
|
|
pattern.registerOutput(p_tanh2->output());
|
|
|
|
auto matches = findPatternMatches(pattern, graph);
|
|
AT_ASSERT(matches.size() == 1);
|
|
for (const Match& m : matches) {
|
|
AT_ASSERT(m.values_map.at(p_in) == g_in);
|
|
AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output());
|
|
AT_ASSERT(m.values_map.at(p_tanh2->output()) == g_tanh2->output());
|
|
AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh);
|
|
AT_ASSERT(m.nodes_map.at(p_tanh2) == g_tanh2);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Test diamond pattern:
|
|
*
|
|
* ooo
|
|
* |
|
|
* aaa
|
|
* / \
|
|
* bbb ccc
|
|
* \ /
|
|
* ddd
|
|
* |
|
|
* eee
|
|
*/
|
|
TEST(SubgraphMatcherTest, Diamond1) {
|
|
Graph graph, pattern1, pattern2;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%o = o::ooo(%0)
|
|
%a = a::aaa(%o)
|
|
%b = b::bbb(%a)
|
|
%c = c::ccc(%a)
|
|
%d = d::ddd(%b, %c)
|
|
%e = e::eee(%d)
|
|
return (%e))IR",
|
|
&graph);
|
|
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
%b = b::bbb(%a)
|
|
%c = c::ccc(%a)
|
|
%d = d::ddd(%b, %c)
|
|
return (%d))IR",
|
|
&pattern1);
|
|
AT_ASSERT(!findPatternMatches(pattern1, graph).empty());
|
|
|
|
// Check that order of nodes inside the diamond does not affect the result
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
%c = c::ccc(%a)
|
|
%b = b::bbb(%a)
|
|
%d = d::ddd(%b, %c)
|
|
return (%d))IR",
|
|
&pattern2);
|
|
AT_ASSERT(!findPatternMatches(pattern2, graph).empty());
|
|
}
|
|
|
|
/**
|
|
* Test diamond pattern:
|
|
*
|
|
* i0
|
|
* |
|
|
* chunk
|
|
* / \
|
|
* os[0] os[1]
|
|
* \ /
|
|
* *
|
|
* |
|
|
* o1
|
|
*/
|
|
TEST(SubgraphMatcherTest, Diamond2) {
|
|
Graph graph;
|
|
auto* g_in = graph.addInput();
|
|
|
|
auto* g_chunk =
|
|
graph.insertNode(graph.create(prim::ConstantChunk, /*num_outputs =*/2));
|
|
g_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0);
|
|
g_chunk->addInput(g_in);
|
|
|
|
auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/1));
|
|
g_mul->addInput(g_chunk->outputs()[0]);
|
|
g_mul->addInput(g_chunk->outputs()[1]);
|
|
graph.registerOutput(g_mul->output());
|
|
|
|
Graph pattern;
|
|
auto* p_in = pattern.addInput();
|
|
auto* p_chunk = pattern.insertNode(
|
|
pattern.create(prim::ConstantChunk, /*num_outputs =*/2));
|
|
p_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0);
|
|
p_chunk->addInput(p_in);
|
|
|
|
auto* p_mul =
|
|
pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/1));
|
|
p_mul->addInput(p_chunk->outputs()[0]);
|
|
p_mul->addInput(p_chunk->outputs()[1]);
|
|
pattern.registerOutput(p_mul->output());
|
|
|
|
auto matches = findPatternMatches(pattern, graph);
|
|
AT_ASSERT(matches.size() == 1);
|
|
for (const Match& m : matches) {
|
|
AT_ASSERT(m.values_map.at(p_in) == g_in);
|
|
AT_ASSERT(m.values_map.at(p_chunk->outputs()[0]) == g_chunk->outputs()[0]);
|
|
AT_ASSERT(m.values_map.at(p_chunk->outputs()[1]) == g_chunk->outputs()[1]);
|
|
AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output());
|
|
AT_ASSERT(m.nodes_map.at(p_mul) == g_mul);
|
|
}
|
|
}
|
|
|
|
TEST(SubgraphMatcherTest, XPattern) {
|
|
Graph graph, pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0, %1):
|
|
%b = b::bbb(%0)
|
|
%c = c::ccc(%1)
|
|
%x = x::xxx(%b, %c)
|
|
%e = e::eee(%x)
|
|
%f = f::fff(%x)
|
|
%g = g::ggg(%e, %f)
|
|
return (%g))IR",
|
|
&graph);
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0, %1):
|
|
%b = b::bbb(%0)
|
|
%c = c::ccc(%1)
|
|
%x = x::xxx(%b, %c)
|
|
%e = e::eee(%x)
|
|
%f = f::fff(%x)
|
|
%g = g::ggg(%e, %f)
|
|
return (%g))IR",
|
|
&pattern);
|
|
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
|
}
|
|
|
|
TEST(SubgraphMatcherTest, MultipleMatches) {
|
|
Graph graph, pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%t0):
|
|
%t1 = a::aaa(%t0)
|
|
%t2 = a::aaa(%t1)
|
|
%t3 = a::aaa(%t2)
|
|
%t4 = a::aaa(%t3)
|
|
return (%t4))IR",
|
|
&graph);
|
|
parseIR(
|
|
R"IR(
|
|
graph(%t0):
|
|
%t1 = a::aaa(%t0)
|
|
return (%t1))IR",
|
|
&pattern);
|
|
auto matches = findPatternMatches(pattern, graph);
|
|
AT_ASSERT(matches.size() == 4);
|
|
}
|
|
|
|
TEST(SubgraphMatcherTest, OverlappingMatches) {
|
|
Graph graph, pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%t0):
|
|
%t1 = a::aaa(%t0)
|
|
%t2 = a::aaa(%t1)
|
|
%t3 = a::aaa(%t2)
|
|
%t4 = a::aaa(%t3)
|
|
return (%t4))IR",
|
|
&graph);
|
|
parseIR(
|
|
R"IR(
|
|
graph(%t0):
|
|
%t1 = a::aaa(%t0)
|
|
%t2 = a::aaa(%t1)
|
|
return (%t2))IR",
|
|
&pattern);
|
|
auto matches = findPatternMatches(pattern, graph);
|
|
AT_ASSERT(matches.size() == 3);
|
|
}
|
|
|
|
TEST(SubgraphMatcherTest, MatchInBasicBlocks1) {
|
|
Graph graph;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%a, %b, %c):
|
|
%d = aten::mul(%a, %b)
|
|
%x = prim::If(%c)
|
|
block0():
|
|
%x1 = aten::mul(%a, %d)
|
|
-> (%x1)
|
|
block1():
|
|
%x2 = aten::mul(%b, %d)
|
|
-> (%x2)
|
|
return (%x))IR",
|
|
&graph);
|
|
|
|
// Ensure the matches don't cross basic block boundaries
|
|
Graph pattern0;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x, %y):
|
|
%z = aten::mul(%x, %y)
|
|
return (%z))IR",
|
|
&pattern0);
|
|
AT_ASSERT(findPatternMatches(pattern0, graph).size() == 3);
|
|
|
|
Graph pattern1;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x, %y):
|
|
%z1 = aten::mul(%x, %y)
|
|
%z2 = aten::mul(%y, %z1)
|
|
return (%z2))IR",
|
|
&pattern1);
|
|
AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0);
|
|
}
|
|
|
|
TEST(SubgraphMatcherTest, MatchInBasicBlocks2) {
|
|
Graph graph;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%a, %b):
|
|
%x = my::mul(%a, %b)
|
|
%y = my::node_with_subblock()
|
|
block0():
|
|
%z = my::mul(%b, %x)
|
|
-> (%z)
|
|
return (%y))IR",
|
|
&graph);
|
|
|
|
// Check that we can match both mul ops
|
|
Graph pattern0;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x, %y):
|
|
%z = my::mul(%x, %y)
|
|
return (%z))IR",
|
|
&pattern0);
|
|
AT_ASSERT(findPatternMatches(pattern0, graph).size() == 2);
|
|
|
|
// Ensure the matches don't cross basic block boundaries
|
|
Graph pattern1;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x, %y):
|
|
%u = my::mul(%x, %y)
|
|
%v = my::mul(%y, %u)
|
|
return (%v))IR",
|
|
&pattern1);
|
|
AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0);
|
|
}
|
|
|
|
TEST(SubgraphMatcherTest, MatchesAttributes) {
|
|
Graph graph;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::a[isattr=[1,2]](%0)
|
|
%b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j](%0)
|
|
%c = a::c[myattr="qqq"](%a, %b)
|
|
return (%c))IR",
|
|
&graph);
|
|
|
|
{
|
|
Graph pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%a, %b):
|
|
%c = a::c[myattr="qqq"](%a, %b)
|
|
return (%c))IR",
|
|
&pattern);
|
|
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
|
}
|
|
{
|
|
Graph pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%a, %b):
|
|
%c = a::c[myattr="zzz"](%a, %b)
|
|
return (%c))IR",
|
|
&pattern);
|
|
AT_ASSERT(findPatternMatches(pattern, graph).empty());
|
|
}
|
|
{
|
|
Graph pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%b = a::b[extraattr=10](%0)
|
|
return (%b))IR",
|
|
&pattern);
|
|
AT_ASSERT(findPatternMatches(pattern, graph).empty());
|
|
}
|
|
{
|
|
Graph pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j](%0)
|
|
return (%b))IR",
|
|
&pattern);
|
|
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
|
}
|
|
{
|
|
Graph pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j, strattr="rrr"](%0)
|
|
return (%b))IR",
|
|
&pattern);
|
|
AT_ASSERT(findPatternMatches(pattern, graph).empty());
|
|
}
|
|
{
|
|
Graph pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::a[isattr=[1,2]](%0)
|
|
return (%a))IR",
|
|
&pattern);
|
|
// Lists are not supported yet, thus we shouldn't match for now.
|
|
AT_ASSERT(findPatternMatches(pattern, graph).empty());
|
|
}
|
|
{
|
|
Graph pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%a, %b):
|
|
%c = a::c[myattr="q.*"](%a, %b)
|
|
return (%c))IR",
|
|
&pattern);
|
|
AT_ASSERT(!findPatternMatches(pattern, graph).empty());
|
|
}
|
|
}
|
|
|
|
TEST(SubgraphMatcherTest, BadPattern) {
|
|
Graph graph, pattern1, pattern2;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x):
|
|
%y = my::op1(%x)
|
|
%z = my::op2(%x)
|
|
return (%y, %z))IR",
|
|
&graph);
|
|
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x):
|
|
%y = my::node_with_subblock()
|
|
block0():
|
|
%z = my::op(%x)
|
|
-> (%z)
|
|
return (%y))IR",
|
|
&pattern1);
|
|
// No support for patterns with subblocks
|
|
ASSERT_ANY_THROW(findPatternMatches(pattern1, graph));
|
|
|
|
parseIR(
|
|
R"IR(
|
|
graph(%x):
|
|
%y = my::op1(%x)
|
|
%z = my::op2(%x)
|
|
return (%y, %z))IR",
|
|
&pattern2);
|
|
// Not supported multi-output pattern, because not the whole pattern is
|
|
// covered by a traversal up from the first output (`%z = ...` is not
|
|
// visited). See the note "Multi-output Patterns" in subgraph_matcher.h.
|
|
ASSERT_ANY_THROW(findPatternMatches(pattern2, graph));
|
|
}
|
|
|
|
TEST(SubgraphMatcherTest, MultiOutput) {
|
|
{
|
|
Graph graph, pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
%b = b::bbb(%a)
|
|
%c = c::ccc(%a, %b)
|
|
%x = a::aaa(%c)
|
|
%y = b::bbb(%x)
|
|
%z = d::ddd(%x, %y)
|
|
return (%y))IR",
|
|
&graph);
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0):
|
|
%a = a::aaa(%0)
|
|
%b = b::bbb(%a)
|
|
return (%b, %a))IR",
|
|
&pattern);
|
|
AT_ASSERT(findPatternMatches(pattern, graph).size() == 2);
|
|
}
|
|
{
|
|
Graph graph, pattern;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0, %1):
|
|
%a1, %a2 = a::aaa(%0, %1)
|
|
%b = b::bbb(%a1)
|
|
%c = c::ccc(%b)
|
|
|
|
%x1, %x2 = a::aaa(%c, %a2)
|
|
%y = b::bbb(%x1)
|
|
%z = d::ddd(%y)
|
|
return (%z))IR",
|
|
&graph);
|
|
parseIR(
|
|
R"IR(
|
|
graph(%0, %1):
|
|
%a1, %a2 = a::aaa(%0, %1)
|
|
%b = b::bbb(%a1)
|
|
return (%b, %a2))IR",
|
|
&pattern);
|
|
AT_ASSERT(findPatternMatches(pattern, graph).size() == 2);
|
|
}
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|