mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Fixed BuildOpInfoWithoutDevice
PiperOrigin-RevId: 165653933
This commit is contained in:
parent
d7e425f0bd
commit
513def0bb2
|
|
@ -141,6 +141,24 @@ tf_cuda_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "utils_test",
|
||||||
|
srcs = ["utils_test.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":utils",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/core:all_kernels",
|
||||||
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:tensor_testutil",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cost_estimator",
|
name = "cost_estimator",
|
||||||
hdrs = ["cost_estimator.h"],
|
hdrs = ["cost_estimator.h"],
|
||||||
|
|
@ -170,7 +188,7 @@ cc_test(
|
||||||
srcs = ["virtual_placer_test.cc"],
|
srcs = ["virtual_placer_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":virtual_placer",
|
":virtual_placer",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
|
|
||||||
|
|
@ -70,11 +70,12 @@ static std::vector<TensorProto> ExtractTensors(const AttrValue& attr_value) {
|
||||||
return tensors;
|
return tensors;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Annotate the op_info inputs with extra information when possible (e.g. the
|
||||||
|
// input value if it's known statically).
|
||||||
static void ExtractExtraProperties(
|
static void ExtractExtraProperties(
|
||||||
const NodeDef& node,
|
const NodeDef& node,
|
||||||
const std::unordered_map<string, const NodeDef*>& name_to_node,
|
const std::unordered_map<string, const NodeDef*>& name_to_node,
|
||||||
std::vector<OpInfo::TensorProperties>* extra_inputs,
|
OpInfo* op_info) {
|
||||||
protobuf::Map<string, AttrValue>* attr_map) {
|
|
||||||
OpRegistry* op_registry = OpRegistry::Global();
|
OpRegistry* op_registry = OpRegistry::Global();
|
||||||
const OpDef* op_def = nullptr;
|
const OpDef* op_def = nullptr;
|
||||||
auto s = op_registry->LookUpOpDef(node.op(), &op_def);
|
auto s = op_registry->LookUpOpDef(node.op(), &op_def);
|
||||||
|
|
@ -102,11 +103,8 @@ static void ExtractExtraProperties(
|
||||||
if (tensors.empty()) continue;
|
if (tensors.empty()) continue;
|
||||||
|
|
||||||
const TensorProto& t = tensors[0];
|
const TensorProto& t = tensors[0];
|
||||||
OpInfo::TensorProperties input;
|
OpInfo::TensorProperties* input = op_info->mutable_inputs(i);
|
||||||
input.set_dtype(t.dtype());
|
*(input->mutable_value()) = t;
|
||||||
*(input.mutable_shape()) = t.tensor_shape();
|
|
||||||
*(input.mutable_value()) = t;
|
|
||||||
extra_inputs->push_back(input);
|
|
||||||
|
|
||||||
// For filename input, the file size can also be useful.
|
// For filename input, the file size can also be useful.
|
||||||
if (op_def && i < op_def->input_arg_size() &&
|
if (op_def && i < op_def->input_arg_size() &&
|
||||||
|
|
@ -129,7 +127,7 @@ static void ExtractExtraProperties(
|
||||||
AttrValue attr;
|
AttrValue attr;
|
||||||
attr.set_i(stat.length);
|
attr.set_i(stat.length);
|
||||||
string attr_key = strings::StrCat("input_", i, "_filesize");
|
string attr_key = strings::StrCat("input_", i, "_filesize");
|
||||||
(*attr_map)[attr_key] = attr;
|
(*op_info->mutable_attr())[attr_key] = attr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -140,7 +138,7 @@ static void ExtractExtraProperties(
|
||||||
string new_key = strings::StrCat("parent_", i, "_op");
|
string new_key = strings::StrCat("parent_", i, "_op");
|
||||||
AttrValue attr;
|
AttrValue attr;
|
||||||
attr.set_s(input_node->op());
|
attr.set_s(input_node->op());
|
||||||
(*attr_map)[new_key] = attr;
|
(*op_info->mutable_attr())[new_key] = attr;
|
||||||
// TODO(yuefengz): Only parent node's op name is copied. Copy inputs
|
// TODO(yuefengz): Only parent node's op name is copied. Copy inputs
|
||||||
// and attributes when necessary.
|
// and attributes when necessary.
|
||||||
}
|
}
|
||||||
|
|
@ -212,14 +210,7 @@ OpInfo BuildOpInfoWithoutDevice(
|
||||||
for (auto& input : inputs) {
|
for (auto& input : inputs) {
|
||||||
*op_info.add_inputs() = input;
|
*op_info.add_inputs() = input;
|
||||||
}
|
}
|
||||||
|
ExtractExtraProperties(node, name_to_node, &op_info);
|
||||||
std::vector<OpInfo::TensorProperties> extra_inputs;
|
|
||||||
ExtractExtraProperties(node, name_to_node, &extra_inputs,
|
|
||||||
op_info.mutable_attr());
|
|
||||||
for (auto& input : extra_inputs) {
|
|
||||||
*op_info.add_inputs() = input;
|
|
||||||
}
|
|
||||||
|
|
||||||
return op_info;
|
return op_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
150
tensorflow/core/grappler/costs/utils_test.cc
Normal file
150
tensorflow/core/grappler/costs/utils_test.cc
Normal file
|
|
@ -0,0 +1,150 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/costs/utils.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
|
||||||
|
class UtilsTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
void CreateConstOp(const string& name, std::initializer_list<int64> dims,
|
||||||
|
NodeDef* node) {
|
||||||
|
Tensor tensor(DT_FLOAT, TensorShape(dims));
|
||||||
|
for (int64 i = 0; i < tensor.NumElements(); ++i) {
|
||||||
|
tensor.flat<float>()(i) = i / 10.0f;
|
||||||
|
}
|
||||||
|
TF_CHECK_OK(NodeDefBuilder(name, "Const")
|
||||||
|
.Attr("dtype", DT_FLOAT)
|
||||||
|
.Attr("value", tensor)
|
||||||
|
.Finalize(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CreateConstSizesOp(const string& name, const std::vector<int32>& sizes,
|
||||||
|
NodeDef* node) {
|
||||||
|
TensorShape shape;
|
||||||
|
shape.AddDim(sizes.size());
|
||||||
|
Tensor tensor(DT_INT32, shape);
|
||||||
|
for (int64 i = 0; i < tensor.NumElements(); ++i) {
|
||||||
|
tensor.flat<int32>()(i) = sizes[i];
|
||||||
|
}
|
||||||
|
TF_CHECK_OK(NodeDefBuilder(name, "Const")
|
||||||
|
.Attr("dtype", DT_INT32)
|
||||||
|
.Attr("value", tensor)
|
||||||
|
.Finalize(node));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(UtilsTest, ConvOpInfo) {
|
||||||
|
int batch = 32;
|
||||||
|
int rows = 7;
|
||||||
|
int cols = 9;
|
||||||
|
int filter_rows = 3;
|
||||||
|
int filter_cols = 3;
|
||||||
|
int out_rows = 7;
|
||||||
|
int out_cols = 9;
|
||||||
|
int in_depth = 3;
|
||||||
|
int out_depth = 5;
|
||||||
|
int stride = 1;
|
||||||
|
|
||||||
|
std::unordered_map<string, const NodeDef*> name_to_node;
|
||||||
|
GraphDef graph;
|
||||||
|
NodeDef* input = graph.add_node();
|
||||||
|
name_to_node["input"] = input;
|
||||||
|
CreateConstOp("input", {batch, rows, cols, in_depth}, input);
|
||||||
|
NodeDef* filter = graph.add_node();
|
||||||
|
name_to_node["filter"] = filter;
|
||||||
|
CreateConstOp("filter", {filter_rows, filter_cols, in_depth, out_depth},
|
||||||
|
filter);
|
||||||
|
NodeDef* output_backprop = graph.add_node();
|
||||||
|
name_to_node["output_backprop"] = output_backprop;
|
||||||
|
CreateConstOp("output_backprop", {batch, out_rows, out_cols, out_depth},
|
||||||
|
output_backprop);
|
||||||
|
NodeDef* input_sizes = graph.add_node();
|
||||||
|
name_to_node["input_sizes"] = input;
|
||||||
|
CreateConstSizesOp("input_sizes",
|
||||||
|
std::vector<int32>({batch, rows, cols, in_depth}),
|
||||||
|
input_sizes);
|
||||||
|
NodeDef* filter_sizes = graph.add_node();
|
||||||
|
name_to_node["filter_sizes"] = filter_sizes;
|
||||||
|
CreateConstSizesOp(
|
||||||
|
"filter_sizes",
|
||||||
|
std::vector<int32>({filter_rows, filter_cols, in_depth, out_depth}),
|
||||||
|
filter_sizes);
|
||||||
|
|
||||||
|
TensorShape paddings_shape({4, 2});
|
||||||
|
Tensor paddings_tensor(DT_INT32, paddings_shape);
|
||||||
|
for (int64 i = 0; i < paddings_tensor.NumElements(); ++i) {
|
||||||
|
paddings_tensor.flat<int32>()(i) = 0;
|
||||||
|
}
|
||||||
|
TF_CHECK_OK(NodeDefBuilder("paddings", "Const")
|
||||||
|
.Attr("dtype", DT_INT32)
|
||||||
|
.Attr("value", paddings_tensor)
|
||||||
|
.Finalize(graph.add_node()));
|
||||||
|
|
||||||
|
// Now add the convolution op
|
||||||
|
NodeDef* conv = graph.add_node();
|
||||||
|
TF_CHECK_OK(NodeDefBuilder("conv2d", "Conv2D")
|
||||||
|
.Input("input", 0, DT_FLOAT)
|
||||||
|
.Input("filter", 0, DT_FLOAT)
|
||||||
|
.Attr("strides", {1, stride, stride, 1})
|
||||||
|
.Attr("padding", "SAME")
|
||||||
|
.Finalize(conv));
|
||||||
|
|
||||||
|
NodeDef* conv_bp_in = graph.add_node();
|
||||||
|
TF_CHECK_OK(NodeDefBuilder("conv2d_bp_in", "Conv2DBackpropInput")
|
||||||
|
.Input("input_sizes", 0, DT_INT32)
|
||||||
|
.Input("filter", 0, DT_FLOAT)
|
||||||
|
.Input("output_backprop", 0, DT_FLOAT)
|
||||||
|
.Attr("strides", {1, stride, stride, 1})
|
||||||
|
.Attr("padding", "SAME")
|
||||||
|
.Finalize(conv_bp_in));
|
||||||
|
|
||||||
|
NodeDef* conv_bp_filter = graph.add_node();
|
||||||
|
TF_CHECK_OK(NodeDefBuilder("conv2d_bp_filter", "Conv2DBackpropFilter")
|
||||||
|
.Input("input", 0, DT_FLOAT)
|
||||||
|
.Input("filter_sizes", 0, DT_INT32)
|
||||||
|
.Input("output_backprop", 0, DT_FLOAT)
|
||||||
|
.Attr("strides", {1, stride, stride, 1})
|
||||||
|
.Attr("padding", "SAME")
|
||||||
|
.Finalize(conv_bp_filter));
|
||||||
|
|
||||||
|
for (const auto& node : graph.node()) {
|
||||||
|
if (node.name().find("conv2d") != 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::vector<OpInfo::TensorProperties> inputs;
|
||||||
|
inputs.resize(node.input_size());
|
||||||
|
OpInfo info = BuildOpInfoWithoutDevice(node, name_to_node, inputs);
|
||||||
|
if (node.name() == "conv2d") {
|
||||||
|
EXPECT_EQ(2, info.inputs_size());
|
||||||
|
} else if (node.name() == "conv2dbp_in") {
|
||||||
|
EXPECT_EQ(3, info.inputs_size());
|
||||||
|
} else if (node.name() == "conv2d_bp_filter") {
|
||||||
|
EXPECT_EQ(3, info.inputs_size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace grappler
|
||||||
|
} // end namespace tensorflow
|
||||||
Loading…
Reference in New Issue
Block a user