Add ability to fetch return nodes and unused input mappings from C API GraphDef import

This change introduces yet another ImportGraphDef function to the C
API (TF_GraphImportGraphDefWithResults), but this one has extensible
return values so we shouldn't have to add more in the future.

This change also modifies the ImportGraphDef C interface to manage all
string data for the user.

PiperOrigin-RevId: 173894710
This commit is contained in:
Skye Wanderman-Milne 2017-10-30 08:07:11 -07:00 committed by TensorFlower Gardener
parent ef4490f637
commit ce02381980
6 changed files with 369 additions and 84 deletions

View File

@ -86,6 +86,7 @@ using tensorflow::errors::FailedPrecondition;
using tensorflow::errors::InvalidArgument;
using tensorflow::gtl::ArraySlice;
using tensorflow::mutex_lock;
using tensorflow::string;
using tensorflow::strings::StrCat;
extern "C" {
@ -366,7 +367,7 @@ namespace {
// Reset helper for converting character arrays to string vectors.
void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
int ncontainers, TF_Status* status) {
std::vector<tensorflow::string> container_names(ncontainers);
std::vector<string> container_names(ncontainers);
for (int i = 0; i < ncontainers; ++i) {
container_names[i] = containers[i];
}
@ -482,7 +483,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
const char* limit = input + src_size;
*dst = Tensor(static_cast<DataType>(src->dtype), src->shape);
auto dstarray = dst->flat<tensorflow::string>();
auto dstarray = dst->flat<string>();
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
tensorflow::uint64 offset =
reinterpret_cast<const tensorflow::uint64*>(input)[i];
@ -556,9 +557,9 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
// Compute bytes needed for encoding.
size_t size = 0;
const auto& srcarray = src.flat<tensorflow::string>();
const auto& srcarray = src.flat<string>();
for (int i = 0; i < srcarray.size(); ++i) {
const tensorflow::string& s = srcarray(i);
const string& s = srcarray(i);
// uint64 starting_offset, TF_StringEncode-d string.
size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
}
@ -572,7 +573,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
for (int i = 0; i < srcarray.size(); ++i) {
*offsets = (dst - data_start);
offsets++;
const tensorflow::string& s = srcarray(i);
const string& s = srcarray(i);
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
if (!status->status.ok()) {
status->status = InvalidArgument(
@ -637,10 +638,9 @@ static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
}
}
static bool TF_Run_Inputs(
TF_Tensor* const* c_inputs,
std::vector<std::pair<tensorflow::string, Tensor>>* input_pairs,
TF_Status* status) {
static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
std::vector<std::pair<string, Tensor>>* input_pairs,
TF_Status* status) {
const int ninputs = input_pairs->size();
for (int i = 0; i < ninputs; ++i) {
status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
@ -652,13 +652,12 @@ static bool TF_Run_Inputs(
static void TF_Run_Helper(
Session* session, const char* handle, const TF_Buffer* run_options,
// Input tensors
const std::vector<std::pair<tensorflow::string, Tensor>>& input_pairs,
const std::vector<std::pair<string, Tensor>>& input_pairs,
// Output tensors
const std::vector<tensorflow::string>& output_tensor_names,
TF_Tensor** c_outputs,
const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
// Target nodes
const std::vector<tensorflow::string>& target_oper_names,
TF_Buffer* run_metadata, TF_Status* status) {
const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
TF_Status* status) {
const int noutputs = output_tensor_names.size();
std::vector<Tensor> outputs(noutputs);
Status result;
@ -718,16 +717,16 @@ void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
const char** c_target_oper_names, int ntargets,
TF_Buffer* run_metadata, TF_Status* status) {
TF_Run_Setup(noutputs, c_outputs, status);
std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
for (int i = 0; i < ninputs; ++i) {
input_pairs[i].first = c_input_names[i];
}
std::vector<tensorflow::string> output_names(noutputs);
std::vector<string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names[i] = c_output_names[i];
}
std::vector<tensorflow::string> target_oper_names(ntargets);
std::vector<string> target_oper_names(ntargets);
for (int i = 0; i < ntargets; ++i) {
target_oper_names[i] = c_target_oper_names[i];
}
@ -745,9 +744,9 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
const char** handle, TF_Status* status) {
*handle = nullptr;
std::vector<tensorflow::string> input_names(ninputs);
std::vector<tensorflow::string> output_names(noutputs);
std::vector<tensorflow::string> target_oper_names(ntargets);
std::vector<string> input_names(ninputs);
std::vector<string> output_names(noutputs);
std::vector<string> target_oper_names(ntargets);
for (int i = 0; i < ninputs; ++i) {
input_names[i] = c_input_names[i];
}
@ -757,7 +756,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
for (int i = 0; i < ntargets; ++i) {
target_oper_names[i] = c_target_oper_names[i];
}
tensorflow::string new_handle;
string new_handle;
status->status = s->session->PRunSetup(input_names, output_names,
target_oper_names, &new_handle);
if (status->status.ok()) {
@ -776,17 +775,17 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle,
const char** c_target_oper_names, int ntargets,
TF_Status* status) {
TF_Run_Setup(noutputs, c_outputs, status);
std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
for (int i = 0; i < ninputs; ++i) {
input_pairs[i].first = c_input_names[i];
}
std::vector<tensorflow::string> output_names(noutputs);
std::vector<string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names[i] = c_output_names[i];
}
std::vector<tensorflow::string> target_oper_names(ntargets);
std::vector<string> target_oper_names(ntargets);
for (int i = 0; i < ntargets; ++i) {
target_oper_names[i] = c_target_oper_names[i];
}
@ -881,7 +880,7 @@ TF_Operation* ToOperation(Node* node) {
return static_cast<TF_Operation*>(static_cast<void*>(node));
}
tensorflow::string OutputName(const TF_Output& output) {
string OutputName(const TF_Output& output) {
return StrCat(output.oper->node.name(), ":", output.index);
}
@ -1254,7 +1253,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
return;
}
desc->colocation_constraints.clear();
for (const tensorflow::string& location : attr_value.list().s()) {
for (const string& location : attr_value.list().s()) {
desc->colocation_constraints.insert(location);
}
} else {
@ -1276,8 +1275,8 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
if (!desc->colocation_constraints.empty()) {
desc->node_builder.Attr(
tensorflow::kColocationAttrName,
std::vector<tensorflow::string>(desc->colocation_constraints.begin(),
desc->colocation_constraints.end()));
std::vector<string>(desc->colocation_constraints.begin(),
desc->colocation_constraints.end()));
}
status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
@ -1500,7 +1499,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
const auto& a = oper->node.op_def().attr(i);
if (a.name().compare(attr_name) != 0) continue;
const tensorflow::string& typestr = a.type();
const string& typestr = a.type();
if (typestr == "list(string)") {
metadata.type = TF_ATTR_STRING;
} else if (typestr == "list(int)") {
@ -1580,7 +1579,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
const auto len = std::min(max_values, attr->list().s_size());
char* p = static_cast<char*>(storage);
for (int i = 0; i < len; ++i) {
const tensorflow::string& s = attr->list().s(i);
const string& s = attr->list().s(i);
values[i] = p;
lengths[i] = s.size();
if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
@ -1824,7 +1823,11 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
const char* src_name,
int src_index, TF_Output dst) {
opts->opts.input_map[TensorId(src_name, src_index)] = ToTensorId(dst);
opts->tensor_id_data.push_back(src_name);
const string& src_name_str = opts->tensor_id_data.back();
// We don't need to store dst's name in tensor_id_data, since `dst` must
// outlive the ImportGraphDef call.
opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
}
void TF_ImportGraphDefOptionsRemapControlDependency(
@ -1840,7 +1843,9 @@ extern void TF_ImportGraphDefOptionsAddControlDependency(
void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
const char* oper_name, int index) {
opts->opts.return_tensors.push_back({oper_name, index});
opts->tensor_id_data.push_back(oper_name);
const string& oper_name_str = opts->tensor_id_data.back();
opts->opts.return_tensors.emplace_back(oper_name_str, index);
}
int TF_ImportGraphDefOptionsNumReturnOutputs(
@ -1848,15 +1853,116 @@ int TF_ImportGraphDefOptionsNumReturnOutputs(
return opts->opts.return_tensors.size();
}
void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
const char* oper_name) {
opts->opts.return_nodes.push_back(oper_name);
}
int TF_ImportGraphDefOptionsNumReturnOperations(
const TF_ImportGraphDefOptions* opts) {
return opts->opts.return_nodes.size();
}
void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
int* num_outputs,
TF_Output** outputs) {
*num_outputs = results->return_tensors.size();
*outputs = results->return_tensors.data();
}
void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
int* num_opers,
TF_Operation*** opers) {
*num_opers = results->return_nodes.size();
*opers = results->return_nodes.data();
}
void TF_ImportGraphDefResultsUnusedInputMappings(
TF_ImportGraphDefResults* results, int* num_unused_input_mappings,
const char*** src_names, int** src_indexes) {
*num_unused_input_mappings = results->unused_key_names.size();
*src_names = results->unused_key_names.data();
*src_indexes = results->unused_key_indexes.data();
}
void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
delete results;
}
static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
const TF_ImportGraphDefOptions* opts,
TF_Output* return_outputs,
int num_return_outputs, TF_Status* status)
TF_ImportGraphDefResults* tf_results,
TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
if (num_return_outputs != opts->opts.return_tensors.size()) {
const int last_node_id = graph->graph.num_node_ids();
tensorflow::ImportGraphDefResults results;
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
&graph->refiner, &results);
if (!status->status.ok()) return;
// Add new nodes to name_map
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
auto* node = graph->graph.FindNodeId(i);
if (node != nullptr) graph->name_map[node->name()] = node;
}
// Populate return_tensors
DCHECK(tf_results->return_tensors.empty());
tf_results->return_tensors.resize(results.return_tensors.size());
for (int i = 0; i < results.return_tensors.size(); ++i) {
tf_results->return_tensors[i].oper =
ToOperation(results.return_tensors[i].first);
tf_results->return_tensors[i].index = results.return_tensors[i].second;
}
// Populate return_nodes
DCHECK(tf_results->return_nodes.empty());
tf_results->return_nodes.resize(results.return_nodes.size());
for (int i = 0; i < results.return_nodes.size(); ++i) {
tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
}
// Populate unused map keys
DCHECK(tf_results->unused_key_names.empty());
DCHECK(tf_results->unused_key_indexes.empty());
DCHECK(tf_results->unused_key_names_data.empty());
tf_results->unused_key_names.resize(results.unused_input_map_keys.size());
tf_results->unused_key_indexes.resize(results.unused_input_map_keys.size());
for (int i = 0; i < results.unused_input_map_keys.size(); ++i) {
TensorId id = results.unused_input_map_keys[i];
tf_results->unused_key_names_data.push_back(id.first.ToString());
tf_results->unused_key_names[i] =
tf_results->unused_key_names_data.back().c_str();
tf_results->unused_key_indexes[i] = id.second;
}
}
TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options, TF_Status* status) {
GraphDef def;
if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
status->status = InvalidArgument("Invalid GraphDef");
return nullptr;
}
auto results = new TF_ImportGraphDefResults();
mutex_lock l(graph->mu);
GraphImportGraphDefLocked(graph, def, options, results, status);
if (!status->status.ok()) {
delete results;
return nullptr;
}
return results;
}
void TF_GraphImportGraphDefWithReturnOutputs(
TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
int num_return_outputs, TF_Status* status) {
if (num_return_outputs != options->opts.return_tensors.size()) {
status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
opts->opts.return_tensors.size(), ", got ",
num_return_outputs);
options->opts.return_tensors.size(),
", got ", num_return_outputs);
return;
}
if (num_return_outputs > 0 && return_outputs == nullptr) {
@ -1864,41 +1970,25 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
"'return_outputs' must be preallocated to length ", num_return_outputs);
return;
}
const int last_node_id = graph->graph.num_node_ids();
tensorflow::ImportGraphDefResults results;
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
&graph->refiner, &results);
if (!status->status.ok()) return;
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
auto* node = graph->graph.FindNodeId(i);
if (node != nullptr) graph->name_map[node->name()] = node;
}
DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
for (int i = 0; i < num_return_outputs; ++i) {
return_outputs[i].oper = ToOperation(results.return_tensors[i].first);
return_outputs[i].index = results.return_tensors[i].second;
}
}
void TF_GraphImportGraphDefWithReturnOutputs(
TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* opts, TF_Output* return_outputs,
int num_return_outputs, TF_Status* status) {
GraphDef def;
if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
status->status = InvalidArgument("Invalid GraphDef");
return;
}
TF_ImportGraphDefResults results;
mutex_lock l(graph->mu);
GraphImportGraphDefLocked(graph, def, opts, return_outputs,
num_return_outputs, status);
GraphImportGraphDefLocked(graph, def, options, &results, status);
DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
memcpy(return_outputs, results.return_tensors.data(),
num_return_outputs * sizeof(TF_Output));
}
void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options,
TF_Status* status) {
TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, options, nullptr, 0,
status);
TF_ImportGraphDefResults* results =
TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
TF_DeleteImportGraphDefResults(results);
}
// While loop functions -------------------------------------------------------
@ -1930,7 +2020,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph,
tensorflow::ShapeRefiner* dst_refiner,
const TF_Output* src_inputs,
const std::vector<tensorflow::Output>& dst_inputs,
const tensorflow::string& prefix,
const string& prefix,
const std::vector<tensorflow::Operation>& control_deps,
const TF_Output* nodes_to_return, int nreturn_nodes,
std::vector<tensorflow::Output>* return_nodes) {
@ -2257,9 +2347,9 @@ TF_Session* TF_LoadSessionFromSavedModel(
return nullptr;
}
std::unordered_set<tensorflow::string> tag_set;
std::unordered_set<string> tag_set;
for (int i = 0; i < tags_len; i++) {
tag_set.insert(tensorflow::string(tags[i]));
tag_set.insert(string(tags[i]));
}
tensorflow::SavedModelBundle bundle;
@ -2275,8 +2365,9 @@ TF_Session* TF_LoadSessionFromSavedModel(
// TODO(jhseu): When Session is modified to take Graphs instead of
// GraphDefs, return the Graph generated in LoadSavedModel().
TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
TF_ImportGraphDefResults results;
GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
import_opts, nullptr, 0, status);
import_opts, &results, status);
TF_DeleteImportGraphDefOptions(import_opts);
if (TF_GetCode(status) != TF_OK) return nullptr;
@ -2372,20 +2463,20 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
TF_Run_Setup(noutputs, output_values, status);
// Convert from TF_Output and TF_Tensor to a string and Tensor.
std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
for (int i = 0; i < ninputs; ++i) {
input_pairs[i].first = OutputName(inputs[i]);
}
// Convert from TF_Output to string names.
std::vector<tensorflow::string> output_names(noutputs);
std::vector<string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names[i] = OutputName(outputs[i]);
}
// Convert from TF_Operation* to string names.
std::vector<tensorflow::string> target_names(ntargets);
std::vector<string> target_names(ntargets);
for (int i = 0; i < ntargets; ++i) {
target_names[i] = target_opers[i]->node.name();
}
@ -2406,22 +2497,22 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
return;
}
std::vector<tensorflow::string> input_names(ninputs);
std::vector<string> input_names(ninputs);
for (int i = 0; i < ninputs; ++i) {
input_names[i] = OutputName(inputs[i]);
}
std::vector<tensorflow::string> output_names(noutputs);
std::vector<string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names[i] = OutputName(outputs[i]);
}
std::vector<tensorflow::string> target_names(ntargets);
std::vector<string> target_names(ntargets);
for (int i = 0; i < ntargets; ++i) {
target_names[i] = target_opers[i]->node.name();
}
tensorflow::string new_handle;
string new_handle;
status->status = session->session->PRunSetup(input_names, output_names,
target_names, &new_handle);
if (status->status.ok()) {
@ -2452,20 +2543,20 @@ void TF_SessionPRun(TF_Session* session, const char* handle,
TF_Run_Setup(noutputs, output_values, status);
// Convert from TF_Output and TF_Tensor to a string and Tensor.
std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
for (int i = 0; i < ninputs; ++i) {
input_pairs[i].first = OutputName(inputs[i]);
}
// Convert from TF_Output to string names.
std::vector<tensorflow::string> output_names(noutputs);
std::vector<string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names[i] = OutputName(outputs[i]);
}
// Convert from TF_Operation* to string names.
std::vector<tensorflow::string> target_names(ntargets);
std::vector<string> target_names(ntargets);
for (int i = 0; i < ntargets; ++i) {
target_names[i] = target_opers[i]->node.name();
}

View File

@ -914,7 +914,62 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput(
TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs(
const TF_ImportGraphDefOptions* opts);
// Add an operation in `graph_def` to be returned via the `return_opers` output
// parameter of TF_GraphImportGraphDef().
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation(
TF_ImportGraphDefOptions* opts, const char* oper_name);
// Returns the number of return operations added via
// TF_ImportGraphDefOptionsAddReturnOperation().
TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOperations(
const TF_ImportGraphDefOptions* opts);
// TF_ImportGraphDefResults holds results that are generated by
// TF_GraphImportGraphDefWithResults().
typedef struct TF_ImportGraphDefResults TF_ImportGraphDefResults;
// Fetches the return outputs requested via
// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is
// returned in `num_outputs`. The array of return outputs is returned in
// `outputs`. `*outputs` is owned by and has the lifetime of `results`.
TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOutputs(
TF_ImportGraphDefResults* results, int* num_outputs, TF_Output** outputs);
// Fetches the return operations requested via
// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched
// operations is returned in `num_opers`. The array of return operations is
// returned in `opers`. `*opers` is owned by and has the lifetime of `results`.
TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations(
TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers);
// Fetches any input mappings requested via
// TF_ImportGraphDefOptionsAddInputMapping() that weren't used as input to any
// node in the imported graph def. The number of fetched mappings is returned in
// `num_unused_input_mappings`. The array of each mapping's source node name is
// returned in `src_names`, and the array of each mapping's source index is
// returned in `src_indexes`.
//
// `*src_names`, `*src_indexes`, and the memory backing each string in
// `src_names` are owned by and have the lifetime of `results`.
TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsUnusedInputMappings(
TF_ImportGraphDefResults* results, int* num_unused_input_mappings,
const char*** src_names, int** src_indexes);
// Deletes a results object returned by TF_GraphImportGraphDefWithResults().
TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefResults(
TF_ImportGraphDefResults* results);
// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and
// a bad status on error. Otherwise, returns a populated
// TF_ImportGraphDefResults instance. The returned instance must be deleted via
// TF_DeleteImportGraphDefResults().
TF_CAPI_EXPORT extern TF_ImportGraphDefResults*
TF_GraphImportGraphDefWithResults(TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options,
TF_Status* status);
// Import the graph serialized in `graph_def` into `graph`.
// Convenience function for when only return outputs are needed.
//
// `num_return_outputs` must be the number of return outputs added (i.e. the
// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If
@ -926,7 +981,7 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDefWithReturnOutputs(
int num_return_outputs, TF_Status* status);
// Import the graph serialized in `graph_def` into `graph`.
// Convenience function for when no return outputs have been added.
// Convenience function for when no results are needed.
TF_CAPI_EXPORT extern void TF_GraphImportGraphDef(
TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options, TF_Status* status);

View File

@ -18,7 +18,9 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include <list>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
@ -124,6 +126,20 @@ struct TF_Session {
struct TF_ImportGraphDefOptions {
tensorflow::ImportGraphDefOptions opts;
// Backing memory for TensorId fields in opts.
// TODO(skyewm): it'd be better if ImportGraphDefOptions owned this.
std::list<tensorflow::string> tensor_id_data;
};
struct TF_ImportGraphDefResults {
std::vector<TF_Output> return_tensors;
std::vector<TF_Operation*> return_nodes;
std::vector<const char*> unused_key_names;
std::vector<int> unused_key_indexes;
// Backing memory for unused_key_names values.
std::list<tensorflow::string> unused_key_names_data;
};
struct TF_DeviceList {

View File

@ -573,7 +573,7 @@ TEST(CAPI, ImportGraphDef) {
TF_GraphToGraphDef(graph, graph_def, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Import it again, with a prefix, in a fresh graph.
// Import it, with a prefix, in a fresh graph.
TF_DeleteGraph(graph);
graph = TF_NewGraph();
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
@ -588,8 +588,8 @@ TEST(CAPI, ImportGraphDef) {
ASSERT_TRUE(feed != nullptr);
ASSERT_TRUE(neg != nullptr);
// Import it again, with an input mapping and return outputs, into the same
// graph.
// Import it again, with an input mapping, return outputs, and a return
// operation, into the same graph.
TF_DeleteImportGraphDefOptions(opts);
opts = TF_NewImportGraphDefOptions();
TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
@ -597,9 +597,10 @@ TEST(CAPI, ImportGraphDef) {
TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
TF_Output return_outputs[2];
TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts,
return_outputs, 2, s);
TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
EXPECT_EQ(1, TF_ImportGraphDefOptionsNumReturnOperations(opts));
TF_ImportGraphDefResults* results =
TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar");
@ -615,11 +616,26 @@ TEST(CAPI, ImportGraphDef) {
EXPECT_EQ(0, neg_input.index);
// Check return outputs
TF_Output* return_outputs;
int num_return_outputs;
TF_ImportGraphDefResultsReturnOutputs(results, &num_return_outputs,
&return_outputs);
ASSERT_EQ(2, num_return_outputs);
EXPECT_EQ(feed2, return_outputs[0].oper);
EXPECT_EQ(0, return_outputs[0].index);
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
EXPECT_EQ(0, return_outputs[1].index);
// Check return operation
TF_Operation** return_opers;
int num_return_opers;
TF_ImportGraphDefResultsReturnOperations(results, &num_return_opers,
&return_opers);
ASSERT_EQ(1, num_return_opers);
EXPECT_EQ(scalar2, return_opers[0]); // not remapped
TF_DeleteImportGraphDefResults(results);
// Import again, with control dependencies, into the same graph.
TF_DeleteImportGraphDefOptions(opts);
opts = TF_NewImportGraphDefOptions();
@ -689,6 +705,113 @@ TEST(CAPI, ImportGraphDef) {
TF_DeleteStatus(s);
}
TEST(CAPI, ImportGraphDef_WithReturnOutputs) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
// Create a graph with two nodes: x and 3
Placeholder(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
TF_Operation* oper = ScalarConst(3, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
Neg(oper, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
// Export to a GraphDef.
TF_Buffer* graph_def = TF_NewBuffer();
TF_GraphToGraphDef(graph, graph_def, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Import it in a fresh graph with return outputs.
TF_DeleteGraph(graph);
graph = TF_NewGraph();
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
TF_Output return_outputs[2];
TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts,
return_outputs, 2, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
TF_Operation* feed = TF_GraphOperationByName(graph, "feed");
TF_Operation* neg = TF_GraphOperationByName(graph, "neg");
ASSERT_TRUE(scalar != nullptr);
ASSERT_TRUE(feed != nullptr);
ASSERT_TRUE(neg != nullptr);
// Check return outputs
EXPECT_EQ(feed, return_outputs[0].oper);
EXPECT_EQ(0, return_outputs[0].index);
EXPECT_EQ(scalar, return_outputs[1].oper);
EXPECT_EQ(0, return_outputs[1].index);
TF_DeleteImportGraphDefOptions(opts);
TF_DeleteBuffer(graph_def);
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
}
TEST(CAPI, ImportGraphDef_UnusedInputMappings) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
// Create a graph with two nodes: x and 3
Placeholder(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
TF_Operation* oper = ScalarConst(3, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
Neg(oper, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
// Export to a GraphDef.
TF_Buffer* graph_def = TF_NewBuffer();
TF_GraphToGraphDef(graph, graph_def, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Import it in a fresh graph.
TF_DeleteGraph(graph);
graph = TF_NewGraph();
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
TF_GraphImportGraphDef(graph, graph_def, opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
// Import it in a fresh graph with an unused input mapping.
TF_DeleteImportGraphDefOptions(opts);
opts = TF_NewImportGraphDefOptions();
TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0});
TF_ImportGraphDefOptionsAddInputMapping(opts, "fake", 0, {scalar, 0});
TF_ImportGraphDefResults* results =
TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Check unused input mappings
int num_unused_input_mappings;
const char** src_names;
int* src_indexes;
TF_ImportGraphDefResultsUnusedInputMappings(
results, &num_unused_input_mappings, &src_names, &src_indexes);
ASSERT_EQ(1, num_unused_input_mappings);
EXPECT_EQ(string("fake"), string(src_names[0]));
EXPECT_EQ(0, src_indexes[0]);
TF_DeleteImportGraphDefResults(results);
TF_DeleteImportGraphDefOptions(opts);
TF_DeleteBuffer(graph_def);
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
}
TEST(CAPI, Session) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();

View File

@ -90,7 +90,7 @@ class GraphConstructor {
bool skip_mapped_nodes;
std::vector<string> control_dependencies;
std::vector<TensorId> return_tensors;
std::vector<StringPiece> return_nodes;
std::vector<string> return_nodes;
// TODO(ashankar): This bool exists to separate out functionality required
// to make ImportGraphDef a close equivalent of Python's import_graph_def

View File

@ -110,7 +110,7 @@ struct ImportGraphDefOptions {
// Unlike `return_tensors`, `input_map` has no effect on the nodes
// returned. `return_nodes` must be empty if `skip_mapped_nodes` is true.
// TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need.
std::vector<StringPiece> return_nodes;
std::vector<string> return_nodes;
// TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
// with ops that are not defined in the binary calling ImportGraphDef.