mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add ability for a mobile::Module to save as flatbuffer (#67351)
Summary:
Included functions:
* save_mobile_module -> saves a mobile::Module to flatbuffer
* load_mobile_module_from_file -> loads a flatbuffer into mobile::Module
* parse_mobile_module -> parses from bytes or deserialized flatbuffer
Module object
Fixes #{issue number}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67351
Reviewed By: iseeyuan
Differential Revision: D32010095
Pulled By: qihqi
fbshipit-source-id: d763b0557780f7c2661b6485105b045e41a5e8f1
This commit is contained in:
parent
40fb28ea87
commit
41d35dc201
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -311,3 +311,6 @@ pr.diff
|
||||||
|
|
||||||
# coverage files
|
# coverage files
|
||||||
*/**/.coverage.*
|
*/**/.coverage.*
|
||||||
|
|
||||||
|
# generated flatbuffer schema header
|
||||||
|
torch/csrc/jit/serialization/mobile_bytecode_generated.h
|
||||||
|
|
|
||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
|
|
@ -142,3 +142,6 @@
|
||||||
[submodule "third_party/breakpad"]
|
[submodule "third_party/breakpad"]
|
||||||
path = third_party/breakpad
|
path = third_party/breakpad
|
||||||
url = https://github.com/driazati/breakpad.git
|
url = https://github.com/driazati/breakpad.git
|
||||||
|
[submodule "third_party/flatbuffers"]
|
||||||
|
path = third_party/flatbuffers
|
||||||
|
url = https://github.com/google/flatbuffers.git
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,8 @@ if [[ $PYLONG_API_CHECK == 0 ]]; then
|
||||||
fi
|
fi
|
||||||
set -ex
|
set -ex
|
||||||
|
|
||||||
|
echo 'python %*' > /c/Windows/py.bat
|
||||||
|
|
||||||
"$SCRIPT_HELPERS_DIR"/build_pytorch.bat
|
"$SCRIPT_HELPERS_DIR"/build_pytorch.bat
|
||||||
|
|
||||||
assert_git_not_dirty
|
assert_git_not_dirty
|
||||||
|
|
|
||||||
38
BUILD.bazel
38
BUILD.bazel
|
|
@ -7,6 +7,7 @@ load("//:tools/build_variables.bzl", "torch_cpp_srcs", "libtorch_python_core_sou
|
||||||
load("//tools/rules:cu.bzl", "cu_library")
|
load("//tools/rules:cu.bzl", "cu_library")
|
||||||
load("//tools/config:defs.bzl", "if_cuda")
|
load("//tools/config:defs.bzl", "if_cuda")
|
||||||
load("//:aten.bzl", "intern_build_aten_ops")
|
load("//:aten.bzl", "intern_build_aten_ops")
|
||||||
|
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
|
||||||
|
|
||||||
COMMON_COPTS = [
|
COMMON_COPTS = [
|
||||||
"-DHAVE_MALLOC_USABLE_SIZE=1",
|
"-DHAVE_MALLOC_USABLE_SIZE=1",
|
||||||
|
|
@ -1833,6 +1834,14 @@ genrule(
|
||||||
tools = [':gen_version_header']
|
tools = [':gen_version_header']
|
||||||
)
|
)
|
||||||
|
|
||||||
|
flatbuffer_cc_library(
|
||||||
|
name = "mobile_bytecode_header",
|
||||||
|
srcs = ["torch/csrc/jit/serialization/mobile_bytecode.fbs"],
|
||||||
|
out_prefix = "torch/csrc/jit/serialization/",
|
||||||
|
flatc_args=["--gen-mutable", "--scoped-enums",],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
torch_cuda_headers = glob(["torch/csrc/cuda/*.h"])
|
torch_cuda_headers = glob(["torch/csrc/cuda/*.h"])
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "torch_headers",
|
name = "torch_headers",
|
||||||
|
|
@ -1864,6 +1873,7 @@ cc_library(
|
||||||
":aten_headers",
|
":aten_headers",
|
||||||
":c10_headers",
|
":c10_headers",
|
||||||
":caffe2_headers",
|
":caffe2_headers",
|
||||||
|
":mobile_bytecode_header",
|
||||||
"@local_config_python//:python_headers",
|
"@local_config_python//:python_headers",
|
||||||
"@onnx",
|
"@onnx",
|
||||||
],
|
],
|
||||||
|
|
@ -1906,6 +1916,32 @@ cc_library(
|
||||||
alwayslink = True,
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "flatbuffer_loader",
|
||||||
|
srcs = [
|
||||||
|
"torch/csrc/jit/mobile/flatbuffer_loader.cpp"
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"torch/csrc/jit/mobile/flatbuffer_loader.h"
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":torch"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "flatbuffer_serializer",
|
||||||
|
srcs = [
|
||||||
|
"torch/csrc/jit/serialization/flatbuffer_serializer.cpp"
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"torch/csrc/jit/serialization/flatbuffer_serializer.h"
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":torch"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "shm",
|
name = "shm",
|
||||||
srcs = glob(["torch/lib/libshm/*.cpp"]),
|
srcs = glob(["torch/lib/libshm/*.cpp"]),
|
||||||
|
|
@ -2056,6 +2092,8 @@ cc_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":torch",
|
":torch",
|
||||||
|
":flatbuffer_serializer",
|
||||||
|
":flatbuffer_loader",
|
||||||
"@com_google_googletest//:gtest_main",
|
"@com_google_googletest//:gtest_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -770,6 +770,7 @@ if(NOT MSVC)
|
||||||
string(APPEND CMAKE_CXX_FLAGS " -Wno-error=deprecated-declarations")
|
string(APPEND CMAKE_CXX_FLAGS " -Wno-error=deprecated-declarations")
|
||||||
if(CMAKE_COMPILER_IS_GNUCXX AND NOT (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0.0))
|
if(CMAKE_COMPILER_IS_GNUCXX AND NOT (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0.0))
|
||||||
string(APPEND CMAKE_CXX_FLAGS " -Wno-stringop-overflow")
|
string(APPEND CMAKE_CXX_FLAGS " -Wno-stringop-overflow")
|
||||||
|
string(APPEND CMAKE_CXX_FLAGS " -Wno-noexcept-type")
|
||||||
endif()
|
endif()
|
||||||
if(CMAKE_COMPILER_IS_GNUCXX)
|
if(CMAKE_COMPILER_IS_GNUCXX)
|
||||||
# Suppress "The ABI for passing parameters with 64-byte alignment has changed in GCC 4.6"
|
# Suppress "The ABI for passing parameters with 64-byte alignment has changed in GCC 4.6"
|
||||||
|
|
|
||||||
|
|
@ -181,3 +181,8 @@ new_empty_repository(
|
||||||
name = "cuda",
|
name = "cuda",
|
||||||
build_file = "//third_party:cuda.BUILD",
|
build_file = "//third_party:cuda.BUILD",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
local_repository(
|
||||||
|
name = "com_github_google_flatbuffers",
|
||||||
|
path = "third_party/flatbuffers",
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -557,6 +557,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||||
${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
|
${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/mobile/model_compatibility.cpp
|
${TORCH_SRC_DIR}/csrc/jit/mobile/model_compatibility.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
|
${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
|
||||||
|
${TORCH_SRC_DIR}/csrc/jit/mobile/flatbuffer_loader.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
|
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
|
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_operators.cpp
|
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_operators.cpp
|
||||||
|
|
@ -591,8 +592,10 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||||
${TORCH_SRC_DIR}/csrc/jit/serialization/export.cpp
|
${TORCH_SRC_DIR}/csrc/jit/serialization/export.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/serialization/export_bytecode.cpp
|
${TORCH_SRC_DIR}/csrc/jit/serialization/export_bytecode.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/serialization/export_module.cpp
|
${TORCH_SRC_DIR}/csrc/jit/serialization/export_module.cpp
|
||||||
|
${TORCH_SRC_DIR}/csrc/jit/serialization/flatbuffer_serializer.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp
|
${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/api/module_save.cpp
|
${TORCH_SRC_DIR}/csrc/jit/api/module_save.cpp
|
||||||
|
${TORCH_SRC_DIR}/csrc/jit/testing/module_differ.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp
|
${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1640,6 +1643,32 @@ if(APPLE AND USE_PYTORCH_METAL)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
set(schema ${TORCH_SRC_DIR}/csrc/jit/serialization/mobile_bytecode.fbs)
|
||||||
|
set(generated_include
|
||||||
|
"${TORCH_ROOT}/build/torch/csrc/jit/serialization/mobile_bytecode_generated.h")
|
||||||
|
## cann add--reflect-names
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT ${generated_include}
|
||||||
|
COMMAND bash ${TORCH_ROOT}/scripts/gen_flatbuffer.sh
|
||||||
|
DEPENDS ${schema}
|
||||||
|
WORKING_DIRECTORY "${TORCH_ROOT}"
|
||||||
|
COMMENT "Generating mobile_bytecode_generated.h"
|
||||||
|
)
|
||||||
|
add_library(mobile_bytecode_generated_h INTERFACE)
|
||||||
|
target_sources(
|
||||||
|
mobile_bytecode_generated_h
|
||||||
|
INTERFACE ${generated_include}
|
||||||
|
)
|
||||||
|
add_dependencies(mobile_bytecode_generated_h flatc ${generated_include})
|
||||||
|
target_include_directories(
|
||||||
|
mobile_bytecode_generated_h
|
||||||
|
INTERFACE ${TORCH_ROOT}/build)
|
||||||
|
|
||||||
|
add_dependencies(torch_cpu mobile_bytecode_generated_h)
|
||||||
|
target_link_libraries(
|
||||||
|
torch_cpu PRIVATE mobile_bytecode_generated_h flatbuffers)
|
||||||
|
|
||||||
# Note [Global dependencies]
|
# Note [Global dependencies]
|
||||||
# Some libraries (e.g. OpenMPI) like to dlopen plugins after they're initialized,
|
# Some libraries (e.g. OpenMPI) like to dlopen plugins after they're initialized,
|
||||||
# and they assume that all of their symbols will be available in the global namespace.
|
# and they assume that all of their symbols will be available in the global namespace.
|
||||||
|
|
|
||||||
|
|
@ -1995,3 +1995,6 @@ if(USE_KINETO)
|
||||||
message(STATUS "Configured Kineto")
|
message(STATUS "Configured Kineto")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Include google/FlatBuffers
|
||||||
|
include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake)
|
||||||
|
|
|
||||||
2
cmake/FlatBuffers.cmake
Normal file
2
cmake/FlatBuffers.cmake
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON CACHE BOOL "" FORCE)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/flatbuffers ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-build EXCLUDE_FROM_ALL)
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# Python dependencies required for development
|
# Python dependencies required for development
|
||||||
astunparse
|
astunparse
|
||||||
expecttest
|
expecttest
|
||||||
|
flatbuffers
|
||||||
future
|
future
|
||||||
numpy
|
numpy
|
||||||
psutil
|
psutil
|
||||||
|
|
|
||||||
15
scripts/gen_flatbuffer.sh
Executable file
15
scripts/gen_flatbuffer.sh
Executable file
|
|
@ -0,0 +1,15 @@
|
||||||
|
#!/bin/bash
|
||||||
|
ROOT=$(pwd)
|
||||||
|
FF_LOCATION="$ROOT/third_party/flatbuffers"
|
||||||
|
cd "$FF_LOCATION" || exit
|
||||||
|
mkdir build
|
||||||
|
cd build || exit
|
||||||
|
py() { command python "$@"; }
|
||||||
|
cmake ..
|
||||||
|
cmake --build . --target flatc
|
||||||
|
mkdir -p "$ROOT/build/torch/csrc/jit/serialization"
|
||||||
|
./flatc --cpp --gen-mutable --scoped-enums \
|
||||||
|
-o "$ROOT/build/torch/csrc/jit/serialization" \
|
||||||
|
-c "$ROOT/torch/csrc/jit/serialization/mobile_bytecode.fbs"
|
||||||
|
cd "$ROOT" || exit
|
||||||
|
exit
|
||||||
|
|
@ -97,6 +97,9 @@ add_executable(test_jit
|
||||||
${TORCH_ROOT}/test/cpp/common/main.cpp
|
${TORCH_ROOT}/test/cpp/common/main.cpp
|
||||||
${JIT_TEST_SRCS}
|
${JIT_TEST_SRCS}
|
||||||
)
|
)
|
||||||
|
add_dependencies(test_jit flatbuffers)
|
||||||
|
target_link_libraries(test_jit PRIVATE flatbuffers)
|
||||||
|
|
||||||
|
|
||||||
# TODO temporary until we can delete the old gtest polyfills.
|
# TODO temporary until we can delete the old gtest polyfills.
|
||||||
target_compile_definitions(test_jit PRIVATE USE_GTEST)
|
target_compile_definitions(test_jit PRIVATE USE_GTEST)
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,21 @@
|
||||||
#include <test/cpp/jit/test_utils.h>
|
#include <test/cpp/jit/test_utils.h>
|
||||||
#include <torch/csrc/jit/api/module.h>
|
#include <torch/csrc/jit/api/module.h>
|
||||||
#include <torch/csrc/jit/backends/backend_detail.h>
|
#include <torch/csrc/jit/backends/backend_detail.h>
|
||||||
|
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||||
#include <torch/csrc/jit/mobile/import.h>
|
#include <torch/csrc/jit/mobile/import.h>
|
||||||
|
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||||
#include <torch/csrc/jit/serialization/import.h>
|
#include <torch/csrc/jit/serialization/import.h>
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
|
|
||||||
// Tests go in torch::jit
|
// Tests go in torch::jit
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
|
mobile::Module load_mobile_module(void* data, size_t) {
|
||||||
|
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
|
||||||
|
return initialize_mobile_module(flatbuffer_module);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(BackendTest, ToBackend) {
|
TEST(BackendTest, ToBackend) {
|
||||||
Module m("m");
|
Module m("m");
|
||||||
m.define(R"(
|
m.define(R"(
|
||||||
|
|
@ -141,6 +149,11 @@ TEST(BackendTest, TestCompiler) {
|
||||||
auto mlm = _load_for_mobile(ss);
|
auto mlm = _load_for_mobile(ss);
|
||||||
auto mres = mlm.forward(inputs);
|
auto mres = mlm.forward(inputs);
|
||||||
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
|
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(mlm);
|
||||||
|
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||||
|
auto mres2 = mlm2.forward(inputs);
|
||||||
|
AT_ASSERT(mres2.toTensor().equal(ref.toTensor()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BackendTest, TestComposite) {
|
TEST(BackendTest, TestComposite) {
|
||||||
|
|
@ -183,8 +196,12 @@ TEST(BackendTest, TestComposite) {
|
||||||
c._save_for_mobile(ss);
|
c._save_for_mobile(ss);
|
||||||
auto mc = _load_for_mobile(ss);
|
auto mc = _load_for_mobile(ss);
|
||||||
auto res_mobile = mc.forward(inputs);
|
auto res_mobile = mc.forward(inputs);
|
||||||
|
|
||||||
AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
|
AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(mc);
|
||||||
|
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||||
|
auto mres2 = mlm2.forward(inputs);
|
||||||
|
AT_ASSERT(mres2.toTensor().equal(res_jit.toTensor()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Module getCompositeModuleWithSameNameSubModules() {
|
Module getCompositeModuleWithSameNameSubModules() {
|
||||||
|
|
@ -241,6 +258,11 @@ TEST(BackendTest, TestCompositeWithSetStates) {
|
||||||
auto mc = _load_for_mobile(ss);
|
auto mc = _load_for_mobile(ss);
|
||||||
auto res_mobile = mc.forward(inputs);
|
auto res_mobile = mc.forward(inputs);
|
||||||
AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
|
AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(mc);
|
||||||
|
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||||
|
auto mres2 = mlm2.forward(inputs);
|
||||||
|
AT_ASSERT(mres2.toTensor().equal(res_jit.toTensor()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) {
|
TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) {
|
||||||
|
|
@ -256,6 +278,11 @@ TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) {
|
||||||
auto mc = _load_for_mobile(ss);
|
auto mc = _load_for_mobile(ss);
|
||||||
auto res_mobile = mc.forward(inputs);
|
auto res_mobile = mc.forward(inputs);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(mc);
|
||||||
|
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||||
|
auto mres2 = mlm2.forward(inputs);
|
||||||
|
AT_ASSERT(mres2.toTensor().equal(res_mobile.toTensor()));
|
||||||
|
|
||||||
// check if the methods names are always the same
|
// check if the methods names are always the same
|
||||||
// by reloading the script module and saving it back as mobile
|
// by reloading the script module and saving it back as mobile
|
||||||
// The below checks ensure that the names of Methods
|
// The below checks ensure that the names of Methods
|
||||||
|
|
@ -354,6 +381,13 @@ Traceback of TorchScript (most recent call last):
|
||||||
~~~~~ <--- HERE
|
~~~~~ <--- HERE
|
||||||
)";
|
)";
|
||||||
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
|
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
|
||||||
|
|
||||||
|
/* TODO(add debug info to flatbuffer)
|
||||||
|
auto buff = save_mobile_module_to_bytes(mlm);
|
||||||
|
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||||
|
mlm2.forward(inputs);
|
||||||
|
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) {
|
TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) {
|
||||||
|
|
@ -414,6 +448,12 @@ Traceback of TorchScript (most recent call last):
|
||||||
~~~~~ <--- HERE
|
~~~~~ <--- HERE
|
||||||
)";
|
)";
|
||||||
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
|
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
|
||||||
|
|
||||||
|
/* TODO(add debug info to flatbuffer)
|
||||||
|
auto buff = save_mobile_module_to_bytes(mlm);
|
||||||
|
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||||
|
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(
|
TEST(
|
||||||
|
|
@ -512,7 +552,13 @@ Traceback of TorchScript (most recent call last):
|
||||||
return x + y
|
return x + y
|
||||||
~~~~~ <--- HERE
|
~~~~~ <--- HERE
|
||||||
)";
|
)";
|
||||||
|
|
||||||
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
|
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
|
||||||
|
/* TODO(add debug info to flatbuffer)
|
||||||
|
auto buff = save_mobile_module_to_bytes(mlm);
|
||||||
|
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||||
|
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithLoweredSubModule) {
|
TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithLoweredSubModule) {
|
||||||
|
|
@ -594,6 +640,11 @@ Traceback of TorchScript (most recent call last):
|
||||||
~~~~~ <--- HERE
|
~~~~~ <--- HERE
|
||||||
)";
|
)";
|
||||||
ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
|
ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
|
||||||
|
/* TODO(add debug info to flatbuffer)
|
||||||
|
auto buff = save_mobile_module_to_bytes(c_loaded);
|
||||||
|
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||||
|
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(
|
TEST(
|
||||||
|
|
@ -721,6 +772,11 @@ Traceback of TorchScript (most recent call last):
|
||||||
~~~~~ <--- HERE
|
~~~~~ <--- HERE
|
||||||
)";
|
)";
|
||||||
ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
|
ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
|
||||||
|
/* TODO(add debug info to flatbuffer)
|
||||||
|
auto buff = save_mobile_module_to_bytes(c_loaded);
|
||||||
|
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
|
||||||
|
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include <torch/csrc/jit/frontend/resolver.h>
|
#include <torch/csrc/jit/frontend/resolver.h>
|
||||||
#include <torch/csrc/jit/mobile/backport.h>
|
#include <torch/csrc/jit/mobile/backport.h>
|
||||||
#include <torch/csrc/jit/mobile/backport_manager.h>
|
#include <torch/csrc/jit/mobile/backport_manager.h>
|
||||||
|
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||||
#include <torch/csrc/jit/mobile/import.h>
|
#include <torch/csrc/jit/mobile/import.h>
|
||||||
#include <torch/csrc/jit/mobile/interpreter.h>
|
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||||
#include <torch/csrc/jit/mobile/model_compatibility.h>
|
#include <torch/csrc/jit/mobile/model_compatibility.h>
|
||||||
|
|
@ -16,6 +17,7 @@
|
||||||
#include <torch/csrc/jit/mobile/parse_operators.h>
|
#include <torch/csrc/jit/mobile/parse_operators.h>
|
||||||
#include <torch/csrc/jit/mobile/runtime_compatibility.h>
|
#include <torch/csrc/jit/mobile/runtime_compatibility.h>
|
||||||
#include <torch/csrc/jit/serialization/export.h>
|
#include <torch/csrc/jit/serialization/export.h>
|
||||||
|
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||||
#include <torch/csrc/jit/serialization/import.h>
|
#include <torch/csrc/jit/serialization/import.h>
|
||||||
#include <torch/custom_class.h>
|
#include <torch/custom_class.h>
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
|
|
@ -26,6 +28,11 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
|
mobile::Module parse_mobile_module(void* data, size_t) {
|
||||||
|
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
|
||||||
|
return initialize_mobile_module(flatbuffer_module);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, UpsampleNearest2d) {
|
TEST(LiteInterpreterTest, UpsampleNearest2d) {
|
||||||
Module m("m");
|
Module m("m");
|
||||||
m.define(R"(
|
m.define(R"(
|
||||||
|
|
@ -47,6 +54,12 @@ TEST(LiteInterpreterTest, UpsampleNearest2d) {
|
||||||
auto resd = res.toTensor();
|
auto resd = res.toTensor();
|
||||||
auto refd = ref.toTensor();
|
auto refd = ref.toTensor();
|
||||||
ASSERT_TRUE(resd.equal(refd));
|
ASSERT_TRUE(resd.equal(refd));
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
auto res2 = bc2.forward(inputs);
|
||||||
|
auto resd2 = res2.toTensor();
|
||||||
|
ASSERT_TRUE(resd2.equal(refd));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, CheckAttrAccess) {
|
TEST(LiteInterpreterTest, CheckAttrAccess) {
|
||||||
|
|
@ -66,6 +79,11 @@ TEST(LiteInterpreterTest, CheckAttrAccess) {
|
||||||
mobile_optimized = bc.attr("mobile_optimized", false).toBool();
|
mobile_optimized = bc.attr("mobile_optimized", false).toBool();
|
||||||
|
|
||||||
AT_ASSERT(!mobile_optimized);
|
AT_ASSERT(!mobile_optimized);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
auto mobile_optimized2 = bc2.attr("mobile_optimized", false).toBool();
|
||||||
|
AT_ASSERT(!mobile_optimized2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, MethodInvocation) { // NOLINT (use =delete in gtest)
|
TEST(LiteInterpreterTest, MethodInvocation) { // NOLINT (use =delete in gtest)
|
||||||
|
|
@ -110,6 +128,16 @@ TEST(LiteInterpreterTest, MethodInvocation) { // NOLINT (use =delete in gtest)
|
||||||
auto resd = res.toTensor().item<float>();
|
auto resd = res.toTensor().item<float>();
|
||||||
auto refd = ref.toTensor().item<float>();
|
auto refd = ref.toTensor().item<float>();
|
||||||
AT_ASSERT(resd == refd);
|
AT_ASSERT(resd == refd);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
const auto& test_func2 = bc2.get_method("test_func");
|
||||||
|
IValue res2;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
res2 = test_func2({minput});
|
||||||
|
}
|
||||||
|
auto resd2 = res2.toTensor().item<float>();
|
||||||
|
AT_ASSERT(resd2 == refd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -144,6 +172,16 @@ TEST(LiteInterpreterTest, Conv) {
|
||||||
AT_ASSERT(outputref.dim() == output.dim());
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
AT_ASSERT(
|
AT_ASSERT(
|
||||||
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
res = bc2.get_method("forward")(inputs);
|
||||||
|
}
|
||||||
|
output = res.toTensor();
|
||||||
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
|
AT_ASSERT(
|
||||||
|
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, Inline) {
|
TEST(LiteInterpreterTest, Inline) {
|
||||||
|
|
@ -164,6 +202,12 @@ TEST(LiteInterpreterTest, Inline) {
|
||||||
std::vector<torch::jit::IValue> inputs({torch::ones({})});
|
std::vector<torch::jit::IValue> inputs({torch::ones({})});
|
||||||
auto output = bc.get_method("foo3")(inputs);
|
auto output = bc.get_method("foo3")(inputs);
|
||||||
AT_ASSERT(output.toTensor().item<float>() == 7.0);
|
AT_ASSERT(output.toTensor().item<float>() == 7.0);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
std::vector<torch::jit::IValue> inputs2({torch::ones({})});
|
||||||
|
output = bc2.get_method("foo3")(inputs2);
|
||||||
|
AT_ASSERT(output.toTensor().item<float>() == 7.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, Tuple) {
|
TEST(LiteInterpreterTest, Tuple) {
|
||||||
|
|
@ -182,6 +226,11 @@ TEST(LiteInterpreterTest, Tuple) {
|
||||||
std::vector<torch::jit::IValue> inputs({torch::ones({})});
|
std::vector<torch::jit::IValue> inputs({torch::ones({})});
|
||||||
auto output = bc.get_method("forward")(inputs);
|
auto output = bc.get_method("forward")(inputs);
|
||||||
AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
|
AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
output = bc2.get_method("forward")(inputs);
|
||||||
|
AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, Dict) {
|
TEST(LiteInterpreterTest, Dict) {
|
||||||
|
|
@ -200,6 +249,11 @@ TEST(LiteInterpreterTest, Dict) {
|
||||||
std::vector<torch::jit::IValue> inputs({torch::ones({})});
|
std::vector<torch::jit::IValue> inputs({torch::ones({})});
|
||||||
auto output = bc.get_method("forward")(inputs);
|
auto output = bc.get_method("forward")(inputs);
|
||||||
AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
|
AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
output = bc2.get_method("forward")(inputs);
|
||||||
|
AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, PrimOverload) {
|
TEST(LiteInterpreterTest, PrimOverload) {
|
||||||
|
|
@ -246,6 +300,16 @@ TEST(LiteInterpreterTest, Prim) {
|
||||||
auto resi = res.toInt();
|
auto resi = res.toInt();
|
||||||
auto refi = ref.toInt();
|
auto refi = ref.toInt();
|
||||||
AT_ASSERT(resi == refi);
|
AT_ASSERT(resi == refi);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||||
|
auto bcinputs = inputs;
|
||||||
|
res = bc2.get_method("forward")(bcinputs);
|
||||||
|
}
|
||||||
|
auto resi2 = res.toInt();
|
||||||
|
AT_ASSERT(resi2 == refi);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, PrimScalar) {
|
TEST(LiteInterpreterTest, PrimScalar) {
|
||||||
|
|
@ -273,6 +337,16 @@ TEST(LiteInterpreterTest, PrimScalar) {
|
||||||
auto resi = res.toInt();
|
auto resi = res.toInt();
|
||||||
auto refi = ref.toInt();
|
auto refi = ref.toInt();
|
||||||
AT_ASSERT(resi == refi);
|
AT_ASSERT(resi == refi);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||||
|
auto bcinputs = inputs;
|
||||||
|
res = bc2.get_method("forward")(bcinputs);
|
||||||
|
}
|
||||||
|
auto resi2 = res.toInt();
|
||||||
|
AT_ASSERT(resi2 == refi);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, LoadOrigJit) {
|
TEST(LiteInterpreterTest, LoadOrigJit) {
|
||||||
|
|
@ -304,6 +378,11 @@ TEST(LiteInterpreterTest, WrongMethodName) {
|
||||||
inputs.emplace_back(minput);
|
inputs.emplace_back(minput);
|
||||||
ASSERT_THROWS_WITH_MESSAGE(
|
ASSERT_THROWS_WITH_MESSAGE(
|
||||||
bc.get_method("forward")(inputs), "is not defined");
|
bc.get_method("forward")(inputs), "is not defined");
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
ASSERT_THROWS_WITH_MESSAGE(
|
||||||
|
bc2.get_method("forward")(inputs), "is not defined");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, SetState) {
|
TEST(LiteInterpreterTest, SetState) {
|
||||||
|
|
@ -311,7 +390,7 @@ TEST(LiteInterpreterTest, SetState) {
|
||||||
m.register_parameter("foo", torch::ones({}), false);
|
m.register_parameter("foo", torch::ones({}), false);
|
||||||
m.define(R"(
|
m.define(R"(
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
return self.foo + self.foo
|
return self.foo
|
||||||
def __setstate__(self, a):
|
def __setstate__(self, a):
|
||||||
self.foo = a
|
self.foo = a
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
@ -341,6 +420,17 @@ TEST(LiteInterpreterTest, SetState) {
|
||||||
auto resd = res.toTensor().item<float>();
|
auto resd = res.toTensor().item<float>();
|
||||||
auto refd = ref.toTensor().item<float>();
|
auto refd = ref.toTensor().item<float>();
|
||||||
AT_ASSERT(resd == refd);
|
AT_ASSERT(resd == refd);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||||
|
auto bcinputs = inputs;
|
||||||
|
res = bc2.get_method("forward")(bcinputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resd2 = res.toTensor().item<float>();
|
||||||
|
AT_ASSERT(resd2 == refd);
|
||||||
}
|
}
|
||||||
|
|
||||||
class TorchBindLiteInterpreterTestStruct
|
class TorchBindLiteInterpreterTestStruct
|
||||||
|
|
@ -435,6 +525,12 @@ TEST(LiteInterpreterTest, BuiltinClass) {
|
||||||
const auto& str = res.toStringRef();
|
const auto& str = res.toStringRef();
|
||||||
std::string expected = "Hello! Your tensor has 12 elements!";
|
std::string expected = "Hello! Your tensor has 12 elements!";
|
||||||
AT_ASSERT(str == expected);
|
AT_ASSERT(str == expected);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
res = bc2.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
|
||||||
|
const auto& str2 = res.toStringRef();
|
||||||
|
AT_ASSERT(str2 == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, BuiltinFunction) {
|
TEST(LiteInterpreterTest, BuiltinFunction) {
|
||||||
|
|
@ -456,6 +552,13 @@ TEST(LiteInterpreterTest, BuiltinFunction) {
|
||||||
auto str = res.toStringRef();
|
auto str = res.toStringRef();
|
||||||
std::string expected = "Hello! Your tensor has 12 elements!";
|
std::string expected = "Hello! Your tensor has 12 elements!";
|
||||||
AT_ASSERT(str == expected);
|
AT_ASSERT(str == expected);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
res = bc2.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
|
||||||
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||||
|
str = res.toStringRef();
|
||||||
|
AT_ASSERT(str == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined FB_XPLAT_BUILD
|
#if !defined FB_XPLAT_BUILD
|
||||||
|
|
@ -776,6 +879,17 @@ TEST(LiteInterpreterTest, Eval) {
|
||||||
AT_ASSERT(outputref.dim() == output.dim());
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
AT_ASSERT(
|
AT_ASSERT(
|
||||||
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
bc2.eval();
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
res = bc2.get_method("forward")(inputs);
|
||||||
|
}
|
||||||
|
output = res.toTensor();
|
||||||
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
|
AT_ASSERT(
|
||||||
|
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, FindWrongMethodName) {
|
TEST(LiteInterpreterTest, FindWrongMethodName) {
|
||||||
|
|
@ -790,6 +904,10 @@ TEST(LiteInterpreterTest, FindWrongMethodName) {
|
||||||
m._save_for_mobile(ss);
|
m._save_for_mobile(ss);
|
||||||
mobile::Module bc = _load_for_mobile(ss);
|
mobile::Module bc = _load_for_mobile(ss);
|
||||||
ASSERT_TRUE(bc.find_method("forward") == c10::nullopt);
|
ASSERT_TRUE(bc.find_method("forward") == c10::nullopt);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
ASSERT_TRUE(bc2.find_method("forward") == c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, FindAndRunMethod) {
|
TEST(LiteInterpreterTest, FindAndRunMethod) {
|
||||||
|
|
@ -820,6 +938,19 @@ TEST(LiteInterpreterTest, FindAndRunMethod) {
|
||||||
auto resd = res.toTensor().item<float>();
|
auto resd = res.toTensor().item<float>();
|
||||||
auto refd = ref.toTensor().item<float>();
|
auto refd = ref.toTensor().item<float>();
|
||||||
AT_ASSERT(resd == refd);
|
AT_ASSERT(resd == refd);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
auto bcinputs = inputs;
|
||||||
|
auto method = bc2.find_method("add_it");
|
||||||
|
AT_ASSERT(method != c10::nullopt);
|
||||||
|
res = (*method)(std::move(bcinputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
resd = res.toTensor().item<float>();
|
||||||
|
AT_ASSERT(resd == refd);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, RunMethodVariadic) {
|
TEST(LiteInterpreterTest, RunMethodVariadic) {
|
||||||
|
|
@ -843,6 +974,12 @@ TEST(LiteInterpreterTest, RunMethodVariadic) {
|
||||||
auto resd = res.toTensor().item<float>();
|
auto resd = res.toTensor().item<float>();
|
||||||
auto refd = ref.toTensor().item<float>();
|
auto refd = ref.toTensor().item<float>();
|
||||||
AT_ASSERT(resd == refd);
|
AT_ASSERT(resd == refd);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
res = bc.run_method("add_three", inputx, inputy);
|
||||||
|
resd = res.toTensor().item<float>();
|
||||||
|
AT_ASSERT(resd == refd);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, DuplicateSetState) {
|
TEST(LiteInterpreterTest, DuplicateSetState) {
|
||||||
|
|
@ -872,6 +1009,11 @@ TEST(LiteInterpreterTest, DuplicateSetState) {
|
||||||
const auto methods = bc.get_methods();
|
const auto methods = bc.get_methods();
|
||||||
const size_t expected_n = 3;
|
const size_t expected_n = 3;
|
||||||
ASSERT_EQ(methods.size(), expected_n);
|
ASSERT_EQ(methods.size(), expected_n);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
const auto methods2 = bc.get_methods();
|
||||||
|
ASSERT_EQ(methods2.size(), expected_n);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, ExtraFiles) {
|
TEST(LiteInterpreterTest, ExtraFiles) {
|
||||||
|
|
@ -940,6 +1082,12 @@ TEST(LiteInterpreterTest, OpNameExportFetchRootOperators) {
|
||||||
};
|
};
|
||||||
EXPECT_EQ(operator_names, expected_operator_names)
|
EXPECT_EQ(operator_names, expected_operator_names)
|
||||||
<< "Expected the root operator lists to be the same";
|
<< "Expected the root operator lists to be the same";
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(ptl_model);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
operator_names = torch::jit::mobile::_export_operator_list(bc2);
|
||||||
|
EXPECT_EQ(operator_names, expected_operator_names)
|
||||||
|
<< "Expected the root operator lists to be the same";
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, DefaultArgsConv) {
|
TEST(LiteInterpreterTest, DefaultArgsConv) {
|
||||||
|
|
@ -957,7 +1105,7 @@ TEST(LiteInterpreterTest, DefaultArgsConv) {
|
||||||
return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
|
return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
|
||||||
)");
|
)");
|
||||||
|
|
||||||
inputs.push_back(torch::ones({1, 1, 28, 28}));
|
inputs.emplace_back(torch::ones({1, 1, 28, 28}));
|
||||||
|
|
||||||
auto outputref = m.forward(inputs).toTensor();
|
auto outputref = m.forward(inputs).toTensor();
|
||||||
|
|
||||||
|
|
@ -971,6 +1119,15 @@ TEST(LiteInterpreterTest, DefaultArgsConv) {
|
||||||
auto output = res.toTensor();
|
auto output = res.toTensor();
|
||||||
AT_ASSERT(outputref.dim() == output.dim());
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
AT_ASSERT(output.equal(outputref));
|
AT_ASSERT(output.equal(outputref));
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
for (int i = 0; i < 1; ++i) {
|
||||||
|
res = bc2.get_method("forward")(inputs);
|
||||||
|
}
|
||||||
|
output = res.toTensor();
|
||||||
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
|
AT_ASSERT(output.equal(outputref));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(RunTimeTest, ParseBytecode) {
|
TEST(RunTimeTest, ParseBytecode) {
|
||||||
|
|
@ -1016,8 +1173,8 @@ TEST(RunTimeTest, ParseBytecode) {
|
||||||
std::vector<IValue> types{"List[int]", "List[int]"};
|
std::vector<IValue> types{"List[int]", "List[int]"};
|
||||||
// 2. Parse the function
|
// 2. Parse the function
|
||||||
std::string function_name("test_function");
|
std::string function_name("test_function");
|
||||||
auto function = std::unique_ptr<mobile::Function>(
|
auto function =
|
||||||
new mobile::Function(c10::QualifiedName(function_name)));
|
std::make_unique<mobile::Function>(c10::QualifiedName(function_name));
|
||||||
c10::ivalue::TupleElements debug_handles_m_tuple;
|
c10::ivalue::TupleElements debug_handles_m_tuple;
|
||||||
parseInstructions(
|
parseInstructions(
|
||||||
function_name,
|
function_name,
|
||||||
|
|
@ -1077,8 +1234,8 @@ TEST(RunTimeTest, ParseOperator) {
|
||||||
int64_t model_version = caffe2::serialize::kProducedBytecodeVersion;
|
int64_t model_version = caffe2::serialize::kProducedBytecodeVersion;
|
||||||
// 2. Parse the function
|
// 2. Parse the function
|
||||||
std::string function_name("test_function");
|
std::string function_name("test_function");
|
||||||
auto function = std::unique_ptr<mobile::Function>(
|
auto function =
|
||||||
new mobile::Function(c10::QualifiedName(function_name)));
|
std::make_unique<mobile::Function>(c10::QualifiedName(function_name));
|
||||||
c10::ivalue::TupleElements debug_handles_m_tuple;
|
c10::ivalue::TupleElements debug_handles_m_tuple;
|
||||||
parseInstructions(
|
parseInstructions(
|
||||||
function_name,
|
function_name,
|
||||||
|
|
@ -1120,6 +1277,15 @@ void testLiteModuleCompareResultTensors(
|
||||||
auto output = res.toTensor();
|
auto output = res.toTensor();
|
||||||
AT_ASSERT(outputref.dim() == output.dim());
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
AT_ASSERT(output.equal(outputref));
|
AT_ASSERT(output.equal(outputref));
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
res = bc2.get_method(method_name)(inputs);
|
||||||
|
}
|
||||||
|
output = res.toTensor();
|
||||||
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
|
AT_ASSERT(output.equal(outputref));
|
||||||
}
|
}
|
||||||
|
|
||||||
void testDefaultArgsPinv(int num_args) {
|
void testDefaultArgsPinv(int num_args) {
|
||||||
|
|
@ -1146,7 +1312,7 @@ void testDefaultArgsPinv(int num_args) {
|
||||||
auto input = torch::range(1, N * N, 1);
|
auto input = torch::range(1, N * N, 1);
|
||||||
input[0] = 1; // a more stable matrix
|
input[0] = 1; // a more stable matrix
|
||||||
input = input.view({N, N});
|
input = input.view({N, N});
|
||||||
inputs.push_back(input);
|
inputs.emplace_back(input);
|
||||||
testLiteModuleCompareResultTensors(m, inputs);
|
testLiteModuleCompareResultTensors(m, inputs);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
@ -1246,7 +1412,7 @@ TEST(LiteInterpreterTest, DefaultArgsTensorinvSpecifyDefault) {
|
||||||
std::vector<torch::jit::IValue> inputs;
|
std::vector<torch::jit::IValue> inputs;
|
||||||
const int N = 4;
|
const int N = 4;
|
||||||
auto input = torch::rand({N, N, N, N});
|
auto input = torch::rand({N, N, N, N});
|
||||||
inputs.push_back(input);
|
inputs.emplace_back(input);
|
||||||
testLiteModuleCompareResultTensors(m, inputs);
|
testLiteModuleCompareResultTensors(m, inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1310,6 +1476,19 @@ TEST(LiteInterpreterTest, DefaultArgsWithOutArg) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
op != ops.end() && op->second.num_schema_args.has_value() &&
|
op != ops.end() && op->second.num_schema_args.has_value() &&
|
||||||
op->second.num_schema_args.value() == 3);
|
op->second.num_schema_args.value() == 3);
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
auto input_x2 = 2 * torch::ones({});
|
||||||
|
auto input_h2 = torch::ones({});
|
||||||
|
m.run_method("forward", input_x2, input_h2);
|
||||||
|
bc2.run_method("forward", input_x2, input_h2);
|
||||||
|
AT_ASSERT(input_x2.equal(4 * torch::ones({})));
|
||||||
|
ops = _get_model_ops_and_info(ss);
|
||||||
|
op = ops.find("aten::add.out");
|
||||||
|
TORCH_CHECK(
|
||||||
|
op != ops.end() && op->second.num_schema_args.has_value() &&
|
||||||
|
op->second.num_schema_args.value() == 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
|
TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
|
||||||
|
|
@ -1517,6 +1696,13 @@ TEST(LiteInterpreterTest, OperatorSize1) {
|
||||||
ASSERT_EQ(
|
ASSERT_EQ(
|
||||||
func.get_code()->operator_input_sizes_.size(),
|
func.get_code()->operator_input_sizes_.size(),
|
||||||
func.get_code()->operators_.size());
|
func.get_code()->operators_.size());
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
const auto& func2 = bc.get_method("forward").function();
|
||||||
|
ASSERT_EQ(
|
||||||
|
func2.get_code()->operator_input_sizes_.size(),
|
||||||
|
func2.get_code()->operators_.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LiteInterpreterTest, OperatorTest2) { // NOLINT (use =delete in gtest)
|
TEST(LiteInterpreterTest, OperatorTest2) { // NOLINT (use =delete in gtest)
|
||||||
|
|
@ -1552,6 +1738,13 @@ TEST(LiteInterpreterTest, OperatorTest2) { // NOLINT (use =delete in gtest)
|
||||||
ASSERT_EQ(
|
ASSERT_EQ(
|
||||||
func.get_code()->operator_input_sizes_.size(),
|
func.get_code()->operator_input_sizes_.size(),
|
||||||
func.get_code()->operators_.size());
|
func.get_code()->operators_.size());
|
||||||
|
|
||||||
|
auto buff = save_mobile_module_to_bytes(bc);
|
||||||
|
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
|
||||||
|
const auto& func2 = bc.get_method("test_func").function();
|
||||||
|
ASSERT_EQ(
|
||||||
|
func2.get_code()->operator_input_sizes_.size(),
|
||||||
|
func2.get_code()->operators_.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
1
third_party/flatbuffers
vendored
Submodule
1
third_party/flatbuffers
vendored
Submodule
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit f2f9380c86a762ef0d9410693c61c35567923d63
|
||||||
|
|
@ -56,10 +56,19 @@ def run_autogen() -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_flatbuffers() -> None:
|
||||||
|
run_timed_cmd(
|
||||||
|
[
|
||||||
|
"bash", "scripts/gen_flatbuffer.sh"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_build_files() -> None:
|
def generate_build_files() -> None:
|
||||||
update_submodules()
|
update_submodules()
|
||||||
gen_compile_commands()
|
gen_compile_commands()
|
||||||
run_autogen()
|
run_autogen()
|
||||||
|
gen_flatbuffers()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES
|
||||||
|
|
||||||
${TORCH_ROOT}/third_party/gloo
|
${TORCH_ROOT}/third_party/gloo
|
||||||
${TORCH_ROOT}/third_party/onnx
|
${TORCH_ROOT}/third_party/onnx
|
||||||
|
${TORCH_ROOT}/third_party/flatbuffers/include
|
||||||
${pybind11_INCLUDE_DIRS}
|
${pybind11_INCLUDE_DIRS}
|
||||||
|
|
||||||
${TORCH_SRC_DIR}/csrc
|
${TORCH_SRC_DIR}/csrc
|
||||||
|
|
@ -343,6 +344,8 @@ if(HAVE_SOVERSION)
|
||||||
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
|
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
|
||||||
endif()
|
endif()
|
||||||
add_dependencies(torch_python torch_python_stubs)
|
add_dependencies(torch_python torch_python_stubs)
|
||||||
|
add_dependencies(torch_python flatbuffers)
|
||||||
|
|
||||||
|
|
||||||
if(USE_PRECOMPILED_HEADERS)
|
if(USE_PRECOMPILED_HEADERS)
|
||||||
target_precompile_headers(torch_python PRIVATE
|
target_precompile_headers(torch_python PRIVATE
|
||||||
|
|
|
||||||
|
|
@ -261,8 +261,11 @@ def _jit_assert_is_instance(obj: Any, type: JitType): ...
|
||||||
def _jit_clear_class_registry() -> None: ...
|
def _jit_clear_class_registry() -> None: ...
|
||||||
def _jit_set_emit_hooks(ModuleHook: Optional[Callable], FunctionHook: Optional[Callable]) -> None: ...
|
def _jit_set_emit_hooks(ModuleHook: Optional[Callable], FunctionHook: Optional[Callable]) -> None: ...
|
||||||
def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ...
|
def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ...
|
||||||
def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ...
|
def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None], is_flatbuffer: _bool): ...
|
||||||
def _load_for_lite_interpreter_from_buffer(buffer: BinaryIO, map_location: Union[_device, str, None]): ...
|
def _load_for_lite_interpreter_from_buffer(buffer: BinaryIO, map_location: Union[_device, str, None], is_flatbuffer: _bool): ...
|
||||||
|
def _save_mobile_module(module: LiteScriptModule, filename: str): ...
|
||||||
|
def _jit_module_to_mobile(module: ScriptModule): ...
|
||||||
|
def _module_equals(lhs: LiteScriptModule, rhs: LiteScriptModule): ...
|
||||||
def _export_operator_list(module: LiteScriptModule): ...
|
def _export_operator_list(module: LiteScriptModule): ...
|
||||||
def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ...
|
def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ...
|
||||||
def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ...
|
def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ...
|
||||||
|
|
|
||||||
515
torch/csrc/jit/mobile/flatbuffer_loader.cpp
Normal file
515
torch/csrc/jit/mobile/flatbuffer_loader.cpp
Normal file
|
|
@ -0,0 +1,515 @@
|
||||||
|
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||||
|
|
||||||
|
#include <ATen/core/ivalue.h>
|
||||||
|
#include <ATen/core/qualified_name.h>
|
||||||
|
#include <c10/core/CPUAllocator.h>
|
||||||
|
#include <c10/util/Exception.h>
|
||||||
|
#include <c10/util/Optional.h>
|
||||||
|
#include <c10/util/ScopeExit.h>
|
||||||
|
#include <caffe2/serialize/inline_container.h>
|
||||||
|
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
||||||
|
#include <torch/csrc/jit/mobile/import.h>
|
||||||
|
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||||
|
#include <torch/csrc/jit/mobile/observer.h>
|
||||||
|
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||||
|
#include <torch/csrc/jit/runtime/instruction.h>
|
||||||
|
#include <torch/csrc/jit/serialization/import_export_constants.h>
|
||||||
|
#include <torch/csrc/jit/serialization/import_read.h>
|
||||||
|
#include <torch/custom_class.h>
|
||||||
|
|
||||||
|
#include <flatbuffers/flatbuffers.h>
|
||||||
|
|
||||||
|
#if defined(HAVE_MMAP)
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <sys/mman.h>
|
||||||
|
#include <sys/stat.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using caffe2::serialize::IStreamAdapter;
|
||||||
|
using caffe2::serialize::PyTorchStreamReader;
|
||||||
|
using caffe2::serialize::ReadAdapterInterface;
|
||||||
|
|
||||||
|
static constexpr c10::string_view kCustomClassPrefix =
|
||||||
|
"__torch__.torch.classes";
|
||||||
|
static constexpr c10::string_view kTorchPrefix = "__torch__";
|
||||||
|
static constexpr c10::string_view kJitPrefix = "torch.jit";
|
||||||
|
|
||||||
|
class FlatbufferLoader {
|
||||||
|
public:
|
||||||
|
FlatbufferLoader()
|
||||||
|
: mcu_(std::make_shared<mobile::CompilationUnit>()),
|
||||||
|
cu_(std::make_shared<CompilationUnit>()) {}
|
||||||
|
|
||||||
|
mobile::Module parseModule(mobile::serialization::Module* module);
|
||||||
|
|
||||||
|
private:
|
||||||
|
IValue parseIValue(const mobile::serialization::IValue* ivalue);
|
||||||
|
IValue parseList(const mobile::serialization::List* list);
|
||||||
|
at::Tensor parseTensor(const mobile::serialization::TensorMetadata* tensor);
|
||||||
|
IValue parseTuple(const mobile::serialization::Tuple* tuple);
|
||||||
|
IValue parseDict(const mobile::serialization::Dict* dict);
|
||||||
|
IValue parseObject(const mobile::serialization::Object* object);
|
||||||
|
std::unique_ptr<mobile::Function> parseFunction(
|
||||||
|
const mobile::serialization::Function* method);
|
||||||
|
|
||||||
|
IValue& getIValue(uint32_t pos) {
|
||||||
|
TORCH_CHECK(pos < all_ivalues_.size());
|
||||||
|
return all_ivalues_[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
mobile::Function* getFunction(uint32_t pos) {
|
||||||
|
return all_functions_[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
ClassTypePtr getType(uint32_t pos) const {
|
||||||
|
TORCH_CHECK(pos < all_ivalues_.size());
|
||||||
|
return all_types_[pos];
|
||||||
|
// auto iter = all_types_.find(pos);
|
||||||
|
// AT_ASSERT(iter != all_types_.end(), "type not found at pos: ", pos);
|
||||||
|
// return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::Storage getStorage(uint32_t index);
|
||||||
|
TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
|
||||||
|
|
||||||
|
// fields
|
||||||
|
std::unordered_map<uint32_t, mobile::Function*> all_functions_;
|
||||||
|
std::vector<ClassTypePtr> all_types_;
|
||||||
|
std::unordered_set<uint32_t> initialized_types_;
|
||||||
|
std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
|
||||||
|
std::vector<bool> storage_loaded_;
|
||||||
|
std::vector<c10::Storage> storages_;
|
||||||
|
std::vector<IValue> all_ivalues_;
|
||||||
|
std::shared_ptr<mobile::CompilationUnit> mcu_;
|
||||||
|
std::shared_ptr<CompilationUnit> cu_;
|
||||||
|
mobile::serialization::Module* module_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
mobile::Module FlatbufferLoader::parseModule(
|
||||||
|
mobile::serialization::Module* module) {
|
||||||
|
module_ = module;
|
||||||
|
all_ivalues_.clear();
|
||||||
|
all_types_.clear();
|
||||||
|
storages_.clear();
|
||||||
|
storage_loaded_.clear();
|
||||||
|
|
||||||
|
const auto* ivalues = module->ivalues();
|
||||||
|
all_ivalues_.resize(ivalues->size());
|
||||||
|
all_types_.resize(module->object_types()->size());
|
||||||
|
storages_.resize(module->storage_data_size());
|
||||||
|
storage_loaded_.resize(module->storage_data_size(), false);
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < ivalues->size(); i++) {
|
||||||
|
const auto* ival = ivalues->Get(i);
|
||||||
|
if (const auto* func = ival->val_as_Function()) {
|
||||||
|
auto func_ptr = parseFunction(func);
|
||||||
|
all_functions_[i] = func_ptr.get();
|
||||||
|
mcu_->register_function(std::move(func_ptr));
|
||||||
|
} else {
|
||||||
|
all_ivalues_[i] = parseIValue(ival);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
IValue& module_ivalue = getIValue(module->state_obj());
|
||||||
|
// register function to class
|
||||||
|
// for (const auto& func: all_functions_) {
|
||||||
|
// const auto* fb_func = ivalues->Get(func.first)->val_as_Function();
|
||||||
|
// auto class_type = getType(fb_func->class_type());
|
||||||
|
// class_type->addMethod(func.second);
|
||||||
|
// }
|
||||||
|
return mobile::Module(module_ivalue.toObject(), mcu_);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
|
||||||
|
const mobile::serialization::Function* method) {
|
||||||
|
auto function = std::make_unique<mobile::Function>(
|
||||||
|
c10::QualifiedName(method->qn()->str()));
|
||||||
|
// TODO(qihan) add debug handle
|
||||||
|
// const auto* debug_handle = method->debug_info()->debug_handle();
|
||||||
|
for (const auto* inst : *method->instructions()) {
|
||||||
|
function->append_instruction(
|
||||||
|
static_cast<OpCode>(inst->op()), inst->x(), inst->n());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t i : *method->constants()) {
|
||||||
|
function->append_constant(getIValue(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_set<std::string> unsupported_op_names;
|
||||||
|
const int64_t model_version = 0x6L;
|
||||||
|
for (const auto* op : *method->operators()) {
|
||||||
|
c10::optional<int> num_args = c10::nullopt;
|
||||||
|
if (op->num_args_serialized() > -1) {
|
||||||
|
num_args = op->num_args_serialized();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto op_found = function->append_operator(
|
||||||
|
op->name()->str(), op->overload_name()->str(), num_args, model_version);
|
||||||
|
|
||||||
|
if (!op_found) {
|
||||||
|
unsupported_op_names.emplace(
|
||||||
|
op->name()->str() + "/" + op->overload_name()->str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AT_ASSERT(unsupported_op_names.empty());
|
||||||
|
|
||||||
|
for (const auto i : *method->type_annotations()) {
|
||||||
|
function->append_type(getOrCreateTypeAnnotations(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
function->set_register_size(method->register_size());
|
||||||
|
if (method->schema()) {
|
||||||
|
auto parseArgList = [this](const auto* args_fb) {
|
||||||
|
std::vector<c10::Argument> args;
|
||||||
|
for (const auto* arg_tb : *args_fb) {
|
||||||
|
IValue default_value = getIValue(arg_tb->default_value());
|
||||||
|
TypePtr type_ptr = getOrCreateTypeAnnotations(arg_tb->type());
|
||||||
|
auto arg = c10::Argument(
|
||||||
|
arg_tb->name()->str(),
|
||||||
|
std::move(type_ptr),
|
||||||
|
c10::nullopt /*N*/,
|
||||||
|
std::move(default_value));
|
||||||
|
args.emplace_back(std::move(arg));
|
||||||
|
}
|
||||||
|
return args;
|
||||||
|
};
|
||||||
|
c10::FunctionSchema schema(
|
||||||
|
method->qn()->str(),
|
||||||
|
"" /*overload_name*/,
|
||||||
|
parseArgList(method->schema()->arguments()),
|
||||||
|
parseArgList(method->schema()->returns()),
|
||||||
|
false /*is_varargs*/,
|
||||||
|
false /*is_varret*/);
|
||||||
|
|
||||||
|
function->setSchema(std::move(schema));
|
||||||
|
}
|
||||||
|
return function;
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor FlatbufferLoader::parseTensor(
|
||||||
|
const mobile::serialization::TensorMetadata* tensor_md) {
|
||||||
|
at::ScalarType type = static_cast<at::ScalarType>(tensor_md->scalar_type());
|
||||||
|
auto options = at::CPU(type).options();
|
||||||
|
at::Tensor tensor;
|
||||||
|
if (tensor_md->quantized_schema() != nullptr) {
|
||||||
|
// is quantized
|
||||||
|
const auto* schema = tensor_md->quantized_schema();
|
||||||
|
auto qscheme_type = static_cast<at::QScheme>(schema->qscheme());
|
||||||
|
switch (qscheme_type) {
|
||||||
|
case at::kPerTensorAffine: {
|
||||||
|
tensor = at::_empty_affine_quantized(
|
||||||
|
{0}, options, schema->scale(), schema->zero_point());
|
||||||
|
} break;
|
||||||
|
case at::kPerChannelAffineFloatQParams:
|
||||||
|
case at::kPerChannelAffine: {
|
||||||
|
at::Tensor scales = parseTensor(schema->scales());
|
||||||
|
at::Tensor zero_points = parseTensor(schema->zero_points());
|
||||||
|
tensor = at::_empty_per_channel_affine_quantized(
|
||||||
|
{0}, scales, zero_points, schema->axis(), options);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"Unsupported tensor quantization type in serialization ",
|
||||||
|
toString(qscheme_type));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tensor = at::empty({0}, options);
|
||||||
|
}
|
||||||
|
at::TensorImpl* impl = tensor.unsafeGetTensorImpl();
|
||||||
|
|
||||||
|
c10::Storage storage;
|
||||||
|
storage = getStorage(tensor_md->storage_location_index());
|
||||||
|
impl->set_storage_keep_dtype(storage);
|
||||||
|
impl->set_storage_offset(tensor_md->storage_offset());
|
||||||
|
|
||||||
|
std::vector<int64_t> size{
|
||||||
|
tensor_md->sizes()->begin(), tensor_md->sizes()->end()};
|
||||||
|
std::vector<int64_t> stride{
|
||||||
|
tensor_md->strides()->begin(), tensor_md->strides()->end()};
|
||||||
|
impl->set_sizes_and_strides(size, stride);
|
||||||
|
tensor = autograd::make_variable(tensor, tensor_md->requires_grad());
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
IValue FlatbufferLoader::parseList(const mobile::serialization::List* list) {
|
||||||
|
auto res = c10::impl::GenericList(AnyType::get());
|
||||||
|
for (int i : *list->items()) {
|
||||||
|
res.emplace_back(getIValue(i));
|
||||||
|
}
|
||||||
|
auto type =
|
||||||
|
getOrCreateTypeAnnotations(list->annotation_str())->cast<ListType>();
|
||||||
|
res.unsafeSetElementType(type->getElementType());
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
IValue FlatbufferLoader::parseTuple(const mobile::serialization::Tuple* tuple) {
|
||||||
|
std::vector<IValue> res;
|
||||||
|
for (int i : *tuple->items()) {
|
||||||
|
res.emplace_back(getIValue(i));
|
||||||
|
}
|
||||||
|
return c10::ivalue::Tuple::create(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
IValue FlatbufferLoader::parseDict(const mobile::serialization::Dict* dict) {
|
||||||
|
auto result = c10::impl::GenericDict(AnyType::get(), AnyType::get());
|
||||||
|
const auto* keys = dict->keys();
|
||||||
|
const auto* values = dict->values();
|
||||||
|
for (size_t i = 0; i < keys->size(); ++i) {
|
||||||
|
uint32_t key = keys->Get(i);
|
||||||
|
uint32_t val = values->Get(i);
|
||||||
|
result.insert_or_assign(getIValue(key), getIValue(val));
|
||||||
|
}
|
||||||
|
auto type =
|
||||||
|
getOrCreateTypeAnnotations(dict->annotation_str())->cast<DictType>();
|
||||||
|
result.unsafeSetKeyType(type->getKeyType());
|
||||||
|
result.unsafeSetValueType(type->getValueType());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
IValue FlatbufferLoader::parseObject(
|
||||||
|
const mobile::serialization::Object* object) {
|
||||||
|
const mobile::serialization::ObjectType* obj_type =
|
||||||
|
module_->object_types()->Get(object->type_index());
|
||||||
|
auto cls = getType(object->type_index());
|
||||||
|
bool initialized = true;
|
||||||
|
if (cls == nullptr) {
|
||||||
|
c10::string_view qn_str(
|
||||||
|
obj_type->type_name()->c_str(), obj_type->type_name()->size());
|
||||||
|
if (qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) {
|
||||||
|
c10::QualifiedName qn(obj_type->type_name()->str());
|
||||||
|
cls = cu_->get_class(qn);
|
||||||
|
if (cls == nullptr) {
|
||||||
|
cls = ClassType::create(qn, cu_, true);
|
||||||
|
cu_->register_type(cls);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cls = c10::parseType(std::string(qn_str))->cast<ClassType>();
|
||||||
|
}
|
||||||
|
TORCH_CHECK(object->type_index() < all_ivalues_.size());
|
||||||
|
all_types_[object->type_index()] = cls;
|
||||||
|
initialized = false;
|
||||||
|
}
|
||||||
|
Stack stack;
|
||||||
|
switch (obj_type->type()) {
|
||||||
|
case mobile::serialization::TypeType::CLASS_WITH_FIELD: {
|
||||||
|
auto obj = c10::ivalue::Object::create(
|
||||||
|
at::StrongTypePtr(cu_, cls), object->attrs()->size());
|
||||||
|
if (!initialized) {
|
||||||
|
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
|
||||||
|
IValue val = getIValue(object->attrs()->Get(i));
|
||||||
|
cls->addAttribute(obj_type->attr_names()->Get(i)->str(), val.type());
|
||||||
|
obj->setSlot(i, std::move(val));
|
||||||
|
}
|
||||||
|
initialized_types_.insert(object->type_index());
|
||||||
|
} else {
|
||||||
|
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
|
||||||
|
IValue val = getIValue(object->attrs()->Get(i));
|
||||||
|
obj->setSlot(i, std::move(val));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
case mobile::serialization::TypeType::CLASS_WITH_SETSTATE: {
|
||||||
|
IValue input = getIValue(object->state());
|
||||||
|
mobile::Function* setstate = getFunction(object->setstate_func());
|
||||||
|
auto obj = c10::ivalue::Object::create(at::StrongTypePtr(cu_, cls), 0);
|
||||||
|
stack.push_back(obj);
|
||||||
|
stack.emplace_back(std::move(input));
|
||||||
|
setstate->run(stack);
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
case mobile::serialization::TypeType::CUSTOM_CLASS: {
|
||||||
|
auto custom_class_type =
|
||||||
|
torch::jit::getCustomClass(cls->name()->qualifiedName());
|
||||||
|
IValue input = getIValue(object->state());
|
||||||
|
auto obj = c10::ivalue::Object::create(
|
||||||
|
c10::StrongTypePtr(nullptr, custom_class_type), 1);
|
||||||
|
stack.push_back(obj);
|
||||||
|
stack.emplace_back(std::move(input));
|
||||||
|
custom_class_type->getMethod("__setstate__").run(stack);
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
AT_ASSERT(false, "need to be object");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
std::vector<T> parseListNative(const U* list) {
|
||||||
|
return {list->items()->begin(), list->items()->end()};
|
||||||
|
}
|
||||||
|
|
||||||
|
IValue FlatbufferLoader::parseIValue(
|
||||||
|
const mobile::serialization::IValue* ivalue) {
|
||||||
|
switch (ivalue->val_type()) {
|
||||||
|
case mobile::serialization::IValueUnion::NONE:
|
||||||
|
return {};
|
||||||
|
case mobile::serialization::IValueUnion::Int:
|
||||||
|
return ivalue->val_as_Int()->int_val();
|
||||||
|
case mobile::serialization::IValueUnion::Bool:
|
||||||
|
return ivalue->val_as_Bool()->bool_val();
|
||||||
|
case mobile::serialization::IValueUnion::Double:
|
||||||
|
return ivalue->val_as_Double()->double_val();
|
||||||
|
case mobile::serialization::IValueUnion::ComplexDouble: {
|
||||||
|
const auto* comp = ivalue->val_as_ComplexDouble();
|
||||||
|
return c10::complex<double>(comp->real(), comp->imag());
|
||||||
|
}
|
||||||
|
case mobile::serialization::IValueUnion::TensorMetadata:
|
||||||
|
return parseTensor(ivalue->val_as_TensorMetadata());
|
||||||
|
case mobile::serialization::IValueUnion::String:
|
||||||
|
return ivalue->val_as_String()->data()->str();
|
||||||
|
case mobile::serialization::IValueUnion::List:
|
||||||
|
return parseList(ivalue->val_as_List());
|
||||||
|
case mobile::serialization::IValueUnion::IntList:
|
||||||
|
return parseListNative<int64_t>(ivalue->val_as_IntList());
|
||||||
|
case mobile::serialization::IValueUnion::DoubleList:
|
||||||
|
return parseListNative<double>(ivalue->val_as_DoubleList());
|
||||||
|
case mobile::serialization::IValueUnion::BoolList: {
|
||||||
|
std::vector<uint8_t> res =
|
||||||
|
parseListNative<uint8_t>(ivalue->val_as_BoolList());
|
||||||
|
c10::List<bool> boollist;
|
||||||
|
for (auto x : res) {
|
||||||
|
boollist.push_back(x);
|
||||||
|
}
|
||||||
|
return boollist;
|
||||||
|
}
|
||||||
|
case mobile::serialization::IValueUnion::Tuple:
|
||||||
|
return parseTuple(ivalue->val_as_Tuple());
|
||||||
|
case mobile::serialization::IValueUnion::Dict:
|
||||||
|
return parseDict(ivalue->val_as_Dict());
|
||||||
|
case mobile::serialization::IValueUnion::Object: {
|
||||||
|
auto val = parseObject(ivalue->val_as_Object());
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
case mobile::serialization::IValueUnion::Device: {
|
||||||
|
return c10::Device(ivalue->val_as_Device()->str()->str());
|
||||||
|
}
|
||||||
|
case mobile::serialization::IValueUnion::EnumValue: {
|
||||||
|
const auto* enum_val = ivalue->val_as_EnumValue();
|
||||||
|
auto enum_type = getOrCreateTypeAnnotations(enum_val->type_name())
|
||||||
|
->cast<c10::EnumType>();
|
||||||
|
AT_ASSERT(
|
||||||
|
enum_type,
|
||||||
|
"Enum with type: " + enum_val->type_name()->str() + " not found.");
|
||||||
|
IValue val = getIValue(enum_val->value());
|
||||||
|
for (const auto& p : enum_type->enumNamesValues()) {
|
||||||
|
if (p.second == val) {
|
||||||
|
auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
|
||||||
|
enum_type, p.first, p.second);
|
||||||
|
return IValue(std::move(enum_holder));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
AT_ASSERT(
|
||||||
|
false,
|
||||||
|
"Enum with type: " + enum_val->type_name()->str() + " not found.");
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void deleteNothing2(void*);
|
||||||
|
void deleteNothing2(void*) {}
|
||||||
|
|
||||||
|
c10::Storage FlatbufferLoader::getStorage(uint32_t index) {
|
||||||
|
TORCH_CHECK(index < storage_loaded_.size());
|
||||||
|
TORCH_CHECK(index < storages_.size());
|
||||||
|
if (!storage_loaded_[index]) {
|
||||||
|
auto* storage = module_->storage_data()->GetMutableObject(index);
|
||||||
|
size_t size = storage->data()->size();
|
||||||
|
void* ptr = static_cast<void*>(storage->mutable_data()->data());
|
||||||
|
at::DataPtr data(ptr, ptr, deleteNothing2, DeviceType::CPU);
|
||||||
|
storages_[index] =
|
||||||
|
c10::Storage(c10::Storage::use_byte_size_t(), size, std::move(data));
|
||||||
|
storage_loaded_[index] = true;
|
||||||
|
}
|
||||||
|
return storages_[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr FlatbufferLoader::getOrCreateTypeAnnotations(
|
||||||
|
const flatbuffers::String* offset) {
|
||||||
|
auto iter = type_annotations_.find(offset);
|
||||||
|
if (iter != type_annotations_.end()) {
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
TypePtr type;
|
||||||
|
c10::string_view qn_str(offset->c_str(), offset->size());
|
||||||
|
c10::QualifiedName qn(offset->str());
|
||||||
|
if (qn_str.starts_with(kCustomClassPrefix)) {
|
||||||
|
type = getCustomClass(qn.qualifiedName());
|
||||||
|
TORCH_CHECK(
|
||||||
|
type,
|
||||||
|
"The implementation of class ",
|
||||||
|
qn.qualifiedName(),
|
||||||
|
" cannot be found.");
|
||||||
|
} else if (
|
||||||
|
qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) {
|
||||||
|
if (cu_->get_class(qn) == nullptr) {
|
||||||
|
auto classtype = ClassType::create(qn, cu_, true);
|
||||||
|
cu_->register_type(classtype);
|
||||||
|
type = classtype;
|
||||||
|
} else {
|
||||||
|
type = cu_->get_class(qn);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
type = c10::parseType(qn.qualifiedName());
|
||||||
|
}
|
||||||
|
type_annotations_[offset] = type;
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
mobile::Module parse_and_initialize_mobile_module(
|
||||||
|
std::shared_ptr<char> data,
|
||||||
|
size_t,
|
||||||
|
c10::optional<at::Device>) {
|
||||||
|
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
|
||||||
|
mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module);
|
||||||
|
m.set_delete_memory(std::move(data));
|
||||||
|
return m;
|
||||||
|
}
|
||||||
|
|
||||||
|
mobile::Module initialize_mobile_module(
|
||||||
|
mobile::serialization::Module* flatbuffer_module,
|
||||||
|
c10::optional<at::Device>) {
|
||||||
|
mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module);
|
||||||
|
return m;
|
||||||
|
}
|
||||||
|
|
||||||
|
mobile::Module load_mobile_module_from_file(
|
||||||
|
const std::string& filename,
|
||||||
|
c10::optional<c10::Device> device) {
|
||||||
|
#if defined(HAVE_MMAP)
|
||||||
|
int fd = open(filename.c_str(), O_RDONLY);
|
||||||
|
struct stat statbuf {};
|
||||||
|
fstat(fd, &statbuf);
|
||||||
|
int size = statbuf.st_size;
|
||||||
|
void* ptr = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
|
||||||
|
close(fd);
|
||||||
|
auto deleter = [statbuf](char* ptr) { munmap(ptr, statbuf.st_size); };
|
||||||
|
std::shared_ptr<char> data(reinterpret_cast<char*>(ptr), deleter);
|
||||||
|
#else
|
||||||
|
FILE* f = fopen(filename.c_str(), "rb");
|
||||||
|
fseek(f, 0, SEEK_END);
|
||||||
|
long size = ftell(f);
|
||||||
|
fseek(f, 0, SEEK_SET);
|
||||||
|
std::shared_ptr<char> data(static_cast<char*>(malloc(size)), free); // NOLINT
|
||||||
|
fread(data.get(), size, 1, f);
|
||||||
|
fclose(f);
|
||||||
|
#endif
|
||||||
|
return parse_and_initialize_mobile_module(std::move(data), size, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
54
torch/csrc/jit/mobile/flatbuffer_loader.h
Normal file
54
torch/csrc/jit/mobile/flatbuffer_loader.h
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/core/ivalue.h>
|
||||||
|
#include <caffe2/serialize/inline_container.h>
|
||||||
|
#include <torch/csrc/jit/mobile/function.h>
|
||||||
|
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||||
|
#include <torch/csrc/jit/mobile/module.h>
|
||||||
|
#include <torch/csrc/jit/runtime/instruction.h>
|
||||||
|
#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
|
||||||
|
#include <torch/custom_class.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
// On high level, to produce a Module from a file on disk, we need to go
|
||||||
|
// through the follow steps:
|
||||||
|
// 1. Read: Read the file from disk -> memory
|
||||||
|
// 2. Deserialize: Parse the bytes to produce some in memory manipulable
|
||||||
|
// structure
|
||||||
|
// 3. Module initialization: Produce mobile::Module out of the structure
|
||||||
|
// produced in 2.
|
||||||
|
// Under this context, the structure described in 2. is
|
||||||
|
// mobile::serialization::Module
|
||||||
|
|
||||||
|
// Parse a mobile::Module from flatbuffer's in-memory Module representation.
|
||||||
|
// The caller is assumed to manage the lifetimes of Module.
|
||||||
|
// This function does step 3 described above.
|
||||||
|
TORCH_API mobile::Module initialize_mobile_module(
|
||||||
|
mobile::serialization::Module* flatbuffer_module,
|
||||||
|
c10::optional<at::Device> device = c10::nullopt);
|
||||||
|
|
||||||
|
// Parse a mobile::Module from raw bytes.
|
||||||
|
// ownership of data is shared to the returned Module.
|
||||||
|
// (Feel free to pass in a unique_ptr too!)
|
||||||
|
// This function does steps 2+3 described above
|
||||||
|
TORCH_API mobile::Module parse_and_initialize_mobile_module(
|
||||||
|
std::shared_ptr<char> data,
|
||||||
|
size_t size,
|
||||||
|
c10::optional<at::Device> device = c10::nullopt);
|
||||||
|
|
||||||
|
// Load a mobile::Module from a filepath.
|
||||||
|
// This function does steps 1+2+3 described above.
|
||||||
|
// We need to have this as a convienience because Python
|
||||||
|
// API will need to wrap this. C++ clients should use one
|
||||||
|
// versions above.
|
||||||
|
TORCH_API mobile::Module load_mobile_module_from_file(
|
||||||
|
const std::string& filename,
|
||||||
|
c10::optional<at::Device> device = c10::nullopt);
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
|
|
@ -130,12 +130,19 @@ class TORCH_API Module {
|
||||||
return *cu_.get();
|
return *cu_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_delete_memory(std::shared_ptr<char> delete_mem) {
|
||||||
|
mem_to_delete_ = delete_mem;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
c10::intrusive_ptr<c10::ivalue::Object> object_;
|
c10::intrusive_ptr<c10::ivalue::Object> object_;
|
||||||
std::unordered_map<std::string, std::string> metadata_;
|
std::unordered_map<std::string, std::string> metadata_;
|
||||||
std::shared_ptr<CompilationUnit> cu_;
|
std::shared_ptr<CompilationUnit> cu_;
|
||||||
MobileDebugTable debug_table_;
|
MobileDebugTable debug_table_;
|
||||||
bool has_debug_handles_ = false;
|
bool has_debug_handles_ = false;
|
||||||
|
|
||||||
|
// Extra handle for the module to delete when itself is deleted
|
||||||
|
std::shared_ptr<char> mem_to_delete_;
|
||||||
};
|
};
|
||||||
} // namespace mobile
|
} // namespace mobile
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||||
#include <torch/csrc/jit/frontend/sugared_value.h>
|
#include <torch/csrc/jit/frontend/sugared_value.h>
|
||||||
#include <torch/csrc/jit/mobile/backport.h>
|
#include <torch/csrc/jit/mobile/backport.h>
|
||||||
|
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||||
#include <torch/csrc/jit/mobile/import.h>
|
#include <torch/csrc/jit/mobile/import.h>
|
||||||
#include <torch/csrc/jit/mobile/model_compatibility.h>
|
#include <torch/csrc/jit/mobile/model_compatibility.h>
|
||||||
#include <torch/csrc/jit/mobile/module.h>
|
#include <torch/csrc/jit/mobile/module.h>
|
||||||
|
|
@ -31,9 +32,12 @@
|
||||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||||
#include <torch/csrc/jit/runtime/logging.h>
|
#include <torch/csrc/jit/runtime/logging.h>
|
||||||
#include <torch/csrc/jit/serialization/export.h>
|
#include <torch/csrc/jit/serialization/export.h>
|
||||||
|
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
||||||
|
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||||
#include <torch/csrc/jit/serialization/import_source.h>
|
#include <torch/csrc/jit/serialization/import_source.h>
|
||||||
#include <torch/csrc/jit/serialization/python_print.h>
|
#include <torch/csrc/jit/serialization/python_print.h>
|
||||||
#include <torch/csrc/jit/testing/hooks_for_testing.h>
|
#include <torch/csrc/jit/testing/hooks_for_testing.h>
|
||||||
|
#include <torch/csrc/jit/testing/module_differ.h>
|
||||||
|
|
||||||
#include <torch/csrc/api/include/torch/ordered_dict.h>
|
#include <torch/csrc/api/include/torch/ordered_dict.h>
|
||||||
|
|
||||||
|
|
@ -47,6 +51,7 @@
|
||||||
#include <pybind11/stl_bind.h>
|
#include <pybind11/stl_bind.h>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <cstdlib>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
@ -1762,18 +1767,42 @@ void initJitScriptBindings(PyObject* module) {
|
||||||
});
|
});
|
||||||
m.def(
|
m.def(
|
||||||
"_load_for_lite_interpreter",
|
"_load_for_lite_interpreter",
|
||||||
[](const std::string& filename, py::object map_location) {
|
[](const std::string& filename,
|
||||||
|
py::object map_location,
|
||||||
|
bool is_flatbuffer = false) {
|
||||||
c10::optional<at::Device> optional_device;
|
c10::optional<at::Device> optional_device;
|
||||||
if (!map_location.is(py::none())) {
|
if (!map_location.is(py::none())) {
|
||||||
AT_ASSERT(THPDevice_Check(map_location.ptr()));
|
AT_ASSERT(THPDevice_Check(map_location.ptr()));
|
||||||
optional_device =
|
optional_device =
|
||||||
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
|
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
|
||||||
}
|
}
|
||||||
|
if (is_flatbuffer) {
|
||||||
|
return load_mobile_module_from_file(filename, optional_device);
|
||||||
|
} else {
|
||||||
return _load_for_mobile(filename, optional_device);
|
return _load_for_mobile(filename, optional_device);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
m.def(
|
||||||
|
"_save_mobile_module",
|
||||||
|
[](const torch::jit::mobile::Module& m, std::string& filename) {
|
||||||
|
save_mobile_module(m, filename);
|
||||||
|
});
|
||||||
|
m.def(
|
||||||
|
"_module_equals",
|
||||||
|
[](const torch::jit::mobile::Module& lhs,
|
||||||
|
const torch::jit::mobile::Module& rhs) {
|
||||||
|
return moduleEquals(lhs, rhs);
|
||||||
|
});
|
||||||
|
m.def("_jit_module_to_mobile", [](const torch::jit::Module& mod) {
|
||||||
|
CompilationOptions options;
|
||||||
|
return jitModuleToMobile(mod, options);
|
||||||
|
});
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"_load_for_lite_interpreter_from_buffer",
|
"_load_for_lite_interpreter_from_buffer",
|
||||||
[](const std::string& buffer, py::object map_location) {
|
[](const std::string& buffer,
|
||||||
|
py::object map_location,
|
||||||
|
bool is_flatbuffer = false) {
|
||||||
std::istringstream in(buffer);
|
std::istringstream in(buffer);
|
||||||
c10::optional<at::Device> optional_device;
|
c10::optional<at::Device> optional_device;
|
||||||
if (!map_location.is(py::none())) {
|
if (!map_location.is(py::none())) {
|
||||||
|
|
@ -1781,7 +1810,16 @@ void initJitScriptBindings(PyObject* module) {
|
||||||
optional_device =
|
optional_device =
|
||||||
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
|
reinterpret_cast<THPDevice*>(map_location.ptr())->device;
|
||||||
}
|
}
|
||||||
|
if (is_flatbuffer) {
|
||||||
|
size_t size = buffer.size();
|
||||||
|
std::shared_ptr<char> data(
|
||||||
|
static_cast<char*>(malloc(size)), free); // NOLINT
|
||||||
|
memcpy(data.get(), buffer.data(), size); // NOLINT
|
||||||
|
return parse_and_initialize_mobile_module(
|
||||||
|
std::move(data), size, optional_device);
|
||||||
|
} else {
|
||||||
return _load_for_mobile(in, optional_device);
|
return _load_for_mobile(in, optional_device);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
m.def(
|
m.def(
|
||||||
"_backport_for_mobile",
|
"_backport_for_mobile",
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,7 @@ std::ostream& operator<<(std::ostream& out, Instruction inst);
|
||||||
|
|
||||||
bool isOpSupportedInMobile(OpCode op);
|
bool isOpSupportedInMobile(OpCode op);
|
||||||
char const* toString(OpCode op);
|
char const* toString(OpCode op);
|
||||||
|
std::ostream& operator<<(std::ostream& out, Instruction inst);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
682
torch/csrc/jit/serialization/flatbuffer_serializer.cpp
Normal file
682
torch/csrc/jit/serialization/flatbuffer_serializer.cpp
Normal file
|
|
@ -0,0 +1,682 @@
|
||||||
|
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||||
|
|
||||||
|
#include <c10/core/CPUAllocator.h>
|
||||||
|
#include <flatbuffers/flatbuffers.h>
|
||||||
|
#include <torch/csrc/jit/mobile/code.h>
|
||||||
|
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
|
||||||
|
#include <torch/csrc/jit/passes/inliner.h>
|
||||||
|
#include <torch/csrc/jit/runtime/instruction.h>
|
||||||
|
#include <torch/csrc/jit/serialization/export.h>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
using flatbuffers::FlatBufferBuilder;
|
||||||
|
using mobile::serialization::CreateArg;
|
||||||
|
using mobile::serialization::CreateDebugInfo;
|
||||||
|
using mobile::serialization::CreateDict;
|
||||||
|
using mobile::serialization::CreateFunctionDirect;
|
||||||
|
using mobile::serialization::CreateIValue;
|
||||||
|
using mobile::serialization::CreateList;
|
||||||
|
using mobile::serialization::CreateModule;
|
||||||
|
using mobile::serialization::CreateObject;
|
||||||
|
using mobile::serialization::CreateOperator;
|
||||||
|
using mobile::serialization::CreateTensorMetadataDirect;
|
||||||
|
using mobile::serialization::CreateTupleDirect;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class FlatbufferSerializer {
|
||||||
|
public:
|
||||||
|
FlatbufferSerializer() = default;
|
||||||
|
|
||||||
|
flatbuffers::DetachedBuffer serializeModule(
|
||||||
|
const mobile::Module& module,
|
||||||
|
bool include_tensor_data_in_flatbuffer);
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename It>
|
||||||
|
std::vector<uint32_t> storeIValuesAndGetIndexes(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
It begin,
|
||||||
|
It end) {
|
||||||
|
std::vector<uint32_t> indexes;
|
||||||
|
for (; begin != end; ++begin) {
|
||||||
|
indexes.push_back(storeIValueAndGetIndex(fbb, *begin));
|
||||||
|
}
|
||||||
|
return indexes;
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::Tuple> tupleToFB(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const IValue& tuple);
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::List> listToFB(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const IValue& list);
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::Dict> dictToFB(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const IValue& list);
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::Object> objectToFB(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const IValue& ivalue);
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::TensorMetadata> tensorToFB(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const IValue& ivalue);
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::Function> functionToFB(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const std::string& qn,
|
||||||
|
const mobile::Function& func);
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::IValue> iValueToFB(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const IValue& ivalue);
|
||||||
|
|
||||||
|
flatbuffers::Offset<jit::mobile::serialization::Schema> CreateFBSchema(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const std::vector<Argument>& args,
|
||||||
|
const std::vector<Argument>& returns,
|
||||||
|
c10::TypePrinter type_printer);
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::ObjectType> classTypeToFB(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
ClassTypePtr class_ptr);
|
||||||
|
|
||||||
|
uint32_t storeIValueAndGetIndex(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const IValue& ivalue);
|
||||||
|
uint32_t storeFunctionAndGetIndex(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const std::string& qn,
|
||||||
|
const mobile::Function& function);
|
||||||
|
|
||||||
|
uint32_t storeClassTypeAndGetIndex(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
ClassTypePtr class_type);
|
||||||
|
|
||||||
|
// cached stuff
|
||||||
|
uint32_t insertIValue(
|
||||||
|
flatbuffers::Offset<mobile::serialization::IValue> ivalue) {
|
||||||
|
uint32_t size = ivalue_offsets_.size();
|
||||||
|
ivalue_offsets_.push_back(ivalue);
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<at::Tensor> tensor_data_;
|
||||||
|
|
||||||
|
std::unordered_map<const void*, uint32_t> memoized_storage_map_;
|
||||||
|
|
||||||
|
std::vector<flatbuffers::Offset<mobile::serialization::IValue>>
|
||||||
|
ivalue_offsets_;
|
||||||
|
std::vector<flatbuffers::Offset<mobile::serialization::ObjectType>>
|
||||||
|
obj_types_offset_;
|
||||||
|
|
||||||
|
// qualified name to serialized class, type or function
|
||||||
|
std::unordered_map<std::string, uint32_t> qn_to_serialized_values_;
|
||||||
|
|
||||||
|
// cache of some ivalues
|
||||||
|
struct IValueHash {
|
||||||
|
size_t operator()(const IValue& val) const {
|
||||||
|
return IValue::hash(val);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unordered_map<IValue, uint32_t, IValueHash> cached_ivalues_;
|
||||||
|
};
|
||||||
|
|
||||||
|
flatbuffers::Offset<jit::mobile::serialization::Schema> FlatbufferSerializer::
|
||||||
|
CreateFBSchema(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const std::vector<Argument>& args,
|
||||||
|
const std::vector<Argument>& returns,
|
||||||
|
c10::TypePrinter type_printer) {
|
||||||
|
std::vector<flatbuffers::Offset<jit::mobile::serialization::Arg>> arg_vec;
|
||||||
|
arg_vec.reserve(args.size());
|
||||||
|
std::vector<flatbuffers::Offset<jit::mobile::serialization::Arg>> return_vec;
|
||||||
|
return_vec.reserve(returns.size());
|
||||||
|
for (const auto& arg : args) {
|
||||||
|
int index = storeIValueAndGetIndex(fbb, arg.default_value());
|
||||||
|
arg_vec.emplace_back(CreateArg(
|
||||||
|
fbb,
|
||||||
|
fbb.CreateSharedString(arg.name()),
|
||||||
|
fbb.CreateSharedString(arg.type()->annotation_str(type_printer)),
|
||||||
|
index));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& ret : returns) {
|
||||||
|
int index = storeIValueAndGetIndex(fbb, ret.default_value());
|
||||||
|
return_vec.emplace_back(CreateArg(
|
||||||
|
fbb,
|
||||||
|
fbb.CreateSharedString(ret.name()),
|
||||||
|
fbb.CreateSharedString(ret.type()->annotation_str(type_printer)),
|
||||||
|
index));
|
||||||
|
}
|
||||||
|
return CreateSchema(
|
||||||
|
fbb, fbb.CreateVector(arg_vec), fbb.CreateVector(return_vec));
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::Function> FlatbufferSerializer::
|
||||||
|
functionToFB(
|
||||||
|
FlatBufferBuilder& fbb,
|
||||||
|
const std::string& qn,
|
||||||
|
const mobile::Function& func) {
|
||||||
|
const auto* code = func.get_code().get();
|
||||||
|
|
||||||
|
// instructions
|
||||||
|
std::vector<mobile::serialization::Instruction> instruction_vector;
|
||||||
|
for (const auto& inst : code->instructions_) {
|
||||||
|
instruction_vector.emplace_back(inst.op, inst.N, inst.X);
|
||||||
|
}
|
||||||
|
|
||||||
|
// operators
|
||||||
|
std::vector<flatbuffers::Offset<mobile::serialization::Operator>>
|
||||||
|
operator_vector;
|
||||||
|
operator_vector.reserve(code->op_names_.size());
|
||||||
|
for (int i = 0; i < code->op_names_.size(); ++i) {
|
||||||
|
const auto& opname = code->op_names_[i];
|
||||||
|
const int op_size = code->operator_input_sizes_[i];
|
||||||
|
operator_vector.push_back(CreateOperator(
|
||||||
|
fbb,
|
||||||
|
fbb.CreateSharedString(opname.name),
|
||||||
|
fbb.CreateSharedString(opname.overload_name),
|
||||||
|
op_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& constants = code->constants_;
|
||||||
|
|
||||||
|
std::vector<uint32_t> constant_indexes;
|
||||||
|
constant_indexes.reserve(constants.size());
|
||||||
|
for (const auto& constant : constants) {
|
||||||
|
constant_indexes.push_back(storeIValueAndGetIndex(fbb, constant));
|
||||||
|
}
|
||||||
|
|
||||||
|
// types
|
||||||
|
static const std::string torch_prefix("__torch__");
|
||||||
|
static const std::string class_prefix("__torch__.torch.classes");
|
||||||
|
std::vector<flatbuffers::Offset<flatbuffers::String>> type_offsets;
|
||||||
|
|
||||||
|
for (const TypePtr& t : code->types_) {
|
||||||
|
auto type_str = t->annotation_str();
|
||||||
|
if (type_str.find(torch_prefix) == 0) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
type_str.find(class_prefix) == 0,
|
||||||
|
"__torch__ types other than torchbind (__torch__.torch.classes)"
|
||||||
|
"are not supported in lite interpreter. ",
|
||||||
|
"Workaround: instead of using arbitrary class type (class Foo()), ",
|
||||||
|
"define a pytorch class (class Foo(torch.nn.Module)).");
|
||||||
|
}
|
||||||
|
|
||||||
|
type_offsets.push_back(fbb.CreateSharedString(type_str));
|
||||||
|
}
|
||||||
|
|
||||||
|
// since the register location is embedded into the bytecode, pass the
|
||||||
|
// register size
|
||||||
|
auto register_size = static_cast<int>(code->register_size_);
|
||||||
|
|
||||||
|
// schema
|
||||||
|
auto type_printer =
|
||||||
|
[&](const c10::ConstTypePtr& t) -> c10::optional<std::string> {
|
||||||
|
auto namedType = t->cast<c10::NamedType>();
|
||||||
|
if (namedType && namedType->name()) {
|
||||||
|
return namedType->name().value().qualifiedName();
|
||||||
|
}
|
||||||
|
return c10::nullopt;
|
||||||
|
};
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::Schema> schema_offset = 0;
|
||||||
|
if (func.hasSchema()) {
|
||||||
|
const auto& schema = func.getSchema();
|
||||||
|
TORCH_CHECK(
|
||||||
|
schema.overload_name().empty(), // @TODO: is this check correct?
|
||||||
|
"Overloads are not supported in mobile modules.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
!schema.is_vararg(),
|
||||||
|
"Python *args are not supported in mobile modules.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
!schema.is_varret(),
|
||||||
|
"A variable number of return values is not supported in mobile modules.");
|
||||||
|
schema_offset =
|
||||||
|
CreateFBSchema(fbb, schema.arguments(), schema.returns(), type_printer);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto debug_info_offset =
|
||||||
|
CreateDebugInfo(fbb, fbb.CreateVector(code->debug_handles_));
|
||||||
|
|
||||||
|
// auto classtype = schema.arguments()[0].type()->cast<ClassType>();
|
||||||
|
// uint32_t class_type = storeClassTypeAndGetIndex(fbb, classtype);
|
||||||
|
|
||||||
|
auto function_offset = CreateFunctionDirect(
|
||||||
|
fbb,
|
||||||
|
qn.c_str(),
|
||||||
|
&instruction_vector,
|
||||||
|
&operator_vector,
|
||||||
|
&constant_indexes,
|
||||||
|
&type_offsets,
|
||||||
|
register_size,
|
||||||
|
schema_offset,
|
||||||
|
debug_info_offset,
|
||||||
|
0);
|
||||||
|
return function_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule(
|
||||||
|
const mobile::Module& module,
|
||||||
|
bool include_tensor_data_in_flatbuffer) {
|
||||||
|
FlatBufferBuilder fbb;
|
||||||
|
|
||||||
|
// first element is None.
|
||||||
|
insertIValue(CreateIValue(fbb, mobile::serialization::IValueUnion::NONE, 0));
|
||||||
|
|
||||||
|
auto methods = module.get_methods();
|
||||||
|
std::vector<uint32_t> functions_index;
|
||||||
|
functions_index.reserve(methods.size());
|
||||||
|
for (const auto& method : methods) {
|
||||||
|
auto func_offset = storeFunctionAndGetIndex(
|
||||||
|
fbb, method.function().qualname().qualifiedName(), method.function());
|
||||||
|
functions_index.push_back(func_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto functions_offset = fbb.CreateVector(functions_index);
|
||||||
|
uint32_t ivalue_index = storeIValueAndGetIndex(fbb, module._ivalue());
|
||||||
|
|
||||||
|
flatbuffers::Offset<flatbuffers::Vector<
|
||||||
|
flatbuffers::Offset<mobile::serialization::StorageData>>>
|
||||||
|
storage_data_offset = 0;
|
||||||
|
if (include_tensor_data_in_flatbuffer) {
|
||||||
|
std::vector<flatbuffers::Offset<mobile::serialization::StorageData>>
|
||||||
|
storage_data;
|
||||||
|
for (auto td : tensor_data_) {
|
||||||
|
if (td.storage().device_type() != DeviceType::CPU) {
|
||||||
|
td = at::empty({0}, td.options())
|
||||||
|
.set_(
|
||||||
|
td.storage(),
|
||||||
|
/* storage_offset = */ 0,
|
||||||
|
/* size = */
|
||||||
|
{static_cast<int64_t>(
|
||||||
|
td.storage().nbytes() / td.element_size())},
|
||||||
|
/* stride = */ {1})
|
||||||
|
.cpu();
|
||||||
|
}
|
||||||
|
fbb.ForceVectorAlignment(
|
||||||
|
td.storage().nbytes(), sizeof(uint8_t), FLATBUFFERS_MAX_ALIGNMENT);
|
||||||
|
auto storage_offset = mobile::serialization::CreateStorageData(
|
||||||
|
fbb,
|
||||||
|
fbb.CreateVector(
|
||||||
|
reinterpret_cast<const uint8_t*>(td.storage().data()),
|
||||||
|
td.storage().nbytes()));
|
||||||
|
storage_data.push_back(storage_offset);
|
||||||
|
}
|
||||||
|
storage_data_offset = fbb.CreateVector(storage_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto mod = CreateModule(
|
||||||
|
fbb,
|
||||||
|
0, /* version */
|
||||||
|
0, /* extra_files */
|
||||||
|
functions_offset,
|
||||||
|
ivalue_index,
|
||||||
|
fbb.CreateVector(ivalue_offsets_),
|
||||||
|
tensor_data_.size(),
|
||||||
|
storage_data_offset,
|
||||||
|
fbb.CreateVector(obj_types_offset_));
|
||||||
|
fbb.Finish(mod);
|
||||||
|
return fbb.Release();
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::Tuple> FlatbufferSerializer::
|
||||||
|
tupleToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& tuple) {
|
||||||
|
const auto& elements = tuple.toTuple()->elements();
|
||||||
|
std::vector<uint32_t> items =
|
||||||
|
storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end());
|
||||||
|
return CreateTupleDirect(fbb, &items);
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::List> FlatbufferSerializer::listToFB(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const IValue& list) {
|
||||||
|
const auto& elements = list.toList();
|
||||||
|
std::vector<uint32_t> items =
|
||||||
|
storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end());
|
||||||
|
return CreateList(
|
||||||
|
fbb,
|
||||||
|
fbb.CreateVector(items),
|
||||||
|
fbb.CreateSharedString(list.type()->annotation_str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::Dict> FlatbufferSerializer::dictToFB(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const IValue& ivalue) {
|
||||||
|
const auto& dict = ivalue.toGenericDict();
|
||||||
|
std::vector<uint32_t> keys;
|
||||||
|
std::vector<uint32_t> values;
|
||||||
|
keys.reserve(dict.size());
|
||||||
|
values.reserve(dict.size());
|
||||||
|
for (const auto& entry : dict) {
|
||||||
|
int key_index = storeIValueAndGetIndex(fbb, entry.key());
|
||||||
|
keys.push_back(key_index);
|
||||||
|
int value_index = storeIValueAndGetIndex(fbb, entry.value());
|
||||||
|
values.push_back(value_index);
|
||||||
|
}
|
||||||
|
return CreateDict(
|
||||||
|
fbb,
|
||||||
|
fbb.CreateVector(keys),
|
||||||
|
fbb.CreateVector(values),
|
||||||
|
fbb.CreateSharedString(ivalue.type()->annotation_str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::ObjectType> FlatbufferSerializer::
|
||||||
|
classTypeToFB(FlatBufferBuilder& fbb, ClassTypePtr class_ptr) {
|
||||||
|
mobile::serialization::TypeType typetype =
|
||||||
|
mobile::serialization::TypeType::UNSET;
|
||||||
|
|
||||||
|
flatbuffers::Offset<
|
||||||
|
flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
|
||||||
|
names_offset = 0;
|
||||||
|
Function* setstate = class_ptr->findMethod("__setstate__");
|
||||||
|
if (setstate == nullptr) {
|
||||||
|
const std::string setstate_qn =
|
||||||
|
class_ptr->name()->qualifiedName() + ".__setstate__";
|
||||||
|
if (qn_to_serialized_values_.find(setstate_qn) !=
|
||||||
|
qn_to_serialized_values_.end()) {
|
||||||
|
typetype = mobile::serialization::TypeType::CLASS_WITH_SETSTATE;
|
||||||
|
} else {
|
||||||
|
size_t num_attr = class_ptr->numAttributes();
|
||||||
|
std::vector<flatbuffers::Offset<flatbuffers::String>> names;
|
||||||
|
std::vector<uint32_t> type_index;
|
||||||
|
for (size_t i = 0; i < num_attr; ++i) {
|
||||||
|
names.push_back(fbb.CreateSharedString(class_ptr->getAttributeName(i)));
|
||||||
|
}
|
||||||
|
names_offset = fbb.CreateVector(names);
|
||||||
|
typetype = mobile::serialization::TypeType::CLASS_WITH_FIELD;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto* mobile_func = dynamic_cast<mobile::Function*>(setstate);
|
||||||
|
if (mobile_func != nullptr) {
|
||||||
|
typetype = mobile::serialization::TypeType::CLASS_WITH_SETSTATE;
|
||||||
|
} else {
|
||||||
|
typetype = mobile::serialization::TypeType::CUSTOM_CLASS;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto name_offset = fbb.CreateString(class_ptr->name()->qualifiedName());
|
||||||
|
return CreateObjectType(fbb, name_offset, typetype, names_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t FlatbufferSerializer::storeFunctionAndGetIndex(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const std::string& qn,
|
||||||
|
const mobile::Function& function) {
|
||||||
|
auto iter = qn_to_serialized_values_.find(qn);
|
||||||
|
if (iter != qn_to_serialized_values_.end()) {
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto offset = CreateIValue(
|
||||||
|
fbb,
|
||||||
|
mobile::serialization::IValueUnion::Function,
|
||||||
|
functionToFB(fbb, qn, function).Union());
|
||||||
|
|
||||||
|
uint32_t index = insertIValue(offset);
|
||||||
|
qn_to_serialized_values_[qn] = index;
|
||||||
|
return index;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t FlatbufferSerializer::storeClassTypeAndGetIndex(
|
||||||
|
FlatBufferBuilder& fbb,
|
||||||
|
ClassTypePtr class_ptr) {
|
||||||
|
const auto& type_str = class_ptr->name()->qualifiedName();
|
||||||
|
auto iter = qn_to_serialized_values_.find(type_str);
|
||||||
|
if (iter != qn_to_serialized_values_.end()) {
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto offset = classTypeToFB(fbb, class_ptr);
|
||||||
|
uint32_t res = obj_types_offset_.size();
|
||||||
|
obj_types_offset_.push_back(offset);
|
||||||
|
qn_to_serialized_values_[type_str] = res;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::Object> FlatbufferSerializer::
|
||||||
|
objectToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) {
|
||||||
|
auto obj = ivalue.toObject();
|
||||||
|
auto type = obj->type();
|
||||||
|
// rename type?
|
||||||
|
// check getstate
|
||||||
|
|
||||||
|
// save state as ivalue
|
||||||
|
flatbuffers::Offset<flatbuffers::Vector<uint32_t>> attrs = 0;
|
||||||
|
uint32_t state_index = 0;
|
||||||
|
uint32_t setstate_func_index = 0;
|
||||||
|
const auto qn = type->name()->qualifiedName() + ".__setstate__";
|
||||||
|
auto getstate = type->findMethod("__getstate__");
|
||||||
|
auto setstate = type->findMethod("__setstate__");
|
||||||
|
if (getstate && setstate) {
|
||||||
|
auto state = (*getstate)({obj});
|
||||||
|
state_index = storeIValueAndGetIndex(fbb, state);
|
||||||
|
auto func_index = qn_to_serialized_values_.find(qn);
|
||||||
|
if (func_index != qn_to_serialized_values_.end()) {
|
||||||
|
setstate_func_index = func_index->second;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
size_t num_attr = type->numAttributes();
|
||||||
|
std::vector<uint32_t> tuple_index;
|
||||||
|
for (size_t i = 0; i < num_attr; ++i) {
|
||||||
|
tuple_index.push_back(storeIValueAndGetIndex(fbb, obj->getSlot(i)));
|
||||||
|
}
|
||||||
|
attrs = fbb.CreateVector(tuple_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t type_index = storeClassTypeAndGetIndex(fbb, type);
|
||||||
|
return CreateObject(fbb, type_index, state_index, attrs, setstate_func_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::TensorMetadata> FlatbufferSerializer::
|
||||||
|
FlatbufferSerializer::tensorToFB(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const IValue& ivalue) {
|
||||||
|
auto& tensor = ivalue.toTensor();
|
||||||
|
bool quantized = tensor.is_quantized();
|
||||||
|
const at::Storage& storage = tensor.storage();
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::QuantizedSchema> qschema_offset =
|
||||||
|
0;
|
||||||
|
if (quantized) {
|
||||||
|
double scale = 0;
|
||||||
|
int32_t zero_point = 0;
|
||||||
|
flatbuffers::Offset<mobile::serialization::TensorMetadata> scales = 0;
|
||||||
|
flatbuffers::Offset<mobile::serialization::TensorMetadata> zero_points = 0;
|
||||||
|
int32_t axis = 0;
|
||||||
|
|
||||||
|
switch (tensor.qscheme()) {
|
||||||
|
case at::kPerTensorAffine:
|
||||||
|
scale = tensor.q_scale();
|
||||||
|
zero_point = tensor.q_zero_point();
|
||||||
|
break;
|
||||||
|
case at::kPerChannelAffineFloatQParams:
|
||||||
|
case at::kPerChannelAffine: {
|
||||||
|
scales = tensorToFB(fbb, tensor.q_per_channel_scales());
|
||||||
|
zero_points = tensorToFB(fbb, tensor.q_per_channel_zero_points());
|
||||||
|
axis = tensor.q_per_channel_axis();
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"Unsupported tensor quantization type in serialization ",
|
||||||
|
toString(tensor.qscheme()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
qschema_offset = mobile::serialization::CreateQuantizedSchema(
|
||||||
|
fbb,
|
||||||
|
static_cast<int8_t>(tensor.qscheme()),
|
||||||
|
scale,
|
||||||
|
zero_point,
|
||||||
|
scales,
|
||||||
|
zero_points,
|
||||||
|
axis);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* addr = storage.unsafeGetStorageImpl();
|
||||||
|
uint32_t storage_index = 0;
|
||||||
|
auto it = memoized_storage_map_.find(addr);
|
||||||
|
if (it != memoized_storage_map_.end()) {
|
||||||
|
storage_index = it->second;
|
||||||
|
} else {
|
||||||
|
storage_index = tensor_data_.size();
|
||||||
|
memoized_storage_map_[addr] = storage_index;
|
||||||
|
tensor_data_.push_back(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> sizes{tensor.sizes().begin(), tensor.sizes().end()};
|
||||||
|
std::vector<int> strides{tensor.strides().begin(), tensor.strides().end()};
|
||||||
|
|
||||||
|
return CreateTensorMetadataDirect(
|
||||||
|
fbb,
|
||||||
|
/* storage_location_index */ storage_index,
|
||||||
|
/* scalar_type */ static_cast<int8_t>(tensor.scalar_type()),
|
||||||
|
/* int32_t storage_offset */ tensor.storage_offset(),
|
||||||
|
/* sizes */ &sizes,
|
||||||
|
/* strides */ &strides,
|
||||||
|
/* bool requires_grad */ tensor.requires_grad(),
|
||||||
|
/* qschema */ qschema_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t FlatbufferSerializer::storeIValueAndGetIndex(
|
||||||
|
flatbuffers::FlatBufferBuilder& fbb,
|
||||||
|
const IValue& ivalue) {
|
||||||
|
if (ivalue.isNone()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
auto iter = cached_ivalues_.find(ivalue);
|
||||||
|
if (iter != cached_ivalues_.end()) {
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
} catch (const std::runtime_error&) {
|
||||||
|
} catch (const c10::Error&) {
|
||||||
|
}
|
||||||
|
|
||||||
|
auto offset = iValueToFB(fbb, ivalue);
|
||||||
|
uint32_t index = insertIValue(offset);
|
||||||
|
try {
|
||||||
|
cached_ivalues_[ivalue] = index;
|
||||||
|
} catch (const std::runtime_error&) {
|
||||||
|
} catch (const c10::Error&) {
|
||||||
|
}
|
||||||
|
|
||||||
|
return index;
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<mobile::serialization::IValue> FlatbufferSerializer::
|
||||||
|
iValueToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) {
|
||||||
|
using mobile::serialization::IValueUnion;
|
||||||
|
|
||||||
|
IValueUnion ivalue_type = IValueUnion::NONE;
|
||||||
|
flatbuffers::Offset<void> offset = 0;
|
||||||
|
|
||||||
|
if (ivalue.isTensor()) {
|
||||||
|
ivalue_type = IValueUnion::TensorMetadata;
|
||||||
|
offset = tensorToFB(fbb, ivalue).Union();
|
||||||
|
} else if (ivalue.isTuple()) {
|
||||||
|
ivalue_type = IValueUnion::Tuple;
|
||||||
|
offset = tupleToFB(fbb, ivalue).Union();
|
||||||
|
} else if (ivalue.isDouble()) {
|
||||||
|
ivalue_type = IValueUnion::Double;
|
||||||
|
offset = fbb.CreateStruct(mobile::serialization::Double(ivalue.toDouble()))
|
||||||
|
.Union();
|
||||||
|
} else if (ivalue.isComplexDouble()) {
|
||||||
|
auto comp = ivalue.toComplexDouble();
|
||||||
|
ivalue_type = IValueUnion::ComplexDouble;
|
||||||
|
offset = fbb.CreateStruct(mobile::serialization::ComplexDouble(
|
||||||
|
comp.real(), comp.imag()))
|
||||||
|
.Union();
|
||||||
|
} else if (ivalue.isInt()) {
|
||||||
|
ivalue_type = IValueUnion::Int;
|
||||||
|
offset =
|
||||||
|
fbb.CreateStruct(mobile::serialization::Int(ivalue.toInt())).Union();
|
||||||
|
} else if (ivalue.isBool()) {
|
||||||
|
ivalue_type = IValueUnion::Bool;
|
||||||
|
offset =
|
||||||
|
fbb.CreateStruct(mobile::serialization::Bool(ivalue.toBool())).Union();
|
||||||
|
} else if (ivalue.isString()) {
|
||||||
|
ivalue_type = IValueUnion::String;
|
||||||
|
offset = mobile::serialization::CreateString(
|
||||||
|
fbb, fbb.CreateSharedString(ivalue.toString()->string()))
|
||||||
|
.Union();
|
||||||
|
} else if (ivalue.isGenericDict()) {
|
||||||
|
ivalue_type = IValueUnion::Dict;
|
||||||
|
offset = dictToFB(fbb, ivalue).Union();
|
||||||
|
} else if (ivalue.isNone()) {
|
||||||
|
ivalue_type = IValueUnion::NONE;
|
||||||
|
offset = 0;
|
||||||
|
} else if (ivalue.isIntList()) {
|
||||||
|
ivalue_type = IValueUnion::IntList;
|
||||||
|
offset = mobile::serialization::CreateIntList(
|
||||||
|
fbb, fbb.CreateVector(ivalue.toIntVector()))
|
||||||
|
.Union();
|
||||||
|
} else if (ivalue.isDoubleList()) {
|
||||||
|
ivalue_type = IValueUnion::DoubleList;
|
||||||
|
offset = mobile::serialization::CreateDoubleList(
|
||||||
|
fbb, fbb.CreateVector(ivalue.toDoubleVector()))
|
||||||
|
.Union();
|
||||||
|
} else if (ivalue.isBoolList()) {
|
||||||
|
ivalue_type = IValueUnion::BoolList;
|
||||||
|
auto boollist = ivalue.toBoolList();
|
||||||
|
std::vector<uint8_t> bool_vec(boollist.begin(), boollist.end());
|
||||||
|
offset =
|
||||||
|
mobile::serialization::CreateBoolListDirect(fbb, &bool_vec).Union();
|
||||||
|
} else if (ivalue.isList()) {
|
||||||
|
ivalue_type = IValueUnion::List;
|
||||||
|
offset = listToFB(fbb, ivalue).Union();
|
||||||
|
} else if (ivalue.isObject()) {
|
||||||
|
ivalue_type = IValueUnion::Object;
|
||||||
|
offset = objectToFB(fbb, ivalue).Union();
|
||||||
|
} else if (ivalue.isDevice()) {
|
||||||
|
ivalue_type = IValueUnion::Device;
|
||||||
|
offset = mobile::serialization::CreateDevice(
|
||||||
|
fbb, fbb.CreateSharedString(ivalue.toDevice().str()))
|
||||||
|
.Union();
|
||||||
|
} else if (ivalue.isEnum()) {
|
||||||
|
const auto& enum_holder = ivalue.toEnumHolder();
|
||||||
|
const auto& qualified_class_name =
|
||||||
|
enum_holder->type()->qualifiedClassName();
|
||||||
|
uint32_t ival_pos = storeIValueAndGetIndex(fbb, enum_holder->value());
|
||||||
|
ivalue_type = IValueUnion::EnumValue;
|
||||||
|
offset = mobile::serialization::CreateEnumValue(
|
||||||
|
fbb,
|
||||||
|
fbb.CreateSharedString(qualified_class_name.qualifiedName()),
|
||||||
|
ival_pos)
|
||||||
|
.Union();
|
||||||
|
} else {
|
||||||
|
AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind());
|
||||||
|
}
|
||||||
|
return CreateIValue(fbb, ivalue_type, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void save_mobile_module(
|
||||||
|
const mobile::Module& module,
|
||||||
|
const std::string& filename) {
|
||||||
|
FlatbufferSerializer fb_serializer;
|
||||||
|
auto buffer = fb_serializer.serializeModule(module, true);
|
||||||
|
std::fstream ofile(filename, std::ios::binary | std::ios::out);
|
||||||
|
ofile.write(reinterpret_cast<char*>(buffer.data()), buffer.size());
|
||||||
|
ofile.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::DetachedBuffer save_mobile_module_to_bytes(
|
||||||
|
const mobile::Module& module) {
|
||||||
|
FlatbufferSerializer fb_serializer;
|
||||||
|
return fb_serializer.serializeModule(module, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
26
torch/csrc/jit/serialization/flatbuffer_serializer.h
Normal file
26
torch/csrc/jit/serialization/flatbuffer_serializer.h
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/core/qualified_name.h>
|
||||||
|
#include <flatbuffers/flatbuffers.h>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <ATen/core/ivalue.h>
|
||||||
|
#include <ATen/core/jit_type.h>
|
||||||
|
#include <torch/csrc/jit/backends/backend_debug_handler.h>
|
||||||
|
#include <torch/csrc/jit/mobile/module.h>
|
||||||
|
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
TORCH_API void save_mobile_module(
|
||||||
|
const mobile::Module& module,
|
||||||
|
const std::string& filename);
|
||||||
|
TORCH_API flatbuffers::DetachedBuffer save_mobile_module_to_bytes(
|
||||||
|
const mobile::Module& module);
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
197
torch/csrc/jit/serialization/mobile_bytecode.fbs
Normal file
197
torch/csrc/jit/serialization/mobile_bytecode.fbs
Normal file
|
|
@ -0,0 +1,197 @@
|
||||||
|
namespace torch.jit.mobile.serialization;
|
||||||
|
|
||||||
|
struct Int {
|
||||||
|
int_val:long;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Bool {
|
||||||
|
bool_val:bool;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Double{
|
||||||
|
double_val:double;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PerTensorAffineSchema {
|
||||||
|
q_scale:double;
|
||||||
|
q_zero_point:int;
|
||||||
|
}
|
||||||
|
|
||||||
|
table QuantizedSchema {
|
||||||
|
qscheme:byte;
|
||||||
|
scale:double;
|
||||||
|
zero_point:int;
|
||||||
|
scales:TensorMetadata;
|
||||||
|
zero_points:TensorMetadata;
|
||||||
|
axis:int;
|
||||||
|
}
|
||||||
|
|
||||||
|
table TensorMetadata {
|
||||||
|
// torch._utils _rebuild_tensor_v2
|
||||||
|
storage_location_index: uint;
|
||||||
|
// enum ScalarType
|
||||||
|
scalar_type: byte;
|
||||||
|
storage_offset: int;
|
||||||
|
sizes:[int];
|
||||||
|
strides:[int];
|
||||||
|
requires_grad:bool;
|
||||||
|
|
||||||
|
// only set for quantized tensors
|
||||||
|
quantized_schema:QuantizedSchema;
|
||||||
|
}
|
||||||
|
|
||||||
|
table String {
|
||||||
|
data: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table Device {
|
||||||
|
str:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table List {
|
||||||
|
items: [uint];
|
||||||
|
annotation_str: string; // to recover key/val type
|
||||||
|
}
|
||||||
|
|
||||||
|
table IntList {
|
||||||
|
items: [long];
|
||||||
|
}
|
||||||
|
|
||||||
|
table DoubleList {
|
||||||
|
items: [double];
|
||||||
|
}
|
||||||
|
|
||||||
|
table BoolList {
|
||||||
|
items: [bool];
|
||||||
|
}
|
||||||
|
|
||||||
|
table Tuple {
|
||||||
|
items: [uint];
|
||||||
|
}
|
||||||
|
|
||||||
|
table Dict {
|
||||||
|
keys: [uint];
|
||||||
|
values: [uint];
|
||||||
|
annotation_str: string; // to recover key/val type
|
||||||
|
}
|
||||||
|
|
||||||
|
enum TypeType : ubyte {
|
||||||
|
UNSET,
|
||||||
|
CLASS_WITH_FIELD,
|
||||||
|
CUSTOM_CLASS,
|
||||||
|
CLASS_WITH_SETSTATE,
|
||||||
|
NON_OBJ,
|
||||||
|
}
|
||||||
|
|
||||||
|
table ObjectType {
|
||||||
|
type_name:string;
|
||||||
|
type: TypeType;
|
||||||
|
// Below fields are optional
|
||||||
|
attr_names:[string];
|
||||||
|
}
|
||||||
|
|
||||||
|
table Object {
|
||||||
|
type_index: uint;
|
||||||
|
state: uint;
|
||||||
|
attrs: [uint];
|
||||||
|
setstate_func: uint;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ComplexDouble {
|
||||||
|
real:double;
|
||||||
|
imag:double;
|
||||||
|
}
|
||||||
|
|
||||||
|
table EnumValue {
|
||||||
|
type_name:string;
|
||||||
|
value:uint; // index to ivalues;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
struct Instruction {
|
||||||
|
// Should op be enum instead?
|
||||||
|
op:byte;
|
||||||
|
n:ushort;
|
||||||
|
x:int;
|
||||||
|
}
|
||||||
|
|
||||||
|
table Operator {
|
||||||
|
name:string;
|
||||||
|
overload_name:string;
|
||||||
|
num_args_serialized:int = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
table Arg {
|
||||||
|
name:string;
|
||||||
|
// Why do we use string to represent types
|
||||||
|
// rather than index into Code.types?
|
||||||
|
type: string;
|
||||||
|
default_value:uint; // position into ivalues
|
||||||
|
}
|
||||||
|
|
||||||
|
table Schema {
|
||||||
|
arguments:[Arg];
|
||||||
|
returns:[Arg];
|
||||||
|
}
|
||||||
|
|
||||||
|
table DebugInfo {
|
||||||
|
debug_handle:[long];
|
||||||
|
}
|
||||||
|
|
||||||
|
table Function {
|
||||||
|
qn:string;
|
||||||
|
instructions:[Instruction];
|
||||||
|
operators:[Operator];
|
||||||
|
constants:[uint]; // index to ivalue
|
||||||
|
type_annotations: [string];
|
||||||
|
register_size:int;
|
||||||
|
schema:Schema;
|
||||||
|
debug_info:DebugInfo;
|
||||||
|
class_type:uint; // index into type table
|
||||||
|
}
|
||||||
|
|
||||||
|
table StorageData {
|
||||||
|
data: [ubyte] (force_align: 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is it needed to represent other types?
|
||||||
|
union IValueUnion {
|
||||||
|
Int,
|
||||||
|
Bool,
|
||||||
|
Double,
|
||||||
|
ComplexDouble,
|
||||||
|
TensorMetadata,
|
||||||
|
String,
|
||||||
|
List,
|
||||||
|
Tuple,
|
||||||
|
Dict,
|
||||||
|
Object,
|
||||||
|
IntList,
|
||||||
|
DoubleList,
|
||||||
|
BoolList,
|
||||||
|
Device,
|
||||||
|
EnumValue,
|
||||||
|
Function,
|
||||||
|
}
|
||||||
|
|
||||||
|
table IValue {
|
||||||
|
val: IValueUnion;
|
||||||
|
}
|
||||||
|
|
||||||
|
table ExtraFile {
|
||||||
|
name: string;
|
||||||
|
content: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table Module {
|
||||||
|
version:int;
|
||||||
|
extra_files:[ExtraFile];
|
||||||
|
methods:[uint]; // index to ivalues
|
||||||
|
state_obj: uint; // index to ivalues
|
||||||
|
ivalues: [IValue];
|
||||||
|
storage_data_size:int; // number of storage data;
|
||||||
|
storage_data: [StorageData];
|
||||||
|
object_types: [ObjectType];
|
||||||
|
}
|
||||||
|
|
||||||
|
root_type Module;
|
||||||
348
torch/csrc/jit/testing/module_differ.cpp
Normal file
348
torch/csrc/jit/testing/module_differ.cpp
Normal file
|
|
@ -0,0 +1,348 @@
|
||||||
|
#include <torch/csrc/jit/testing/module_differ.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
template <typename It>
|
||||||
|
bool ivalueListEquals(
|
||||||
|
It lbegin,
|
||||||
|
It lend,
|
||||||
|
It rbegin,
|
||||||
|
It rend,
|
||||||
|
bool print,
|
||||||
|
int print_indent) {
|
||||||
|
int i = 0;
|
||||||
|
const std::string indent(print_indent, '\t');
|
||||||
|
for (; lbegin != lend && rbegin != rend; ++lbegin, ++rbegin, ++i) {
|
||||||
|
if (!ivalueEquals(*lbegin, *rbegin, print, print_indent + 1)) {
|
||||||
|
std::cout << indent << "list element differs at position " << i
|
||||||
|
<< std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ivalueEquals(
|
||||||
|
const IValue& lhs,
|
||||||
|
const IValue& rhs,
|
||||||
|
bool print,
|
||||||
|
int print_indent) {
|
||||||
|
const std::string indent(print_indent, '\t');
|
||||||
|
if (lhs.tagKind() != rhs.tagKind()) {
|
||||||
|
if (print) {
|
||||||
|
std::cout << indent << "lhs is type: " << lhs.tagKind()
|
||||||
|
<< "rhs is type: " << rhs.tagKind() << std::endl;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (lhs.isCapsule()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lhs.isDouble() || lhs.isComplexDouble() || lhs.isInt() || lhs.isBool() ||
|
||||||
|
lhs.isString() || lhs.isDevice() || lhs.isCapsule() || lhs.isRRef() ||
|
||||||
|
lhs.isEnum() || lhs.isIntList() || lhs.isDoubleList() ||
|
||||||
|
lhs.isBoolList() || lhs.isNone()) {
|
||||||
|
// operator == should do what we want
|
||||||
|
if (lhs != rhs) {
|
||||||
|
if (print) {
|
||||||
|
std::cout << indent << "lhs is " << lhs << " || rhs is " << rhs
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lhs.isTensor()) {
|
||||||
|
const auto& lt = lhs.toTensor();
|
||||||
|
const auto& rt = rhs.toTensor();
|
||||||
|
std::stringstream lsize;
|
||||||
|
std::stringstream rsize;
|
||||||
|
for (const auto x : lt.sizes()) {
|
||||||
|
lsize << x << ",";
|
||||||
|
}
|
||||||
|
for (const auto x : rt.sizes()) {
|
||||||
|
rsize << x << ",";
|
||||||
|
}
|
||||||
|
if (lsize.str() != lsize.str()) {
|
||||||
|
if (print) {
|
||||||
|
std::cout << indent << "left tensor is of shape " << lsize.str()
|
||||||
|
<< "but right tensor is of shape " << rsize.str()
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (lt.allclose(rt)) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
if (print) {
|
||||||
|
std::cout << indent << "rhs and lhs has are not close" << std::endl;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lhs.isGenericDict()) {
|
||||||
|
const auto& ldict = lhs.toGenericDict();
|
||||||
|
const auto& rdict = rhs.toGenericDict();
|
||||||
|
if (ldict.size() != rdict.size()) {
|
||||||
|
if (print) {
|
||||||
|
std::cout << indent << "lhs and rhs are dicts of different sizes: "
|
||||||
|
<< ldict.size() << " vs. " << rdict.size() << std::endl;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& kv : ldict) {
|
||||||
|
auto rhs_iter = rdict.find(kv.key());
|
||||||
|
if (rhs_iter == rdict.end()) {
|
||||||
|
if (print) {
|
||||||
|
std::cout << indent << "rhs missing key: " << kv.key() << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!ivalueEquals(
|
||||||
|
kv.value(), rhs_iter->value(), print, print_indent + 1)) {
|
||||||
|
if (print) {
|
||||||
|
std::cout << indent << "for key: " << kv.key() << " value differs."
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
} else if (lhs.isTensorList() || lhs.isList()) {
|
||||||
|
const auto& vec = lhs.toList();
|
||||||
|
const auto& rvec = rhs.toList();
|
||||||
|
return ivalueListEquals(
|
||||||
|
vec.begin(), vec.end(), rvec.begin(), rvec.end(), print, print_indent);
|
||||||
|
} else if (lhs.isTuple()) {
|
||||||
|
const auto vec = lhs.toTuple()->elements();
|
||||||
|
const auto rvec = rhs.toTuple()->elements();
|
||||||
|
return ivalueListEquals(
|
||||||
|
vec.begin(), vec.end(), rvec.begin(), rvec.end(), print, print_indent);
|
||||||
|
} else if (lhs.isObject()) {
|
||||||
|
auto lobj = lhs.toObject();
|
||||||
|
auto robj = rhs.toObject();
|
||||||
|
auto ltype = lobj->type();
|
||||||
|
auto rtype = robj->type();
|
||||||
|
|
||||||
|
if (ltype->name() != rtype->name()) {
|
||||||
|
if (print) {
|
||||||
|
std::cerr << indent << "left object is of type: "
|
||||||
|
<< ltype->name()->qualifiedName()
|
||||||
|
<< " but right obj is of type: "
|
||||||
|
<< rtype->name()->qualifiedName() << std::endl;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto getstate = ltype->findMethod("__getstate__");
|
||||||
|
if (getstate != nullptr) {
|
||||||
|
return ivalueEquals(
|
||||||
|
(*getstate)({lobj}), (*getstate)({robj}), print, print_indent + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < ltype->numAttributes(); i++) {
|
||||||
|
if (!ivalueEquals(
|
||||||
|
lobj->getSlot(i), robj->getSlot(i), print, print_indent + 1)) {
|
||||||
|
std::cout << "attribute differs at position " << i << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::cerr << " I am here and should not be: " << rhs.tagKind() << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename COMP, typename PRINTER>
|
||||||
|
bool vectorEqual(
|
||||||
|
const std::vector<T>& lhs,
|
||||||
|
const std::vector<T>& rhs,
|
||||||
|
bool print,
|
||||||
|
COMP comparator,
|
||||||
|
PRINTER printer) {
|
||||||
|
if (lhs.size() != rhs.size()) {
|
||||||
|
if (print) {
|
||||||
|
std::cout << "lhs and rhs has different size: " << lhs.size() << "vs. "
|
||||||
|
<< rhs.size() << std::endl;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < lhs.size(); i++) {
|
||||||
|
if (!comparator(lhs[i], rhs[i])) {
|
||||||
|
if (print) {
|
||||||
|
std::cout << i << "th element of lhs and rhs differs \n lhs is "
|
||||||
|
<< printer(lhs[i]) << " rhs is " << printer(rhs[i])
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool moduleFunctionEquals(
|
||||||
|
const mobile::Function& lhs,
|
||||||
|
const mobile::Function& rhs,
|
||||||
|
bool print) {
|
||||||
|
const auto* lhs_code = lhs.get_code().get();
|
||||||
|
const auto* rhs_code = rhs.get_code().get();
|
||||||
|
|
||||||
|
// instructions
|
||||||
|
if (print) {
|
||||||
|
std::cout << "> Diffing instructions..." << std::endl;
|
||||||
|
}
|
||||||
|
auto ins_equal = [](Instruction lins, Instruction rins) -> bool {
|
||||||
|
return (lins.op == rins.op && lins.N == rins.N && lins.X == rins.X);
|
||||||
|
};
|
||||||
|
auto id = [](auto ins) {
|
||||||
|
return ins; // operator << works for it already
|
||||||
|
};
|
||||||
|
if (vectorEqual(
|
||||||
|
lhs_code->instructions_,
|
||||||
|
rhs_code->instructions_,
|
||||||
|
true,
|
||||||
|
ins_equal,
|
||||||
|
id)) {
|
||||||
|
std::cout << " pass." << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << " fail." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// constants
|
||||||
|
|
||||||
|
if (print) {
|
||||||
|
std::cout << "> Diffing constants..." << std::endl;
|
||||||
|
}
|
||||||
|
if (ivalueListEquals(
|
||||||
|
lhs_code->constants_.begin(),
|
||||||
|
lhs_code->constants_.end(),
|
||||||
|
rhs_code->constants_.begin(),
|
||||||
|
rhs_code->constants_.end(),
|
||||||
|
true,
|
||||||
|
2)) {
|
||||||
|
std::cout << " pass" << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << " fail" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// diffing operators
|
||||||
|
if (print) {
|
||||||
|
std::cout << "> Diffing operators ..." << std::endl;
|
||||||
|
}
|
||||||
|
auto equals = [](auto op1, auto op2) -> bool { return op1 == op2; };
|
||||||
|
if (vectorEqual(lhs_code->op_names_, rhs_code->op_names_, true, equals, id)) {
|
||||||
|
std::cout << " pass." << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << " fail." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lhs_code->register_size_ != rhs_code->register_size_) {
|
||||||
|
std::cout << "Register size differs: " << lhs_code->register_size_
|
||||||
|
<< " vs. " << rhs_code->register_size_ << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// debug handles
|
||||||
|
if (print) {
|
||||||
|
std::cout << "> Diffing debug handles..." << std::endl;
|
||||||
|
}
|
||||||
|
if (vectorEqual(
|
||||||
|
lhs_code->debug_handles_,
|
||||||
|
rhs_code->debug_handles_,
|
||||||
|
true,
|
||||||
|
equals,
|
||||||
|
id)) {
|
||||||
|
std::cout << " pass." << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << " fail." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// types
|
||||||
|
auto type_eq = [](auto t1, auto t2) { return t1->str() == t2->str(); };
|
||||||
|
auto type_print = [](auto t1) { return t1->str(); };
|
||||||
|
|
||||||
|
if (print) {
|
||||||
|
std::cout << "> Diffing types..." << std::endl;
|
||||||
|
}
|
||||||
|
if (vectorEqual(
|
||||||
|
lhs_code->types_, rhs_code->types_, true, type_eq, type_print)) {
|
||||||
|
std::cout << " pass." << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << " fail." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (print) {
|
||||||
|
std::cout << "> Diffing schema..." << std::endl;
|
||||||
|
}
|
||||||
|
// NOTE: Schema has Argument; which has TypePtr. operator== of
|
||||||
|
// TypePtr is pointer identity. This behavior is not suitable for
|
||||||
|
// our use case.
|
||||||
|
if (toString(lhs.getSchema()) == toString(rhs.getSchema())) {
|
||||||
|
std::cout << " pass." << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << " lhs is " << lhs.getSchema() << std::endl;
|
||||||
|
std::cout << " rhs is " << rhs.getSchema() << std::endl;
|
||||||
|
std::cout << " fail." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool moduleEquals(const mobile::Module& lhs, const mobile::Module& rhs) {
|
||||||
|
std::unordered_map<std::string, const mobile::Function*> lhs_name_to_func;
|
||||||
|
std::unordered_map<std::string, const mobile::Function*> rhs_name_to_func;
|
||||||
|
|
||||||
|
for (const auto& func : lhs.compilation_unit().methods()) {
|
||||||
|
lhs_name_to_func[func->name()] = func.get();
|
||||||
|
}
|
||||||
|
for (const auto& func : rhs.compilation_unit().methods()) {
|
||||||
|
rhs_name_to_func[func->name()] = func.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& name_func : lhs_name_to_func) {
|
||||||
|
auto rhs_func = rhs_name_to_func.find(name_func.first);
|
||||||
|
if (rhs_func == rhs_name_to_func.end()) {
|
||||||
|
std::cout << "Method with name: " << name_func.first
|
||||||
|
<< " only exists in lhs";
|
||||||
|
}
|
||||||
|
std::cout << "comparing method with name " << name_func.first << std::endl;
|
||||||
|
if (moduleFunctionEquals(*name_func.second, *rhs_func->second, true)) {
|
||||||
|
std::cout << "pass" << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << "fail" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "Diffing m._ivalue()..." << std::endl;
|
||||||
|
if (ivalueEquals(lhs._ivalue(), rhs._ivalue(), true, 0)) {
|
||||||
|
std::cout << " pass." << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << " fail." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
32
torch/csrc/jit/testing/module_differ.h
Normal file
32
torch/csrc/jit/testing/module_differ.h
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
#include <ATen/core/ivalue.h>
|
||||||
|
#include <torch/csrc/jit/mobile/module.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
// Compares 2 mobile::Module. Comparison is done as follows:
|
||||||
|
// 1. _ivalue() returned by both should be equal according to ivalueEquals below
|
||||||
|
// 2. all functions with same name shall have same instructions and constants
|
||||||
|
// 3. all functions in lhs exists in rhs.
|
||||||
|
TORCH_API bool moduleEquals(
|
||||||
|
const mobile::Module& lhs,
|
||||||
|
const mobile::Module& rhs);
|
||||||
|
|
||||||
|
// This is a function used in unittests to see if 2 IValue are the same.
|
||||||
|
// If print is true; then it will print out where the ivalue differs.
|
||||||
|
// Behavior of this function is different from IValue::operator== in the
|
||||||
|
// following parts:
|
||||||
|
// 1. Tensors are compared with allclose and returns bool (instead of bool
|
||||||
|
// tensor)
|
||||||
|
// 2. Therefore, comparing List[Tensor] or deeply nested tensor works
|
||||||
|
// 3. 2 Capsules compares to true: this is because we intent to use this to
|
||||||
|
// compare 2 IValue's after
|
||||||
|
// saving and loading.
|
||||||
|
TORCH_API bool ivalueEquals(
|
||||||
|
const IValue& lhs,
|
||||||
|
const IValue& rhs,
|
||||||
|
bool print,
|
||||||
|
int print_indent = 0);
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
|
|
@ -41,13 +41,19 @@ def _load_for_lite_interpreter(f, map_location=None):
|
||||||
raise ValueError("The provided filename {} does not exist".format(f))
|
raise ValueError("The provided filename {} does not exist".format(f))
|
||||||
if os.path.isdir(f):
|
if os.path.isdir(f):
|
||||||
raise ValueError("The provided filename {} is a directory".format(f))
|
raise ValueError("The provided filename {} is a directory".format(f))
|
||||||
|
zip_magic = b'PK\x03\x04'
|
||||||
map_location = validate_map_location(map_location)
|
map_location = validate_map_location(map_location)
|
||||||
|
|
||||||
if isinstance(f, str) or isinstance(f, pathlib.Path):
|
if isinstance(f, str) or isinstance(f, pathlib.Path):
|
||||||
cpp_module = torch._C._load_for_lite_interpreter(f, map_location)
|
is_flatbuffer = False
|
||||||
|
with open(f, 'rb') as fi:
|
||||||
|
magic_bytes = fi.read(4)
|
||||||
|
is_flatbuffer = (magic_bytes != zip_magic)
|
||||||
|
cpp_module = torch._C._load_for_lite_interpreter(f, map_location, is_flatbuffer)
|
||||||
else:
|
else:
|
||||||
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(f.read(), map_location)
|
all_bytes = f.read()
|
||||||
|
is_flatbuffer = (all_bytes[:4] != zip_magic)
|
||||||
|
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
|
||||||
|
all_bytes, map_location, is_flatbuffer)
|
||||||
|
|
||||||
return LiteScriptModule(cpp_module)
|
return LiteScriptModule(cpp_module)
|
||||||
|
|
||||||
|
|
@ -215,3 +221,15 @@ def _get_model_ops_and_info(f_input):
|
||||||
return torch._C._get_model_ops_and_info(str(f_input))
|
return torch._C._get_model_ops_and_info(str(f_input))
|
||||||
else:
|
else:
|
||||||
return torch._C._get_model_ops_and_info(f_input.read())
|
return torch._C._get_model_ops_and_info(f_input.read())
|
||||||
|
|
||||||
|
|
||||||
|
def save_mobile_module(m: LiteScriptModule, filename: str):
|
||||||
|
torch._C._save_mobile_module(m._c, filename)
|
||||||
|
|
||||||
|
def jit_module_to_mobile(m):
|
||||||
|
mobile_m = torch._C._jit_module_to_mobile(m._c)
|
||||||
|
return LiteScriptModule(mobile_m)
|
||||||
|
|
||||||
|
|
||||||
|
def module_equals(lhs: LiteScriptModule, rhs: LiteScriptModule):
|
||||||
|
torch._C._module_equals(lhs._c, rhs._c)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user