Commit Graph

139 Commits

Author SHA1 Message Date
drisspg
4e29f01bf2 Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689)
# Summary
Simplification of Backend Selection

This PR deprecates the `torch.backends/cuda/sdp_kernel` context manager and replaces it with a new context manager `torch.nn.attention.sdpa_kernel`. This context manager also changes the api for this context manager.

For `sdp_kernel` one would specify the backend choice by taking the negation of what kernel they would like to run. The purpose of this backend manager was to only to be a debugging tool, "turn off the math backend" and see if you can run one of the fused implementations.

Problems:
- This pattern makes sense if majority of users don't care to know anything about the backends that can be run. However, if users are seeking to use this context manager then they are explicitly trying to run a specific backend.
- This is not scalable. We are working on adding the cudnn backend and this API makes it so so that more implementations will need to be turned off if user wants to explicitly run a given backend.
- Discoverability of the current context manager. It is somewhat un-intutive that this backend manager is in backends/cuda/init when this now also controls the CPU fused kernel behavior. I think centralizing to attention namespace will be helpful.

Other concerns:
- Typically backends (kernels) for operators are entirely hidden from users and implementation details of the framework. We have exposed this to users already, albeit not by default and with beta warnings. Does making backends choices even more explicit lead to problems when we potentially want to remove existing backends, (perhaps inputs shapes will get covered by newer backends).

A nice side effect is now that we aren't using the `BACKEND_MAP` in test_transformers many, many dynamo failures are passing for CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114689
Approved by: https://github.com/cpuhrsch
2024-01-24 22:28:04 +00:00
PyTorch MergeBot
2f84a9d37c Revert "[CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)"
This reverts commit 5aa92b5090.

Reverted https://github.com/pytorch/pytorch/pull/115663 on behalf of https://github.com/PaliC due to Unfortunately, this pr breaks cuda builds internally ([comment](https://github.com/pytorch/pytorch/pull/115663#issuecomment-1899388813))
2024-01-18 23:40:30 +00:00
Eddie Yan
5aa92b5090 [CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)
#113713

Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with `TORCH_CUDNN_MHA_ENABLED=1`.

