Summary: Vulkan rewrite sp that quantized transpose 2d ops can run in a model
Test Plan:
Run vulkan api test:
# buck2 build --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 --show-output"
# buck-out//v2/gen/fbsource/xplat/caffe2/pt_vulkan_api_test_binAppleMac
Running main() from third-party/googletest/1.14.0/googletest/googletest/src/gtest_main.cc
[==========] Running 418 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 418 tests from VulkanAPITest
....
[----------] Global test environment tear-down
[==========] 418 tests from 1 test suite ran. (4510 ms total)
[ PASSED ] 417 tests.
[ SKIPPED ] 1 test, listed below:
[ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log
YOU HAVE 9 DISABLED TESTS
Run quantized vulkan api test: Note the linear quantized are failing but all the convolution tests still pass. Linear failures are being debugged.
# buck2 build --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_quantized_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 --show-output"
# buck-out//v2/gen/fbsource/xplat/caffe2/pt_vulkan_quantized_api_test_binAppleMac
Running main() from third-party/googletest/1.14.0/googletest/googletest/src/gtest_main.cc
[==========] Running 86 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 86 tests from VulkanAPITest
...
[ PASSED ] 77 tests.
[ FAILED ] 9 tests, listed below:
[ FAILED ] VulkanAPITest.linear_2d_flat
[ FAILED ] VulkanAPITest.linear_2d_small
[ FAILED ] VulkanAPITest.linear_2d_large
[ FAILED ] VulkanAPITest.linear_3d_flat
[ FAILED ] VulkanAPITest.linear_3d_small
[ FAILED ] VulkanAPITest.linear_3d_large
[ FAILED ] VulkanAPITest.linear_4d_flat
[ FAILED ] VulkanAPITest.linear_4d_small
[ FAILED ] VulkanAPITest.linear_4d_large
9 FAILED TESTS
YOU HAVE 8 DISABLED TESTS
# Run CUNET quantized model on hibiki board.
Reviewed By: manuelcandales
Differential Revision: D52344263
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122547
Approved by: https://github.com/manuelcandales, https://github.com/copyrightly, https://github.com/yipjustin
Summary:
`conv1d` has two arguments `weight` and `bias` which are stored as constant tensors on the CPU and they are transferred to GPU at every inference call. We create a context for this operator to avoid the repeated passing. Specifically, we
- created `Conv1dPackedContext`,`create_conv1d_context` and `run_layernorm_context` in `Convolution.h` and `Convolution.cpp`
- registered them in `Register.cpp`
- rewrote the graph representation of the op in `vulkan_rewrite.cpp`
Test Plan:
## Numerical test
```
[luwei@82308.od /data/sandcastle/boxes/fbsource (8a8d911dc)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*conv1d*"
Buck UI: https://www.internalfb.com/buck2/7760800b-fd75-479a-9368-be5fcd5a7fef
Network: Up: 0B Down: 0B
Jobs completed: 4. Time elapsed: 0.6s.
BUILD SUCCEEDED
Running main() from third-party/googletest/1.14.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *conv1d*
[==========] Running 2 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 2 tests from VulkanAPITest
[ RUN ] VulkanAPITest.conv1d_simple
[ OK ] VulkanAPITest.conv1d_simple (159 ms)
[ RUN ] VulkanAPITest.conv1d
[ OK ] VulkanAPITest.conv1d (57 ms)
[----------] 2 tests from VulkanAPITest (217 ms total)
[----------] Global test environment tear-down
[==========] 2 tests from 1 test suite ran. (217 ms total)
[ PASSED ] 2 tests.
```
Full test result in P1053644934, summary as below
```
[----------] 419 tests from VulkanAPITest (28080 ms total)
[----------] Global test environment tear-down
[==========] 419 tests from 1 test suite ran. (28080 ms total)
[ PASSED ] 418 tests.
[ SKIPPED ] 1 test, listed below:
[ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log
```
## Graph representation comparison
We created a model using `conv1d` and traced it as below
```
# Define a simple model that uses conv1d
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1d = nn.Conv1d(16, 33, 3)
def forward(self, x):
return self.conv1d(x)
# Create an instance of the model
model = MyModel()
# Create a dummy input tensor for tracing
input_tensor = torch.randn(20, 16, 50)
# Use torch.jit.trace to trace the model and generate a graph
traced_model = torch.jit.trace(model, input_tensor)
```
Then we converted the traced model to Vulkan backend using `optimize_for_mobile`
```
from torch.utils import mobile_optimizer
vulkan_model = mobile_optimizer.optimize_for_mobile(
traced_model, backend="vulkan", preserved_methods=to_preserve
)
```
Next we can print the graph of the `vulkan_model` as `print(vk_model.graph)`
- before this diff: `conv1d` was used
```
graph(%self.1 : __torch__.___torch_mangle_16.MyModel,
%x : Tensor):
%60 : Device = prim::Constant[value="cpu"]()
%self.conv1d.bias : Float(33, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
%37 : bool = prim::Constant[value=0]()
%36 : NoneType = prim::Constant()
%59 : Device = prim::Constant[value="vulkan"]()
%self.conv1d.weight : Float(33, 16, 3, strides=[48, 3, 1], requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
%7 : int = prim::Constant[value=1](), scope: __module.conv1d # /mnt/xarfuse/uid-23453/243f3953-seed-nspid4026532834_cgpid7972545-ns-4026532831/torch/nn/modules/conv.py:306:0
%18 : int[] = prim::Constant[value=[1]]()
%19 : int[] = prim::Constant[value=[0]]()
%39 : Tensor = aten::to(%x, %59, %36, %37, %37)
%20 : Tensor = aten::conv1d(%39, %self.conv1d.weight, %self.conv1d.bias, %18, %19, %18, %7)
%58 : Tensor = aten::to(%20, %60, %36, %37, %37)
return (%58)
```
- after this diff: `conv1d` was replaced with `run_conv1d_context`
```
graph(%self.1 : __torch__.___torch_mangle_6.MyModel,
%x : Tensor):
%85 : Device = prim::Constant[value="cpu"]()
%51 : bool = prim::Constant[value=0]()
%50 : NoneType = prim::Constant()
%84 : Device = prim::Constant[value="vulkan"]()
%53 : Tensor = aten::to(%x, %84, %50, %51, %51)
%prepack_folding_forward._jit_pass_packed_weight_0 : __torch__.torch.classes.vulkan.Conv1dPackedContext = prim::GetAttr[name="prepack_folding_forward._jit_pass_packed_weight_0"](%self.1)
%22 : Tensor = vulkan_prepack::run_conv1d_context(%53, %prepack_folding_forward._jit_pass_packed_weight_0)
%83 : Tensor = aten::to(%22, %85, %50, %51, %51)
return (%83)
```
Differential Revision: D52865379
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117780
Approved by: https://github.com/yipjustin
Summary:
`Layernorm` has two arguments weight and bias which are stored as constant tensors on the CPU and they are transferred to GPU at every inference call. We create a context for this op to avoid the repeated passing. Specifically, we
- created `create_layernorm_context` and `run_layernorm_context` in `Layernorm.h` and `Layernorm.cpp`
- registered them in `Register.cpp`
- rewrote the graph representation of the op in `vulkan_rewrite.cpp`
Test Plan:
## Numerical test
```
[luwei@devbig984.prn1 /data/users/luwei/fbsource (b6ccc956c)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*layer_norm*"
Recommended: For faster builds try buck2: replace 'buck' with 'buck2'
NOTE: buck-out/ has changed: look for files in fbsource/buck-out/v2/
'buck2 build --show-output //xplat/caffe2:pt_vulkan_api_test_bin' will print the new output paths.
If you are building in fbsource//xplat and have questions, post in 'Cross Platform Dev Discussions': https://fb.workplace.com/groups/xplat.qa
Targets matching .buckconfig buck2.supported_projects:
{'//xplat/caffe2:pt_vulkan_api_test_bin': '//xplat'}
To suppress this warning: touch ~/.config/.dont_hint_buck2
Building: finished in 0.1 sec (100%) 339/339 jobs, 0/339 updated
Total time: 0.2 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.14.0/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *layer_norm*
[==========] Running 10 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 10 tests from VulkanAPITest
[ RUN ] VulkanAPITest.packed_layer_norm_2d
[ OK ] VulkanAPITest.packed_layer_norm_2d (342 ms)
[ RUN ] VulkanAPITest.packed_layer_norm_3d
[ OK ] VulkanAPITest.packed_layer_norm_3d (284 ms)
[ RUN ] VulkanAPITest.packed_layer_norm_4d
[ OK ] VulkanAPITest.packed_layer_norm_4d (5 ms)
[ RUN ] VulkanAPITest.layer_norm_invalid_inputs
[ OK ] VulkanAPITest.layer_norm_invalid_inputs (28 ms)
[ RUN ] VulkanAPITest.layer_norm_2d
[ OK ] VulkanAPITest.layer_norm_2d (1 ms)
[ RUN ] VulkanAPITest.layer_norm_3d
[ OK ] VulkanAPITest.layer_norm_3d (2 ms)
[ RUN ] VulkanAPITest.layer_norm_4d
[ OK ] VulkanAPITest.layer_norm_4d (4 ms)
[ RUN ] VulkanAPITest.native_layer_norm_2d
[ OK ] VulkanAPITest.native_layer_norm_2d (1 ms)
[ RUN ] VulkanAPITest.native_layer_norm_3d
[ OK ] VulkanAPITest.native_layer_norm_3d (2 ms)
[ RUN ] VulkanAPITest.native_layer_norm_4d
[ OK ] VulkanAPITest.native_layer_norm_4d (6 ms)
[----------] 10 tests from VulkanAPITest (679 ms total)
[----------] Global test environment tear-down
[==========] 10 tests from 1 test suite ran. (679 ms total)
[ PASSED ] 10 tests.
```
Full test result in P888496077, summary as below
```
[----------] 419 tests from VulkanAPITest (21652 ms total)
[----------] Global test environment tear-down
[==========] 419 tests from 1 test suite ran. (21652 ms total)
[ PASSED ] 418 tests.
[ SKIPPED ] 1 test, listed below:
[ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log
```
## Graph representation comparison
We created a model using `layer_norm` and traced it as below
```
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layer_norm = torch.nn.LayerNorm(normalized_shape=10)
def forward(self, x):
return self.layer_norm(x)
# Create an instance of the model
model = MyModel()
# Create a dummy input tensor for tracing
input_tensor = torch.randn(1, 10)
# Use torch.jit.trace to trace the model and generate a graph
traced_model = torch.jit.trace(model, input_tensor)
```
Then we converted the traced model to Vulkan backend using `optimize_for_mobile`
```
from torch.utils import mobile_optimizer
vulkan_model = mobile_optimizer.optimize_for_mobile(
traced_model, backend="vulkan", preserved_methods=to_preserve
)
```
Then we can print the graph of the `vulkan_model` as `print(vk_model.graph)`
- Before this diff
```
%4 : bool = prim::Constant[value=1](), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
%5 : float = prim::Constant[value=1.0000000000000001e-05](), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
%14 : int[] = prim::Constant[value=[10]]()
%33 : Tensor = aten::to(%x, %53, %30, %31, %31)
%10 : Tensor = aten::layer_norm(%33, %14, %self.layer_norm.weight, %self.layer_norm.bias, %5, %4), scope: __module.layer_norm # /mnt/xarfuse/uid-602118/33e18f68-seed-nspid4026531836_cgpid32066351-ns-4026531840/torch/nn/functional.py:2546:0
```
- after this diff
```
%14 : int[] = prim::Constant[value=[10]]()
%47 : Tensor = aten::to(%x, %78, %44, %45, %45)
%16 : Tensor = vulkan_prepack::run_layernorm_context(%47, %14, %17)
```
Reviewed By: SS-JIA
Differential Revision: D51530478
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114701
Approved by: https://github.com/yipjustin
Summary: Add linear quantize for vulkan to custom ops so it can be used from a model.
Test Plan:
buck2 run --target-platforms ovr_config//platform/macos:arm64-fbsource -c pt.vulkan_full_precision=1
//xplat/caffe2/fb/custom_ops/vulkan_quantized:pt_vulkan_quantized_test_binAppleMac\#macosx-arm64
[ OK ] VulkanAPITest.convert_qconv2d_context (135 ms)
[ RUN ] VulkanAPITest.linear_2d
[ OK ] VulkanAPITest.linear_2d (4 ms)
[----------] 2 tests from VulkanAPITest (139 ms total)
[----------] Global test environment tear-down
[==========] 2 tests from 1 test suite ran. (139 ms total)
[ PASSED ] 2 tests.
##############################################################
buck2 build --target-platforms ovr_config//platform/macos:arm64-fbsource
//xplat/caffe2:pt_vulkan_quantized_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 --show-output"
buck-out//v2/gen/fbsource/xplat/caffe2/pt_vulkan_quantized_api_test_binAppleMac
[ OK ] VulkanAPITest.conv2d_pw_quantized_prepack_random_params_int8_int32 (11 ms)
[ RUN ] VulkanAPITest.linear_2d_flat
[ OK ] VulkanAPITest.linear_2d_flat (4 ms)
[ RUN ] VulkanAPITest.linear_2d_small
[ OK ] VulkanAPITest.linear_2d_small (1 ms)
[ RUN ] VulkanAPITest.linear_2d_large
[ OK ] VulkanAPITest.linear_2d_large (1 ms)
[ RUN ] VulkanAPITest.linear_3d_flat
[ OK ] VulkanAPITest.linear_3d_flat (2 ms)
[ RUN ] VulkanAPITest.linear_3d_small
[ OK ] VulkanAPITest.linear_3d_small (2 ms)
[ RUN ] VulkanAPITest.linear_3d_large
[ OK ] VulkanAPITest.linear_3d_large (1 ms)
[ RUN ] VulkanAPITest.linear_4d_flat
[ OK ] VulkanAPITest.linear_4d_flat (1 ms)
[ RUN ] VulkanAPITest.linear_4d_small
[ OK ] VulkanAPITest.linear_4d_small (1 ms)
[ RUN ] VulkanAPITest.linear_4d_large
[ OK ] VulkanAPITest.linear_4d_large (1 ms)
[ RUN ] VulkanAPITest.linear_custom
[ OK ] VulkanAPITest.linear_custom (0 ms)
[----------] 76 tests from VulkanAPITest (1811 ms total)
[----------] Global test environment tear-down
[==========] 76 tests from 1 test suite ran. (1811 ms total)
[ PASSED ] 76 tests.
YOU HAVE 8 DISABLED TESTS
##############################################################
buck2 run --target-platforms ovr_configplatform/macos:arm64-fbsourcexplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1
[----------] Global test environment tear-down
[==========] 346 tests from 1 test suite ran. (5648 ms total)
[ PASSED ] 345 tests.
[ SKIPPED ] 1 test, listed below:
[ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log
YOU HAVE 5 DISABLED TESTS
Reviewed By: manuelcandales
Differential Revision: D49609985
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111148
Approved by: https://github.com/yipjustin
Summary:
This diffs registers the vulkan quantized binary ops (add/sub/mul/div), and adds graph rewrites for quantized add, mul, conv2d and conv2d_relu.
The rewrites for conv2d and conv2d_relu make use of the convert_qconv2d_context introduced in D41595032
Test Plan: export quantized mcs model to vulkan
Reviewed By: SS-JIA
Differential Revision: D44189363
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97468
Approved by: https://github.com/SS-JIA
Summary:
Avoid dereferencing element [0] if the vector is empty.
___
In ```transferInputOutputBackends```, one of the rewrite passes for Vulkan ```optimize_for_mobile```, an out of bounds access happens when trying to insert a backend transfer for an input if that input's ```uses()``` is empty. This diff corrects that issue.
Test Plan:
Run tests
___
Phabricator + CI Tests
Reviewed By: SS-JIA
Differential Revision: D41296037
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92918
Approved by: https://github.com/SS-JIA, https://github.com/kirklandsign
Summary:
This diffs fixes several issues in GRU and LSTM vulkan ops:
- Add create_gru_context and create_lstm_context to vulkanFoldPrePackingOps
- Add filter to insertPrePackedGruOp and insertPrePackedLstmOp to avoid matching gru.data and lstm.data usages
- Fixed output dimension of GRU and LSTM
- Allowed batch_first to be false when batch=1 and seq=1
Test Plan:
Check that optimize_for_mobile runs and correctly folds the create context ops
```
buck run :export_for_mobile ~/ferraris/ferraris.ptl ~/ferraris
```
Check that vulkan api tests are still passing
```
buck run //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64
```
Reviewed By: SS-JIA
Differential Revision: D38811967
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83722
Approved by: https://github.com/SS-JIA
Summary:
High level description of this diff:
- VulkanOpContext is eliminated
- LinearPackedContext, Conv2dPackedContext, GruPackedContext & LstmPackedContext are introduced.
- They are child classes of the virtual class VulkanPackedContext.
- Their purpose is to pack and unpack the context for each of those ops. They unpack the context on serialization, and pack it on deserialization.
- They are different than the old op specific context (LinearOpContext, Conv2dOpContext, etc) in two important ways: they only store the packed data and they do not contain the logic for running the op. (In this diff, the unpacked functions for LinearPackedContext and Conv2dPackedContext haven't been implemented yet, so, we are cheating by including a private unpacked_ list inside each; but in a future diff, unpack functionality will be implemented for those two classes, and that private list removed).
- The old LinearOpContext, GruOpContext & LstmOpContext are completely eliminated. Conv2dOpContext is maintained for backwards compatibility, but it is just a wrapper around Conv2dPackedContext
- A lot of code from Convolution.cpp was repeated in the files TransposeConvolution2d.cpp and QuantizedConvolution.cpp. Therefore the logic was combined, introducing transposed and quantized flags where appropriate and everything was moved to Convolution.cpp & Convolution.h
- The top level convolution functions defined in Register.cpp are moved to Convolution.cpp
Test Plan:
Run vulkan_api_test
- On Mac:
```
buck run //xplat/caffe2:pt_vulkan_api_test_binAppleMac
```
- On Android:
```
buck build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 //xplat/caffe2:pt_vulkan_api_test_binAndroid\#android-arm64 --show-output
adb push buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_api_test
adb shell "/data/local/tmp/vulkan_api_test"
```
Run vulkan_quantized_api_test
- On Mac:
```
buck run //xplat/caffe2:pt_vulkan_quantized_api_test_binAppleMac
```
- On Android:
```
buck build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 //xplat/caffe2:pt_vulkan_quantized_api_test_binAndroid\#android-arm64 --show-output
adb push buck-out/gen/xplat/caffe2/pt_vulkan_quantized_api_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_quantized_api_test
adb shell "/data/local/tmp/vulkan_quantized_api_test"
```
Reviewed By: SS-JIA
Differential Revision: D38363981
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82730
Approved by: https://github.com/SS-JIA
Summary:
Optimized LSTM operator by using pre-packing for weights and biases in the Vulkan GPU backend
- The weights and biases are always on the CPU side by design.
- The packed and unpacked data are stored in a VulkanOpContext
- Ops:
- `at::native::vulkan::ops::create_lstm_context`: Creates a VulkanOpContext object with the packed and unpacked data, and returns a pointer to it.
- `at::native::vulkan::ops::run_lstm_context`: Takes in the three input vulkan tensors (input sequence, initial hidden state and initial cell state) and a pointer to the context, and runs the LSTM operation.
- Registered the ops in [Register.cpp](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/vulkan/ops/Register.cpp).
- Rewrote the subgraph function of LSTM in [vulkan_rewrite.cpp](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/vulkan_rewrite.cpp) so that `create_lstm_context` and `run_lstm_context` can be executed instead in the Vulkan GPU backend.
- Added new test for the LSTM pre-packing and run ops: `lstm_prepack_success`
Test Plan: buck run //xplat/caffe2:pt_vulkan_api_test_binAppleMac
Reviewed By: SS-JIA
Differential Revision: D37052597
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79702
Approved by: https://github.com/SS-JIA
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73599
Optimized GRU operator by using pre-packing for weights and biases in the Vulkan GPU backend:
* The weights and biases are always on the CPU side by design.
* To reduce the overhead by retrieving the weight and bias tensors every time, it is the best way to store them by pre-packing.
* A custom op context `GruOpContext` (derived from `torch::jit::CustomClassHolder`) is created to hold both packed and unpacked data. It corresponds to the unpacked_ struct which represents the data needed to construct the op context. This data will be pre-packed and be stored in the packed_ struct. The constructor of the `GruOpContext` loads the data into the unpacked_ and packed_ structs.
* `at::native::vulkan::ops::gru_prepack` and `at::native::vulkan::ops::gru_run` methods use the op context. The `gru_prepack` takes in whatever data is needed to construct the op context and returns a pointer to a created context. The `gru_run` takes input tensors and a pointer to the op context that uses the data stored in the context to process the inputs.
* Lastly, we need to register the op context class and ops in [Register.cpp](11dc158129/aten/src/ATen/native/vulkan/ops/Register.cpp). And rewrite the subgraph function of GRU op in [vulkan_rewrite.cpp](11dc158129/torch/csrc/jit/passes/vulkan_rewrite.cpp) so that `gru_prepack` and `gru_run` ops can be executed instead in the Vulkan GPU backend.
* To avoid `"Undefined symbols for architecture x86_64"` compiler error on the x86_64 platform, `c10::Dispatcher::callBoxed()` API is used to call `vulkan_prepack::gru_prepack` and `vulkan_prepack::gru_run` by name. Otherwise, the test methods can't resolve the symbols.
* Added new tests for the GRU pre-packing and run operations: `gru_prepack_success` and gru_prepack_invalidinputs_exceptions`
* To build your PyTorch OSS on your local machine:
```
python setup.py clean
git submodule update --init --recursive
USE_VULKAN=1 USE_VULKAN_FP16_INFERENCE=1 python3 setup.py install --cmake
python setup.py develop && python -c "import torch"
```
* To run and dump a model containing GRU operators in Python:
```
import torch
from torch.utils import mobile_optimizer
model = torch.jit.load("Mclaren_traced.pt")
vk_model = mobile_optimizer.optimize_for_mobile(model, backend="vulkan")
print(vk_model.graph)
```
* The following torch scripts are the updated version by GRU pre-packing:
```
%15 : Tensor[] = prim::ListConstruct(%weight_ih_l0.1, %weight_hh_l0.1, %bias_ih_l0.1, %bias_hh_l0.1, %weight_ih_l1.1, %weight_hh_l1.1, %bias_ih_l1.1, %bias_hh_l1.1)
%19 : __torch__.torch.classes.vulkan.GruOpContext = vulkan_prepack::gru_prepack(%15, %4, %5, %6, %3, %3, %4)
%20 : Tensor, %21 : Tensor = vulkan_prepack::gru_run(%input.1, %hx.1, %19)
%18 : (Tensor, Tensor) = prim::TupleConstruct(%21, %20)
return (%18)
```
* This implementation has some limitations:
* Tensor dim should be 3 for input sequence and hidden state.
* has_biases=True
* train=False
* bidirectional=False
* batch_first=True
* dropout=0.0
* D=1 since bidirectional=False
* N=1 (batch size)
* L=1 (sequence length)
Test Plan:
Build & test on Android:
```
cd ~/fbsource
buck build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 //xplat/caffe2:pt_vulkan_api_test_binAndroid\#android-arm64 --show-output
adb push buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_api_test
adb shell "/data/local/tmp/vulkan_api_test"
```
Build & test on MacOS (x86_64):
```
cd ~/fbsource
buck build //xplat/caffe2:pt_vulkan_api_test_binAppleMac
./buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAppleMac\#macosx-x86_64
```
Test result on Android (Google Pixel 5):
```
Running main() from gtest_main.cc
[==========] Running 4 tests from 1 test case.
[----------] Global test environment set-up.
[----------] 4 tests from VulkanAPITest
[ RUN ] VulkanAPITest.gru_mclareninputs_success
[ OK ] VulkanAPITest.gru_mclareninputs_success (1037 ms)
[ RUN ] VulkanAPITest.gru_invalidinputs_exceptions
[ OK ] VulkanAPITest.gru_invalidinputs_exceptions (16 ms)
[ RUN ] VulkanAPITest.gru_prepack_success
[ OK ] VulkanAPITest.gru_prepack_success (45 ms)
[ RUN ] VulkanAPITest.gru_prepack_invalidinputs_exceptions
[ OK ] VulkanAPITest.gru_prepack_invalidinputs_exceptions (16 ms)
[----------] 4 tests from VulkanAPITest (1114 ms total)
[----------] Global test environment tear-down
[==========] 4 tests from 1 test case ran. (1114 ms total)
[ PASSED ] 4 tests.
```
Test result on MacOS (x86_64):
```
Running main() from gtest_main.cc
[==========] Running 4 tests from 1 test case.
[----------] Global test environment set-up.
[----------] 4 tests from VulkanAPITest
[ RUN ] VulkanAPITest.gru_mclareninputs_success
[ OK ] VulkanAPITest.gru_mclareninputs_success (1012 ms)
[ RUN ] VulkanAPITest.gru_invalidinputs_exceptions
[ OK ] VulkanAPITest.gru_invalidinputs_exceptions (40 ms)
[ RUN ] VulkanAPITest.gru_prepack_success
[ OK ] VulkanAPITest.gru_prepack_success (99 ms)
[ RUN ] VulkanAPITest.gru_prepack_invalidinputs_exceptions
[ OK ] VulkanAPITest.gru_prepack_invalidinputs_exceptions (39 ms)
[----------] 4 tests from VulkanAPITest (1190 ms total)
[----------] Global test environment tear-down
[==========] 4 tests from 1 test case ran. (1190 ms total)
[ PASSED ] 4 tests.
```
Reviewed By: SS-JIA
Differential Revision: D34556940
fbshipit-source-id: dce918de238fb8a4a0ea5e966e05ca99ed910c28
(cherry picked from commit cd1d95ff8d0fa7810cf18a54ba64539e46daa26a)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73872
This diff adds an equivalent target for [`aten_vulkan`](https://fburl.com/code/h9ybej5u) in FBCode as the `ATen-vulkan` target. This diff simply creates equivalent fbcode targets for all the xplat targets needed to build `aten_vulkan`:
The following targets in `xplat/caffe2` have had equivalent targets created in `fbcode/caffe2/aten`
* `aten_vulkan_glsl_src_path`
* filegroup containing all Vulkan glsl files
* `gen_aten_vulkan_spv_lib`
* python library containing script to generate vulkan spv files
* `gen_aten_vulkan_spv_bin`
* python binary wrapping the above target
* `gen_aten_vulkan_spv`
* genrule to execute the above python script and create C++ headers containing the SPIR-V shader code
* `generated_aten_headers_vulkan`
* C++ library that points to the generated SPIR-V headers from above
* `aten_vulkan`
* Contains the Pytorch Vulkan backend
FBCode targets have also been added for:
* `Vulkan-Headers` which contains Vulkan API function signatures
* `vulkan_wrapper` which loads the vulkan library
* `dotslash:glslc` which wraps the glsl compiler in a target that can be executed by genrules
Test Plan:
Try building the new `ATen-vulkan` target:
```
cd fbsource/fbcode/caffe2/aten
buck build :ATen-vulkan
```
Also tested in the next diff which tries to use this target in a Python script in FBCode.
Reviewed By: beback4u
Differential Revision: D34647445
fbshipit-source-id: 7330df1e3858c88b934b06e8e75f4fdcfa88068e
(cherry picked from commit 25251bed83e97bb9ef96a5f611c6ed72ba4219fc)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73243
The previous version of the Vulkan backend is no longer being used anymore. Delete the dead code from the codebase.
Test Plan: Make sure everything still builds.
Reviewed By: beback4u
Differential Revision: D34400045
fbshipit-source-id: ae2a61452bf9199c11d81cc0369de8a9dd6692b1
(cherry picked from commit 22ee917f05f36ed16226dbe79da7892426eb09a2)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61646
There are several passes which are written to handle both
`CallFunction("linear", ...)` and `aten::linear(...)` despite the two being
functionally identical.
This changes `FuseLinear` to alse normalize the `CallFunction` variant to
`aten::linear`. That way each subsequent transformation only has to handle one
form instead of both.
Test Plan: Imported from OSS
Reviewed By: mikaylagawarecki
Differential Revision: D33754261
Pulled By: albanD
fbshipit-source-id: 42465cea790538481efc881a249dafdda4bba5d4
(cherry picked from commit ebeca9434c)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56070
**Summary**
Currently, we're returning copies instead of alias on mobile GPU (Metal/Vulkan). As suggested by ailzhang , we could use the JIT pass - `RemoveTensorMutation` to ban mutations ahead of time. I've tested two scenarios as shown below. They both work fine on mobile.
- view
```
class Model (torch.nn.Module):
def forward(self, x):
y = x.view(-1)
z = torch.tensor(2.0).float()
y.add_(z)
return x
m = Model()
x = torch.rand(2, 3)
y = m(x)
```
- transpose
```
class Model (torch.nn.Module):
def forward(self, x):
y = x.transpose(1, 2)
z = torch.tensor(2.0).float()
x.add_(z)
return y
m = Model()
x = torch.rand(1, 2, 3)
y = m(x)
```
As we're adding more ops, we should add more tests to cover all the alias ops - https://github.com/pytorch/pytorch/blob/master/tools/autograd/gen_inplace_or_view_type.py#L31-L80
**Next step**
Synced offline with eellison. Since mutation removal is also being used in ONNX, Static runtime, some jit optimizations, Torch -> TVM, etc, instead of inventing something new, we would continue to make it better in cases where it fails.
Although this JIT pass could work for most of the mobile models, there are cases that it can't cover. What we're going to do next is to implement stub ops for GPU models to let them run on server side, such that users can compare results to see if there is any discrepancy.
ghstack-source-id: 126802123
Test Plan:
- Sandcastle
- CircleCI
Reviewed By: raziel
Differential Revision: D27692683
fbshipit-source-id: 9d1be8a6c0a276032b1907807a54fbe2afd882f9