mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR removes SymbolicVariable from all tests as well as the specialize_autogradzero and canonicalize_ops passes. These passes used SymbolicVariable in a relatively simple way compared to its few remaining uses. Removing SymbolicVariable means graphs must be constructed by other methods. IRParser was preferred for tests, but tests requiring pointers to graph internals or differentiation use direct construction instead. See https://github.com/pytorch/pytorch/issues/23989, which was discovered during this process, for why IRParser cannot be used when differentiation is required. Direct construction was also used in the updated passes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/24007 Test Plan: Only refactors existing tests and preserves current checks; no additional testing needed. Differential Revision: D16906045 Pulled By: mruberry fbshipit-source-id: b67df4611562cd7618f969890e2b6840750c7266
97 lines
2.5 KiB
C++
97 lines
2.5 KiB
C++
#include "test/cpp/jit/test_base.h"
|
|
#include "test/cpp/jit/test_utils.h"
|
|
#include "torch/csrc/jit/irparser.h"
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void testAttributes() {
|
|
Graph g;
|
|
auto one = attr::alpha;
|
|
auto two = attr::device;
|
|
auto three = attr::end;
|
|
auto four = attr::perm;
|
|
Node* n = g.create(Symbol::fromQualString("foo::bar"));
|
|
Node& attr = *n;
|
|
attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what");
|
|
ASSERT_EQ(attr.f(one), 3.4);
|
|
ASSERT_EQ(attr.s(three), "what");
|
|
ASSERT_EQ(attr.i(two), 5);
|
|
attr.s_(one, "no");
|
|
ASSERT_EQ(attr.s(one), "no");
|
|
ASSERT_TRUE(attr.hasAttribute(three));
|
|
ASSERT_TRUE(!attr.hasAttribute(four));
|
|
attr.ss_(two, {"hi", "now"});
|
|
ASSERT_EQ(attr.ss(two).at(1), "now");
|
|
|
|
Node* n2 = g.create(Symbol::fromQualString("foo::baz"));
|
|
Node& attr2 = *n2;
|
|
attr2.copyAttributes(attr);
|
|
ASSERT_EQ(attr2.s(one), "no");
|
|
attr2.f_(one, 5);
|
|
ASSERT_EQ(attr.s(one), "no");
|
|
ASSERT_EQ(attr2.f(one), 5);
|
|
}
|
|
|
|
void testBlocks() {
|
|
auto g = std::make_shared<Graph>();
|
|
const auto graph_string = R"IR(
|
|
graph(%a : Tensor,
|
|
%b : Tensor,
|
|
%c : Tensor):
|
|
%2 : int = prim::Constant[value=1]()
|
|
%3 : Tensor = aten::add(%a, %b, %2)
|
|
%5 : Tensor = prim::If(%c)
|
|
block0():
|
|
%6 : int = prim::Constant[value=1]()
|
|
%7 : Tensor = aten::add(%3, %3, %6)
|
|
-> (%7)
|
|
block1():
|
|
%8 : int = prim::Constant[value=1]()
|
|
%9 : Tensor = aten::add(%b, %3, %8)
|
|
%10 : int = prim::Constant[value=1]()
|
|
%11 : Tensor = aten::add(%9, %3, %10)
|
|
-> (%11)
|
|
%12 : int = prim::Constant[value=1]()
|
|
%13 : Tensor = aten::add(%5, %3, %12)
|
|
return (%13))IR";
|
|
torch::jit::script::parseIR(graph_string, g.get());
|
|
|
|
g->lint();
|
|
testing::FileCheck()
|
|
.check("add")
|
|
->check("prim::If")
|
|
->check("block0")
|
|
->check("aten::add")
|
|
->check("block1")
|
|
->check_count("aten::add", 3)
|
|
->run(*g);
|
|
|
|
// Removes block0 of the conditional
|
|
for (auto* node : g->block()->nodes()) {
|
|
if (node->kind() == prim::If) {
|
|
node->eraseBlock(0);
|
|
break;
|
|
}
|
|
}
|
|
|
|
testing::FileCheck()
|
|
.check("add")
|
|
->check("prim::If")
|
|
->check("block0")
|
|
->check_not("block")
|
|
->run(*g);
|
|
g->lint();
|
|
// test recursive copy of blocks works
|
|
auto g2 = g->copy();
|
|
testing::FileCheck()
|
|
.check("add")
|
|
->check("prim::If")
|
|
->check("block0")
|
|
->check_not("block")
|
|
->run(*g2);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|