mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add PyTorchPredictorContainer (#15899)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15899 Add PyTorchPredictorContainer to support multiple jit script modules Reviewed By: pritamdamania87 Differential Revision: D13596139 fbshipit-source-id: 3ce0bdf2f4dbba7aa1d20e824d03e5ac98f5d887
This commit is contained in:
parent
1065e7cd24
commit
b329e03684
|
|
@ -3,14 +3,14 @@
|
|||
#include <fstream>
|
||||
#include <memory>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include "c10/macros/Macros.h"
|
||||
#include "caffe2/serialize/istream_adapter.h"
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
class FileAdapter final : public ReadAdapterInterface {
|
||||
class CAFFE2_API FileAdapter final : public ReadAdapterInterface {
|
||||
public:
|
||||
C10_DISABLE_COPY_AND_ASSIGN(FileAdapter);
|
||||
explicit FileAdapter(const std::string& file_name);
|
||||
|
|
|
|||
|
|
@ -2,15 +2,14 @@
|
|||
|
||||
#include <istream>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include "c10/macros/Macros.h"
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
// this is a reader implemented by std::istream
|
||||
class IStreamAdapter final : public ReadAdapterInterface {
|
||||
class CAFFE2_API IStreamAdapter final : public ReadAdapterInterface {
|
||||
public:
|
||||
C10_DISABLE_COPY_AND_ASSIGN(IStreamAdapter);
|
||||
explicit IStreamAdapter(std::istream* istream);
|
||||
|
|
|
|||
|
|
@ -3,13 +3,15 @@
|
|||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include "c10/macros/Macros.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
// this is the interface for the (file/stream/memory) reader in
|
||||
// PyTorchStreamReader. with this interface, we can extend the support
|
||||
// besides standard istream
|
||||
class ReadAdapterInterface {
|
||||
class CAFFE2_API ReadAdapterInterface {
|
||||
public:
|
||||
virtual size_t size() const = 0;
|
||||
virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
||||
|
|
|
|||
|
|
@ -8,10 +8,13 @@
|
|||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/utils/functional.h>
|
||||
|
||||
#include <caffe2/core/types.h>
|
||||
#include <caffe2/proto/caffe2_pb.h>
|
||||
#include <caffe2/proto/torch_pb.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/types.h"
|
||||
#include "caffe2/proto/caffe2_pb.h"
|
||||
#include "caffe2/proto/torch_pb.h"
|
||||
#include "caffe2/serialize/file_adapter.h"
|
||||
#include "caffe2/serialize/inline_container.h"
|
||||
#include "caffe2/serialize/istream_adapter.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
|
|
@ -23,6 +26,10 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using caffe2::serialize::ReadAdapterInterface;
|
||||
using caffe2::serialize::IStreamAdapter;
|
||||
using caffe2::serialize::FileAdapter;
|
||||
|
||||
namespace {
|
||||
|
||||
// this is a deserializer class which loads script modules from pt files. the
|
||||
|
|
@ -34,9 +41,8 @@ namespace {
|
|||
class ScriptModuleDeserializer final {
|
||||
public:
|
||||
ScriptModuleDeserializer(const std::string& filename);
|
||||
|
||||
ScriptModuleDeserializer(std::istream* is);
|
||||
|
||||
explicit ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai);
|
||||
void deserialize(
|
||||
ModuleLookup module_lookup,
|
||||
c10::optional<at::Device> device);
|
||||
|
|
@ -68,6 +74,9 @@ ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename)
|
|||
ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is)
|
||||
: reader_(is) {}
|
||||
|
||||
ScriptModuleDeserializer::ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai)
|
||||
: reader_(std::move(rai)) {}
|
||||
|
||||
void ScriptModuleDeserializer::deserialize(
|
||||
ModuleLookup module_lookup,
|
||||
c10::optional<at::Device> device) {
|
||||
|
|
@ -229,9 +238,34 @@ void import_ir_module(
|
|||
deserializer.deserialize(module_lookup, device);
|
||||
}
|
||||
|
||||
void import_ir_module(
|
||||
ModuleLookup module_lookup,
|
||||
std::unique_ptr<ReadAdapterInterface> rai,
|
||||
c10::optional<at::Device> device) {
|
||||
ScriptModuleDeserializer deserializer(std::move(rai));
|
||||
deserializer.deserialize(module_lookup, device);
|
||||
}
|
||||
|
||||
std::shared_ptr<script::Module> load(
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device) {
|
||||
std::unique_ptr<IStreamAdapter> rai =
|
||||
caffe2::make_unique<IStreamAdapter>(&in);
|
||||
auto module = load(std::move(rai), device);
|
||||
return module;
|
||||
}
|
||||
|
||||
std::shared_ptr<script::Module> load(
|
||||
const std::string& filename,
|
||||
c10::optional<at::Device> device) {
|
||||
std::unique_ptr<FileAdapter> rai = caffe2::make_unique<FileAdapter>(filename);
|
||||
auto module = load(std::move(rai), device);
|
||||
return module;
|
||||
}
|
||||
|
||||
std::shared_ptr<script::Module> load(
|
||||
std::unique_ptr<ReadAdapterInterface> rai,
|
||||
c10::optional<c10::Device> device) {
|
||||
auto module = std::make_shared<script::Module>();
|
||||
|
||||
auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
|
||||
|
|
@ -245,23 +279,11 @@ std::shared_ptr<script::Module> load(
|
|||
return curr;
|
||||
};
|
||||
|
||||
ScriptModuleDeserializer deserializer(&in);
|
||||
ScriptModuleDeserializer deserializer(std::move(rai));
|
||||
deserializer.deserialize(module_lookup, device);
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
std::shared_ptr<script::Module> load(
|
||||
const std::string& filename,
|
||||
c10::optional<at::Device> device) {
|
||||
std::ifstream in(filename, std::ios_base::binary);
|
||||
|
||||
AT_CHECK(!in.fail(), "load: could not open file ", filename);
|
||||
|
||||
auto module = load(in, device);
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -5,6 +5,12 @@
|
|||
|
||||
#include <istream>
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
class ReadAdapterInterface;
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
|
|
@ -21,6 +27,11 @@ TORCH_API void import_ir_module(
|
|||
std::istream& in,
|
||||
c10::optional<c10::Device> device = c10::nullopt);
|
||||
|
||||
TORCH_API void import_ir_module(
|
||||
ModuleLookup module_lookup,
|
||||
std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
|
||||
c10::optional<c10::Device> device = c10::nullopt);
|
||||
|
||||
/// Loads a serialized `script::Module` from the given `istream`.
|
||||
///
|
||||
/// The istream must contain a serialized `script::Module`, exported via
|
||||
|
|
@ -38,5 +49,15 @@ TORCH_API std::shared_ptr<script::Module> load(
|
|||
const std::string& filename,
|
||||
c10::optional<c10::Device> device = c10::nullopt);
|
||||
|
||||
/// Loads a serialized `script::Module` from the given `rai`.
|
||||
///
|
||||
/// The reader adapter, which is for customized input stream, must contain a
|
||||
/// serialized `script::Module`, exported either via `ScriptModule.save()` in
|
||||
/// Python or `torch::jit::ExportModule` in C++.
|
||||
TORCH_API std::shared_ptr<script::Module> load(
|
||||
std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
|
||||
c10::optional<c10::Device> device = c10::nullopt);
|
||||
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user