Better handle nodes with a variable number of outputs

PiperOrigin-RevId: 158435028
This commit is contained in:
Benoit Steiner 2017-06-08 13:39:48 -07:00 committed by TensorFlower Gardener
parent bc3e20807a
commit f0e185d1f5
3 changed files with 84 additions and 14 deletions

View File

@ -110,15 +110,17 @@ cc_test(
deps = [
":constant_folding",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:direct_session",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
],
)

View File

@ -90,17 +90,6 @@ class DeviceSimple : public DeviceBase {
std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
};
Status NumOutputs(const NodeDef& node, int* num_outputs) {
const OpDef* op_def = nullptr;
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
if (node.op() == "ConcatOffset") {
*num_outputs = node.attr().at("N").i();
} else {
*num_outputs = op_def->output_arg_size();
}
return Status::OK();
}
string AsControlDependency(const NodeDef& node) {
return strings::StrCat("^", node.name());
}
@ -341,15 +330,16 @@ Status ConstantFolding::EvaluateNode(const NodeDef& node,
params.frame_iter = FrameAndIter(0, 0);
params.inputs = &inputs;
params.op_kernel = op_kernel.get();
int num_outputs;
TF_RETURN_IF_ERROR(NumOutputs(node, &num_outputs));
gtl::InlinedVector<AllocatorAttributes, 4> output_attrs;
const int num_outputs = op_kernel->num_outputs();
for (int i = 0; i < num_outputs; i++) {
AllocatorAttributes attr;
attr.set_on_host(true);
output_attrs.push_back(attr);
}
params.output_attr_array = output_attrs.data();
OpKernelContext op_context(&params);
op_kernel->Compute(&op_context);
for (int i = 0; i < num_outputs; i++) {

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
@ -193,6 +194,83 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
EXPECT_EQ(2, found);
}
TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
// Add a DynamicPartition node to the graph
Output input = ops::Const(scope.WithOpName("in0"), 314, {3, 4, 5});
Output indices = ops::Const(scope.WithOpName("indices"), 1, {3, 4});
int num_partitions = 4;
ops::DynamicPartition part(scope.WithOpName("partition"), input, indices,
num_partitions);
std::vector<string> outputs;
for (int i = 0; i < num_partitions; ++i) {
string part_out_name = strings::StrCat("part_out", i);
ops::Identity partition_out(scope.WithOpName(part_out_name),
{part.outputs[i]});
outputs.push_back(part_out_name);
}
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
// Add a ConcatOffset node to the graph
Tensor initial_val(DT_INT32, TensorShape({3}));
test::FillIota<int>(&initial_val, 7);
for (int i = 1; i < 5; ++i) {
TF_CHECK_OK(NodeDefBuilder(strings::StrCat("in", i), "Const")
.Attr("dtype", DT_INT32)
.Attr("value", initial_val)
.Finalize(item.graph.add_node()));
}
Tensor concat_dim(DT_INT32, TensorShape({}));
test::FillIota<int>(&concat_dim, 0);
TF_CHECK_OK(NodeDefBuilder("concat_dim", "Const")
.Attr("dtype", DT_INT32)
.Attr("value", concat_dim)
.Finalize(item.graph.add_node()));
TF_CHECK_OK(NodeDefBuilder("concat_offsets", "ConcatOffset")
.Input("concat_dim", 0, DT_INT32)
.Input({NodeDefBuilder::NodeOut("in1", 0, DT_INT32),
NodeDefBuilder::NodeOut("in2", 0, DT_INT32),
NodeDefBuilder::NodeOut("in3", 0, DT_INT32),
NodeDefBuilder::NodeOut("in4", 0, DT_INT32)})
.Finalize(item.graph.add_node()));
for (int i = 0; i < 4; ++i) {
string concat_offset_out_name = strings::StrCat("concat_offset_out", i);
TF_CHECK_OK(NodeDefBuilder(concat_offset_out_name, "Identity")
.Attr("T", DT_INT32)
.Input("concat_offsets", i, DT_INT32)
.Finalize(item.graph.add_node()));
outputs.push_back(concat_offset_out_name);
}
item.fetch = outputs;
ConstantFolding fold;
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
int constant_folded = 0;
for (const auto& node : output.node()) {
if (node.name().find("ConstantFolding/partition") != string::npos ||
node.name().find("ConstantFolding/concat_offsets") != string::npos) {
++constant_folded;
EXPECT_EQ("Const", node.op());
}
}
EXPECT_EQ(8, constant_folded);
auto expected = EvaluateNodes(item.graph, outputs);
auto optimized = EvaluateNodes(output, outputs);
ASSERT_EQ(expected.size(), optimized.size());
for (int i = 0; i < expected.size(); ++i) {
test::ExpectTensorEqual<int>(expected[i], optimized[i]);
}
}
TEST_F(ConstantFoldingTest, ShapeMaterialization) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);