Add optional unused_input_map_keys output param to ImportGraphDef

This is a more general feature than that in the Python importer, which
raises an exception if the input map contains unused names.

PiperOrigin-RevId: 171029316
This commit is contained in:
Skye Wanderman-Milne 2017-10-04 10:40:04 -07:00 committed by TensorFlower Gardener
parent 4f10a6597c
commit 9d7843c0a8
3 changed files with 124 additions and 41 deletions

View File

@ -108,14 +108,15 @@ class GraphConstructor {
const VersionDef* versions,
const FunctionDefLibrary* library, Graph* g,
ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors) {
std::vector<std::pair<Node*, int>>* return_tensors,
std::vector<TensorId>* unused_input_map_keys) {
if (versions) {
TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
"GraphDef", "graph"));
}
GraphConstructor c(opts, node_defs, versions, library, g, refiner,
return_tensors);
return_tensors, unused_input_map_keys);
const Status s = c.TryImport();
if (!s.ok()) c.Undo();
return s;
@ -126,7 +127,8 @@ class GraphConstructor {
const VersionDef* versions,
const FunctionDefLibrary* library, Graph* g,
ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors)
std::vector<std::pair<Node*, int>>* return_tensors,
std::vector<TensorId>* unused_input_map_keys)
: opts_(opts),
node_defs_(node_defs),
versions_(versions),
@ -134,7 +136,8 @@ class GraphConstructor {
g_(g),
original_versions_(g->versions()),
refiner_(refiner),
return_tensors_(return_tensors) {}
return_tensors_(return_tensors),
unused_input_map_keys_(unused_input_map_keys) {}
Status TryImport() {
TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
@ -193,7 +196,13 @@ class GraphConstructor {
// May be null. Not owned.
std::vector<std::pair<Node*, int>>* return_tensors_;
// Mapping from node name to the index within node_defs_
// May be null. Not owned.
std::vector<TensorId>* unused_input_map_keys_;
// Intermediate datastructure used to populate `unused_input_map_keys_`.
std::set<TensorId> used_input_map_keys_;
// Mapping from node name to the index within node_defs_.
struct NodeInfo {
explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
// std::unordered_map<> requires that we have a default constructor.
@ -583,6 +592,7 @@ void GraphConstructor::RemapNodeDefInputs(
for (int i = 0; i < node_def->input_size(); ++i) {
auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i)));
if (iter == opts_.input_map.end()) continue;
used_input_map_keys_.insert(iter->first);
TensorId new_input = iter->second;
if (new_input.second == Graph::kControlSlot) {
@ -840,6 +850,16 @@ Status GraphConstructor::Convert() {
return errors::InvalidArgument(node_defs_.size() - processed,
" nodes in a cycle");
}
// Update unused_input_map_keys_
if (unused_input_map_keys_ != nullptr) {
for (const auto& pair : opts_.input_map) {
if (used_input_map_keys_.find(pair.first) == used_input_map_keys_.end()) {
unused_input_map_keys_->push_back(pair.first);
}
}
}
return Status::OK();
}
@ -943,8 +963,9 @@ Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g) {
ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
&gdef.library(), g, &refiner, nullptr);
return GraphConstructor::Construct(
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
/*return_tensors=*/nullptr, /*unused_input_map_keys=*/nullptr);
}
Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
@ -956,25 +977,33 @@ Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
node_defs.push_back(&n);
}
return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
&refiner, nullptr);
&refiner, /*return_tensors=*/nullptr,
/*unused_input_map_keys=*/nullptr);
}
Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
Graph* g, ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors) {
std::vector<std::pair<Node*, int>>* return_tensors,
std::vector<TensorId>* unused_input_map_keys) {
if (!opts.return_tensors.empty()) {
if (return_tensors == nullptr) {
return errors::InvalidArgument(
"return_tensors argument to ImportNodeDef() must be non-null if "
"return_tensors argument to ImportGraphDef() must be non-null if "
"opts.return_tensors is non-empty");
}
if (!return_tensors->empty()) {
return errors::InvalidArgument(
"return_tensors argument to ImportNodeDef() should be empty (has "
"return_tensors argument to ImportGraphDef() should be empty (has "
"size ",
return_tensors->size(), ")");
}
}
if (unused_input_map_keys != nullptr && !unused_input_map_keys->empty()) {
return errors::InvalidArgument(
"If non-null, unused_input_map_keys argument to ImportGraphDef() should"
" be empty (has size ",
unused_input_map_keys->size(), ")");
}
ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
if (refiner == nullptr) {
@ -1007,7 +1036,7 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
&gdef.library(), g, refiner,
return_tensors);
return_tensors, unused_input_map_keys);
}
void CopyGraph(const Graph& src, Graph* dest) {

View File

@ -52,17 +52,7 @@ extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
gtl::ArraySlice<NodeDef> nodes, Graph* g);
// Add the graph in GraphDef gdef into an existing Graph *g.
//
// On error, returns non-OK and leaves *g unmodified.
//
// "shape_refiner" can be null. It should be non-null if the caller
// intends to add additional nodes to the graph after the import. This
// allows the caller to validate shapes of those nodes (since
// ShapeRefiner::AddNode must be called in topological order).
//
// TODO(ashankar): Push this mechanism and get rid of Session::Extend()
// as a means of enhancing an existing Graph.
// Options for calling ImportGraphDef().
struct ImportGraphDefOptions {
ImportGraphDefOptions() : skip_mapped_nodes(false) {}
@ -116,13 +106,30 @@ struct ImportGraphDefOptions {
// python API.
};
// Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
//
// On error, returns non-OK and leaves `*g` unmodified.
//
// `refiner` can be null. It should be non-null if the caller
// intends to add additional nodes to the graph after the import. This
// allows the caller to validate shapes of those nodes (since
// ShapeRefiner::AddNode must be called in topological order).
//
// Each `return_tensors` entry is the requested node and output index. The index
// is included in case the returned tensor has been remapped according to
// `input_map`.
//
// If `unused_input_map_keys` is non-null, it should be empty and will be
// populated with any keys in `opts.input_map` that aren't used as an input to
// any node in `gdef`.
//
// TODO(ashankar): Push this mechanism and get rid of Session::Extend()
// as a means of enhancing an existing Graph.
extern Status ImportGraphDef(
const ImportGraphDefOptions& opts, const GraphDef& gdef, Graph* g,
ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors = nullptr);
std::vector<std::pair<Node*, int>>* return_tensors = nullptr,
std::vector<TensorId>* unused_input_map_keys = nullptr);
// Make a copy of "src" into "*dest".
//

