mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Pytorch] Build lite interpreter as default for iOS
Summary: Two changes: 1. Build lite interpreter as default for iOS 2. Switch the previous lite interpreter test to full jit build test Test Plan: Imported from OSS Differential Revision: D27698039 Reviewed By: xta0 Pulled By: cccclai fbshipit-source-id: 022b554f4997ae577681f2b79a9ebe9236ca4f7d
This commit is contained in:
parent
8a3fb2689f
commit
b5a834a739
|
|
@ -61,14 +61,20 @@ class IOSJob:
|
|||
|
||||
|
||||
WORKFLOW_DATA = [
|
||||
IOSJob(XCODE_VERSION, ArchVariant("x86_64"), is_org_member_context=False),
|
||||
IOSJob(XCODE_VERSION, ArchVariant("x86_64", "lite_interpreter"), is_org_member_context=False, extra_props={
|
||||
IOSJob(XCODE_VERSION, ArchVariant("x86_64"), is_org_member_context=False, extra_props={
|
||||
"lite_interpreter": miniutils.quote(str(int(True)))}),
|
||||
IOSJob(XCODE_VERSION, ArchVariant("arm64")),
|
||||
IOSJob(XCODE_VERSION, ArchVariant("arm64", "metal"), extra_props={"use_metal": miniutils.quote(str(int(True)))}),
|
||||
IOSJob(XCODE_VERSION, ArchVariant("arm64", "lite_interpreter"), extra_props={
|
||||
IOSJob(XCODE_VERSION, ArchVariant("x86_64", "full_jit"), is_org_member_context=False, extra_props={
|
||||
"lite_interpreter": miniutils.quote(str(int(False)))}),
|
||||
IOSJob(XCODE_VERSION, ArchVariant("arm64"), extra_props={
|
||||
"lite_interpreter": miniutils.quote(str(int(True)))}),
|
||||
IOSJob(XCODE_VERSION, ArchVariant("arm64", "metal"), extra_props={
|
||||
"use_metal": miniutils.quote(str(int(True))),
|
||||
"lite_interpreter": miniutils.quote(str(int(True)))}),
|
||||
IOSJob(XCODE_VERSION, ArchVariant("arm64", "full_jit"), extra_props={
|
||||
"lite_interpreter": miniutils.quote(str(int(False)))}),
|
||||
IOSJob(XCODE_VERSION, ArchVariant("arm64", "custom"), extra_props={
|
||||
"op_list": "mobilenetv2.yaml",
|
||||
"lite_interpreter": miniutils.quote(str(int(True)))}),
|
||||
IOSJob(XCODE_VERSION, ArchVariant("arm64", "custom"), extra_props={"op_list": "mobilenetv2.yaml"}),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ pytorch_android_params: &pytorch_android_params
|
|||
default: ""
|
||||
lite_interpreter:
|
||||
type: string
|
||||
default: "0"
|
||||
default: "1"
|
||||
environment:
|
||||
BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single
|
||||
DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c"
|
||||
|
|
@ -324,7 +324,7 @@ pytorch_ios_params: &pytorch_ios_params
|
|||
default: "0"
|
||||
lite_interpreter:
|
||||
type: string
|
||||
default: "0"
|
||||
default: "1"
|
||||
environment:
|
||||
BUILD_ENVIRONMENT: << parameters.build_environment >>
|
||||
IOS_ARCH: << parameters.ios_arch >>
|
||||
|
|
@ -1759,8 +1759,8 @@ jobs:
|
|||
no_output_timeout: "30m"
|
||||
command: |
|
||||
set -e
|
||||
if [ ${BUILD_LITE_INTERPRETER} == 1 ]; then
|
||||
echo "Run Build Test is not for BUILD_LITE_INTERPRETER, skipping."
|
||||
if [ ${BUILD_LITE_INTERPRETER} == 0 ]; then
|
||||
echo "Run Build Test is not for full jit, skipping."
|
||||
exit 0
|
||||
fi
|
||||
PROJ_ROOT=/Users/distiller/project
|
||||
|
|
@ -1788,8 +1788,8 @@ jobs:
|
|||
if [ ${IOS_PLATFORM} != "SIMULATOR" ]; then
|
||||
echo "not SIMULATOR build, skip it."
|
||||
exit 0
|
||||
elif [ ${BUILD_LITE_INTERPRETER} == 1 ]; then
|
||||
echo "Run Simulator Tests is not for BUILD_LITE_INTERPRETER, skipping."
|
||||
elif [ ${BUILD_LITE_INTERPRETER} == 0 ]; then
|
||||
echo "Run Simulator Tests is not for full jit, skipping."
|
||||
exit 0
|
||||
fi
|
||||
WORKSPACE=/Users/distiller/workspace
|
||||
|
|
@ -6555,38 +6555,42 @@ workflows:
|
|||
build_environment: pytorch-ios-12.0.0-x86_64_build
|
||||
ios_arch: x86_64
|
||||
ios_platform: SIMULATOR
|
||||
lite_interpreter: "1"
|
||||
name: pytorch_ios_12_0_0_x86_64_build
|
||||
- pytorch_ios_build:
|
||||
build_environment: pytorch-ios-12.0.0-x86_64_lite_interpreter_build
|
||||
build_environment: pytorch-ios-12.0.0-x86_64_full_jit_build
|
||||
ios_arch: x86_64
|
||||
ios_platform: SIMULATOR
|
||||
lite_interpreter: "1"
|
||||
name: pytorch_ios_12_0_0_x86_64_lite_interpreter_build
|
||||
lite_interpreter: "0"
|
||||
name: pytorch_ios_12_0_0_x86_64_full_jit_build
|
||||
- pytorch_ios_build:
|
||||
build_environment: pytorch-ios-12.0.0-arm64_build
|
||||
context: org-member
|
||||
ios_arch: arm64
|
||||
ios_platform: OS
|
||||
lite_interpreter: "1"
|
||||
name: pytorch_ios_12_0_0_arm64_build
|
||||
- pytorch_ios_build:
|
||||
build_environment: pytorch-ios-12.0.0-arm64_metal_build
|
||||
context: org-member
|
||||
ios_arch: arm64
|
||||
ios_platform: OS
|
||||
lite_interpreter: "1"
|
||||
name: pytorch_ios_12_0_0_arm64_metal_build
|
||||
use_metal: "1"
|
||||
- pytorch_ios_build:
|
||||
build_environment: pytorch-ios-12.0.0-arm64_lite_interpreter_build
|
||||
build_environment: pytorch-ios-12.0.0-arm64_full_jit_build
|
||||
context: org-member
|
||||
ios_arch: arm64
|
||||
ios_platform: OS
|
||||
lite_interpreter: "1"
|
||||
name: pytorch_ios_12_0_0_arm64_lite_interpreter_build
|
||||
lite_interpreter: "0"
|
||||
name: pytorch_ios_12_0_0_arm64_full_jit_build
|
||||
- pytorch_ios_build:
|
||||
build_environment: pytorch-ios-12.0.0-arm64_custom_build
|
||||
context: org-member
|
||||
ios_arch: arm64
|
||||
ios_platform: OS
|
||||
lite_interpreter: "1"
|
||||
name: pytorch_ios_12_0_0_arm64_custom_build
|
||||
op_list: mobilenetv2.yaml
|
||||
- pytorch_linux_build:
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ pytorch_android_params: &pytorch_android_params
|
|||
default: ""
|
||||
lite_interpreter:
|
||||
type: string
|
||||
default: "0"
|
||||
default: "1"
|
||||
environment:
|
||||
BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single
|
||||
DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c"
|
||||
|
|
@ -59,7 +59,7 @@ pytorch_ios_params: &pytorch_ios_params
|
|||
default: "0"
|
||||
lite_interpreter:
|
||||
type: string
|
||||
default: "0"
|
||||
default: "1"
|
||||
environment:
|
||||
BUILD_ENVIRONMENT: << parameters.build_environment >>
|
||||
IOS_ARCH: << parameters.ios_arch >>
|
||||
|
|
|
|||
|
|
@ -528,8 +528,8 @@
|
|||
no_output_timeout: "30m"
|
||||
command: |
|
||||
set -e
|
||||
if [ ${BUILD_LITE_INTERPRETER} == 1 ]; then
|
||||
echo "Run Build Test is not for BUILD_LITE_INTERPRETER, skipping."
|
||||
if [ ${BUILD_LITE_INTERPRETER} == 0 ]; then
|
||||
echo "Run Build Test is not for full jit, skipping."
|
||||
exit 0
|
||||
fi
|
||||
PROJ_ROOT=/Users/distiller/project
|
||||
|
|
@ -557,8 +557,8 @@
|
|||
if [ ${IOS_PLATFORM} != "SIMULATOR" ]; then
|
||||
echo "not SIMULATOR build, skip it."
|
||||
exit 0
|
||||
elif [ ${BUILD_LITE_INTERPRETER} == 1 ]; then
|
||||
echo "Run Simulator Tests is not for BUILD_LITE_INTERPRETER, skipping."
|
||||
elif [ ${BUILD_LITE_INTERPRETER} == 0 ]; then
|
||||
echo "Run Simulator Tests is not for full jit, skipping."
|
||||
exit 0
|
||||
fi
|
||||
WORKSPACE=/Users/distiller/workspace
|
||||
|
|
|
|||
|
|
@ -1,14 +1,19 @@
|
|||
#import "Benchmark.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "torch/script.h"
|
||||
|
||||
#include <torch/csrc/jit/mobile/function.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
#include "ATen/ATen.h"
|
||||
#include "caffe2/core/timer.h"
|
||||
#include "caffe2/utils/string_utils.h"
|
||||
#include "torch/csrc/autograd/grad_mode.h"
|
||||
#include "torch/csrc/jit/serialization/import.h"
|
||||
#include "torch/script.h"
|
||||
|
||||
static std::string model = "model.pt";
|
||||
static std::string model = "model.ptl";
|
||||
static std::string input_dims = "1,3,224,224";
|
||||
static std::string input_type = "float";
|
||||
static BOOL print_output = false;
|
||||
|
|
@ -18,9 +23,9 @@ static int iter = 10;
|
|||
@implementation Benchmark
|
||||
|
||||
+ (BOOL)setup:(NSDictionary*)config {
|
||||
NSString* modelPath = [[NSBundle mainBundle] pathForResource:@"model" ofType:@"pt"];
|
||||
NSString* modelPath = [[NSBundle mainBundle] pathForResource:@"model" ofType:@"ptl"];
|
||||
if (![[NSFileManager defaultManager] fileExistsAtPath:modelPath]) {
|
||||
NSLog(@"model.pt doesn't exist!");
|
||||
NSLog(@"model.ptl doesn't exist!");
|
||||
return NO;
|
||||
}
|
||||
model = std::string(modelPath.UTF8String);
|
||||
|
|
@ -66,10 +71,9 @@ static int iter = 10;
|
|||
}
|
||||
|
||||
c10::InferenceMode mode;
|
||||
torch::jit::GraphOptimizerEnabledGuard opguard(false);
|
||||
auto module = torch::jit::load(model);
|
||||
auto module = torch::jit::_load_for_mobile(model);
|
||||
|
||||
module.eval();
|
||||
// module.eval();
|
||||
if (print_output) {
|
||||
std::cout << module.forward(inputs) << std::endl;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,22 @@
|
|||
#import <XCTest/XCTest.h>
|
||||
|
||||
#include <torch/script.h>
|
||||
#include <torch/csrc/jit/mobile/function.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
#include "ATen/ATen.h"
|
||||
#include "caffe2/core/timer.h"
|
||||
#include "caffe2/utils/string_utils.h"
|
||||
#include "torch/csrc/autograd/grad_mode.h"
|
||||
|
||||
@interface TestAppTests : XCTestCase
|
||||
|
||||
@end
|
||||
|
||||
@implementation TestAppTests {
|
||||
torch::jit::Module _module;
|
||||
torch::jit::mobile::Module _module;
|
||||
}
|
||||
|
||||
+ (void)setUp {
|
||||
|
|
@ -17,14 +26,14 @@
|
|||
- (void)setUp {
|
||||
[super setUp];
|
||||
NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model"
|
||||
ofType:@"pt"];
|
||||
ofType:@"ptl"];
|
||||
XCTAssertTrue([NSFileManager.defaultManager fileExistsAtPath:modelPath],
|
||||
@"model.pt doesn't exist!");
|
||||
_module = torch::jit::load(modelPath.UTF8String);
|
||||
@"model.ptl doesn't exist!");
|
||||
_module = torch::jit::_load_for_mobile(modelPath.UTF8String);
|
||||
}
|
||||
|
||||
- (void)testForward {
|
||||
_module.eval();
|
||||
// _module.eval();
|
||||
c10::InferenceMode mode;
|
||||
std::vector<c10::IValue> inputs;
|
||||
inputs.push_back(torch::ones({1, 3, 224, 224}, at::ScalarType::Float));
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ targets.each do |target|
|
|||
end
|
||||
end
|
||||
puts "Installing the testing model..."
|
||||
model_path = File.expand_path("./model.pt")
|
||||
model_path = File.expand_path("./model.ptl")
|
||||
if not File.exist?(model_path)
|
||||
raise "model.pt can't be found!"
|
||||
end
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import torch
|
||||
import torchvision
|
||||
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
|
||||
model = torchvision.models.mobilenet_v2(pretrained=True)
|
||||
model.eval()
|
||||
example = torch.rand(1, 3, 224, 224)
|
||||
traced_script_module = torch.jit.trace(model, example)
|
||||
traced_script_module.save("model.pt")
|
||||
traced_script_module = torch.jit.script(model, example)
|
||||
optimized_scripted_module = optimize_for_mobile(traced_script_module)
|
||||
exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter("model.ptl")
|
||||
|
|
|
|||
|
|
@ -78,10 +78,10 @@ if [ -n "${IOS_ARCH:-}" ]; then
|
|||
CMAKE_ARGS+=("-DIOS_ARCH=${IOS_ARCH}")
|
||||
fi
|
||||
|
||||
if [ "${BUILD_LITE_INTERPRETER}" == 1 ]; then
|
||||
CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON")
|
||||
else
|
||||
if [ "${BUILD_LITE_INTERPRETER}" == 0 ]; then
|
||||
CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF")
|
||||
else
|
||||
CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON")
|
||||
fi
|
||||
|
||||
# Don't build binaries or tests (only the library)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user