mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Since caffe2 and torch have been consolidated, CAFFE2_API should be merged with TORCH_API. Addresses a TODO. Manually edited some references of the removed `CAFFE2_API`: * `CONTRIBUTING.md` * `caffe2/proto/CMakeLists.txt` * `cmake/ProtoBuf.cmake` * `c10/macros/Export.h` * `torch/csrc/WindowsTorchApiMacro.h` Pull Request resolved: https://github.com/pytorch/pytorch/pull/49496 Reviewed By: malfet, samestep Differential Revision: D25600726 Pulled By: janeyx99 fbshipit-source-id: 7e068d959e397ac183c097d7e9a9afeca5ddd782
61 lines
1.6 KiB
C++
61 lines
1.6 KiB
C++
#ifndef CAFFE2_OPERATORS_LOAD_SAVE_OP_UTIL_H_
|
|
#define CAFFE2_OPERATORS_LOAD_SAVE_OP_UTIL_H_
|
|
|
|
#include <set>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
|
|
#include "caffe2/core/blob.h"
|
|
#include "caffe2/core/blob_serialization.h"
|
|
|
|
namespace caffe2 {
|
|
namespace load_save_op_util {
|
|
|
|
struct BlobState {
|
|
int64_t total_size;
|
|
int64_t current_size;
|
|
bool is_tensor;
|
|
std::set<int32_t> seen_chunks_ids;
|
|
|
|
explicit BlobState(
|
|
int64_t total_size = 0,
|
|
int64_t current_size = 0,
|
|
bool is_tensor = false)
|
|
: total_size(total_size),
|
|
current_size(current_size),
|
|
is_tensor(is_tensor) {}
|
|
};
|
|
|
|
TORCH_API std::string buildBlobNameFromDbKey(
|
|
const std::string& dbKey,
|
|
const std::string& strip_prefix = "",
|
|
const std::string& add_prefix = "");
|
|
|
|
// We are tracking sizes of already read tensor parts while reading data
|
|
// chunks. This way we can make sure that all chunks were loaded in the end.
|
|
TORCH_API void ProcessBlob(
|
|
Blob* blob,
|
|
const BlobProto& proto,
|
|
std::unordered_map<std::string, BlobState>* blob_states_ptr,
|
|
const std::string& key,
|
|
int* loaded_blobs);
|
|
|
|
TORCH_API void prepareBlob(
|
|
Blob* blob,
|
|
std::unordered_map<std::string, BlobState>* blob_states_ptr,
|
|
const std::string& key);
|
|
|
|
TORCH_API void updateBlobStates(
|
|
const BlobProto& proto,
|
|
std::unordered_map<std::string, BlobState>* blob_states_ptr,
|
|
const std::string& key,
|
|
int* loaded_blobs);
|
|
|
|
TORCH_API void validateBlobStates(
|
|
const std::unordered_map<std::string, BlobState>& blob_states);
|
|
|
|
} // namespace load_save_op_util
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_OPERATORS_LOAD_SAVE_OP_UTIL_H_
|