mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
#tf-data TFRecordWriter checks the return value of SerializeToString.
PiperOrigin-RevId: 516594306
This commit is contained in:
parent
7403ddf320
commit
aaf1f47074
|
|
@ -463,6 +463,7 @@ cc_library(
|
|||
"//tensorflow/core/platform:random",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/data/snapshot_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
|
@ -25,6 +26,7 @@ limitations under the License.
|
|||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/data/name_utils.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
|
@ -62,6 +64,18 @@ constexpr const char* const kCurrentCheckpointID = "current_checkpoint_id";
|
|||
constexpr const char* const kIndex = "index";
|
||||
constexpr const char* const kStartIndex = "start_index";
|
||||
|
||||
std::string ProtoSerializationErrorMessage(const TensorProto& proto,
|
||||
const std::string& output_file) {
|
||||
const auto proto_byte_size = proto.ByteSizeLong();
|
||||
std::string error_message =
|
||||
absl::StrCat("Failed to serialize tensor proto of ", proto_byte_size,
|
||||
" bytes to file: ", output_file);
|
||||
if (proto_byte_size > INT_MAX) {
|
||||
absl::StrAppend(&error_message, ": exceeded maximum protobuf size of 2GB.");
|
||||
}
|
||||
return error_message;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ constexpr const int64_t
|
||||
|
|
@ -142,13 +156,20 @@ Status TFRecordWriter::WriteTensors(const std::vector<Tensor>& tensors) {
|
|||
// will result in a smart pointer being moved upon function creation, which
|
||||
// will result in proto_buffer == nullptr when WriteRecord happens.
|
||||
auto* proto_buffer = new std::string();
|
||||
proto.SerializeToString(proto_buffer);
|
||||
if (!proto.SerializeToString(proto_buffer)) {
|
||||
delete proto_buffer;
|
||||
return errors::DataLoss(ProtoSerializationErrorMessage(proto, filename_));
|
||||
}
|
||||
absl::Cord proto_serialized = absl::MakeCordFromExternal(
|
||||
*proto_buffer,
|
||||
[proto_buffer](absl::string_view) { delete proto_buffer; });
|
||||
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized));
|
||||
#else // TF_CORD_SUPPORT
|
||||
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsString()));
|
||||
std::string proto_serialized;
|
||||
if (!proto.SerializeToString(&proto_serialized)) {
|
||||
return errors::DataLoss(ProtoSerializationErrorMessage(proto, filename_));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized));
|
||||
#endif // TF_CORD_SUPPORT
|
||||
}
|
||||
return OkStatus();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user