mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-08 07:38:39 +01:00
Add optional unused_input_map_keys output param to ImportGraphDef
This is a more general feature than that in the Python importer, which raises an exception if the input map contains unused names. PiperOrigin-RevId: 171029316
This commit is contained in:
parent
4f10a6597c
commit
9d7843c0a8
|
|
@ -108,14 +108,15 @@ class GraphConstructor {
|
|||
const VersionDef* versions,
|
||||
const FunctionDefLibrary* library, Graph* g,
|
||||
ShapeRefiner* refiner,
|
||||
std::vector<std::pair<Node*, int>>* return_tensors) {
|
||||
std::vector<std::pair<Node*, int>>* return_tensors,
|
||||
std::vector<TensorId>* unused_input_map_keys) {
|
||||
if (versions) {
|
||||
TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
|
||||
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
|
||||
"GraphDef", "graph"));
|
||||
}
|
||||
GraphConstructor c(opts, node_defs, versions, library, g, refiner,
|
||||
return_tensors);
|
||||
return_tensors, unused_input_map_keys);
|
||||
const Status s = c.TryImport();
|
||||
if (!s.ok()) c.Undo();
|
||||
return s;
|
||||
|
|
@ -126,7 +127,8 @@ class GraphConstructor {
|
|||
const VersionDef* versions,
|
||||
const FunctionDefLibrary* library, Graph* g,
|
||||
ShapeRefiner* refiner,
|
||||
std::vector<std::pair<Node*, int>>* return_tensors)
|
||||
std::vector<std::pair<Node*, int>>* return_tensors,
|
||||
std::vector<TensorId>* unused_input_map_keys)
|
||||
: opts_(opts),
|
||||
node_defs_(node_defs),
|
||||
versions_(versions),
|
||||
|
|
@ -134,7 +136,8 @@ class GraphConstructor {
|
|||
g_(g),
|
||||
original_versions_(g->versions()),
|
||||
refiner_(refiner),
|
||||
return_tensors_(return_tensors) {}
|
||||
return_tensors_(return_tensors),
|
||||
unused_input_map_keys_(unused_input_map_keys) {}
|
||||
|
||||
Status TryImport() {
|
||||
TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
|
||||
|
|
@ -193,7 +196,13 @@ class GraphConstructor {
|
|||
// May be null. Not owned.
|
||||
std::vector<std::pair<Node*, int>>* return_tensors_;
|
||||
|
||||
// Mapping from node name to the index within node_defs_
|
||||
// May be null. Not owned.
|
||||
std::vector<TensorId>* unused_input_map_keys_;
|
||||
|
||||
// Intermediate datastructure used to populate `unused_input_map_keys_`.
|
||||
std::set<TensorId> used_input_map_keys_;
|
||||
|
||||
// Mapping from node name to the index within node_defs_.
|
||||
struct NodeInfo {
|
||||
explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
|
||||
// std::unordered_map<> requires that we have a default constructor.
|
||||
|
|
@ -583,6 +592,7 @@ void GraphConstructor::RemapNodeDefInputs(
|
|||
for (int i = 0; i < node_def->input_size(); ++i) {
|
||||
auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i)));
|
||||
if (iter == opts_.input_map.end()) continue;
|
||||
used_input_map_keys_.insert(iter->first);
|
||||
|
||||
TensorId new_input = iter->second;
|
||||
if (new_input.second == Graph::kControlSlot) {
|
||||
|
|
@ -840,6 +850,16 @@ Status GraphConstructor::Convert() {
|
|||
return errors::InvalidArgument(node_defs_.size() - processed,
|
||||
" nodes in a cycle");
|
||||
}
|
||||
|
||||
// Update unused_input_map_keys_
|
||||
if (unused_input_map_keys_ != nullptr) {
|
||||
for (const auto& pair : opts_.input_map) {
|
||||
if (used_input_map_keys_.find(pair.first) == used_input_map_keys_.end()) {
|
||||
unused_input_map_keys_->push_back(pair.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -943,8 +963,9 @@ Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
|
|||
Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
|
||||
const GraphDef& gdef, Graph* g) {
|
||||
ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
|
||||
return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
|
||||
&gdef.library(), g, &refiner, nullptr);
|
||||
return GraphConstructor::Construct(
|
||||
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
|
||||
/*return_tensors=*/nullptr, /*unused_input_map_keys=*/nullptr);
|
||||
}
|
||||
|
||||
Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
|
||||
|
|
@ -956,25 +977,33 @@ Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
|
|||
node_defs.push_back(&n);
|
||||
}
|
||||
return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
|
||||
&refiner, nullptr);
|
||||
&refiner, /*return_tensors=*/nullptr,
|
||||
/*unused_input_map_keys=*/nullptr);
|
||||
}
|
||||
|
||||
Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
|
||||
Graph* g, ShapeRefiner* refiner,
|
||||
std::vector<std::pair<Node*, int>>* return_tensors) {
|
||||
std::vector<std::pair<Node*, int>>* return_tensors,
|
||||
std::vector<TensorId>* unused_input_map_keys) {
|
||||
if (!opts.return_tensors.empty()) {
|
||||
if (return_tensors == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"return_tensors argument to ImportNodeDef() must be non-null if "
|
||||
"return_tensors argument to ImportGraphDef() must be non-null if "
|
||||
"opts.return_tensors is non-empty");
|
||||
}
|
||||
if (!return_tensors->empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"return_tensors argument to ImportNodeDef() should be empty (has "
|
||||
"return_tensors argument to ImportGraphDef() should be empty (has "
|
||||
"size ",
|
||||
return_tensors->size(), ")");
|
||||
}
|
||||
}
|
||||
if (unused_input_map_keys != nullptr && !unused_input_map_keys->empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"If non-null, unused_input_map_keys argument to ImportGraphDef() should"
|
||||
" be empty (has size ",
|
||||
unused_input_map_keys->size(), ")");
|
||||
}
|
||||
|
||||
ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
|
||||
if (refiner == nullptr) {
|
||||
|
|
@ -1007,7 +1036,7 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
|
|||
|
||||
return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
|
||||
&gdef.library(), g, refiner,
|
||||
return_tensors);
|
||||
return_tensors, unused_input_map_keys);
|
||||
}
|
||||
|
||||
void CopyGraph(const Graph& src, Graph* dest) {
|
||||
|
|
|
|||
|
|
@ -52,17 +52,7 @@ extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
|
|||
extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
|
||||
gtl::ArraySlice<NodeDef> nodes, Graph* g);
|
||||
|
||||
// Add the graph in GraphDef gdef into an existing Graph *g.
|
||||
//
|
||||
// On error, returns non-OK and leaves *g unmodified.
|
||||
//
|
||||
// "shape_refiner" can be null. It should be non-null if the caller
|
||||
// intends to add additional nodes to the graph after the import. This
|
||||
// allows the caller to validate shapes of those nodes (since
|
||||
// ShapeRefiner::AddNode must be called in topological order).
|
||||
//
|
||||
// TODO(ashankar): Push this mechanism and get rid of Session::Extend()
|
||||
// as a means of enhancing an existing Graph.
|
||||
// Options for calling ImportGraphDef().
|
||||
struct ImportGraphDefOptions {
|
||||
ImportGraphDefOptions() : skip_mapped_nodes(false) {}
|
||||
|
||||
|
|
@ -116,13 +106,30 @@ struct ImportGraphDefOptions {
|
|||
// python API.
|
||||
};
|
||||
|
||||
// Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
|
||||
//
|
||||
// On error, returns non-OK and leaves `*g` unmodified.
|
||||
//
|
||||
// `refiner` can be null. It should be non-null if the caller
|
||||
// intends to add additional nodes to the graph after the import. This
|
||||
// allows the caller to validate shapes of those nodes (since
|
||||
// ShapeRefiner::AddNode must be called in topological order).
|
||||
//
|
||||
// Each `return_tensors` entry is the requested node and output index. The index
|
||||
// is included in case the returned tensor has been remapped according to
|
||||
// `input_map`.
|
||||
//
|
||||
// If `unused_input_map_keys` is non-null, it should be empty and will be
|
||||
// populated with any keys in `opts.input_map` that aren't used as an input to
|
||||
// any node in `gdef`.
|
||||
//
|
||||
// TODO(ashankar): Push this mechanism and get rid of Session::Extend()
|
||||
// as a means of enhancing an existing Graph.
|
||||
extern Status ImportGraphDef(
|
||||
const ImportGraphDefOptions& opts, const GraphDef& gdef, Graph* g,
|
||||
ShapeRefiner* refiner,
|
||||
std::vector<std::pair<Node*, int>>* return_tensors = nullptr);
|
||||
std::vector<std::pair<Node*, int>>* return_tensors = nullptr,
|
||||
std::vector<TensorId>* unused_input_map_keys = nullptr);
|
||||
|
||||
// Make a copy of "src" into "*dest".
|
||||
//
|
||||
|
|
|
|||
|
|
@ -68,17 +68,17 @@ class GraphConstructorTest : public ::testing::Test {
|
|||
EXPECT_EQ(original_graph_description, GraphDebugString());
|
||||
}
|
||||
|
||||
void ExpectError(
|
||||
const string& gdef_ascii, const ImportGraphDefOptions& opts,
|
||||
const std::vector<string>& expected_error_strs,
|
||||
ShapeRefiner* refiner = nullptr,
|
||||
std::vector<std::pair<Node*, int>>* return_tensors = nullptr) {
|
||||
void ExpectError(const string& gdef_ascii, const ImportGraphDefOptions& opts,
|
||||
const std::vector<string>& expected_error_strs,
|
||||
ShapeRefiner* refiner = nullptr,
|
||||
std::vector<std::pair<Node*, int>>* return_tensors = nullptr,
|
||||
std::vector<TensorId>* unused_input_map_keys = nullptr) {
|
||||
// Used to verify that errors don't change graph
|
||||
const string original_graph_description = GraphDebugString();
|
||||
|
||||
Convert(gdef_ascii);
|
||||
Status status =
|
||||
ImportGraphDef(opts, gdef_, &graph_, refiner, return_tensors);
|
||||
Status status = ImportGraphDef(opts, gdef_, &graph_, refiner,
|
||||
return_tensors, unused_input_map_keys);
|
||||
EXPECT_FALSE(status.ok());
|
||||
|
||||
for (const string& error : expected_error_strs) {
|
||||
|
|
@ -97,9 +97,11 @@ class GraphConstructorTest : public ::testing::Test {
|
|||
|
||||
void ExpectOK(const string& gdef_ascii, const ImportGraphDefOptions& opts,
|
||||
ShapeRefiner* refiner = nullptr,
|
||||
std::vector<std::pair<Node*, int>>* return_tensors = nullptr) {
|
||||
std::vector<std::pair<Node*, int>>* return_tensors = nullptr,
|
||||
std::vector<TensorId>* unused_input_map_keys = nullptr) {
|
||||
Convert(gdef_ascii);
|
||||
Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, return_tensors);
|
||||
Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, return_tensors,
|
||||
unused_input_map_keys);
|
||||
EXPECT_EQ(Status::OK(), s) << s;
|
||||
}
|
||||
|
||||
|
|
@ -1279,8 +1281,9 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithControlEdges) {
|
|||
|
||||
// Create input_map containing control edges and use it to import more nodes
|
||||
ImportGraphDefOptions opts;
|
||||
opts.input_map[TensorId("W2", -1)] = TensorId("W1", -1);
|
||||
opts.input_map[TensorId("W3", -1)] = TensorId("W1", -1);
|
||||
const int kControlSlot = Graph::kControlSlot;
|
||||
opts.input_map[TensorId("W2", kControlSlot)] = TensorId("W1", kControlSlot);
|
||||
opts.input_map[TensorId("W3", kControlSlot)] = TensorId("W1", kControlSlot);
|
||||
ExpectOK(
|
||||
R"EOF(
|
||||
node { name: 'W2' op: 'TestParams' }
|
||||
|
|
@ -1316,7 +1319,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithControlEdges) {
|
|||
// node
|
||||
opts.prefix = "import";
|
||||
opts.input_map.clear();
|
||||
opts.input_map[TensorId("W1", -1)] = TensorId("W1", -1);
|
||||
opts.input_map[TensorId("W1", kControlSlot)] = TensorId("W1", kControlSlot);
|
||||
ExpectOK(
|
||||
R"EOF(
|
||||
node { name: 'W1' op: 'TestParams' }
|
||||
|
|
@ -1343,7 +1346,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithBadControlEdge) {
|
|||
|
||||
// Create input_map with bad control edge mapping
|
||||
ImportGraphDefOptions opts;
|
||||
opts.input_map[TensorId("W2", -1)] = TensorId("W1", 0);
|
||||
opts.input_map[TensorId("W2", Graph::kControlSlot)] = TensorId("W1", 0);
|
||||
ExpectError(
|
||||
R"EOF(
|
||||
node { name: 'W2' op: 'TestParams' }
|
||||
|
|
@ -1355,7 +1358,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithBadControlEdge) {
|
|||
|
||||
opts.input_map.clear();
|
||||
// "W2:0" isn't used in the imported graph but still causes an error
|
||||
opts.input_map[TensorId("W2", 0)] = TensorId("W1", -1);
|
||||
opts.input_map[TensorId("W2", 0)] = TensorId("W1", Graph::kControlSlot);
|
||||
ExpectError(
|
||||
R"EOF(
|
||||
node { name: 'W2' op: 'TestParams' }
|
||||
|
|
@ -1396,7 +1399,8 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithMissingEntries) {
|
|||
|
||||
// Create input_map referencing node that doesn't exist in graph
|
||||
ImportGraphDefOptions opts;
|
||||
opts.input_map[TensorId("W2", -1)] = TensorId("DNE", -1);
|
||||
const int kControlSlot = Graph::kControlSlot;
|
||||
opts.input_map[TensorId("W2", kControlSlot)] = TensorId("DNE", kControlSlot);
|
||||
ExpectError(
|
||||
R"EOF(
|
||||
node { name: 'W2' op: 'TestParams' }
|
||||
|
|
@ -1433,6 +1437,49 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapDuplicateNodeNames) {
|
|||
&refiner);
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) {
|
||||
ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
|
||||
|
||||
std::vector<TensorId> unused_input_map_keys;
|
||||
|
||||
// No input map
|
||||
ImportGraphDefOptions opts;
|
||||
ExpectOK(
|
||||
"node { name: 'W1' op: 'TestParams' }"
|
||||
"node { name: 'input' op: 'TestInput' }",
|
||||
opts, &refiner, nullptr, &unused_input_map_keys);
|
||||
EXPECT_TRUE(unused_input_map_keys.empty());
|
||||
|
||||
// Non-empty unused_input_map_keys
|
||||
unused_input_map_keys.push_back(TensorId());
|
||||
ExpectError("node { name: 'W2' op: 'TestParams' }", opts,
|
||||
{"If non-null, unused_input_map_keys argument to ImportGraphDef()"
|
||||
" should be empty (has size 1)"},
|
||||
&refiner, nullptr, &unused_input_map_keys);
|
||||
|
||||
// Input map with some used, some unused keys
|
||||
const int kControlSlot = Graph::kControlSlot;
|
||||
unused_input_map_keys.clear();
|
||||
opts.input_map[TensorId("W2", kControlSlot)] = TensorId("W1", kControlSlot);
|
||||
opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0);
|
||||
opts.input_map[TensorId("new_input", 1)] = TensorId("input", 0);
|
||||
opts.input_map[TensorId("new_input", kControlSlot)] =
|
||||
TensorId("input", kControlSlot);
|
||||
opts.input_map[TensorId("t1", 1)] = TensorId("input", 0);
|
||||
ExpectOK(
|
||||
R"EOF(
|
||||
node { name: 'W2' op: 'TestParams' }
|
||||
node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] }
|
||||
node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
|
||||
node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] }
|
||||
)EOF",
|
||||
opts, &refiner, nullptr, &unused_input_map_keys);
|
||||
|
||||
std::vector<TensorId> expected_unused_keys = {
|
||||
TensorId("new_input", kControlSlot), TensorId("t1", 1)};
|
||||
EXPECT_EQ(unused_input_map_keys, expected_unused_keys);
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, ImportGraphDef_SkipMappedNodes_FullyMapped) {
|
||||
ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
|
||||
|
||||
|
|
@ -1586,13 +1633,13 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensorsErrors) {
|
|||
// Null return_tensors with non-empty opts.return_tensors
|
||||
opts.return_tensors.push_back({"new_input", 0});
|
||||
ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
|
||||
{"return_tensors argument to ImportNodeDef() must be non-null "
|
||||
{"return_tensors argument to ImportGraphDef() must be non-null "
|
||||
"if opts.return_tensors is non-empty"});
|
||||
|
||||
// Non-empty return_tensors
|
||||
return_tensors.push_back({nullptr, 0});
|
||||
ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
|
||||
{"return_tensors argument to ImportNodeDef() should be empty "
|
||||
{"return_tensors argument to ImportGraphDef() should be empty "
|
||||
"(has size 1)"},
|
||||
nullptr, &return_tensors);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user