mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
53 lines
1.8 KiB
C++
53 lines
1.8 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/c/python_api.h"
|
|
|
|
#include "tensorflow/c/c_api_internal.h"
|
|
|
|
namespace tensorflow {
|
|
|
|
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
|
|
mutex_lock l(graph->mu);
|
|
graph->graph.AddControlEdge(&input->node, &op->node);
|
|
}
|
|
|
|
void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
|
TF_Buffer* attr_value_proto, TF_Status* status) {
|
|
AttrValue attr_val;
|
|
if (!attr_val.ParseFromArray(attr_value_proto->data,
|
|
attr_value_proto->length)) {
|
|
status->status =
|
|
tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
|
|
}
|
|
|
|
mutex_lock l(graph->mu);
|
|
op->node.AddAttr(attr_name, attr_val);
|
|
}
|
|
|
|
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
|
|
mutex_lock l(graph->mu);
|
|
op->node.set_requested_device(device);
|
|
}
|
|
|
|
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
|
TF_Status* status) {
|
|
mutex_lock l(graph->mu);
|
|
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
|
|
&dst.oper->node, dst.index);
|
|
}
|
|
|
|
} // namespace tensorflow
|