pytorch/test/cpp/nativert/test_execution_planner.cpp
dolpm 8892b782a8 [nativert] move execution planner to torch (#155374)
Summary: att

Test Plan:
ci

Rollback Plan:

Differential Revidsion: D76167093

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155374
Approved by: https://github.com/zhxchen17
2025-06-10 22:36:06 +00:00

48 lines
1.2 KiB
C++

#include <gtest/gtest.h>
#include <torch/nativert/executor/ExecutionPlanner.h>
namespace torch::nativert {
TEST(ExecutionPlannerTest, CreatePlan) {
auto graph = stringToGraph(R"(
graph(%x, %y):
%a = foo(a=%x, b=%y)
%b = foo1(a=%x, b=%y)
%c = foo2(c=%a, d=%b)
return(%c)
)");
{
auto plan = ExecutionPlanner{*graph}.createPlan();
auto& values_to_free = plan->valuesToFree;
EXPECT_EQ(values_to_free.size(), 5);
for (const auto i : c10::irange(3)) {
EXPECT_TRUE(values_to_free[i].empty());
}
EXPECT_EQ(values_to_free[3].size(), 2);
std::set<int64_t> ids{values_to_free[3].begin(), values_to_free[3].end()};
EXPECT_EQ(
ids,
std::set<int64_t>(
{graph->tryGetValue("a")->id(), graph->tryGetValue("b")->id()}));
EXPECT_EQ(values_to_free[4].size(), 0);
}
{
auto static_values = ExecutionPlanner::staticValues(*graph);
std::set<int64_t> static_ids{static_values.begin(), static_values.end()};
EXPECT_EQ(
static_ids,
std::set<int64_t>(
{graph->tryGetValue("x")->id(),
graph->tryGetValue("y")->id(),
graph->tryGetValue("c")->id()}));
}
}
} // namespace torch::nativert