CC @drisspg @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115663
Approved by: https://github.com/drisspg
2024-01-18 01:20:36 +00:00
Sun, Jiayi
d9b265adaf modify the conditions as PythonModuleVariable (#116856)
## Motivation
The current code of `value in [torch.backends.cudnn, torch.ops]` requires `value` to have the implementation of `__eq__`. If the value is a custom object and does not implement `__eq__`, dynamo will throw error. For example, ConvolutionOpContext, the custom 'torch._C.ScriptClass' object registered in IPEX, dynamo will throw the following error:

**torch._dynamo.exc.InternalTorchDynamoError: '__eq__' is not implemented for __torch__.torch.classes.ipex_prepack.ConvolutionOpContext**

I think this is a common issue, To avoid this issue, the PR replaces the current code `value in [torch.backends.cudnn, torch.ops]`with `isinstance(value, (torch.backends.cudnn.CudnnModule, torch._ops._Ops)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116856
Approved by: https://github.com/jansel
2024-01-15 11:10:57 +00:00
drisspg
19e93b85b9 Fixes last_dim stride check for singleton dimensions (#117001)
Fixes #116333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117001
Approved by: https://github.com/cpuhrsch
2024-01-10 04:46:49 +00:00
Valentine233
20c2ec9a15 [CPU] Add flash attention mask version (#115913)
Add a masked-version flash attention for CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115913
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-01-07 04:58:23 +00:00
PyTorch MergeBot
2ccc7af028 Revert "[CPU] Add flash attention mask version (#115913)"
This reverts commit 76a3fbb709.

Reverted https://github.com/pytorch/pytorch/pull/115913 on behalf of https://github.com/zou3519 due to broke transformer test on dynamo shard ([comment](https://github.com/pytorch/pytorch/pull/115913#issuecomment-1878043389))
2024-01-05 02:39:12 +00:00
Valentine233
76a3fbb709 [CPU] Add flash attention mask version (#115913)
Add a masked-version flash attention for CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115913
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-01-05 01:27:36 +00:00
Mikayla Gawarecki
d0cf2182ea Fix TransformerEncoderLayer for bias=False (#116760)
Fixes https://github.com/pytorch/pytorch/issues/116385

Don't call `torch._transformer_encoder_layer_fwd` when `bias=False`

`bias=False` was not something that `torch._transformer_encoder_layer_fwd`  was meant to work with, it was my bad that this wasn't tested as I approved https://github.com/pytorch/pytorch/pull/101687.

`bias=False` was causing the `tensor_args` in [`TransformerEncoder`](a17de2d645/torch/nn/modules/transformer.py (L663-L677)) to contain `None`s and error on checks for the fastpath like `t.requires_grad for t in tensor_args`.

Alternative fix would be to
1) Pass `torch.zeros_like({*}.weight)` to the kernel when `bias=False` and filter `tensor_args` as appropriate
2) Fix `torch._transformer_encoder_layer_fwd` to take `Optional<Tensor>` for biases and fix the kernels as appropriate

Let me know if these approaches are preferable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116760
Approved by: https://github.com/jbschlosser
2024-01-05 00:13:10 +00:00
Xinya Zhang
e3ca7346ce Re-add initial Flash Attention support on ROCM (#115981)
Note about the Updates:

This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.

CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.

Original PR (https://github.com/pytorch/pytorch/pull/114309) Note:

This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.

Fixes #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115981
Approved by: https://github.com/malfet
2024-01-04 22:21:31 +00:00
Mikayla Gawarecki
0f6f582c0d Add config to disable TransformerEncoder/MHA fastpath (#112212)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112212
Approved by: https://github.com/jbschlosser
2024-01-02 23:59:30 +00:00
Aaron Gokaslan
bd10fea79a [BE]: Enable F821 and fix bugs (#116579)
Fixes #112371

I tried to fix as many of the bugs as I could, a few I could not figure out what the proper fix for them was though and so I left them with noqas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116579
Approved by: https://github.com/ezyang
2024-01-01 08:40:46 +00:00
drisspg
1e834e0e50 Fix bug in mem_eff kernel with attention mask and MQA (#116234)
# Summary

Found using the repros mentioned in this issue: #112577

After many go rounds with compute-sanitizer and eventual printf debugging I feel pretty confident that this was the underlying issue

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116234
Approved by: https://github.com/malfet, https://github.com/danthe3rd, https://github.com/atalman
2023-12-21 21:52:21 +00:00
drisspg
65d3dde665 Fix allowed dtypes for mem_eff attention (#116026)
# Summary

Fix issue bug in detecting mem eff capability for cuda devices less than sm80:
https://github.com/pytorch-labs/gpt-fast/issues/49

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116026
Approved by: https://github.com/janeyx99
2023-12-21 01:56:38 +00:00
PyTorch MergeBot
af8a50e656 Revert "Fix allowed dtypes for mem_eff attention (#116026)"
This reverts commit fc58909bab.

Reverted https://github.com/pytorch/pytorch/pull/116026 on behalf of https://github.com/jeanschmidt due to breaking internal windows buck builds, check internal diff for more details ([comment](https://github.com/pytorch/pytorch/pull/116026#issuecomment-1864354665))
2023-12-20 12:01:34 +00:00
drisspg
fc58909bab Fix allowed dtypes for mem_eff attention (#116026)
# Summary

Fix issue bug in detecting mem eff capability for cuda devices less than sm80:
https://github.com/pytorch-labs/gpt-fast/issues/49

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116026
Approved by: https://github.com/janeyx99
2023-12-18 23:20:52 +00:00
Jeff Daily
e3aefe2970 Revert "Initial Flash Attention support on ROCM (#114309)" (#115975)
This reverts commit 5bddbed399.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115975
Approved by: https://github.com/atalman, https://github.com/malfet
2023-12-16 03:40:14 +00:00
Xinya Zhang
5bddbed399
Initial Flash Attention support on ROCM (#114309)
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.

Fixes https://github.com/pytorch/pytorch/issues/112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114309

Approved by: https://github.com/jeffdaily, https://github.com/malfet

---------

Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
2023-12-14 08:52:57 -08:00
Fuzzkatt
661c1cf2aa numerical mismatch fix for test_mem_efficient_attention_attn_mask_vs_math_ref_grads in test_transformers.py (#115707)
adjust dropout_fudge_factor since previous fudge factor was too small and led to numerical mismatch in NVIDIA internal CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115707
Approved by: https://github.com/drisspg
2023-12-14 01:04:39 +00:00
Valentine233
064846dbc2 [cpu] flash attention optimization (#115151)
### Modifications
- **EXP**: Add a fast version with a reduced accuracy (ULP20) to vec exp `exp_u20` and use it in flash attention.
- **FUSION**: Do fusion for `softmax` ops.
- **SCALE**: Move the calculation of `scaling_factor` after `gemm`.

### Performance
_Model: Stable Diffusion V2.1_

| Version | BF16 Kernel latency (s) | BF16 speedup | FP32 Kernel latency (s) | FP32 speedup |
| ----- | ----- | ----- | ----- | ----- |
| PT | 15.865 |  | 35.362 |  |
| PT + EXP | 12.518 | 21.10% | 19.327 | 45.35% |
| PT + EXP + FUSION | 11.774 | 25.79% | 18.306 | 48.23% |
| PT + EXP + FUSION + SCALE | 11.053 | 30.33% | 18.360 | 48.08% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115151
Approved by: https://github.com/jgong5, https://github.com/drisspg
2023-12-12 01:09:55 +00:00
drisspg
d4c79a3078 Add an attention bias subclass for a lower right causal masking (#114823)
# Summary
This PR introduces a new Tensor subclass that is designed to be used with torch.nn.functional.scaled_dot_product_attention. Currently we have a boolean `is_causal` flag that allows users to do do causal masking without the need to actually create the "realized" attention bias and pass into sdpa. We originally added this flag since there is native support in both fused kernels we support. This provides a big performance gain ( the kernels only need to iterate over ~0.5x the sequence, and for very large sequence lengths this can provide vary large memory improvements.

The flag was introduced when the early on in the kernel development and at the time it was implicitly meant to "upper_left" causal attention. This distinction only matters when the attention_bias is not square. For a more detailed break down see: https://github.com/pytorch/pytorch/issues/108108. The kernels default behavior has since changed, largely due to the rise of autogressive text generation. And unfortunately this would lead to a BC break. In the long term it may actually be beneficial to change the default meaning of `is_causal` to represent lower_right causal masking.

The larger theme though is laid here: https://github.com/pytorch/pytorch/issues/110681. The thesis being that there is alot of innovation in SDPA revolving around the attention_bias being used. This is the first in hopefully a few more attention_biases that we would like to add. The next interesting one would be `sliding_window` which is used by the popular mistral model family.

Results from benchmarking, I improved the meff_attention perf hence the slightly decreased max perf.
```Shell
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
|  Type   |      Speedup       | batch_size | num_heads | q_seq_len | k_seq_len | embed_dim |     dtype      | head_dim |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
| Average | 1.2388050062214226 |            |           |           |           |           |                |          |
|   Max   | 1.831672915579016  |    128     |    32     |   1024    |   2048    |   2048    | torch.bfloat16 |    64    |
|   Min   | 0.9430534166730135 |     1      |    16     |    256    |    416    |   2048    | torch.bfloat16 |   128    |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114823
Approved by: https://github.com/cpuhrsch
2023-12-06 08:29:26 +00:00
drisspg
8556a09d44 Require less alignment for attn bias (#114173)
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)

This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
2023-11-28 02:40:41 +00:00
PyTorch MergeBot
88a8a0daa4 Revert "Require less alignment for masking (#114173)"
This reverts commit f882c175d8.

Reverted https://github.com/pytorch/pytorch/pull/114173 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but it is failing some inductor tests f882c175d8 ([comment](https://github.com/pytorch/pytorch/pull/114173#issuecomment-1823552362))
2023-11-22 21:49:31 +00:00
drisspg
f882c175d8 Require less alignment for masking (#114173)
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)

This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
2023-11-22 20:02:51 +00:00
drisspg
9b0f2f8d94 expose sdpa helpers to python (#110496)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110496
Approved by: https://github.com/jbschlosser
2023-11-15 07:34:34 +00:00
drisspg
14811d69d7 [BE] Cleanup sdpa test helper usage (#113294)
# Summary

standardizes usage of the rand_sdpa_tensor helper

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113294
Approved by: https://github.com/soulitzer
2023-11-09 01:16:53 +00:00
drisspg
e509b162ed Disable FlashAttenion for is_causal=True when seqlen q not equal kv (#111007)
# Summary:
This pull request **removes** support for non-square sequence lengths in causal attention when using FlashAttention V2.

### Why are doing this
  // FlashAttention 2 updated the default mask meaning for causal in this PR:
  // 9e5e8bc91e it is now aligned to lower_right which would be a BC break
  // for non-square masks. We will not support non-square masks for causal w/ FAV2

 For more context see:
 https://github.com/pytorch/pytorch/issues/108108

 ### Followup
 A large number of people will likely want to use FAV2 with lower_right causal attention for non equal sequence lengths. See this RFC : https://github.com/pytorch/pytorch/issues/110681

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111007
Approved by: https://github.com/cpuhrsch
2023-10-23 20:33:37 +00:00
drisspg
5183760ca5 Adding Backward Support for NestedTensors and FlashAttention (#97485)
# Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 318764f</samp>

This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`.

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at ed4a773</samp>

> _Fused kernels of doom, unleash the flash attention_
> _Nested tensors on fire, reshape and pad with caution_
> _Backward pass of power, dispatch the CUDA key_
> _Test the gradients of hell, warn the user if they disagree_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97485
Approved by: https://github.com/jbschlosser
2023-10-10 18:08:17 +00:00
Fuzzkatt
c28bb46445 Fix test_mem_efficient_attention_vs_math_ref_grads tolerance from test_transformers.py (#108094)
Tolerance currently too low, triggering test failures via numerical mismatch in NVIDIA internal testing for certain H100, A16, A40 configs. cc: @ptrblck @eqy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108094
Approved by: https://github.com/eqy, https://github.com/msaroufim
2023-10-02 20:42:57 +00:00
PyTorch MergeBot
8d6479725a Revert "Adding Backward Support for NestedTensors and FlashAttention (#97485)"
This reverts commit 28d69d5256.

Reverted https://github.com/pytorch/pytorch/pull/97485 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but one of the tests test_fused_kernels_nested_broadcasting_requires_grad_failure_cuda is failing on Windows CUDA f7ba3e85e2 ([comment](https://github.com/pytorch/pytorch/pull/97485#issuecomment-1743474468))
2023-10-02 17:48:57 +00:00
drisspg
28d69d5256 Adding Backward Support for NestedTensors and FlashAttention (#97485)
# Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 318764f</samp>

This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`.

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at ed4a773</samp>

> _Fused kernels of doom, unleash the flash attention_
> _Nested tensors on fire, reshape and pad with caution_
> _Backward pass of power, dispatch the CUDA key_
> _Test the gradients of hell, warn the user if they disagree_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97485
Approved by: https://github.com/jbschlosser
2023-09-29 21:34:47 +00:00
drisspg
ad90ab31f2 Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-09-13 13:59:05 +00:00
Huy Do
a9c663c269 Revert "Flash Attention v2 (#105602)" (#108827)
This reverts commit add45aea1c.

There are some conflicts on some benchmark csv file https://github.com/pytorch/pytorch/pull/105602#issuecomment-1710988951 so I need to revert this manually.

The diff has been reverted internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108827
Approved by: https://github.com/kit1980
2023-09-08 07:43:04 +00:00
PyTorch MergeBot
e45b290127 Revert "Revert "Flash Attention v2 (#105602)" (#108827)"
This reverts commit 24e9bbe22a.

Reverted https://github.com/pytorch/pytorch/pull/108827 on behalf of https://github.com/huydhn due to I need to land this revert properly as there are new failures showing up on trunk ([comment](https://github.com/pytorch/pytorch/pull/108827#issuecomment-1711020924))
2023-09-08 03:25:45 +00:00
Huy Do
24e9bbe22a Revert "Flash Attention v2 (#105602)" (#108827)
This reverts commit add45aea1c.

There are some conflicts on some benchmark csv file https://github.com/pytorch/pytorch/pull/105602#issuecomment-1710988951 so I need to revert this manually.

The diff has been reverted internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108827
Approved by: https://github.com/kit1980
2023-09-08 02:54:20 +00:00
Michael Gschwind
2a40fe2dbf [experimental] use EXCEPT_FOR env to suppress CPU tests from GPU RE (#108672)
Summary:
[experimental] use EXCEPT_FOR env to suppress CPU tests from GPU RE -- alternative implementation to D48997976 using preexisting PYTORCH_TESTING_DEVICE_EXCEPT_FOR facility and building remaining logic (for assert-positive listers like test_transformers)  on top of that.

Goal: save ~100 GPU (10% of capacity), enables us to fund more aggressive PyPer unit testing on GPU RE

Test Plan: sandcastle, github

Differential Revision: D48998582

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108672
Approved by: https://github.com/bertmaher
2023-09-06 23:33:18 +00:00
drisspg
add45aea1c Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-09-01 22:14:44 +00:00
PyTorch MergeBot
d569e506ab Revert "Flash Attention v2 (#105602)"
This reverts commit 9df3d882c8.

Reverted https://github.com/pytorch/pytorch/pull/105602 on behalf of https://github.com/huydhn due to I think we miss a case here for sm80 build on inductor workflow as it is now OOM on trunk https://github.com/pytorch/pytorch/actions/runs/6042843139 ([comment](https://github.com/pytorch/pytorch/pull/105602#issuecomment-1701974862))
2023-09-01 01:15:01 +00:00
drisspg
9df3d882c8 Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-08-31 16:02:20 +00:00
drisspg
42d60d012e Bias overflow fix mem eff bias (#107968)
Fixes #107959
This should have been fixed here https://github.com/pytorch/pytorch/pull/103201
Edit:
Looking at git blame it appears the dropout revet squashed the changes from this PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107968
Approved by: https://github.com/cpuhrsch
2023-08-26 00:00:49 +00:00
Mikayla Gawarecki
48b1208e05 Disable nn.MHA fastpath for floating point masks (#107641)
Fixes https://github.com/pytorch/pytorch/issues/107084 by disabling the fast path when floating point masks (which should be additive) are passed

- [We claim in our docs for MHA that float masks will be added to the attention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) (be it `key_padding_mask` or `attn_mask`)
- We always canonicalize any mask at the start of MHA in python by converting it to float
- my understanding from Driss is that SDPA properly supports additive masking (but there are many special cases for mask shape for MHA that don't work properly currently (BxT, TxT) so [we're turning this off for now](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L531-L532)
- More broadly, the problem isn't with the SDPA path, but that things are broken for the path it falls back to
-  Right now mha "fast path" code with non-None masks is always going through [this path ](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L554-L640) that  has a call to `masked_softmax` that [converts the masks back to bool](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp#L154-L156)
- the implication here is that **additive floating point attn_mask and additive key_padding_mask to nn.MHA fastpath are broken**
- This wasn't broken for the user in [https://github.com/pytorch/pytorch/issues/107084](https://l.workplace.com/l.php?u=https%3A%2F%2Fgithub.com%2Fpytorch%2Fpytorch%2Fissues%2F107084&h=AT35qHIQavtxKtriTkrkPsWRB3eSRh4qH5PQUyiTzrPTshoztPL0593AmKCmSdEQ5O-5wib0Fd4mwztVu4YbMWb2ghZnZw1pvpJb9-FYWjDsPQ6_oHRVPzFfj8xYXC1TaFnJCkMYjrGXkIfzzxZvmcQYNnIPgsJSiWgjIw) in 1.13.1 because of [this check which bypassed the fast path if attn_mask was defined](https://github.com/pytorch/pytorch/blob/v1.13.1/torch/nn/modules/activation.py#L1096-L1097) (as Driss pointed out though additive key_padding_mask with the fast path were probably  broken in 1.13.1)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107641
Approved by: https://github.com/drisspg, https://github.com/jbschlosser
2023-08-23 15:08:18 +00:00
Aaron Gokaslan
660e8060ad [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-22 23:16:38 +00:00
PyTorch MergeBot
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e4322.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00
Liao, Xuan
71632d4d24 [cpu] add sdpa choice and UT (#105131)
Feature RFC: https://github.com/pytorch/rfcs/pull/56.

Write an SDPA selecting function for CPU to automatically choose one SDPA implementation among several ones. There are two CPU implementations which could be chosen: the unfused SDPA and flash attention. In general, flash attention has a higher priority than the unfused SDPA. For cases where flash attention is not applicable, such as manually disabling flash attention or the inputs not 4 dimensional, the unfused SDPA is chosen.

## Performance of the stack

### NanoGPT's SDPA kernel
Using benchmark [repo](https://github.com/mingfeima/bench_sdpa/blob/main/README.md), with one socket.
Shape: Batch size 1, Sequence length 1024, Head number 25, Head size 64.
Machine: SPR.

| Dtype    | Causal   | Mode      | SDPA            | Time (ms per iter) | Speedup |
| -------- | -------- | -------   | -------         | -------            | ------- |
| float32  | FALSE    | Inference | Unfused         | 3.081              |         |
|          |          |           | Flash attention | 1.665              | **1.85045** |
| float32  | TRUE     | Inference | Unfused         | 3.463              |         |
|          |          |           | Flash attention | 1.662              | **2.083634**|
| bfloat16 | FALSE    | Inference | Unfused         | 1.203              |         |
|          |          |           | Flash attention | 1.154              | **1.042461**|
| bfloat16 | TRUE     | Inference | Unfused         | 1.543              |         |
|          |          |           | Flash attention | 1.154              | **1.337088**|
| float32  | FALSE    | Training  | Unfused         | 54.938             |         |
|          |          |           | Flash attention | 23.029             | **2.385601**|
| float32  | TRUE     | Training  | Unfused         | 58.266             |         |
|          |          |           | Flash attention | 17.835             | **3.266947**|
| bfloat16 | FALSE    | Training  | Unfused         | 18.924             |         |
|          |          |           | Flash attention | 18.886             | **1.002012**|
| bfloat16 | TRUE     | Training  | Unfused         | 21.08              |         |
|          |          |           | Flash attention | 14.172             | **1.48744** |

### Stable Diffusion
Following model's [BKM](https://github.com/intel-innersource/frameworks.ai.models.intel-models/blob/develop/quickstart/diffusion/pytorch/stable_diffusion/inference/cpu/README.md).
Mode: Inference; Machine: SPR.

| Dtype    | SDPA                    | Throughput (fps) | Speedup SDPA | Total Time (ms) | Speedup |
| -------- | --------                | -------          | -------      | -------         | ------- |
| float32  | Unfused                 | 1.63             |              | 1139            |         |
|          | Flash attention         | 1.983            | 1.216564     | 547.488         | **2.080411**|
| bfloat16 | Flash attention in IPEX | 4.784            |              | 429.051         |         |
|          | Flash attention         | 4.857            | 1.015259     | 408.823         | **1.049479**|

### LLM models of Torchbench

Dtype: float32; Mode: Inference, single socket; Machine: CPX.
Model   name | SDPA | Inductor_new | Inductor_old | Inductor   Ratio(old/new)
-- | -- | -- | -- | --
hf_Albert | Unfused -> Flash attention | 0.048629309 | 0.05591545 | **1.14983024**
hf_Bert | Unfused -> Flash attention | 0.053156243 | 0.060732115 | **1.142520841**
hf_Bert_large | Unfused -> Flash attention | 0.141089502 | 0.155190077 | **1.099940636**
llama | Unfused -> Flash attention | 0.033250106 | 0.033720745 | **1.01415451**

Dtype: bfloat16; Mode: Inference, single socket; Machine: SPR.
Model   name | SDPA | Inductor_new | Inductor_old | Inductor   Ratio(old/new)
-- | -- | -- | -- | --
hf_Albert | Unfused -> Flash attention | 0.020681298 | 0.020718282 | **1.001788324**
hf_Bert | Unfused -> Flash attention | 0.019932816 | 0.019935424 | **1.000130842**
hf_Bert_large | Unfused -> Flash attention | 0.047949174 | 0.048312502 | **1.007577355**
llama | Unfused -> Flash attention | 0.018528057 | 0.01861126 | **1.0044907**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105131
Approved by: https://github.com/drisspg
ghstack dependencies: #104583, #104584, #103826, #104693, #104863, #107128
2023-08-20 08:56:21 +00:00
Aaron Gokaslan
88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00
Fuzzkatt
3c7331742a test_fused_sdp_choice in test_transformers.py fix (#106587)
sdp dispatcher prioritizes flash attention over efficient attention: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp#L684-L687, and flash attention is enabled for sm75+: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp#L625. Thus, the unit test `test_fused_sdp_choice` from `test_transformers.py` which is failing on T4 (sm75) should have this `SM80OrLater` check changed to `SM75OrLater`: https://github.com/pytorch/pytorch/blob/main/test/test_transformers.py#L1914-L1917.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106587
Approved by: https://github.com/drisspg
2023-08-04 03:43:56 +00:00
drisspg
cfa4edcde0 [SDPA] Update dispatch checks to catch last_dim_stride != 1. Also update mask padding logic (#106102)
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at bb1fc29</samp>

This pull request simplifies and refactors the code for fused scaled dot product attention kernels in `attention.cu` and `sdp_utils.cpp`, and adds new input validation checks and tests. It also modifies the `sdp_params` struct to store optional mask tensors directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106102
Approved by: https://github.com/cpuhrsch
2023-08-01 19:13:01 +00:00
XiaobingSuper
55f9359d36 fix sdpa math accuracy issue when scale is negative (#105202)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105202
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/drisspg
2023-08-01 00:19:14 +00:00
Fuzzkatt
1cebfef8a4 sm90 efficient attention test fixes (#105978)
Fixes the following two test cases involving efficient attention on sm90:

Explanations:

functorch/test_ops.py: test_vjp_nn_functional_scaled_dot_product_attention_cuda_float32
* originally the test had xfail for all sm
* in https://github.com/pytorch/pytorch/issues/102029, we found that it was unexpectedly passing on sm90
* I made https://github.com/pytorch/pytorch/pull/102131 to update the test to let it pass
* @drisspg seems to have made changes to the behavior such that the original xfail was getting triggered (https://github.com/pytorch/pytorch/issues/102029#issuecomment-1560071148)
* the CI began complaining about the failure again: https://github.com/pytorch/pytorch/issues/102663
* I'm now reverting https://github.com/pytorch/pytorch/pull/102131 to bring back the original xfail now that the behavior has been fixed by @drisspg to trigger the xfail in sm90 similar to all other sm

test_transformers.py: test_mem_efficient_fail_sm90_cuda
* the test as it's currently written seems to expect the sdp dispatcher to fail for mem efficient attention on sm90; however, testing this on H100, it actually succeeds, so I'm disabling the test for now as the current expected result may be outdated

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105978
Approved by: https://github.com/eqy, https://github.com/kshitij12345, https://github.com/zou3519
2023-07-31 17:59:40 +00:00
drisspg
cb9a4fbbf2 [BE] Improve test_transformers test structure (#105938)
# Summary

We have a vast majority of test that only run on cuda. Decorating with @onlycuda causes pytest to instantiate 2x the tests and skip half of them. This overhead is non trivial when the #tests cross larger like it has for this file.

This breaks up the cuda only tests into a separate class
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105938
Approved by: https://github.com/mikaylagawarecki, https://github.com/malfet
2023-07-26 22:16:20 +00:00