mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
improve android instrumentation test and update README
Added tests for lite interpreter. By default the run_test.sh will use lite interpreter, unless manually set BUILD_LITE_INTERPRETER=0 Also fixed model generation script for android instrumentation test and README. Verified test can pass for both full jit and lite interpreter. Also tested on emulator and real device using different abis. Lite interpreter ``` ./scripts/build_pytorch_android.sh x86 ./android/run_tests.sh ``` Full JIT ``` BUILD_LITE_INTERPRETER=0 ./scripts/build_pytorch_android.sh x86 BUILD_LITE_INTERPRETER=0 ./android/run_tests.sh ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/72736
This commit is contained in:
parent
c2255c36ec
commit
99bcadced4
|
|
@ -14,9 +14,16 @@ repositories {
|
|||
jcenter()
|
||||
}
|
||||
|
||||
# lite interpreter build
|
||||
dependencies {
|
||||
implementation 'org.pytorch:pytorch_android:1.6.0'
|
||||
implementation 'org.pytorch:pytorch_android_torchvision:1.6.0'
|
||||
implementation 'org.pytorch:pytorch_android_lite:1.10.0'
|
||||
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.10.0'
|
||||
}
|
||||
|
||||
# full jit build
|
||||
dependencies {
|
||||
implementation 'org.pytorch:pytorch_android:1.10.0'
|
||||
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -32,6 +39,15 @@ repositories {
|
|||
}
|
||||
}
|
||||
|
||||
# lite interpreter build
|
||||
dependencies {
|
||||
...
|
||||
implementation 'org.pytorch:pytorch_android_lite:1.12.0-SNAPSHOT'
|
||||
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.12.0-SNAPSHOT'
|
||||
...
|
||||
}
|
||||
|
||||
# full jit build
|
||||
dependencies {
|
||||
...
|
||||
implementation 'org.pytorch:pytorch_android:1.12.0-SNAPSHOT'
|
||||
|
|
@ -68,7 +84,7 @@ They are specified as environment variables:
|
|||
|
||||
`ANDROID_HOME` - path to [Android SDK](https://developer.android.com/studio/command-line/sdkmanager.html)
|
||||
|
||||
`ANDROID_NDK` - path to [Android NDK](https://developer.android.com/studio/projects/install-ndk)
|
||||
`ANDROID_NDK` - path to [Android NDK](https://developer.android.com/studio/projects/install-ndk). It's recommended to use NDK 21.x.
|
||||
|
||||
`GRADLE_HOME` - path to [gradle](https://gradle.org/releases/)
|
||||
|
||||
|
|
@ -133,7 +149,7 @@ android {
|
|||
}
|
||||
|
||||
dependencies {
|
||||
extractForNativeBuild('org.pytorch:pytorch_android:1.6.0')
|
||||
extractForNativeBuild('org.pytorch:pytorch_android:1.10.0')
|
||||
}
|
||||
|
||||
task extractAARForNativeBuild {
|
||||
|
|
|
|||
|
|
@ -50,7 +50,17 @@ android {
|
|||
}
|
||||
androidTest {
|
||||
java {
|
||||
exclude 'org/pytorch/PytorchHostTests.java'
|
||||
if(System.env.BUILD_LITE_INTERPRETER == '0') {
|
||||
println 'Build test for full jit (pytorch_jni)'
|
||||
exclude 'org/pytorch/PytorchHostTests.java'
|
||||
exclude 'org/pytorch/PytorchLiteInstrumentedTests.java'
|
||||
exclude 'org/pytorch/suite/PytorchLiteInstrumentedTestSuite.java'
|
||||
} else {
|
||||
println 'Build test for lite interpreter (pytorch_jni_lite)'
|
||||
exclude 'org/pytorch/PytorchHostTests.java'
|
||||
exclude 'org/pytorch/PytorchInstrumentedTests.java'
|
||||
exclude 'org/pytorch/suite/PytorchInstrumentedTestSuite.java'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
OUTPUT_DIR = "src/androidTest/assets/"
|
||||
|
||||
|
|
@ -7,7 +9,8 @@ def scriptAndSave(module, fileName):
|
|||
script_module = torch.jit.script(module)
|
||||
print(script_module.graph)
|
||||
outputFileName = OUTPUT_DIR + fileName
|
||||
script_module.save(outputFileName)
|
||||
# note that the lite interpreter model can also be used in full JIT
|
||||
script_module._save_for_lite_interpreter(outputFileName)
|
||||
print("Saved to " + outputFileName)
|
||||
print('=' * 80)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ sourceSets {
|
|||
java {
|
||||
srcDir '../src/androidTest/java'
|
||||
exclude '**/PytorchInstrumented*'
|
||||
exclude '**/PytorchLiteInstrumented*'
|
||||
}
|
||||
resources.srcDirs = ["../src/androidTest/assets"]
|
||||
}
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -10,7 +10,11 @@ import java.util.Objects;
|
|||
public class PytorchHostTests extends PytorchTestBase {
|
||||
|
||||
@Override
|
||||
protected String assetFilePath(String assetName) throws IOException {
|
||||
protected Module loadModel(String path) throws IOException {
|
||||
return Module.load(assetFilePath(path));
|
||||
}
|
||||
|
||||
private String assetFilePath(String assetName) throws IOException {
|
||||
Path tempFile = Files.createTempFile("test", ".pt");
|
||||
try (InputStream resource =
|
||||
Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,11 @@ import org.junit.runner.RunWith;
|
|||
public class PytorchInstrumentedTests extends PytorchTestBase {
|
||||
|
||||
@Override
|
||||
protected String assetFilePath(String assetName) throws IOException {
|
||||
protected Module loadModel(String path) throws IOException {
|
||||
return Module.load(assetFilePath(path));
|
||||
}
|
||||
|
||||
private String assetFilePath(String assetName) throws IOException {
|
||||
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||
File file = new File(appContext.getFilesDir(), assetName);
|
||||
if (file.exists() && file.length() > 0) {
|
||||
|
|
@ -35,4 +39,5 @@ public class PytorchInstrumentedTests extends PytorchTestBase {
|
|||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
package org.pytorch;
|
||||
|
||||
import android.content.Context;
|
||||
|
||||
import androidx.test.InstrumentationRegistry;
|
||||
import androidx.test.runner.AndroidJUnit4;
|
||||
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class PytorchLiteInstrumentedTests extends PytorchTestBase {
|
||||
|
||||
@Override
|
||||
protected Module loadModel(String path) throws IOException {
|
||||
return LiteModuleLoader.load(assetFilePath(path));
|
||||
}
|
||||
|
||||
private String assetFilePath(String assetName) throws IOException {
|
||||
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||
File file = new File(appContext.getFilesDir(), assetName);
|
||||
if (file.exists() && file.length() > 0) {
|
||||
return file.getAbsolutePath();
|
||||
}
|
||||
|
||||
try (InputStream is = appContext.getAssets().open(assetName)) {
|
||||
try (OutputStream os = new FileOutputStream(file)) {
|
||||
byte[] buffer = new byte[4 * 1024];
|
||||
int read;
|
||||
while ((read = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, read);
|
||||
}
|
||||
os.flush();
|
||||
}
|
||||
return file.getAbsolutePath();
|
||||
} catch (IOException e) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -16,7 +16,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testForwardNull() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
|
||||
assertTrue(input.isTensor());
|
||||
final IValue output = module.forward(input);
|
||||
|
|
@ -25,7 +25,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testEqBool() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
for (boolean value : new boolean[] {false, true}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isBool());
|
||||
|
|
@ -38,7 +38,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testEqInt() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
|
||||
final IValue input = IValue.from(value);
|
||||
assertTrue(input.isLong());
|
||||
|
|
@ -51,7 +51,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testEqFloat() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
double[] values =
|
||||
new double[] {
|
||||
-Double.MAX_VALUE,
|
||||
|
|
@ -86,7 +86,7 @@ public abstract class PytorchTestBase {
|
|||
}
|
||||
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
|
||||
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue input = IValue.from(inputTensor);
|
||||
assertTrue(input.isTensor());
|
||||
assertTrue(inputTensor == input.toTensor());
|
||||
|
|
@ -103,7 +103,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testEqDictIntKeyIntValue() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final Map<Long, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
|
||||
|
|
@ -127,7 +127,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testEqDictStrKeyIntValue() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final Map<String, IValue> inputMap = new HashMap<>();
|
||||
|
||||
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
|
||||
|
|
@ -151,7 +151,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testListIntSumReturnTuple() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
for (int n : new int[] {0, 1, 128}) {
|
||||
long[] a = new long[n];
|
||||
|
|
@ -178,7 +178,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testOptionalIntIsNone() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
|
||||
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
|
||||
|
|
@ -186,7 +186,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testIntEq0None() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
|
||||
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
|
||||
|
|
@ -194,7 +194,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testRunUndefinedMethod() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
module.runMethod("test_undefined_method_throws_exception");
|
||||
}
|
||||
|
||||
|
|
@ -241,7 +241,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testEqString() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
|
|
@ -260,7 +260,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testStr3Concat() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
String[] values =
|
||||
new String[] {
|
||||
"smoketest",
|
||||
|
|
@ -281,7 +281,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testEmptyShape() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final long someNumber = 43;
|
||||
final IValue input = IValue.from(Tensor.fromBlob(new long[] {someNumber}, new long[] {}));
|
||||
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
|
||||
|
|
@ -293,7 +293,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testAliasWithOffset() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue output = module.runMethod("testAliasWithOffset");
|
||||
assertTrue(output.isTensorList());
|
||||
Tensor[] tensors = output.toTensorList();
|
||||
|
|
@ -303,7 +303,7 @@ public abstract class PytorchTestBase {
|
|||
|
||||
@Test
|
||||
public void testNonContiguous() throws IOException {
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue output = module.runMethod("testNonContiguous");
|
||||
assertTrue(output.isTensor());
|
||||
Tensor value = output.toTensor();
|
||||
|
|
@ -316,7 +316,7 @@ public abstract class PytorchTestBase {
|
|||
long[] inputShape = new long[] {1, 3, 2, 2};
|
||||
long[] data = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
|
||||
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
|
||||
assertIValueTensor(
|
||||
outputNCHW,
|
||||
|
|
@ -334,7 +334,7 @@ public abstract class PytorchTestBase {
|
|||
long[] dataNHWDC = new long[] {1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
|
||||
|
||||
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
|
||||
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
|
||||
|
||||
|
|
@ -358,7 +358,7 @@ public abstract class PytorchTestBase {
|
|||
long[] dataWeightOHWI = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
|
||||
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
|
||||
|
||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
final IValue outputNCHW =
|
||||
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
|
||||
|
|
@ -389,5 +389,5 @@ public abstract class PytorchTestBase {
|
|||
assertArrayEquals(expectedData, t.getDataAsLongArray());
|
||||
}
|
||||
|
||||
protected abstract String assetFilePath(String assetName) throws IOException;
|
||||
protected abstract Module loadModel(String assetName) throws IOException;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
package org.pytorch.suite;
|
||||
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.pytorch.PytorchLiteInstrumentedTests;
|
||||
|
||||
@RunWith(Suite.class)
|
||||
@Suite.SuiteClasses({PytorchLiteInstrumentedTests.class})
|
||||
public class PytorchLiteInstrumentedTestSuite {}
|
||||
Loading…
Reference in New Issue
Block a user