mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Automated Code Change
PiperOrigin-RevId: 826749442
This commit is contained in:
parent
aa4db17b00
commit
ef28899305
|
|
@ -65,7 +65,7 @@ absl::Status GraphTopologyView::InitializeFromGraph(
|
|||
const auto src = node_name_to_index_.find(edge.src.node->name());
|
||||
const bool valid_src = src != node_name_to_index_.end();
|
||||
if (!valid_src) {
|
||||
const string error_message =
|
||||
const std::string error_message =
|
||||
absl::StrCat("Non-existent src node: ", edge.src.node->name());
|
||||
if (skip_invalid_edges_) {
|
||||
VLOG(0) << "Skip error: " << error_message;
|
||||
|
|
@ -78,7 +78,7 @@ absl::Status GraphTopologyView::InitializeFromGraph(
|
|||
const bool valid_dst = dst != node_name_to_index_.end();
|
||||
|
||||
if (!valid_dst) {
|
||||
const string error_message =
|
||||
const std::string error_message =
|
||||
absl::StrCat("Non-existent dst node: ", edge.dst.node->name());
|
||||
if (skip_invalid_edges_) {
|
||||
VLOG(0) << "Skip error: " << error_message;
|
||||
|
|
@ -103,7 +103,7 @@ absl::Status GraphTopologyView::InitializeFromGraph(
|
|||
const NodeDef& node = graph.node(node_idx);
|
||||
fanins_[node_idx].reserve(node.input_size());
|
||||
|
||||
for (const string& input : node.input()) {
|
||||
for (const std::string& input : node.input()) {
|
||||
TensorId tensor = ParseTensorName(input);
|
||||
if (ignore_control_edges && IsTensorIdControl(tensor)) {
|
||||
continue;
|
||||
|
|
@ -112,8 +112,8 @@ absl::Status GraphTopologyView::InitializeFromGraph(
|
|||
const bool valid_input = it != node_name_to_index_.end();
|
||||
|
||||
if (!valid_input) {
|
||||
const string error_message = absl::StrCat("Non-existent input ", input,
|
||||
" in node ", node.name());
|
||||
const std::string error_message = absl::StrCat(
|
||||
"Non-existent input ", input, " in node ", node.name());
|
||||
if (skip_invalid_edges_) {
|
||||
VLOG(3) << "Skip error: " << error_message;
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ namespace grappler {
|
|||
|
||||
class GraphTopologyViewTest : public ::testing::Test {
|
||||
protected:
|
||||
using NodeConfig = std::pair<string, std::vector<string>>;
|
||||
using NodeConfig = std::pair<std::string, std::vector<std::string>>;
|
||||
|
||||
static GraphDef CreateGraph(const std::vector<NodeConfig>& nodes) {
|
||||
GraphDef graph;
|
||||
|
|
@ -35,7 +35,7 @@ class GraphTopologyViewTest : public ::testing::Test {
|
|||
|
||||
NodeDef node_def;
|
||||
node_def.set_name(node_name);
|
||||
for (const string& input : node_inputs) {
|
||||
for (const std::string& input : node_inputs) {
|
||||
node_def.add_input(input);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -161,16 +161,16 @@ TEST_F(GraphViewTest, BasicGraph) {
|
|||
const NodeDef* add_node = graph.GetNode("AddN");
|
||||
EXPECT_NE(add_node, nullptr);
|
||||
|
||||
absl::flat_hash_set<string> fanouts;
|
||||
absl::flat_hash_set<string> expected_fanouts = {"AddN_2:0", "AddN_3:0"};
|
||||
absl::flat_hash_set<std::string> fanouts;
|
||||
absl::flat_hash_set<std::string> expected_fanouts = {"AddN_2:0", "AddN_3:0"};
|
||||
for (const auto& fo : graph.GetFanouts(*add_node, false)) {
|
||||
fanouts.insert(absl::StrCat(fo.node->name(), ":", fo.port_id));
|
||||
}
|
||||
EXPECT_EQ(graph.NumFanouts(*add_node, false), 2);
|
||||
EXPECT_EQ(fanouts, expected_fanouts);
|
||||
|
||||
absl::flat_hash_set<string> fanins;
|
||||
absl::flat_hash_set<string> expected_fanins = {"Sign_1:0", "Sign:0"};
|
||||
absl::flat_hash_set<std::string> fanins;
|
||||
absl::flat_hash_set<std::string> expected_fanins = {"Sign_1:0", "Sign:0"};
|
||||
for (const auto& fi : graph.GetFanins(*add_node, false)) {
|
||||
fanins.insert(absl::StrCat(fi.node->name(), ":", fi.port_id));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,9 +74,9 @@ std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {
|
|||
}
|
||||
|
||||
std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const {
|
||||
std::vector<string> enqueue_ops;
|
||||
std::vector<std::string> enqueue_ops;
|
||||
for (const auto& queue_runner : queue_runners) {
|
||||
for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
|
||||
for (const std::string& enqueue_op : queue_runner.enqueue_op_name()) {
|
||||
enqueue_ops.push_back(enqueue_op);
|
||||
}
|
||||
}
|
||||
|
|
@ -103,9 +103,9 @@ std::vector<const NodeDef*> GrapplerItem::MainVariables() const {
|
|||
return vars;
|
||||
}
|
||||
|
||||
std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
|
||||
std::unordered_set<string> result;
|
||||
for (const string& f : fetch) {
|
||||
std::unordered_set<std::string> GrapplerItem::NodesToPreserve() const {
|
||||
std::unordered_set<std::string> result;
|
||||
for (const std::string& f : fetch) {
|
||||
VLOG(1) << "Add fetch " << f;
|
||||
result.insert(NodeName(f));
|
||||
}
|
||||
|
|
@ -130,7 +130,7 @@ std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
|
|||
}
|
||||
|
||||
for (const auto& queue_runner : queue_runners) {
|
||||
for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
|
||||
for (const std::string& enqueue_op : queue_runner.enqueue_op_name()) {
|
||||
result.insert(NodeName(enqueue_op));
|
||||
}
|
||||
if (!queue_runner.close_op_name().empty()) {
|
||||
|
|
@ -167,11 +167,11 @@ std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
|
|||
return result;
|
||||
}
|
||||
|
||||
const std::unordered_set<string>& GrapplerItem::devices() const {
|
||||
const std::unordered_set<std::string>& GrapplerItem::devices() const {
|
||||
return devices_;
|
||||
}
|
||||
|
||||
absl::Status GrapplerItem::AddDevice(const string& device) {
|
||||
absl::Status GrapplerItem::AddDevice(const std::string& device) {
|
||||
DeviceNameUtils::ParsedName name;
|
||||
|
||||
if (!DeviceNameUtils::ParseFullName(device, &name)) {
|
||||
|
|
@ -189,7 +189,7 @@ absl::Status GrapplerItem::AddDevice(const string& device) {
|
|||
|
||||
absl::Status GrapplerItem::AddDevices(const GrapplerItem& other) {
|
||||
std::vector<absl::string_view> invalid_devices;
|
||||
for (const string& device : other.devices()) {
|
||||
for (const std::string& device : other.devices()) {
|
||||
absl::Status added = AddDevice(device);
|
||||
if (!added.ok()) invalid_devices.emplace_back(device);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -46,22 +46,22 @@ struct GrapplerItem {
|
|||
// Create a copy of this GrapplerItem with graph swapped with the argument.
|
||||
GrapplerItem WithGraph(GraphDef&& graph) const;
|
||||
|
||||
string id; // A unique id for this item
|
||||
std::string id; // A unique id for this item
|
||||
|
||||
// Inputs
|
||||
GraphDef graph;
|
||||
std::vector<std::pair<string, Tensor>> feed;
|
||||
std::vector<string> fetch;
|
||||
std::vector<std::pair<std::string, Tensor>> feed;
|
||||
std::vector<std::string> fetch;
|
||||
|
||||
// Initialization op(s).
|
||||
std::vector<string> init_ops;
|
||||
std::vector<std::string> init_ops;
|
||||
// Expected initialization time in seconds, or 0 if unknown
|
||||
int64_t expected_init_time = 0;
|
||||
|
||||
// Save/restore ops (if any)
|
||||
string save_op;
|
||||
string restore_op;
|
||||
string save_restore_loc_tensor;
|
||||
std::string save_op;
|
||||
std::string restore_op;
|
||||
std::string save_restore_loc_tensor;
|
||||
|
||||
// Queue runner(s) required to run the queue(s) of this model.
|
||||
std::vector<QueueRunnerDef> queue_runners;
|
||||
|
|
@ -69,7 +69,7 @@ struct GrapplerItem {
|
|||
// List of op names to keep in the graph. This includes nodes that are
|
||||
// referenced in various collections, and therefore must be preserved to
|
||||
// ensure that the optimized metagraph can still be loaded.
|
||||
std::vector<string> keep_ops;
|
||||
std::vector<std::string> keep_ops;
|
||||
|
||||
// Return the set of node evaluated during a regular train/inference step.
|
||||
std::vector<const NodeDef*> MainOpsFanin() const;
|
||||
|
|
@ -81,7 +81,7 @@ struct GrapplerItem {
|
|||
std::vector<const NodeDef*> MainVariables() const;
|
||||
// Return a set of node names that must be preserved. This includes feed and
|
||||
// fetch nodes, keep_ops, init_ops.
|
||||
std::unordered_set<string> NodesToPreserve() const;
|
||||
std::unordered_set<std::string> NodesToPreserve() const;
|
||||
|
||||
struct OptimizationOptions {
|
||||
// Is it allowed to add nodes to the graph that do not have registered
|
||||
|
|
@ -108,11 +108,11 @@ struct GrapplerItem {
|
|||
int intra_op_parallelism_threads = tsl::port::MaxParallelism();
|
||||
};
|
||||
|
||||
const std::unordered_set<string>& devices() const;
|
||||
const std::unordered_set<std::string>& devices() const;
|
||||
// Adds a device to a set of available devices, only if it's a valid fully
|
||||
// defined device name. Returns `OkStatus()` if successfully added a device,
|
||||
// and an error otherwise.
|
||||
absl::Status AddDevice(const string& device);
|
||||
absl::Status AddDevice(const std::string& device);
|
||||
// Adds all valid devices from the other Grappler item to the device set.
|
||||
absl::Status AddDevices(const GrapplerItem& other);
|
||||
// Adds all valid devices from the nodes of the graph to the device set.
|
||||
|
|
@ -132,7 +132,7 @@ struct GrapplerItem {
|
|||
// A set of fully defined device names that can be used to place the nodes of
|
||||
// the `graph`.
|
||||
// Example of a fully defined name: "/job:work/replica:1/task:1/device:CPU:0"
|
||||
std::unordered_set<string> devices_;
|
||||
std::unordered_set<std::string> devices_;
|
||||
|
||||
OptimizationOptions optimization_options_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ absl::Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
|
|||
const TensorShapeProto& shape_pb_in,
|
||||
TensorShapeProto* shape_pb_out,
|
||||
TensorShape* shape_out) {
|
||||
std::vector<int32> dims;
|
||||
std::vector<int32_t> dims;
|
||||
for (const auto& dim_proto : shape_pb_in.dim()) {
|
||||
if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
|
||||
dim_proto.size() == -1) {
|
||||
|
|
@ -103,7 +103,7 @@ absl::Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
|
|||
shape_pb_out->add_dim()->set_size(
|
||||
cfg.placeholder_unknown_output_shape_dim);
|
||||
} else {
|
||||
dims.push_back(std::max<int32>(1, dim_proto.size()));
|
||||
dims.push_back(std::max<int32_t>(1, dim_proto.size()));
|
||||
shape_pb_out->add_dim()->set_size(dim_proto.size());
|
||||
}
|
||||
}
|
||||
|
|
@ -117,7 +117,7 @@ absl::Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
|
|||
// (b/134092018).
|
||||
absl::Status UpdatePlaceholderShape(
|
||||
const ItemConfig& cfg,
|
||||
const std::unordered_set<string>& signature_feed_nodes,
|
||||
const std::unordered_set<std::string>& signature_feed_nodes,
|
||||
GrapplerItem* new_item, NodeDef* node) {
|
||||
if (node->attr().count("dtype") == 0) {
|
||||
return absl::InternalError(absl::StrCat("Unknown type for placeholder ",
|
||||
|
|
@ -188,7 +188,7 @@ absl::Status UpdatePlaceholderShape(
|
|||
} else if (cfg.feed_nodes.count(node->name()) > 0) {
|
||||
// If specific feed nodes were given, only update their tensors.
|
||||
auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
|
||||
[&node](std::pair<string, Tensor>& f) {
|
||||
[&node](std::pair<std::string, Tensor>& f) {
|
||||
return f.first == node->name();
|
||||
});
|
||||
DCHECK(it != new_item->feed.end());
|
||||
|
|
@ -294,7 +294,8 @@ absl::Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
|
|||
}
|
||||
|
||||
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
|
||||
const std::string& id, const MetaGraphDef& meta_graph,
|
||||
const ItemConfig& cfg) {
|
||||
if (id.empty()) {
|
||||
LOG(ERROR) << "id must be non-empty.";
|
||||
return nullptr;
|
||||
|
|
@ -305,7 +306,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||
|
||||
// Fill in feed nodes from config, if any provided.
|
||||
for (const auto& feed_node : cfg.feed_nodes) {
|
||||
const string feed_name = NodeName(feed_node);
|
||||
const std::string feed_name = NodeName(feed_node);
|
||||
new_item->feed.emplace_back(feed_name, Tensor());
|
||||
}
|
||||
for (const auto& fetch_node : cfg.fetch_nodes) {
|
||||
|
|
@ -325,8 +326,8 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||
|
||||
// Detect feed and fetch nodes from signature defs. Signatures may share same
|
||||
// inputs or outputs.
|
||||
std::unordered_set<string> signature_feed_nodes;
|
||||
std::unordered_set<string> signature_fetch_nodes;
|
||||
std::unordered_set<std::string> signature_feed_nodes;
|
||||
std::unordered_set<std::string> signature_fetch_nodes;
|
||||
for (const auto& name_and_signature : meta_graph.signature_def()) {
|
||||
for (const auto& name_and_input : name_and_signature.second.inputs()) {
|
||||
const TensorInfo& input = name_and_input.second;
|
||||
|
|
@ -442,7 +443,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||
// have to run restore op first.
|
||||
|
||||
// Try to find initializers from variables and tables as init ops.
|
||||
for (const string& var_collection :
|
||||
for (const std::string& var_collection :
|
||||
{"variables", "local_variables", "model_variables",
|
||||
"trainable_variables"}) {
|
||||
if (meta_graph.collection_def().count(var_collection) == 0) {
|
||||
|
|
@ -476,7 +477,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||
// We keep the mapping from asset node to asset files. This should have been
|
||||
// used as feed but since asset node is usually a constant node, we will fill
|
||||
// the values of these constant nodes with their actual asset file paths.
|
||||
std::unordered_map<string, string> asset_node_to_value;
|
||||
std::unordered_map<std::string, std::string> asset_node_to_value;
|
||||
|
||||
// Assets file may have changed their directory, we assemble their new paths
|
||||
// if assets_directory_override is set. We also make sure we still can
|
||||
|
|
@ -495,8 +496,8 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||
LOG(ERROR) << "Failed to parse AssetFile.";
|
||||
continue;
|
||||
}
|
||||
string asset_filepath = io::JoinPath(cfg.assets_directory_override,
|
||||
asset_file_def.filename());
|
||||
std::string asset_filepath = io::JoinPath(
|
||||
cfg.assets_directory_override, asset_file_def.filename());
|
||||
if (!FilesExist({asset_filepath}, nullptr)) {
|
||||
LOG(ERROR) << "Can't access one or more of the asset files "
|
||||
<< asset_filepath << ", skipping this input";
|
||||
|
|
@ -514,7 +515,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||
} else if (meta_graph.collection_def().count("asset_filepaths") > 0) {
|
||||
const CollectionDef& file_paths =
|
||||
meta_graph.collection_def().at("asset_filepaths");
|
||||
std::vector<string> paths;
|
||||
std::vector<std::string> paths;
|
||||
for (const auto& raw_path : file_paths.bytes_list().value()) {
|
||||
paths.push_back(raw_path);
|
||||
}
|
||||
|
|
@ -544,7 +545,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||
// Add each node referenced in a collection to the list of nodes to keep.
|
||||
for (const auto& col : meta_graph.collection_def()) {
|
||||
const CollectionDef& collection = col.second;
|
||||
for (const string& node : collection.node_list().value()) {
|
||||
for (const std::string& node : collection.node_list().value()) {
|
||||
new_item->keep_ops.push_back(NodeName(node));
|
||||
}
|
||||
}
|
||||
|
|
@ -654,7 +655,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||
}
|
||||
|
||||
// Validate feed, fetch and init nodes
|
||||
std::unordered_set<string> nodes;
|
||||
std::unordered_set<std::string> nodes;
|
||||
for (const auto& node : new_item->graph.node()) {
|
||||
nodes.insert(node.name());
|
||||
}
|
||||
|
|
@ -680,7 +681,8 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||
}
|
||||
|
||||
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
|
||||
const string& id, const string& meta_graph_file, const ItemConfig& cfg) {
|
||||
const std::string& id, const std::string& meta_graph_file,
|
||||
const ItemConfig& cfg) {
|
||||
MetaGraphDef meta_graph;
|
||||
if (!ReadMetaGraphDefFromFile(meta_graph_file, &meta_graph).ok()) {
|
||||
LOG(ERROR) << "Failed to read " << meta_graph_file;
|
||||
|
|
|
|||
|
|
@ -43,13 +43,13 @@ struct ItemConfig {
|
|||
// Has no effect if "inline_functions" is disabled.
|
||||
bool erase_noinline_attributes = false;
|
||||
// If non-empty, override the directory of asset paths.
|
||||
string assets_directory_override;
|
||||
std::string assets_directory_override;
|
||||
// If true, runs ModelPruner on the graph.
|
||||
bool prune_graph = false;
|
||||
// Override feed nodes list.
|
||||
std::set<string> feed_nodes;
|
||||
std::set<std::string> feed_nodes;
|
||||
// Override fetch nodes list.
|
||||
std::set<string> fetch_nodes;
|
||||
std::set<std::string> fetch_nodes;
|
||||
|
||||
// Configs for graph optimizations from common_runtime. This is NOT Grappler
|
||||
// function optimizer. When Grappler is invoked at runtime, it is typically
|
||||
|
|
@ -71,13 +71,15 @@ absl::Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
|
|||
// Factory method for creating a GrapplerItem from a MetaGraphDef.
|
||||
// Returns nullptr if the given meta_graph cannot be converted.
|
||||
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg);
|
||||
const std::string& id, const MetaGraphDef& meta_graph,
|
||||
const ItemConfig& cfg);
|
||||
|
||||
// Factory method for creating a GrapplerItem from a file
|
||||
// containing a MetaGraphDef in either binary or text format.
|
||||
// Returns nullptr if the given meta_graph cannot be converted.
|
||||
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
|
||||
const string& id, const string& meta_graph_file, const ItemConfig& cfg);
|
||||
const std::string& id, const std::string& meta_graph_file,
|
||||
const ItemConfig& cfg);
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -42,19 +42,19 @@ TEST_F(GrapplerItemBuilderTest, AssetFilepathOverrideTest) {
|
|||
Output var =
|
||||
ops::Variable(s.WithOpName("var"), TensorShape(), DataType::DT_FLOAT);
|
||||
Output filename_node =
|
||||
ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
|
||||
ops::Const(s.WithOpName("filename"), std::string("model"), TensorShape());
|
||||
Output tensor_name =
|
||||
ops::Const(s.WithOpName("tensorname"), string("var"), TensorShape());
|
||||
ops::Const(s.WithOpName("tensorname"), std::string("var"), TensorShape());
|
||||
Output restore = ops::Restore(s.WithOpName("restore"), filename_node,
|
||||
tensor_name, DataType::DT_FLOAT);
|
||||
Output assign = ops::Assign(s.WithOpName("assign"), var, restore);
|
||||
|
||||
TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def()));
|
||||
|
||||
string temp_dir = testing::TmpDir();
|
||||
std::string temp_dir = testing::TmpDir();
|
||||
|
||||
Env *env = Env::Default();
|
||||
string filename =
|
||||
std::string filename =
|
||||
io::JoinPath(temp_dir, "grappler_item_builder_test_filename");
|
||||
env->DeleteFile(filename).IgnoreError();
|
||||
std::unique_ptr<WritableFile> file_to_write;
|
||||
|
|
@ -88,7 +88,7 @@ TEST_F(GrapplerItemBuilderTest, AssetFilepathOverrideTest) {
|
|||
ASSERT_TRUE(iter->second.has_tensor());
|
||||
ASSERT_EQ(1, iter->second.tensor().string_val_size());
|
||||
|
||||
string tensor_string_val = iter->second.tensor().string_val(0);
|
||||
std::string tensor_string_val = iter->second.tensor().string_val(0);
|
||||
EXPECT_EQ(tensor_string_val, filename);
|
||||
}
|
||||
}
|
||||
|
|
@ -100,12 +100,12 @@ TEST_F(GrapplerItemBuilderTest, AssetFilepathOverrideTest_FileNotAccessible) {
|
|||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output var =
|
||||
ops::Variable(s.WithOpName("var"), TensorShape(), DataType::DT_FLOAT);
|
||||
Output filename_node1 =
|
||||
ops::Const(s.WithOpName("filename1"), string("model1"), TensorShape());
|
||||
Output filename_node2 =
|
||||
ops::Const(s.WithOpName("filename2"), string("model2"), TensorShape());
|
||||
Output filename_node1 = ops::Const(s.WithOpName("filename1"),
|
||||
std::string("model1"), TensorShape());
|
||||
Output filename_node2 = ops::Const(s.WithOpName("filename2"),
|
||||
std::string("model2"), TensorShape());
|
||||
Output tensor_name =
|
||||
ops::Const(s.WithOpName("tensorname"), string("var"), TensorShape());
|
||||
ops::Const(s.WithOpName("tensorname"), std::string("var"), TensorShape());
|
||||
Output restore1 = ops::Restore(s.WithOpName("restore1"), filename_node1,
|
||||
tensor_name, DataType::DT_FLOAT);
|
||||
Output restore2 = ops::Restore(s.WithOpName("restore2"), filename_node1,
|
||||
|
|
@ -115,11 +115,11 @@ TEST_F(GrapplerItemBuilderTest, AssetFilepathOverrideTest_FileNotAccessible) {
|
|||
|
||||
TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def()));
|
||||
|
||||
string temp_dir = testing::TmpDir();
|
||||
std::string temp_dir = testing::TmpDir();
|
||||
|
||||
// Create the first AssetFileDef that has a valid file.
|
||||
Env *env = Env::Default();
|
||||
string filename1 =
|
||||
std::string filename1 =
|
||||
io::JoinPath(temp_dir, "grappler_item_builder_test_filename1");
|
||||
env->DeleteFile(filename1).IgnoreError();
|
||||
std::unique_ptr<WritableFile> file_to_write;
|
||||
|
|
@ -132,7 +132,7 @@ TEST_F(GrapplerItemBuilderTest, AssetFilepathOverrideTest_FileNotAccessible) {
|
|||
*asset_file_def1.mutable_filename() = "grappler_item_builder_test_filename1";
|
||||
|
||||
// Create the second AssetFileDef that has not a valid file.
|
||||
string filename2 =
|
||||
std::string filename2 =
|
||||
io::JoinPath(temp_dir, "grappler_item_builder_test_filename1");
|
||||
env->DeleteFile(filename2).IgnoreError();
|
||||
EXPECT_FALSE(env->FileExists(filename2).ok());
|
||||
|
|
|
|||
|
|
@ -33,11 +33,11 @@ TEST_F(GrapplerItemTest, Basic) {
|
|||
|
||||
EXPECT_TRUE(item.InitOpsFanin().empty());
|
||||
|
||||
std::vector<string> graph_nodes;
|
||||
std::vector<std::string> graph_nodes;
|
||||
for (const auto& node : item.graph.node()) {
|
||||
graph_nodes.push_back(node.name());
|
||||
}
|
||||
std::vector<string> main_ops;
|
||||
std::vector<std::string> main_ops;
|
||||
for (const auto& node : item.MainOpsFanin()) {
|
||||
main_ops.push_back(node->name());
|
||||
}
|
||||
|
|
@ -49,9 +49,9 @@ TEST_F(GrapplerItemTest, Basic) {
|
|||
TEST_F(GrapplerItemTest, InferDevices) {
|
||||
using test::function::NDef;
|
||||
|
||||
const string cpu0 = "/job:work/replica:1/task:1/device:CPU:0";
|
||||
const string cpu1 = "/job:work/replica:1/task:1/device:CPU:1";
|
||||
const string cpu2 = "/device:CPU:2";
|
||||
const std::string cpu0 = "/job:work/replica:1/task:1/device:CPU:0";
|
||||
const std::string cpu1 = "/job:work/replica:1/task:1/device:CPU:1";
|
||||
const std::string cpu2 = "/device:CPU:2";
|
||||
|
||||
GrapplerItem item;
|
||||
item.graph = test::function::GDef(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user