[BE][flatbuffer] Remove code duplications and refactor (#79184)

Summary:
Remove code dup in import.cpp / export_modules.cpp such that
1. Only one copy of switching logic (detect flatbuffer / is_flatbuffer);
2. Move detection of includeness of flatbuffer to runtime (so no more macros)

This also reverts the dependency of import.cpp -> flatbuffer_loader.cpp to flatbuffer_loader.cpp -> import.cpp.

Differential Revision: D36926217

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79184
Approved by: https://github.com/zhxchen17
This commit is contained in:
Han Qi (qihqi) 2022-06-20 16:37:38 +00:00 committed by PyTorch MergeBot
parent 7de231a813
commit fed12ff680
24 changed files with 617 additions and 560 deletions

View File

@ -1645,7 +1645,8 @@ cc_library(
],
)) + libtorch_core_sources + libtorch_distributed_sources + torch_cpp_srcs + libtorch_extra_sources + jit_core_sources + lazy_tensor_ts_sources + GENERATED_AUTOGRAD_CPP + [
"torch/csrc/jit/serialization/flatbuffer_serializer.cpp",
"torch/csrc/jit/mobile/flatbuffer_loader.cpp"
"torch/csrc/jit/mobile/flatbuffer_loader.cpp",
"torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp",
],
copts = TORCH_COPTS,
defines = [

View File

@ -0,0 +1,32 @@
#pragma once
#include <cstring>
#include <caffe2/serialize/read_adapter_interface.h>
namespace caffe2 {
namespace serialize {
class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
public:
explicit MemoryReadAdapter(const void* data, off_t size)
: data_(data), size_(size) {}
size_t size() const override {
return size_;
}
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const override {
(void) what;
memcpy(buf, (int8_t*)(data_) + pos, n);
return n;
}
private:
const void* data_;
off_t size_;
};
} // namespace serialize
} // namespace caffe2

View File

