Add function in shape inference to try to infer output tensor content based on

the input shapes of the op. In some cases (E.g: shape), knowing the shapes of
the input is all that is necessary to infer the content of the output tensor.
This improves shape inference.

PiperOrigin-RevId: 158079306
This commit is contained in:
A. Unique TensorFlower 2017-06-05 16:39:45 -07:00 committed by TensorFlower Gardener
parent 0cc851c08f
commit a58553e4db
4 changed files with 330 additions and 4 deletions

View File

@ -1580,6 +1580,7 @@ tf_cuda_library(
":lib_internal",
":proto_text",
":protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
"//tensorflow/core/kernels:required",
"//third_party/eigen3",
],

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/public/session.h"
@ -256,6 +257,85 @@ Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
return Status::OK();
}
Status ShapeRefiner::TryToInferTensorOutputFromInputShapes(const Edge* edge,
Tensor* output,
bool* success) {
*success = false;
const Node* node = edge->src();
auto it = node_to_context_.find(node);
if (it == node_to_context_.end()) {
return errors::FailedPrecondition("Node does not have context.");
}
InferenceContext* c = it->second.get();
if (node->def().op() == "Shape") {
// If input shapes to the shape op are fully defined,
// we can infer the shape op's output tensor.
bool fully_defined_inputs = c->FullyDefined(c->input(0));
if (fully_defined_inputs) {
int input_rank = c->Rank(c->input(0));
Tensor t(node->output_type(0), TensorShape({input_rank}));
if (node->output_type(0) == DT_INT32) {
auto flat = t.flat<int>();
for (int i = 0; i < input_rank; i++) {
int64 dimension = c->Value(c->Dim(c->input(0), i));
if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
return errors::FailedPrecondition(
"Shape has output type int32, but dimension exceeds maximum "
"int32 value");
}
flat(i) = static_cast<int32>(dimension);
}
} else if (node->output_type(0) == DT_INT64) {
auto flat = t.flat<int64>();
for (int i = 0; i < input_rank; i++) {
flat(i) = c->Value(c->Dim(c->input(0), i));
}
} else {
return errors::FailedPrecondition(
"Shape has output type that is not int32 or int64");
}
*output = t;
*success = true;
}
} else if (node->def().op() == "Rank") {
bool rank_known = c->RankKnown(c->input(0));
if (rank_known) {
int32 input_rank = c->Rank(c->input(0));
Tensor t(node->output_type(0), TensorShape({}));
t.flat<int32>()(0) = input_rank;
*output = t;
*success = true;
}
} else if (node->def().op() == "Size") {
bool fully_defined_inputs = c->FullyDefined(c->input(0));
if (fully_defined_inputs) {
int32 rank = c->Rank(c->input(0));
Tensor t(node->output_type(0), TensorShape({}));
int64 size = 1;
for (int i = 0; i < rank; i++) {
size *= c->Value(c->Dim(c->input(0), i));
}
if (node->output_type(0) == DT_INT32) {
if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
return errors::FailedPrecondition(
"Size has output type int32, but size exceeds maximum int32 "
"value");
}
t.flat<int32>()(0) = static_cast<int32>(size);
} else if (node->output_type(0) == DT_INT64) {
t.flat<int64>()(0) = size;
} else {
return errors::FailedPrecondition(
"Size has output type that is not int32 or int64");
}
*output = t;
*success = true;
}
}
return Status::OK();
}
Status ShapeRefiner::ExtractConstantSubgraph(
Node* target_node, Graph* out_graph, bool* is_constant_graph,
std::vector<std::pair<string, Tensor>>* const_inputs) {
@ -356,15 +436,27 @@ Status ShapeRefiner::ExtractConstantSubgraph(
dst_copy, current_edge->dst_input());
}
// If we have a copy of the input tensor materialized already,
// then add to the list of inputs to feed and do not recurse further.
const string& output_tensor_name =
strings::StrCat(current_node->name(), ":", current_edge->src_output());
// Some tensor values can be inferred. For example, a shape op
// with input shapes fully defined can have its output tensor inferred.
Tensor tensor_inferred;
bool successfully_inferred_tensor = false;
TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
current_edge, &tensor_inferred, &successfully_inferred_tensor));
if (successfully_inferred_tensor) {
const_inputs->emplace_back(output_tensor_name, tensor_inferred);
const_inputs_added.insert(output_tensor_name);
continue;
}
// If we have a copy of the input tensor materialized already,
// then add to the list of inputs to feed and do not recurse further.
auto it = const_tensor_map_.find(output_tensor_name);
if (it != const_tensor_map_.end() &&
const_inputs_added.count(output_tensor_name) == 0) {
const_inputs->emplace_back(
std::make_pair(output_tensor_name, it->second));
const_inputs->emplace_back(output_tensor_name, it->second);
const_inputs_added.insert(output_tensor_name);
continue;
}

