mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Add ability to fetch return nodes and unused input mappings from C API GraphDef import
This change introduces yet another ImportGraphDef function to the C API (TF_GraphImportGraphDefWithResults), but this one has extensible return values so we shouldn't have to add more in the future. This change also modifies the ImportGraphDef C interface to manage all string data for the user. PiperOrigin-RevId: 173894710
This commit is contained in:
parent
ef4490f637
commit
ce02381980
|
|
@ -86,6 +86,7 @@ using tensorflow::errors::FailedPrecondition;
|
|||
using tensorflow::errors::InvalidArgument;
|
||||
using tensorflow::gtl::ArraySlice;
|
||||
using tensorflow::mutex_lock;
|
||||
using tensorflow::string;
|
||||
using tensorflow::strings::StrCat;
|
||||
|
||||
extern "C" {
|
||||
|
|
@ -366,7 +367,7 @@ namespace {
|
|||
// Reset helper for converting character arrays to string vectors.
|
||||
void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
|
||||
int ncontainers, TF_Status* status) {
|
||||
std::vector<tensorflow::string> container_names(ncontainers);
|
||||
std::vector<string> container_names(ncontainers);
|
||||
for (int i = 0; i < ncontainers; ++i) {
|
||||
container_names[i] = containers[i];
|
||||
}
|
||||
|
|
@ -482,7 +483,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
|||
const char* limit = input + src_size;
|
||||
|
||||
*dst = Tensor(static_cast<DataType>(src->dtype), src->shape);
|
||||
auto dstarray = dst->flat<tensorflow::string>();
|
||||
auto dstarray = dst->flat<string>();
|
||||
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
|
||||
tensorflow::uint64 offset =
|
||||
reinterpret_cast<const tensorflow::uint64*>(input)[i];
|
||||
|
|
@ -556,9 +557,9 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
|||
|
||||
// Compute bytes needed for encoding.
|
||||
size_t size = 0;
|
||||
const auto& srcarray = src.flat<tensorflow::string>();
|
||||
const auto& srcarray = src.flat<string>();
|
||||
for (int i = 0; i < srcarray.size(); ++i) {
|
||||
const tensorflow::string& s = srcarray(i);
|
||||
const string& s = srcarray(i);
|
||||
// uint64 starting_offset, TF_StringEncode-d string.
|
||||
size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
|
||||
}
|
||||
|
|
@ -572,7 +573,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
|||
for (int i = 0; i < srcarray.size(); ++i) {
|
||||
*offsets = (dst - data_start);
|
||||
offsets++;
|
||||
const tensorflow::string& s = srcarray(i);
|
||||
const string& s = srcarray(i);
|
||||
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
|
||||
if (!status->status.ok()) {
|
||||
status->status = InvalidArgument(
|
||||
|
|
@ -637,10 +638,9 @@ static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
|
|||
}
|
||||
}
|
||||
|
||||
static bool TF_Run_Inputs(
|
||||
TF_Tensor* const* c_inputs,
|
||||
std::vector<std::pair<tensorflow::string, Tensor>>* input_pairs,
|
||||
TF_Status* status) {
|
||||
static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
|
||||
std::vector<std::pair<string, Tensor>>* input_pairs,
|
||||
TF_Status* status) {
|
||||
const int ninputs = input_pairs->size();
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
|
||||
|
|
@ -652,13 +652,12 @@ static bool TF_Run_Inputs(
|
|||
static void TF_Run_Helper(
|
||||
Session* session, const char* handle, const TF_Buffer* run_options,
|
||||
// Input tensors
|
||||
const std::vector<std::pair<tensorflow::string, Tensor>>& input_pairs,
|
||||
const std::vector<std::pair<string, Tensor>>& input_pairs,
|
||||
// Output tensors
|
||||
const std::vector<tensorflow::string>& output_tensor_names,
|
||||
TF_Tensor** c_outputs,
|
||||
const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
|
||||
// Target nodes
|
||||
const std::vector<tensorflow::string>& target_oper_names,
|
||||
TF_Buffer* run_metadata, TF_Status* status) {
|
||||
const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
|
||||
TF_Status* status) {
|
||||
const int noutputs = output_tensor_names.size();
|
||||
std::vector<Tensor> outputs(noutputs);
|
||||
Status result;
|
||||
|
|
@ -718,16 +717,16 @@ void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
|
|||
const char** c_target_oper_names, int ntargets,
|
||||
TF_Buffer* run_metadata, TF_Status* status) {
|
||||
TF_Run_Setup(noutputs, c_outputs, status);
|
||||
std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
|
||||
std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
|
||||
if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
input_pairs[i].first = c_input_names[i];
|
||||
}
|
||||
std::vector<tensorflow::string> output_names(noutputs);
|
||||
std::vector<string> output_names(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
output_names[i] = c_output_names[i];
|
||||
}
|
||||
std::vector<tensorflow::string> target_oper_names(ntargets);
|
||||
std::vector<string> target_oper_names(ntargets);
|
||||
for (int i = 0; i < ntargets; ++i) {
|
||||
target_oper_names[i] = c_target_oper_names[i];
|
||||
}
|
||||
|
|
@ -745,9 +744,9 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
|
|||
const char** handle, TF_Status* status) {
|
||||
*handle = nullptr;
|
||||
|
||||
std::vector<tensorflow::string> input_names(ninputs);
|
||||
std::vector<tensorflow::string> output_names(noutputs);
|
||||
std::vector<tensorflow::string> target_oper_names(ntargets);
|
||||
std::vector<string> input_names(ninputs);
|
||||
std::vector<string> output_names(noutputs);
|
||||
std::vector<string> target_oper_names(ntargets);
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
input_names[i] = c_input_names[i];
|
||||
}
|
||||
|
|
@ -757,7 +756,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
|
|||
for (int i = 0; i < ntargets; ++i) {
|
||||
target_oper_names[i] = c_target_oper_names[i];
|
||||
}
|
||||
tensorflow::string new_handle;
|
||||
string new_handle;
|
||||
status->status = s->session->PRunSetup(input_names, output_names,
|
||||
target_oper_names, &new_handle);
|
||||
if (status->status.ok()) {
|
||||
|
|
@ -776,17 +775,17 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle,
|
|||
const char** c_target_oper_names, int ntargets,
|
||||
TF_Status* status) {
|
||||
TF_Run_Setup(noutputs, c_outputs, status);
|
||||
std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
|
||||
std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
|
||||
if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
input_pairs[i].first = c_input_names[i];
|
||||
}
|
||||
|
||||
std::vector<tensorflow::string> output_names(noutputs);
|
||||
std::vector<string> output_names(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
output_names[i] = c_output_names[i];
|
||||
}
|
||||
std::vector<tensorflow::string> target_oper_names(ntargets);
|
||||
std::vector<string> target_oper_names(ntargets);
|
||||
for (int i = 0; i < ntargets; ++i) {
|
||||
target_oper_names[i] = c_target_oper_names[i];
|
||||
}
|
||||
|
|
@ -881,7 +880,7 @@ TF_Operation* ToOperation(Node* node) {
|
|||
return static_cast<TF_Operation*>(static_cast<void*>(node));
|
||||
}
|
||||
|
||||
tensorflow::string OutputName(const TF_Output& output) {
|
||||
string OutputName(const TF_Output& output) {
|
||||
return StrCat(output.oper->node.name(), ":", output.index);
|
||||
}
|
||||
|
||||
|
|
@ -1254,7 +1253,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
|
|||
return;
|
||||
}
|
||||
desc->colocation_constraints.clear();
|
||||
for (const tensorflow::string& location : attr_value.list().s()) {
|
||||
for (const string& location : attr_value.list().s()) {
|
||||
desc->colocation_constraints.insert(location);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -1276,8 +1275,8 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
|
|||
if (!desc->colocation_constraints.empty()) {
|
||||
desc->node_builder.Attr(
|
||||
tensorflow::kColocationAttrName,
|
||||
std::vector<tensorflow::string>(desc->colocation_constraints.begin(),
|
||||
desc->colocation_constraints.end()));
|
||||
std::vector<string>(desc->colocation_constraints.begin(),
|
||||
desc->colocation_constraints.end()));
|
||||
}
|
||||
status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
|
||||
|
||||
|
|
@ -1500,7 +1499,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
|
|||
for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
|
||||
const auto& a = oper->node.op_def().attr(i);
|
||||
if (a.name().compare(attr_name) != 0) continue;
|
||||
const tensorflow::string& typestr = a.type();
|
||||
const string& typestr = a.type();
|
||||
if (typestr == "list(string)") {
|
||||
metadata.type = TF_ATTR_STRING;
|
||||
} else if (typestr == "list(int)") {
|
||||
|
|
@ -1580,7 +1579,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
|
|||
const auto len = std::min(max_values, attr->list().s_size());
|
||||
char* p = static_cast<char*>(storage);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
const tensorflow::string& s = attr->list().s(i);
|
||||
const string& s = attr->list().s(i);
|
||||
values[i] = p;
|
||||
lengths[i] = s.size();
|
||||
if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
|
||||
|
|
@ -1824,7 +1823,11 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
|
|||
void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
|
||||
const char* src_name,
|
||||
int src_index, TF_Output dst) {
|
||||
opts->opts.input_map[TensorId(src_name, src_index)] = ToTensorId(dst);
|
||||
opts->tensor_id_data.push_back(src_name);
|
||||
const string& src_name_str = opts->tensor_id_data.back();
|
||||
// We don't need to store dst's name in tensor_id_data, since `dst` must
|
||||
// outlive the ImportGraphDef call.
|
||||
opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
|
||||
}
|
||||
|
||||
void TF_ImportGraphDefOptionsRemapControlDependency(
|
||||
|
|
@ -1840,7 +1843,9 @@ extern void TF_ImportGraphDefOptionsAddControlDependency(
|
|||
|
||||
void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
|
||||
const char* oper_name, int index) {
|
||||
opts->opts.return_tensors.push_back({oper_name, index});
|
||||
opts->tensor_id_data.push_back(oper_name);
|
||||
const string& oper_name_str = opts->tensor_id_data.back();
|
||||
opts->opts.return_tensors.emplace_back(oper_name_str, index);
|
||||
}
|
||||
|
||||
int TF_ImportGraphDefOptionsNumReturnOutputs(
|
||||
|
|
@ -1848,15 +1853,116 @@ int TF_ImportGraphDefOptionsNumReturnOutputs(
|
|||
return opts->opts.return_tensors.size();
|
||||
}
|
||||
|
||||
void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
|
||||
const char* oper_name) {
|
||||
opts->opts.return_nodes.push_back(oper_name);
|
||||
}
|
||||
|
||||
int TF_ImportGraphDefOptionsNumReturnOperations(
|
||||
const TF_ImportGraphDefOptions* opts) {
|
||||
return opts->opts.return_nodes.size();
|
||||
}
|
||||
|
||||
void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
|
||||
int* num_outputs,
|
||||
TF_Output** outputs) {
|
||||
*num_outputs = results->return_tensors.size();
|
||||
*outputs = results->return_tensors.data();
|
||||
}
|
||||
|
||||
void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
|
||||
int* num_opers,
|
||||
TF_Operation*** opers) {
|
||||
*num_opers = results->return_nodes.size();
|
||||
*opers = results->return_nodes.data();
|
||||
}
|
||||
|
||||
void TF_ImportGraphDefResultsUnusedInputMappings(
|
||||
TF_ImportGraphDefResults* results, int* num_unused_input_mappings,
|
||||
const char*** src_names, int** src_indexes) {
|
||||
*num_unused_input_mappings = results->unused_key_names.size();
|
||||
*src_names = results->unused_key_names.data();
|
||||
*src_indexes = results->unused_key_indexes.data();
|
||||
}
|
||||
|
||||
void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
|
||||
delete results;
|
||||
}
|
||||
|
||||
static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
|
||||
const TF_ImportGraphDefOptions* opts,
|
||||
TF_Output* return_outputs,
|
||||
int num_return_outputs, TF_Status* status)
|
||||
TF_ImportGraphDefResults* tf_results,
|
||||
TF_Status* status)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
|
||||
if (num_return_outputs != opts->opts.return_tensors.size()) {
|
||||
const int last_node_id = graph->graph.num_node_ids();
|
||||
tensorflow::ImportGraphDefResults results;
|
||||
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
|
||||
&graph->refiner, &results);
|
||||
if (!status->status.ok()) return;
|
||||
|
||||
// Add new nodes to name_map
|
||||
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
|
||||
auto* node = graph->graph.FindNodeId(i);
|
||||
if (node != nullptr) graph->name_map[node->name()] = node;
|
||||
}
|
||||
|
||||
// Populate return_tensors
|
||||
DCHECK(tf_results->return_tensors.empty());
|
||||
tf_results->return_tensors.resize(results.return_tensors.size());
|
||||
for (int i = 0; i < results.return_tensors.size(); ++i) {
|
||||
tf_results->return_tensors[i].oper =
|
||||
ToOperation(results.return_tensors[i].first);
|
||||
tf_results->return_tensors[i].index = results.return_tensors[i].second;
|
||||
}
|
||||
|
||||
// Populate return_nodes
|
||||
DCHECK(tf_results->return_nodes.empty());
|
||||
tf_results->return_nodes.resize(results.return_nodes.size());
|
||||
for (int i = 0; i < results.return_nodes.size(); ++i) {
|
||||
tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
|
||||
}
|
||||
|
||||
// Populate unused map keys
|
||||
DCHECK(tf_results->unused_key_names.empty());
|
||||
DCHECK(tf_results->unused_key_indexes.empty());
|
||||
DCHECK(tf_results->unused_key_names_data.empty());
|
||||
tf_results->unused_key_names.resize(results.unused_input_map_keys.size());
|
||||
tf_results->unused_key_indexes.resize(results.unused_input_map_keys.size());
|
||||
for (int i = 0; i < results.unused_input_map_keys.size(); ++i) {
|
||||
TensorId id = results.unused_input_map_keys[i];
|
||||
tf_results->unused_key_names_data.push_back(id.first.ToString());
|
||||
tf_results->unused_key_names[i] =
|
||||
tf_results->unused_key_names_data.back().c_str();
|
||||
tf_results->unused_key_indexes[i] = id.second;
|
||||
}
|
||||
}
|
||||
|
||||
TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
|
||||
TF_Graph* graph, const TF_Buffer* graph_def,
|
||||
const TF_ImportGraphDefOptions* options, TF_Status* status) {
|
||||
GraphDef def;
|
||||
if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
|
||||
status->status = InvalidArgument("Invalid GraphDef");
|
||||
return nullptr;
|
||||
}
|
||||
auto results = new TF_ImportGraphDefResults();
|
||||
mutex_lock l(graph->mu);
|
||||
GraphImportGraphDefLocked(graph, def, options, results, status);
|
||||
if (!status->status.ok()) {
|
||||
delete results;
|
||||
return nullptr;
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
void TF_GraphImportGraphDefWithReturnOutputs(
|
||||
TF_Graph* graph, const TF_Buffer* graph_def,
|
||||
const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
|
||||
int num_return_outputs, TF_Status* status) {
|
||||
if (num_return_outputs != options->opts.return_tensors.size()) {
|
||||
status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
|
||||
opts->opts.return_tensors.size(), ", got ",
|
||||
num_return_outputs);
|
||||
options->opts.return_tensors.size(),
|
||||
", got ", num_return_outputs);
|
||||
return;
|
||||
}
|
||||
if (num_return_outputs > 0 && return_outputs == nullptr) {
|
||||
|
|
@ -1864,41 +1970,25 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
|
|||
"'return_outputs' must be preallocated to length ", num_return_outputs);
|
||||
return;
|
||||
}
|
||||
const int last_node_id = graph->graph.num_node_ids();
|
||||
tensorflow::ImportGraphDefResults results;
|
||||
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
|
||||
&graph->refiner, &results);
|
||||
if (!status->status.ok()) return;
|
||||
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
|
||||
auto* node = graph->graph.FindNodeId(i);
|
||||
if (node != nullptr) graph->name_map[node->name()] = node;
|
||||
}
|
||||
DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
|
||||
for (int i = 0; i < num_return_outputs; ++i) {
|
||||
return_outputs[i].oper = ToOperation(results.return_tensors[i].first);
|
||||
return_outputs[i].index = results.return_tensors[i].second;
|
||||
}
|
||||
}
|
||||
|
||||
void TF_GraphImportGraphDefWithReturnOutputs(
|
||||
TF_Graph* graph, const TF_Buffer* graph_def,
|
||||
const TF_ImportGraphDefOptions* opts, TF_Output* return_outputs,
|
||||
int num_return_outputs, TF_Status* status) {
|
||||
GraphDef def;
|
||||
if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
|
||||
status->status = InvalidArgument("Invalid GraphDef");
|
||||
return;
|
||||
}
|
||||
TF_ImportGraphDefResults results;
|
||||
mutex_lock l(graph->mu);
|
||||
GraphImportGraphDefLocked(graph, def, opts, return_outputs,
|
||||
num_return_outputs, status);
|
||||
GraphImportGraphDefLocked(graph, def, options, &results, status);
|
||||
DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
|
||||
memcpy(return_outputs, results.return_tensors.data(),
|
||||
num_return_outputs * sizeof(TF_Output));
|
||||
}
|
||||
|
||||
void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
|
||||
const TF_ImportGraphDefOptions* options,
|
||||
TF_Status* status) {
|
||||
TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, options, nullptr, 0,
|
||||
status);
|
||||
TF_ImportGraphDefResults* results =
|
||||
TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
|
||||
TF_DeleteImportGraphDefResults(results);
|
||||
}
|
||||
|
||||
// While loop functions -------------------------------------------------------
|
||||
|
|
@ -1930,7 +2020,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph,
|
|||
tensorflow::ShapeRefiner* dst_refiner,
|
||||
const TF_Output* src_inputs,
|
||||
const std::vector<tensorflow::Output>& dst_inputs,
|
||||
const tensorflow::string& prefix,
|
||||
const string& prefix,
|
||||
const std::vector<tensorflow::Operation>& control_deps,
|
||||
const TF_Output* nodes_to_return, int nreturn_nodes,
|
||||
std::vector<tensorflow::Output>* return_nodes) {
|
||||
|
|
@ -2257,9 +2347,9 @@ TF_Session* TF_LoadSessionFromSavedModel(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::unordered_set<tensorflow::string> tag_set;
|
||||
std::unordered_set<string> tag_set;
|
||||
for (int i = 0; i < tags_len; i++) {
|
||||
tag_set.insert(tensorflow::string(tags[i]));
|
||||
tag_set.insert(string(tags[i]));
|
||||
}
|
||||
|
||||
tensorflow::SavedModelBundle bundle;
|
||||
|
|
@ -2275,8 +2365,9 @@ TF_Session* TF_LoadSessionFromSavedModel(
|
|||
// TODO(jhseu): When Session is modified to take Graphs instead of
|
||||
// GraphDefs, return the Graph generated in LoadSavedModel().
|
||||
TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
|
||||
TF_ImportGraphDefResults results;
|
||||
GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
|
||||
import_opts, nullptr, 0, status);
|
||||
import_opts, &results, status);
|
||||
TF_DeleteImportGraphDefOptions(import_opts);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
|
|
@ -2372,20 +2463,20 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
|
|||
TF_Run_Setup(noutputs, output_values, status);
|
||||
|
||||
// Convert from TF_Output and TF_Tensor to a string and Tensor.
|
||||
std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
|
||||
std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
|
||||
if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
input_pairs[i].first = OutputName(inputs[i]);
|
||||
}
|
||||
|
||||
// Convert from TF_Output to string names.
|
||||
std::vector<tensorflow::string> output_names(noutputs);
|
||||
std::vector<string> output_names(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
output_names[i] = OutputName(outputs[i]);
|
||||
}
|
||||
|
||||
// Convert from TF_Operation* to string names.
|
||||
std::vector<tensorflow::string> target_names(ntargets);
|
||||
std::vector<string> target_names(ntargets);
|
||||
for (int i = 0; i < ntargets; ++i) {
|
||||
target_names[i] = target_opers[i]->node.name();
|
||||
}
|
||||
|
|
@ -2406,22 +2497,22 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
|
|||
return;
|
||||
}
|
||||
|
||||
std::vector<tensorflow::string> input_names(ninputs);
|
||||
std::vector<string> input_names(ninputs);
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
input_names[i] = OutputName(inputs[i]);
|
||||
}
|
||||
|
||||
std::vector<tensorflow::string> output_names(noutputs);
|
||||
std::vector<string> output_names(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
output_names[i] = OutputName(outputs[i]);
|
||||
}
|
||||
|
||||
std::vector<tensorflow::string> target_names(ntargets);
|
||||
std::vector<string> target_names(ntargets);
|
||||
for (int i = 0; i < ntargets; ++i) {
|
||||
target_names[i] = target_opers[i]->node.name();
|
||||
}
|
||||
|
||||
tensorflow::string new_handle;
|
||||
string new_handle;
|
||||
status->status = session->session->PRunSetup(input_names, output_names,
|
||||
target_names, &new_handle);
|
||||
if (status->status.ok()) {
|
||||
|
|
@ -2452,20 +2543,20 @@ void TF_SessionPRun(TF_Session* session, const char* handle,
|
|||
TF_Run_Setup(noutputs, output_values, status);
|
||||
|
||||
// Convert from TF_Output and TF_Tensor to a string and Tensor.
|
||||
std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
|
||||
std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
|
||||
if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
input_pairs[i].first = OutputName(inputs[i]);
|
||||
}
|
||||
|
||||
// Convert from TF_Output to string names.
|
||||
std::vector<tensorflow::string> output_names(noutputs);
|
||||
std::vector<string> output_names(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
output_names[i] = OutputName(outputs[i]);
|
||||
}
|
||||
|
||||
// Convert from TF_Operation* to string names.
|
||||
std::vector<tensorflow::string> target_names(ntargets);
|
||||
std::vector<string> target_names(ntargets);
|
||||
for (int i = 0; i < ntargets; ++i) {
|
||||
target_names[i] = target_opers[i]->node.name();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -914,7 +914,62 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput(
|
|||
TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs(
|
||||
const TF_ImportGraphDefOptions* opts);
|
||||
|
||||
// Add an operation in `graph_def` to be returned via the `return_opers` output
|
||||
// parameter of TF_GraphImportGraphDef().
|
||||
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation(
|
||||
TF_ImportGraphDefOptions* opts, const char* oper_name);
|
||||
|
||||
// Returns the number of return operations added via
|
||||
// TF_ImportGraphDefOptionsAddReturnOperation().
|
||||
TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOperations(
|
||||
const TF_ImportGraphDefOptions* opts);
|
||||
|
||||
// TF_ImportGraphDefResults holds results that are generated by
|
||||
// TF_GraphImportGraphDefWithResults().
|
||||
typedef struct TF_ImportGraphDefResults TF_ImportGraphDefResults;
|
||||
|
||||
// Fetches the return outputs requested via
|
||||
// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is
|
||||
// returned in `num_outputs`. The array of return outputs is returned in
|
||||
// `outputs`. `*outputs` is owned by and has the lifetime of `results`.
|
||||
TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOutputs(
|
||||
TF_ImportGraphDefResults* results, int* num_outputs, TF_Output** outputs);
|
||||
|
||||
// Fetches the return operations requested via
|
||||
// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched
|
||||
// operations is returned in `num_opers`. The array of return operations is
|
||||
// returned in `opers`. `*opers` is owned by and has the lifetime of `results`.
|
||||
TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations(
|
||||
TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers);
|
||||
|
||||
// Fetches any input mappings requested via
|
||||
// TF_ImportGraphDefOptionsAddInputMapping() that weren't used as input to any
|
||||
// node in the imported graph def. The number of fetched mappings is returned in
|
||||
// `num_unused_input_mappings`. The array of each mapping's source node name is
|
||||
// returned in `src_names`, and the array of each mapping's source index is
|
||||
// returned in `src_indexes`.
|
||||
//
|
||||
// `*src_names`, `*src_indexes`, and the memory backing each string in
|
||||
// `src_names` are owned by and have the lifetime of `results`.
|
||||
TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsUnusedInputMappings(
|
||||
TF_ImportGraphDefResults* results, int* num_unused_input_mappings,
|
||||
const char*** src_names, int** src_indexes);
|
||||
|
||||
// Deletes a results object returned by TF_GraphImportGraphDefWithResults().
|
||||
TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefResults(
|
||||
TF_ImportGraphDefResults* results);
|
||||
|
||||
// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and
|
||||
// a bad status on error. Otherwise, returns a populated
|
||||
// TF_ImportGraphDefResults instance. The returned instance must be deleted via
|
||||
// TF_DeleteImportGraphDefResults().
|
||||
TF_CAPI_EXPORT extern TF_ImportGraphDefResults*
|
||||
TF_GraphImportGraphDefWithResults(TF_Graph* graph, const TF_Buffer* graph_def,
|
||||
const TF_ImportGraphDefOptions* options,
|
||||
TF_Status* status);
|
||||
|
||||
// Import the graph serialized in `graph_def` into `graph`.
|
||||
// Convenience function for when only return outputs are needed.
|
||||
//
|
||||
// `num_return_outputs` must be the number of return outputs added (i.e. the
|
||||
// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If
|
||||
|
|
@ -926,7 +981,7 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDefWithReturnOutputs(
|
|||
int num_return_outputs, TF_Status* status);
|
||||
|
||||
// Import the graph serialized in `graph_def` into `graph`.
|
||||
// Convenience function for when no return outputs have been added.
|
||||
// Convenience function for when no results are needed.
|
||||
TF_CAPI_EXPORT extern void TF_GraphImportGraphDef(
|
||||
TF_Graph* graph, const TF_Buffer* graph_def,
|
||||
const TF_ImportGraphDefOptions* options, TF_Status* status);
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
|
||||
#include <list>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -124,6 +126,20 @@ struct TF_Session {
|
|||
|
||||
struct TF_ImportGraphDefOptions {
|
||||
tensorflow::ImportGraphDefOptions opts;
|
||||
|
||||
// Backing memory for TensorId fields in opts.
|
||||
// TODO(skyewm): it'd be better if ImportGraphDefOptions owned this.
|
||||
std::list<tensorflow::string> tensor_id_data;
|
||||
};
|
||||
|
||||
struct TF_ImportGraphDefResults {
|
||||
std::vector<TF_Output> return_tensors;
|
||||
std::vector<TF_Operation*> return_nodes;
|
||||
std::vector<const char*> unused_key_names;
|
||||
std::vector<int> unused_key_indexes;
|
||||
|
||||
// Backing memory for unused_key_names values.
|
||||
std::list<tensorflow::string> unused_key_names_data;
|
||||
};
|
||||
|
||||
struct TF_DeviceList {
|
||||
|
|
|
|||
|
|
@ -573,7 +573,7 @@ TEST(CAPI, ImportGraphDef) {
|
|||
TF_GraphToGraphDef(graph, graph_def, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Import it again, with a prefix, in a fresh graph.
|
||||
// Import it, with a prefix, in a fresh graph.
|
||||
TF_DeleteGraph(graph);
|
||||
graph = TF_NewGraph();
|
||||
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
|
||||
|
|
@ -588,8 +588,8 @@ TEST(CAPI, ImportGraphDef) {
|
|||
ASSERT_TRUE(feed != nullptr);
|
||||
ASSERT_TRUE(neg != nullptr);
|
||||
|
||||
// Import it again, with an input mapping and return outputs, into the same
|
||||
// graph.
|
||||
// Import it again, with an input mapping, return outputs, and a return
|
||||
// operation, into the same graph.
|
||||
TF_DeleteImportGraphDefOptions(opts);
|
||||
opts = TF_NewImportGraphDefOptions();
|
||||
TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
|
||||
|
|
@ -597,9 +597,10 @@ TEST(CAPI, ImportGraphDef) {
|
|||
TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
|
||||
TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
|
||||
EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
|
||||
TF_Output return_outputs[2];
|
||||
TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts,
|
||||
return_outputs, 2, s);
|
||||
TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
|
||||
EXPECT_EQ(1, TF_ImportGraphDefOptionsNumReturnOperations(opts));
|
||||
TF_ImportGraphDefResults* results =
|
||||
TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar");
|
||||
|
|
@ -615,11 +616,26 @@ TEST(CAPI, ImportGraphDef) {
|
|||
EXPECT_EQ(0, neg_input.index);
|
||||
|
||||
// Check return outputs
|
||||
TF_Output* return_outputs;
|
||||
int num_return_outputs;
|
||||
TF_ImportGraphDefResultsReturnOutputs(results, &num_return_outputs,
|
||||
&return_outputs);
|
||||
ASSERT_EQ(2, num_return_outputs);
|
||||
EXPECT_EQ(feed2, return_outputs[0].oper);
|
||||
EXPECT_EQ(0, return_outputs[0].index);
|
||||
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
|
||||
EXPECT_EQ(0, return_outputs[1].index);
|
||||
|
||||
// Check return operation
|
||||
TF_Operation** return_opers;
|
||||
int num_return_opers;
|
||||
TF_ImportGraphDefResultsReturnOperations(results, &num_return_opers,
|
||||
&return_opers);
|
||||
ASSERT_EQ(1, num_return_opers);
|
||||
EXPECT_EQ(scalar2, return_opers[0]); // not remapped
|
||||
|
||||
TF_DeleteImportGraphDefResults(results);
|
||||
|
||||
// Import again, with control dependencies, into the same graph.
|
||||
TF_DeleteImportGraphDefOptions(opts);
|
||||
opts = TF_NewImportGraphDefOptions();
|
||||
|
|
@ -689,6 +705,113 @@ TEST(CAPI, ImportGraphDef) {
|
|||
TF_DeleteStatus(s);
|
||||
}
|
||||
|
||||
TEST(CAPI, ImportGraphDef_WithReturnOutputs) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
// Create a graph with two nodes: x and 3
|
||||
Placeholder(graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
|
||||
TF_Operation* oper = ScalarConst(3, graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
|
||||
Neg(oper, graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
|
||||
|
||||
// Export to a GraphDef.
|
||||
TF_Buffer* graph_def = TF_NewBuffer();
|
||||
TF_GraphToGraphDef(graph, graph_def, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Import it in a fresh graph with return outputs.
|
||||
TF_DeleteGraph(graph);
|
||||
graph = TF_NewGraph();
|
||||
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
|
||||
TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
|
||||
TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
|
||||
EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
|
||||
TF_Output return_outputs[2];
|
||||
TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts,
|
||||
return_outputs, 2, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
|
||||
TF_Operation* feed = TF_GraphOperationByName(graph, "feed");
|
||||
TF_Operation* neg = TF_GraphOperationByName(graph, "neg");
|
||||
ASSERT_TRUE(scalar != nullptr);
|
||||
ASSERT_TRUE(feed != nullptr);
|
||||
ASSERT_TRUE(neg != nullptr);
|
||||
|
||||
// Check return outputs
|
||||
EXPECT_EQ(feed, return_outputs[0].oper);
|
||||
EXPECT_EQ(0, return_outputs[0].index);
|
||||
EXPECT_EQ(scalar, return_outputs[1].oper);
|
||||
EXPECT_EQ(0, return_outputs[1].index);
|
||||
|
||||
TF_DeleteImportGraphDefOptions(opts);
|
||||
TF_DeleteBuffer(graph_def);
|
||||
TF_DeleteGraph(graph);
|
||||
TF_DeleteStatus(s);
|
||||
}
|
||||
|
||||
TEST(CAPI, ImportGraphDef_UnusedInputMappings) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
// Create a graph with two nodes: x and 3
|
||||
Placeholder(graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
|
||||
TF_Operation* oper = ScalarConst(3, graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
|
||||
Neg(oper, graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
|
||||
|
||||
// Export to a GraphDef.
|
||||
TF_Buffer* graph_def = TF_NewBuffer();
|
||||
TF_GraphToGraphDef(graph, graph_def, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Import it in a fresh graph.
|
||||
TF_DeleteGraph(graph);
|
||||
graph = TF_NewGraph();
|
||||
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
|
||||
TF_GraphImportGraphDef(graph, graph_def, opts, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
|
||||
|
||||
// Import it in a fresh graph with an unused input mapping.
|
||||
TF_DeleteImportGraphDefOptions(opts);
|
||||
opts = TF_NewImportGraphDefOptions();
|
||||
TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
|
||||
TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0});
|
||||
TF_ImportGraphDefOptionsAddInputMapping(opts, "fake", 0, {scalar, 0});
|
||||
TF_ImportGraphDefResults* results =
|
||||
TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Check unused input mappings
|
||||
int num_unused_input_mappings;
|
||||
const char** src_names;
|
||||
int* src_indexes;
|
||||
TF_ImportGraphDefResultsUnusedInputMappings(
|
||||
results, &num_unused_input_mappings, &src_names, &src_indexes);
|
||||
ASSERT_EQ(1, num_unused_input_mappings);
|
||||
EXPECT_EQ(string("fake"), string(src_names[0]));
|
||||
EXPECT_EQ(0, src_indexes[0]);
|
||||
|
||||
TF_DeleteImportGraphDefResults(results);
|
||||
TF_DeleteImportGraphDefOptions(opts);
|
||||
TF_DeleteBuffer(graph_def);
|
||||
TF_DeleteGraph(graph);
|
||||
TF_DeleteStatus(s);
|
||||
}
|
||||
|
||||
TEST(CAPI, Session) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class GraphConstructor {
|
|||
bool skip_mapped_nodes;
|
||||
std::vector<string> control_dependencies;
|
||||
std::vector<TensorId> return_tensors;
|
||||
std::vector<StringPiece> return_nodes;
|
||||
std::vector<string> return_nodes;
|
||||
|
||||
// TODO(ashankar): This bool exists to separate out functionality required
|
||||
// to make ImportGraphDef a close equivalent of Python's import_graph_def
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ struct ImportGraphDefOptions {
|
|||
// Unlike `return_tensors`, `input_map` has no effect on the nodes
|
||||
// returned. `return_nodes` must be empty if `skip_mapped_nodes` is true.
|
||||
// TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need.
|
||||
std::vector<StringPiece> return_nodes;
|
||||
std::vector<string> return_nodes;
|
||||
|
||||
// TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
|
||||
// with ops that are not defined in the binary calling ImportGraphDef.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user