mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[vulkan] test_app for mobilenetV2 on vulkan api (#48924)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48924 Test Plan: Imported from OSS Reviewed By: SS-JIA Differential Revision: D25365000 Pulled By: IvanKobzarev fbshipit-source-id: 79295b5781d2494681dbb4e4a741de49ff9c058c
This commit is contained in:
parent
36df25334f
commit
21ba48fe49
|
|
@ -60,20 +60,20 @@ android {
|
|||
//}
|
||||
flavorDimensions "model", "build", "activity"
|
||||
productFlavors {
|
||||
mbq {
|
||||
mnet {
|
||||
dimension "model"
|
||||
applicationIdSuffix ".mbq"
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2q.pt\"")
|
||||
addManifestPlaceholders([APP_NAME: "MBQ"])
|
||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbq\"")
|
||||
applicationIdSuffix ".mnet"
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet.pt\"")
|
||||
addManifestPlaceholders([APP_NAME: "MNET"])
|
||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"")
|
||||
}
|
||||
mbvulkan {
|
||||
mnetVulkan {
|
||||
dimension "model"
|
||||
applicationIdSuffix ".mbvulkan"
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2-vulkan.pt\"")
|
||||
applicationIdSuffix ".mnet_vulkan"
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet_vulkan.pt\"")
|
||||
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
|
||||
addManifestPlaceholders([APP_NAME: "MBQ"])
|
||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbvulkan\"")
|
||||
addManifestPlaceholders([APP_NAME: "MNET_VULKAN"])
|
||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"")
|
||||
}
|
||||
resnet18 {
|
||||
dimension "model"
|
||||
|
|
|
|||
|
|
@ -119,17 +119,40 @@ vTensor pack_weights(
|
|||
}
|
||||
|
||||
// shader KO4C4HW_to_image
|
||||
float image[4 * C_4][OC_4][KH * KW][4];
|
||||
memset(image, 0.f, 16 * C_4 * OC_4 * KH * KW * sizeof(float));
|
||||
struct Image3D {
|
||||
float* data_;
|
||||
uint32_t dim0_, dim1_, dim2_;
|
||||
|
||||
Image3D(uint32_t dim0, uint32_t dim1, uint32_t dim2) {
|
||||
dim0_ = dim0;
|
||||
dim1_ = dim1;
|
||||
dim2_ = dim2;
|
||||
data_ = new float[dim0 * dim1 * dim2 * 4];
|
||||
memset(data_, 0.f, dim0 * dim1 * dim2 * 4 * sizeof(float));
|
||||
}
|
||||
|
||||
inline uint32_t idx(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3) {
|
||||
return i3 + i2 * 4 + i1 * 4 * dim2_ + i0 * 4 * dim2_ * dim1_;
|
||||
}
|
||||
|
||||
void set(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, float value) {
|
||||
data_[idx(i0, i1, i2, i3)] = value;
|
||||
}
|
||||
|
||||
float get(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3) {
|
||||
return data_[idx(i0, i1, i2, i3)];
|
||||
}
|
||||
} image{4 * C_4, OC_4, KH * KW};
|
||||
|
||||
for (uint32_t sx = 0; sx < C_4; ++sx) {
|
||||
for (uint32_t sy = 0; sy < OC_4; ++sy) {
|
||||
for (uint32_t sz = 0; sz < (KH * KW); ++sz) {
|
||||
for (uint32_t vi = 0; vi < 4; ++vi) {
|
||||
int bufferVIdx = 4 * sx * KH * KW + 4 * sy * C_4 * KH * KW + 4 * sz;
|
||||
image[4 * sx + 0][sy][sz][vi] = dst[4 * (bufferVIdx + 0) + vi];
|
||||
image[4 * sx + 1][sy][sz][vi] = dst[4 * (bufferVIdx + 1) + vi];
|
||||
image[4 * sx + 2][sy][sz][vi] = dst[4 * (bufferVIdx + 2) + vi];
|
||||
image[4 * sx + 3][sy][sz][vi] = dst[4 * (bufferVIdx + 3) + vi];
|
||||
image.set(4 * sx + 0, sy, sz, vi, dst[4 * (bufferVIdx + 0) + vi]);
|
||||
image.set(4 * sx + 1, sy, sz, vi, dst[4 * (bufferVIdx + 1) + vi]);
|
||||
image.set(4 * sx + 2, sy, sz, vi, dst[4 * (bufferVIdx + 2) + vi]);
|
||||
image.set(4 * sx + 3, sy, sz, vi, dst[4 * (bufferVIdx + 3) + vi]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -143,7 +166,7 @@ vTensor pack_weights(
|
|||
for (uint32_t sy = 0; sy < H; ++sy) {
|
||||
for (uint32_t sz = 0; sz < D; ++sz) {
|
||||
for (uint32_t szvi = 0; szvi < 4; ++szvi) {
|
||||
dst_weight_ptr[W * sy + sx + (4 * sz + szvi) * W * H] = image[sx][sy][sz][szvi];
|
||||
dst_weight_ptr[W * sy + sx + (4 * sz + szvi) * W * H] = image.get(sx, sy, sz, szvi);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user