mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[vulkan][android][test_app] Add test_app variant that runs module on Vulkan (#44897)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44897 Test Plan: Imported from OSS Reviewed By: dreiss Differential Revision: D23763770 Pulled By: IvanKobzarev fbshipit-source-id: 6ad16b7271c745313a71da64a629a764258bbc85
This commit is contained in:
parent
2c300fd74c
commit
17be7c6e5c
|
|
@ -40,6 +40,7 @@ android {
|
||||||
buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
|
buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
|
||||||
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
|
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
|
||||||
buildConfigField("boolean", "NATIVE_BUILD", 'false')
|
buildConfigField("boolean", "NATIVE_BUILD", 'false')
|
||||||
|
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
|
||||||
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
|
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
|
||||||
}
|
}
|
||||||
buildTypes {
|
buildTypes {
|
||||||
|
|
@ -66,9 +67,17 @@ android {
|
||||||
addManifestPlaceholders([APP_NAME: "MBQ"])
|
addManifestPlaceholders([APP_NAME: "MBQ"])
|
||||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbq\"")
|
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbq\"")
|
||||||
}
|
}
|
||||||
|
mbvulkan {
|
||||||
|
dimension "model"
|
||||||
|
applicationIdSuffix ".mbvulkan"
|
||||||
|
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2-vulkan.pt\"")
|
||||||
|
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
|
||||||
|
addManifestPlaceholders([APP_NAME: "MBQ"])
|
||||||
|
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbvulkan\"")
|
||||||
|
}
|
||||||
resnet18 {
|
resnet18 {
|
||||||
dimension "model"
|
dimension "model"
|
||||||
applicationIdSuffix ".resneti18"
|
applicationIdSuffix ".resnet18"
|
||||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.pt\"")
|
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.pt\"")
|
||||||
addManifestPlaceholders([APP_NAME: "RN18"])
|
addManifestPlaceholders([APP_NAME: "RN18"])
|
||||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
|
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
import java.nio.FloatBuffer;
|
import java.nio.FloatBuffer;
|
||||||
|
import org.pytorch.Device;
|
||||||
import org.pytorch.IValue;
|
import org.pytorch.IValue;
|
||||||
import org.pytorch.Module;
|
import org.pytorch.Module;
|
||||||
import org.pytorch.PyTorchAndroid;
|
import org.pytorch.PyTorchAndroid;
|
||||||
|
|
@ -126,7 +127,9 @@ public class MainActivity extends AppCompatActivity {
|
||||||
mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements);
|
mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements);
|
||||||
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE);
|
mInputTensor = Tensor.fromBlob(mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE);
|
||||||
PyTorchAndroid.setNumThreads(1);
|
PyTorchAndroid.setNumThreads(1);
|
||||||
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
|
mModule = BuildConfig.USE_VULKAN_DEVICE
|
||||||
|
? PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
|
||||||
|
: PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
|
||||||
}
|
}
|
||||||
|
|
||||||
final long startTime = SystemClock.elapsedRealtime();
|
final long startTime = SystemClock.elapsedRealtime();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user