mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][flatbuffer] Remove code duplications and refactor (#79184)
Summary: Remove code dup in import.cpp / export_modules.cpp such that 1. Only one copy of switching logic (detect flatbuffer / is_flatbuffer); 2. Move detection of includeness of flatbuffer to runtime (so no more macros) This also reverts the dependency of import.cpp -> flatbuffer_loader.cpp to flatbuffer_loader.cpp -> import.cpp. Differential Revision: D36926217 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79184 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
7de231a813
commit
fed12ff680
|
|
@ -1645,7 +1645,8 @@ cc_library(
|
|||
],
|
||||
)) + libtorch_core_sources + libtorch_distributed_sources + torch_cpp_srcs + libtorch_extra_sources + jit_core_sources + lazy_tensor_ts_sources + GENERATED_AUTOGRAD_CPP + [
|
||||
"torch/csrc/jit/serialization/flatbuffer_serializer.cpp",
|
||||
"torch/csrc/jit/mobile/flatbuffer_loader.cpp"
|
||||
"torch/csrc/jit/mobile/flatbuffer_loader.cpp",
|
||||
"torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp",
|
||||
],
|
||||
copts = TORCH_COPTS,
|
||||
defines = [
|
||||
|
|
|
|||
32
caffe2/serialize/in_memory_adapter.h
Normal file
32
caffe2/serialize/in_memory_adapter.h
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
#pragma once
|
||||
#include <cstring>
|
||||
#include <caffe2/serialize/read_adapter_interface.h>
|
||||
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
|
||||
public:
|
||||
explicit MemoryReadAdapter(const void* data, off_t size)
|
||||
: data_(data), size_(size) {}
|
||||
|
||||
size_t size() const override {
|
||||
return size_;
|
||||
}
|
||||
|
||||
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
||||
const override {
|
||||
(void) what;
|
||||
memcpy(buf, (int8_t*)(data_) + pos, n);
|
||||
return n;
|
||||
}
|
||||
|
||||
private:
|
||||
const void* data_;
|
||||
off_t size_;
|
||||
};
|
||||
|
||||
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
||||
|
|
@ -174,7 +174,7 @@ TEST(FlatbufferTest, MethodInvocation) { // NOLINT (use =delete in gtest)
|
|||
}
|
||||
}
|
||||
|
||||
#if defined(ENABLE_FLATBUFFER) && !defined(FB_XPLAT_BUILD)
|
||||
#if !defined(FB_XPLAT_BUILD)
|
||||
TEST(FlatbufferTest, FlatbufferBackPortTest) {
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
|
|
@ -188,7 +188,7 @@ TEST(FlatbufferTest, FlatbufferBackPortTest) {
|
|||
bool backPortSuccess = _backport_for_mobile(ss, oss, 5);
|
||||
ASSERT_TRUE(backPortSuccess);
|
||||
}
|
||||
#endif // defined(ENABLE_FLATBUFFER) && !defined(FB_XPLAT_BUILD)
|
||||
#endif // !defined(FB_XPLAT_BUILD)
|
||||
|
||||
TEST(FlatbufferTest, ExtraFiles) {
|
||||
const auto script = R"JIT(
|
||||
|
|
@ -207,7 +207,6 @@ TEST(FlatbufferTest, ExtraFiles) {
|
|||
extra_files["mobile_info.json"] = "{\"key\": 23}";
|
||||
|
||||
std::unordered_map<std::string, std::string> loaded_extra_files;
|
||||
#if defined ENABLE_FLATBUFFER
|
||||
std::stringstream ss;
|
||||
module->_save_for_mobile(ss, extra_files, true, /*use_flatbuffer=*/true);
|
||||
|
||||
|
|
@ -219,17 +218,6 @@ TEST(FlatbufferTest, ExtraFiles) {
|
|||
|
||||
// load it twice using the same stream
|
||||
auto mobile_module2 = _load_for_mobile(ss, c10::nullopt, loaded_extra_files);
|
||||
#else
|
||||
CompilationOptions options;
|
||||
mobile::Module bc = jitModuleToMobile(*module, options);
|
||||
auto buff = save_mobile_module_to_bytes(bc, extra_files);
|
||||
|
||||
loaded_extra_files["metadata.json"] = "";
|
||||
auto* flatbuffer_module =
|
||||
mobile::serialization::GetMutableModule(buff.data());
|
||||
|
||||
parseExtraFiles(flatbuffer_module, loaded_extra_files);
|
||||
#endif
|
||||
|
||||
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
|
||||
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
|
||||
|
|
@ -1283,7 +1271,6 @@ Module jitModuleFromBuffer(void* data) {
|
|||
mobilem._ivalue(), files, constants, 8);
|
||||
}
|
||||
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
TEST(TestSourceFlatbuffer, UpsampleNearest2d) {
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
|
|
@ -1375,7 +1362,6 @@ TEST(TestSourceFlatbuffer,
|
|||
AT_ASSERT(resd == refd);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined FB_XPLAT_BUILD
|
||||
// The following test run in fbcode only
|
||||
|
|
|
|||
|
|
@ -635,7 +635,7 @@ void backportAllVersionCheck(
|
|||
std::vector<IValue>& expect_result_list,
|
||||
const uint64_t expect_from_version) {
|
||||
auto from_version = _get_model_bytecode_version(test_model_file_stream);
|
||||
AT_ASSERT(from_version == expect_from_version);
|
||||
EXPECT_EQ(from_version, expect_from_version);
|
||||
AT_ASSERT(from_version > 0);
|
||||
|
||||
// Backport script_module_v5.ptl to an older version
|
||||
|
|
@ -717,15 +717,11 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
|
|||
torch::jit::Module module_freeze = freeze(module);
|
||||
|
||||
std::stringstream input_model_stream;
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
module_freeze._save_for_mobile(
|
||||
input_model_stream,
|
||||
/*extra_files=*/{},
|
||||
/*save_mobile_debug_info=*/false,
|
||||
/*use_flatbuffer=*/true);
|
||||
#else
|
||||
module_freeze._save_for_mobile(input_model_stream);
|
||||
#endif
|
||||
std::vector<IValue> input_data =
|
||||
std::vector<IValue>({torch::ones({1, 1, 28, 28})});
|
||||
std::vector<IValue> expect_result_list;
|
||||
|
|
@ -748,7 +744,7 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
|
|||
input_model_stream,
|
||||
input_data,
|
||||
expect_result_list,
|
||||
caffe2::serialize::kProducedBytecodeVersion);
|
||||
9); // flatbuffer starts at 9
|
||||
}
|
||||
#endif // !defined(FB_XPLAT_BUILD)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <torch/csrc/jit/mobile/train/optim/sgd.h>
|
||||
#include <torch/csrc/jit/mobile/train/random.h>
|
||||
#include <torch/csrc/jit/mobile/train/sequential.h>
|
||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/data/dataloader.h>
|
||||
#include <torch/torch.h>
|
||||
|
|
@ -172,9 +173,9 @@ TEST(MobileTest, SaveParametersDefaultsToZip) {
|
|||
EXPECT_EQ(ss_data.str()[3], '\x04');
|
||||
}
|
||||
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
TEST(MobileTest, SaveParametersCanUseFlatbuffer) {
|
||||
// Save some empty parameters using flatbuffer.
|
||||
register_flatbuffer_all();
|
||||
std::map<std::string, at::Tensor> empty_parameters;
|
||||
std::stringstream ss_data;
|
||||
_save_parameters(empty_parameters, ss_data, /*use_flatbuffer=*/true);
|
||||
|
|
@ -188,34 +189,10 @@ TEST(MobileTest, SaveParametersCanUseFlatbuffer) {
|
|||
EXPECT_EQ(ss_data.str()[6], 'M');
|
||||
EXPECT_EQ(ss_data.str()[7], 'F');
|
||||
}
|
||||
#else // !defined(ENABLE_FLATBUFFER)
|
||||
TEST(MobileTest, SaveParametersThrowsWithoutFlatbufferSupport) {
|
||||
// Some empty parameters to try saving.
|
||||
std::map<std::string, at::Tensor> empty_parameters;
|
||||
std::stringstream ss_data;
|
||||
|
||||
// Save using flatbuffers should fail when support isn't compiled in. Make
|
||||
// sure we get the exception that explicitly mentions the lack of flatbuffer
|
||||
// support.
|
||||
try {
|
||||
_save_parameters(empty_parameters, ss_data, /*use_flatbuffer=*/true);
|
||||
FAIL() << "_save_parameters should have thrown";
|
||||
} catch (const ::c10::Error& e) {
|
||||
static const std::string kExpectedSubstring =
|
||||
"build hasn't enabled flatbuffer";
|
||||
EXPECT_TRUE(
|
||||
std::string(e.msg()).find(kExpectedSubstring) != std::string::npos)
|
||||
<< "Exception message does not contain expected substring \""
|
||||
<< kExpectedSubstring << "\": actual message \"" << e.msg() << "\"";
|
||||
} catch (...) {
|
||||
FAIL() << "Unexpected exception type";
|
||||
}
|
||||
}
|
||||
#endif // !defined(ENABLE_FLATBUFFER)
|
||||
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
TEST(MobileTest, SaveLoadParametersUsingFlatbuffers) {
|
||||
// Create some simple parameters to save.
|
||||
register_flatbuffer_all();
|
||||
std::map<std::string, at::Tensor> input_params;
|
||||
input_params["four_by_ones"] = 4 * torch::ones({});
|
||||
input_params["three_by_ones"] = 3 * torch::ones({});
|
||||
|
|
@ -244,33 +221,6 @@ TEST(MobileTest, SaveLoadParametersUsingFlatbuffers) {
|
|||
output_params["three_by_ones"].item<int>(), three_by_ones.item<int>());
|
||||
}
|
||||
}
|
||||
#else // !defined(ENABLE_FLATBUFFER)
|
||||
TEST(MobileTest, LoadParametersFailsWithoutFlatbufferSupport) {
|
||||
// Create some data that looks like a flatbuffer header.
|
||||
std::stringstream data;
|
||||
data << "abcd"
|
||||
<< "PTMF" // Flatbuffer magic
|
||||
<< "ijkl";
|
||||
|
||||
// Loading the "flatbuffer" data should fail. Make sure we see the expected
|
||||
// exception, not just any exception; since this isn't properly-formed
|
||||
// flatbuffer data, any attempt to parse it might throw a different error type
|
||||
// or message, but we don't expect anyone to try parsing it.
|
||||
try {
|
||||
_load_parameters(data);
|
||||
FAIL() << "_load_parameters should have thrown";
|
||||
} catch (const ::c10::Error& e) {
|
||||
static const std::string kExpectedSubstring =
|
||||
"build hasn't enabled flatbuffer";
|
||||
EXPECT_TRUE(
|
||||
std::string(e.msg()).find(kExpectedSubstring) != std::string::npos)
|
||||
<< "Exception message does not contain expected substring \""
|
||||
<< kExpectedSubstring << "\": actual message \"" << e.msg() << "\"";
|
||||
} catch (...) {
|
||||
FAIL() << "Unexpected exception type";
|
||||
}
|
||||
}
|
||||
#endif // !defined(ENABLE_FLATBUFFER)
|
||||
|
||||
TEST(MobileTest, LoadParametersUnexpectedFormatShouldThrow) {
|
||||
// Manually create some data that doesn't look like a ZIP or Flatbuffer file.
|
||||
|
|
|
|||
|
|
@ -153,7 +153,7 @@ def add_torch_libs():
|
|||
] if enable_flatbuffer else []),
|
||||
link_whole = True,
|
||||
include_directories = include_directories,
|
||||
propagated_pp_flags = propagated_pp_flags_cpu + (["-DENABLE_FLATBUFFER"] if enable_flatbuffer else []),
|
||||
propagated_pp_flags = propagated_pp_flags_cpu,
|
||||
exported_deps = (
|
||||
[
|
||||
":ATen-cpu",
|
||||
|
|
|
|||
|
|
@ -4,14 +4,12 @@
|
|||
#include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader
|
||||
#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
|
||||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
#endif
|
||||
#include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_constants.h>
|
||||
#include <torch/csrc/jit/serialization/import_read.h>
|
||||
|
||||
#include <caffe2/serialize/in_memory_adapter.h>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
|
@ -71,59 +69,33 @@ std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) {
|
|||
// Forward declare
|
||||
uint64_t _get_model_bytecode_version(
|
||||
const std::vector<IValue>& bytecode_ivalues);
|
||||
static uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size);
|
||||
|
||||
uint64_t _get_model_bytecode_version(std::istream& in) {
|
||||
auto orig_pos = in.tellg();
|
||||
in.seekg(0, in.beg);
|
||||
auto format = getFileFormat(in);
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if !defined(ENABLE_FLATBUFFER)
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enabled flatbuffer");
|
||||
#else
|
||||
return get_bytecode_version(in);
|
||||
#endif
|
||||
}
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<IStreamAdapter> rai =
|
||||
std::make_unique<IStreamAdapter>(&in);
|
||||
auto version = _get_model_bytecode_version(std::move(rai));
|
||||
in.seekg(orig_pos, in.beg);
|
||||
return version;
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_stream_content(in);
|
||||
in.seekg(orig_pos, in.beg);
|
||||
return _get_model_bytecode_version_from_bytes(data.get(), size);
|
||||
}
|
||||
|
||||
uint64_t _get_model_bytecode_version(const std::string& filename) {
|
||||
auto format = getFileFormat(filename);
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if !defined(ENABLE_FLATBUFFER)
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enabled flatbuffer");
|
||||
#else
|
||||
return get_bytecode_version(filename);
|
||||
#endif
|
||||
}
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<FileAdapter> rai =
|
||||
std::make_unique<FileAdapter>(filename);
|
||||
return _get_model_bytecode_version(std::move(rai));
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
std::ifstream ifile(filename);
|
||||
return _get_model_bytecode_version(ifile);
|
||||
}
|
||||
|
||||
uint64_t _get_model_bytecode_version(
|
||||
std::shared_ptr<ReadAdapterInterface> rai) {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_rai_content(rai.get());
|
||||
return _get_model_bytecode_version_from_bytes(data.get(), size);
|
||||
}
|
||||
|
||||
uint64_t _get_model_bytecode_version_zip(
|
||||
std::shared_ptr<ReadAdapterInterface> rai) {
|
||||
if (!check_zip_file(rai)) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
|
@ -134,6 +106,31 @@ uint64_t _get_model_bytecode_version(
|
|||
return _get_model_bytecode_version(bytecode_values);
|
||||
}
|
||||
|
||||
uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size) {
|
||||
TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
|
||||
auto format = getFileFormat(data);
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
if (get_flatbuffer_bytecode_version == nullptr) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enabled flatbuffer");
|
||||
} else {
|
||||
return get_flatbuffer_bytecode_version(data);
|
||||
}
|
||||
}
|
||||
case FileFormat::ZipFileFormat: {
|
||||
auto rai =
|
||||
std::make_unique<caffe2::serialize::MemoryReadAdapter>(data, size);
|
||||
auto version = _get_model_bytecode_version_zip(std::move(rai));
|
||||
return version;
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t _get_model_bytecode_version(
|
||||
const std::vector<IValue>& bytecode_ivalues) {
|
||||
if (!bytecode_ivalues.empty() && bytecode_ivalues[0].isInt()) {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,24 @@
|
|||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <cerrno>
|
||||
#include <cstddef>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <istream>
|
||||
#include <memory>
|
||||
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/core/impl/alloc_cpu.h>
|
||||
#include <caffe2/serialize/read_adapter_interface.h>
|
||||
|
||||
#if defined(HAVE_MMAP)
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @file
|
||||
|
|
@ -27,18 +41,16 @@ enum class FileFormat {
|
|||
ZipFileFormat,
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
|
||||
/// The size of the buffer to pass to #getFileFormat(), in bytes.
|
||||
constexpr size_t kFileFormatHeaderSize = 8;
|
||||
constexpr size_t kMaxAlignment = 16;
|
||||
|
||||
/**
|
||||
* Returns the likely file format based on the magic header bytes in @p header,
|
||||
* which should contain the first bytes of a file or data stream.
|
||||
*/
|
||||
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
|
||||
static inline FileFormat getFileFormat(
|
||||
const std::array<char, kFileFormatHeaderSize>& header) {
|
||||
static inline FileFormat getFileFormat(const char* data) {
|
||||
// The size of magic strings to look for in the buffer.
|
||||
static constexpr size_t kMagicSize = 4;
|
||||
|
||||
|
|
@ -60,22 +72,19 @@ static inline FileFormat getFileFormat(
|
|||
// that do not typically cross into the printable ASCII range, so a ZIP file
|
||||
// should never have a header that looks like a Flatbuffer file.
|
||||
if (std::memcmp(
|
||||
header.data() + kFlatbufferMagicOffset,
|
||||
data + kFlatbufferMagicOffset,
|
||||
kFlatbufferMagicString.data(),
|
||||
kMagicSize) == 0) {
|
||||
// Magic header for a binary file containing a Flatbuffer-serialized mobile
|
||||
// Module.
|
||||
return FileFormat::FlatbufferFileFormat;
|
||||
} else if (
|
||||
std::memcmp(header.data(), kZipMagicString.data(), kMagicSize) == 0) {
|
||||
} else if (std::memcmp(data, kZipMagicString.data(), kMagicSize) == 0) {
|
||||
// Magic header for a zip file, which we use to store pickled sub-files.
|
||||
return FileFormat::ZipFileFormat;
|
||||
}
|
||||
return FileFormat::UnknownFileFormat;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
/**
|
||||
* Returns the likely file format based on the magic header bytes of @p data.
|
||||
* If the stream position changes while inspecting the data, this function will
|
||||
|
|
@ -86,10 +95,10 @@ static inline FileFormat getFileFormat(std::istream& data) {
|
|||
FileFormat format = FileFormat::UnknownFileFormat;
|
||||
std::streampos orig_pos = data.tellg();
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
std::array<char, internal::kFileFormatHeaderSize> header;
|
||||
std::array<char, kFileFormatHeaderSize> header;
|
||||
data.read(header.data(), header.size());
|
||||
if (data.good()) {
|
||||
format = internal::getFileFormat(header);
|
||||
format = getFileFormat(header.data());
|
||||
}
|
||||
data.seekg(orig_pos, data.beg);
|
||||
return format;
|
||||
|
|
@ -105,5 +114,83 @@ static inline FileFormat getFileFormat(const std::string& filename) {
|
|||
return getFileFormat(data);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
|
||||
static void file_not_found_error() {
|
||||
std::stringstream message;
|
||||
message << "Error while opening file: ";
|
||||
if (errno == ENOENT) {
|
||||
message << "no such file or directory" << std::endl;
|
||||
} else {
|
||||
message << "error no is: " << errno << std::endl;
|
||||
}
|
||||
TORCH_CHECK(false, message.str());
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
|
||||
static inline std::tuple<std::shared_ptr<char>, size_t> get_file_content(
|
||||
const char* filename) {
|
||||
#if defined(HAVE_MMAP)
|
||||
int fd = open(filename, O_RDONLY);
|
||||
if (fd < 0) {
|
||||
// failed to open file, chances are it's no such file or directory.
|
||||
file_not_found_error();
|
||||
}
|
||||
struct stat statbuf {};
|
||||
fstat(fd, &statbuf);
|
||||
size_t size = statbuf.st_size;
|
||||
void* ptr = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
|
||||
close(fd);
|
||||
auto deleter = [statbuf](char* ptr) { munmap(ptr, statbuf.st_size); };
|
||||
std::shared_ptr<char> data(reinterpret_cast<char*>(ptr), deleter);
|
||||
#else
|
||||
FILE* f = fopen(filename, "rb");
|
||||
if (f == nullptr) {
|
||||
file_not_found_error();
|
||||
}
|
||||
fseek(f, 0, SEEK_END);
|
||||
size_t size = ftell(f);
|
||||
fseek(f, 0, SEEK_SET);
|
||||
// make sure buffer size is multiple of alignment
|
||||
size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment;
|
||||
std::shared_ptr<char> data(
|
||||
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
|
||||
fread(data.get(), size, 1, f);
|
||||
fclose(f);
|
||||
#endif
|
||||
return std::make_tuple(data, size);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
|
||||
static inline std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
|
||||
std::istream& in) {
|
||||
// get size of the stream and reset to orig
|
||||
std::streampos orig_pos = in.tellg();
|
||||
in.seekg(orig_pos, std::ios::end);
|
||||
const long size = in.tellg();
|
||||
in.seekg(orig_pos, in.beg);
|
||||
|
||||
// read stream
|
||||
// NOLINT make sure buffer size is multiple of alignment
|
||||
size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment;
|
||||
std::shared_ptr<char> data(
|
||||
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
|
||||
in.read(data.get(), size);
|
||||
|
||||
// reset stream to original position
|
||||
in.seekg(orig_pos, in.beg);
|
||||
return std::make_tuple(data, size);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
|
||||
static inline std::tuple<std::shared_ptr<char>, size_t> get_rai_content(
|
||||
caffe2::serialize::ReadAdapterInterface* rai) {
|
||||
size_t buffer_size = (rai->size() / kMaxAlignment + 1) * kMaxAlignment;
|
||||
std::shared_ptr<char> data(
|
||||
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
|
||||
rai->read(
|
||||
0, data.get(), rai->size(), "Loading ReadAdapterInterface to bytes");
|
||||
return std::make_tuple(data, buffer_size);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include <c10/util/ScopeExit.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
||||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
|
|
@ -618,53 +619,6 @@ TypePtr FlatbufferLoader::getOrCreateTypeAnnotations(
|
|||
return type;
|
||||
}
|
||||
|
||||
std::tuple<std::shared_ptr<char>, size_t> get_file_content(
|
||||
const char* filename) {
|
||||
#if defined(HAVE_MMAP)
|
||||
int fd = open(filename, O_RDONLY);
|
||||
struct stat statbuf {};
|
||||
fstat(fd, &statbuf);
|
||||
size_t size = statbuf.st_size;
|
||||
void* ptr = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
|
||||
close(fd);
|
||||
auto deleter = [statbuf](char* ptr) { munmap(ptr, statbuf.st_size); };
|
||||
std::shared_ptr<char> data(reinterpret_cast<char*>(ptr), deleter);
|
||||
#else
|
||||
FILE* f = fopen(filename, "rb");
|
||||
fseek(f, 0, SEEK_END);
|
||||
size_t size = ftell(f);
|
||||
fseek(f, 0, SEEK_SET);
|
||||
// make sure buffer size is multiple of alignment
|
||||
size_t buffer_size =
|
||||
(size / FLATBUFFERS_MAX_ALIGNMENT + 1) * FLATBUFFERS_MAX_ALIGNMENT;
|
||||
std::shared_ptr<char> data(
|
||||
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
|
||||
fread(data.get(), size, 1, f);
|
||||
fclose(f);
|
||||
#endif
|
||||
return std::make_tuple(data, size);
|
||||
}
|
||||
|
||||
std::tuple<std::shared_ptr<char>, size_t> get_stream_content(std::istream& in) {
|
||||
// get size of the stream and reset to orig
|
||||
std::streampos orig_pos = in.tellg();
|
||||
in.seekg(orig_pos, std::ios::end);
|
||||
const long size = in.tellg();
|
||||
in.seekg(orig_pos, in.beg);
|
||||
|
||||
// read stream
|
||||
// NOLINT make sure buffer size is multiple of alignment
|
||||
size_t buffer_size =
|
||||
(size / FLATBUFFERS_MAX_ALIGNMENT + 1) * FLATBUFFERS_MAX_ALIGNMENT;
|
||||
std::shared_ptr<char> data(
|
||||
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
|
||||
in.read(data.get(), size);
|
||||
|
||||
// reset stream to original position
|
||||
in.seekg(orig_pos, in.beg);
|
||||
return std::make_tuple(data, size);
|
||||
}
|
||||
|
||||
void FlatbufferLoader::extractJitSourceAndConstants(
|
||||
ExtraFilesMap* jit_sources,
|
||||
std::vector<IValue>* constants) {
|
||||
|
|
@ -696,13 +650,17 @@ void FlatbufferLoader::extractJitSourceAndConstants(
|
|||
mobile::Module parse_and_initialize_mobile_module(
|
||||
std::shared_ptr<char> data,
|
||||
size_t,
|
||||
c10::optional<at::Device>) {
|
||||
c10::optional<at::Device>,
|
||||
ExtraFilesMap* extra_files) {
|
||||
TORCH_CHECK(
|
||||
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
|
||||
"Format error");
|
||||
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
|
||||
mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module);
|
||||
m.set_delete_memory(std::move(data));
|
||||
if (extra_files != nullptr) {
|
||||
parseExtraFiles(flatbuffer_module, *extra_files);
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
|
|
@ -718,32 +676,35 @@ mobile::Module initialize_mobile_module(
|
|||
|
||||
mobile::Module load_mobile_module_from_file(
|
||||
const std::string& filename,
|
||||
c10::optional<c10::Device> device) {
|
||||
c10::optional<c10::Device> device,
|
||||
ExtraFilesMap* extra_files) {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_file_content(filename.c_str());
|
||||
return parse_and_initialize_mobile_module(std::move(data), size, device);
|
||||
return parse_and_initialize_mobile_module(
|
||||
std::move(data), size, device, extra_files);
|
||||
}
|
||||
|
||||
uint64_t get_bytecode_version(std::istream& in) {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_stream_content(in);
|
||||
TORCH_CHECK(
|
||||
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
|
||||
"Format error");
|
||||
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
|
||||
return flatbuffer_module->bytecode_version();
|
||||
return get_bytecode_version_from_bytes(data.get());
|
||||
}
|
||||
|
||||
uint64_t get_bytecode_version(const std::string& filename) {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_file_content(filename.c_str());
|
||||
return get_bytecode_version_from_bytes(data.get());
|
||||
}
|
||||
|
||||
uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content) {
|
||||
TORCH_CHECK(
|
||||
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
|
||||
mobile::serialization::ModuleBufferHasIdentifier(flatbuffer_content),
|
||||
"Format error");
|
||||
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
|
||||
auto* flatbuffer_module =
|
||||
mobile::serialization::GetMutableModule(flatbuffer_content);
|
||||
return flatbuffer_module->bytecode_version();
|
||||
}
|
||||
|
||||
|
|
@ -755,5 +716,55 @@ mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) {
|
|||
return mobile::get_module_info(m);
|
||||
}
|
||||
|
||||
mobile::Module load_mobile_module_from_stream_with_copy(
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap* extra_files) {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_stream_content(in);
|
||||
return parse_and_initialize_mobile_module(
|
||||
std::move(data), size, device, extra_files);
|
||||
}
|
||||
|
||||
static mobile::Module parse_flatbuffer_no_object(
|
||||
std::shared_ptr<char> data,
|
||||
size_t size,
|
||||
c10::optional<at::Device> device) {
|
||||
(void)device;
|
||||
(void)size;
|
||||
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
|
||||
FlatbufferLoader loader;
|
||||
// replace parserObject with to handle only class with field case
|
||||
// function.
|
||||
loader.registerIValueParser(
|
||||
mobile::serialization::IValueUnion::Object,
|
||||
+[](FlatbufferLoader& loader,
|
||||
const mobile::serialization::IValue& ivalue) {
|
||||
const mobile::serialization::Object* object = ivalue.val_as_Object();
|
||||
auto cls = loader.getOrCreateClassTypeForObject(object);
|
||||
auto obj = c10::ivalue::Object::create(
|
||||
at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
|
||||
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
|
||||
IValue val = loader.getIValue(object->attrs()->Get(i));
|
||||
obj->setSlot(i, std::move(val));
|
||||
}
|
||||
return static_cast<c10::IValue>(obj);
|
||||
});
|
||||
|
||||
mobile::Module m = loader.parseModule(flatbuffer_module);
|
||||
m.set_delete_memory(std::move(data));
|
||||
return m;
|
||||
}
|
||||
|
||||
bool register_flatbuffer_loader() {
|
||||
load_flatbuffer_bytes = parse_and_initialize_mobile_module;
|
||||
load_flatbuffer_bytes_no_object = parse_flatbuffer_no_object;
|
||||
get_flatbuffer_bytecode_version = get_bytecode_version_from_bytes;
|
||||
return true;
|
||||
}
|
||||
|
||||
const bool kRegisteredFlatbufferLoader = register_flatbuffer_loader();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -47,7 +47,8 @@ TORCH_API mobile::Module initialize_mobile_module(
|
|||
TORCH_API mobile::Module parse_and_initialize_mobile_module(
|
||||
std::shared_ptr<char> data,
|
||||
size_t size,
|
||||
c10::optional<at::Device> device = c10::nullopt);
|
||||
c10::optional<at::Device> device = c10::nullopt,
|
||||
ExtraFilesMap* extra_files = nullptr);
|
||||
|
||||
// Load a mobile::Module from a filepath.
|
||||
// This function does steps 1+2+3 described above.
|
||||
|
|
@ -56,24 +57,33 @@ TORCH_API mobile::Module parse_and_initialize_mobile_module(
|
|||
// versions above.
|
||||
TORCH_API mobile::Module load_mobile_module_from_file(
|
||||
const std::string& filename,
|
||||
c10::optional<at::Device> device = c10::nullopt);
|
||||
c10::optional<at::Device> device = c10::nullopt,
|
||||
ExtraFilesMap* extra_files = nullptr);
|
||||
|
||||
TORCH_API void parseExtraFiles(
|
||||
mobile::serialization::Module* module,
|
||||
ExtraFilesMap& extra_files);
|
||||
|
||||
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_file_content(
|
||||
const char* filename);
|
||||
|
||||
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
|
||||
std::istream& in);
|
||||
|
||||
TORCH_API uint64_t get_bytecode_version(std::istream& in);
|
||||
TORCH_API uint64_t get_bytecode_version(const std::string& filename);
|
||||
TORCH_API uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content);
|
||||
|
||||
TORCH_API mobile::ModuleInfo get_module_info_from_flatbuffer(
|
||||
char* flatbuffer_content);
|
||||
|
||||
// The methods below are less efficient because it need to read the stream in
|
||||
// its entirity to a buffer
|
||||
TORCH_API mobile::Module load_mobile_module_from_stream_with_copy(
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device = c10::nullopt,
|
||||
ExtraFilesMap* extra_files = nullptr);
|
||||
|
||||
// This function will make the capabilities to load
|
||||
// Module as a flatbuffer file available for use by _load_for_mobile
|
||||
// and friends. This is NOT needed if using the other functions
|
||||
// in this file directly.
|
||||
TORCH_API bool register_flatbuffer_loader();
|
||||
|
||||
class TORCH_API FlatbufferLoader {
|
||||
public:
|
||||
FlatbufferLoader();
|
||||
|
|
|
|||
|
|
@ -7,13 +7,12 @@
|
|||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/ScopeExit.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <caffe2/serialize/in_memory_adapter.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <caffe2/serialize/read_adapter_interface.h>
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
#endif
|
||||
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
|
|
@ -85,10 +84,23 @@
|
|||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using caffe2::serialize::IStreamAdapter;
|
||||
using caffe2::serialize::MemoryReadAdapter;
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
using caffe2::serialize::ReadAdapterInterface;
|
||||
|
||||
mobile::Module (*load_flatbuffer_bytes)(
|
||||
std::shared_ptr<char>,
|
||||
size_t size,
|
||||
c10::optional<at::Device>,
|
||||
ExtraFilesMap*) = nullptr;
|
||||
|
||||
mobile::Module (*load_flatbuffer_bytes_no_object)(
|
||||
std::shared_ptr<char>,
|
||||
size_t size,
|
||||
c10::optional<at::Device>) = nullptr;
|
||||
|
||||
uint64_t (*get_flatbuffer_bytecode_version)(char* flatbuffer_content) = nullptr;
|
||||
|
||||
OpCode parseOpCode(const char* str);
|
||||
|
||||
TypePtr resolveTypeNameMobile(
|
||||
|
|
@ -516,6 +528,13 @@ mobile::Module _load_for_mobile_impl(
|
|||
ExtraFilesMap& extra_files,
|
||||
uint64_t module_load_options);
|
||||
|
||||
mobile::Module _load_mobile_from_bytes(
|
||||
std::shared_ptr<char> data,
|
||||
size_t size,
|
||||
c10::optional<c10::Device> device,
|
||||
ExtraFilesMap& extra_files,
|
||||
uint64_t module_load_options);
|
||||
|
||||
mobile::Module _load_for_mobile(
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device) {
|
||||
|
|
@ -541,38 +560,11 @@ mobile::Module _load_for_mobile(
|
|||
std::istream& in,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
in.seekg(0, in.beg);
|
||||
auto format = getFileFormat(in);
|
||||
switch (format) {
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<IStreamAdapter> rai =
|
||||
std::make_unique<IStreamAdapter>(&in);
|
||||
auto module = _load_for_mobile(std::move(rai), device, extra_files);
|
||||
return module;
|
||||
}
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_stream_content(in);
|
||||
auto* flatbuffer_module =
|
||||
mobile::serialization::GetMutableModule(data.get());
|
||||
mobile::Module m = initialize_mobile_module(flatbuffer_module);
|
||||
parseExtraFiles(flatbuffer_module, extra_files);
|
||||
m.set_delete_memory(data);
|
||||
return m;
|
||||
}
|
||||
#else
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enabled flatbuffer");
|
||||
}
|
||||
#endif
|
||||
default: {
|
||||
TORCH_CHECK(false, "Format error");
|
||||
}
|
||||
}
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_stream_content(in);
|
||||
return _load_mobile_from_bytes(
|
||||
data, size, device, extra_files, kDefaultMobileLoadOptions);
|
||||
}
|
||||
|
||||
mobile::Module _load_for_mobile(
|
||||
|
|
@ -580,10 +572,7 @@ mobile::Module _load_for_mobile(
|
|||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
return _load_for_mobile(
|
||||
filename,
|
||||
device,
|
||||
extra_files,
|
||||
/*module_load_options=*/_default_mobile_module_load_options);
|
||||
filename, device, extra_files, kDefaultMobileLoadOptions);
|
||||
}
|
||||
|
||||
mobile::Module _load_for_mobile(
|
||||
|
|
@ -591,49 +580,55 @@ mobile::Module _load_for_mobile(
|
|||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files,
|
||||
uint64_t module_load_options) {
|
||||
auto format = getFileFormat(filename);
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_file_content(filename.c_str());
|
||||
return _load_mobile_from_bytes(
|
||||
data, size, device, extra_files, module_load_options);
|
||||
}
|
||||
|
||||
TORCH_API mobile::Module _load_for_mobile(
|
||||
std::unique_ptr<ReadAdapterInterface> rai,
|
||||
c10::optional<c10::Device> device,
|
||||
ExtraFilesMap& extra_files,
|
||||
uint64_t module_load_options) {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_rai_content(rai.get());
|
||||
return _load_mobile_from_bytes(
|
||||
data, size, device, extra_files, module_load_options);
|
||||
}
|
||||
|
||||
mobile::Module _load_mobile_from_bytes(
|
||||
std::shared_ptr<char> data,
|
||||
size_t size,
|
||||
c10::optional<c10::Device> device,
|
||||
ExtraFilesMap& extra_files,
|
||||
uint64_t module_load_options) {
|
||||
TORCH_CHECK(size >= kFileFormatHeaderSize, "Format error");
|
||||
auto format = getFileFormat(data.get());
|
||||
switch (format) {
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<FileAdapter> rai =
|
||||
std::make_unique<FileAdapter>(filename);
|
||||
auto module = _load_for_mobile_impl(
|
||||
std::unique_ptr<ReadAdapterInterface> rai =
|
||||
std::make_unique<MemoryReadAdapter>(data.get(), size);
|
||||
return _load_for_mobile_impl(
|
||||
std::move(rai), device, extra_files, module_load_options);
|
||||
return module;
|
||||
}
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_file_content(filename.c_str());
|
||||
auto* flatbuffer_module =
|
||||
mobile::serialization::GetMutableModule(data.get());
|
||||
mobile::Module m = initialize_mobile_module(flatbuffer_module);
|
||||
parseExtraFiles(flatbuffer_module, extra_files);
|
||||
m.set_delete_memory(data);
|
||||
return m;
|
||||
if (load_flatbuffer_bytes != nullptr) {
|
||||
return load_flatbuffer_bytes(data, size, device, &extra_files);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enabled flatbuffer");
|
||||
}
|
||||
}
|
||||
#else
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enabled flatbuffer");
|
||||
}
|
||||
#endif
|
||||
default: {
|
||||
TORCH_CHECK(false, "Format error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mobile::Module _load_for_mobile(
|
||||
std::unique_ptr<ReadAdapterInterface> rai,
|
||||
c10::optional<c10::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
auto module = _load_for_mobile_impl(
|
||||
std::move(rai), device, extra_files, _default_mobile_module_load_options);
|
||||
return module;
|
||||
}
|
||||
|
||||
mobile::Module _load_for_mobile_impl(
|
||||
std::unique_ptr<ReadAdapterInterface> rai,
|
||||
c10::optional<c10::Device> device,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/mobile/parse_operators.h>
|
||||
|
||||
#include <istream>
|
||||
#include <memory>
|
||||
|
|
@ -32,7 +33,8 @@ TORCH_API mobile::Module _load_for_mobile(
|
|||
TORCH_API mobile::Module _load_for_mobile(
|
||||
std::unique_ptr<ReadAdapterInterface> rai,
|
||||
c10::optional<c10::Device> device,
|
||||
ExtraFilesMap& extra_files);
|
||||
ExtraFilesMap& extra_files,
|
||||
uint64_t module_load_options = kDefaultMobileLoadOptions);
|
||||
|
||||
TORCH_API mobile::Module _load_for_mobile(
|
||||
const std::string& filename,
|
||||
|
|
@ -104,5 +106,19 @@ TORCH_API std::set<std::string> _export_operator_list(
|
|||
torch::jit::mobile::Module& module);
|
||||
|
||||
} // namespace mobile
|
||||
|
||||
extern mobile::Module (*load_flatbuffer_bytes)(
|
||||
std::shared_ptr<char>,
|
||||
size_t size,
|
||||
c10::optional<at::Device>,
|
||||
ExtraFilesMap*);
|
||||
|
||||
extern mobile::Module (*load_flatbuffer_bytes_no_object)(
|
||||
std::shared_ptr<char>,
|
||||
size_t size,
|
||||
c10::optional<at::Device>);
|
||||
|
||||
extern uint64_t (*get_flatbuffer_bytecode_version)(char* flatbuffer_content);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/mobile/import_export_common.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
|
|
@ -15,10 +16,7 @@
|
|||
#include <torch/csrc/jit/serialization/unpickler.h>
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
#endif // defined(ENABLE_FLATBUFFER)
|
||||
|
||||
#include <caffe2/serialize/in_memory_adapter.h>
|
||||
#include <exception>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
|
@ -28,6 +26,7 @@ namespace torch {
|
|||
namespace jit {
|
||||
using caffe2::serialize::FileAdapter;
|
||||
using caffe2::serialize::IStreamAdapter;
|
||||
using caffe2::serialize::MemoryReadAdapter;
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
using caffe2::serialize::ReadAdapterInterface;
|
||||
|
||||
|
|
@ -238,58 +237,30 @@ std::map<std::string, at::Tensor> mobile_module_to_parameter_map(
|
|||
"' in deserialized mobile::Module");
|
||||
}
|
||||
|
||||
std::map<std::string, at::Tensor> _load_parameters(
|
||||
std::istream& in,
|
||||
std::map<std::string, at::Tensor> _load_parameters_bytes(
|
||||
std::shared_ptr<char> data,
|
||||
size_t size,
|
||||
c10::optional<at::Device> device) {
|
||||
// Detect the data format from the head of the input stream.
|
||||
FileFormat format = getFileFormat(in);
|
||||
|
||||
TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
|
||||
FileFormat format = getFileFormat(data.get());
|
||||
// Call the appropriate parser.
|
||||
std::map<std::string, at::Tensor> map;
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_stream_content(in);
|
||||
TORCH_CHECK(
|
||||
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
|
||||
"Format error");
|
||||
auto* flatbuffer_module =
|
||||
mobile::serialization::GetMutableModule(data.get());
|
||||
FlatbufferLoader loader;
|
||||
// replace parserObject with to handle only class with field case
|
||||
// function.
|
||||
loader.registerIValueParser(
|
||||
mobile::serialization::IValueUnion::Object,
|
||||
+[](FlatbufferLoader& loader,
|
||||
const mobile::serialization::IValue& ivalue) {
|
||||
const mobile::serialization::Object* object =
|
||||
ivalue.val_as_Object();
|
||||
auto cls = loader.getOrCreateClassTypeForObject(object);
|
||||
auto obj = c10::ivalue::Object::create(
|
||||
at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
|
||||
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
|
||||
IValue val = loader.getIValue(object->attrs()->Get(i));
|
||||
obj->setSlot(i, std::move(val));
|
||||
}
|
||||
return static_cast<c10::IValue>(obj);
|
||||
});
|
||||
|
||||
mobile::Module m = loader.parseModule(flatbuffer_module);
|
||||
m.set_delete_memory(std::move(data));
|
||||
map = mobile_module_to_parameter_map(m);
|
||||
#else // !defined(ENABLE_FLATBUFFER)
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enabled flatbuffer");
|
||||
#endif // !defined(ENABLE_FLATBUFFER)
|
||||
if (load_flatbuffer_bytes_no_object != nullptr) {
|
||||
auto m = load_flatbuffer_bytes_no_object(data, size, device);
|
||||
map = mobile_module_to_parameter_map(m);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enabled flatbuffer");
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<IStreamAdapter> rai =
|
||||
std::make_unique<IStreamAdapter>(&in);
|
||||
auto rai = std::make_unique<caffe2::serialize::MemoryReadAdapter>(
|
||||
data.get(), size);
|
||||
map = load_parameters_from_zip(std::move(rai), device);
|
||||
break;
|
||||
}
|
||||
|
|
@ -300,38 +271,22 @@ std::map<std::string, at::Tensor> _load_parameters(
|
|||
return map;
|
||||
}
|
||||
|
||||
std::map<std::string, at::Tensor> _load_parameters(
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device) {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_stream_content(in);
|
||||
return _load_parameters_bytes(std::move(data), size, device);
|
||||
}
|
||||
|
||||
std::map<std::string, at::Tensor> _load_parameters(
|
||||
const std::string& filename,
|
||||
c10::optional<at::Device> device) {
|
||||
// Detect the file format from its header.
|
||||
FileFormat format = getFileFormat(filename);
|
||||
|
||||
// Call the appropriate parser.
|
||||
std::map<std::string, at::Tensor> map;
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
mobile::Module module = load_mobile_module_from_file(filename, device);
|
||||
map = mobile_module_to_parameter_map(module);
|
||||
#else // !defined(ENABLE_FLATBUFFER)
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enabled flatbuffer");
|
||||
#endif // !defined(ENABLE_FLATBUFFER)
|
||||
break;
|
||||
}
|
||||
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<FileAdapter> rai =
|
||||
std::make_unique<FileAdapter>(filename);
|
||||
map = load_parameters_from_zip(std::move(rai), device);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
return map;
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_file_content(filename.c_str());
|
||||
return _load_parameters_bytes(std::move(data), size, device);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ enum MobileModuleLoadOptions {
|
|||
OPERATOR_CHECK = 1,
|
||||
};
|
||||
|
||||
const uint64_t _default_mobile_module_load_options =
|
||||
const uint64_t kDefaultMobileLoadOptions =
|
||||
MobileModuleLoadOptions::OPERATOR_CHECK;
|
||||
|
||||
namespace mobile {
|
||||
|
|
|
|||
|
|
@ -11,11 +11,6 @@
|
|||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
#include <flatbuffers/flatbuffers.h>
|
||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||
#endif // defined(ENABLE_FLATBUFFER)
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -119,36 +114,34 @@ mobile::Module tensor_dict_to_mobile(
|
|||
|
||||
} // namespace mobile
|
||||
|
||||
void (*_save_mobile_module_to)(
|
||||
const mobile::Module& module,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func) = nullptr;
|
||||
|
||||
void _save_parameters(
|
||||
const std::map<std::string, at::Tensor>& map,
|
||||
std::ostream& out,
|
||||
bool use_flatbuffer) {
|
||||
auto dict = mobile::tensor_map_to_dict(map);
|
||||
|
||||
if (use_flatbuffer) {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
// For Flatbuffer, we serialize an entire, mostly-empty module containing
|
||||
// the dict as an attribute.
|
||||
flatbuffers::DetachedBuffer bytes = torch::jit::save_mobile_module_to_bytes(
|
||||
mobile::tensor_dict_to_mobile(dict));
|
||||
auto write_func = [&out](const void* buf, size_t nbytes) -> size_t {
|
||||
out.write(
|
||||
reinterpret_cast<char*>(bytes.data()),
|
||||
static_cast<std::streamsize>(bytes.size()));
|
||||
#else // !defined(ENABLE_FLATBUFFER)
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Trying to export as flatbuffer file but "
|
||||
"the build hasn't enabled flatbuffer");
|
||||
#endif // !defined(ENABLE_FLATBUFFER)
|
||||
static_cast<const char*>(buf), static_cast<std::streamsize>(nbytes));
|
||||
return !out ? 0 : nbytes;
|
||||
};
|
||||
|
||||
if (use_flatbuffer) {
|
||||
if (_save_mobile_module_to != nullptr) {
|
||||
_save_mobile_module_to(mobile::tensor_dict_to_mobile(dict), write_func);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Trying to export as flatbuffer file but "
|
||||
"the build hasn't enabled flatbuffer");
|
||||
}
|
||||
} else {
|
||||
// For Pickle, we only serialize the dict itself.
|
||||
mobile::IValuePickler pickler(
|
||||
[&](const void* buf, size_t nbytes) -> size_t {
|
||||
out.write(
|
||||
static_cast<const char*>(buf),
|
||||
static_cast<std::streamsize>(nbytes));
|
||||
return !out ? 0 : nbytes;
|
||||
});
|
||||
mobile::IValuePickler pickler(write_func);
|
||||
pickler.serialize(dict);
|
||||
}
|
||||
}
|
||||
|
|
@ -159,23 +152,8 @@ void _save_parameters(
|
|||
bool use_flatbuffer) {
|
||||
auto dict = mobile::tensor_map_to_dict(map);
|
||||
|
||||
if (use_flatbuffer) {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
// For Flatbuffer, we serialize an entire, mostly-empty module containing
|
||||
// the dict as an attribute.
|
||||
torch::jit::save_mobile_module(
|
||||
mobile::tensor_dict_to_mobile(dict), filename);
|
||||
#else // !defined(ENABLE_FLATBUFFER)
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Trying to export as flatbuffer file but "
|
||||
"the build hasn't enabled flatbuffer");
|
||||
#endif // !defined(ENABLE_FLATBUFFER)
|
||||
} else {
|
||||
// For Pickle, we only serialize the dict itself.
|
||||
mobile::IValuePickler pickler(filename);
|
||||
pickler.serialize(dict);
|
||||
}
|
||||
std::ofstream ifile(filename);
|
||||
_save_parameters(map, ifile, use_flatbuffer);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -45,5 +45,9 @@ c10::Dict<std::string, at::Tensor> tensor_map_to_dict(
|
|||
|
||||
} // namespace mobile
|
||||
|
||||
extern void (*_save_mobile_module_to)(
|
||||
const mobile::Module& module,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -262,5 +262,11 @@ bool getMobileInterfaceCallExport();
|
|||
|
||||
CompilationOptions getOptionsFromGlobal();
|
||||
|
||||
extern void (*_save_jit_module_to)(
|
||||
const Module& module,
|
||||
const ExtraFilesMap& extra_files,
|
||||
bool save_mobile_debug_info,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -16,10 +16,6 @@
|
|||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
|
||||
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
|
||||
#endif
|
||||
#include <torch/csrc/jit/serialization/import_export_constants.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_helpers.h>
|
||||
|
|
@ -34,6 +30,8 @@
|
|||
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <cerrno>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
|
@ -817,23 +815,6 @@ SerializationStorageContext& ScriptModuleSerializer::storage_context() {
|
|||
return storage_context_;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
void save_mobile_module_to(
|
||||
const Module& module,
|
||||
const ExtraFilesMap& extra_files,
|
||||
bool save_mobile_debug_info,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func) {
|
||||
ExtraFilesMap jitFiles;
|
||||
CompilationOptions options = getOptionsFromGlobal();
|
||||
std::vector<IValue> constants;
|
||||
jitModuleToPythonCodeAndConstants(module, &jitFiles, &constants);
|
||||
mobile::Module mod = jitModuleToMobile(module, options);
|
||||
auto buffer =
|
||||
save_mobile_module_to_bytes(mod, extra_files, jitFiles, constants);
|
||||
writer_func(reinterpret_cast<void*>(buffer.data()), buffer.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
void ExportModule(
|
||||
const Module& module,
|
||||
std::ostream& out,
|
||||
|
|
@ -845,21 +826,13 @@ void ExportModule(
|
|||
out.write(static_cast<const char*>(buf), nbytes);
|
||||
return !out ? 0 : nbytes;
|
||||
};
|
||||
if (use_flatbuffer) {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
save_mobile_module_to(
|
||||
module, extra_files, save_mobile_debug_info, writer_func);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
|
||||
#endif
|
||||
} else {
|
||||
caffe2::serialize::PyTorchStreamWriter writer(writer_func);
|
||||
ScriptModuleSerializer serializer(writer);
|
||||
serializer.serialize(
|
||||
module, extra_files, bytecode_format, save_mobile_debug_info);
|
||||
}
|
||||
ExportModule(
|
||||
module,
|
||||
writer_func,
|
||||
extra_files,
|
||||
bytecode_format,
|
||||
save_mobile_debug_info,
|
||||
use_flatbuffer);
|
||||
}
|
||||
|
||||
void ExportModule(
|
||||
|
|
@ -869,29 +842,41 @@ void ExportModule(
|
|||
bool bytecode_format,
|
||||
bool save_mobile_debug_info,
|
||||
bool use_flatbuffer) {
|
||||
if (use_flatbuffer) {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
|
||||
std::fstream ofile(filename, std::ios::binary | std::ios::out);
|
||||
ofile.write(static_cast<const char*>(buf), nbytes);
|
||||
ofile.close();
|
||||
return !ofile ? 0 : nbytes;
|
||||
};
|
||||
save_mobile_module_to(
|
||||
module, extra_files, save_mobile_debug_info, writer_func);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
|
||||
#endif
|
||||
} else {
|
||||
if (!use_flatbuffer) {
|
||||
// the zip archive need to know the filepath
|
||||
caffe2::serialize::PyTorchStreamWriter writer(filename);
|
||||
ScriptModuleSerializer serializer(writer);
|
||||
serializer.serialize(
|
||||
module, extra_files, bytecode_format, save_mobile_debug_info);
|
||||
return;
|
||||
}
|
||||
std::ofstream ofile;
|
||||
ofile.open(filename, std::ios::binary | std::ios::out);
|
||||
if (ofile.fail()) {
|
||||
std::stringstream message;
|
||||
if (errno == ENOENT) {
|
||||
message << "Parent directory of " << filename << " does not exist.\n";
|
||||
} else {
|
||||
message << "Error while opening file: " << errno << std::endl;
|
||||
;
|
||||
}
|
||||
TORCH_CHECK(false, message.str());
|
||||
}
|
||||
ExportModule(
|
||||
module,
|
||||
ofile,
|
||||
extra_files,
|
||||
bytecode_format,
|
||||
save_mobile_debug_info,
|
||||
use_flatbuffer);
|
||||
}
|
||||
|
||||
void (*_save_jit_module_to)(
|
||||
const Module& module,
|
||||
const ExtraFilesMap& extra_files,
|
||||
bool save_mobile_debug_info,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func) = nullptr;
|
||||
|
||||
void ExportModule(
|
||||
const Module& module,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func,
|
||||
|
|
@ -900,14 +885,14 @@ void ExportModule(
|
|||
bool save_mobile_debug_info,
|
||||
bool use_flatbuffer) {
|
||||
if (use_flatbuffer) {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
save_mobile_module_to(
|
||||
module, extra_files, save_mobile_debug_info, writer_func);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
|
||||
#endif
|
||||
if (_save_jit_module_to != nullptr) {
|
||||
_save_jit_module_to(
|
||||
module, extra_files, save_mobile_debug_info, writer_func);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Trying to export as flatbuffer file but the build hasn't enabled flatbuffer");
|
||||
}
|
||||
} else {
|
||||
caffe2::serialize::PyTorchStreamWriter writer(writer_func);
|
||||
ScriptModuleSerializer serializer(writer);
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <flatbuffers/flatbuffers.h>
|
||||
#include <torch/csrc/jit/mobile/code.h>
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
#include <torch/csrc/jit/mobile/train/export_data.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
|
|
@ -776,5 +778,19 @@ flatbuffers::DetachedBuffer save_mobile_module_to_bytes(
|
|||
jit_constants);
|
||||
}
|
||||
|
||||
void save_mobile_module_to_func(
|
||||
const mobile::Module& module,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func) {
|
||||
auto buffer = save_mobile_module_to_bytes(module);
|
||||
writer_func(buffer.data(), buffer.size());
|
||||
}
|
||||
|
||||
bool register_flatbuffer_serializer() {
|
||||
_save_mobile_module_to = save_mobile_module_to_func;
|
||||
return true;
|
||||
}
|
||||
|
||||
const bool kFlatbufferSerializerRegistered = register_flatbuffer_serializer();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -29,5 +29,11 @@ TORCH_API flatbuffers::DetachedBuffer save_mobile_module_to_bytes(
|
|||
const ExtraFilesMap& jit_sources = ExtraFilesMap(),
|
||||
const std::vector<IValue>& jit_constants = {});
|
||||
|
||||
// This function will make the capabilities to load and safe
|
||||
// Module as a flatbuffer file available for use by _load_for_mobile
|
||||
// and friends. This is NOT needed if using the other functions
|
||||
// in this file directly.
|
||||
TORCH_API bool register_flatbuffer_serializer();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
|
||||
|
||||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
|
|
@ -73,5 +74,25 @@ flatbuffers::DetachedBuffer save_jit_module_to_bytes(
|
|||
return save_mobile_module_to_bytes(mobilem, extra_files, jitfiles, constants);
|
||||
}
|
||||
|
||||
void save_jit_module_to_write_func(
|
||||
const Module& module,
|
||||
const ExtraFilesMap& extra_files,
|
||||
bool save_mobile_debug_info,
|
||||
const std::function<size_t(const void*, size_t)>& writer_func) {
|
||||
(void)save_mobile_debug_info;
|
||||
auto buffer = save_jit_module_to_bytes(module, extra_files);
|
||||
writer_func(reinterpret_cast<void*>(buffer.data()), buffer.size());
|
||||
}
|
||||
|
||||
bool register_flatbuffer_all() {
|
||||
(void)register_flatbuffer_loader();
|
||||
(void)register_flatbuffer_serializer();
|
||||
_save_jit_module_to = save_jit_module_to_write_func;
|
||||
_load_jit_module_from_flatbuffer_bytes = parse_and_initialize_jit_module;
|
||||
return true;
|
||||
}
|
||||
|
||||
const bool kFlatbufferSerializerJitInitialized = register_flatbuffer_all();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -31,5 +31,11 @@ TORCH_API Module load_jit_module_from_stream(
|
|||
ExtraFilesMap& extra_files,
|
||||
c10::optional<at::Device> device = c10::nullopt);
|
||||
|
||||
// This function will make the capabilities to load and safe
|
||||
// Module as a flatbuffer file available for use by _load_for_mobile
|
||||
// and friends. This is NOT needed if using the other functions
|
||||
// in this file directly.
|
||||
TORCH_API bool register_flatbuffer_all();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -1,4 +1,17 @@
|
|||
#include <ATen/core/interned_strings.h>
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/core/impl/alloc_cpu.h>
|
||||
#include <caffe2/serialize/file_adapter.h>
|
||||
#include <caffe2/serialize/in_memory_adapter.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <caffe2/serialize/istream_adapter.h>
|
||||
#include <caffe2/serialize/read_adapter_interface.h>
|
||||
#include <caffe2/serialize/versions.h>
|
||||
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/mobile/file_format.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/csrc/jit/serialization/source_range_serialization.h>
|
||||
|
||||
#include <ATen/core/functional.h>
|
||||
#include <ATen/core/ivalue_inl.h>
|
||||
|
|
@ -19,16 +32,6 @@
|
|||
#include <torch/csrc/jit/serialization/source_range_serialization.h>
|
||||
#include <torch/csrc/jit/serialization/unpickler.h>
|
||||
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
|
||||
#endif
|
||||
|
||||
#include <caffe2/serialize/file_adapter.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <caffe2/serialize/istream_adapter.h>
|
||||
#include <caffe2/serialize/versions.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
|
|
@ -42,6 +45,7 @@ namespace jit {
|
|||
|
||||
using caffe2::serialize::FileAdapter;
|
||||
using caffe2::serialize::IStreamAdapter;
|
||||
using caffe2::serialize::MemoryReadAdapter;
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
using caffe2::serialize::ReadAdapterInterface;
|
||||
|
||||
|
|
@ -291,31 +295,29 @@ Module import_ir_module(
|
|||
return import_ir_module(std::move(cu), in, device, extra_files);
|
||||
}
|
||||
|
||||
Module (*_load_jit_module_from_flatbuffer_bytes)(
|
||||
std::shared_ptr<char>,
|
||||
size_t,
|
||||
ExtraFilesMap&,
|
||||
c10::optional<at::Device>) = nullptr;
|
||||
|
||||
static Module _load_jit_module_from_bytes(
|
||||
std::shared_ptr<char> data,
|
||||
size_t size,
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
c10::optional<c10::Device> device,
|
||||
ExtraFilesMap& extra_files);
|
||||
|
||||
Module import_ir_module(
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
in.seekg(0, in.beg);
|
||||
auto format = getFileFormat(in);
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
return load_jit_module_from_stream(in, extra_files, device);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
|
||||
#endif
|
||||
}
|
||||
case FileFormat::ZipFileFormat: {
|
||||
auto reader = torch::make_unique<PyTorchStreamReader>(&in);
|
||||
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
|
||||
return deserializer.deserialize(device, extra_files);
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_stream_content(in);
|
||||
return _load_jit_module_from_bytes(data, size, cu, device, extra_files);
|
||||
}
|
||||
|
||||
// For reading unified serialization format from torch.Package.
|
||||
|
|
@ -348,25 +350,10 @@ Module import_ir_module(
|
|||
const std::string& filename,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
auto format = getFileFormat(filename);
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
return load_jit_module_from_file(filename, extra_files, device);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
|
||||
#endif
|
||||
}
|
||||
case FileFormat::ZipFileFormat: {
|
||||
auto reader = torch::make_unique<PyTorchStreamReader>(filename);
|
||||
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
|
||||
return deserializer.deserialize(device, extra_files);
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_file_content(filename.c_str());
|
||||
return _load_jit_module_from_bytes(data, size, cu, device, extra_files);
|
||||
}
|
||||
|
||||
Module import_ir_module(
|
||||
|
|
@ -382,100 +369,92 @@ Module import_ir_module(
|
|||
std::unique_ptr<ReadAdapterInterface> rai,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
|
||||
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
|
||||
return deserializer.deserialize(device, extra_files);
|
||||
std::shared_ptr<ReadAdapterInterface> rai_shared = std::move(rai);
|
||||
return import_ir_module(cu, rai_shared, device, extra_files);
|
||||
}
|
||||
|
||||
Module import_ir_module(
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
std::shared_ptr<ReadAdapterInterface> rai,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
std::shared_ptr<char> data;
|
||||
size_t size = 0;
|
||||
std::tie(data, size) = get_rai_content(rai.get());
|
||||
return _load_jit_module_from_bytes(data, size, cu, device, extra_files);
|
||||
}
|
||||
|
||||
Module load(std::istream& in, c10::optional<at::Device> device) {
|
||||
ExtraFilesMap extra_files;
|
||||
return load(in, device, extra_files);
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
return import_ir_module(std::move(cu), in, device);
|
||||
}
|
||||
|
||||
Module load(
|
||||
std::istream& in,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
in.seekg(0, in.beg);
|
||||
auto format = getFileFormat(in);
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
return load_jit_module_from_stream(in, extra_files, device);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
|
||||
#endif
|
||||
}
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<IStreamAdapter> rai =
|
||||
std::make_unique<IStreamAdapter>(&in);
|
||||
auto module = load(std::move(rai), device, extra_files);
|
||||
return module;
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
return import_ir_module(std::move(cu), in, device, extra_files);
|
||||
}
|
||||
|
||||
Module load(const std::string& filename, c10::optional<at::Device> device) {
|
||||
ExtraFilesMap extra_files;
|
||||
return load(filename, device, extra_files);
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
return import_ir_module(std::move(cu), filename, device);
|
||||
}
|
||||
|
||||
Module load(
|
||||
const std::string& filename,
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
auto format = getFileFormat(filename);
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
#if defined(ENABLE_FLATBUFFER)
|
||||
return load_jit_module_from_file(filename, extra_files, device);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "Flatbuffer input file but the build hasn't enable flatbuffer")
|
||||
#endif
|
||||
|
||||
case FileFormat::ZipFileFormat: {
|
||||
std::unique_ptr<FileAdapter> rai =
|
||||
std::make_unique<FileAdapter>(filename);
|
||||
auto module = load(std::move(rai), device, extra_files);
|
||||
return module;
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
}
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
return import_ir_module(std::move(cu), filename, device, extra_files);
|
||||
}
|
||||
|
||||
Module load(
|
||||
std::shared_ptr<ReadAdapterInterface> rai,
|
||||
c10::optional<c10::Device> device) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
ExtraFilesMap extra_files;
|
||||
return load(std::move(rai), device, extra_files);
|
||||
return import_ir_module(std::move(cu), std::move(rai), device, extra_files);
|
||||
}
|
||||
|
||||
Module load(
|
||||
std::shared_ptr<ReadAdapterInterface> rai,
|
||||
c10::optional<c10::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
// Verify that we're loading a zip archive and not a torch.save pickle
|
||||
// archive (marked by the 0x80 0x02 bytes at the start)
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
TORCH_CHECK(
|
||||
check_zip_file(rai),
|
||||
"`torch::jit::load()` received a file from `torch.save()`, "
|
||||
"but `torch::jit::load()` can only load files"
|
||||
" produced by `torch.jit.save()`");
|
||||
|
||||
auto reader = std::make_shared<PyTorchStreamReader>(std::move(rai));
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
return import_ir_module(std::move(cu), std::move(rai), device, extra_files);
|
||||
}
|
||||
|
||||
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
|
||||
return deserializer.deserialize(device, extra_files);
|
||||
Module _load_jit_module_from_bytes(
|
||||
std::shared_ptr<char> data,
|
||||
size_t size,
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
c10::optional<c10::Device> device,
|
||||
ExtraFilesMap& extra_files) {
|
||||
TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecorgnized data format");
|
||||
auto format = getFileFormat(data.get());
|
||||
switch (format) {
|
||||
case FileFormat::FlatbufferFileFormat: {
|
||||
if (_load_jit_module_from_flatbuffer_bytes != nullptr) {
|
||||
return _load_jit_module_from_flatbuffer_bytes(
|
||||
data, size, extra_files, device);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Flatbuffer input file but the build hasn't enable flatbuffer")
|
||||
}
|
||||
}
|
||||
case FileFormat::ZipFileFormat: {
|
||||
auto rai = std::make_unique<MemoryReadAdapter>(data.get(), size);
|
||||
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
|
||||
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
|
||||
return deserializer.deserialize(device, extra_files);
|
||||
}
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false, "Unrecognized data format");
|
||||
}
|
||||
}
|
||||
|
||||
// Replace object with a newly created but equivalent object.
|
||||
|
|
|
|||
|
|
@ -58,6 +58,12 @@ TORCH_API Module import_ir_module(
|
|||
c10::optional<c10::Device> device,
|
||||
ExtraFilesMap& extra_files);
|
||||
|
||||
TORCH_API Module import_ir_module(
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
|
||||
c10::optional<c10::Device> device,
|
||||
ExtraFilesMap& extra_files);
|
||||
|
||||
/// Loads a serialized `Module` from the given `istream`.
|
||||
///
|
||||
/// The istream must contain a serialized `Module`, exported via
|
||||
|
|
@ -104,5 +110,19 @@ TORCH_API Module jitModuleFromSourceAndConstants(
|
|||
const std::vector<IValue>& constants,
|
||||
int32_t version);
|
||||
|
||||
extern Module (*_load_jit_module_from_flatbuffer_bytes)(
|
||||
// comp unit
|
||||
std::shared_ptr<char>,
|
||||
size_t,
|
||||
ExtraFilesMap&,
|
||||
c10::optional<at::Device>);
|
||||
|
||||
extern Module (*_load_jit_module_from_flatbuffer_bytes)(
|
||||
// comp unit
|
||||
std::shared_ptr<char>,
|
||||
size_t,
|
||||
ExtraFilesMap&,
|
||||
c10::optional<at::Device>);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user