pytorch/test/cpp/jit/test_ir.h
Elias Ellison f3806094d5 Breakup Test Misc (batch 1/2) (#18071)
Summary:
Breakup test_misc so that a test for a file is in test_filename. I think we might want to wait on moving test files into the source directory, since that would involve moving some tests over to the C10 folder, and this goes 99% of the way for test discoverability IMO anyway.

I added a file test_utils for common functions invoked in the tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18071

Differential Revision: D14485787

Pulled By: eellison

fbshipit-source-id: dcb20d1978d490999d435ea20c1d0503413a5c80
2019-03-15 13:56:19 -07:00

90 lines
2.2 KiB
C++

#pragma once
#include "test/cpp/jit/test_base.h"
#include "test/cpp/jit/test_utils.h"
namespace torch {
namespace jit {
namespace test {
void testAttributes() {
Graph g;
auto one = attr::alpha;
auto two = attr::device;
auto three = attr::end;
auto four = attr::perm;
Node* n = g.create(Symbol::fromQualString("foo::bar"));
Node& attr = *n;
attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what");
ASSERT_EQ(attr.f(one), 3.4);
ASSERT_EQ(attr.s(three), "what");
ASSERT_EQ(attr.i(two), 5);
attr.s_(one, "no");
ASSERT_EQ(attr.s(one), "no");
ASSERT_TRUE(attr.hasAttribute(three));
ASSERT_TRUE(!attr.hasAttribute(four));
attr.ss_(two, {"hi", "now"});
ASSERT_EQ(attr.ss(two).at(1), "now");
Node* n2 = g.create(Symbol::fromQualString("foo::baz"));
Node& attr2 = *n2;
attr2.copyAttributes(attr);
ASSERT_EQ(attr2.s(one), "no");
attr2.f_(one, 5);
ASSERT_EQ(attr.s(one), "no");
ASSERT_EQ(attr2.f(one), 5);
}
void testBlocks(std::ostream& out = std::cout) {
auto g = std::make_shared<Graph>();
// auto g = *graph;
auto a = Var::asNewInput(*g, "a");
auto b = Var::asNewInput(*g, "b");
auto c = a + b;
auto r =
g->appendNode(g->create(prim::If, {Var::asNewInput(*g, "c").value()}));
auto then_block = r->addBlock();
auto else_block = r->addBlock();
{
WithInsertPoint guard(then_block);
auto t = c + c;
then_block->registerOutput(t.value());
}
{
WithInsertPoint guard(else_block);
auto d = b + c;
auto e = d + c;
else_block->registerOutput(e.value());
}
g->registerOutput((Var(r->output()) + c).value());
g->lint();
testing::FileCheck()
.check("add")
->check("prim::If")
->check("block0")
->check("aten::add")
->check("block1")
->check_count("aten::add", 3)
->run(*g);
r->eraseBlock(0);
testing::FileCheck()
.check("add")
->check("prim::If")
->check("block0")
->check_not("block")
->run(*g);
g->lint();
// test recursive copy of blocks works
auto g2 = g->copy();
testing::FileCheck()
.check("add")
->check("prim::If")
->check("block0")
->check_not("block")
->run(*g2);
}
} // namespace test
} // namespace jit
} // namespace torch