Add PyTorchPredictorContainer (#15899)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15899

Add PyTorchPredictorContainer to support multiple jit script modules

Reviewed By: pritamdamania87

Differential Revision: D13596139

fbshipit-source-id: 3ce0bdf2f4dbba7aa1d20e824d03e5ac98f5d887
This commit is contained in:
Lu Fang 2019-01-15 09:13:16 -08:00 committed by Facebook Github Bot
parent 1065e7cd24
commit b329e03684
5 changed files with 69 additions and 25 deletions

View File

@ -3,14 +3,14 @@
#include <fstream>
#include <memory>
#include <c10/macros/Macros.h>
#include "c10/macros/Macros.h"
#include "caffe2/serialize/istream_adapter.h"
#include "caffe2/serialize/read_adapter_interface.h"
namespace caffe2 {
namespace serialize {
class FileAdapter final : public ReadAdapterInterface {
class CAFFE2_API FileAdapter final : public ReadAdapterInterface {
public:
C10_DISABLE_COPY_AND_ASSIGN(FileAdapter);
explicit FileAdapter(const std::string& file_name);

View File

@ -2,15 +2,14 @@
#include <istream>
#include <c10/macros/Macros.h>
#include "c10/macros/Macros.h"
#include "caffe2/serialize/read_adapter_interface.h"
namespace caffe2 {
namespace serialize {
// this is a reader implemented by std::istream
class IStreamAdapter final : public ReadAdapterInterface {
class CAFFE2_API IStreamAdapter final : public ReadAdapterInterface {
public:
C10_DISABLE_COPY_AND_ASSIGN(IStreamAdapter);
explicit IStreamAdapter(std::istream* istream);

View File

@ -3,13 +3,15 @@
#include <cstddef>
#include <cstdint>
#include "c10/macros/Macros.h"
namespace caffe2 {
namespace serialize {
// this is the interface for the (file/stream/memory) reader in
// PyTorchStreamReader. with this interface, we can extend the support
// besides standard istream
class ReadAdapterInterface {
class CAFFE2_API ReadAdapterInterface {
public:
virtual size_t size() const = 0;
virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")

View File

@ -8,10 +8,13 @@
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/utils/functional.h>
#include <caffe2/core/types.h>
#include <caffe2/proto/caffe2_pb.h>
#include <caffe2/proto/torch_pb.h>
#include <caffe2/serialize/inline_container.h>
#include "caffe2/core/common.h"
#include "caffe2/core/types.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/torch_pb.h"
#include "caffe2/serialize/file_adapter.h"
#include "caffe2/serialize/inline_container.h"
#include "caffe2/serialize/istream_adapter.h"
#include <ATen/ATen.h>
@ -23,6 +26,10 @@
namespace torch {
namespace jit {
using caffe2::serialize::ReadAdapterInterface;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::FileAdapter;
namespace {
// this is a deserializer class which loads script modules from pt files. the
@ -34,9 +41,8 @@ namespace {
class ScriptModuleDeserializer final {
public:
ScriptModuleDeserializer(const std::string& filename);
ScriptModuleDeserializer(std::istream* is);
explicit ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai);
void deserialize(
ModuleLookup module_lookup,
c10::optional<at::Device> device);
@ -68,6 +74,9 @@ ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename)
ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is)
: reader_(is) {}
ScriptModuleDeserializer::ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai)
: reader_(std::move(rai)) {}
void ScriptModuleDeserializer::deserialize(
ModuleLookup module_lookup,
c10::optional<at::Device> device) {
@ -229,9 +238,34 @@ void import_ir_module(
deserializer.deserialize(module_lookup, device);
}
void import_ir_module(
ModuleLookup module_lookup,
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<at::Device> device) {
ScriptModuleDeserializer deserializer(std::move(rai));
deserializer.deserialize(module_lookup, device);
}
std::shared_ptr<script::Module> load(
std::istream& in,
c10::optional<at::Device> device) {
std::unique_ptr<IStreamAdapter> rai =
caffe2::make_unique<IStreamAdapter>(&in);
auto module = load(std::move(rai), device);
return module;
}
std::shared_ptr<script::Module> load(
const std::string& filename,
c10::optional<at::Device> device) {
std::unique_ptr<FileAdapter> rai = caffe2::make_unique<FileAdapter>(filename);
auto module = load(std::move(rai), device);
return module;
}
std::shared_ptr<script::Module> load(
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device) {
auto module = std::make_shared<script::Module>();
auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
@ -245,23 +279,11 @@ std::shared_ptr<script::Module> load(
return curr;
};
ScriptModuleDeserializer deserializer(&in);
ScriptModuleDeserializer deserializer(std::move(rai));
deserializer.deserialize(module_lookup, device);
return module;
}
std::shared_ptr<script::Module> load(
const std::string& filename,
c10::optional<at::Device> device) {
std::ifstream in(filename, std::ios_base::binary);
AT_CHECK(!in.fail(), "load: could not open file ", filename);
auto module = load(in, device);
return module;
}
} // namespace jit
} // namespace torch

View File

@ -5,6 +5,12 @@
#include <istream>
namespace caffe2 {
namespace serialize {
class ReadAdapterInterface;
} // namespace serialize
} // namespace caffe2
namespace torch {
namespace jit {
@ -21,6 +27,11 @@ TORCH_API void import_ir_module(
std::istream& in,
c10::optional<c10::Device> device = c10::nullopt);
TORCH_API void import_ir_module(
ModuleLookup module_lookup,
std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
c10::optional<c10::Device> device = c10::nullopt);
/// Loads a serialized `script::Module` from the given `istream`.
///
/// The istream must contain a serialized `script::Module`, exported via
@ -38,5 +49,15 @@ TORCH_API std::shared_ptr<script::Module> load(
const std::string& filename,
c10::optional<c10::Device> device = c10::nullopt);
/// Loads a serialized `script::Module` from the given `rai`.
///
/// The reader adapter, which is for customized input stream, must contain a
/// serialized `script::Module`, exported either via `ScriptModule.save()` in
/// Python or `torch::jit::ExportModule` in C++.
TORCH_API std::shared_ptr<script::Module> load(
std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
c10::optional<c10::Device> device = c10::nullopt);
} // namespace jit
} // namespace torch