View File

@ -68,17 +68,17 @@ class GraphConstructorTest : public ::testing::Test {
EXPECT_EQ(original_graph_description, GraphDebugString());
}
void ExpectError(
const string& gdef_ascii, const ImportGraphDefOptions& opts,
const std::vector<string>& expected_error_strs,
ShapeRefiner* refiner = nullptr,
std::vector<std::pair<Node*, int>>* return_tensors = nullptr) {
void ExpectError(const string& gdef_ascii, const ImportGraphDefOptions& opts,
const std::vector<string>& expected_error_strs,
ShapeRefiner* refiner = nullptr,
std::vector<std::pair<Node*, int>>* return_tensors = nullptr,
std::vector<TensorId>* unused_input_map_keys = nullptr) {
// Used to verify that errors don't change graph
const string original_graph_description = GraphDebugString();
Convert(gdef_ascii);
Status status =
ImportGraphDef(opts, gdef_, &graph_, refiner, return_tensors);
Status status = ImportGraphDef(opts, gdef_, &graph_, refiner,
return_tensors, unused_input_map_keys);
EXPECT_FALSE(status.ok());
for (const string& error : expected_error_strs) {
@ -97,9 +97,11 @@ class GraphConstructorTest : public ::testing::Test {
void ExpectOK(const string& gdef_ascii, const ImportGraphDefOptions& opts,
ShapeRefiner* refiner = nullptr,
std::vector<std::pair<Node*, int>>* return_tensors = nullptr) {
std::vector<std::pair<Node*, int>>* return_tensors = nullptr,
std::vector<TensorId>* unused_input_map_keys = nullptr) {
Convert(gdef_ascii);
Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, return_tensors);
Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, return_tensors,
unused_input_map_keys);
EXPECT_EQ(Status::OK(), s) << s;
}
@ -1279,8 +1281,9 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithControlEdges) {
// Create input_map containing control edges and use it to import more nodes
ImportGraphDefOptions opts;
opts.input_map[TensorId("W2", -1)] = TensorId("W1", -1);
opts.input_map[TensorId("W3", -1)] = TensorId("W1", -1);
const int kControlSlot = Graph::kControlSlot;
opts.input_map[TensorId("W2", kControlSlot)] = TensorId("W1", kControlSlot);
opts.input_map[TensorId("W3", kControlSlot)] = TensorId("W1", kControlSlot);
ExpectOK(
R"EOF(
node { name: 'W2' op: 'TestParams' }
@ -1316,7 +1319,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithControlEdges) {
// node
opts.prefix = "import";
opts.input_map.clear();
opts.input_map[TensorId("W1", -1)] = TensorId("W1", -1);
opts.input_map[TensorId("W1", kControlSlot)] = TensorId("W1", kControlSlot);
ExpectOK(
R"EOF(
node { name: 'W1' op: 'TestParams' }
@ -1343,7 +1346,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithBadControlEdge) {
// Create input_map with bad control edge mapping
ImportGraphDefOptions opts;
opts.input_map[TensorId("W2", -1)] = TensorId("W1", 0);
opts.input_map[TensorId("W2", Graph::kControlSlot)] = TensorId("W1", 0);
ExpectError(
R"EOF(
node { name: 'W2' op: 'TestParams' }
@ -1355,7 +1358,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithBadControlEdge) {
opts.input_map.clear();
// "W2:0" isn't used in the imported graph but still causes an error
opts.input_map[TensorId("W2", 0)] = TensorId("W1", -1);
opts.input_map[TensorId("W2", 0)] = TensorId("W1", Graph::kControlSlot);
ExpectError(
R"EOF(
node { name: 'W2' op: 'TestParams' }
@ -1396,7 +1399,8 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithMissingEntries) {
// Create input_map referencing node that doesn't exist in graph
ImportGraphDefOptions opts;
opts.input_map[TensorId("W2", -1)] = TensorId("DNE", -1);
const int kControlSlot = Graph::kControlSlot;
opts.input_map[TensorId("W2", kControlSlot)] = TensorId("DNE", kControlSlot);
ExpectError(
R"EOF(
node { name: 'W2' op: 'TestParams' }
@ -1433,6 +1437,49 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapDuplicateNodeNames) {
&refiner);
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) {
ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
std::vector<TensorId> unused_input_map_keys;
// No input map
ImportGraphDefOptions opts;
ExpectOK(
"node { name: 'W1' op: 'TestParams' }"
"node { name: 'input' op: 'TestInput' }",
opts, &refiner, nullptr, &unused_input_map_keys);
EXPECT_TRUE(unused_input_map_keys.empty());
// Non-empty unused_input_map_keys
unused_input_map_keys.push_back(TensorId());
ExpectError("node { name: 'W2' op: 'TestParams' }", opts,
{"If non-null, unused_input_map_keys argument to ImportGraphDef()"
" should be empty (has size 1)"},
&refiner, nullptr, &unused_input_map_keys);
// Input map with some used, some unused keys
const int kControlSlot = Graph::kControlSlot;
unused_input_map_keys.clear();
opts.input_map[TensorId("W2", kControlSlot)] = TensorId("W1", kControlSlot);
opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0);
opts.input_map[TensorId("new_input", 1)] = TensorId("input", 0);
opts.input_map[TensorId("new_input", kControlSlot)] =
TensorId("input", kControlSlot);
opts.input_map[TensorId("t1", 1)] = TensorId("input", 0);
ExpectOK(
R"EOF(
node { name: 'W2' op: 'TestParams' }
node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] }
node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] }
)EOF",
opts, &refiner, nullptr, &unused_input_map_keys);
std::vector<TensorId> expected_unused_keys = {
TensorId("new_input", kControlSlot), TensorId("t1", 1)};
EXPECT_EQ(unused_input_map_keys, expected_unused_keys);
}
TEST_F(GraphConstructorTest, ImportGraphDef_SkipMappedNodes_FullyMapped) {
ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
@ -1586,13 +1633,13 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensorsErrors) {
// Null return_tensors with non-empty opts.return_tensors
opts.return_tensors.push_back({"new_input", 0});
ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
{"return_tensors argument to ImportNodeDef() must be non-null "
{"return_tensors argument to ImportGraphDef() must be non-null "
"if opts.return_tensors is non-empty"});
// Non-empty return_tensors
return_tensors.push_back({nullptr, 0});
ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
{"return_tensors argument to ImportNodeDef() should be empty "
{"return_tensors argument to ImportGraphDef() should be empty "
"(has size 1)"},
nullptr, &return_tensors);