Switch ScriptModuleOp to use a unique_ptr

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29856

Test Plan: waitforsadcastle

Reviewed By: dzhulgakov

Differential Revision: D18516553

fbshipit-source-id: d1e2d49ec613d07b21cd30bd777fbd300032cba1
This commit is contained in:
James Reed 2019-11-14 19:34:13 -08:00 committed by Facebook Github Bot
parent 902c1f9ef1
commit 7a6c3b36a1
2 changed files with 16 additions and 11 deletions

View File

@ -18,10 +18,10 @@ class ScriptModuleSerializer : public BlobSerializerBase {
TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override {
CAFFE_ENFORCE(typeMeta.Match<Module>());
CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<Module>>());
std::stringstream ss;
(*static_cast<const Module*>(pointer)).save(ss);
(*static_cast<const std::unique_ptr<Module>*>(pointer))->save(ss);
// NB: wrapping the entire zip archive as one string is probably not a
// good idea and might be slow. This is meant as a workaround, any proper
@ -45,7 +45,8 @@ class ScriptModuleDeserializer : public BlobDeserializerBase {
std::stringstream ss;
ss << serialized;
ss.seekg(0);
*blob->GetMutable<Module>() = torch::jit::load(ss);
blob->GetMutable<std::unique_ptr<Module>>()->reset(
new Module(torch::jit::load(ss)));
}
};
@ -62,7 +63,8 @@ class ScriptModuleLoadOp final : public Operator<CPUContext> {
std::stringstream ss;
ss << moduleBinary;
ss.seekg(0);
*OperatorBase::Output<Module>(0) = torch::jit::load(ss);
OperatorBase::Output<std::unique_ptr<Module>>(0)->reset(
new Module(torch::jit::load(ss)));
return true;
}
};
@ -91,8 +93,9 @@ class ScriptModuleOp final : public Operator<Context> {
// want their gradients to be tracked in this operator.
torch::NoGradGuard guard;
const auto& module = OperatorBase::Input<Module>(0);
Method method = module.get_method(method_name_);
const auto& module = OperatorBase::Input<std::unique_ptr<Module>>(0);
CAFFE_ENFORCE(module);
Method method = module->get_method(method_name_);
// Assume all inputs are tensor for now
std::vector<IValue> inputs;
const int num_inputs = InputSize();
@ -123,9 +126,11 @@ class ScriptModuleOp final : public Operator<Context> {
};
} // namespace
CAFFE_KNOWN_TYPE(Module);
CAFFE_KNOWN_TYPE(std::unique_ptr<Module>);
REGISTER_BLOB_SERIALIZER((TypeMeta::Id<Module>()), ScriptModuleSerializer);
REGISTER_BLOB_SERIALIZER(
(TypeMeta::Id<std::unique_ptr<Module>>()),
ScriptModuleSerializer);
// NB: the first argument to REGISTER_BLOB_DESERIALIZER macro doesn't really
// need to be a real type, it just get converted to string
REGISTER_BLOB_DESERIALIZER(

View File

@ -93,12 +93,12 @@ REGISTER_BLOB_FETCHER((TypeMeta::Id<string>()), StringFetcher);
class ScriptModuleFetcher : public BlobFetcherBase {
public:
pybind11::object Fetch(const Blob& blob) override {
return py::cast(blob.Get<torch::jit::script::Module>());
return py::cast(*blob.Get<std::unique_ptr<torch::jit::script::Module>>());
}
};
REGISTER_BLOB_FETCHER(
(TypeMeta::Id<torch::jit::script::Module>()),
(TypeMeta::Id<std::unique_ptr<torch::jit::script::Module>>()),
caffe2::python::ScriptModuleFetcher);
#endif
@ -247,7 +247,7 @@ bool feedBlob(
}
#ifdef FBCODE_CAFFE2
if (auto module = torch::jit::script::as_module(arg)) {
*blob->GetMutable<torch::jit::script::Module>() = *module;
blob->GetMutable<std::unique_ptr<torch::jit::script::Module>>()->reset(new torch::jit::script::Module(*module));
return true;
}
#endif