#include #include namespace torch::nativert { TEST(ExecutionPlannerTest, CreatePlan) { auto graph = stringToGraph(R"( graph(%x, %y): %a = foo(a=%x, b=%y) %b = foo1(a=%x, b=%y) %c = foo2(c=%a, d=%b) return(%c) )"); { auto plan = ExecutionPlanner{*graph}.createPlan(); auto& values_to_free = plan->valuesToFree; EXPECT_EQ(values_to_free.size(), 5); for (const auto i : c10::irange(3)) { EXPECT_TRUE(values_to_free[i].empty()); } EXPECT_EQ(values_to_free[3].size(), 2); std::set ids{values_to_free[3].begin(), values_to_free[3].end()}; EXPECT_EQ( ids, std::set( {graph->tryGetValue("a")->id(), graph->tryGetValue("b")->id()})); EXPECT_EQ(values_to_free[4].size(), 0); } { auto static_values = ExecutionPlanner::staticValues(*graph); std::set static_ids{static_values.begin(), static_values.end()}; EXPECT_EQ( static_ids, std::set( {graph->tryGetValue("x")->id(), graph->tryGetValue("y")->id(), graph->tryGetValue("c")->id()})); } } } // namespace torch::nativert