[pytorch] Delete TorchScript based Android demo app and point user to ExecuTorch (#153767)

Summary: A retry of #153656. This time start from co-dev to make sure we capture internal signals.

Test Plan: Rely on CI jobs.

Differential Revision: D74911818

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153767
Approved by: https://github.com/kirklandsign, https://github.com/cyyever, https://github.com/Skylion007
This commit is contained in:
Mengwei Liu 2025-05-19 17:20:36 +00:00 committed by PyTorch MergeBot
parent 6487ea30b3
commit be36bacdaa
26 changed files with 4 additions and 1987 deletions

View File

@ -2,7 +2,9 @@
## Demo applications and tutorials ## Demo applications and tutorials
Demo applications with code walk-through can be find in [this github repo](https://github.com/pytorch/android-demo-app). Please refer to [pytorch-labs/executorch-examples](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch).
Please join our [Discord](https://discord.com/channels/1334270993966825602/1349854760299270284) for any questions.
## Publishing ## Publishing
@ -119,8 +121,6 @@ We also have to add all transitive dependencies of our aars.
As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.10.5'` and `'com.facebook.fbjni:fbjni-java-only:0.2.2'`, we need to add them. As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.10.5'` and `'com.facebook.fbjni:fbjni-java-only:0.2.2'`, we need to add them.
(In case of using maven dependencies they are added automatically from `pom.xml`). (In case of using maven dependencies they are added automatically from `pom.xml`).
You can check out [test app example](https://github.com/pytorch/pytorch/blob/master/android/test_app/app/build.gradle) that uses aars directly.
## Linking to prebuilt libtorch library from gradle dependency ## Linking to prebuilt libtorch library from gradle dependency
In some cases, you may want to use libtorch from your android native build. In some cases, you may want to use libtorch from your android native build.
@ -202,7 +202,7 @@ find_library(FBJNI_LIBRARY fbjni
NO_CMAKE_FIND_ROOT_PATH) NO_CMAKE_FIND_ROOT_PATH)
target_link_libraries(${PROJECT_NAME} target_link_libraries(${PROJECT_NAME}
${PYTORCH_LIBRARY}) ${PYTORCH_LIBRARY}
${FBJNI_LIBRARY}) ${FBJNI_LIBRARY})
``` ```
@ -233,8 +233,6 @@ void loadAndForwardModel(const std::string& modelPath) {
To load torchscript model for mobile we need some special setup which is placed in `struct JITCallGuard` in this example. It may change in future, you can track the latest changes keeping an eye in our [pytorch android jni code]([https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp#L28) To load torchscript model for mobile we need some special setup which is placed in `struct JITCallGuard` in this example. It may change in future, you can track the latest changes keeping an eye in our [pytorch android jni code]([https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp#L28)
[Example of linking to libtorch from aar](https://github.com/pytorch/pytorch/tree/master/android/test_app)
## PyTorch Android API Javadoc ## PyTorch Android API Javadoc
You can find more details about the PyTorch Android API in the [Javadoc](https://pytorch.org/javadoc/). You can find more details about the PyTorch Android API in the [Javadoc](https://pytorch.org/javadoc/).

View File

@ -1,30 +0,0 @@
#!/bin/bash
set -eux
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android
echo "PYTORCH_DIR:$PYTORCH_DIR"
source "$PYTORCH_ANDROID_DIR/common.sh"
check_android_sdk
check_gradle
parse_abis_list "$@"
build_android
# To set proxy for gradle add following lines to ./gradle/gradle.properties:
# systemProp.http.proxyHost=...
# systemProp.http.proxyPort=8080
# systemProp.https.proxyHost=...
# systemProp.https.proxyPort=8080
if [ "$CUSTOM_ABIS_LIST" = true ]; then
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
else
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
fi
find $PYTORCH_ANDROID_DIR -type f -name *apk
find $PYTORCH_ANDROID_DIR -type f -name *apk | xargs echo "To install apk run: $ANDROID_HOME/platform-tools/adb install -r "

View File

@ -1,32 +0,0 @@
#!/bin/bash
###############################################################################
# This script tests the custom selective build flow for PyTorch Android, which
# optimizes library size by only including ops used by a specific model.
###############################################################################
set -eux
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
PYTORCH_ANDROID_DIR="${PYTORCH_DIR}/android"
BUILD_ROOT="${PYTORCH_DIR}/build_pytorch_android_custom"
source "${PYTORCH_ANDROID_DIR}/common.sh"
prepare_model_and_dump_root_ops() {
cd "${BUILD_ROOT}"
MODEL="${BUILD_ROOT}/MobileNetV2.pt"
ROOT_OPS="${BUILD_ROOT}/MobileNetV2.yaml"
python "${PYTORCH_ANDROID_DIR}/test_app/make_assets_custom.py"
cp "${MODEL}" "${PYTORCH_ANDROID_DIR}/test_app/app/src/main/assets/mobilenet2.pt"
}
# Start building
mkdir -p "${BUILD_ROOT}"
check_android_sdk
check_gradle
parse_abis_list "$@"
prepare_model_and_dump_root_ops
SELECTED_OP_LIST="${ROOT_OPS}" build_android
# TODO: change this to build test_app instead
$GRADLE_PATH -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean assembleRelease

View File

@ -3,4 +3,3 @@ include ':app', ':pytorch_android', ':pytorch_android_torchvision', ':pytorch_ho
project(':pytorch_android_torchvision').projectDir = file('pytorch_android_torchvision') project(':pytorch_android_torchvision').projectDir = file('pytorch_android_torchvision')
project(':pytorch_host').projectDir = file('pytorch_android/host') project(':pytorch_host').projectDir = file('pytorch_android/host')
project(':test_app').projectDir = file('test_app/app')

View File

@ -1,9 +0,0 @@
local.properties
**/*.iml
.gradle
gradlew*
gradle/wrapper
.idea/*
.DS_Store
build
.externalNativeBuild

View File

@ -1,38 +0,0 @@
cmake_minimum_required(VERSION 3.5)
set(PROJECT_NAME pytorch_testapp_jni)
project(${PROJECT_NAME} CXX)
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.")
set(CMAKE_VERBOSE_MAKEFILE ON)
set(build_DIR ${CMAKE_SOURCE_DIR}/build)
set(pytorch_testapp_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
message(STATUS "ANDROID_STL:${ANDROID_STL}")
file(GLOB pytorch_testapp_SOURCES
${pytorch_testapp_cpp_DIR}/pytorch_testapp_jni.cpp
)
add_library(${PROJECT_NAME} SHARED
${pytorch_testapp_SOURCES}
)
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
target_compile_options(${PROJECT_NAME} PRIVATE
-fexceptions
)
set(BUILD_SUBDIR ${ANDROID_ABI})
target_include_directories(${PROJECT_NAME} PRIVATE
${PYTORCH_INCLUDE_DIRS}
)
find_library(PYTORCH_LIBRARY pytorch_jni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)
target_link_libraries(${PROJECT_NAME}
${PYTORCH_LIBRARY}
log)

View File

@ -1,190 +0,0 @@
apply plugin: 'com.android.application'
repositories {
jcenter()
maven {
url "https://oss.sonatype.org/content/repositories/snapshots"
}
flatDir {
dirs 'aars'
}
}
android {
configurations {
extractForNativeBuild
}
compileOptions {
sourceCompatibility 1.8
targetCompatibility 1.8
}
compileSdkVersion rootProject.compileSdkVersion
buildToolsVersion rootProject.buildToolsVersion
defaultConfig {
applicationId "org.pytorch.testapp"
minSdkVersion rootProject.minSdkVersion
targetSdkVersion rootProject.targetSdkVersion
versionCode 1
versionName "1.0"
ndk {
abiFilters ABI_FILTERS.split(",")
}
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//externalNativeBuild {
// cmake {
// abiFilters ABI_FILTERS.split(",")
// arguments "-DANDROID_STL=c++_shared"
// }
//}
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2q.pt\"")
buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
buildConfigField("boolean", "NATIVE_BUILD", 'false')
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
buildConfigField(
"int",
"BUILD_LITE_INTERPRETER",
System.env.BUILD_LITE_INTERPRETER != null ? System.env.BUILD_LITE_INTERPRETER : "1"
)
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
}
buildTypes {
debug {
minifyEnabled false
debuggable true
}
release {
minifyEnabled false
}
}
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//externalNativeBuild {
// cmake {
// path "CMakeLists.txt"
// }
//}
flavorDimensions "model", "build", "activity"
productFlavors {
mnet {
dimension "model"
applicationIdSuffix ".mnet"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2.ptl\"")
addManifestPlaceholders([APP_NAME: "MNET"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"")
}
// NB: This is not working atm https://github.com/pytorch/pytorch/issues/102966
mnetVulkan {
dimension "model"
applicationIdSuffix ".mnet_vulkan"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2_vulkan.ptl\"")
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
addManifestPlaceholders([APP_NAME: "MNET_VULKAN"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"")
}
resnet18 {
dimension "model"
applicationIdSuffix ".resnet18"
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.ptl\"")
addManifestPlaceholders([APP_NAME: "RN18"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
}
local {
dimension "build"
}
nightly {
dimension "build"
}
aar {
dimension "build"
}
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//nativeBuild {
// dimension "build"
// buildConfigField("boolean", "NATIVE_BUILD", "true")
//}
camera {
dimension "activity"
addManifestPlaceholders([MAIN_ACTIVITY: "org.pytorch.testapp.CameraActivity"])
}
base {
dimension "activity"
sourceSets {
main {
java {
exclude 'org/pytorch/testapp/CameraActivity.java'
}
}
}
}
}
packagingOptions {
doNotStrip '**.so'
}
// Filtering for CI
if (!testAppAllVariantsEnabled.toBoolean()) {
variantFilter { variant ->
def names = variant.flavors*.name
if (names.contains("nightly")
|| names.contains("camera")
|| names.contains("aar")
|| names.contains("nativeBuild")) {
setIgnore(true)
}
}
}
}
tasks.all { task ->
// Disable externalNativeBuild for all but nativeBuild variant
if (task.name.startsWith('externalNativeBuild')
&& !task.name.contains('NativeBuild')) {
task.enabled = false
}
}
dependencies {
implementation 'com.android.support:appcompat-v7:28.0.0'
implementation 'com.facebook.soloader:nativeloader:0.10.5'
localImplementation project(':pytorch_android')
localImplementation project(':pytorch_android_torchvision')
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//nativeBuildImplementation(name: 'pytorch_android-release', ext: 'aar')
//nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
nightlyImplementation 'org.pytorch:pytorch_android:2.2.0-SNAPSHOT'
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:2.2.0-SNAPSHOT'
aarImplementation(name:'pytorch_android', ext:'aar')
aarImplementation(name:'pytorch_android_torchvision', ext:'aar')
aarImplementation 'com.facebook.soloader:nativeloader:0.10.5'
aarImplementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
def camerax_version = "1.0.0-alpha05"
cameraImplementation "androidx.camera:camera-core:$camerax_version"
cameraImplementation "androidx.camera:camera-camera2:$camerax_version"
cameraImplementation 'com.google.android.material:material:1.0.0-beta01'
}
task extractAARForNativeBuild {
doLast {
configurations.extractForNativeBuild.files.each {
def file = it.absoluteFile
copy {
from zipTree(file)
into "$buildDir/$file.name"
include "headers/**"
include "jni/**"
}
}
}
}
tasks.whenTaskAdded { task ->
if (task.name.contains('externalNativeBuild')) {
task.dependsOn(extractAARForNativeBuild)
}
}

View File

@ -1,27 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.pytorch.testapp">
<application
android:allowBackup="true"
android:label="${APP_NAME}"
android:supportsRtl="true"
android:theme="@style/AppTheme">
<activity android:name="${MAIN_ACTIVITY}">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
<uses-permission android:name="android.permission.CAMERA" />
<!--
Permissions required by the Snapdragon Profiler to collect GPU metrics.
-->
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
</manifest>

View File

@ -1,3 +0,0 @@
*
*/
!.gitignore

View File

@ -1,77 +0,0 @@
#include <android/log.h>
#include <pthread.h>
#include <unistd.h>
#include <cassert>
#include <cmath>
#include <vector>
#define ALOGI(...) \
__android_log_print(ANDROID_LOG_INFO, "PyTorchTestAppJni", __VA_ARGS__)
#define ALOGE(...) \
__android_log_print(ANDROID_LOG_ERROR, "PyTorchTestAppJni", __VA_ARGS__)
#include "jni.h"
#include <torch/script.h>
namespace pytorch_testapp_jni {
namespace {
template <typename T>
void log(const char* m, T t) {
std::ostringstream os;
os << t << std::endl;
ALOGI("%s %s", m, os.str().c_str());
}
struct JITCallGuard {
c10::InferenceMode guard;
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
};
} // namespace
static void loadAndForwardModel(JNIEnv* env, jclass, jstring jModelPath) {
const char* modelPath = env->GetStringUTFChars(jModelPath, 0);
assert(modelPath);
// To load torchscript model for mobile we need set these guards,
// because mobile build doesn't support features like autograd for smaller
// build size which is placed in `struct JITCallGuard` in this example. It may
// change in future, you can track the latest changes keeping an eye in
// android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
JITCallGuard guard;
torch::jit::Module module = torch::jit::load(modelPath);
module.eval();
torch::Tensor t = torch::randn({1, 3, 224, 224});
log("input tensor:", t);
c10::IValue t_out = module.forward({t});
log("output tensor:", t_out);
env->ReleaseStringUTFChars(jModelPath, modelPath);
}
} // namespace pytorch_testapp_jni
JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) {
JNIEnv* env;
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
return JNI_ERR;
}
jclass c =
env->FindClass("org/pytorch/testapp/LibtorchNativeClient$NativePeer");
if (c == nullptr) {
return JNI_ERR;
}
static const JNINativeMethod methods[] = {
{"loadAndForwardModel",
"(Ljava/lang/String;)V",
(void*)pytorch_testapp_jni::loadAndForwardModel},
};
int rc = env->RegisterNatives(
c, methods, sizeof(methods) / sizeof(JNINativeMethod));
if (rc != JNI_OK) {
return rc;
}
return JNI_VERSION_1_6;
}

View File

@ -1,214 +0,0 @@
package org.pytorch.testapp;
import android.Manifest;
import android.content.pm.PackageManager;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.SystemClock;
import android.util.Log;
import android.util.Size;
import android.view.TextureView;
import android.view.ViewStub;
import android.widget.TextView;
import android.widget.Toast;
import androidx.annotation.Nullable;
import androidx.annotation.UiThread;
import androidx.annotation.WorkerThread;
import androidx.appcompat.app.AppCompatActivity;
import androidx.camera.core.CameraX;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageAnalysisConfig;
import androidx.camera.core.ImageProxy;
import androidx.camera.core.Preview;
import androidx.camera.core.PreviewConfig;
import androidx.core.app.ActivityCompat;
import java.nio.FloatBuffer;
import org.pytorch.IValue;
import org.pytorch.MemoryFormat;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
public class CameraActivity extends AppCompatActivity {
private static final String TAG = BuildConfig.LOGCAT_TAG;
private static final int TEXT_TRIM_SIZE = 4096;
private static final int REQUEST_CODE_CAMERA_PERMISSION = 200;
private static final String[] PERMISSIONS = {Manifest.permission.CAMERA};
private long mLastAnalysisResultTime;
protected HandlerThread mBackgroundThread;
protected Handler mBackgroundHandler;
protected Handler mUIHandler;
private TextView mTextView;
private StringBuilder mTextViewStringBuilder = new StringBuilder();
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_camera);
mTextView = findViewById(R.id.text);
mUIHandler = new Handler(getMainLooper());
startBackgroundThread();
if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA)
!= PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, PERMISSIONS, REQUEST_CODE_CAMERA_PERMISSION);
} else {
setupCameraX();
}
}
@Override
protected void onPostCreate(@Nullable Bundle savedInstanceState) {
super.onPostCreate(savedInstanceState);
startBackgroundThread();
}
protected void startBackgroundThread() {
mBackgroundThread = new HandlerThread("ModuleActivity");
mBackgroundThread.start();
mBackgroundHandler = new Handler(mBackgroundThread.getLooper());
}
@Override
protected void onDestroy() {
stopBackgroundThread();
super.onDestroy();
}
protected void stopBackgroundThread() {
mBackgroundThread.quitSafely();
try {
mBackgroundThread.join();
mBackgroundThread = null;
mBackgroundHandler = null;
} catch (InterruptedException e) {
Log.e(TAG, "Error on stopping background thread", e);
}
}
@Override
public void onRequestPermissionsResult(
int requestCode, String[] permissions, int[] grantResults) {
if (requestCode == REQUEST_CODE_CAMERA_PERMISSION) {
if (grantResults[0] == PackageManager.PERMISSION_DENIED) {
Toast.makeText(
this,
"You can't use image classification example without granting CAMERA permission",
Toast.LENGTH_LONG)
.show();
finish();
} else {
setupCameraX();
}
}
}
private static final int TENSOR_WIDTH = 224;
private static final int TENSOR_HEIGHT = 224;
private void setupCameraX() {
final TextureView textureView =
((ViewStub) findViewById(R.id.camera_texture_view_stub))
.inflate()
.findViewById(R.id.texture_view);
final PreviewConfig previewConfig = new PreviewConfig.Builder().build();
final Preview preview = new Preview(previewConfig);
preview.setOnPreviewOutputUpdateListener(
new Preview.OnPreviewOutputUpdateListener() {
@Override
public void onUpdated(Preview.PreviewOutput output) {
textureView.setSurfaceTexture(output.getSurfaceTexture());
}
});
final ImageAnalysisConfig imageAnalysisConfig =
new ImageAnalysisConfig.Builder()
.setTargetResolution(new Size(TENSOR_WIDTH, TENSOR_HEIGHT))
.setCallbackHandler(mBackgroundHandler)
.setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
.build();
final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);
imageAnalysis.setAnalyzer(
new ImageAnalysis.Analyzer() {
@Override
public void analyze(ImageProxy image, int rotationDegrees) {
if (SystemClock.elapsedRealtime() - mLastAnalysisResultTime < 500) {
return;
}
final Result result = CameraActivity.this.analyzeImage(image, rotationDegrees);
if (result != null) {
mLastAnalysisResultTime = SystemClock.elapsedRealtime();
CameraActivity.this.runOnUiThread(
new Runnable() {
@Override
public void run() {
CameraActivity.this.handleResult(result);
}
});
}
}
});
CameraX.bindToLifecycle(this, preview, imageAnalysis);
}
private Module mModule;
private FloatBuffer mInputTensorBuffer;
private Tensor mInputTensor;
@WorkerThread
@Nullable
protected Result analyzeImage(ImageProxy image, int rotationDegrees) {
Log.i(TAG, String.format("analyzeImage(%s, %d)", image, rotationDegrees));
if (mModule == null) {
Log.i(TAG, "Loading module from asset '" + BuildConfig.MODULE_ASSET_NAME + "'");
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
mInputTensorBuffer = Tensor.allocateFloatBuffer(3 * TENSOR_WIDTH * TENSOR_HEIGHT);
mInputTensor =
Tensor.fromBlob(mInputTensorBuffer, new long[] {1, 3, TENSOR_WIDTH, TENSOR_HEIGHT});
}
final long startTime = SystemClock.elapsedRealtime();
TensorImageUtils.imageYUV420CenterCropToFloatBuffer(
image.getImage(),
rotationDegrees,
TENSOR_WIDTH,
TENSOR_HEIGHT,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB,
mInputTensorBuffer,
0,
MemoryFormat.CHANNELS_LAST);
final long moduleForwardStartTime = SystemClock.elapsedRealtime();
final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
final float[] scores = outputTensor.getDataAsFloatArray();
final long analysisDuration = SystemClock.elapsedRealtime() - startTime;
return new Result(scores, moduleForwardDuration, analysisDuration);
}
@UiThread
protected void handleResult(Result result) {
int ixs[] = Utils.topK(result.scores, 1);
String message =
String.format(
"forwardDuration:%d class:%s",
result.moduleForwardDuration, Constants.IMAGENET_CLASSES[ixs[0]]);
Log.i(TAG, message);
mTextViewStringBuilder.insert(0, '\n').insert(0, message);
if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) {
mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length());
}
mTextView.setText(mTextViewStringBuilder.toString());
}
}

View File

@ -1,22 +0,0 @@
package org.pytorch.testapp;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
public final class LibtorchNativeClient {
public static void loadAndForwardModel(final String modelPath) {
NativePeer.loadAndForwardModel(modelPath);
}
private static class NativePeer {
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_testapp_jni");
}
private static native void loadAndForwardModel(final String modelPath);
}
}

View File

@ -1,171 +0,0 @@
package org.pytorch.testapp;
import android.content.Context;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.SystemClock;
import android.util.Log;
import android.widget.TextView;
import androidx.annotation.Nullable;
import androidx.annotation.UiThread;
import androidx.annotation.WorkerThread;
import androidx.appcompat.app.AppCompatActivity;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.FloatBuffer;
import org.pytorch.Device;
import org.pytorch.IValue;
import org.pytorch.MemoryFormat;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;
public class MainActivity extends AppCompatActivity {
private static final String TAG = BuildConfig.LOGCAT_TAG;
private static final int TEXT_TRIM_SIZE = 4096;
private TextView mTextView;
protected HandlerThread mBackgroundThread;
protected Handler mBackgroundHandler;
private Module mModule;
private FloatBuffer mInputTensorBuffer;
private Tensor mInputTensor;
private StringBuilder mTextViewStringBuilder = new StringBuilder();
private final Runnable mModuleForwardRunnable =
new Runnable() {
@Override
public void run() {
final Result result = doModuleForward();
runOnUiThread(
new Runnable() {
@Override
public void run() {
handleResult(result);
if (mBackgroundHandler != null) {
mBackgroundHandler.post(mModuleForwardRunnable);
}
}
});
}
};
public static String assetFilePath(Context context, String assetName) {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
Log.e(TAG, "Error process asset " + assetName + " to file path");
}
return null;
}
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
if (BuildConfig.NATIVE_BUILD) {
final String modelFileAbsoluteFilePath =
new File(assetFilePath(this, BuildConfig.MODULE_ASSET_NAME)).getAbsolutePath();
LibtorchNativeClient.loadAndForwardModel(modelFileAbsoluteFilePath);
return;
}
setContentView(R.layout.activity_main);
mTextView = findViewById(R.id.text);
startBackgroundThread();
mBackgroundHandler.post(mModuleForwardRunnable);
}
protected void startBackgroundThread() {
mBackgroundThread = new HandlerThread(TAG + "_bg");
mBackgroundThread.start();
mBackgroundHandler = new Handler(mBackgroundThread.getLooper());
}
@Override
protected void onDestroy() {
stopBackgroundThread();
super.onDestroy();
}
protected void stopBackgroundThread() {
mBackgroundThread.quitSafely();
try {
mBackgroundThread.join();
mBackgroundThread = null;
mBackgroundHandler = null;
} catch (InterruptedException e) {
Log.e(TAG, "Error stopping background thread", e);
}
}
@WorkerThread
@Nullable
protected Result doModuleForward() {
if (mModule == null) {
final long[] shape = BuildConfig.INPUT_TENSOR_SHAPE;
long numElements = 1;
for (int i = 0; i < shape.length; i++) {
numElements *= shape[i];
}
mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements);
mInputTensor =
Tensor.fromBlob(
mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE, MemoryFormat.CHANNELS_LAST);
PyTorchAndroid.setNumThreads(1);
mModule =
BuildConfig.USE_VULKAN_DEVICE
? PyTorchAndroid.loadModuleFromAsset(
getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
: PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
}
final long startTime = SystemClock.elapsedRealtime();
final long moduleForwardStartTime = SystemClock.elapsedRealtime();
final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
final float[] scores = outputTensor.getDataAsFloatArray();
final long analysisDuration = SystemClock.elapsedRealtime() - startTime;
return new Result(scores, moduleForwardDuration, analysisDuration);
}
static class Result {
private final float[] scores;
private final long totalDuration;
private final long moduleForwardDuration;
public Result(float[] scores, long moduleForwardDuration, long totalDuration) {
this.scores = scores;
this.moduleForwardDuration = moduleForwardDuration;
this.totalDuration = totalDuration;
}
}
@UiThread
protected void handleResult(Result result) {
String message = String.format("forwardDuration:%d", result.moduleForwardDuration);
mTextViewStringBuilder.insert(0, '\n').insert(0, message);
if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) {
mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length());
}
mTextView.setText(mTextViewStringBuilder.toString());
}
}

View File

@ -1,14 +0,0 @@
package org.pytorch.testapp;
class Result {
public final float[] scores;
public final long totalDuration;
public final long moduleForwardDuration;
public Result(float[] scores, long moduleForwardDuration, long totalDuration) {
this.scores = scores;
this.moduleForwardDuration = moduleForwardDuration;
this.totalDuration = totalDuration;
}
}

View File

@ -1,28 +0,0 @@
package org.pytorch.testapp;
import java.util.Arrays;
public class Utils {
public static int[] topK(float[] a, final int topk) {
float values[] = new float[topk];
Arrays.fill(values, -Float.MAX_VALUE);
int ixs[] = new int[topk];
Arrays.fill(ixs, -1);
for (int i = 0; i < a.length; i++) {
for (int j = 0; j < topk; j++) {
if (a[i] > values[j]) {
for (int k = topk - 1; k >= j + 1; k--) {
values[k] = values[k - 1];
ixs[k] = ixs[k - 1];
}
values[j] = a[i];
ixs[j] = i;
break;
}
}
}
return ixs;
}
}

View File

@ -1,23 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<FrameLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".CameraActivity">
<ViewStub
android:id="@+id/camera_texture_view_stub"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout="@layout/texture_view"/>
<TextView
android:id="@+id/text"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_gravity="top"
android:textSize="16sp"
android:textStyle="bold"
android:textColor="#ff0000"/>
</FrameLayout>

View File

@ -1,17 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<TextView
android:id="@+id/text"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_gravity="top"
android:textSize="14sp"
android:background="@android:color/black"
android:textColor="@android:color/white" />
</FrameLayout>

View File

@ -1,5 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<TextureView xmlns:android="http://schemas.android.com/apk/res/android"
android:id="@+id/texture_view"
android:layout_width="match_parent"
android:layout_height="0dp" />

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 KiB

View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="colorPrimary">#008577</color>
<color name="colorPrimaryDark">#00574B</color>
<color name="colorAccent">#D81B60</color>
</resources>

View File

@ -1,3 +0,0 @@
<resources>
<string name="app_name">PyTest</string>
</resources>

View File

@ -1,11 +0,0 @@
<resources>
<!-- Base application theme. -->
<style name="AppTheme" parent="Theme.AppCompat.Light.DarkActionBar">
<!-- Customize your theme here. -->
<item name="colorPrimary">@color/colorPrimary</item>
<item name="colorPrimaryDark">@color/colorPrimaryDark</item>
<item name="colorAccent">@color/colorAccent</item>
</style>
</resources>

View File

@ -1,24 +0,0 @@
from torchvision import models
import torch
print(torch.version.__version__)
resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
resnet18.eval()
resnet18_traced = torch.jit.trace(resnet18, torch.rand(1, 3, 224, 224)).save(
"app/src/main/assets/resnet18.pt"
)
resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
resnet50.eval()
torch.jit.trace(resnet50, torch.rand(1, 3, 224, 224)).save(
"app/src/main/assets/resnet50.pt"
)
mobilenet2q = models.quantization.mobilenet_v2(pretrained=True, quantize=True)
mobilenet2q.eval()
torch.jit.trace(mobilenet2q, torch.rand(1, 3, 224, 224)).save(
"app/src/main/assets/mobilenet2q.pt"
)

View File

@ -1,27 +0,0 @@
"""
This is a script for PyTorch Android custom selective build test. It prepares
MobileNetV2 TorchScript model, and dumps root ops used by the model for custom
build script to create a tailored build which only contains these used ops.
"""
import yaml
from torchvision import models
import torch
# Download and trace the model.
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
model.eval()
example = torch.rand(1, 3, 224, 224)
# TODO: create script model with `torch.jit.script`
traced_script_module = torch.jit.trace(model, example)
# Save traced TorchScript model.
traced_script_module.save("MobileNetV2.pt")
# Dump root ops used by the model (for custom build optimization).
ops = torch.jit.export_opnames(traced_script_module)
with open("MobileNetV2.yaml", "w") as output:
yaml.dump(ops, output)