# Motivation
Add context variable `torch.bachend.mkldnn.allow_tf32` to control tf32 computation in convolution kernels at XPU side. The tf32 data type is beneficial to improve the performance of deep learning workloads during training/inference. Current PR uses the [oneDNN API fpmath_mode](https://oneapi-src.github.io/oneDNN/dev_guide_attributes_fpmath_mode.html#the-floating-point-math-mode-attribute) to trigger the tf32 acceleration in convolution kernels.
# Valiadation
* ut to test context variable
`python test/xpu/test_conv.py -k test_mkldnn_allow_tf32_get_set`
* Runtime exemplification
```
onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic16oc33_ih50oh24kh3sh2dh0ph0_iw100ow49kw3sw2dw0pw0,0.649902
onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic33oc33_ih24oh24kh3sh1dh0ph1_iw49ow49kw3sw1dw0pw1,0.151855
onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,backward_data,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_undef::undef::: dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic33oc33_ih24oh24kh3sh1dh0ph1_iw49ow49kw3sw1dw0pw1,0.167969
onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,backward_weights,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic33oc33_ih24oh24kh3sh1dh0ph1_iw49ow49kw3sw1dw0pw1,0.26709
onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,backward_weights,src_f32::blocked:abcd::f0 wei_f32::blocked:abcd::f0 bia_f32::blocked:a::f0 dst_f32::blocked:abcd::f0,attr-scratchpad:user attr-fpmath:tf32,alg:convolution_direct,mb20_ic16oc33_ih50oh24kh3sh2dh0ph0_iw100ow49kw3sw2dw0pw0,0.219971
```
According to the field `fpmath:tf32` in verbose, we could see that, current context setting utils could successfully trigger tf32 computation in conv forward/backward_data/backward_weights kernels.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137570
Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/atalman, https://github.com/malfet
Co-authored-by: Yu, Guangye <guangye.yu@intel.com>
Replace https://github.com/pytorch/pytorch/pull/138947 for re-import.
Replaces https://github.com/ROCm/pytorch/pull/1592
This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.
Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author
NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143695
Approved by: https://github.com/malfet
Co-authored-by: Andy Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
Description:
1. Quantize Linear Layer Weights to 4-bits:
Quantize the weights of the Linear layer to 4 bits, using symmetric quantization.
Pack two 4-bit weights into one uint8 container.
Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32.
2. Prepare Quantized Weights, Scales, and Optional Bias:
After quantizing, obtain the quantized_weights, scales, and groupsize.
If the original Linear layer has a bias, prepare it as well.
3. Pack the Weights Efficiently:
Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias.
```python
packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features)
```
Input parameters should include:
in_features and out_features (the same as the Linear layer’s corresponding parameters).
4. Perform Dynamic Quantized Matrix Multiplication:
Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights.
```python
output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights, groupsize, in_features, out_features)
```
Inputs required include:
The input tensor, packed_weights , groupsize, and the in_features and out_features.
API Usage: https://github.com/pytorch/pytorch/issues/143289
Model Perf :
7B Transformer model:
Prefill : 340 t/s
Decode : 40 t/s
2B Transformer model
Prefill : 747 t/s
Decode : 80 t/s
Tests:
python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight
Ran 1 test in 0.016s
OK
python test/test_linalg.py -k test__dyn_quant_matmul_4bit
Ran 8 tests in 0.077s
OK
python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit
Ran 8 tests in 11.454s
Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124
Approved by: https://github.com/digantdesai, https://github.com/malfet
Description:
1. Quantize Linear Layer Weights to 4-bits:
Quantize the weights of the Linear layer to 4 bits, using symmetric quantization.
Pack two 4-bit weights into one uint8 container.
Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32.
2. Prepare Quantized Weights, Scales, and Optional Bias:
After quantizing, obtain the quantized_weights, scales, and groupsize.
If the original Linear layer has a bias, prepare it as well.
3. Pack the Weights Efficiently:
Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias.
```python
packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features)
```
Input parameters should include:
in_features and out_features (the same as the Linear layer’s corresponding parameters).
4. Perform Dynamic Quantized Matrix Multiplication:
Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights.
```python
output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights, groupsize, in_features, out_features)
```
Inputs required include:
The input tensor, packed_weights , groupsize, and the in_features and out_features.
API Usage: https://github.com/pytorch/pytorch/issues/143289
Model Perf :
7B Transformer model:
Prefill : 340 t/s
Decode : 40 t/s
2B Transformer model
Prefill : 747 t/s
Decode : 80 t/s
Tests:
python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight
Ran 1 test in 0.016s
OK
python test/test_linalg.py -k test__dyn_quant_matmul_4bit
Ran 8 tests in 0.077s
OK
python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit
Ran 8 tests in 11.454s
Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124
Approved by: https://github.com/digantdesai, https://github.com/malfet
Description:
1. Quantize Linear Layer Weights to 4-bits:
Quantize the weights of the Linear layer to 4 bits, using symmetric quantization.
Pack two 4-bit weights into one uint8 container.
Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32.
2. Prepare Quantized Weights, Scales, and Optional Bias:
After quantizing, obtain the quantized_weights, scales, and groupsize.
If the original Linear layer has a bias, prepare it as well.
3. Pack the Weights Efficiently:
Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias.
```python
packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features)
```
Input parameters should include:
in_features and out_features (the same as the Linear layer’s corresponding parameters).
4. Perform Dynamic Quantized Matrix Multiplication:
Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights.
```python
output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights, groupsize, in_features, out_features)
```
Inputs required include:
The input tensor, packed_weights , groupsize, and the in_features and out_features.
API Usage: https://github.com/pytorch/pytorch/issues/143289
Model Perf :
7B Transformer model:
Prefill : 340 t/s
Decode : 40 t/s
2B Transformer model
Prefill : 747 t/s
Decode : 80 t/s
Tests:
python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight
Ran 1 test in 0.016s
OK
python test/test_linalg.py -k test__dyn_quant_matmul_4bit
Ran 8 tests in 0.077s
OK
python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit
Ran 8 tests in 11.454s
Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124
Approved by: https://github.com/digantdesai, https://github.com/malfet
Replaces https://github.com/ROCm/pytorch/pull/1592
This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling `torch.backends.cuda.preferred_rocm_fa_library("ck")`. Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via `USE_FLASH_ATTENTION`) and is selected at runtime by the existing heuristics.
Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author
NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138947
Approved by: https://github.com/pruthvistony, https://github.com/xw285cornell, https://github.com/leitian
Co-authored-by: Xiaodong Wang <xw285@cornell.edu>
Summary:
This PR adds a lowering for `torch._cslt_sparse_mm` to find the optimal
alg_id and cache it when running with `torch.compile`
Seeing speedups on both bfloat16 and float8 dtypes:
<img width="641" alt="Screenshot 2024-10-17 at 2 10 38 PM" src="https://github.com/user-attachments/assets/b928cd11-32a3-43e5-b209-8e4028896f0b">
<img width="1274" alt="Screenshot 2024-10-17 at 1 39 03 PM" src="https://github.com/user-attachments/assets/d9edd684-a8ec-46fd-b3da-2e76dbcb7bb6">
* `torch._cslt_sparse_mm_search` has been modified to return optimal
split-k parameters as well as max alg_id.
* max_id is now available in `torch.backends.cusparselt` via
`torch.backends.cusparselt.get_max_alg_id()`
* fixed meta registrations for float8
Test Plan:
python test/test_sparse_semi_structured.py
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427
Approved by: https://github.com/cpuhrsch
- composable_kernel as a third_party submodule
- "ck" as a `torch.backends.cuda.preferred_linalg_library()`
- reference CK gemm implementations for float, bfloat16, and half types
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131004
Approved by: https://github.com/xw285cornell, https://github.com/pruthvistony
Co-authored-by: Andres Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Pruthvi Madugundu <pruthvigithub@gmail.com>
Summary:
This PR adds `torch.float8e4m3fn` support to cuSPARSELt and `to_sparse_semi_structured`.
This will let users to run fp8 + 2:4 sparse matmuls on Hopper GPUs with
cusparselt >= 0.6.2, via to `scaled_mm` API.
```
A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16)
B = torch.rand(dense_input_shape, device=device).to(torch.float16).t()
A_fp8, A_scale = to_float8(A)
B_fp8, B_scale = to_float8(B)
dense_result = torch._scaled_mm(
A_fp8, B_fp8,
scale_a=A_scale, scale_b=B_scale,
out_dtype=out_dtype
)
A_fp8_sparse = to_sparse_semi_structured(A_fp8)
sparse_result = torch._scaled_mm(
A_fp8_sparse, B_fp8,
scale_a=A_scale, scale_b=B_scale,
out_dtype=out_dtype
)
```
Note that to keep this consistent with normal torch behavior, calling
`torch.mm(A_fp8_sparse, B_fp8)` will raise a NotImplementedError.
I also turned on cuSPARSELt by default and added CUSPARSELT_MAX_ID to the
backend to make the tests a bit cleaner
Test Plan:
```
python test/test_sparse_semi_structured -k scaled_mm
python test/test_sparse_semi_structured -k fp8
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136397
Approved by: https://github.com/drisspg
**Motivation:**
In Pytorch, Aten vectorization supports multiple platforms, including x86 and Arm, as well as multiple data types. It provides a generic implementation of Vector (Vec) type that allows the programmer to write code packing various primitives (such as floats) within 256bit & 512bits registers. It can be extended to support other ISAs easily by adding more VecISA sub-classes.
**Reference Link:** https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cpu/vec
**This PR:**
* Our goal with this contribution is to add support for SVE backend for Vec in the Aten vectorization for CPU backend which can be benefitted by any ARM architecture supported CPU's that supports SVE.
* More about SVE ISA for ARM: [https://developer.arm.com/Architectures/Scalable Vector Extensions](https://developer.arm.com/Architectures/Scalable%20Vector%20Extensions)
* We are using the ARM C Language Extensions for SVE (https://developer.arm.com/documentation/102699/0100/Optimizing-with-intrinsics ) to accelerate performance for various operators in the SVE backend for Vec.
* Currently we are adding support only for SVE ISA with the vector length of 256 bits (SVE 256). In future, we plan to extend this SVE support for other vector lengths as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119571
Approved by: https://github.com/malfet, https://github.com/snadampal
Co-authored-by: Divya Kotadiya <divya.kotadiya@fujitsu.com>
Summary:
This PR adds in cuSPARSELt as a backend to PyTorch.
It is now possible to see if cuSPARSELt is available and the version if
it is with
```
torch.backends.cusparselt.is_available()
torch.backends.cusparselt.version()
```
Test Plan:
```
python test/test_sparse_semi_structured.py -k test_cusparselt_backend
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128534
Approved by: https://github.com/cpuhrsch, https://github.com/eqy, https://github.com/syed-ahmed
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007