* Add graph rewrite rule that removes repeated application of scalar unary ops that are involutions (their own inverse).

* Update rewrite rule for Transpose to also handle ConjugateTranspose.

PiperOrigin-RevId: 173967184
This commit is contained in:
A. Unique TensorFlower 2017-10-30 16:22:18 -07:00 committed by TensorFlower Gardener
parent ff5c276adf
commit b46c196e9d
2 changed files with 41 additions and 3 deletions

View File

@ -31,6 +31,12 @@ namespace tensorflow {
namespace grappler {
namespace {
static bool IsInvolution(const NodeDef& node) {
const std::unordered_set<string> involution_ops = {"Conj", "Reciprocal",
"Neg", "LogicalNot"};
return involution_ops.count(node.op()) > 0;
}
bool AreInversePermutations(gtl::ArraySlice<int32> a,
gtl::ArraySlice<int32> b) {
if (a.size() != b.size()) {
@ -394,10 +400,20 @@ void ArithmeticOptimizer::DedupComputations(GraphDef* optimized_graph) const {
string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const NodeDef* node, GraphDef* graph_def, NodeMap* node_map,
std::vector<const NodeDef*>* new_nodes) const {
// Remove inverse transposes.
if (node->op() == "Transpose") {
// Remove involutions applied twice.
if (IsInvolution(*node)) {
// An involution is a function f(x) that is its own inverse,
// i.e. f(f(x)) = x.
const NodeDef* input = node_map->GetNode(node->input(0));
if (input->op() == "Transpose") {
if (input->op() == node->op()) {
return input->input(0);
}
}
// Remove inverse transposes.
if (node->op() == "Transpose" || node->op() == "ConjugateTranspose") {
const NodeDef* input = node_map->GetNode(node->input(0));
if (input->op() == node->op()) {
const NodeDef* node_perm = node_map->GetNode(node->input(1));
const NodeDef* input_perm = node_map->GetNode(input->input(1));
std::vector<int> node_perm_values;

View File

@ -109,6 +109,28 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
EXPECT_EQ("add1", new_add3.input(1));
}
TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
Output neg1 = ops::Neg(s.WithOpName("neg1"), c);
Output neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2);
Output recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
Output id = ops::Identity(s.WithOpName("id"), recip2);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ArithmeticOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(6, output.node_size());
EXPECT_EQ("c", output.node(1).input(0));
EXPECT_EQ("c", output.node(3).input(0));
EXPECT_EQ("c", output.node(5).input(0));
}
TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =