improve android instrumentation test and update README

Added tests for lite interpreter. By default the run_test.sh will use lite interpreter, unless manually set BUILD_LITE_INTERPRETER=0

Also fixed model generation script for android instrumentation test and README.

Verified test can pass for both full jit and lite interpreter. Also tested on emulator and real device using different abis.

Lite interpreter
```
./scripts/build_pytorch_android.sh x86
./android/run_tests.sh
```

Full JIT
```
BUILD_LITE_INTERPRETER=0 ./scripts/build_pytorch_android.sh x86
BUILD_LITE_INTERPRETER=0 ./android/run_tests.sh
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72736
This commit is contained in:
Linbin Yu 2022-02-22 08:05:33 +00:00 committed by PyTorch MergeBot
parent c2255c36ec
commit 99bcadced4
10 changed files with 122 additions and 28 deletions

View File

@ -14,9 +14,16 @@ repositories {
jcenter()
}
# lite interpreter build
dependencies {
implementation 'org.pytorch:pytorch_android:1.6.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.6.0'
implementation 'org.pytorch:pytorch_android_lite:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.10.0'
}
# full jit build
dependencies {
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
}
```
@ -32,6 +39,15 @@ repositories {
}
}
# lite interpreter build
dependencies {
...
implementation 'org.pytorch:pytorch_android_lite:1.12.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.12.0-SNAPSHOT'
...
}
# full jit build
dependencies {
...
implementation 'org.pytorch:pytorch_android:1.12.0-SNAPSHOT'
@ -68,7 +84,7 @@ They are specified as environment variables:
`ANDROID_HOME` - path to [Android SDK](https://developer.android.com/studio/command-line/sdkmanager.html)
`ANDROID_NDK` - path to [Android NDK](https://developer.android.com/studio/projects/install-ndk)
`ANDROID_NDK` - path to [Android NDK](https://developer.android.com/studio/projects/install-ndk). It's recommended to use NDK 21.x.
`GRADLE_HOME` - path to [gradle](https://gradle.org/releases/)
@ -133,7 +149,7 @@ android {
}
dependencies {
extractForNativeBuild('org.pytorch:pytorch_android:1.6.0')
extractForNativeBuild('org.pytorch:pytorch_android:1.10.0')
}
task extractAARForNativeBuild {

View File

@ -50,7 +50,17 @@ android {
}
androidTest {
java {
exclude 'org/pytorch/PytorchHostTests.java'
if(System.env.BUILD_LITE_INTERPRETER == '0') {
println 'Build test for full jit (pytorch_jni)'
exclude 'org/pytorch/PytorchHostTests.java'
exclude 'org/pytorch/PytorchLiteInstrumentedTests.java'
exclude 'org/pytorch/suite/PytorchLiteInstrumentedTestSuite.java'
} else {
println 'Build test for lite interpreter (pytorch_jni_lite)'
exclude 'org/pytorch/PytorchHostTests.java'
exclude 'org/pytorch/PytorchInstrumentedTests.java'
exclude 'org/pytorch/suite/PytorchInstrumentedTestSuite.java'
}
}
}
}

View File

@ -1,4 +1,6 @@
import torch
from torch import Tensor
from typing import Dict, List, Tuple, Optional
OUTPUT_DIR = "src/androidTest/assets/"
@ -7,7 +9,8 @@ def scriptAndSave(module, fileName):
script_module = torch.jit.script(module)
print(script_module.graph)
outputFileName = OUTPUT_DIR + fileName
script_module.save(outputFileName)
# note that the lite interpreter model can also be used in full JIT
script_module._save_for_lite_interpreter(outputFileName)
print("Saved to " + outputFileName)
print('=' * 80)

View File

@ -25,6 +25,7 @@ sourceSets {
java {
srcDir '../src/androidTest/java'
exclude '**/PytorchInstrumented*'
exclude '**/PytorchLiteInstrumented*'
}
resources.srcDirs = ["../src/androidTest/assets"]
}

View File

@ -10,7 +10,11 @@ import java.util.Objects;
public class PytorchHostTests extends PytorchTestBase {
@Override
protected String assetFilePath(String assetName) throws IOException {
protected Module loadModel(String path) throws IOException {
return Module.load(assetFilePath(path));
}
private String assetFilePath(String assetName) throws IOException {
Path tempFile = Files.createTempFile("test", ".pt");
try (InputStream resource =
Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) {

View File

@ -14,7 +14,11 @@ import org.junit.runner.RunWith;
public class PytorchInstrumentedTests extends PytorchTestBase {
@Override
protected String assetFilePath(String assetName) throws IOException {
protected Module loadModel(String path) throws IOException {
return Module.load(assetFilePath(path));
}
private String assetFilePath(String assetName) throws IOException {
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
File file = new File(appContext.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
@ -35,4 +39,5 @@ public class PytorchInstrumentedTests extends PytorchTestBase {
throw e;
}
}
}

View File

@ -0,0 +1,46 @@
package org.pytorch;
import android.content.Context;
import androidx.test.InstrumentationRegistry;
import androidx.test.runner.AndroidJUnit4;
import org.junit.runner.RunWith;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
@RunWith(AndroidJUnit4.class)
public class PytorchLiteInstrumentedTests extends PytorchTestBase {
@Override
protected Module loadModel(String path) throws IOException {
return LiteModuleLoader.load(assetFilePath(path));
}
private String assetFilePath(String assetName) throws IOException {
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
File file = new File(appContext.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = appContext.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
throw e;
}
}
}

View File

@ -16,7 +16,7 @@ public abstract class PytorchTestBase {
@Test
public void testForwardNull() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
assertTrue(input.isTensor());
final IValue output = module.forward(input);
@ -25,7 +25,7 @@ public abstract class PytorchTestBase {
@Test
public void testEqBool() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (boolean value : new boolean[] {false, true}) {
final IValue input = IValue.from(value);
assertTrue(input.isBool());
@ -38,7 +38,7 @@ public abstract class PytorchTestBase {
@Test
public void testEqInt() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
final IValue input = IValue.from(value);
assertTrue(input.isLong());
@ -51,7 +51,7 @@ public abstract class PytorchTestBase {
@Test
public void testEqFloat() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
double[] values =
new double[] {
-Double.MAX_VALUE,
@ -86,7 +86,7 @@ public abstract class PytorchTestBase {
}
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(inputTensor);
assertTrue(input.isTensor());
assertTrue(inputTensor == input.toTensor());
@ -103,7 +103,7 @@ public abstract class PytorchTestBase {
@Test
public void testEqDictIntKeyIntValue() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<Long, IValue> inputMap = new HashMap<>();
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
@ -127,7 +127,7 @@ public abstract class PytorchTestBase {
@Test
public void testEqDictStrKeyIntValue() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<String, IValue> inputMap = new HashMap<>();
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
@ -151,7 +151,7 @@ public abstract class PytorchTestBase {
@Test
public void testListIntSumReturnTuple() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (int n : new int[] {0, 1, 128}) {
long[] a = new long[n];
@ -178,7 +178,7 @@ public abstract class PytorchTestBase {
@Test
public void testOptionalIntIsNone() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
@ -186,7 +186,7 @@ public abstract class PytorchTestBase {
@Test
public void testIntEq0None() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
@ -194,7 +194,7 @@ public abstract class PytorchTestBase {
@Test(expected = IllegalArgumentException.class)
public void testRunUndefinedMethod() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
module.runMethod("test_undefined_method_throws_exception");
}
@ -241,7 +241,7 @@ public abstract class PytorchTestBase {
@Test
public void testEqString() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[] {
"smoketest",
@ -260,7 +260,7 @@ public abstract class PytorchTestBase {
@Test
public void testStr3Concat() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[] {
"smoketest",
@ -281,7 +281,7 @@ public abstract class PytorchTestBase {
@Test
public void testEmptyShape() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final long someNumber = 43;
final IValue input = IValue.from(Tensor.fromBlob(new long[] {someNumber}, new long[] {}));
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
@ -293,7 +293,7 @@ public abstract class PytorchTestBase {
@Test
public void testAliasWithOffset() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testAliasWithOffset");
assertTrue(output.isTensorList());
Tensor[] tensors = output.toTensorList();
@ -303,7 +303,7 @@ public abstract class PytorchTestBase {
@Test
public void testNonContiguous() throws IOException {
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testNonContiguous");
assertTrue(output.isTensor());
Tensor value = output.toTensor();
@ -316,7 +316,7 @@ public abstract class PytorchTestBase {
long[] inputShape = new long[] {1, 3, 2, 2};
long[] data = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
assertIValueTensor(
outputNCHW,
@ -334,7 +334,7 @@ public abstract class PytorchTestBase {
long[] dataNHWDC = new long[] {1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
@ -358,7 +358,7 @@ public abstract class PytorchTestBase {
long[] dataWeightOHWI = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW =
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
@ -389,5 +389,5 @@ public abstract class PytorchTestBase {
assertArrayEquals(expectedData, t.getDataAsLongArray());
}
protected abstract String assetFilePath(String assetName) throws IOException;
protected abstract Module loadModel(String assetName) throws IOException;
}

View File

@ -0,0 +1,9 @@
package org.pytorch.suite;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.pytorch.PytorchLiteInstrumentedTests;
@RunWith(Suite.class)
@Suite.SuiteClasses({PytorchLiteInstrumentedTests.class})
public class PytorchLiteInstrumentedTestSuite {}