View File

@ -78,6 +78,13 @@ class ShapeRefiner {
}
private:
// Tries to infer tensor output based on the input shapes of the node. In some
// cases, the shapes of the inputs are sufficient for inferring the contents
// of the output tensor. For example, a Shape op with fully defined input
// shapes can have its output tensor inferred.
Status TryToInferTensorOutputFromInputShapes(const Edge* edge, Tensor* output,
bool* success);
// Extracts the subgraph ending at 'node' that is statically
// computable and inserts into 'out_graph'. If statically computable,
// 'is_constant_graph' will be true.

View File

@ -268,8 +268,234 @@ REGISTER_OP("ShapeData")
return Status::OK();
});
REGISTER_OP("ShapeDataInt64")
.Input("a: int64")
.Output("o: int64")
.SetShapeFn([](shape_inference::InferenceContext* c) {
const Tensor* shape_data = c->input_tensor(0);
if (shape_data == nullptr) {
return shape_inference::UnknownShape(c);
}
std::vector<shape_inference::DimensionHandle> dims;
dims.reserve(shape_data->NumElements());
for (int i = 0; i < shape_data->NumElements(); ++i) {
dims.emplace_back(c->MakeDim(shape_data->flat<int64>()(i)));
}
c->set_output(0, c->MakeShape(dims));
return Status::OK();
});
} // namespace
TEST(ShapeRefinerTest, PropagateShapeAcrossTensorContent) {
Scope root = Scope::NewRootScope();
// Create variable 2x4 tensor.
auto input = ops::Variable(root, {2, 4}, DT_INT32);
// Shape is a vector of 2 elements (2,4)
auto shape = ops::Shape(root, input);
// Ones for indices of the slice. (get the 4).
auto ones = ops::Const(root, {1});
// Slice an element of the shape (4).
auto sliced = ops::Slice(root, shape, ones, ones);
Node* shape_data;
TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
.Input(sliced.node())
.Finalize(root.graph(), &shape_data));
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(ones.node()));
TF_ASSERT_OK(m.AddNode(input.node()));
TF_ASSERT_OK(m.AddNode(shape.node()));
TF_ASSERT_OK(m.AddNode(sliced.node()));
TF_ASSERT_OK(m.AddNode(shape_data));
shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
EXPECT_EQ("[4]", ctx->DebugString(ctx->output(0)));
}
TEST(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt64) {
Scope root = Scope::NewRootScope();
// Create variable 2x4 tensor.
auto input = ops::Variable(
root, {2, 4, static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
DT_INT64);
// Shape is a vector of 2 elements (2,4)
auto attrs = ops::Shape::OutType(DT_INT64);
auto shape = ops::Shape(root, input, attrs);
// Ones for indices of the slice. (get the 4).
auto ones = ops::Const(root, {1});
// Slice an element of the shape (4).
auto sliced = ops::Slice(root, shape, ones, ones);
Node* shape_data;
TF_ASSERT_OK(NodeBuilder("Test", "ShapeDataInt64")
.Input(sliced.node())
.Finalize(root.graph(), &shape_data));
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(ones.node()));
TF_ASSERT_OK(m.AddNode(input.node()));
TF_ASSERT_OK(m.AddNode(shape.node()));
TF_ASSERT_OK(m.AddNode(sliced.node()));
TF_ASSERT_OK(m.AddNode(shape_data));
shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
EXPECT_EQ("[4]", ctx->DebugString(ctx->output(0)));
}
TEST(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt32Overflow) {
Scope root = Scope::NewRootScope();
// Create variable 2x4 tensor.
auto input = ops::Variable(
root, {2, 4, static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
DT_INT32);
// Shape is a vector of 2 elements (2,4)
auto shape = ops::Shape(root, input);
// Ones for indices of the slice. (get the 4).
auto ones = ops::Const(root, {1});
// Slice an element of the shape (4).
auto sliced = ops::Slice(root, shape, ones, ones);
Node* shape_data;
TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
.Input(sliced.node())
.Finalize(root.graph(), &shape_data));
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(ones.node()));
TF_ASSERT_OK(m.AddNode(input.node()));
TF_ASSERT_OK(m.AddNode(shape.node()));
TF_ASSERT_OK(m.AddNode(sliced.node()));
// Expect an error since there's an overflow.
EXPECT_FALSE(m.AddNode(shape_data).ok());
}
TEST(ShapeRefinerTest, PropagateRankAcrossTensorContent) {
Scope root = Scope::NewRootScope();
// Create variable 2x4x3 tensor.
auto input = ops::Variable(root, {2, 4, 3}, DT_INT32);
// Rank 3.
auto rank = ops::Rank(root, input);
auto identity = ops::Identity(root, rank);
Node* shape_data;
TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
.Input(identity.node())
.Finalize(root.graph(), &shape_data));
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input.node()));
TF_ASSERT_OK(m.AddNode(rank.node()));
TF_ASSERT_OK(m.AddNode(identity.node()));
TF_ASSERT_OK(m.AddNode(shape_data));
shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
EXPECT_EQ("[3]", ctx->DebugString(ctx->output(0)));
}
TEST(ShapeRefinerTest, PropagateSizeAcrossTensorContent) {
Scope root = Scope::NewRootScope();
// Create variable.
auto input = ops::Variable(root, {1, 2, 3, 4, 5}, DT_INT32);
// 5!.
auto size = ops::Size(root, input);
auto identity = ops::Identity(root, size);
Node* shape_data;
TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
.Input(identity.node())
.Finalize(root.graph(), &shape_data));
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input.node()));
TF_ASSERT_OK(m.AddNode(size.node()));
TF_ASSERT_OK(m.AddNode(identity.node()));
TF_ASSERT_OK(m.AddNode(shape_data));
shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
EXPECT_EQ("[120]", ctx->DebugString(ctx->output(0)));
}
TEST(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt64) {
Scope root = Scope::NewRootScope();
// Create variable.
auto input =
ops::Variable(root,
{1, 2, 3, 4, 5,
static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
DT_INT64);
// 5! * int32_max_value * 2.
auto attrs = ops::Size::OutType(DT_INT64);
auto size = ops::Size(root, input, attrs);
auto identity = ops::Identity(root, size);
Node* shape_data;
TF_ASSERT_OK(NodeBuilder("Test", "ShapeDataInt64")
.Input(identity.node())
.Finalize(root.graph(), &shape_data));
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input.node()));
TF_ASSERT_OK(m.AddNode(size.node()));
TF_ASSERT_OK(m.AddNode(identity.node()));
TF_ASSERT_OK(m.AddNode(shape_data));
shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
EXPECT_EQ("[515396075280]", ctx->DebugString(ctx->output(0)));
}
TEST(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt32Overflow) {
Scope root = Scope::NewRootScope();
// Create variable.
auto input =
ops::Variable(root,
{1, 2, 3, 4, 5,
static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
DT_INT32);
// 5!.
auto size = ops::Size(root, input);
auto identity = ops::Identity(root, size);
Node* shape_data;
TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
.Input(identity.node())
.Finalize(root.graph(), &shape_data));
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input.node()));
TF_ASSERT_OK(m.AddNode(size.node()));
TF_ASSERT_OK(m.AddNode(identity.node()));
EXPECT_FALSE(m.AddNode(shape_data).ok());
}
TEST(ShapeRefinerTest, PropagateShape) {
Scope root = Scope::NewRootScope();
// 3x2 input