pytorch/test/cpp/jit/test_ir.cpp
Mike Ruberry f01548e5a4 Removes SymbolicVariable from tests (#24007)
Summary:
This PR removes SymbolicVariable from all tests as well as the specialize_autogradzero and canonicalize_ops passes. These passes used SymbolicVariable in a relatively simple way compared to its few remaining uses.

Removing SymbolicVariable means graphs must be constructed by other methods. IRParser was preferred for tests, but tests requiring pointers to graph internals or differentiation use direct construction instead. See https://github.com/pytorch/pytorch/issues/23989, which was discovered during this process, for why IRParser cannot be used when differentiation is required. Direct construction was also used in the updated passes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24007

Test Plan: Only refactors existing tests and preserves current checks; no additional testing needed.

Differential Revision: D16906045

Pulled By: mruberry

fbshipit-source-id: b67df4611562cd7618f969890e2b6840750c7266
2019-08-19 20:49:37 -07:00

97 lines
2.5 KiB
C++

#include "test/cpp/jit/test_base.h"
#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/irparser.h"
namespace torch {
namespace jit {
void testAttributes() {
Graph g;
auto one = attr::alpha;
auto two = attr::device;
auto three = attr::end;
auto four = attr::perm;
Node* n = g.create(Symbol::fromQualString("foo::bar"));
Node& attr = *n;
attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what");
ASSERT_EQ(attr.f(one), 3.4);
ASSERT_EQ(attr.s(three), "what");
ASSERT_EQ(attr.i(two), 5);
attr.s_(one, "no");
ASSERT_EQ(attr.s(one), "no");
ASSERT_TRUE(attr.hasAttribute(three));
ASSERT_TRUE(!attr.hasAttribute(four));
attr.ss_(two, {"hi", "now"});
ASSERT_EQ(attr.ss(two).at(1), "now");
Node* n2 = g.create(Symbol::fromQualString("foo::baz"));
Node& attr2 = *n2;
attr2.copyAttributes(attr);
ASSERT_EQ(attr2.s(one), "no");
attr2.f_(one, 5);
ASSERT_EQ(attr.s(one), "no");
ASSERT_EQ(attr2.f(one), 5);
}
void testBlocks() {
auto g = std::make_shared<Graph>();
const auto graph_string = R"IR(
graph(%a : Tensor,
%b : Tensor,
%c : Tensor):
%2 : int = prim::Constant[value=1]()
%3 : Tensor = aten::add(%a, %b, %2)
%5 : Tensor = prim::If(%c)
block0():
%6 : int = prim::Constant[value=1]()
%7 : Tensor = aten::add(%3, %3, %6)
-> (%7)
block1():
%8 : int = prim::Constant[value=1]()
%9 : Tensor = aten::add(%b, %3, %8)
%10 : int = prim::Constant[value=1]()
%11 : Tensor = aten::add(%9, %3, %10)
-> (%11)
%12 : int = prim::Constant[value=1]()
%13 : Tensor = aten::add(%5, %3, %12)
return (%13))IR";
torch::jit::script::parseIR(graph_string, g.get());
g->lint();
testing::FileCheck()
.check("add")
->check("prim::If")
->check("block0")
->check("aten::add")
->check("block1")
->check_count("aten::add", 3)
->run(*g);
// Removes block0 of the conditional
for (auto* node : g->block()->nodes()) {
if (node->kind() == prim::If) {
node->eraseBlock(0);
break;
}
}
testing::FileCheck()
.check("add")
->check("prim::If")
->check("block0")
->check_not("block")
->run(*g);
g->lint();
// test recursive copy of blocks works
auto g2 = g->copy();
testing::FileCheck()
.check("add")
->check("prim::If")
->check("block0")
->check_not("block")
->run(*g2);
}
} // namespace jit
} // namespace torch