@ -174,7 +174,7 @@ TEST(FlatbufferTest, MethodInvocation) { // NOLINT (use =delete in gtest)
}
}
#if defined(ENABLE_FLATBUFFER) && !defined(FB_XPLAT_BUILD)
#if !defined(FB_XPLAT_BUILD)
TEST(FlatbufferTest, FlatbufferBackPortTest) {
Module m("m");
m.define(R"(
@ -188,7 +188,7 @@ TEST(FlatbufferTest, FlatbufferBackPortTest) {
bool backPortSuccess = _backport_for_mobile(ss, oss, 5);
ASSERT_TRUE(backPortSuccess);
}
#endif // defined(ENABLE_FLATBUFFER) && !defined(FB_XPLAT_BUILD)
#endif // !defined(FB_XPLAT_BUILD)
TEST(FlatbufferTest, ExtraFiles) {
const auto script = R"JIT(
@ -207,7 +207,6 @@ TEST(FlatbufferTest, ExtraFiles) {
extra_files["mobile_info.json"] = "{\"key\": 23}";
std::unordered_map<std::string, std::string> loaded_extra_files;
#if defined ENABLE_FLATBUFFER
std::stringstream ss;
module->_save_for_mobile(ss, extra_files, true, /*use_flatbuffer=*/true);
@ -219,17 +218,6 @@ TEST(FlatbufferTest, ExtraFiles) {
// load it twice using the same stream
auto mobile_module2 = _load_for_mobile(ss, c10::nullopt, loaded_extra_files);
#else
CompilationOptions options;
mobile::Module bc = jitModuleToMobile(*module, options);
auto buff = save_mobile_module_to_bytes(bc, extra_files);
loaded_extra_files["metadata.json"] = "";
auto* flatbuffer_module =
mobile::serialization::GetMutableModule(buff.data());
parseExtraFiles(flatbuffer_module, loaded_extra_files);
#endif
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
@ -1283,7 +1271,6 @@ Module jitModuleFromBuffer(void* data) {
mobilem._ivalue(), files, constants, 8);
}
#if defined(ENABLE_FLATBUFFER)
TEST(TestSourceFlatbuffer, UpsampleNearest2d) {
Module m("m");
m.define(R"(
@ -1375,7 +1362,6 @@ TEST(TestSourceFlatbuffer,
AT_ASSERT(resd == refd);
}
}
#endif
#if !defined FB_XPLAT_BUILD
// The following test run in fbcode only

View File

@ -635,7 +635,7 @@ void backportAllVersionCheck(
std::vector<IValue>& expect_result_list,
const uint64_t expect_from_version) {
auto from_version = _get_model_bytecode_version(test_model_file_stream);
AT_ASSERT(from_version == expect_from_version);
EXPECT_EQ(from_version, expect_from_version);
AT_ASSERT(from_version > 0);
// Backport script_module_v5.ptl to an older version
@ -717,15 +717,11 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
torch::jit::Module module_freeze = freeze(module);
std::stringstream input_model_stream;
#if defined(ENABLE_FLATBUFFER)
module_freeze._save_for_mobile(
input_model_stream,
/*extra_files=*/{},
/*save_mobile_debug_info=*/false,
/*use_flatbuffer=*/true);
#else
module_freeze._save_for_mobile(input_model_stream);
#endif
std::vector<IValue> input_data =
std::vector<IValue>({torch::ones({1, 1, 28, 28})});
std::vector<IValue> expect_result_list;
@ -748,7 +744,7 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
input_model_stream,
input_data,
expect_result_list,
caffe2::serialize::kProducedBytecodeVersion);
9); // flatbuffer starts at 9
}
#endif // !defined(FB_XPLAT_BUILD)

View File

@ -10,6 +10,7 @@
#include <torch/csrc/jit/mobile/train/optim/sgd.h>
#include <torch/csrc/jit/mobile/train/random.h>
#include <torch/csrc/jit/mobile/train/sequential.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/data/dataloader.h>
#include <torch/torch.h>
@ -172,9 +173,9 @@ TEST(MobileTest, SaveParametersDefaultsToZip) {
EXPECT_EQ(ss_data.str()[3], '\x04');
}
#if defined(ENABLE_FLATBUFFER)
TEST(MobileTest, SaveParametersCanUseFlatbuffer) {
// Save some empty parameters using flatbuffer.
register_flatbuffer_all();
std::map<std::string, at::Tensor> empty_parameters;
std::stringstream ss_data;
_save_parameters(empty_parameters, ss_data, /*use_flatbuffer=*/true);
@ -188,34 +189,10 @@ TEST(MobileTest, SaveParametersCanUseFlatbuffer) {
EXPECT_EQ(ss_data.str()[6], 'M');
EXPECT_EQ(ss_data.str()[7], 'F');
}
#else // !defined(ENABLE_FLATBUFFER)
TEST(MobileTest, SaveParametersThrowsWithoutFlatbufferSupport) {
// Some empty parameters to try saving.
std::map<std::string, at::Tensor> empty_parameters;
std::stringstream ss_data;
// Save using flatbuffers should fail when support isn't compiled in. Make
// sure we get the exception that explicitly mentions the lack of flatbuffer
// support.
try {
_save_parameters(empty_parameters, ss_data, /*use_flatbuffer=*/true);
FAIL() << "_save_parameters should have thrown";
} catch (const ::c10::Error& e) {
static const std::string kExpectedSubstring =
"build hasn't enabled flatbuffer";
EXPECT_TRUE(
std::string(e.msg()).find(kExpectedSubstring) != std::string::npos)
<< "Exception message does not contain expected substring \""
<< kExpectedSubstring << "\": actual message \"" << e.msg() << "\"";
} catch (...) {
FAIL() << "Unexpected exception type";
}
}
#endif // !defined(ENABLE_FLATBUFFER)
#if defined(ENABLE_FLATBUFFER)
TEST(MobileTest, SaveLoadParametersUsingFlatbuffers) {
// Create some simple parameters to save.
register_flatbuffer_all();
std::map<std::string, at::Tensor> input_params;
input_params["four_by_ones"] = 4 * torch::ones({});
input_params["three_by_ones"] = 3 * torch::ones({});
@ -244,33 +221,6 @@ TEST(MobileTest, SaveLoadParametersUsingFlatbuffers) {
output_params["three_by_ones"].item<int>(), three_by_ones.item<int>());
}
}
#else // !defined(ENABLE_FLATBUFFER)
TEST(MobileTest, LoadParametersFailsWithoutFlatbufferSupport) {
// Create some data that looks like a flatbuffer header.
std::stringstream data;
data << "abcd"
<< "PTMF" // Flatbuffer magic
<< "ijkl";
// Loading the "flatbuffer" data should fail. Make sure we see the expected
// exception, not just any exception; since this isn't properly-formed
// flatbuffer data, any attempt to parse it might throw a different error type
// or message, but we don't expect anyone to try parsing it.
try {
_load_parameters(data);
FAIL() << "_load_parameters should have thrown";
} catch (const ::c10::Error& e) {
static const std::string kExpectedSubstring =
"build hasn't enabled flatbuffer";
EXPECT_TRUE(
std::string(e.msg()).find(kExpectedSubstring) != std::string::npos)
<< "Exception message does not contain expected substring \""
<< kExpectedSubstring << "\": actual message \"" << e.msg() << "\"";
} catch (...) {
FAIL() << "Unexpected exception type";
}
}
#endif // !defined(ENABLE_FLATBUFFER)
TEST(MobileTest, LoadParametersUnexpectedFormatShouldThrow) {
// Manually create some data that doesn't look like a ZIP or Flatbuffer file.

View File

@ -153,7 +153,7 @@ def add_torch_libs():
] if enable_flatbuffer else []),
link_whole = True,
include_directories = include_directories,
propagated_pp_flags = propagated_pp_flags_cpu + (["-DENABLE_FLATBUFFER"] if enable_flatbuffer else []),
propagated_pp_flags = propagated_pp_flags_cpu,
exported_deps = (
[
":ATen-cpu",

View File

@ -4,14 +4,12 @@
#include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
#include <torch/csrc/jit/mobile/file_format.h>
#if defined(ENABLE_FLATBUFFER)
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#endif
#include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_read.h>
#include <caffe2/serialize/in_memory_adapter.h>
#include <sstream>
#include <string>
#include <unordered_set>
@ -71,59 +69,33 @@ std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) {
// Forward declare
uint64_t _get_model_bytecode_version(
const std::vector<IValue>& bytecode_ivalues);
static uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size);
uint64_t _get_model_bytecode_version(std::istream& in) {
auto orig_pos = in.tellg();
in.seekg(0, in.beg);
auto format = getFileFormat(in);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if !defined(ENABLE_FLATBUFFER)
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
#else
return get_bytecode_version(in);
#endif
}
case FileFormat::ZipFileFormat: {
std::unique_ptr<IStreamAdapter> rai =
std::make_unique<IStreamAdapter>(&in);
auto version = _get_model_bytecode_version(std::move(rai));
in.seekg(orig_pos, in.beg);
return version;
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_stream_content(in);
in.seekg(orig_pos, in.beg);
return _get_model_bytecode_version_from_bytes(data.get(), size);
}
uint64_t _get_model_bytecode_version(const std::string& filename) {
auto format = getFileFormat(filename);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if !defined(ENABLE_FLATBUFFER)
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
#else
return get_bytecode_version(filename);
#endif
}
case FileFormat::ZipFileFormat: {
std::unique_ptr<FileAdapter> rai =
std::make_unique<FileAdapter>(filename);
return _get_model_bytecode_version(std::move(rai));
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
std::ifstream ifile(filename);
return _get_model_bytecode_version(ifile);
}
uint64_t _get_model_bytecode_version(
std::shared_ptr<ReadAdapterInterface> rai) {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_rai_content(rai.get());
return _get_model_bytecode_version_from_bytes(data.get(), size);
}
uint64_t _get_model_bytecode_version_zip(
std::shared_ptr<ReadAdapterInterface> rai) {
if (!check_zip_file(rai)) {
TORCH_CHECK(
false,
@ -134,6 +106,31 @@ uint64_t _get_model_bytecode_version(
return _get_model_bytecode_version(bytecode_values);
}
uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size) {
TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
auto format = getFileFormat(data);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
if (get_flatbuffer_bytecode_version == nullptr) {
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
} else {
return get_flatbuffer_bytecode_version(data);
}
}
case FileFormat::ZipFileFormat: {
auto rai =
std::make_unique<caffe2::serialize::MemoryReadAdapter>(data, size);
auto version = _get_model_bytecode_version_zip(std::move(rai));
return version;
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
}
uint64_t _get_model_bytecode_version(
const std::vector<IValue>& bytecode_ivalues) {
if (!bytecode_ivalues.empty() && bytecode_ivalues[0].isInt()) {

View File

@ -1,10 +1,24 @@
#pragma once
#include <array>
#include <cerrno>
#include <cstddef>
#include <cstring>
#include <fstream>
#include <istream>
#include <memory>
#include <c10/core/CPUAllocator.h>
#include <c10/core/impl/alloc_cpu.h>
#include <caffe2/serialize/read_adapter_interface.h>
#if defined(HAVE_MMAP)
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#endif
/**
* @file
@ -27,18 +41,16 @@ enum class FileFormat {
ZipFileFormat,
};
namespace internal {
/// The size of the buffer to pass to #getFileFormat(), in bytes.
constexpr size_t kFileFormatHeaderSize = 8;
constexpr size_t kMaxAlignment = 16;
/**
* Returns the likely file format based on the magic header bytes in @p header,
* which should contain the first bytes of a file or data stream.
*/
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
static inline FileFormat getFileFormat(
const std::array<char, kFileFormatHeaderSize>& header) {
static inline FileFormat getFileFormat(const char* data) {
// The size of magic strings to look for in the buffer.
static constexpr size_t kMagicSize = 4;
@ -60,22 +72,19 @@ static inline FileFormat getFileFormat(
// that do not typically cross into the printable ASCII range, so a ZIP file
// should never have a header that looks like a Flatbuffer file.
if (std::memcmp(
header.data() + kFlatbufferMagicOffset,
data + kFlatbufferMagicOffset,
kFlatbufferMagicString.data(),
kMagicSize) == 0) {
// Magic header for a binary file containing a Flatbuffer-serialized mobile
// Module.
return FileFormat::FlatbufferFileFormat;
} else if (
std::memcmp(header.data(), kZipMagicString.data(), kMagicSize) == 0) {
} else if (std::memcmp(data, kZipMagicString.data(), kMagicSize) == 0) {
// Magic header for a zip file, which we use to store pickled sub-files.
return FileFormat::ZipFileFormat;
}
return FileFormat::UnknownFileFormat;
}
} // namespace internal
/**
* Returns the likely file format based on the magic header bytes of @p data.
* If the stream position changes while inspecting the data, this function will
@ -86,10 +95,10 @@ static inline FileFormat getFileFormat(std::istream& data) {
FileFormat format = FileFormat::UnknownFileFormat;
std::streampos orig_pos = data.tellg();
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
std::array<char, internal::kFileFormatHeaderSize> header;
std::array<char, kFileFormatHeaderSize> header;
data.read(header.data(), header.size());
if (data.good()) {
format = internal::getFileFormat(header);
format = getFileFormat(header.data());
}
data.seekg(orig_pos, data.beg);
return format;
@ -105,5 +114,83 @@ static inline FileFormat getFileFormat(const std::string& filename) {
return getFileFormat(data);
}
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
static void file_not_found_error() {
std::stringstream message;
message << "Error while opening file: ";
if (errno == ENOENT) {
message << "no such file or directory" << std::endl;
} else {
message << "error no is: " << errno << std::endl;
}
TORCH_CHECK(false, message.str());
}
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
static inline std::tuple<std::shared_ptr<char>, size_t> get_file_content(
const char* filename) {
#if defined(HAVE_MMAP)
int fd = open(filename, O_RDONLY);
if (fd < 0) {
// failed to open file, chances are it's no such file or directory.
file_not_found_error();
}
struct stat statbuf {};
fstat(fd, &statbuf);
size_t size = statbuf.st_size;
void* ptr = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
close(fd);
auto deleter = [statbuf](char* ptr) { munmap(ptr, statbuf.st_size); };
std::shared_ptr<char> data(reinterpret_cast<char*>(ptr), deleter);
#else
FILE* f = fopen(filename, "rb");
if (f == nullptr) {
file_not_found_error();
}
fseek(f, 0, SEEK_END);
size_t size = ftell(f);
fseek(f, 0, SEEK_SET);
// make sure buffer size is multiple of alignment
size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment;
std::shared_ptr<char> data(
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
fread(data.get(), size, 1, f);
fclose(f);
#endif
return std::make_tuple(data, size);
}
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
static inline std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
std::istream& in) {
// get size of the stream and reset to orig
std::streampos orig_pos = in.tellg();
in.seekg(orig_pos, std::ios::end);
const long size = in.tellg();
in.seekg(orig_pos, in.beg);
// read stream
// NOLINT make sure buffer size is multiple of alignment
size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment;
std::shared_ptr<char> data(
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
in.read(data.get(), size);
// reset stream to original position
in.seekg(orig_pos, in.beg);
return std::make_tuple(data, size);
}
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
static inline std::tuple<std::shared_ptr<char>, size_t> get_rai_content(
caffe2::serialize::ReadAdapterInterface* rai) {
size_t buffer_size = (rai->size() / kMaxAlignment + 1) * kMaxAlignment;
std::shared_ptr<char> data(
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
rai->read(
0, data.get(), rai->size(), "Loading ReadAdapterInterface to bytes");
return std::make_tuple(data, buffer_size);
}
} // namespace jit
} // namespace torch

View File

@ -12,6 +12,7 @@
#include <c10/util/ScopeExit.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/observer.h>
@ -618,53 +619,6 @@ TypePtr FlatbufferLoader::getOrCreateTypeAnnotations(
return type;
}
std::tuple<std::shared_ptr<char>, size_t> get_file_content(
const char* filename) {
#if defined(HAVE_MMAP)
int fd = open(filename, O_RDONLY);
struct stat statbuf {};
fstat(fd, &statbuf);
size_t size = statbuf.st_size;
void* ptr = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
close(fd);
auto deleter = [statbuf](char* ptr) { munmap(ptr, statbuf.st_size); };
std::shared_ptr<char> data(reinterpret_cast<char*>(ptr), deleter);
#else
FILE* f = fopen(filename, "rb");
fseek(f, 0, SEEK_END);
size_t size = ftell(f);
fseek(f, 0, SEEK_SET);
// make sure buffer size is multiple of alignment
size_t buffer_size =
(size / FLATBUFFERS_MAX_ALIGNMENT + 1) * FLATBUFFERS_MAX_ALIGNMENT;
std::shared_ptr<char> data(
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
fread(data.get(), size, 1, f);
fclose(f);
#endif
return std::make_tuple(data, size);
}
std::tuple<std::shared_ptr<char>, size_t> get_stream_content(std::istream& in) {
// get size of the stream and reset to orig
std::streampos orig_pos = in.tellg();
in.seekg(orig_pos, std::ios::end);
const long size = in.tellg();
in.seekg(orig_pos, in.beg);
// read stream
// NOLINT make sure buffer size is multiple of alignment
size_t buffer_size =
(size / FLATBUFFERS_MAX_ALIGNMENT + 1) * FLATBUFFERS_MAX_ALIGNMENT;
std::shared_ptr<char> data(
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
in.read(data.get(), size);
// reset stream to original position
in.seekg(orig_pos, in.beg);
return std::make_tuple(data, size);
}
void FlatbufferLoader::extractJitSourceAndConstants(
ExtraFilesMap* jit_sources,
std::vector<IValue>* constants) {
@ -696,13 +650,17 @@ void FlatbufferLoader::extractJitSourceAndConstants(
mobile::Module parse_and_initialize_mobile_module(
std::shared_ptr<char> data,
size_t,
c10::optional<at::Device>) {
c10::optional<at::Device>,
ExtraFilesMap* extra_files) {
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
"Format error");
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module);
m.set_delete_memory(std::move(data));
if (extra_files != nullptr) {
parseExtraFiles(flatbuffer_module, *extra_files);
}
return m;
}
@ -718,32 +676,35 @@ mobile::Module initialize_mobile_module(
mobile::Module load_mobile_module_from_file(
const std::string& filename,
c10::optional<c10::Device> device) {
c10::optional<c10::Device> device,
ExtraFilesMap* extra_files) {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_file_content(filename.c_str());
return parse_and_initialize_mobile_module(std::move(data), size, device);
return parse_and_initialize_mobile_module(
std::move(data), size, device, extra_files);
}
uint64_t get_bytecode_version(std::istream& in) {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_stream_content(in);
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
"Format error");
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
return flatbuffer_module->bytecode_version();
return get_bytecode_version_from_bytes(data.get());
}
uint64_t get_bytecode_version(const std::string& filename) {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_file_content(filename.c_str());
return get_bytecode_version_from_bytes(data.get());
}
uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content) {
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
mobile::serialization::ModuleBufferHasIdentifier(flatbuffer_content),
"Format error");
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
auto* flatbuffer_module =
mobile::serialization::GetMutableModule(flatbuffer_content);
return flatbuffer_module->bytecode_version();
}
@ -755,5 +716,55 @@ mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) {
return mobile::get_module_info(m);
}
mobile::Module load_mobile_module_from_stream_with_copy(
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap* extra_files) {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_stream_content(in);
return parse_and_initialize_mobile_module(
std::move(data), size, device, extra_files);
}
static mobile::Module parse_flatbuffer_no_object(
std::shared_ptr<char> data,
size_t size,
c10::optional<at::Device> device) {
(void)device;
(void)size;
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
FlatbufferLoader loader;
// replace parserObject with to handle only class with field case
// function.
loader.registerIValueParser(
mobile::serialization::IValueUnion::Object,
+[](FlatbufferLoader& loader,
const mobile::serialization::IValue& ivalue) {
const mobile::serialization::Object* object = ivalue.val_as_Object();
auto cls = loader.getOrCreateClassTypeForObject(object);
auto obj = c10::ivalue::Object::create(
at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
IValue val = loader.getIValue(object->attrs()->Get(i));
obj->setSlot(i, std::move(val));
}
return static_cast<c10::IValue>(obj);
});
mobile::Module m = loader.parseModule(flatbuffer_module);
m.set_delete_memory(std::move(data));
return m;
}
bool register_flatbuffer_loader() {
load_flatbuffer_bytes = parse_and_initialize_mobile_module;
load_flatbuffer_bytes_no_object = parse_flatbuffer_no_object;
get_flatbuffer_bytecode_version = get_bytecode_version_from_bytes;
return true;
}
const bool kRegisteredFlatbufferLoader = register_flatbuffer_loader();
} // namespace jit
} // namespace torch

View File

@ -47,7 +47,8 @@ TORCH_API mobile::Module initialize_mobile_module(
TORCH_API mobile::Module parse_and_initialize_mobile_module(
std::shared_ptr<char> data,
size_t size,
c10::optional<at::Device> device = c10::nullopt);
c10::optional<at::Device> device = c10::nullopt,
ExtraFilesMap* extra_files = nullptr);
// Load a mobile::Module from a filepath.
// This function does steps 1+2+3 described above.
@ -56,24 +57,33 @@ TORCH_API mobile::Module parse_and_initialize_mobile_module(
// versions above.
TORCH_API mobile::Module load_mobile_module_from_file(
const std::string& filename,
c10::optional<at::Device> device = c10::nullopt);
c10::optional<at::Device> device = c10::nullopt,
ExtraFilesMap* extra_files = nullptr);
TORCH_API void parseExtraFiles(
mobile::serialization::Module* module,
ExtraFilesMap& extra_files);
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_file_content(
const char* filename);
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
std::istream& in);
TORCH_API uint64_t get_bytecode_version(std::istream& in);
TORCH_API uint64_t get_bytecode_version(const std::string& filename);
TORCH_API uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content);
TORCH_API mobile::ModuleInfo get_module_info_from_flatbuffer(
char* flatbuffer_content);
// The methods below are less efficient because it need to read the stream in
// its entirity to a buffer
TORCH_API mobile::Module load_mobile_module_from_stream_with_copy(
std::istream& in,
c10::optional<at::Device> device = c10::nullopt,
ExtraFilesMap* extra_files = nullptr);
// This function will make the capabilities to load
// Module as a flatbuffer file available for use by _load_for_mobile
// and friends. This is NOT needed if using the other functions
// in this file directly.
TORCH_API bool register_flatbuffer_loader();
class TORCH_API FlatbufferLoader {
public:
FlatbufferLoader();

View File

@ -7,13 +7,12 @@
#include <c10/util/Exception.h>
#include <c10/util/ScopeExit.h>
#include <c10/util/irange.h>
#include <caffe2/serialize/in_memory_adapter.h>
#include <caffe2/serialize/inline_container.h>
#include <caffe2/serialize/read_adapter_interface.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/file_format.h>
#if defined(ENABLE_FLATBUFFER)
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#endif
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
@ -85,10 +84,23 @@
namespace torch {
namespace jit {
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::MemoryReadAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface;
mobile::Module (*load_flatbuffer_bytes)(
std::shared_ptr<char>,
size_t size,
c10::optional<at::Device>,
ExtraFilesMap*) = nullptr;
mobile::Module (*load_flatbuffer_bytes_no_object)(
std::shared_ptr<char>,
size_t size,
c10::optional<at::Device>) = nullptr;
uint64_t (*get_flatbuffer_bytecode_version)(char* flatbuffer_content) = nullptr;
OpCode parseOpCode(const char* str);
TypePtr resolveTypeNameMobile(
@ -516,6 +528,13 @@ mobile::Module _load_for_mobile_impl(
ExtraFilesMap& extra_files,
uint64_t module_load_options);
mobile::Module _load_mobile_from_bytes(
std::shared_ptr<char> data,
size_t size,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files,
uint64_t module_load_options);
mobile::Module _load_for_mobile(
std::istream& in,
c10::optional<at::Device> device) {
@ -541,38 +560,11 @@ mobile::Module _load_for_mobile(
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
in.seekg(0, in.beg);
auto format = getFileFormat(in);
switch (format) {
case FileFormat::ZipFileFormat: {
std::unique_ptr<IStreamAdapter> rai =
std::make_unique<IStreamAdapter>(&in);
auto module = _load_for_mobile(std::move(rai), device, extra_files);
return module;
}
#if defined(ENABLE_FLATBUFFER)
case FileFormat::FlatbufferFileFormat: {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_stream_content(in);
auto* flatbuffer_module =
mobile::serialization::GetMutableModule(data.get());
mobile::Module m = initialize_mobile_module(flatbuffer_module);
parseExtraFiles(flatbuffer_module, extra_files);
m.set_delete_memory(data);
return m;
}
#else
case FileFormat::FlatbufferFileFormat: {
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
}
#endif
default: {
TORCH_CHECK(false, "Format error");
}
}
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_stream_content(in);
return _load_mobile_from_bytes(
data, size, device, extra_files, kDefaultMobileLoadOptions);
}
mobile::Module _load_for_mobile(
@ -580,10 +572,7 @@ mobile::Module _load_for_mobile(
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
return _load_for_mobile(
filename,
device,
extra_files,
/*module_load_options=*/_default_mobile_module_load_options);
filename, device, extra_files, kDefaultMobileLoadOptions);
}
mobile::Module _load_for_mobile(
@ -591,49 +580,55 @@ mobile::Module _load_for_mobile(
c10::optional<at::Device> device,
ExtraFilesMap& extra_files,
uint64_t module_load_options) {
auto format = getFileFormat(filename);
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_file_content(filename.c_str());
return _load_mobile_from_bytes(
data, size, device, extra_files, module_load_options);
}
TORCH_API mobile::Module _load_for_mobile(
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files,
uint64_t module_load_options) {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_rai_content(rai.get());
return _load_mobile_from_bytes(
data, size, device, extra_files, module_load_options);
}
mobile::Module _load_mobile_from_bytes(
std::shared_ptr<char> data,
size_t size,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files,
uint64_t module_load_options) {
TORCH_CHECK(size >= kFileFormatHeaderSize, "Format error");
auto format = getFileFormat(data.get());
switch (format) {
case FileFormat::ZipFileFormat: {
std::unique_ptr<FileAdapter> rai =
std::make_unique<FileAdapter>(filename);
auto module = _load_for_mobile_impl(
std::unique_ptr<ReadAdapterInterface> rai =
std::make_unique<MemoryReadAdapter>(data.get(), size);
return _load_for_mobile_impl(
std::move(rai), device, extra_files, module_load_options);
return module;
}
#if defined(ENABLE_FLATBUFFER)
case FileFormat::FlatbufferFileFormat: {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_file_content(filename.c_str());
auto* flatbuffer_module =
mobile::serialization::GetMutableModule(data.get());
mobile::Module m = initialize_mobile_module(flatbuffer_module);
parseExtraFiles(flatbuffer_module, extra_files);
m.set_delete_memory(data);
return m;
if (load_flatbuffer_bytes != nullptr) {
return load_flatbuffer_bytes(data, size, device, &extra_files);
} else {
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
}
}
#else
case FileFormat::FlatbufferFileFormat: {
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
}
#endif
default: {
TORCH_CHECK(false, "Format error");
}
}
}
mobile::Module _load_for_mobile(
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files) {
auto module = _load_for_mobile_impl(
std::move(rai), device, extra_files, _default_mobile_module_load_options);
return module;
}
mobile::Module _load_for_mobile_impl(
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device,

View File

@ -1,5 +1,6 @@
#pragma once
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/parse_operators.h>
#include <istream>
#include <memory>
@ -32,7 +33,8 @@ TORCH_API mobile::Module _load_for_mobile(
TORCH_API mobile::Module _load_for_mobile(
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files);
ExtraFilesMap& extra_files,
uint64_t module_load_options = kDefaultMobileLoadOptions);
TORCH_API mobile::Module _load_for_mobile(
const std::string& filename,
@ -104,5 +106,19 @@ TORCH_API std::set<std::string> _export_operator_list(
torch::jit::mobile::Module& module);
} // namespace mobile
extern mobile::Module (*load_flatbuffer_bytes)(
std::shared_ptr<char>,
size_t size,
c10::optional<at::Device>,
ExtraFilesMap*);
extern mobile::Module (*load_flatbuffer_bytes_no_object)(
std::shared_ptr<char>,
size_t size,
c10::optional<at::Device>);
extern uint64_t (*get_flatbuffer_bytecode_version)(char* flatbuffer_content);
} // namespace jit
} // namespace torch

View File

@ -7,6 +7,7 @@
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/import_export_common.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/observer.h>
@ -15,10 +16,7 @@
#include <torch/csrc/jit/serialization/unpickler.h>
#include <torch/custom_class.h>
#if defined(ENABLE_FLATBUFFER)
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#endif // defined(ENABLE_FLATBUFFER)
#include <caffe2/serialize/in_memory_adapter.h>
#include <exception>
#include <fstream>
#include <string>
@ -28,6 +26,7 @@ namespace torch {
namespace jit {
using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::MemoryReadAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface;
@ -238,58 +237,30 @@ std::map<std::string, at::Tensor> mobile_module_to_parameter_map(
"' in deserialized mobile::Module");
}
std::map<std::string, at::Tensor> _load_parameters(
std::istream& in,
std::map<std::string, at::Tensor> _load_parameters_bytes(
std::shared_ptr<char> data,
size_t size,
c10::optional<at::Device> device) {
// Detect the data format from the head of the input stream.
FileFormat format = getFileFormat(in);
TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
FileFormat format = getFileFormat(data.get());
// Call the appropriate parser.
std::map<std::string, at::Tensor> map;
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if defined(ENABLE_FLATBUFFER)
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_stream_content(in);
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
"Format error");
auto* flatbuffer_module =
mobile::serialization::GetMutableModule(data.get());
FlatbufferLoader loader;
// replace parserObject with to handle only class with field case
// function.
loader.registerIValueParser(
mobile::serialization::IValueUnion::Object,
+[](FlatbufferLoader& loader,
const mobile::serialization::IValue& ivalue) {
const mobile::serialization::Object* object =
ivalue.val_as_Object();
auto cls = loader.getOrCreateClassTypeForObject(object);
auto obj = c10::ivalue::Object::create(
at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
IValue val = loader.getIValue(object->attrs()->Get(i));
obj->setSlot(i, std::move(val));
}
return static_cast<c10::IValue>(obj);
});
mobile::Module m = loader.parseModule(flatbuffer_module);
m.set_delete_memory(std::move(data));
map = mobile_module_to_parameter_map(m);
#else // !defined(ENABLE_FLATBUFFER)
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
#endif // !defined(ENABLE_FLATBUFFER)
if (load_flatbuffer_bytes_no_object != nullptr) {
auto m = load_flatbuffer_bytes_no_object(data, size, device);
map = mobile_module_to_parameter_map(m);
} else {
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
}
break;
}
case FileFormat::ZipFileFormat: {
std::unique_ptr<IStreamAdapter> rai =
std::make_unique<IStreamAdapter>(&in);
auto rai = std::make_unique<caffe2::serialize::MemoryReadAdapter>(
data.get(), size);
map = load_parameters_from_zip(std::move(rai), device);
break;
}
@ -300,38 +271,22 @@ std::map<std::string, at::Tensor> _load_parameters(
return map;
}
std::map<std::string, at::Tensor> _load_parameters(
std::istream& in,
c10::optional<at::Device> device) {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_stream_content(in);
return _load_parameters_bytes(std::move(data), size, device);
}
std::map<std::string, at::Tensor> _load_parameters(
const std::string& filename,
c10::optional<at::Device> device) {
// Detect the file format from its header.
FileFormat format = getFileFormat(filename);
// Call the appropriate parser.
std::map<std::string, at::Tensor> map;
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if defined(ENABLE_FLATBUFFER)
mobile::Module module = load_mobile_module_from_file(filename, device);
map = mobile_module_to_parameter_map(module);
#else // !defined(ENABLE_FLATBUFFER)
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
#endif // !defined(ENABLE_FLATBUFFER)
break;
}
case FileFormat::ZipFileFormat: {
std::unique_ptr<FileAdapter> rai =
std::make_unique<FileAdapter>(filename);
map = load_parameters_from_zip(std::move(rai), device);
break;
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
return map;
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_file_content(filename.c_str());
return _load_parameters_bytes(std::move(data), size, device);
}
} // namespace jit

View File

@ -9,7 +9,7 @@ enum MobileModuleLoadOptions {
OPERATOR_CHECK = 1,
};
const uint64_t _default_mobile_module_load_options =
const uint64_t kDefaultMobileLoadOptions =
MobileModuleLoadOptions::OPERATOR_CHECK;
namespace mobile {

View File

@ -11,11 +11,6 @@
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#if defined(ENABLE_FLATBUFFER)
#include <flatbuffers/flatbuffers.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#endif // defined(ENABLE_FLATBUFFER)
#include <string>
#include <vector>
@ -119,36 +114,34 @@ mobile::Module tensor_dict_to_mobile(
} // namespace mobile
void (*_save_mobile_module_to)(
const mobile::Module& module,
const std::function<size_t(const void*, size_t)>& writer_func) = nullptr;
void _save_parameters(
const std::map<std::string, at::Tensor>& map,
std::ostream& out,
bool use_flatbuffer) {
auto dict = mobile::tensor_map_to_dict(map);
if (use_flatbuffer) {
#if defined(ENABLE_FLATBUFFER)
// For Flatbuffer, we serialize an entire, mostly-empty module containing
// the dict as an attribute.
flatbuffers::DetachedBuffer bytes = torch::jit::save_mobile_module_to_bytes(
mobile::tensor_dict_to_mobile(dict));
auto write_func = [&out](const void* buf, size_t nbytes) -> size_t {
out.write(
reinterpret_cast<char*>(bytes.data()),
static_cast<std::streamsize>(bytes.size()));
#else // !defined(ENABLE_FLATBUFFER)
TORCH_CHECK(
false,
"Trying to export as flatbuffer file but "
"the build hasn't enabled flatbuffer");
#endif // !defined(ENABLE_FLATBUFFER)
static_cast<const char*>(buf), static_cast<std::streamsize>(nbytes));
return !out ? 0 : nbytes;
};
if (use_flatbuffer) {
if (_save_mobile_module_to != nullptr) {
_save_mobile_module_to(mobile::tensor_dict_to_mobile(dict), write_func);
} else {
TORCH_CHECK(
false,
"Trying to export as flatbuffer file but "
"the build hasn't enabled flatbuffer");
}
} else {
// For Pickle, we only serialize the dict itself.
mobile::IValuePickler pickler(
[&](const void* buf, size_t nbytes) -> size_t {
out.write(
static_cast<const char*>(buf),
static_cast<std::streamsize>(nbytes));
return !out ? 0 : nbytes;
});
mobile::IValuePickler pickler(write_func);
pickler.serialize(dict);
}
}
@ -159,23 +152,8 @@ void _save_parameters(
bool use_flatbuffer) {
auto dict = mobile::tensor_map_to_dict(map);
if (use_flatbuffer) {
#if defined(ENABLE_FLATBUFFER)
// For Flatbuffer, we serialize an entire, mostly-empty module containing
// the dict as an attribute.
torch::jit::save_mobile_module(
mobile::tensor_dict_to_mobile(dict), filename);
#else // !defined(ENABLE_FLATBUFFER)
TORCH_CHECK(
false,
"Trying to export as flatbuffer file but "
"the build hasn't enabled flatbuffer");
#endif // !defined(ENABLE_FLATBUFFER)
} else {
// For Pickle, we only serialize the dict itself.
mobile::IValuePickler pickler(filename);
pickler.serialize(dict);
}
std::ofstream ifile(filename);
_save_parameters(map, ifile, use_flatbuffer);
}
} // namespace jit

View File

@ -45,5 +45,9 @@ c10::Dict<std::string, at::Tensor> tensor_map_to_dict(
} // namespace mobile
extern void (*_save_mobile_module_to)(
const mobile::Module& module,
const std::function<size_t(const void*, size_t)>& writer_func);
} // namespace jit
} // namespace torch

View File

@ -262,5 +262,11 @@ bool getMobileInterfaceCallExport();
CompilationOptions getOptionsFromGlobal();
extern void (*_save_jit_module_to)(
const Module& module,
const ExtraFilesMap& extra_files,
bool save_mobile_debug_info,
const std::function<size_t(const void*, size_t)>& writer_func);
} // namespace jit
} // namespace torch

View File

@ -16,10 +16,6 @@
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
#include <torch/csrc/jit/serialization/export_bytecode.h>
#if defined(ENABLE_FLATBUFFER)
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
#endif
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_export_functions.h>
#include <torch/csrc/jit/serialization/import_export_helpers.h>
@ -34,6 +30,8 @@
#include <ATen/core/jit_type.h>
#include <ATen/core/qualified_name.h>
#include <cerrno>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
@ -817,23 +815,6 @@ SerializationStorageContext& ScriptModuleSerializer::storage_context() {
return storage_context_;
}
#if defined(ENABLE_FLATBUFFER)
void save_mobile_module_to(
const Module& module,
const ExtraFilesMap& extra_files,
bool save_mobile_debug_info,
const std::function<size_t(const void*, size_t)>& writer_func) {
ExtraFilesMap jitFiles;
CompilationOptions options = getOptionsFromGlobal();
std::vector<IValue> constants;
jitModuleToPythonCodeAndConstants(module, &jitFiles, &constants);
mobile::Module mod = jitModuleToMobile(module, options);
auto buffer =
save_mobile_module_to_bytes(mod, extra_files, jitFiles, constants);
writer_func(reinterpret_cast<void*>(buffer.data()), buffer.size());
}
#endif
void ExportModule(
const Module& module,
std::ostream& out,
@ -845,21 +826,13 @@ void ExportModule(
out.write(static_cast<const char*>(buf), nbytes);
return !out ? 0 : nbytes;
};
if (use_flatbuffer) {
#if defined(ENABLE_FLATBUFFER)
save_mobile_module_to(
module, extra_files, save_mobile_debug_info, writer_func);
#else
TORCH_CHECK(
false,
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
#endif
} else {
caffe2::serialize::PyTorchStreamWriter writer(writer_func);
ScriptModuleSerializer serializer(writer);
serializer.serialize(
module, extra_files, bytecode_format, save_mobile_debug_info);
}
ExportModule(
module,
writer_func,
extra_files,
bytecode_format,
save_mobile_debug_info,
use_flatbuffer);
}
void ExportModule(
@ -869,29 +842,41 @@ void ExportModule(
bool bytecode_format,
bool save_mobile_debug_info,
bool use_flatbuffer) {
if (use_flatbuffer) {
#if defined(ENABLE_FLATBUFFER)
auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
std::fstream ofile(filename, std::ios::binary | std::ios::out);
ofile.write(static_cast<const char*>(buf), nbytes);
ofile.close();
return !ofile ? 0 : nbytes;
};
save_mobile_module_to(
module, extra_files, save_mobile_debug_info, writer_func);
#else
TORCH_CHECK(
false,
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
#endif
} else {
if (!use_flatbuffer) {
// the zip archive need to know the filepath
caffe2::serialize::PyTorchStreamWriter writer(filename);
ScriptModuleSerializer serializer(writer);
serializer.serialize(
module, extra_files, bytecode_format, save_mobile_debug_info);
return;
}
std::ofstream ofile;
ofile.open(filename, std::ios::binary | std::ios::out);
if (ofile.fail()) {
std::stringstream message;
if (errno == ENOENT) {
message << "Parent directory of " << filename << " does not exist.\n";
} else {
message << "Error while opening file: " << errno << std::endl;
;
}
TORCH_CHECK(false, message.str());
}
ExportModule(
module,
ofile,
extra_files,
bytecode_format,
save_mobile_debug_info,
use_flatbuffer);
}
void (*_save_jit_module_to)(
const Module& module,
const ExtraFilesMap& extra_files,
bool save_mobile_debug_info,
const std::function<size_t(const void*, size_t)>& writer_func) = nullptr;
void ExportModule(
const Module& module,
const std::function<size_t(const void*, size_t)>& writer_func,
@ -900,14 +885,14 @@ void ExportModule(
bool save_mobile_debug_info,
bool use_flatbuffer) {
if (use_flatbuffer) {
#if defined(ENABLE_FLATBUFFER)
save_mobile_module_to(
module, extra_files, save_mobile_debug_info, writer_func);
#else
TORCH_CHECK(
false,
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
#endif
if (_save_jit_module_to != nullptr) {
_save_jit_module_to(
module, extra_files, save_mobile_debug_info, writer_func);
} else {
TORCH_CHECK(
false,
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
}
} else {
caffe2::serialize::PyTorchStreamWriter writer(writer_func);
ScriptModuleSerializer serializer(writer);

View File

@ -2,10 +2,12 @@
#include <ATen/ATen.h>
#include <c10/core/CPUAllocator.h>
#include <c10/util/Exception.h>
#include <caffe2/serialize/versions.h>
#include <flatbuffers/flatbuffers.h>
#include <torch/csrc/jit/mobile/code.h>
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/mobile/train/export_data.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/export.h>
@ -776,5 +778,19 @@ flatbuffers::DetachedBuffer save_mobile_module_to_bytes(
jit_constants);
}
void save_mobile_module_to_func(
const mobile::Module& module,
const std::function<size_t(const void*, size_t)>& writer_func) {
auto buffer = save_mobile_module_to_bytes(module);
writer_func(buffer.data(), buffer.size());
}
bool register_flatbuffer_serializer() {
_save_mobile_module_to = save_mobile_module_to_func;
return true;
}
const bool kFlatbufferSerializerRegistered = register_flatbuffer_serializer();
} // namespace jit
} // namespace torch

View File

@ -29,5 +29,11 @@ TORCH_API flatbuffers::DetachedBuffer save_mobile_module_to_bytes(
const ExtraFilesMap& jit_sources = ExtraFilesMap(),
const std::vector<IValue>& jit_constants = {});
// This function will make the capabilities to load and safe
// Module as a flatbuffer file available for use by _load_for_mobile
// and friends. This is NOT needed if using the other functions
// in this file directly.
TORCH_API bool register_flatbuffer_serializer();
} // namespace jit
} // namespace torch

View File

@ -1,5 +1,6 @@
#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
#include <torch/csrc/jit/serialization/export.h>
@ -73,5 +74,25 @@ flatbuffers::DetachedBuffer save_jit_module_to_bytes(
return save_mobile_module_to_bytes(mobilem, extra_files, jitfiles, constants);
}
void save_jit_module_to_write_func(
const Module& module,
const ExtraFilesMap& extra_files,
bool save_mobile_debug_info,
const std::function<size_t(const void*, size_t)>& writer_func) {
(void)save_mobile_debug_info;
auto buffer = save_jit_module_to_bytes(module, extra_files);
writer_func(reinterpret_cast<void*>(buffer.data()), buffer.size());
}
bool register_flatbuffer_all() {
(void)register_flatbuffer_loader();
(void)register_flatbuffer_serializer();
_save_jit_module_to = save_jit_module_to_write_func;
_load_jit_module_from_flatbuffer_bytes = parse_and_initialize_jit_module;
return true;
}
const bool kFlatbufferSerializerJitInitialized = register_flatbuffer_all();
} // namespace jit
} // namespace torch

View File

@ -31,5 +31,11 @@ TORCH_API Module load_jit_module_from_stream(
ExtraFilesMap& extra_files,
c10::optional<at::Device> device = c10::nullopt);
// This function will make the capabilities to load and safe
// Module as a flatbuffer file available for use by _load_for_mobile
// and friends. This is NOT needed if using the other functions
// in this file directly.
TORCH_API bool register_flatbuffer_all();
} // namespace jit
} // namespace torch

View File

@ -1,4 +1,17 @@
#include <ATen/core/interned_strings.h>
#include <c10/core/CPUAllocator.h>
#include <c10/core/impl/alloc_cpu.h>
#include <caffe2/serialize/file_adapter.h>
#include <caffe2/serialize/in_memory_adapter.h>
#include <caffe2/serialize/inline_container.h>
#include <caffe2/serialize/istream_adapter.h>
#include <caffe2/serialize/read_adapter_interface.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/serialization/source_range_serialization.h>
#include <ATen/core/functional.h>
#include <ATen/core/ivalue_inl.h>
@ -19,16 +32,6 @@
#include <torch/csrc/jit/serialization/source_range_serialization.h>
#include <torch/csrc/jit/serialization/unpickler.h>
#if defined(ENABLE_FLATBUFFER)
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
#endif
#include <caffe2/serialize/file_adapter.h>
#include <caffe2/serialize/inline_container.h>
#include <caffe2/serialize/istream_adapter.h>
#include <caffe2/serialize/versions.h>
#include <ATen/ATen.h>
#include <fmt/format.h>
@ -42,6 +45,7 @@ namespace jit {
using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::MemoryReadAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface;
@ -291,31 +295,29 @@ Module import_ir_module(
return import_ir_module(std::move(cu), in, device, extra_files);
}
Module (*_load_jit_module_from_flatbuffer_bytes)(
std::shared_ptr<char>,
size_t,
ExtraFilesMap&,
c10::optional<at::Device>) = nullptr;
static Module _load_jit_module_from_bytes(
std::shared_ptr<char> data,
size_t size,
std::shared_ptr<CompilationUnit> cu,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files);
Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
in.seekg(0, in.beg);
auto format = getFileFormat(in);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if defined(ENABLE_FLATBUFFER)
return load_jit_module_from_stream(in, extra_files, device);
#else
TORCH_CHECK(
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
#endif
}
case FileFormat::ZipFileFormat: {
auto reader = torch::make_unique<PyTorchStreamReader>(&in);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_stream_content(in);
return _load_jit_module_from_bytes(data, size, cu, device, extra_files);
}
// For reading unified serialization format from torch.Package.
@ -348,25 +350,10 @@ Module import_ir_module(
const std::string& filename,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
auto format = getFileFormat(filename);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if defined(ENABLE_FLATBUFFER)
return load_jit_module_from_file(filename, extra_files, device);
#else
TORCH_CHECK(
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
#endif
}
case FileFormat::ZipFileFormat: {
auto reader = torch::make_unique<PyTorchStreamReader>(filename);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_file_content(filename.c_str());
return _load_jit_module_from_bytes(data, size, cu, device, extra_files);
}
Module import_ir_module(
@ -382,100 +369,92 @@ Module import_ir_module(
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
std::shared_ptr<ReadAdapterInterface> rai_shared = std::move(rai);
return import_ir_module(cu, rai_shared, device, extra_files);
}
Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::shared_ptr<ReadAdapterInterface> rai,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_rai_content(rai.get());
return _load_jit_module_from_bytes(data, size, cu, device, extra_files);
}
Module load(std::istream& in, c10::optional<at::Device> device) {
ExtraFilesMap extra_files;
return load(in, device, extra_files);
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(std::move(cu), in, device);
}
Module load(
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
in.seekg(0, in.beg);
auto format = getFileFormat(in);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if defined(ENABLE_FLATBUFFER)
return load_jit_module_from_stream(in, extra_files, device);
#else
TORCH_CHECK(
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
#endif
}
case FileFormat::ZipFileFormat: {
std::unique_ptr<IStreamAdapter> rai =
std::make_unique<IStreamAdapter>(&in);
auto module = load(std::move(rai), device, extra_files);
return module;
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(std::move(cu), in, device, extra_files);
}
Module load(const std::string& filename, c10::optional<at::Device> device) {
ExtraFilesMap extra_files;
return load(filename, device, extra_files);
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(std::move(cu), filename, device);
}
Module load(
const std::string& filename,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
auto format = getFileFormat(filename);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if defined(ENABLE_FLATBUFFER)
return load_jit_module_from_file(filename, extra_files, device);
#else
TORCH_CHECK(
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
#endif
case FileFormat::ZipFileFormat: {
std::unique_ptr<FileAdapter> rai =
std::make_unique<FileAdapter>(filename);
auto module = load(std::move(rai), device, extra_files);
return module;
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
}
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(std::move(cu), filename, device, extra_files);
}
Module load(
std::shared_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device) {
auto cu = std::make_shared<CompilationUnit>();
ExtraFilesMap extra_files;
return load(std::move(rai), device, extra_files);
return import_ir_module(std::move(cu), std::move(rai), device, extra_files);
}
Module load(
std::shared_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files) {
// Verify that we're loading a zip archive and not a torch.save pickle
// archive (marked by the 0x80 0x02 bytes at the start)
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
TORCH_CHECK(
check_zip_file(rai),
"`torch::jit::load()` received a file from `torch.save()`, "
"but `torch::jit::load()` can only load files"
" produced by `torch.jit.save()`");
auto reader = std::make_shared<PyTorchStreamReader>(std::move(rai));
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(std::move(cu), std::move(rai), device, extra_files);
}
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
Module _load_jit_module_from_bytes(
std::shared_ptr<char> data,
size_t size,
std::shared_ptr<CompilationUnit> cu,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files) {
TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecorgnized data format");
auto format = getFileFormat(data.get());
switch (format) {
case FileFormat::FlatbufferFileFormat: {
if (_load_jit_module_from_flatbuffer_bytes != nullptr) {
return _load_jit_module_from_flatbuffer_bytes(
data, size, extra_files, device);
} else {
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enable flatbuffer")
}
}
case FileFormat::ZipFileFormat: {
auto rai = std::make_unique<MemoryReadAdapter>(data.get(), size);
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
}
default:
TORCH_CHECK(false, "Unrecognized data format");
}
}
// Replace object with a newly created but equivalent object.

View File

@ -58,6 +58,12 @@ TORCH_API Module import_ir_module(
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files);
TORCH_API Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files);
/// Loads a serialized `Module` from the given `istream`.
///
/// The istream must contain a serialized `Module`, exported via
@ -104,5 +110,19 @@ TORCH_API Module jitModuleFromSourceAndConstants(
const std::vector<IValue>& constants,
int32_t version);
extern Module (*_load_jit_module_from_flatbuffer_bytes)(
// comp unit
std::shared_ptr<char>,
size_t,
ExtraFilesMap&,
c10::optional<at::Device>);
extern Module (*_load_jit_module_from_flatbuffer_bytes)(
// comp unit
std::shared_ptr<char>,
size_t,
ExtraFilesMap&,
c10::optional<at::Device>);
} // namespace jit
} // namespace torch