mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: att Test Plan: ci Rollback Plan: Differential Revidsion: D76167093 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155374 Approved by: https://github.com/zhxchen17
48 lines
1.2 KiB
C++
48 lines
1.2 KiB
C++
#include <gtest/gtest.h>
|
|
#include <torch/nativert/executor/ExecutionPlanner.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
TEST(ExecutionPlannerTest, CreatePlan) {
|
|
auto graph = stringToGraph(R"(
|
|
graph(%x, %y):
|
|
%a = foo(a=%x, b=%y)
|
|
%b = foo1(a=%x, b=%y)
|
|
%c = foo2(c=%a, d=%b)
|
|
return(%c)
|
|
)");
|
|
|
|
{
|
|
auto plan = ExecutionPlanner{*graph}.createPlan();
|
|
|
|
auto& values_to_free = plan->valuesToFree;
|
|
EXPECT_EQ(values_to_free.size(), 5);
|
|
|
|
for (const auto i : c10::irange(3)) {
|
|
EXPECT_TRUE(values_to_free[i].empty());
|
|
}
|
|
|
|
EXPECT_EQ(values_to_free[3].size(), 2);
|
|
std::set<int64_t> ids{values_to_free[3].begin(), values_to_free[3].end()};
|
|
EXPECT_EQ(
|
|
ids,
|
|
std::set<int64_t>(
|
|
{graph->tryGetValue("a")->id(), graph->tryGetValue("b")->id()}));
|
|
|
|
EXPECT_EQ(values_to_free[4].size(), 0);
|
|
}
|
|
|
|
{
|
|
auto static_values = ExecutionPlanner::staticValues(*graph);
|
|
std::set<int64_t> static_ids{static_values.begin(), static_values.end()};
|
|
EXPECT_EQ(
|
|
static_ids,
|
|
std::set<int64_t>(
|
|
{graph->tryGetValue("x")->id(),
|
|
graph->tryGetValue("y")->id(),
|
|
graph->tryGetValue("c")->id()}));
|
|
}
|
|
}
|
|
|
|
} // namespace torch::nativert
|