mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH` All changes but the ones to `.clang-tidy` are generated using following script: ``` for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008 Reviewed By: driazati, r-barnes Differential Revision: D29838584 Pulled By: malfet fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
523 lines
18 KiB
C++
523 lines
18 KiB
C++
#include "caffe2/operators/load_save_op.h"
|
|
|
|
#if CAFFE2_HAVE_RE2
|
|
#include <re2/re2.h>
|
|
#else
|
|
#include <regex>
|
|
#endif
|
|
|
|
namespace caffe2 {
|
|
|
|
template <>
|
|
void LoadOp<CPUContext>::SetCurrentDevice(BlobProto* proto) {
|
|
if (proto->has_tensor()) {
|
|
proto->mutable_tensor()->clear_device_detail();
|
|
proto->mutable_tensor()->mutable_device_detail()->set_device_type(
|
|
PROTO_CPU);
|
|
}
|
|
}
|
|
|
|
template <int VALUE_TYPE = TensorProto_DataType_FLOAT>
|
|
std::vector<TensorShape> LoadTensorInference(
|
|
const OperatorDef& def,
|
|
const vector<TensorShape>& /* unused */) {
|
|
ArgumentHelper helper(def);
|
|
auto shape = helper.GetRepeatedArgument<int64_t>("shape");
|
|
vector<TensorShape> out;
|
|
// Currently load op supports only shape.
|
|
// TODO: We have to extend it to support shapes vector.
|
|
// Since it support just one shape, we return
|
|
// the right shape information only when there is just one blob loaded.
|
|
// Otherwise, we return unknown TensorShapes.
|
|
if (def.output_size() == 1 && shape.size() > 0) {
|
|
TensorShape ts;
|
|
ts.set_data_type(static_cast<TensorProto_DataType>(
|
|
helper.GetSingleArgument<int>("dtype", VALUE_TYPE)));
|
|
for (auto d : shape) {
|
|
ts.add_dims(d);
|
|
}
|
|
out.push_back(ts);
|
|
} else {
|
|
for (int i = 0; i < def.output_size(); i++) {
|
|
TensorShape ts;
|
|
ts.set_unknown_shape(true);
|
|
out.push_back(ts);
|
|
}
|
|
}
|
|
return out;
|
|
}
|
|
|
|
namespace internal {
|
|
|
|
SaveOpImpl::SaveOpImpl(
|
|
OperatorBase* op,
|
|
const OperatorDef& operator_def,
|
|
Workspace* ws)
|
|
: operator_(op),
|
|
strip_prefix_(op->template GetSingleArgument<string>("strip_prefix", "")),
|
|
db_type_(op->template GetSingleArgument<string>("db_type", "")),
|
|
db_options_(op->template GetSingleArgument<string>("db_options", "")),
|
|
blob_names_(
|
|
op->template GetRepeatedArgument<string>("blob_name_overrides")) {
|
|
CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type.");
|
|
CAFFE_ENFORCE(
|
|
blob_names_.empty() || blob_names_.size() == op->Inputs().size(),
|
|
"Number of blobs and blob_name_overrides mismatch.");
|
|
CAFFE_ENFORCE(
|
|
blob_names_.empty() || strip_prefix_.empty(),
|
|
"strip_prefix and blob_name_overrides are mutually exclusive.");
|
|
|
|
auto absolute_path =
|
|
op->template GetSingleArgument<int>("absolute_path", false);
|
|
auto db_name = op->template GetSingleArgument<string>("db", "");
|
|
CAFFE_ENFORCE_GT(db_name.size(), 0, "Must specify a db name.");
|
|
full_db_name_ = absolute_path ? db_name : (ws->RootFolder() + "/" + db_name);
|
|
|
|
auto options_data = op->template GetSingleArgument<string>("options", "");
|
|
if (!options_data.empty()) {
|
|
if (!options_.ParseFromString(options_data)) {
|
|
CAFFE_ENFORCE(false, "unable to parse serialization options");
|
|
}
|
|
}
|
|
if (op->template HasSingleArgumentOfType<int>("chunk_size")) {
|
|
// The chunk size argument pre-dates the options argument.
|
|
// If it was passed in, add it to the options list as a final default
|
|
// setting.
|
|
auto chunk_size_argument =
|
|
op->template GetSingleArgument<int>("chunk_size", kDefaultChunkSize);
|
|
// The chunk_size argument used 0 to mean "no chunking", and -1 to mean
|
|
// "default chunk size". This is backwards from the behavior of the
|
|
// chunk_size field in the BlobSerializationOptions, so swap these values if
|
|
// we see them. (BlobSerializationOptions uses 0 to mean "default chunk
|
|
// size" since protobuf v3 does not support custom default values, and so we
|
|
// need to use 0 to mean the default behavior.)
|
|
constexpr int kOldDefaultChunkSize = -1;
|
|
constexpr int kOldNoChunking = 0;
|
|
if (chunk_size_argument == kOldDefaultChunkSize) {
|
|
chunk_size_argument = kDefaultChunkSize;
|
|
} else if (chunk_size_argument == kOldNoChunking) {
|
|
chunk_size_argument = kNoChunking;
|
|
}
|
|
options_.mutable_options()->Add()->set_chunk_size(chunk_size_argument);
|
|
}
|
|
|
|
if (blob_names_.empty()) {
|
|
std::set<std::string> input_names;
|
|
blob_names_.resize(op->Inputs().size());
|
|
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
|
for (int i = 0; i < blob_names_.size(); ++i) {
|
|
std::string name;
|
|
if (strip_prefix_.empty()) {
|
|
name = operator_def.input(i);
|
|
} else {
|
|
auto match_pos = operator_def.input(i).find(strip_prefix_);
|
|
if (match_pos == string::npos) {
|
|
name = operator_def.input(i);
|
|
} else {
|
|
name = operator_def.input(i).substr(
|
|
match_pos + strip_prefix_.size(), string::npos);
|
|
}
|
|
}
|
|
CAFFE_ENFORCE(
|
|
input_names.insert(name).second, "Duplicated input: ", name);
|
|
blob_names_[i] = name;
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
const BlobSerializationOptions& GetBlobOptions(
|
|
c10::string_view blob_name,
|
|
const SerializationOptions& options_list,
|
|
const BlobSerializationOptions& default_options) {
|
|
for (const auto& options : options_list.options()) {
|
|
const auto& name_regex = options.blob_name_regex();
|
|
if (name_regex.empty()) {
|
|
return options;
|
|
}
|
|
|
|
#if CAFFE2_HAVE_RE2
|
|
// If we have re2, prefer it over std::regex.
|
|
re2::RE2 regex(name_regex);
|
|
if (re2::RE2::FullMatch(
|
|
re2::StringPiece(blob_name.data(), blob_name.size()), regex)) {
|
|
return options;
|
|
}
|
|
#else
|
|
// std::regex should be avoided if at all possible, but use it as a fallback
|
|
// if we don't have re2 (e.g., for some issues with it see
|
|
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=61582)
|
|
if (std::regex_match(
|
|
blob_name.begin(), blob_name.end(), std::regex(name_regex))) {
|
|
return options;
|
|
}
|
|
#endif
|
|
}
|
|
return default_options;
|
|
}
|
|
} // namespace
|
|
|
|
bool SaveOpImpl::RunOnDevice() {
|
|
std::unique_ptr<DB> out_db(
|
|
caffe2::db::CreateDB(db_type_, full_db_name_, caffe2::db::NEW));
|
|
CAFFE_ENFORCE(
|
|
out_db.get(),
|
|
"Cannot find db implementation of type ",
|
|
db_type_,
|
|
" (while trying to open ",
|
|
full_db_name_,
|
|
")");
|
|
if (!db_options_.empty()) {
|
|
out_db->SetOptions(db_options_);
|
|
}
|
|
|
|
BlobSerializerBase::SerializationAcceptor acceptor =
|
|
[&](const std::string& blobName, std::string&& data) {
|
|
// transaction should take care of locking
|
|
VLOG(2) << "Sending " << blobName << " blob's data of size "
|
|
<< data.size() << " to db";
|
|
auto transaction = out_db->NewTransaction();
|
|
transaction->Put(blobName, std::move(data));
|
|
transaction->Commit();
|
|
};
|
|
|
|
const vector<const Blob*>& inputs = operator_->OperatorBase::Inputs();
|
|
VLOG(0) << "Saving " << inputs.size() << " inputs to " << db_type_ << ": "
|
|
<< full_db_name_;
|
|
BlobSerializationOptions default_options;
|
|
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
|
for (int i = 0; i < inputs.size(); ++i) {
|
|
SerializeBlob(
|
|
*inputs[i],
|
|
blob_names_[i],
|
|
acceptor,
|
|
GetBlobOptions(blob_names_[i], options_, default_options));
|
|
}
|
|
out_db->Close();
|
|
return true;
|
|
}
|
|
|
|
} // namespace internal
|
|
|
|
namespace {
|
|
class EstimateAllBlobSizesOp final : public Operator<CPUContext> {
|
|
public:
|
|
explicit EstimateAllBlobSizesOp(
|
|
const OperatorDef& operator_def,
|
|
Workspace* ws)
|
|
: Operator<CPUContext>(operator_def, ws),
|
|
include_shared_(GetSingleArgument<int>("include_shared", true)),
|
|
ws_(ws) {
|
|
auto options_data = GetSingleArgument<string>("options", "");
|
|
if (!options_data.empty()) {
|
|
if (!options_.ParseFromString(options_data)) {
|
|
CAFFE_ENFORCE(false, "unable to parse serialization options");
|
|
}
|
|
}
|
|
}
|
|
|
|
bool RunOnDevice() override {
|
|
const auto& blob_names = include_shared_ ? ws_->Blobs() : ws_->LocalBlobs();
|
|
auto* names_out = Output(0, {static_cast<int64_t>(blob_names.size())}, at::dtype<std::string>());
|
|
auto* sizes_out = Output(1, {static_cast<int64_t>(blob_names.size())}, at::dtype<int64_t>());
|
|
BlobSerializationOptions default_options;
|
|
for (size_t idx = 0; idx < blob_names.size(); ++idx) {
|
|
const auto& name = blob_names[idx];
|
|
auto* blob = ws_->GetBlob(name);
|
|
if (!blob) {
|
|
LOG(ERROR) << "unable to find blob " << name
|
|
<< " when estimating serialization size";
|
|
continue;
|
|
}
|
|
|
|
names_out->template mutable_data<std::string>()[idx] = name;
|
|
const auto& blob_serialization_options =
|
|
internal::GetBlobOptions(name, options_, default_options);
|
|
sizes_out->template mutable_data<int64_t>()[idx] =
|
|
EstimateSerializedBlobSize(*blob, name, blob_serialization_options);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
bool include_shared_{true};
|
|
Workspace* ws_{nullptr};
|
|
SerializationOptions options_;
|
|
};
|
|
} // namespace
|
|
|
|
REGISTER_CPU_OPERATOR(DBExists, DBExistsOp<CPUContext>);
|
|
REGISTER_CPU_OPERATOR(Load, LoadOp<CPUContext>);
|
|
REGISTER_CPU_OPERATOR(Save, SaveOp<CPUContext>);
|
|
REGISTER_CPU_OPERATOR(Checkpoint, CheckpointOp<CPUContext>);
|
|
// CPU Operator old name: do NOT use, we may deprecate this later.
|
|
REGISTER_CPU_OPERATOR(Snapshot, CheckpointOp<CPUContext>);
|
|
REGISTER_CPU_OPERATOR(EstimateAllBlobSizes, EstimateAllBlobSizesOp);
|
|
|
|
OPERATOR_SCHEMA(DBExists)
|
|
.NumInputs(0)
|
|
.NumOutputs(1)
|
|
.SetDoc(R"DOC(
|
|
Checks if the db described by the arguments exists.
|
|
|
|
Github Links:
|
|
|
|
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/load_save_op.cc
|
|
|
|
<details>
|
|
|
|
<summary> <b>Example</b> </summary>
|
|
|
|
**Code**
|
|
|
|
```
|
|
workspace.ResetWorkspace()
|
|
|
|
op = core.CreateOperator(
|
|
"DBExists",
|
|
[],
|
|
["exists"],
|
|
db_name="test_db",
|
|
db_type="leveldb",
|
|
)
|
|
|
|
workspace.RunOperatorOnce(op)
|
|
print("exists:", workspace.FetchBlob("exists"))
|
|
|
|
```
|
|
|
|
</details>
|
|
|
|
)DOC")
|
|
.Output(0, "exists", "*(type: Tensor`<bool>`)* Scalar boolean output "
|
|
"tensor. True if the db exists, else false.")
|
|
.Arg(
|
|
"absolute_path",
|
|
"*(type: int; default: 0)* If set to non-zero, save the db directly to "
|
|
"the path specified by the `db` arg. If not set (default), prepend the "
|
|
"path of the current root folder of the workspace to the path specified "
|
|
"by the `db` arg.")
|
|
.Arg("db_name", "*(type: string)* Path to the db in question; see the "
|
|
"`absolute_path` arg details for options regarding the current root folder "
|
|
"of the workspace.")
|
|
.Arg("db_type", "*(type: string)* Type of db to save (options: \"lmdb\", "
|
|
"\"leveldb\", \"minidb\").");
|
|
|
|
OPERATOR_SCHEMA(Load)
|
|
.NumInputs(0, INT_MAX)
|
|
.NumOutputs(0, INT_MAX)
|
|
.TensorInferenceFunction(LoadTensorInference<>)
|
|
.SetDoc(R"DOC(
|
|
The Load operator loads a set of serialized blobs from a db or multiple dbs. It
|
|
takes $[0, \infty)$ number of inputs and $[0, \infty)$ number of outputs, using
|
|
the db keys to match the db entries with the outputs.
|
|
|
|
If at least one input is passed, then it is assumed that that input blobs are a
|
|
set of DBReaders to load from. Otherwise the `db` or `dbs` argument is used to load
|
|
blobs from one single db or multiple dbs respectively. `db_type` argument is used
|
|
to specify the type of the input db/dbs.
|
|
|
|
Github Links:
|
|
|
|
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/load_save_op.cc
|
|
|
|
<details>
|
|
|
|
<summary> <b>Example</b> </summary>
|
|
|
|
**Code**
|
|
|
|
```
|
|
workspace.ResetWorkspace()
|
|
|
|
op = core.CreateOperator(
|
|
"Load",
|
|
[],
|
|
["X", "Y"],
|
|
db="test_db",
|
|
db_type="lmdb"
|
|
)
|
|
|
|
workspace.RunOperatorOnce(op)
|
|
print("X:", workspace.FetchBlob("X"))
|
|
print("Y:", workspace.FetchBlob("Y"))
|
|
|
|
```
|
|
|
|
</details>
|
|
|
|
)DOC")
|
|
.Input(
|
|
0,
|
|
"X, Y, ...",
|
|
"*(type: List(DBReader))* [OPTIONAL] List of DBReaders to load from. Can "
|
|
"use this instead of the `db`/`dbs` args.")
|
|
.Arg(
|
|
"absolute_path",
|
|
"*(type: int; default: 0)* If set to non-zero, save the db directly to "
|
|
"the path specified by the `db` arg. If not set (default), prepend the "
|
|
"path of the current root folder of the workspace to the path specified "
|
|
"by the `db` arg.")
|
|
.Arg(
|
|
"add_prefix",
|
|
"*(type: string, default: \"\")* Blobs will be prefixed with this when "
|
|
"loading. Useful for avoiding collisions with blobs existing in the "
|
|
"workspace. The output blob names specified to this op should include "
|
|
"this prefix.")
|
|
.Arg(
|
|
"strip_prefix",
|
|
"*(type: string, default: \"\")* Characters in the provided blob names "
|
|
"that match `strip_prefix` will be removed prior to saving. Also, "
|
|
"characters that precede `strip_prefix` will be removed. Useful for "
|
|
"removing device scope from blob names.")
|
|
.Arg("db", "*(type: string)* The output path of the db. See the "
|
|
"`absolute_path` arg details for options regarding the current root folder "
|
|
"of the workspace.")
|
|
.Arg(
|
|
"dbs",
|
|
"*(type: List(string))* List of paths to dbs to load blobs from. See "
|
|
"the `absolute_path` arg details for options regarding the current "
|
|
"root folder of the workspace.")
|
|
.Arg("db_type", "(type: string)* Type of db to save (options: \"lmdb\", "
|
|
"\"leveldb\", \"minidb\").")
|
|
.Arg(
|
|
"keep_device",
|
|
"*(type: int; default: 0)* If nonzero, the blobs are loaded into the "
|
|
"device that is specified in the serialized `BlobProto`. Otherwise, "
|
|
"the device will be set as the one that the `Load` operator is being "
|
|
"run under.")
|
|
.Arg(
|
|
"load_all",
|
|
"*(type: int; default: 0)* If nonzero, will load all blobs pointed to "
|
|
"by the db to the workspace overwriting/creating blobs as needed.")
|
|
.Arg(
|
|
"allow_incomplete",
|
|
"*(type: bool; default: False)* If True, will allow not loading all "
|
|
"the output blobs specified in the outputs.")
|
|
.Arg(
|
|
"source_blob_names",
|
|
"*(type: List(string))* If set, used instead of output blob names to "
|
|
"specify which blobs in the db shall be loaded. Must be the same "
|
|
"length as number of output blobs.");
|
|
|
|
OPERATOR_SCHEMA(Save)
|
|
.NumInputs(1, INT_MAX)
|
|
.NumOutputs(0)
|
|
.SetDoc(R"DOC(
|
|
Saves a set of blobs to a db. It takes $[1, \infty)$ number of inputs and has
|
|
no output. The contents of the inputs are written into the db using the
|
|
settings specified by the arguments.
|
|
|
|
Github Links:
|
|
|
|
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/load_save_op.cc
|
|
|
|
<details>
|
|
|
|
<summary> <b>Example</b> </summary>
|
|
|
|
**Code**
|
|
|
|
```
|
|
workspace.ResetWorkspace()
|
|
|
|
op = core.CreateOperator(
|
|
"Save",
|
|
["X", "Y", "Z"],
|
|
[],
|
|
db="test_db2",
|
|
db_type="leveldb",
|
|
blob_name_overrides=["x_scores", "y_scores", "z_scores"]
|
|
)
|
|
|
|
workspace.FeedBlob("X", np.random.randint(20, size=(5,5)))
|
|
workspace.FeedBlob("Y", np.random.randint(20, size=(5,5)))
|
|
workspace.FeedBlob("Z", np.random.randint(20, size=(5,5)))
|
|
workspace.RunOperatorOnce(op)
|
|
|
|
```
|
|
|
|
</details>
|
|
|
|
)DOC")
|
|
.Arg(
|
|
"absolute_path",
|
|
"*(type: int; default: 0)* If set to non-zero, save the db directly to "
|
|
"the path specified by the `db` arg. If not set (default), prepend the "
|
|
"path of the current root folder of the workspace to the path specified "
|
|
"by the `db` arg.")
|
|
.Arg(
|
|
"strip_prefix",
|
|
"*(type: string, default: \"\")* Characters in the provided blob names "
|
|
"that match `strip_prefix` will be removed prior to saving. Also, "
|
|
"characters that precede `strip_prefix` will be removed. Useful for "
|
|
"removing device scope from blob names.")
|
|
.Arg(
|
|
"blob_name_overrides",
|
|
"*(List(string))* If set, used as blob names instead of original blob "
|
|
"names. Must be same length as number of blobs.")
|
|
.Arg("db", "*(type: string)* The output path of the db. See the "
|
|
"`absolute_path` arg details for options regarding the current root folder "
|
|
"of the workspace.")
|
|
.Arg("db_type", "*(type: string)* Type of db to save (options: \"lmdb\", "
|
|
"\"leveldb\", \"minidb\").")
|
|
.Arg("chunk_size", "*(type: string; default: kDefaultChunkSize)* The chunk "
|
|
"size to split tensor data into. If not set, caffe2_tensor_chunk_size will "
|
|
"be used")
|
|
.Input(0, "X", "*(type: Tensor)* Input tensor(s).");
|
|
|
|
OPERATOR_SCHEMA(Checkpoint)
|
|
.NumInputs(1, INT_MAX)
|
|
.NumOutputs(0)
|
|
.SetDoc(R"DOC(
|
|
The Checkpoint operator is similar to the Save operator, but allows one to save
|
|
to db every few iterations, with a db name that is appended with the iteration
|
|
count. It takes [1, infinity) number of inputs and has no output. The first
|
|
input has to be a TensorCPU of type int and has size 1 (i.e. the iteration
|
|
counter). This is determined whether we need to do checkpointing.
|
|
)DOC")
|
|
.Arg(
|
|
"absolute_path",
|
|
"(int, default 0) if set, use the db path directly and do not prepend "
|
|
"the current root folder of the workspace.")
|
|
.Arg(
|
|
"db",
|
|
"(string) a template string that one can combine with the "
|
|
"iteration to create the final db name. For example, "
|
|
"\"/home/lonestarr/checkpoint_%08d.db\"")
|
|
.Arg("db_type", "(string) the type of the db.")
|
|
.Arg(
|
|
"every",
|
|
"(int, default 1) the checkpointing is carried out when "
|
|
"(iter mod every) is zero.");
|
|
|
|
OPERATOR_SCHEMA(Snapshot);
|
|
|
|
OPERATOR_SCHEMA(EstimateAllBlobSizes)
|
|
.NumInputs(0)
|
|
.NumOutputs(2)
|
|
.SetDoc(R"DOC(
|
|
Returns two outputs: a 1D tensor of strings containing the names
|
|
of each blob in the active workspace, and a 1D tensor of integers containing the
|
|
estimated serialized size of each blob (in bytes).
|
|
)DOC")
|
|
.Arg(
|
|
"include_shared",
|
|
"(bool, default true) Whether to include blobs "
|
|
"inherited from parent workspaces.")
|
|
.Arg(
|
|
"options",
|
|
"(string, default empty) A BlobSerializationOptions message specifying "
|
|
"options for how specific blobs should be serialized.")
|
|
.Output(0, "blob_names", "1D tensor of strings containing blob names.")
|
|
.Output(1, "blob_sizes", "1D tensor of int64_t containing blob sizes.");
|
|
|
|
NO_GRADIENT(Load);
|
|
SHOULD_NOT_DO_GRADIENT(DBExists);
|
|
SHOULD_NOT_DO_GRADIENT(Save);
|
|
SHOULD_NOT_DO_GRADIENT(Checkpoint);
|
|
SHOULD_NOT_DO_GRADIENT(Snapshot);
|
|
SHOULD_NOT_DO_GRADIENT(EstimateAllBlobSizesOp);
|
|
|
|
} // namespace caffe2
|