refactor _save_parameters to _save_data (#43162)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43162

Test Plan: Imported from OSS

Reviewed By: iseeyuan

Differential Revision: D23175286

Pulled By: ann-ss

fbshipit-source-id: 6f930b98c367242fd4efbf51cb1d09995f7c4b40
This commit is contained in:
Ann Shan 2020-08-18 14:53:51 -07:00 committed by Facebook GitHub Bot
parent 888ae1b3d8
commit 2e6e295ecc
7 changed files with 30 additions and 28 deletions

View File

@ -442,7 +442,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/export.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/export_data.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/optim/sgd.cpp
)
list(APPEND TORCH_SRCS ${MOBILE_SRCS})

View File

@ -2,7 +2,7 @@
#include <test/cpp/jit/test_base.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/mobile/export.h>
#include <torch/csrc/jit/mobile/export_data.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/import_data.h>
#include <torch/csrc/jit/mobile/module.h>
@ -117,8 +117,8 @@ void testMobileSaveLoadData() {
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
_save_parameters(bc, ss_data);
auto mobile_params = _load_parameters(ss_data);
mobile::_save_data(bc, ss_data);
auto mobile_params = mobile::_load_data(ss_data).named_parameters();
AT_ASSERT(full_params.size() == mobile_params.size());
for (const auto& e : full_params) {
AT_ASSERT(e.value.item<int>() == mobile_params[e.name].item<int>());

View File

@ -312,7 +312,7 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
"torch/csrc/autograd/VariableTypeManual.cpp",
"torch/csrc/jit/api/module_save.cpp",
"torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp",
"torch/csrc/jit/mobile/export.cpp",
"torch/csrc/jit/mobile/export_data.cpp",
"torch/csrc/jit/mobile/function.cpp",
"torch/csrc/jit/mobile/import.cpp",
"torch/csrc/jit/mobile/import_data.cpp",

View File

@ -1,4 +1,4 @@
#include <torch/csrc/jit/mobile/export.h>
#include <torch/csrc/jit/mobile/export_data.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/runtime/instruction.h>
@ -66,7 +66,7 @@ class ScriptModuleSerializer {
} // namespace
void _save_parameters(const Module& module, std::ostream& out) {
void _save_data(const Module& module, std::ostream& out) {
ScriptModuleSerializer serializer(
[&](const void* buf, size_t nbytes) -> size_t {
out.write(static_cast<const char*>(buf), nbytes);
@ -75,7 +75,7 @@ void _save_parameters(const Module& module, std::ostream& out) {
serializer.serialize(module._ivalue());
}
void _save_parameters(const Module& module, const std::string& filename) {
void _save_data(const Module& module, const std::string& filename) {
ScriptModuleSerializer serializer(filename);
serializer.serialize(module._ivalue());
}

View File

@ -6,11 +6,9 @@ namespace torch {
namespace jit {
namespace mobile {
TORCH_API void _save_parameters(const Module& module, std::ostream& out);
TORCH_API void _save_data(const Module& module, std::ostream& out);
TORCH_API void _save_parameters(
const Module& module,
const std::string& filename);
TORCH_API void _save_data(const Module& module, const std::string& filename);
} // namespace mobile
} // namespace jit

View File

@ -30,7 +30,7 @@ namespace {
class BytecodeDeserializer final {
public:
explicit BytecodeDeserializer(std::unique_ptr<PyTorchStreamReader> reader);
mobile::Module deserialize(c10::optional<at::Device> device);
c10::IValue deserialize(c10::optional<at::Device> device);
private:
c10::IValue readArchive(
@ -47,12 +47,11 @@ BytecodeDeserializer::BytecodeDeserializer(
: compilation_unit_(std::make_shared<CompilationUnit>()),
reader_(std::move(reader)) {}
mobile::Module BytecodeDeserializer::deserialize(
c10::IValue BytecodeDeserializer::deserialize(
c10::optional<at::Device> device) {
auto mcu = std::make_shared<mobile::CompilationUnit>();
auto temp = readArchive("data", mcu, std::move(device));
return mobile::Module(temp.toObject(), mcu);
return readArchive("data", mcu, std::move(device));
}
c10::IValue BytecodeDeserializer::readArchive(
@ -154,21 +153,21 @@ c10::IValue BytecodeDeserializer::readArchive(
} // namespace
std::map<std::string, at::Tensor> _load_parameters(
std::istream& in,
c10::optional<at::Device> device) {
namespace mobile {
mobile::Module _load_data(std::istream& in, c10::optional<at::Device> device) {
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
return _load_parameters(std::move(rai), std::move(device));
return _load_data(std::move(rai), std::move(device));
}
std::map<std::string, at::Tensor> _load_parameters(
mobile::Module _load_data(
const std::string& filename,
c10::optional<at::Device> device) {
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
return _load_parameters(std::move(rai), std::move(device));
return _load_data(std::move(rai), std::move(device));
}
std::map<std::string, at::Tensor> _load_parameters(
mobile::Module _load_data(
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device) {
auto observer = torch::observerConfig().getModuleObserver();
@ -178,12 +177,14 @@ std::map<std::string, at::Tensor> _load_parameters(
try {
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
BytecodeDeserializer deserializer(std::move(reader));
mobile::Module result = deserializer.deserialize(std::move(device));
auto mcu = std::make_shared<mobile::CompilationUnit>();
mobile::Module result = mobile::Module(
deserializer.deserialize(std::move(device)).toObject(), mcu);
std::string name = result.name();
if (observer) {
observer->onExitLoadModel(name);
}
return result.named_parameters();
return result;
} catch (const std::exception& ex) {
if (observer) {
observer->onFailLoadModel(
@ -198,5 +199,6 @@ std::map<std::string, at::Tensor> _load_parameters(
}
}
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -13,16 +13,18 @@ using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::ReadAdapterInterface;
TORCH_API std::map<std::string, at::Tensor> _load_parameters(
namespace mobile {
TORCH_API mobile::Module _load_data(
std::istream& in,
c10::optional<at::Device> device = c10::nullopt);
TORCH_API std::map<std::string, at::Tensor> _load_parameters(
TORCH_API mobile::Module _load_data(
const std::string& filename,
c10::optional<at::Device> device = c10::nullopt);
TORCH_API std::map<std::string, at::Tensor> _load_parameters(
TORCH_API mobile::Module _load_data(
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device = c10::nullopt);
} // namespace mobile
} // namespace jit
} // namespace torch