mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
* Add graph rewrite rule that removes repeated application of scalar unary ops that are involutions (their own inverse).
* Update rewrite rule for Transpose to also handle ConjugateTranspose. PiperOrigin-RevId: 173967184
This commit is contained in:
parent
ff5c276adf
commit
b46c196e9d
|
|
@ -31,6 +31,12 @@ namespace tensorflow {
|
|||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
static bool IsInvolution(const NodeDef& node) {
|
||||
const std::unordered_set<string> involution_ops = {"Conj", "Reciprocal",
|
||||
"Neg", "LogicalNot"};
|
||||
return involution_ops.count(node.op()) > 0;
|
||||
}
|
||||
|
||||
bool AreInversePermutations(gtl::ArraySlice<int32> a,
|
||||
gtl::ArraySlice<int32> b) {
|
||||
if (a.size() != b.size()) {
|
||||
|
|
@ -394,10 +400,20 @@ void ArithmeticOptimizer::DedupComputations(GraphDef* optimized_graph) const {
|
|||
string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
|
||||
const NodeDef* node, GraphDef* graph_def, NodeMap* node_map,
|
||||
std::vector<const NodeDef*>* new_nodes) const {
|
||||
// Remove inverse transposes.
|
||||
if (node->op() == "Transpose") {
|
||||
// Remove involutions applied twice.
|
||||
if (IsInvolution(*node)) {
|
||||
// An involution is a function f(x) that is its own inverse,
|
||||
// i.e. f(f(x)) = x.
|
||||
const NodeDef* input = node_map->GetNode(node->input(0));
|
||||
if (input->op() == "Transpose") {
|
||||
if (input->op() == node->op()) {
|
||||
return input->input(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove inverse transposes.
|
||||
if (node->op() == "Transpose" || node->op() == "ConjugateTranspose") {
|
||||
const NodeDef* input = node_map->GetNode(node->input(0));
|
||||
if (input->op() == node->op()) {
|
||||
const NodeDef* node_perm = node_map->GetNode(node->input(1));
|
||||
const NodeDef* input_perm = node_map->GetNode(input->input(1));
|
||||
std::vector<int> node_perm_values;
|
||||
|
|
|
|||
|
|
@ -109,6 +109,28 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
|
|||
EXPECT_EQ("add1", new_add3.input(1));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
|
||||
Output neg1 = ops::Neg(s.WithOpName("neg1"), c);
|
||||
Output neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
|
||||
Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2);
|
||||
Output recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
|
||||
Output id = ops::Identity(s.WithOpName("id"), recip2);
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
ArithmeticOptimizer optimizer;
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
EXPECT_EQ(6, output.node_size());
|
||||
EXPECT_EQ("c", output.node(1).input(0));
|
||||
EXPECT_EQ("c", output.node(3).input(0));
|
||||
EXPECT_EQ("c", output.node(5).input(0));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output inputs =
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user