mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
C API: Avoid converting uninitialized tensorflow::Tensor to TF_Tensor*
And return error messages instead of CHECK failing when the conversion fails. PiperOrigin-RevId: 163863981
This commit is contained in:
parent
9593704b28
commit
96675956ef
|
|
@ -488,7 +488,13 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Non-static for testing.
|
// Non-static for testing.
|
||||||
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src) {
|
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||||
|
TF_Status* status) {
|
||||||
|
if (!src.IsInitialized()) {
|
||||||
|
status->status = FailedPrecondition(
|
||||||
|
"attempt to use a tensor with an uninitialized value");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
if (src.dtype() == DT_RESOURCE) {
|
if (src.dtype() == DT_RESOURCE) {
|
||||||
DCHECK_EQ(0, src.shape().dims()) << src.shape().DebugString();
|
DCHECK_EQ(0, src.shape().dims()) << src.shape().DebugString();
|
||||||
if (src.shape().dims() != 0) {
|
if (src.shape().dims() != 0) {
|
||||||
|
|
@ -528,18 +534,26 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src) {
|
||||||
char* dst = data_start; // Where next string is encoded.
|
char* dst = data_start; // Where next string is encoded.
|
||||||
size_t dst_len = size - static_cast<size_t>(data_start - base);
|
size_t dst_len = size - static_cast<size_t>(data_start - base);
|
||||||
tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
|
tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
|
||||||
TF_Status status;
|
|
||||||
for (int i = 0; i < srcarray.size(); ++i) {
|
for (int i = 0; i < srcarray.size(); ++i) {
|
||||||
*offsets = (dst - data_start);
|
*offsets = (dst - data_start);
|
||||||
offsets++;
|
offsets++;
|
||||||
const tensorflow::string& s = srcarray(i);
|
const tensorflow::string& s = srcarray(i);
|
||||||
size_t consumed =
|
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
|
||||||
TF_StringEncode(s.data(), s.size(), dst, dst_len, &status);
|
if (!status->status.ok()) {
|
||||||
CHECK(status.status.ok());
|
status->status = InvalidArgument(
|
||||||
|
"invalid string tensor encoding (string #", i, " of ",
|
||||||
|
srcarray.size(), "): ", status->status.error_message());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
dst += consumed;
|
dst += consumed;
|
||||||
dst_len -= consumed;
|
dst_len -= consumed;
|
||||||
}
|
}
|
||||||
CHECK_EQ(dst, base + size);
|
if (dst != base + size) {
|
||||||
|
status->status = InvalidArgument(
|
||||||
|
"invalid string tensor encoding (decoded ", (dst - base),
|
||||||
|
" bytes, but the tensor is encoded in ", size, " bytes");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
auto dims = src.shape().dim_sizes();
|
auto dims = src.shape().dim_sizes();
|
||||||
std::vector<tensorflow::int64> dimvec(dims.size());
|
std::vector<tensorflow::int64> dimvec(dims.size());
|
||||||
|
|
@ -650,7 +664,8 @@ static void TF_Run_Helper(
|
||||||
static_cast<TF_DataType>(src.dtype()), src.shape());
|
static_cast<TF_DataType>(src.dtype()), src.shape());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
c_outputs[i] = TF_TensorFromTensor(src);
|
c_outputs[i] = TF_TensorFromTensor(src, status);
|
||||||
|
if (!status->status.ok()) return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1605,7 +1620,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
|
||||||
Tensor t;
|
Tensor t;
|
||||||
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
|
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
|
||||||
if (!status->status.ok()) return;
|
if (!status->status.ok()) return;
|
||||||
*value = TF_TensorFromTensor(t);
|
*value = TF_TensorFromTensor(t, status);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
||||||
|
|
@ -1616,7 +1631,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
||||||
if (!status->status.ok()) return;
|
if (!status->status.ok()) return;
|
||||||
const auto len = std::min(max_values, static_cast<int>(ts.size()));
|
const auto len = std::min(max_values, static_cast<int>(ts.size()));
|
||||||
for (int i = 0; i < len; ++i) {
|
for (int i = 0; i < len; ++i) {
|
||||||
values[i] = TF_TensorFromTensor(ts[i]);
|
values[i] = TF_TensorFromTensor(ts[i], status);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/util/equal_graph_def.h"
|
#include "tensorflow/core/util/equal_graph_def.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src);
|
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
|
||||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
@ -137,6 +137,7 @@ TEST(CAPI, LibraryLoadFunctions) {
|
||||||
|
|
||||||
void TestEncodeDecode(int line, const std::vector<string>& data) {
|
void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||||
const tensorflow::int64 n = data.size();
|
const tensorflow::int64 n = data.size();
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
for (const std::vector<tensorflow::int64>& dims :
|
for (const std::vector<tensorflow::int64>& dims :
|
||||||
std::vector<std::vector<tensorflow::int64>>{
|
std::vector<std::vector<tensorflow::int64>>{
|
||||||
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
|
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
|
||||||
|
|
@ -145,7 +146,8 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||||
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
|
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
|
||||||
src.flat<string>()(i) = data[i];
|
src.flat<string>()(i) = data[i];
|
||||||
}
|
}
|
||||||
TF_Tensor* dst = TF_TensorFromTensor(src);
|
TF_Tensor* dst = TF_TensorFromTensor(src, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
// Convert back to a C++ Tensor and ensure we get expected output.
|
// Convert back to a C++ Tensor and ensure we get expected output.
|
||||||
Tensor output;
|
Tensor output;
|
||||||
|
|
@ -157,6 +159,7 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||||
|
|
||||||
TF_DeleteTensor(dst);
|
TF_DeleteTensor(dst);
|
||||||
}
|
}
|
||||||
|
TF_DeleteStatus(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, TensorEncodeDecodeStrings) {
|
TEST(CAPI, TensorEncodeDecodeStrings) {
|
||||||
|
|
@ -914,7 +917,8 @@ TEST(CAPI, SavedModel) {
|
||||||
TF_Operation* input_op =
|
TF_Operation* input_op =
|
||||||
TF_GraphOperationByName(graph, input_op_name.c_str());
|
TF_GraphOperationByName(graph, input_op_name.c_str());
|
||||||
ASSERT_TRUE(input_op != nullptr);
|
ASSERT_TRUE(input_op != nullptr);
|
||||||
csession.SetInputs({{input_op, TF_TensorFromTensor(input)}});
|
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
const tensorflow::string output_op_name =
|
const tensorflow::string output_op_name =
|
||||||
tensorflow::ParseTensorName(output_name).first.ToString();
|
tensorflow::ParseTensorName(output_name).first.ToString();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user