[PT-Vulkan] aten::unsqueeze - nit optimization (#118575)

Summary:
Learning Vulkan shaders and realized one of the branches can be easily optimized.

The relevant branch is only taken when we unsqueeze along `dim == 1` for 3D tensors.
1. There's an unnecessary for-loop.
2. There's an unnecessary dependency on the output tensor's number of channels.

## CPU Tensor
```
3D->4D: (c, h, w) -> (c, 0, h, w)
```
## GPU Texture
```
3D->4D: (w, h, c/4)[c%4] -> (w, h, c)[0]
```

Note the GPU Texture's output is always at `[0]` and the output tensor's number of channels is always 1.

We are currently writing the same value `v[p]` to all elements of the texel `out_texel`, but we need only write it to `out_texel[0]`:

Test Plan:
```
[jorgep31415@161342.od /data/sandcastle/boxes/fbsource (ca3b566bc)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck2 run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*unsqueeze*"
File changed: fbcode//caffe2/aten/src/ATen/native/vulkan/glsl/unsqueeze.glsl
File changed: fbsource//xplat/caffe2/aten/src/ATen/native/vulkan/glsl/unsqueeze.glsl
Buck UI: https://www.internalfb.com/buck2/2c7f1365-e004-41a0-9201-473929a2738a
Network: Up: 174B  Down: 0B  (reSessionID-c54d25da-f44b-49f7-8bfd-1db4eee50f6d)
Jobs completed: 6. Time elapsed: 14.4s.
Cache hits: 0%. Commands: 1 (cached: 0, remote: 0, local: 1)
BUILD SUCCEEDED
Running main() from third-party/googletest/1.14.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *unsqueeze*
[==========] Running 10 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 10 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.unsqueeze_0dto1d_dim0
[       OK ] VulkanAPITest.unsqueeze_0dto1d_dim0 (60 ms)
[ RUN      ] VulkanAPITest.unsqueeze_1dto2d_dim0
[       OK ] VulkanAPITest.unsqueeze_1dto2d_dim0 (0 ms)
[ RUN      ] VulkanAPITest.unsqueeze_1dto2d_dim1
[       OK ] VulkanAPITest.unsqueeze_1dto2d_dim1 (132 ms)
[ RUN      ] VulkanAPITest.unsqueeze_2dto3d_dim0
[       OK ] VulkanAPITest.unsqueeze_2dto3d_dim0 (20 ms)
[ RUN      ] VulkanAPITest.unsqueeze_2dto3d_dim1
[       OK ] VulkanAPITest.unsqueeze_2dto3d_dim1 (66 ms)
[ RUN      ] VulkanAPITest.unsqueeze_2dto3d_dim2
[       OK ] VulkanAPITest.unsqueeze_2dto3d_dim2 (3 ms)
[ RUN      ] VulkanAPITest.unsqueeze_3dto4d_dim0
[       OK ] VulkanAPITest.unsqueeze_3dto4d_dim0 (19 ms)
[ RUN      ] VulkanAPITest.unsqueeze_3dto4d_dim1
[       OK ] VulkanAPITest.unsqueeze_3dto4d_dim1 (1 ms)
[ RUN      ] VulkanAPITest.unsqueeze_3dto4d_dim2
[       OK ] VulkanAPITest.unsqueeze_3dto4d_dim2 (1 ms)
[ RUN      ] VulkanAPITest.unsqueeze_3dto4d_dim3
[       OK ] VulkanAPITest.unsqueeze_3dto4d_dim3 (1 ms)
[----------] 10 tests from VulkanAPITest (307 ms total)

[----------] Global test environment tear-down
[==========] 10 tests from 1 test suite ran. (307 ms total)
[  PASSED  ] 10 tests.
[
```

Differential Revision: D53189637

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118575
Approved by: https://github.com/yipjustin
This commit is contained in:
Jorge Pineda 2024-01-30 20:01:18 +00:00 committed by PyTorch MergeBot
parent d0627cc2af
commit 9247641f34

View File

@ -41,12 +41,10 @@ void main() {
if (dim == 1) {
int src_x = pos.x;
int src_y = pos.y;
for (int i = 0; i < 4; i++) {
int src_z = pos.z / (channels * 4);
int p = (pos.z / channels) % 4;
int src_z = pos.z / 4;
int p = pos.z % 4;
const vec4 v = texelFetch(uImage, ivec3(src_x, src_y, src_z), 0);
out_texel[i] = v[p];
}
out_texel[0] = v[p];
imageStore(uOutput, pos, out_texel);
} else if (dim == 2) {
int src_x = pos.x;