mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
improve android instrumentation test and update README
Added tests for lite interpreter. By default the run_test.sh will use lite interpreter, unless manually set BUILD_LITE_INTERPRETER=0 Also fixed model generation script for android instrumentation test and README. Verified test can pass for both full jit and lite interpreter. Also tested on emulator and real device using different abis. Lite interpreter ``` ./scripts/build_pytorch_android.sh x86 ./android/run_tests.sh ``` Full JIT ``` BUILD_LITE_INTERPRETER=0 ./scripts/build_pytorch_android.sh x86 BUILD_LITE_INTERPRETER=0 ./android/run_tests.sh ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/72736
This commit is contained in:
parent
c2255c36ec
commit
99bcadced4
|
|
@ -14,9 +14,16 @@ repositories {
|
||||||
jcenter()
|
jcenter()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# lite interpreter build
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation 'org.pytorch:pytorch_android:1.6.0'
|
implementation 'org.pytorch:pytorch_android_lite:1.10.0'
|
||||||
implementation 'org.pytorch:pytorch_android_torchvision:1.6.0'
|
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.10.0'
|
||||||
|
}
|
||||||
|
|
||||||
|
# full jit build
|
||||||
|
dependencies {
|
||||||
|
implementation 'org.pytorch:pytorch_android:1.10.0'
|
||||||
|
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -32,6 +39,15 @@ repositories {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# lite interpreter build
|
||||||
|
dependencies {
|
||||||
|
...
|
||||||
|
implementation 'org.pytorch:pytorch_android_lite:1.12.0-SNAPSHOT'
|
||||||
|
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.12.0-SNAPSHOT'
|
||||||
|
...
|
||||||
|
}
|
||||||
|
|
||||||
|
# full jit build
|
||||||
dependencies {
|
dependencies {
|
||||||
...
|
...
|
||||||
implementation 'org.pytorch:pytorch_android:1.12.0-SNAPSHOT'
|
implementation 'org.pytorch:pytorch_android:1.12.0-SNAPSHOT'
|
||||||
|
|
@ -68,7 +84,7 @@ They are specified as environment variables:
|
||||||
|
|
||||||
`ANDROID_HOME` - path to [Android SDK](https://developer.android.com/studio/command-line/sdkmanager.html)
|
`ANDROID_HOME` - path to [Android SDK](https://developer.android.com/studio/command-line/sdkmanager.html)
|
||||||
|
|
||||||
`ANDROID_NDK` - path to [Android NDK](https://developer.android.com/studio/projects/install-ndk)
|
`ANDROID_NDK` - path to [Android NDK](https://developer.android.com/studio/projects/install-ndk). It's recommended to use NDK 21.x.
|
||||||
|
|
||||||
`GRADLE_HOME` - path to [gradle](https://gradle.org/releases/)
|
`GRADLE_HOME` - path to [gradle](https://gradle.org/releases/)
|
||||||
|
|
||||||
|
|
@ -133,7 +149,7 @@ android {
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
extractForNativeBuild('org.pytorch:pytorch_android:1.6.0')
|
extractForNativeBuild('org.pytorch:pytorch_android:1.10.0')
|
||||||
}
|
}
|
||||||
|
|
||||||
task extractAARForNativeBuild {
|
task extractAARForNativeBuild {
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,17 @@ android {
|
||||||
}
|
}
|
||||||
androidTest {
|
androidTest {
|
||||||
java {
|
java {
|
||||||
exclude 'org/pytorch/PytorchHostTests.java'
|
if(System.env.BUILD_LITE_INTERPRETER == '0') {
|
||||||
|
println 'Build test for full jit (pytorch_jni)'
|
||||||
|
exclude 'org/pytorch/PytorchHostTests.java'
|
||||||
|
exclude 'org/pytorch/PytorchLiteInstrumentedTests.java'
|
||||||
|
exclude 'org/pytorch/suite/PytorchLiteInstrumentedTestSuite.java'
|
||||||
|
} else {
|
||||||
|
println 'Build test for lite interpreter (pytorch_jni_lite)'
|
||||||
|
exclude 'org/pytorch/PytorchHostTests.java'
|
||||||
|
exclude 'org/pytorch/PytorchInstrumentedTests.java'
|
||||||
|
exclude 'org/pytorch/suite/PytorchInstrumentedTestSuite.java'
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
|
||||||
OUTPUT_DIR = "src/androidTest/assets/"
|
OUTPUT_DIR = "src/androidTest/assets/"
|
||||||
|
|
||||||
|
|
@ -7,7 +9,8 @@ def scriptAndSave(module, fileName):
|
||||||
script_module = torch.jit.script(module)
|
script_module = torch.jit.script(module)
|
||||||
print(script_module.graph)
|
print(script_module.graph)
|
||||||
outputFileName = OUTPUT_DIR + fileName
|
outputFileName = OUTPUT_DIR + fileName
|
||||||
script_module.save(outputFileName)
|
# note that the lite interpreter model can also be used in full JIT
|
||||||
|
script_module._save_for_lite_interpreter(outputFileName)
|
||||||
print("Saved to " + outputFileName)
|
print("Saved to " + outputFileName)
|
||||||
print('=' * 80)
|
print('=' * 80)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ sourceSets {
|
||||||
java {
|
java {
|
||||||
srcDir '../src/androidTest/java'
|
srcDir '../src/androidTest/java'
|
||||||
exclude '**/PytorchInstrumented*'
|
exclude '**/PytorchInstrumented*'
|
||||||
|
exclude '**/PytorchLiteInstrumented*'
|
||||||
}
|
}
|
||||||
resources.srcDirs = ["../src/androidTest/assets"]
|
resources.srcDirs = ["../src/androidTest/assets"]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -10,7 +10,11 @@ import java.util.Objects;
|
||||||
public class PytorchHostTests extends PytorchTestBase {
|
public class PytorchHostTests extends PytorchTestBase {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected String assetFilePath(String assetName) throws IOException {
|
protected Module loadModel(String path) throws IOException {
|
||||||
|
return Module.load(assetFilePath(path));
|
||||||
|
}
|
||||||
|
|
||||||
|
private String assetFilePath(String assetName) throws IOException {
|
||||||
Path tempFile = Files.createTempFile("test", ".pt");
|
Path tempFile = Files.createTempFile("test", ".pt");
|
||||||
try (InputStream resource =
|
try (InputStream resource =
|
||||||
Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) {
|
Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) {
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,11 @@ import org.junit.runner.RunWith;
|
||||||
public class PytorchInstrumentedTests extends PytorchTestBase {
|
public class PytorchInstrumentedTests extends PytorchTestBase {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected String assetFilePath(String assetName) throws IOException {
|
protected Module loadModel(String path) throws IOException {
|
||||||
|
return Module.load(assetFilePath(path));
|
||||||
|
}
|
||||||
|
|
||||||
|
private String assetFilePath(String assetName) throws IOException {
|
||||||
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||||
File file = new File(appContext.getFilesDir(), assetName);
|
File file = new File(appContext.getFilesDir(), assetName);
|
||||||
if (file.exists() && file.length() > 0) {
|
if (file.exists() && file.length() > 0) {
|
||||||
|
|
@ -35,4 +39,5 @@ public class PytorchInstrumentedTests extends PytorchTestBase {
|
||||||
throw e;
|
throw e;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
package org.pytorch;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
|
||||||
|
import androidx.test.InstrumentationRegistry;
|
||||||
|
import androidx.test.runner.AndroidJUnit4;
|
||||||
|
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public class PytorchLiteInstrumentedTests extends PytorchTestBase {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Module loadModel(String path) throws IOException {
|
||||||
|
return LiteModuleLoader.load(assetFilePath(path));
|
||||||
|
}
|
||||||
|
|
||||||
|
private String assetFilePath(String assetName) throws IOException {
|
||||||
|
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||||
|
File file = new File(appContext.getFilesDir(), assetName);
|
||||||
|
if (file.exists() && file.length() > 0) {
|
||||||
|
return file.getAbsolutePath();
|
||||||
|
}
|
||||||
|
|
||||||
|
try (InputStream is = appContext.getAssets().open(assetName)) {
|
||||||
|
try (OutputStream os = new FileOutputStream(file)) {
|
||||||
|
byte[] buffer = new byte[4 * 1024];
|
||||||
|
int read;
|
||||||
|
while ((read = is.read(buffer)) != -1) {
|
||||||
|
os.write(buffer, 0, read);
|
||||||
|
}
|
||||||
|
os.flush();
|
||||||
|
}
|
||||||
|
return file.getAbsolutePath();
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -16,7 +16,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testForwardNull() throws IOException {
|
public void testForwardNull() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
|
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
|
||||||
assertTrue(input.isTensor());
|
assertTrue(input.isTensor());
|
||||||
final IValue output = module.forward(input);
|
final IValue output = module.forward(input);
|
||||||
|
|
@ -25,7 +25,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEqBool() throws IOException {
|
public void testEqBool() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
for (boolean value : new boolean[] {false, true}) {
|
for (boolean value : new boolean[] {false, true}) {
|
||||||
final IValue input = IValue.from(value);
|
final IValue input = IValue.from(value);
|
||||||
assertTrue(input.isBool());
|
assertTrue(input.isBool());
|
||||||
|
|
@ -38,7 +38,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEqInt() throws IOException {
|
public void testEqInt() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
|
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
|
||||||
final IValue input = IValue.from(value);
|
final IValue input = IValue.from(value);
|
||||||
assertTrue(input.isLong());
|
assertTrue(input.isLong());
|
||||||
|
|
@ -51,7 +51,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEqFloat() throws IOException {
|
public void testEqFloat() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
double[] values =
|
double[] values =
|
||||||
new double[] {
|
new double[] {
|
||||||
-Double.MAX_VALUE,
|
-Double.MAX_VALUE,
|
||||||
|
|
@ -86,7 +86,7 @@ public abstract class PytorchTestBase {
|
||||||
}
|
}
|
||||||
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
|
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
|
||||||
|
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
final IValue input = IValue.from(inputTensor);
|
final IValue input = IValue.from(inputTensor);
|
||||||
assertTrue(input.isTensor());
|
assertTrue(input.isTensor());
|
||||||
assertTrue(inputTensor == input.toTensor());
|
assertTrue(inputTensor == input.toTensor());
|
||||||
|
|
@ -103,7 +103,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEqDictIntKeyIntValue() throws IOException {
|
public void testEqDictIntKeyIntValue() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
final Map<Long, IValue> inputMap = new HashMap<>();
|
final Map<Long, IValue> inputMap = new HashMap<>();
|
||||||
|
|
||||||
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
|
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
|
||||||
|
|
@ -127,7 +127,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEqDictStrKeyIntValue() throws IOException {
|
public void testEqDictStrKeyIntValue() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
final Map<String, IValue> inputMap = new HashMap<>();
|
final Map<String, IValue> inputMap = new HashMap<>();
|
||||||
|
|
||||||
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
|
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
|
||||||
|
|
@ -151,7 +151,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testListIntSumReturnTuple() throws IOException {
|
public void testListIntSumReturnTuple() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
|
|
||||||
for (int n : new int[] {0, 1, 128}) {
|
for (int n : new int[] {0, 1, 128}) {
|
||||||
long[] a = new long[n];
|
long[] a = new long[n];
|
||||||
|
|
@ -178,7 +178,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testOptionalIntIsNone() throws IOException {
|
public void testOptionalIntIsNone() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
|
|
||||||
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
|
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
|
||||||
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
|
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
|
||||||
|
|
@ -186,7 +186,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testIntEq0None() throws IOException {
|
public void testIntEq0None() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
|
|
||||||
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
|
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
|
||||||
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
|
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
|
||||||
|
|
@ -194,7 +194,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test(expected = IllegalArgumentException.class)
|
||||||
public void testRunUndefinedMethod() throws IOException {
|
public void testRunUndefinedMethod() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
module.runMethod("test_undefined_method_throws_exception");
|
module.runMethod("test_undefined_method_throws_exception");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -241,7 +241,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEqString() throws IOException {
|
public void testEqString() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
String[] values =
|
String[] values =
|
||||||
new String[] {
|
new String[] {
|
||||||
"smoketest",
|
"smoketest",
|
||||||
|
|
@ -260,7 +260,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testStr3Concat() throws IOException {
|
public void testStr3Concat() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
String[] values =
|
String[] values =
|
||||||
new String[] {
|
new String[] {
|
||||||
"smoketest",
|
"smoketest",
|
||||||
|
|
@ -281,7 +281,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEmptyShape() throws IOException {
|
public void testEmptyShape() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
final long someNumber = 43;
|
final long someNumber = 43;
|
||||||
final IValue input = IValue.from(Tensor.fromBlob(new long[] {someNumber}, new long[] {}));
|
final IValue input = IValue.from(Tensor.fromBlob(new long[] {someNumber}, new long[] {}));
|
||||||
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
|
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
|
||||||
|
|
@ -293,7 +293,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAliasWithOffset() throws IOException {
|
public void testAliasWithOffset() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
final IValue output = module.runMethod("testAliasWithOffset");
|
final IValue output = module.runMethod("testAliasWithOffset");
|
||||||
assertTrue(output.isTensorList());
|
assertTrue(output.isTensorList());
|
||||||
Tensor[] tensors = output.toTensorList();
|
Tensor[] tensors = output.toTensorList();
|
||||||
|
|
@ -303,7 +303,7 @@ public abstract class PytorchTestBase {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNonContiguous() throws IOException {
|
public void testNonContiguous() throws IOException {
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
final IValue output = module.runMethod("testNonContiguous");
|
final IValue output = module.runMethod("testNonContiguous");
|
||||||
assertTrue(output.isTensor());
|
assertTrue(output.isTensor());
|
||||||
Tensor value = output.toTensor();
|
Tensor value = output.toTensor();
|
||||||
|
|
@ -316,7 +316,7 @@ public abstract class PytorchTestBase {
|
||||||
long[] inputShape = new long[] {1, 3, 2, 2};
|
long[] inputShape = new long[] {1, 3, 2, 2};
|
||||||
long[] data = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
|
long[] data = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
|
||||||
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
|
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
|
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
|
||||||
assertIValueTensor(
|
assertIValueTensor(
|
||||||
outputNCHW,
|
outputNCHW,
|
||||||
|
|
@ -334,7 +334,7 @@ public abstract class PytorchTestBase {
|
||||||
long[] dataNHWDC = new long[] {1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
|
long[] dataNHWDC = new long[] {1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
|
||||||
|
|
||||||
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
|
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
|
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
|
||||||
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
|
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
|
||||||
|
|
||||||
|
|
@ -358,7 +358,7 @@ public abstract class PytorchTestBase {
|
||||||
long[] dataWeightOHWI = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
|
long[] dataWeightOHWI = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
|
||||||
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
|
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
|
||||||
|
|
||||||
final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME));
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
|
|
||||||
final IValue outputNCHW =
|
final IValue outputNCHW =
|
||||||
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
|
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
|
||||||
|
|
@ -389,5 +389,5 @@ public abstract class PytorchTestBase {
|
||||||
assertArrayEquals(expectedData, t.getDataAsLongArray());
|
assertArrayEquals(expectedData, t.getDataAsLongArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract String assetFilePath(String assetName) throws IOException;
|
protected abstract Module loadModel(String assetName) throws IOException;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
package org.pytorch.suite;
|
||||||
|
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Suite;
|
||||||
|
import org.pytorch.PytorchLiteInstrumentedTests;
|
||||||
|
|
||||||
|
@RunWith(Suite.class)
|
||||||
|
@Suite.SuiteClasses({PytorchLiteInstrumentedTests.class})
|
||||||
|
public class PytorchLiteInstrumentedTestSuite {}
|
||||||
Loading…
Reference in New Issue
Block a user