#tf-data TFRecordWriter checks the return value of SerializeToString.

PiperOrigin-RevId: 516594306
This commit is contained in:
Yang Chen 2023-03-14 12:05:10 -07:00 committed by TensorFlower Gardener
parent 7403ddf320
commit aaf1f47074
2 changed files with 24 additions and 2 deletions

View File

@ -463,6 +463,7 @@ cc_library(
"//tensorflow/core/platform:random",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/data/snapshot_utils.h"
#include <algorithm>
#include <climits>
#include <functional>
#include <memory>
#include <optional>
@ -25,6 +26,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/data/name_utils.h"
#include "tensorflow/core/framework/dataset.h"
@ -62,6 +64,18 @@ constexpr const char* const kCurrentCheckpointID = "current_checkpoint_id";
constexpr const char* const kIndex = "index";
constexpr const char* const kStartIndex = "start_index";
std::string ProtoSerializationErrorMessage(const TensorProto& proto,
const std::string& output_file) {
const auto proto_byte_size = proto.ByteSizeLong();
std::string error_message =
absl::StrCat("Failed to serialize tensor proto of ", proto_byte_size,
" bytes to file: ", output_file);
if (proto_byte_size > INT_MAX) {
absl::StrAppend(&error_message, ": exceeded maximum protobuf size of 2GB.");
}
return error_message;
}
} // namespace
/* static */ constexpr const int64_t
@ -142,13 +156,20 @@ Status TFRecordWriter::WriteTensors(const std::vector<Tensor>& tensors) {
// will result in a smart pointer being moved upon function creation, which
// will result in proto_buffer == nullptr when WriteRecord happens.
auto* proto_buffer = new std::string();
proto.SerializeToString(proto_buffer);
if (!proto.SerializeToString(proto_buffer)) {
delete proto_buffer;
return errors::DataLoss(ProtoSerializationErrorMessage(proto, filename_));
}
absl::Cord proto_serialized = absl::MakeCordFromExternal(
*proto_buffer,
[proto_buffer](absl::string_view) { delete proto_buffer; });
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized));
#else // TF_CORD_SUPPORT
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsString()));
std::string proto_serialized;
if (!proto.SerializeToString(&proto_serialized)) {
return errors::DataLoss(ProtoSerializationErrorMessage(proto, filename_));
}
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized));
#endif // TF_CORD_SUPPORT
}
return OkStatus();