mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Breakup test_misc so that a test for a file is in test_filename. I think we might want to wait on moving test files into the source directory, since that would involve moving some tests over to the C10 folder, and this goes 99% of the way for test discoverability IMO anyway. I added a file test_utils for common functions invoked in the tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18071 Differential Revision: D14485787 Pulled By: eellison fbshipit-source-id: dcb20d1978d490999d435ea20c1d0503413a5c80
90 lines
2.2 KiB
C++
90 lines
2.2 KiB
C++
#pragma once
|
|
|
|
#include "test/cpp/jit/test_base.h"
|
|
#include "test/cpp/jit/test_utils.h"
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace test {
|
|
|
|
void testAttributes() {
|
|
Graph g;
|
|
auto one = attr::alpha;
|
|
auto two = attr::device;
|
|
auto three = attr::end;
|
|
auto four = attr::perm;
|
|
Node* n = g.create(Symbol::fromQualString("foo::bar"));
|
|
Node& attr = *n;
|
|
attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what");
|
|
ASSERT_EQ(attr.f(one), 3.4);
|
|
ASSERT_EQ(attr.s(three), "what");
|
|
ASSERT_EQ(attr.i(two), 5);
|
|
attr.s_(one, "no");
|
|
ASSERT_EQ(attr.s(one), "no");
|
|
ASSERT_TRUE(attr.hasAttribute(three));
|
|
ASSERT_TRUE(!attr.hasAttribute(four));
|
|
attr.ss_(two, {"hi", "now"});
|
|
ASSERT_EQ(attr.ss(two).at(1), "now");
|
|
|
|
Node* n2 = g.create(Symbol::fromQualString("foo::baz"));
|
|
Node& attr2 = *n2;
|
|
attr2.copyAttributes(attr);
|
|
ASSERT_EQ(attr2.s(one), "no");
|
|
attr2.f_(one, 5);
|
|
ASSERT_EQ(attr.s(one), "no");
|
|
ASSERT_EQ(attr2.f(one), 5);
|
|
}
|
|
|
|
void testBlocks(std::ostream& out = std::cout) {
|
|
auto g = std::make_shared<Graph>();
|
|
// auto g = *graph;
|
|
auto a = Var::asNewInput(*g, "a");
|
|
auto b = Var::asNewInput(*g, "b");
|
|
auto c = a + b;
|
|
auto r =
|
|
g->appendNode(g->create(prim::If, {Var::asNewInput(*g, "c").value()}));
|
|
auto then_block = r->addBlock();
|
|
auto else_block = r->addBlock();
|
|
{
|
|
WithInsertPoint guard(then_block);
|
|
auto t = c + c;
|
|
then_block->registerOutput(t.value());
|
|
}
|
|
{
|
|
WithInsertPoint guard(else_block);
|
|
auto d = b + c;
|
|
auto e = d + c;
|
|
else_block->registerOutput(e.value());
|
|
}
|
|
g->registerOutput((Var(r->output()) + c).value());
|
|
g->lint();
|
|
testing::FileCheck()
|
|
.check("add")
|
|
->check("prim::If")
|
|
->check("block0")
|
|
->check("aten::add")
|
|
->check("block1")
|
|
->check_count("aten::add", 3)
|
|
->run(*g);
|
|
r->eraseBlock(0);
|
|
testing::FileCheck()
|
|
.check("add")
|
|
->check("prim::If")
|
|
->check("block0")
|
|
->check_not("block")
|
|
->run(*g);
|
|
g->lint();
|
|
// test recursive copy of blocks works
|
|
auto g2 = g->copy();
|
|
testing::FileCheck()
|
|
.check("add")
|
|
->check("prim::If")
|
|
->check("block0")
|
|
->check_not("block")
|
|
->run(*g2);
|
|
}
|
|
|
|
} // namespace test
|
|
} // namespace jit
|
|
} // namespace torch
|