Commit Graph

125 Commits

Author SHA1 Message Date
PyTorch MergeBot
af8a50e656 Revert "Fix allowed dtypes for mem_eff attention (#116026)"
This reverts commit fc58909bab.

Reverted https://github.com/pytorch/pytorch/pull/116026 on behalf of https://github.com/jeanschmidt due to breaking internal windows buck builds, check internal diff for more details ([comment](https://github.com/pytorch/pytorch/pull/116026#issuecomment-1864354665))
2023-12-20 12:01:34 +00:00
drisspg
fc58909bab Fix allowed dtypes for mem_eff attention (#116026)
# Summary

Fix issue bug in detecting mem eff capability for cuda devices less than sm80:
https://github.com/pytorch-labs/gpt-fast/issues/49

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116026
Approved by: https://github.com/janeyx99
2023-12-18 23:20:52 +00:00
Jeff Daily
e3aefe2970 Revert "Initial Flash Attention support on ROCM (#114309)" (#115975)
This reverts commit 5bddbed399.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115975
Approved by: https://github.com/atalman, https://github.com/malfet
2023-12-16 03:40:14 +00:00
Xinya Zhang
5bddbed399
Initial Flash Attention support on ROCM (#114309)
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.

Fixes https://github.com/pytorch/pytorch/issues/112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114309

Approved by: https://github.com/jeffdaily, https://github.com/malfet

---------

Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
2023-12-14 08:52:57 -08:00
Fuzzkatt
661c1cf2aa numerical mismatch fix for test_mem_efficient_attention_attn_mask_vs_math_ref_grads in test_transformers.py (#115707)
adjust dropout_fudge_factor since previous fudge factor was too small and led to numerical mismatch in NVIDIA internal CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115707
Approved by: https://github.com/drisspg
2023-12-14 01:04:39 +00:00
Valentine233
064846dbc2 [cpu] flash attention optimization (#115151)
### Modifications
- **EXP**: Add a fast version with a reduced accuracy (ULP20) to vec exp `exp_u20` and use it in flash attention.
- **FUSION**: Do fusion for `softmax` ops.
- **SCALE**: Move the calculation of `scaling_factor` after `gemm`.

### Performance
_Model: Stable Diffusion V2.1_

| Version | BF16 Kernel latency (s) | BF16 speedup | FP32 Kernel latency (s) | FP32 speedup |
| ----- | ----- | ----- | ----- | ----- |
| PT | 15.865 |  | 35.362 |  |
| PT + EXP | 12.518 | 21.10% | 19.327 | 45.35% |
| PT + EXP + FUSION | 11.774 | 25.79% | 18.306 | 48.23% |
| PT + EXP + FUSION + SCALE | 11.053 | 30.33% | 18.360 | 48.08% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115151
Approved by: https://github.com/jgong5, https://github.com/drisspg
2023-12-12 01:09:55 +00:00
drisspg
d4c79a3078 Add an attention bias subclass for a lower right causal masking (#114823)
# Summary
This PR introduces a new Tensor subclass that is designed to be used with torch.nn.functional.scaled_dot_product_attention. Currently we have a boolean `is_causal` flag that allows users to do do causal masking without the need to actually create the "realized" attention bias and pass into sdpa. We originally added this flag since there is native support in both fused kernels we support. This provides a big performance gain ( the kernels only need to iterate over ~0.5x the sequence, and for very large sequence lengths this can provide vary large memory improvements.

The flag was introduced when the early on in the kernel development and at the time it was implicitly meant to "upper_left" causal attention. This distinction only matters when the attention_bias is not square. For a more detailed break down see: https://github.com/pytorch/pytorch/issues/108108. The kernels default behavior has since changed, largely due to the rise of autogressive text generation. And unfortunately this would lead to a BC break. In the long term it may actually be beneficial to change the default meaning of `is_causal` to represent lower_right causal masking.

The larger theme though is laid here: https://github.com/pytorch/pytorch/issues/110681. The thesis being that there is alot of innovation in SDPA revolving around the attention_bias being used. This is the first in hopefully a few more attention_biases that we would like to add. The next interesting one would be `sliding_window` which is used by the popular mistral model family.

Results from benchmarking, I improved the meff_attention perf hence the slightly decreased max perf.
```Shell
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
|  Type   |      Speedup       | batch_size | num_heads | q_seq_len | k_seq_len | embed_dim |     dtype      | head_dim |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
| Average | 1.2388050062214226 |            |           |           |           |           |                |          |
|   Max   | 1.831672915579016  |    128     |    32     |   1024    |   2048    |   2048    | torch.bfloat16 |    64    |
|   Min   | 0.9430534166730135 |     1      |    16     |    256    |    416    |   2048    | torch.bfloat16 |   128    |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114823
Approved by: https://github.com/cpuhrsch
2023-12-06 08:29:26 +00:00
drisspg
8556a09d44 Require less alignment for attn bias (#114173)
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)

This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
2023-11-28 02:40:41 +00:00
PyTorch MergeBot
88a8a0daa4 Revert "Require less alignment for masking (#114173)"
This reverts commit f882c175d8.

Reverted https://github.com/pytorch/pytorch/pull/114173 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but it is failing some inductor tests f882c175d8 ([comment](https://github.com/pytorch/pytorch/pull/114173#issuecomment-1823552362))
2023-11-22 21:49:31 +00:00
drisspg
f882c175d8 Require less alignment for masking (#114173)
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)

This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
2023-11-22 20:02:51 +00:00
drisspg
9b0f2f8d94 expose sdpa helpers to python (#110496)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110496
Approved by: https://github.com/jbschlosser
2023-11-15 07:34:34 +00:00
drisspg
14811d69d7 [BE] Cleanup sdpa test helper usage (#113294)
# Summary

standardizes usage of the rand_sdpa_tensor helper

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113294
Approved by: https://github.com/soulitzer
2023-11-09 01:16:53 +00:00
drisspg
e509b162ed Disable FlashAttenion for is_causal=True when seqlen q not equal kv (#111007)
# Summary:
This pull request **removes** support for non-square sequence lengths in causal attention when using FlashAttention V2.

### Why are doing this
  // FlashAttention 2 updated the default mask meaning for causal in this PR:
  // 9e5e8bc91e it is now aligned to lower_right which would be a BC break
  // for non-square masks. We will not support non-square masks for causal w/ FAV2

 For more context see:
 https://github.com/pytorch/pytorch/issues/108108

 ### Followup
 A large number of people will likely want to use FAV2 with lower_right causal attention for non equal sequence lengths. See this RFC : https://github.com/pytorch/pytorch/issues/110681

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111007
Approved by: https://github.com/cpuhrsch
2023-10-23 20:33:37 +00:00
drisspg
5183760ca5 Adding Backward Support for NestedTensors and FlashAttention (#97485)
# Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 318764f</samp>

This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`.

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at ed4a773</samp>

> _Fused kernels of doom, unleash the flash attention_
> _Nested tensors on fire, reshape and pad with caution_
> _Backward pass of power, dispatch the CUDA key_
> _Test the gradients of hell, warn the user if they disagree_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97485
Approved by: https://github.com/jbschlosser
2023-10-10 18:08:17 +00:00
Fuzzkatt
c28bb46445 Fix test_mem_efficient_attention_vs_math_ref_grads tolerance from test_transformers.py (#108094)
Tolerance currently too low, triggering test failures via numerical mismatch in NVIDIA internal testing for certain H100, A16, A40 configs. cc: @ptrblck @eqy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108094
Approved by: https://github.com/eqy, https://github.com/msaroufim
2023-10-02 20:42:57 +00:00
PyTorch MergeBot
8d6479725a Revert "Adding Backward Support for NestedTensors and FlashAttention (#97485)"
This reverts commit 28d69d5256.

Reverted https://github.com/pytorch/pytorch/pull/97485 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but one of the tests test_fused_kernels_nested_broadcasting_requires_grad_failure_cuda is failing on Windows CUDA f7ba3e85e2 ([comment](https://github.com/pytorch/pytorch/pull/97485#issuecomment-1743474468))
2023-10-02 17:48:57 +00:00
drisspg
28d69d5256 Adding Backward Support for NestedTensors and FlashAttention (#97485)
# Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 318764f</samp>

This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`.

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at ed4a773</samp>

> _Fused kernels of doom, unleash the flash attention_
> _Nested tensors on fire, reshape and pad with caution_
> _Backward pass of power, dispatch the CUDA key_
> _Test the gradients of hell, warn the user if they disagree_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97485
Approved by: https://github.com/jbschlosser
2023-09-29 21:34:47 +00:00
drisspg
ad90ab31f2 Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-09-13 13:59:05 +00:00
Huy Do
a9c663c269 Revert "Flash Attention v2 (#105602)" (#108827)
This reverts commit add45aea1c.

There are some conflicts on some benchmark csv file https://github.com/pytorch/pytorch/pull/105602#issuecomment-1710988951 so I need to revert this manually.

The diff has been reverted internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108827
Approved by: https://github.com/kit1980
2023-09-08 07:43:04 +00:00
PyTorch MergeBot
e45b290127 Revert "Revert "Flash Attention v2 (#105602)" (#108827)"
This reverts commit 24e9bbe22a.

Reverted https://github.com/pytorch/pytorch/pull/108827 on behalf of https://github.com/huydhn due to I need to land this revert properly as there are new failures showing up on trunk ([comment](https://github.com/pytorch/pytorch/pull/108827#issuecomment-1711020924))
2023-09-08 03:25:45 +00:00
Huy Do
24e9bbe22a Revert "Flash Attention v2 (#105602)" (#108827)
This reverts commit add45aea1c.

There are some conflicts on some benchmark csv file https://github.com/pytorch/pytorch/pull/105602#issuecomment-1710988951 so I need to revert this manually.

The diff has been reverted internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108827
Approved by: https://github.com/kit1980
2023-09-08 02:54:20 +00:00
Michael Gschwind
2a40fe2dbf [experimental] use EXCEPT_FOR env to suppress CPU tests from GPU RE (#108672)
Summary:
[experimental] use EXCEPT_FOR env to suppress CPU tests from GPU RE -- alternative implementation to D48997976 using preexisting PYTORCH_TESTING_DEVICE_EXCEPT_FOR facility and building remaining logic (for assert-positive listers like test_transformers)  on top of that.

Goal: save ~100 GPU (10% of capacity), enables us to fund more aggressive PyPer unit testing on GPU RE

Test Plan: sandcastle, github

Differential Revision: D48998582

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108672
Approved by: https://github.com/bertmaher
2023-09-06 23:33:18 +00:00
drisspg
add45aea1c Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-09-01 22:14:44 +00:00
PyTorch MergeBot
d569e506ab Revert "Flash Attention v2 (#105602)"
This reverts commit 9df3d882c8.

Reverted https://github.com/pytorch/pytorch/pull/105602 on behalf of https://github.com/huydhn due to I think we miss a case here for sm80 build on inductor workflow as it is now OOM on trunk https://github.com/pytorch/pytorch/actions/runs/6042843139 ([comment](https://github.com/pytorch/pytorch/pull/105602#issuecomment-1701974862))
2023-09-01 01:15:01 +00:00
drisspg
9df3d882c8 Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-08-31 16:02:20 +00:00
drisspg
42d60d012e Bias overflow fix mem eff bias (#107968)
Fixes #107959
This should have been fixed here https://github.com/pytorch/pytorch/pull/103201
Edit:
Looking at git blame it appears the dropout revet squashed the changes from this PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107968
Approved by: https://github.com/cpuhrsch
2023-08-26 00:00:49 +00:00
Mikayla Gawarecki
48b1208e05 Disable nn.MHA fastpath for floating point masks (#107641)
Fixes https://github.com/pytorch/pytorch/issues/107084 by disabling the fast path when floating point masks (which should be additive) are passed

- [We claim in our docs for MHA that float masks will be added to the attention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) (be it `key_padding_mask` or `attn_mask`)
- We always canonicalize any mask at the start of MHA in python by converting it to float
- my understanding from Driss is that SDPA properly supports additive masking (but there are many special cases for mask shape for MHA that don't work properly currently (BxT, TxT) so [we're turning this off for now](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L531-L532)
- More broadly, the problem isn't with the SDPA path, but that things are broken for the path it falls back to
-  Right now mha "fast path" code with non-None masks is always going through [this path ](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L554-L640) that  has a call to `masked_softmax` that [converts the masks back to bool](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp#L154-L156)
- the implication here is that **additive floating point attn_mask and additive key_padding_mask to nn.MHA fastpath are broken**
- This wasn't broken for the user in [https://github.com/pytorch/pytorch/issues/107084](https://l.workplace.com/l.php?u=https%3A%2F%2Fgithub.com%2Fpytorch%2Fpytorch%2Fissues%2F107084&h=AT35qHIQavtxKtriTkrkPsWRB3eSRh4qH5PQUyiTzrPTshoztPL0593AmKCmSdEQ5O-5wib0Fd4mwztVu4YbMWb2ghZnZw1pvpJb9-FYWjDsPQ6_oHRVPzFfj8xYXC1TaFnJCkMYjrGXkIfzzxZvmcQYNnIPgsJSiWgjIw) in 1.13.1 because of [this check which bypassed the fast path if attn_mask was defined](https://github.com/pytorch/pytorch/blob/v1.13.1/torch/nn/modules/activation.py#L1096-L1097) (as Driss pointed out though additive key_padding_mask with the fast path were probably  broken in 1.13.1)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107641
Approved by: https://github.com/drisspg, https://github.com/jbschlosser
2023-08-23 15:08:18 +00:00
Aaron Gokaslan
660e8060ad [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-22 23:16:38 +00:00
PyTorch MergeBot
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e4322.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00
Liao, Xuan
71632d4d24 [cpu] add sdpa choice and UT (#105131)
Feature RFC: https://github.com/pytorch/rfcs/pull/56.

Write an SDPA selecting function for CPU to automatically choose one SDPA implementation among several ones. There are two CPU implementations which could be chosen: the unfused SDPA and flash attention. In general, flash attention has a higher priority than the unfused SDPA. For cases where flash attention is not applicable, such as manually disabling flash attention or the inputs not 4 dimensional, the unfused SDPA is chosen.

## Performance of the stack

### NanoGPT's SDPA kernel
Using benchmark [repo](https://github.com/mingfeima/bench_sdpa/blob/main/README.md), with one socket.
Shape: Batch size 1, Sequence length 1024, Head number 25, Head size 64.
Machine: SPR.

| Dtype    | Causal   | Mode      | SDPA            | Time (ms per iter) | Speedup |
| -------- | -------- | -------   | -------         | -------            | ------- |
| float32  | FALSE    | Inference | Unfused         | 3.081              |         |
|          |          |           | Flash attention | 1.665              | **1.85045** |
| float32  | TRUE     | Inference | Unfused         | 3.463              |         |
|          |          |           | Flash attention | 1.662              | **2.083634**|
| bfloat16 | FALSE    | Inference | Unfused         | 1.203              |         |
|          |          |           | Flash attention | 1.154              | **1.042461**|
| bfloat16 | TRUE     | Inference | Unfused         | 1.543              |         |
|          |          |           | Flash attention | 1.154              | **1.337088**|
| float32  | FALSE    | Training  | Unfused         | 54.938             |         |
|          |          |           | Flash attention | 23.029             | **2.385601**|
| float32  | TRUE     | Training  | Unfused         | 58.266             |         |
|          |          |           | Flash attention | 17.835             | **3.266947**|
| bfloat16 | FALSE    | Training  | Unfused         | 18.924             |         |
|          |          |           | Flash attention | 18.886             | **1.002012**|
| bfloat16 | TRUE     | Training  | Unfused         | 21.08              |         |
|          |          |           | Flash attention | 14.172             | **1.48744** |

### Stable Diffusion
Following model's [BKM](https://github.com/intel-innersource/frameworks.ai.models.intel-models/blob/develop/quickstart/diffusion/pytorch/stable_diffusion/inference/cpu/README.md).
Mode: Inference; Machine: SPR.

| Dtype    | SDPA                    | Throughput (fps) | Speedup SDPA | Total Time (ms) | Speedup |
| -------- | --------                | -------          | -------      | -------         | ------- |
| float32  | Unfused                 | 1.63             |              | 1139            |         |
|          | Flash attention         | 1.983            | 1.216564     | 547.488         | **2.080411**|
| bfloat16 | Flash attention in IPEX | 4.784            |              | 429.051         |         |
|          | Flash attention         | 4.857            | 1.015259     | 408.823         | **1.049479**|

### LLM models of Torchbench

Dtype: float32; Mode: Inference, single socket; Machine: CPX.
Model   name | SDPA | Inductor_new | Inductor_old | Inductor   Ratio(old/new)
-- | -- | -- | -- | --
hf_Albert | Unfused -> Flash attention | 0.048629309 | 0.05591545 | **1.14983024**
hf_Bert | Unfused -> Flash attention | 0.053156243 | 0.060732115 | **1.142520841**
hf_Bert_large | Unfused -> Flash attention | 0.141089502 | 0.155190077 | **1.099940636**
llama | Unfused -> Flash attention | 0.033250106 | 0.033720745 | **1.01415451**

Dtype: bfloat16; Mode: Inference, single socket; Machine: SPR.
Model   name | SDPA | Inductor_new | Inductor_old | Inductor   Ratio(old/new)
-- | -- | -- | -- | --
hf_Albert | Unfused -> Flash attention | 0.020681298 | 0.020718282 | **1.001788324**
hf_Bert | Unfused -> Flash attention | 0.019932816 | 0.019935424 | **1.000130842**
hf_Bert_large | Unfused -> Flash attention | 0.047949174 | 0.048312502 | **1.007577355**
llama | Unfused -> Flash attention | 0.018528057 | 0.01861126 | **1.0044907**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105131
Approved by: https://github.com/drisspg
ghstack dependencies: #104583, #104584, #103826, #104693, #104863, #107128
2023-08-20 08:56:21 +00:00
Aaron Gokaslan
88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00
Fuzzkatt
3c7331742a test_fused_sdp_choice in test_transformers.py fix (#106587)
sdp dispatcher prioritizes flash attention over efficient attention: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp#L684-L687, and flash attention is enabled for sm75+: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp#L625. Thus, the unit test `test_fused_sdp_choice` from `test_transformers.py` which is failing on T4 (sm75) should have this `SM80OrLater` check changed to `SM75OrLater`: https://github.com/pytorch/pytorch/blob/main/test/test_transformers.py#L1914-L1917.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106587
Approved by: https://github.com/drisspg
2023-08-04 03:43:56 +00:00
drisspg
cfa4edcde0 [SDPA] Update dispatch checks to catch last_dim_stride != 1. Also update mask padding logic (#106102)
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at bb1fc29</samp>

This pull request simplifies and refactors the code for fused scaled dot product attention kernels in `attention.cu` and `sdp_utils.cpp`, and adds new input validation checks and tests. It also modifies the `sdp_params` struct to store optional mask tensors directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106102
Approved by: https://github.com/cpuhrsch
2023-08-01 19:13:01 +00:00
XiaobingSuper
55f9359d36 fix sdpa math accuracy issue when scale is negative (#105202)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105202
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/drisspg
2023-08-01 00:19:14 +00:00
Fuzzkatt
1cebfef8a4 sm90 efficient attention test fixes (#105978)
Fixes the following two test cases involving efficient attention on sm90:

Explanations:

functorch/test_ops.py: test_vjp_nn_functional_scaled_dot_product_attention_cuda_float32
* originally the test had xfail for all sm
* in https://github.com/pytorch/pytorch/issues/102029, we found that it was unexpectedly passing on sm90
* I made https://github.com/pytorch/pytorch/pull/102131 to update the test to let it pass
* @drisspg seems to have made changes to the behavior such that the original xfail was getting triggered (https://github.com/pytorch/pytorch/issues/102029#issuecomment-1560071148)
* the CI began complaining about the failure again: https://github.com/pytorch/pytorch/issues/102663
* I'm now reverting https://github.com/pytorch/pytorch/pull/102131 to bring back the original xfail now that the behavior has been fixed by @drisspg to trigger the xfail in sm90 similar to all other sm

test_transformers.py: test_mem_efficient_fail_sm90_cuda
* the test as it's currently written seems to expect the sdp dispatcher to fail for mem efficient attention on sm90; however, testing this on H100, it actually succeeds, so I'm disabling the test for now as the current expected result may be outdated

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105978
Approved by: https://github.com/eqy, https://github.com/kshitij12345, https://github.com/zou3519
2023-07-31 17:59:40 +00:00
drisspg
cb9a4fbbf2 [BE] Improve test_transformers test structure (#105938)
# Summary

We have a vast majority of test that only run on cuda. Decorating with @onlycuda causes pytest to instantiate 2x the tests and skip half of them. This overhead is non trivial when the #tests cross larger like it has for this file.

This breaks up the cuda only tests into a separate class
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105938
Approved by: https://github.com/mikaylagawarecki, https://github.com/malfet
2023-07-26 22:16:20 +00:00
drisspg
c4b7311fc2 Meff Attn Bias (#104310)
# Summary

### Review Points
- Automatically pad tensors to create aligned masks when seqlen_kv is not multiple of 16. This will cause memory spike ~ 2 * attn_mask size which could in theory be big.  At appears though that doing this + mem_eff is faster than no_pad + math. SO seems to be worth it
- Using expand to view the attn_mask in 4d. This is a little different to how we enforce q,k,v to be viewed in 4d prior to calling. Also not supprint b*n_heads, seq_lenq, seq_lenkv case.
- Should enable, #96099

### Profiling
I ran a bunch of comparisons between sdpa.MATH and sdp.MemEffAttention.  I added a attn_bias of shape (1, 1, seqlen_q, seqln_k). For these experiments seqlen_q == seqlen_k. These were all ran on an a100 80gb gpu.
Configs:
```
    # Run a bunch of experiments
    batch_sizes = [8, 16, 32]
    num_heads = [16, 32]
    max_seq_lens = [15, 64, 128, 512, 555, 1024]
    embed_dims = [32, 64, 128]
    dtypes = [torch.float16, torch.bfloat16, torch.float32]
    pad_percentages = [None]
    backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
    run_backward = True
    attn_mask = True
```

   The function calls `sdpa(input**).sum().backward()`.

   I calculated the geomean speedup of the efficient attention path of the math path for all these configs:
   `Geomean Speedup: 1.977`

An example comparision with batchsize = 8, num_heads = 32, embed_dim = 64, and dtype = torch.float16:
![attn_mask_compare_bsz_8_num_heads_32_embed_dim_64_dtype_fp16](https://github.com/pytorch/pytorch/assets/32754868/0d75bffe-350b-43f2-a37f-514f9158dcff)

 This was done using the current state of the branch where we force alignment of mask when the last dim is not divisible by 16, which shows up in seq_len = 15 and 555 case.

The full data can be found here:

[attn_mask_sweep.csv](https://github.com/pytorch/pytorch/files/11962399/attn_mask_sweep.csv)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104310
Approved by: https://github.com/cpuhrsch
2023-07-26 15:51:59 +00:00
PyTorch MergeBot
340ec1f460 Revert "Meff Attn Bias (#104310)"
This reverts commit 5453508115.

Reverted https://github.com/pytorch/pytorch/pull/104310 on behalf of https://github.com/DanilBaibak due to PR introduced cuda OOM issue ([comment](https://github.com/pytorch/pytorch/pull/104310#issuecomment-1650171538))
2023-07-25 16:37:32 +00:00
drisspg
5453508115 Meff Attn Bias (#104310)
# Summary

### Review Points
- Automatically pad tensors to create aligned masks when seqlen_kv is not multiple of 16. This will cause memory spike ~ 2 * attn_mask size which could in theory be big.  At appears though that doing this + mem_eff is faster than no_pad + math. SO seems to be worth it
- Using expand to view the attn_mask in 4d. This is a little different to how we enforce q,k,v to be viewed in 4d prior to calling. Also not supprint b*n_heads, seq_lenq, seq_lenkv case.
- Should enable, #96099

### Profiling
I ran a bunch of comparisons between sdpa.MATH and sdp.MemEffAttention.  I added a attn_bias of shape (1, 1, seqlen_q, seqln_k). For these experiments seqlen_q == seqlen_k. These were all ran on an a100 80gb gpu.
Configs:
```
    # Run a bunch of experiments
    batch_sizes = [8, 16, 32]
    num_heads = [16, 32]
    max_seq_lens = [15, 64, 128, 512, 555, 1024]
    embed_dims = [32, 64, 128]
    dtypes = [torch.float16, torch.bfloat16, torch.float32]
    pad_percentages = [None]
    backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
    run_backward = True
    attn_mask = True
```

   The function calls `sdpa(input**).sum().backward()`.

   I calculated the geomean speedup of the efficient attention path of the math path for all these configs:
   `Geomean Speedup: 1.977`

An example comparision with batchsize = 8, num_heads = 32, embed_dim = 64, and dtype = torch.float16:
![attn_mask_compare_bsz_8_num_heads_32_embed_dim_64_dtype_fp16](https://github.com/pytorch/pytorch/assets/32754868/0d75bffe-350b-43f2-a37f-514f9158dcff)

 This was done using the current state of the branch where we force alignment of mask when the last dim is not divisible by 16, which shows up in seq_len = 15 and 555 case.

The full data can be found here:

[attn_mask_sweep.csv](https://github.com/pytorch/pytorch/files/11962399/attn_mask_sweep.csv)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104310
Approved by: https://github.com/cpuhrsch
2023-07-24 22:19:26 +00:00
Michael Gschwind
11b753af01 Refactor causal mask generation and detection for nn.transformer (#105265)
Summary:
* Create a private global-scope function _generate_subsequent because static class attribute member functions not supported by TorchScript resulting in torchscripting errors.
* Make TransformerEncoder and TransformerDecoder consistent w.r.t. is_causal handling by calling _detect_casual_mask
* Clarify documentation that is_causal is a hint
* Move causal mask detection into a method _detect_causal_mask
* only accept input-size compatible causal mask as causal mask
* update _generate_subsequent_causal_mask to include factory kwargs for dtype and device:
   avoid extra copies & conversions by passing directly to torch.full.

Test Plan: sandcastle & github CICD
Continuation of #101487 (due to a tooling issue) which is a continuation-in-part of https://github.com/pytorch/pytorch/pull/98327 by @janEbert

Differential Revision: D47427117

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105265
Approved by: https://github.com/mikaylagawarecki
2023-07-19 01:26:50 +00:00
Michael Gschwind
07a1c3f7ff Exercise subclass of TransformerEncoderLayer (#105297)
Summary: Exercise subclass of TransformerEncoderLayer
Additional unit tests for change in #102045 to show correct e2e operation (cf. issue #100188)

Also: remove batch_first from list of TS module constants where it is not used to resolve torchscripting warning

Test Plan: saqndcastle, github

Differential Revision: D47503004

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105297
Approved by: https://github.com/davidberard98
2023-07-17 16:03:10 +00:00
Driss Guessous
4a008d268a REDO of dropout support for mem eff #102038 (#103704)
THIS IS A new PR with the changes from #102038 + #103201 +  plus namespacing changes to fix bug.

# Summary
This PR builds off of:
- https://github.com/pytorch/pytorch/pull/101847
- https://github.com/pytorch/pytorch/pull/100583

It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made:
- Update sdpa dispatching to allow for inputs requiring grad to be sent to efficient attention
- Update how memory efficient attention handles passing the rng state from forward to backward in order to enable cuda_graph support
- Fix a bug in the kernel that was causing incorrect gradients to be produced for num_keys > 64 with dropout and causal masking set. https://github.com/facebookresearch/xformers/pull/755

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103704
Approved by: https://github.com/cpuhrsch
2023-06-26 23:05:03 +00:00
Yinghai Lu
4c3799447f
Back out "Dropout support for memory efficient attention (#102038)" & "Two small mem_eff bug fixes (#103201)" (#103464)
Summary:
Original commit changeset: 04c4473d8510

Original Phabricator Diff: D46584152 & D46582033

Test Plan: Already explained in summary.

Reviewed By: yinghai

Differential Revision: D46633283

fbshipit-source-id: c23c2945408988f3c4339dfd5cd40ae46261716c

Co-authored-by: Shenxiu Liu <shenxiu@meta.com>
2023-06-12 18:56:48 -07:00
Driss Guessous
606fb882c4 Dropout support for memory efficient attention (#102038)
# Summary
This PR builds off of:
- https://github.com/pytorch/pytorch/pull/101847
- https://github.com/pytorch/pytorch/pull/100583

It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made:
- Update sdpa dispatching to allow for inputs requiring grad to be sent to efficient attention
- Update how memory efficient attention handles passing the rng state from forward to backward in order to enable cuda_graph support
- Fix a bug in the kernel that was causing incorrect gradients to be produced for num_keys > 64 with dropout and causal masking set. https://github.com/facebookresearch/xformers/pull/755

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102038
Approved by: https://github.com/cpuhrsch
2023-06-08 21:50:12 +00:00
Driss Guessous
2800a04a17 Add device range helper and remove sm86 specific check for memory efficient attention (#102985)
# Summary
Since we have upstreamed the latest changes of memory efficient attetnion we can remove the sm86/sm89 specific check. All head_sizes (assuming correctly alignment) should work for sm86 and sm89 size and don't have a max capability.

If head_size > 96 there will be a big drop in performance but should not error and still maintain memory savings by not materializing attention weights.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102985
Approved by: https://github.com/cpuhrsch
2023-06-07 00:28:40 +00:00
Michael Gschwind
4d89489df5 Move static checks of layers[0] (e.g., isinstance check) to model build time (#102045)
Summary: Move static checks of layers[0] (e.g., isinstance check) to model build time because isinstance() does not work for torchscripted code.  Because the validation is now performed while constructing the object, the isinstance() call is performed in eager mode at model build time, and we avoid needing to call  isinstance() at runtime to determine whether the layers in a model are an instance of the TransformerEncoderLayer class, or its derived classes.

Test Plan: sandcastle, github

Differential Revision: D46096222

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102045
Approved by: https://github.com/mikaylagawarecki
2023-05-30 19:42:01 +00:00
Driss Guessous
ef13fde290 Increase mem eff backward performance (#101847)
# Summary
This is another upstream which is much smaller than the previous.
This bumps the kernel versions from  xformers
Current: [6425fd0cacb1a6579aa2f0c4a570b737cb10e9c3](6425fd0cac)
With this PR: [1d635e193e169fc677b2e7fa42dad7ebe88eec9e](1d635e193e)

### Notable Changes:
- Drastically improve the BW pass in multiple cases (especially when B*numHeads < 100)
- H100 Support: *Warning* While these kernels have been added, we don't have the CI/CD machines to test.
- Enables a deterministic mode.

## Specific Changes
- Updates to the backward kernel.
- Added num_splits_key which we hard code to -1. (This is a another performance knob that we set to the heuristic)
- Update gen_code and kernels to produce h100 instantiations.

### Due Diligence Checks:
* CUDA_lib size: No changes in size

#### Peformance
* Micro Benchmark: (batch_size: 1, num_heads=25, seq_len=4096, embed_dim = 64 | grid:[1,25,1]block: [128,1,1])
    * MemEfficientAttention Backward Kernel: 27.972 ms
    * After the updated Xformers code(https://github.com/pytorch/pytorch/pull/100583): 23.958 ms
    * With this PR: 4.085 ms
* Ran micro benchmarks on sdpa_forw().sum().backward() over a range of dtypes, and input shapes
   * Geo_mean increase -> 1.17x
   * Max increase -> 2.95x
   * min_increase -> 0.8x

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101847
Approved by: https://github.com/cpuhrsch
2023-05-26 02:25:31 +00:00
Aaron Gokaslan
3e2ea32dab [BE]: Enable ruff rule TRY302 and apply fixes (#101874)
Removes useless try statements and unreachable code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101874
Approved by: https://github.com/malfet
2023-05-19 17:30:52 +00:00
Driss Guessous
c9ba967c21 Upstream xformers code (#100583)
# Summary
Since the initial upstream of memory efficient attention from xformers: #86157, significant work updates have been made to the kernel including - increased performance, bug-fixes, and added functionality. This PR upstreams the latest version of this kernel as of: version 0.0.20 or commit: [6425fd0cacb1a6579aa2f0c4a570b737cb10e9c3](6425fd0cac)

## Future
Although this version of the Kernel has support for dropout and arbitrary attention bias, I did not add this support to SDPA yet, and left the guards in sdp_utils. Those will follow up PRs in order to reduce the scope creep of these substantial changes, and ensure that nothing is broken.

## Specific Changes
### Minor Changes
* The build system work was done in the previous PR and so no changes were needed to CMAKE 🤞
* Adding the new files and re-arranging/creating folder structure
* Updating include paths
* Switching from xformer specific functions: `XFORMERS_CHECK -> TORCH_CHECK`
* Changes to xformer specific macros
* Updates to the `generate_kernels.py` to use account for Pytorch file structure, also added an arg parse that I could run on a test dir before creating the files in place.

### Bigger Changes
* Previous Kernel changes "Removed the chunk optimization: see discussion here: https://github.com/pytorch/pytorch/pull/96880"
* Increased the number of cuda kernels -> potentially effecting the cuda_lib size.
* Preemptively made changes to the dtypes of seed and offset in order to allow for cuda_graphs: #100196 this is not finished.
* Made VERY BC breaking changes to at::_efficient_attention_forward and at::_efficeint_attention_backward function signatures.
    * I made these changes due to in part to the ability for this PR to land:https://github.com/pytorch/pytorch/pull/100196

### Due Diligence Checks:
* CUDA_lib size:
    * Before: 496 MiB
    * After:    496MiB
* Performance Sweep:
    * I sweeped over 576 configs for forward only inference and the geomean speedup was 0.98x with a min speed up of 0.84 and a max speedup of 1.2
    * For Forw+Back running on 270 configs ( to reduce memory) the geomean speedup was 1.02X with a min speed up of 1.02 and a max speedup of 1.35.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100583
Approved by: https://github.com/cpuhrsch
2023-05-18 16:15:34 +00:00
Driss Guessous
52363de2ec Clean up grad check in sdp_utils.h (#101435)
# Summary
The priorty order was not being run correctly because of confusing function name
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101435
Approved by: https://github.com/jbschlosser
2023-05-16 02:22:45 +00:00