mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52512 This API is not used at all, and is tricky to maintain. When we were using it last we ran into lifetime issues when using `Value *` as the key. In hind sight, we should have been using `value->unique()`, but regardless, this not being used and should be removed. Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D26696695 Pulled By: eellison fbshipit-source-id: 97ed92e88ecab0085fabbac46573611666bf2420
72 lines
2.0 KiB
C++
72 lines
2.0 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include "test/cpp/jit/test_utils.h"
|
|
|
|
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
|
|
#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
TEST(SubgraphUtilsTest, Basic) {
|
|
auto graph = build_lstm();
|
|
EliminateCommonSubexpression(graph);
|
|
|
|
std::vector<Node*> originalNodes(
|
|
graph->nodes().begin(), graph->nodes().end());
|
|
|
|
// Merge everything into a single subgraph
|
|
bool first = true;
|
|
Node* subgraph;
|
|
for (auto it = graph->nodes().rbegin(); it != graph->nodes().rend();) {
|
|
if (first) {
|
|
subgraph = SubgraphUtils::createSingletonSubgraph(
|
|
*it, prim::DifferentiableGraph);
|
|
it = ++subgraph->reverseIterator();
|
|
first = false;
|
|
}
|
|
|
|
SubgraphUtils::mergeNodeIntoSubgraph(*it, subgraph);
|
|
it = ++subgraph->reverseIterator();
|
|
}
|
|
|
|
// Unmerge and compare with original node listing
|
|
SubgraphUtils::unmergeSubgraph(subgraph);
|
|
EliminateCommonSubexpression(graph);
|
|
|
|
std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end());
|
|
ASSERT_EQ(originalNodes.size(), newNodes.size());
|
|
}
|
|
|
|
TEST(SubgraphUtilsTest, GraphName) {
|
|
auto graph = std::make_shared<Graph>();
|
|
|
|
std::unordered_map<std::string, Value*> parse_map;
|
|
parseIR(
|
|
R"IR(
|
|
graph(%a : Tensor, %b : Tensor, %c : Tensor):
|
|
%x : Tensor = aten::tanh(%a)
|
|
%y : Tensor = aten::mul(%a, %b)
|
|
%p : Tensor = aten::div(%c, %b)
|
|
%q1 : Tensor = aten::mul(%p, %a)
|
|
%q2 : Tensor = aten::tanh(%q1)
|
|
%q3 : Tensor = aten::tanh(%q2)
|
|
%q4 : Tensor = aten::tanh(%q3)
|
|
%q5 : Tensor = aten::tanh(%q4)
|
|
return (%x, %y, %q5))IR",
|
|
&*graph,
|
|
parse_map);
|
|
std::string ref_full_name = "graph_tanh_mul_div_mul_tanh_tanh_tanh_tanh";
|
|
std::string full_name =
|
|
SubgraphUtils::generateNameForGraph(graph, 80, "graph");
|
|
ASSERT_EQ(full_name, ref_full_name);
|
|
|
|
std::string truncated_name =
|
|
SubgraphUtils::generateNameForGraph(graph, 10, "graph");
|
|
|
|
ASSERT_LE(truncated_name.size(), ref_full_name.size());
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|