mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Switch ScriptModuleOp to use a unique_ptr
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29856 Test Plan: waitforsadcastle Reviewed By: dzhulgakov Differential Revision: D18516553 fbshipit-source-id: d1e2d49ec613d07b21cd30bd777fbd300032cba1
This commit is contained in:
parent
902c1f9ef1
commit
7a6c3b36a1
|
|
@ -18,10 +18,10 @@ class ScriptModuleSerializer : public BlobSerializerBase {
|
|||
TypeMeta typeMeta,
|
||||
const string& name,
|
||||
SerializationAcceptor acceptor) override {
|
||||
CAFFE_ENFORCE(typeMeta.Match<Module>());
|
||||
CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<Module>>());
|
||||
|
||||
std::stringstream ss;
|
||||
(*static_cast<const Module*>(pointer)).save(ss);
|
||||
(*static_cast<const std::unique_ptr<Module>*>(pointer))->save(ss);
|
||||
|
||||
// NB: wrapping the entire zip archive as one string is probably not a
|
||||
// good idea and might be slow. This is meant as a workaround, any proper
|
||||
|
|
@ -45,7 +45,8 @@ class ScriptModuleDeserializer : public BlobDeserializerBase {
|
|||
std::stringstream ss;
|
||||
ss << serialized;
|
||||
ss.seekg(0);
|
||||
*blob->GetMutable<Module>() = torch::jit::load(ss);
|
||||
blob->GetMutable<std::unique_ptr<Module>>()->reset(
|
||||
new Module(torch::jit::load(ss)));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -62,7 +63,8 @@ class ScriptModuleLoadOp final : public Operator<CPUContext> {
|
|||
std::stringstream ss;
|
||||
ss << moduleBinary;
|
||||
ss.seekg(0);
|
||||
*OperatorBase::Output<Module>(0) = torch::jit::load(ss);
|
||||
OperatorBase::Output<std::unique_ptr<Module>>(0)->reset(
|
||||
new Module(torch::jit::load(ss)));
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
|
@ -91,8 +93,9 @@ class ScriptModuleOp final : public Operator<Context> {
|
|||
// want their gradients to be tracked in this operator.
|
||||
torch::NoGradGuard guard;
|
||||
|
||||
const auto& module = OperatorBase::Input<Module>(0);
|
||||
Method method = module.get_method(method_name_);
|
||||
const auto& module = OperatorBase::Input<std::unique_ptr<Module>>(0);
|
||||
CAFFE_ENFORCE(module);
|
||||
Method method = module->get_method(method_name_);
|
||||
// Assume all inputs are tensor for now
|
||||
std::vector<IValue> inputs;
|
||||
const int num_inputs = InputSize();
|
||||
|
|
@ -123,9 +126,11 @@ class ScriptModuleOp final : public Operator<Context> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
CAFFE_KNOWN_TYPE(Module);
|
||||
CAFFE_KNOWN_TYPE(std::unique_ptr<Module>);
|
||||
|
||||
REGISTER_BLOB_SERIALIZER((TypeMeta::Id<Module>()), ScriptModuleSerializer);
|
||||
REGISTER_BLOB_SERIALIZER(
|
||||
(TypeMeta::Id<std::unique_ptr<Module>>()),
|
||||
ScriptModuleSerializer);
|
||||
// NB: the first argument to REGISTER_BLOB_DESERIALIZER macro doesn't really
|
||||
// need to be a real type, it just get converted to string
|
||||
REGISTER_BLOB_DESERIALIZER(
|
||||
|
|
|
|||
|
|
@ -93,12 +93,12 @@ REGISTER_BLOB_FETCHER((TypeMeta::Id<string>()), StringFetcher);
|
|||
class ScriptModuleFetcher : public BlobFetcherBase {
|
||||
public:
|
||||
pybind11::object Fetch(const Blob& blob) override {
|
||||
return py::cast(blob.Get<torch::jit::script::Module>());
|
||||
return py::cast(*blob.Get<std::unique_ptr<torch::jit::script::Module>>());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_BLOB_FETCHER(
|
||||
(TypeMeta::Id<torch::jit::script::Module>()),
|
||||
(TypeMeta::Id<std::unique_ptr<torch::jit::script::Module>>()),
|
||||
caffe2::python::ScriptModuleFetcher);
|
||||
#endif
|
||||
|
||||
|
|
@ -247,7 +247,7 @@ bool feedBlob(
|
|||
}
|
||||
#ifdef FBCODE_CAFFE2
|
||||
if (auto module = torch::jit::script::as_module(arg)) {
|
||||
*blob->GetMutable<torch::jit::script::Module>() = *module;
|
||||
blob->GetMutable<std::unique_ptr<torch::jit::script::Module>>()->reset(new torch::jit::script::Module(*module));
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user