mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Fixed BuildOpInfoWithoutDevice
PiperOrigin-RevId: 165653933
This commit is contained in:
parent
d7e425f0bd
commit
513def0bb2
|
|
@ -141,6 +141,24 @@ tf_cuda_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "utils_test",
|
||||
srcs = ["utils_test.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cost_estimator",
|
||||
hdrs = ["cost_estimator.h"],
|
||||
|
|
@ -170,7 +188,7 @@ cc_test(
|
|||
srcs = ["virtual_placer_test.cc"],
|
||||
deps = [
|
||||
":virtual_placer",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
|
|
|||
|
|
@ -70,11 +70,12 @@ static std::vector<TensorProto> ExtractTensors(const AttrValue& attr_value) {
|
|||
return tensors;
|
||||
}
|
||||
|
||||
// Annotate the op_info inputs with extra information when possible (e.g. the
|
||||
// input value if it's known statically).
|
||||
static void ExtractExtraProperties(
|
||||
const NodeDef& node,
|
||||
const std::unordered_map<string, const NodeDef*>& name_to_node,
|
||||
std::vector<OpInfo::TensorProperties>* extra_inputs,
|
||||
protobuf::Map<string, AttrValue>* attr_map) {
|
||||
OpInfo* op_info) {
|
||||
OpRegistry* op_registry = OpRegistry::Global();
|
||||
const OpDef* op_def = nullptr;
|
||||
auto s = op_registry->LookUpOpDef(node.op(), &op_def);
|
||||
|
|
@ -102,11 +103,8 @@ static void ExtractExtraProperties(
|
|||
if (tensors.empty()) continue;
|
||||
|
||||
const TensorProto& t = tensors[0];
|
||||
OpInfo::TensorProperties input;
|
||||
input.set_dtype(t.dtype());
|
||||
*(input.mutable_shape()) = t.tensor_shape();
|
||||
*(input.mutable_value()) = t;
|
||||
extra_inputs->push_back(input);
|
||||
OpInfo::TensorProperties* input = op_info->mutable_inputs(i);
|
||||
*(input->mutable_value()) = t;
|
||||
|
||||
// For filename input, the file size can also be useful.
|
||||
if (op_def && i < op_def->input_arg_size() &&
|
||||
|
|
@ -129,7 +127,7 @@ static void ExtractExtraProperties(
|
|||
AttrValue attr;
|
||||
attr.set_i(stat.length);
|
||||
string attr_key = strings::StrCat("input_", i, "_filesize");
|
||||
(*attr_map)[attr_key] = attr;
|
||||
(*op_info->mutable_attr())[attr_key] = attr;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -140,7 +138,7 @@ static void ExtractExtraProperties(
|
|||
string new_key = strings::StrCat("parent_", i, "_op");
|
||||
AttrValue attr;
|
||||
attr.set_s(input_node->op());
|
||||
(*attr_map)[new_key] = attr;
|
||||
(*op_info->mutable_attr())[new_key] = attr;
|
||||
// TODO(yuefengz): Only parent node's op name is copied. Copy inputs
|
||||
// and attributes when necessary.
|
||||
}
|
||||
|
|
@ -212,14 +210,7 @@ OpInfo BuildOpInfoWithoutDevice(
|
|||
for (auto& input : inputs) {
|
||||
*op_info.add_inputs() = input;
|
||||
}
|
||||
|
||||
std::vector<OpInfo::TensorProperties> extra_inputs;
|
||||
ExtractExtraProperties(node, name_to_node, &extra_inputs,
|
||||
op_info.mutable_attr());
|
||||
for (auto& input : extra_inputs) {
|
||||
*op_info.add_inputs() = input;
|
||||
}
|
||||
|
||||
ExtractExtraProperties(node, name_to_node, &op_info);
|
||||
return op_info;
|
||||
}
|
||||
|
||||
|
|
|
|||
150
tensorflow/core/grappler/costs/utils_test.cc
Normal file
150
tensorflow/core/grappler/costs/utils_test.cc
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class UtilsTest : public ::testing::Test {
|
||||
public:
|
||||
void CreateConstOp(const string& name, std::initializer_list<int64> dims,
|
||||
NodeDef* node) {
|
||||
Tensor tensor(DT_FLOAT, TensorShape(dims));
|
||||
for (int64 i = 0; i < tensor.NumElements(); ++i) {
|
||||
tensor.flat<float>()(i) = i / 10.0f;
|
||||
}
|
||||
TF_CHECK_OK(NodeDefBuilder(name, "Const")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("value", tensor)
|
||||
.Finalize(node));
|
||||
}
|
||||
|
||||
void CreateConstSizesOp(const string& name, const std::vector<int32>& sizes,
|
||||
NodeDef* node) {
|
||||
TensorShape shape;
|
||||
shape.AddDim(sizes.size());
|
||||
Tensor tensor(DT_INT32, shape);
|
||||
for (int64 i = 0; i < tensor.NumElements(); ++i) {
|
||||
tensor.flat<int32>()(i) = sizes[i];
|
||||
}
|
||||
TF_CHECK_OK(NodeDefBuilder(name, "Const")
|
||||
.Attr("dtype", DT_INT32)
|
||||
.Attr("value", tensor)
|
||||
.Finalize(node));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(UtilsTest, ConvOpInfo) {
|
||||
int batch = 32;
|
||||
int rows = 7;
|
||||
int cols = 9;
|
||||
int filter_rows = 3;
|
||||
int filter_cols = 3;
|
||||
int out_rows = 7;
|
||||
int out_cols = 9;
|
||||
int in_depth = 3;
|
||||
int out_depth = 5;
|
||||
int stride = 1;
|
||||
|
||||
std::unordered_map<string, const NodeDef*> name_to_node;
|
||||
GraphDef graph;
|
||||
NodeDef* input = graph.add_node();
|
||||
name_to_node["input"] = input;
|
||||
CreateConstOp("input", {batch, rows, cols, in_depth}, input);
|
||||
NodeDef* filter = graph.add_node();
|
||||
name_to_node["filter"] = filter;
|
||||
CreateConstOp("filter", {filter_rows, filter_cols, in_depth, out_depth},
|
||||
filter);
|
||||
NodeDef* output_backprop = graph.add_node();
|
||||
name_to_node["output_backprop"] = output_backprop;
|
||||
CreateConstOp("output_backprop", {batch, out_rows, out_cols, out_depth},
|
||||
output_backprop);
|
||||
NodeDef* input_sizes = graph.add_node();
|
||||
name_to_node["input_sizes"] = input;
|
||||
CreateConstSizesOp("input_sizes",
|
||||
std::vector<int32>({batch, rows, cols, in_depth}),
|
||||
input_sizes);
|
||||
NodeDef* filter_sizes = graph.add_node();
|
||||
name_to_node["filter_sizes"] = filter_sizes;
|
||||
CreateConstSizesOp(
|
||||
"filter_sizes",
|
||||
std::vector<int32>({filter_rows, filter_cols, in_depth, out_depth}),
|
||||
filter_sizes);
|
||||
|
||||
TensorShape paddings_shape({4, 2});
|
||||
Tensor paddings_tensor(DT_INT32, paddings_shape);
|
||||
for (int64 i = 0; i < paddings_tensor.NumElements(); ++i) {
|
||||
paddings_tensor.flat<int32>()(i) = 0;
|
||||
}
|
||||
TF_CHECK_OK(NodeDefBuilder("paddings", "Const")
|
||||
.Attr("dtype", DT_INT32)
|
||||
.Attr("value", paddings_tensor)
|
||||
.Finalize(graph.add_node()));
|
||||
|
||||
// Now add the convolution op
|
||||
NodeDef* conv = graph.add_node();
|
||||
TF_CHECK_OK(NodeDefBuilder("conv2d", "Conv2D")
|
||||
.Input("input", 0, DT_FLOAT)
|
||||
.Input("filter", 0, DT_FLOAT)
|
||||
.Attr("strides", {1, stride, stride, 1})
|
||||
.Attr("padding", "SAME")
|
||||
.Finalize(conv));
|
||||
|
||||
NodeDef* conv_bp_in = graph.add_node();
|
||||
TF_CHECK_OK(NodeDefBuilder("conv2d_bp_in", "Conv2DBackpropInput")
|
||||
.Input("input_sizes", 0, DT_INT32)
|
||||
.Input("filter", 0, DT_FLOAT)
|
||||
.Input("output_backprop", 0, DT_FLOAT)
|
||||
.Attr("strides", {1, stride, stride, 1})
|
||||
.Attr("padding", "SAME")
|
||||
.Finalize(conv_bp_in));
|
||||
|
||||
NodeDef* conv_bp_filter = graph.add_node();
|
||||
TF_CHECK_OK(NodeDefBuilder("conv2d_bp_filter", "Conv2DBackpropFilter")
|
||||
.Input("input", 0, DT_FLOAT)
|
||||
.Input("filter_sizes", 0, DT_INT32)
|
||||
.Input("output_backprop", 0, DT_FLOAT)
|
||||
.Attr("strides", {1, stride, stride, 1})
|
||||
.Attr("padding", "SAME")
|
||||
.Finalize(conv_bp_filter));
|
||||
|
||||
for (const auto& node : graph.node()) {
|
||||
if (node.name().find("conv2d") != 0) {
|
||||
continue;
|
||||
}
|
||||
std::vector<OpInfo::TensorProperties> inputs;
|
||||
inputs.resize(node.input_size());
|
||||
OpInfo info = BuildOpInfoWithoutDevice(node, name_to_node, inputs);
|
||||
if (node.name() == "conv2d") {
|
||||
EXPECT_EQ(2, info.inputs_size());
|
||||
} else if (node.name() == "conv2dbp_in") {
|
||||
EXPECT_EQ(3, info.inputs_size());
|
||||
} else if (node.name() == "conv2d_bp_filter") {
|
||||
EXPECT_EQ(3, info.inputs_size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
Loading…
Reference in New Issue
Block a user