mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
refactor _save_parameters to _save_data (#43162)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43162 Test Plan: Imported from OSS Reviewed By: iseeyuan Differential Revision: D23175286 Pulled By: ann-ss fbshipit-source-id: 6f930b98c367242fd4efbf51cb1d09995f7c4b40
This commit is contained in:
parent
888ae1b3d8
commit
2e6e295ecc
|
|
@ -442,7 +442,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||
${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/export.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/export_data.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/optim/sgd.cpp
|
||||
)
|
||||
list(APPEND TORCH_SRCS ${MOBILE_SRCS})
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
#include <test/cpp/jit/test_base.h>
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/mobile/export.h>
|
||||
#include <torch/csrc/jit/mobile/export_data.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/mobile/import_data.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
|
|
@ -117,8 +117,8 @@ void testMobileSaveLoadData() {
|
|||
m._save_for_mobile(ss);
|
||||
mobile::Module bc = _load_for_mobile(ss);
|
||||
|
||||
_save_parameters(bc, ss_data);
|
||||
auto mobile_params = _load_parameters(ss_data);
|
||||
mobile::_save_data(bc, ss_data);
|
||||
auto mobile_params = mobile::_load_data(ss_data).named_parameters();
|
||||
AT_ASSERT(full_params.size() == mobile_params.size());
|
||||
for (const auto& e : full_params) {
|
||||
AT_ASSERT(e.value.item<int>() == mobile_params[e.name].item<int>());
|
||||
|
|
|
|||
|
|
@ -312,7 +312,7 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
|
|||
"torch/csrc/autograd/VariableTypeManual.cpp",
|
||||
"torch/csrc/jit/api/module_save.cpp",
|
||||
"torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp",
|
||||
"torch/csrc/jit/mobile/export.cpp",
|
||||
"torch/csrc/jit/mobile/export_data.cpp",
|
||||
"torch/csrc/jit/mobile/function.cpp",
|
||||
"torch/csrc/jit/mobile/import.cpp",
|
||||
"torch/csrc/jit/mobile/import_data.cpp",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#include <torch/csrc/jit/mobile/export.h>
|
||||
#include <torch/csrc/jit/mobile/export_data.h>
|
||||
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
|
|
@ -66,7 +66,7 @@ class ScriptModuleSerializer {
|
|||
|
||||
} // namespace
|
||||
|
||||
void _save_parameters(const Module& module, std::ostream& out) {
|
||||
void _save_data(const Module& module, std::ostream& out) {
|
||||
ScriptModuleSerializer serializer(
|
||||
[&](const void* buf, size_t nbytes) -> size_t {
|
||||
out.write(static_cast<const char*>(buf), nbytes);
|
||||
|
|
@ -75,7 +75,7 @@ void _save_parameters(const Module& module, std::ostream& out) {
|
|||
serializer.serialize(module._ivalue());
|
||||
}
|
||||
|
||||
void _save_parameters(const Module& module, const std::string& filename) {
|
||||
void _save_data(const Module& module, const std::string& filename) {
|
||||
ScriptModuleSerializer serializer(filename);
|
||||
serializer.serialize(module._ivalue());
|
||||
}
|
||||
|
|
@ -6,11 +6,9 @@ namespace torch {
|
|||
namespace jit {
|
||||
namespace mobile {
|
||||
|
||||
TORCH_API void _save_parameters(const Module& module, std::ostream& out);
|
||||
TORCH_API void _save_data(const Module& module, std::ostream& out);
|
||||
|
||||
TORCH_API void _save_parameters(
|
||||
const Module& module,
|
||||
const std::string& filename);
|
||||
TORCH_API void _save_data(const Module& module, const std::string& filename);
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
|
|
@ -30,7 +30,7 @@ namespace {
|
|||
class BytecodeDeserializer final {
|
||||
public:
|
||||
explicit BytecodeDeserializer(std::unique_ptr<PyTorchStreamReader> reader);
|
||||
mobile::Module deserialize(c10::optional<at::Device> device);
|
||||
c10::IValue deserialize(c10::optional<at::Device> device);
|
||||
|
||||
private:
|
||||
c10::IValue readArchive(
|
||||
|
|
@ -47,12 +47,11 @@ BytecodeDeserializer::BytecodeDeserializer(
|
|||
: compilation_unit_(std::make_shared<CompilationUnit>()),
|
||||
reader_(std::move(reader)) {}
|
||||
|
||||
mobile::Module BytecodeDeserializer::deserialize(
|
||||
c10::IValue BytecodeDeserializer::deserialize(
|
||||
c10::optional<at::Device> device) {
|
||||
auto mcu = std::make_shared<mobile::CompilationUnit>();
|
||||
|
||||
auto temp = readArchive("data", mcu, std::move(device));
|
||||
return mobile::Module(temp.toObject(), mcu);
|
||||
return readArchive("data", mcu, std::move(device));
|
||||
}
|
||||
|
||||
c10::IValue BytecodeDeserializer::readArchive(
|
||||
|
|
@ -154,21 +153,21 @@ c10::IValue BytecodeDeserializer::readArchive(
|
|||
|
||||
} // namespace
|
||||
|
||||
std::map<std::string, at::Tensor> _load_parameters(
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device) {
|
||||
namespace mobile {
|
||||
|
||||
mobile::Module _load_data(std::istream& in, c10::optional<at::Device> device) {
|
||||
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
|
||||
return _load_parameters(std::move(rai), std::move(device));
|
||||
return _load_data(std::move(rai), std::move(device));
|
||||
}
|
||||
|
||||
std::map<std::string, at::Tensor> _load_parameters(
|
||||
mobile::Module _load_data(
|
||||
const std::string& filename,
|
||||
c10::optional<at::Device> device) {
|
||||
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
|
||||
return _load_parameters(std::move(rai), std::move(device));
|
||||
return _load_data(std::move(rai), std::move(device));
|
||||
}
|
||||
|
||||
std::map<std::string, at::Tensor> _load_parameters(
|
||||
mobile::Module _load_data(
|
||||
std::unique_ptr<ReadAdapterInterface> rai,
|
||||
c10::optional<c10::Device> device) {
|
||||
auto observer = torch::observerConfig().getModuleObserver();
|
||||
|
|
@ -178,12 +177,14 @@ std::map<std::string, at::Tensor> _load_parameters(
|
|||
try {
|
||||
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
|
||||
BytecodeDeserializer deserializer(std::move(reader));
|
||||
mobile::Module result = deserializer.deserialize(std::move(device));
|
||||
auto mcu = std::make_shared<mobile::CompilationUnit>();
|
||||
mobile::Module result = mobile::Module(
|
||||
deserializer.deserialize(std::move(device)).toObject(), mcu);
|
||||
std::string name = result.name();
|
||||
if (observer) {
|
||||
observer->onExitLoadModel(name);
|
||||
}
|
||||
return result.named_parameters();
|
||||
return result;
|
||||
} catch (const std::exception& ex) {
|
||||
if (observer) {
|
||||
observer->onFailLoadModel(
|
||||
|
|
@ -198,5 +199,6 @@ std::map<std::string, at::Tensor> _load_parameters(
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -13,16 +13,18 @@ using caffe2::serialize::FileAdapter;
|
|||
using caffe2::serialize::IStreamAdapter;
|
||||
using caffe2::serialize::ReadAdapterInterface;
|
||||
|
||||
TORCH_API std::map<std::string, at::Tensor> _load_parameters(
|
||||
namespace mobile {
|
||||
TORCH_API mobile::Module _load_data(
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device = c10::nullopt);
|
||||
|
||||
TORCH_API std::map<std::string, at::Tensor> _load_parameters(
|
||||
TORCH_API mobile::Module _load_data(
|
||||
const std::string& filename,
|
||||
c10::optional<at::Device> device = c10::nullopt);
|
||||
|
||||
TORCH_API std::map<std::string, at::Tensor> _load_parameters(
|
||||
TORCH_API mobile::Module _load_data(
|
||||
std::unique_ptr<ReadAdapterInterface> rai,
|
||||
c10::optional<c10::Device> device = c10::nullopt);
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user