mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Add function in shape inference to try to infer output tensor content based on
the input shapes of the op. In some cases (E.g: shape), knowing the shapes of the input is all that is necessary to infer the content of the output tensor. This improves shape inference. PiperOrigin-RevId: 158079306
This commit is contained in:
parent
0cc851c08f
commit
a58553e4db
|
|
@ -1580,6 +1580,7 @@ tf_cuda_library(
|
|||
":lib_internal",
|
||||
":proto_text",
|
||||
":protos_all_cc",
|
||||
"//tensorflow/core/kernels:bounds_check",
|
||||
"//tensorflow/core/kernels:required",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
|
@ -256,6 +257,85 @@ Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeRefiner::TryToInferTensorOutputFromInputShapes(const Edge* edge,
|
||||
Tensor* output,
|
||||
bool* success) {
|
||||
*success = false;
|
||||
const Node* node = edge->src();
|
||||
auto it = node_to_context_.find(node);
|
||||
if (it == node_to_context_.end()) {
|
||||
return errors::FailedPrecondition("Node does not have context.");
|
||||
}
|
||||
InferenceContext* c = it->second.get();
|
||||
|
||||
if (node->def().op() == "Shape") {
|
||||
// If input shapes to the shape op are fully defined,
|
||||
// we can infer the shape op's output tensor.
|
||||
bool fully_defined_inputs = c->FullyDefined(c->input(0));
|
||||
if (fully_defined_inputs) {
|
||||
int input_rank = c->Rank(c->input(0));
|
||||
Tensor t(node->output_type(0), TensorShape({input_rank}));
|
||||
if (node->output_type(0) == DT_INT32) {
|
||||
auto flat = t.flat<int>();
|
||||
for (int i = 0; i < input_rank; i++) {
|
||||
int64 dimension = c->Value(c->Dim(c->input(0), i));
|
||||
if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
|
||||
return errors::FailedPrecondition(
|
||||
"Shape has output type int32, but dimension exceeds maximum "
|
||||
"int32 value");
|
||||
}
|
||||
flat(i) = static_cast<int32>(dimension);
|
||||
}
|
||||
} else if (node->output_type(0) == DT_INT64) {
|
||||
auto flat = t.flat<int64>();
|
||||
for (int i = 0; i < input_rank; i++) {
|
||||
flat(i) = c->Value(c->Dim(c->input(0), i));
|
||||
}
|
||||
} else {
|
||||
return errors::FailedPrecondition(
|
||||
"Shape has output type that is not int32 or int64");
|
||||
}
|
||||
*output = t;
|
||||
*success = true;
|
||||
}
|
||||
} else if (node->def().op() == "Rank") {
|
||||
bool rank_known = c->RankKnown(c->input(0));
|
||||
if (rank_known) {
|
||||
int32 input_rank = c->Rank(c->input(0));
|
||||
Tensor t(node->output_type(0), TensorShape({}));
|
||||
t.flat<int32>()(0) = input_rank;
|
||||
*output = t;
|
||||
*success = true;
|
||||
}
|
||||
} else if (node->def().op() == "Size") {
|
||||
bool fully_defined_inputs = c->FullyDefined(c->input(0));
|
||||
if (fully_defined_inputs) {
|
||||
int32 rank = c->Rank(c->input(0));
|
||||
Tensor t(node->output_type(0), TensorShape({}));
|
||||
int64 size = 1;
|
||||
for (int i = 0; i < rank; i++) {
|
||||
size *= c->Value(c->Dim(c->input(0), i));
|
||||
}
|
||||
if (node->output_type(0) == DT_INT32) {
|
||||
if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
|
||||
return errors::FailedPrecondition(
|
||||
"Size has output type int32, but size exceeds maximum int32 "
|
||||
"value");
|
||||
}
|
||||
t.flat<int32>()(0) = static_cast<int32>(size);
|
||||
} else if (node->output_type(0) == DT_INT64) {
|
||||
t.flat<int64>()(0) = size;
|
||||
} else {
|
||||
return errors::FailedPrecondition(
|
||||
"Size has output type that is not int32 or int64");
|
||||
}
|
||||
*output = t;
|
||||
*success = true;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeRefiner::ExtractConstantSubgraph(
|
||||
Node* target_node, Graph* out_graph, bool* is_constant_graph,
|
||||
std::vector<std::pair<string, Tensor>>* const_inputs) {
|
||||
|
|
@ -356,15 +436,27 @@ Status ShapeRefiner::ExtractConstantSubgraph(
|
|||
dst_copy, current_edge->dst_input());
|
||||
}
|
||||
|
||||
// If we have a copy of the input tensor materialized already,
|
||||
// then add to the list of inputs to feed and do not recurse further.
|
||||
const string& output_tensor_name =
|
||||
strings::StrCat(current_node->name(), ":", current_edge->src_output());
|
||||
|
||||
// Some tensor values can be inferred. For example, a shape op
|
||||
// with input shapes fully defined can have its output tensor inferred.
|
||||
Tensor tensor_inferred;
|
||||
bool successfully_inferred_tensor = false;
|
||||
TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
|
||||
current_edge, &tensor_inferred, &successfully_inferred_tensor));
|
||||
if (successfully_inferred_tensor) {
|
||||
const_inputs->emplace_back(output_tensor_name, tensor_inferred);
|
||||
const_inputs_added.insert(output_tensor_name);
|
||||
continue;
|
||||
}
|
||||
|
||||
// If we have a copy of the input tensor materialized already,
|
||||
// then add to the list of inputs to feed and do not recurse further.
|
||||
auto it = const_tensor_map_.find(output_tensor_name);
|
||||
if (it != const_tensor_map_.end() &&
|
||||
const_inputs_added.count(output_tensor_name) == 0) {
|
||||
const_inputs->emplace_back(
|
||||
std::make_pair(output_tensor_name, it->second));
|
||||
const_inputs->emplace_back(output_tensor_name, it->second);
|
||||
const_inputs_added.insert(output_tensor_name);
|
||||
continue;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,6 +78,13 @@ class ShapeRefiner {
|
|||
}
|
||||
|
||||
private:
|
||||
// Tries to infer tensor output based on the input shapes of the node. In some
|
||||
// cases, the shapes of the inputs are sufficient for inferring the contents
|
||||
// of the output tensor. For example, a Shape op with fully defined input
|
||||
// shapes can have its output tensor inferred.
|
||||
Status TryToInferTensorOutputFromInputShapes(const Edge* edge, Tensor* output,
|
||||
bool* success);
|
||||
|
||||
// Extracts the subgraph ending at 'node' that is statically
|
||||
// computable and inserts into 'out_graph'. If statically computable,
|
||||
// 'is_constant_graph' will be true.
|
||||
|
|
|
|||
|
|
@ -268,8 +268,234 @@ REGISTER_OP("ShapeData")
|
|||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("ShapeDataInt64")
|
||||
.Input("a: int64")
|
||||
.Output("o: int64")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
const Tensor* shape_data = c->input_tensor(0);
|
||||
if (shape_data == nullptr) {
|
||||
return shape_inference::UnknownShape(c);
|
||||
}
|
||||
|
||||
std::vector<shape_inference::DimensionHandle> dims;
|
||||
dims.reserve(shape_data->NumElements());
|
||||
for (int i = 0; i < shape_data->NumElements(); ++i) {
|
||||
dims.emplace_back(c->MakeDim(shape_data->flat<int64>()(i)));
|
||||
}
|
||||
|
||||
c->set_output(0, c->MakeShape(dims));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(ShapeRefinerTest, PropagateShapeAcrossTensorContent) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
// Create variable 2x4 tensor.
|
||||
auto input = ops::Variable(root, {2, 4}, DT_INT32);
|
||||
|
||||
// Shape is a vector of 2 elements (2,4)
|
||||
auto shape = ops::Shape(root, input);
|
||||
|
||||
// Ones for indices of the slice. (get the 4).
|
||||
auto ones = ops::Const(root, {1});
|
||||
|
||||
// Slice an element of the shape (4).
|
||||
auto sliced = ops::Slice(root, shape, ones, ones);
|
||||
|
||||
Node* shape_data;
|
||||
TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
|
||||
.Input(sliced.node())
|
||||
.Finalize(root.graph(), &shape_data));
|
||||
|
||||
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
|
||||
TF_ASSERT_OK(m.AddNode(ones.node()));
|
||||
TF_ASSERT_OK(m.AddNode(input.node()));
|
||||
TF_ASSERT_OK(m.AddNode(shape.node()));
|
||||
TF_ASSERT_OK(m.AddNode(sliced.node()));
|
||||
TF_ASSERT_OK(m.AddNode(shape_data));
|
||||
|
||||
shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
|
||||
EXPECT_EQ("[4]", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt64) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
// Create variable 2x4 tensor.
|
||||
auto input = ops::Variable(
|
||||
root, {2, 4, static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
|
||||
DT_INT64);
|
||||
|
||||
// Shape is a vector of 2 elements (2,4)
|
||||
auto attrs = ops::Shape::OutType(DT_INT64);
|
||||
auto shape = ops::Shape(root, input, attrs);
|
||||
|
||||
// Ones for indices of the slice. (get the 4).
|
||||
auto ones = ops::Const(root, {1});
|
||||
|
||||
// Slice an element of the shape (4).
|
||||
auto sliced = ops::Slice(root, shape, ones, ones);
|
||||
|
||||
Node* shape_data;
|
||||
TF_ASSERT_OK(NodeBuilder("Test", "ShapeDataInt64")
|
||||
.Input(sliced.node())
|
||||
.Finalize(root.graph(), &shape_data));
|
||||
|
||||
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
|
||||
TF_ASSERT_OK(m.AddNode(ones.node()));
|
||||
TF_ASSERT_OK(m.AddNode(input.node()));
|
||||
TF_ASSERT_OK(m.AddNode(shape.node()));
|
||||
TF_ASSERT_OK(m.AddNode(sliced.node()));
|
||||
TF_ASSERT_OK(m.AddNode(shape_data));
|
||||
|
||||
shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
|
||||
EXPECT_EQ("[4]", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt32Overflow) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
// Create variable 2x4 tensor.
|
||||
auto input = ops::Variable(
|
||||
root, {2, 4, static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
|
||||
DT_INT32);
|
||||
|
||||
// Shape is a vector of 2 elements (2,4)
|
||||
auto shape = ops::Shape(root, input);
|
||||
|
||||
// Ones for indices of the slice. (get the 4).
|
||||
auto ones = ops::Const(root, {1});
|
||||
|
||||
// Slice an element of the shape (4).
|
||||
auto sliced = ops::Slice(root, shape, ones, ones);
|
||||
|
||||
Node* shape_data;
|
||||
TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
|
||||
.Input(sliced.node())
|
||||
.Finalize(root.graph(), &shape_data));
|
||||
|
||||
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
|
||||
TF_ASSERT_OK(m.AddNode(ones.node()));
|
||||
TF_ASSERT_OK(m.AddNode(input.node()));
|
||||
TF_ASSERT_OK(m.AddNode(shape.node()));
|
||||
TF_ASSERT_OK(m.AddNode(sliced.node()));
|
||||
|
||||
// Expect an error since there's an overflow.
|
||||
EXPECT_FALSE(m.AddNode(shape_data).ok());
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, PropagateRankAcrossTensorContent) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
// Create variable 2x4x3 tensor.
|
||||
auto input = ops::Variable(root, {2, 4, 3}, DT_INT32);
|
||||
|
||||
// Rank 3.
|
||||
auto rank = ops::Rank(root, input);
|
||||
|
||||
auto identity = ops::Identity(root, rank);
|
||||
|
||||
Node* shape_data;
|
||||
TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
|
||||
.Input(identity.node())
|
||||
.Finalize(root.graph(), &shape_data));
|
||||
|
||||
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
|
||||
TF_ASSERT_OK(m.AddNode(input.node()));
|
||||
TF_ASSERT_OK(m.AddNode(rank.node()));
|
||||
TF_ASSERT_OK(m.AddNode(identity.node()));
|
||||
TF_ASSERT_OK(m.AddNode(shape_data));
|
||||
|
||||
shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
|
||||
EXPECT_EQ("[3]", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, PropagateSizeAcrossTensorContent) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
// Create variable.
|
||||
auto input = ops::Variable(root, {1, 2, 3, 4, 5}, DT_INT32);
|
||||
|
||||
// 5!.
|
||||
auto size = ops::Size(root, input);
|
||||
|
||||
auto identity = ops::Identity(root, size);
|
||||
|
||||
Node* shape_data;
|
||||
TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
|
||||
.Input(identity.node())
|
||||
.Finalize(root.graph(), &shape_data));
|
||||
|
||||
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
|
||||
TF_ASSERT_OK(m.AddNode(input.node()));
|
||||
TF_ASSERT_OK(m.AddNode(size.node()));
|
||||
TF_ASSERT_OK(m.AddNode(identity.node()));
|
||||
TF_ASSERT_OK(m.AddNode(shape_data));
|
||||
|
||||
shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
|
||||
EXPECT_EQ("[120]", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt64) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
// Create variable.
|
||||
auto input =
|
||||
ops::Variable(root,
|
||||
{1, 2, 3, 4, 5,
|
||||
static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
|
||||
DT_INT64);
|
||||
|
||||
// 5! * int32_max_value * 2.
|
||||
auto attrs = ops::Size::OutType(DT_INT64);
|
||||
auto size = ops::Size(root, input, attrs);
|
||||
|
||||
auto identity = ops::Identity(root, size);
|
||||
|
||||
Node* shape_data;
|
||||
TF_ASSERT_OK(NodeBuilder("Test", "ShapeDataInt64")
|
||||
.Input(identity.node())
|
||||
.Finalize(root.graph(), &shape_data));
|
||||
|
||||
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
|
||||
TF_ASSERT_OK(m.AddNode(input.node()));
|
||||
TF_ASSERT_OK(m.AddNode(size.node()));
|
||||
TF_ASSERT_OK(m.AddNode(identity.node()));
|
||||
TF_ASSERT_OK(m.AddNode(shape_data));
|
||||
|
||||
shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
|
||||
EXPECT_EQ("[515396075280]", ctx->DebugString(ctx->output(0)));
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt32Overflow) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
// Create variable.
|
||||
auto input =
|
||||
ops::Variable(root,
|
||||
{1, 2, 3, 4, 5,
|
||||
static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
|
||||
DT_INT32);
|
||||
|
||||
// 5!.
|
||||
auto size = ops::Size(root, input);
|
||||
|
||||
auto identity = ops::Identity(root, size);
|
||||
|
||||
Node* shape_data;
|
||||
TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
|
||||
.Input(identity.node())
|
||||
.Finalize(root.graph(), &shape_data));
|
||||
|
||||
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
|
||||
TF_ASSERT_OK(m.AddNode(input.node()));
|
||||
TF_ASSERT_OK(m.AddNode(size.node()));
|
||||
TF_ASSERT_OK(m.AddNode(identity.node()));
|
||||
EXPECT_FALSE(m.AddNode(shape_data).ok());
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, PropagateShape) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
// 3x2 input
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user