mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Better handle nodes with a variable number of outputs
PiperOrigin-RevId: 158435028
This commit is contained in:
parent
bc3e20807a
commit
f0e185d1f5
|
|
@ -110,15 +110,17 @@ cc_test(
|
|||
deps = [
|
||||
":constant_folding",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -90,17 +90,6 @@ class DeviceSimple : public DeviceBase {
|
|||
std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
|
||||
};
|
||||
|
||||
Status NumOutputs(const NodeDef& node, int* num_outputs) {
|
||||
const OpDef* op_def = nullptr;
|
||||
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
|
||||
if (node.op() == "ConcatOffset") {
|
||||
*num_outputs = node.attr().at("N").i();
|
||||
} else {
|
||||
*num_outputs = op_def->output_arg_size();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string AsControlDependency(const NodeDef& node) {
|
||||
return strings::StrCat("^", node.name());
|
||||
}
|
||||
|
|
@ -341,15 +330,16 @@ Status ConstantFolding::EvaluateNode(const NodeDef& node,
|
|||
params.frame_iter = FrameAndIter(0, 0);
|
||||
params.inputs = &inputs;
|
||||
params.op_kernel = op_kernel.get();
|
||||
int num_outputs;
|
||||
TF_RETURN_IF_ERROR(NumOutputs(node, &num_outputs));
|
||||
|
||||
gtl::InlinedVector<AllocatorAttributes, 4> output_attrs;
|
||||
const int num_outputs = op_kernel->num_outputs();
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
output_attrs.push_back(attr);
|
||||
}
|
||||
params.output_attr_array = output_attrs.data();
|
||||
|
||||
OpKernelContext op_context(¶ms);
|
||||
op_kernel->Compute(&op_context);
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
|
|
@ -193,6 +194,83 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
|
|||
EXPECT_EQ(2, found);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
|
||||
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||
// Add a DynamicPartition node to the graph
|
||||
Output input = ops::Const(scope.WithOpName("in0"), 314, {3, 4, 5});
|
||||
Output indices = ops::Const(scope.WithOpName("indices"), 1, {3, 4});
|
||||
int num_partitions = 4;
|
||||
ops::DynamicPartition part(scope.WithOpName("partition"), input, indices,
|
||||
num_partitions);
|
||||
|
||||
std::vector<string> outputs;
|
||||
for (int i = 0; i < num_partitions; ++i) {
|
||||
string part_out_name = strings::StrCat("part_out", i);
|
||||
ops::Identity partition_out(scope.WithOpName(part_out_name),
|
||||
{part.outputs[i]});
|
||||
outputs.push_back(part_out_name);
|
||||
}
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
// Add a ConcatOffset node to the graph
|
||||
Tensor initial_val(DT_INT32, TensorShape({3}));
|
||||
test::FillIota<int>(&initial_val, 7);
|
||||
for (int i = 1; i < 5; ++i) {
|
||||
TF_CHECK_OK(NodeDefBuilder(strings::StrCat("in", i), "Const")
|
||||
.Attr("dtype", DT_INT32)
|
||||
.Attr("value", initial_val)
|
||||
.Finalize(item.graph.add_node()));
|
||||
}
|
||||
Tensor concat_dim(DT_INT32, TensorShape({}));
|
||||
test::FillIota<int>(&concat_dim, 0);
|
||||
TF_CHECK_OK(NodeDefBuilder("concat_dim", "Const")
|
||||
.Attr("dtype", DT_INT32)
|
||||
.Attr("value", concat_dim)
|
||||
.Finalize(item.graph.add_node()));
|
||||
|
||||
TF_CHECK_OK(NodeDefBuilder("concat_offsets", "ConcatOffset")
|
||||
.Input("concat_dim", 0, DT_INT32)
|
||||
.Input({NodeDefBuilder::NodeOut("in1", 0, DT_INT32),
|
||||
NodeDefBuilder::NodeOut("in2", 0, DT_INT32),
|
||||
NodeDefBuilder::NodeOut("in3", 0, DT_INT32),
|
||||
NodeDefBuilder::NodeOut("in4", 0, DT_INT32)})
|
||||
.Finalize(item.graph.add_node()));
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
string concat_offset_out_name = strings::StrCat("concat_offset_out", i);
|
||||
TF_CHECK_OK(NodeDefBuilder(concat_offset_out_name, "Identity")
|
||||
.Attr("T", DT_INT32)
|
||||
.Input("concat_offsets", i, DT_INT32)
|
||||
.Finalize(item.graph.add_node()));
|
||||
outputs.push_back(concat_offset_out_name);
|
||||
}
|
||||
|
||||
item.fetch = outputs;
|
||||
ConstantFolding fold;
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
int constant_folded = 0;
|
||||
for (const auto& node : output.node()) {
|
||||
if (node.name().find("ConstantFolding/partition") != string::npos ||
|
||||
node.name().find("ConstantFolding/concat_offsets") != string::npos) {
|
||||
++constant_folded;
|
||||
EXPECT_EQ("Const", node.op());
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(8, constant_folded);
|
||||
|
||||
auto expected = EvaluateNodes(item.graph, outputs);
|
||||
auto optimized = EvaluateNodes(output, outputs);
|
||||
ASSERT_EQ(expected.size(), optimized.size());
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
test::ExpectTensorEqual<int>(expected[i], optimized[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, ShapeMaterialization) {
|
||||
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||
Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user