#ifndef CAFFE2_OPERATORS_LOAD_SAVE_OP_UTIL_H_ #define CAFFE2_OPERATORS_LOAD_SAVE_OP_UTIL_H_ #include #include #include #include "caffe2/core/blob.h" #include "caffe2/core/blob_serialization.h" namespace caffe2 { namespace load_save_op_util { struct BlobState { int64_t total_size; int64_t current_size; bool is_tensor; std::set seen_chunks_ids; explicit BlobState( int64_t total_size = 0, int64_t current_size = 0, bool is_tensor = false) : total_size(total_size), current_size(current_size), is_tensor(is_tensor) {} }; TORCH_API std::string buildBlobNameFromDbKey( const std::string& dbKey, const std::string& strip_prefix = "", const std::string& add_prefix = ""); // We are tracking sizes of already read tensor parts while reading data // chunks. This way we can make sure that all chunks were loaded in the end. TORCH_API void ProcessBlob( Blob* blob, const BlobProto& proto, std::unordered_map* blob_states_ptr, const std::string& key, int* loaded_blobs); TORCH_API void prepareBlob( Blob* blob, std::unordered_map* blob_states_ptr, const std::string& key); TORCH_API void updateBlobStates( const BlobProto& proto, std::unordered_map* blob_states_ptr, const std::string& key, int* loaded_blobs); TORCH_API void validateBlobStates( const std::unordered_map& blob_states); } // namespace load_save_op_util } // namespace caffe2 #endif // CAFFE2_OPERATORS_LOAD_SAVE_OP_UTIL_H_