This reverts commit a5a63db3bf.
Fixes #ISSUE_NUMBER
Reverts #118368
Got reverted internally but branch got deleted to automation didn't work
Mildly edited stack trace
```
...
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "torch/_dynamo/eval_frame.py", line 453, in _fn
return fn(*args, **kwargs)
File "torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 635, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "torch/fx/experimental/proxy_tensor.py", line 995, in trace
res = super().trace(root, concrete_args)
File "torch/_dynamo/eval_frame.py", line 453, in _fn
return fn(*args, **kwargs)
File "torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "torch/fx/_symbolic_trace.py", line 793, in trace
(self.create_arg(fn(*args)),),
File "torch/fx/experimental/proxy_tensor.py", line 665, in wrapped
out = f(*tensors)
File "<string>", line 1, in <lambda>
File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 357, in _functionalized_f_helper
f_outs = fn(*f_args)
File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 68, in inner_fn
outs = fn(*args)
File "torch/_functorch/_aot_autograd/utils.py", line 161, in flat_fn
tree_out = fn(*args, **kwargs)
File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 618, in functional_call
out = PropagateUnbackedSymInts(mod).run(
File "torch/fx/interpreter.py", line 145, in run
self.env[node] = self.run_node(node)
File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 593, in run_node
result = super().run_node(n)
File "torch/fx/interpreter.py", line 202, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "torch/fx/interpreter.py", line 274, in call_function
return target(*args, **kwargs)
File "torch/_ops.py", line 571, in __call__
return self_._op(*args, **kwargs)
File "torch/_subclasses/functional_tensor.py", line 380, in __torch_dispatch__
outs_unwrapped = func._op_dk(
File "torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 744, in __torch_dispatch__
return self.inner_torch_dispatch(func, types, args, kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 779, in inner_torch_dispatch
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 423, in proxy_call
r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 1225, in maybe_handle_decomp
return CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
File "torch/_decomp/decompositions.py", line 4322, in scaled_dot_product_flash_attention_for_cpu
torch._check(
File "torch/__init__.py", line 1133, in _check
_check_with(RuntimeError, cond, message)
File "torch/__init__.py", line 1116, in _check_with
raise error_type(message_evaluated)
RuntimeError: query must be FP32, FP64, BF16 but got torch.float16
While executing %_scaled_dot_product_flash_attention_for_cpu : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default](args = (%l_q_, %l_k_, %l_v_), kwargs = {attn_mask: %l_attn_mask_})
Original traceback:
File "executorch/backends/xnnpack/partition/graphs/sdpa.py", line 34, in forward
return torch.nn.functional.scaled_dot_product_attention(
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119204
Approved by: https://github.com/kit1980
# Summary
Simplification of Backend Selection
This PR deprecates the `torch.backends/cuda/sdp_kernel` context manager and replaces it with a new context manager `torch.nn.attention.sdpa_kernel`. This context manager also changes the api for this context manager.
For `sdp_kernel` one would specify the backend choice by taking the negation of what kernel they would like to run. The purpose of this backend manager was to only to be a debugging tool, "turn off the math backend" and see if you can run one of the fused implementations.
Problems:
- This pattern makes sense if majority of users don't care to know anything about the backends that can be run. However, if users are seeking to use this context manager then they are explicitly trying to run a specific backend.
- This is not scalable. We are working on adding the cudnn backend and this API makes it so so that more implementations will need to be turned off if user wants to explicitly run a given backend.
- Discoverability of the current context manager. It is somewhat un-intutive that this backend manager is in backends/cuda/init when this now also controls the CPU fused kernel behavior. I think centralizing to attention namespace will be helpful.
Other concerns:
- Typically backends (kernels) for operators are entirely hidden from users and implementation details of the framework. We have exposed this to users already, albeit not by default and with beta warnings. Does making backends choices even more explicit lead to problems when we potentially want to remove existing backends, (perhaps inputs shapes will get covered by newer backends).
A nice side effect is now that we aren't using the `BACKEND_MAP` in test_transformers many, many dynamo failures are passing for CPU tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114689
Approved by: https://github.com/cpuhrsch
## Motivation
The current code of `value in [torch.backends.cudnn, torch.ops]` requires `value` to have the implementation of `__eq__`. If the value is a custom object and does not implement `__eq__`, dynamo will throw error. For example, ConvolutionOpContext, the custom 'torch._C.ScriptClass' object registered in IPEX, dynamo will throw the following error:
**torch._dynamo.exc.InternalTorchDynamoError: '__eq__' is not implemented for __torch__.torch.classes.ipex_prepack.ConvolutionOpContext**
I think this is a common issue, To avoid this issue, the PR replaces the current code `value in [torch.backends.cudnn, torch.ops]`with `isinstance(value, (torch.backends.cudnn.CudnnModule, torch._ops._Ops)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116856
Approved by: https://github.com/jansel
Fixes https://github.com/pytorch/pytorch/issues/116385
Don't call `torch._transformer_encoder_layer_fwd` when `bias=False`
`bias=False` was not something that `torch._transformer_encoder_layer_fwd` was meant to work with, it was my bad that this wasn't tested as I approved https://github.com/pytorch/pytorch/pull/101687.
`bias=False` was causing the `tensor_args` in [`TransformerEncoder`](a17de2d645/torch/nn/modules/transformer.py (L663-L677)) to contain `None`s and error on checks for the fastpath like `t.requires_grad for t in tensor_args`.
Alternative fix would be to
1) Pass `torch.zeros_like({*}.weight)` to the kernel when `bias=False` and filter `tensor_args` as appropriate
2) Fix `torch._transformer_encoder_layer_fwd` to take `Optional<Tensor>` for biases and fix the kernels as appropriate
Let me know if these approaches are preferable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116760
Approved by: https://github.com/jbschlosser
Note about the Updates:
This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.
CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.
Original PR (https://github.com/pytorch/pytorch/pull/114309) Note:
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.
Know limitations:
- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.
Fixes#112997
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115981
Approved by: https://github.com/malfet
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.
Know limitations:
- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.
Fixes https://github.com/pytorch/pytorch/issues/112997
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114309
Approved by: https://github.com/jeffdaily, https://github.com/malfet
---------
Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
# Summary
This PR introduces a new Tensor subclass that is designed to be used with torch.nn.functional.scaled_dot_product_attention. Currently we have a boolean `is_causal` flag that allows users to do do causal masking without the need to actually create the "realized" attention bias and pass into sdpa. We originally added this flag since there is native support in both fused kernels we support. This provides a big performance gain ( the kernels only need to iterate over ~0.5x the sequence, and for very large sequence lengths this can provide vary large memory improvements.
The flag was introduced when the early on in the kernel development and at the time it was implicitly meant to "upper_left" causal attention. This distinction only matters when the attention_bias is not square. For a more detailed break down see: https://github.com/pytorch/pytorch/issues/108108. The kernels default behavior has since changed, largely due to the rise of autogressive text generation. And unfortunately this would lead to a BC break. In the long term it may actually be beneficial to change the default meaning of `is_causal` to represent lower_right causal masking.
The larger theme though is laid here: https://github.com/pytorch/pytorch/issues/110681. The thesis being that there is alot of innovation in SDPA revolving around the attention_bias being used. This is the first in hopefully a few more attention_biases that we would like to add. The next interesting one would be `sliding_window` which is used by the popular mistral model family.
Results from benchmarking, I improved the meff_attention perf hence the slightly decreased max perf.
```Shell
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
| Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | embed_dim | dtype | head_dim |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
| Average | 1.2388050062214226 | | | | | | | |
| Max | 1.831672915579016 | 128 | 32 | 1024 | 2048 | 2048 | torch.bfloat16 | 64 |
| Min | 0.9430534166730135 | 1 | 16 | 256 | 416 | 2048 | torch.bfloat16 | 128 |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114823
Approved by: https://github.com/cpuhrsch
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)
This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.
## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.
Should this be warn_once?
We only call expand, once on the aligned mask.
Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115
@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)
This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.
## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.
Should this be warn_once?
We only call expand, once on the aligned mask.
Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115
@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
# Summary:
This pull request **removes** support for non-square sequence lengths in causal attention when using FlashAttention V2.
### Why are doing this
// FlashAttention 2 updated the default mask meaning for causal in this PR:
// 9e5e8bc91e it is now aligned to lower_right which would be a BC break
// for non-square masks. We will not support non-square masks for causal w/ FAV2
For more context see:
https://github.com/pytorch/pytorch/issues/108108
### Followup
A large number of people will likely want to use FAV2 with lower_right causal attention for non equal sequence lengths. See this RFC : https://github.com/pytorch/pytorch/issues/110681
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111007
Approved by: https://github.com/cpuhrsch
# Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 318764f</samp>
This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`.
<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at ed4a773</samp>
> _Fused kernels of doom, unleash the flash attention_
> _Nested tensors on fire, reshape and pad with caution_
> _Backward pass of power, dispatch the CUDA key_
> _Test the gradients of hell, warn the user if they disagree_
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97485
Approved by: https://github.com/jbschlosser
# Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 318764f</samp>
This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`.
<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at ed4a773</samp>
> _Fused kernels of doom, unleash the flash attention_
> _Nested tensors on fire, reshape and pad with caution_
> _Backward pass of power, dispatch the CUDA key_
> _Test the gradients of hell, warn the user if they disagree_
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97485
Approved by: https://github.com/jbschlosser
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985
### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao
### Changes Made
The majority of the changes in this pull request involve:
- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.
### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)
There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.
### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108
### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes
### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
Summary:
[experimental] use EXCEPT_FOR env to suppress CPU tests from GPU RE -- alternative implementation to D48997976 using preexisting PYTORCH_TESTING_DEVICE_EXCEPT_FOR facility and building remaining logic (for assert-positive listers like test_transformers) on top of that.
Goal: save ~100 GPU (10% of capacity), enables us to fund more aggressive PyPer unit testing on GPU RE
Test Plan: sandcastle, github
Differential Revision: D48998582
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108672
Approved by: https://github.com/bertmaher
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985
### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao
### Changes Made
The majority of the changes in this pull request involve:
- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.
### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)
There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.
### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108
### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes
### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985
### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao
### Changes Made
The majority of the changes in this pull request involve:
- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.
### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)
There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.
### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108
### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes
### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at bb1fc29</samp>
This pull request simplifies and refactors the code for fused scaled dot product attention kernels in `attention.cu` and `sdp_utils.cpp`, and adds new input validation checks and tests. It also modifies the `sdp_params` struct to store optional mask tensors directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106102
Approved by: https://github.com/cpuhrsch
# Summary
We have a vast majority of test that only run on cuda. Decorating with @onlycuda causes pytest to instantiate 2x the tests and skip half of them. This overhead is non trivial when the #tests cross larger like it has for this file.
This breaks up the cuda only tests into a separate class
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105938
Approved by: https://github.com/mikaylagawarecki, https://github.com/malfet
# Summary
### Review Points
- Automatically pad tensors to create aligned masks when seqlen_kv is not multiple of 16. This will cause memory spike ~ 2 * attn_mask size which could in theory be big. At appears though that doing this + mem_eff is faster than no_pad + math. SO seems to be worth it
- Using expand to view the attn_mask in 4d. This is a little different to how we enforce q,k,v to be viewed in 4d prior to calling. Also not supprint b*n_heads, seq_lenq, seq_lenkv case.
- Should enable, #96099
### Profiling
I ran a bunch of comparisons between sdpa.MATH and sdp.MemEffAttention. I added a attn_bias of shape (1, 1, seqlen_q, seqln_k). For these experiments seqlen_q == seqlen_k. These were all ran on an a100 80gb gpu.
Configs:
```
# Run a bunch of experiments
batch_sizes = [8, 16, 32]
num_heads = [16, 32]
max_seq_lens = [15, 64, 128, 512, 555, 1024]
embed_dims = [32, 64, 128]
dtypes = [torch.float16, torch.bfloat16, torch.float32]
pad_percentages = [None]
backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
run_backward = True
attn_mask = True
```
The function calls `sdpa(input**).sum().backward()`.
I calculated the geomean speedup of the efficient attention path of the math path for all these configs:
`Geomean Speedup: 1.977`
An example comparision with batchsize = 8, num_heads = 32, embed_dim = 64, and dtype = torch.float16:

This was done using the current state of the branch where we force alignment of mask when the last dim is not divisible by 16, which shows up in seq_len = 15 and 555 case.
The full data can be found here:
[attn_mask_sweep.csv](https://github.com/pytorch/pytorch/files/11962399/attn_mask_sweep.csv)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104310
Approved by: https://github.com/cpuhrsch
# Summary
### Review Points
- Automatically pad tensors to create aligned masks when seqlen_kv is not multiple of 16. This will cause memory spike ~ 2 * attn_mask size which could in theory be big. At appears though that doing this + mem_eff is faster than no_pad + math. SO seems to be worth it
- Using expand to view the attn_mask in 4d. This is a little different to how we enforce q,k,v to be viewed in 4d prior to calling. Also not supprint b*n_heads, seq_lenq, seq_lenkv case.
- Should enable, #96099
### Profiling
I ran a bunch of comparisons between sdpa.MATH and sdp.MemEffAttention. I added a attn_bias of shape (1, 1, seqlen_q, seqln_k). For these experiments seqlen_q == seqlen_k. These were all ran on an a100 80gb gpu.
Configs:
```
# Run a bunch of experiments
batch_sizes = [8, 16, 32]
num_heads = [16, 32]
max_seq_lens = [15, 64, 128, 512, 555, 1024]
embed_dims = [32, 64, 128]
dtypes = [torch.float16, torch.bfloat16, torch.float32]
pad_percentages = [None]
backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
run_backward = True
attn_mask = True
```
The function calls `sdpa(input**).sum().backward()`.
I calculated the geomean speedup of the efficient attention path of the math path for all these configs:
`Geomean Speedup: 1.977`
An example comparision with batchsize = 8, num_heads = 32, embed_dim = 64, and dtype = torch.float16:

This was done using the current state of the branch where we force alignment of mask when the last dim is not divisible by 16, which shows up in seq_len = 15 and 555 case.
The full data can be found here:
[attn_mask_sweep.csv](https://github.com/pytorch/pytorch/files/11962399/attn_mask_sweep.csv)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104310
Approved by: https://github.com/cpuhrsch
Summary:
* Create a private global-scope function _generate_subsequent because static class attribute member functions not supported by TorchScript resulting in torchscripting errors.
* Make TransformerEncoder and TransformerDecoder consistent w.r.t. is_causal handling by calling _detect_casual_mask
* Clarify documentation that is_causal is a hint
* Move causal mask detection into a method _detect_causal_mask
* only accept input-size compatible causal mask as causal mask
* update _generate_subsequent_causal_mask to include factory kwargs for dtype and device:
avoid extra copies & conversions by passing directly to torch.full.
Test Plan: sandcastle & github CICD
Continuation of #101487 (due to a tooling issue) which is a continuation-in-part of https://github.com/pytorch/pytorch/pull/98327 by @janEbert
Differential Revision: D47427117
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105265
Approved by: https://github.com/mikaylagawarecki
Summary: Exercise subclass of TransformerEncoderLayer
Additional unit tests for change in #102045 to show correct e2e operation (cf. issue #100188)
Also: remove batch_first from list of TS module constants where it is not used to resolve torchscripting warning
Test Plan: saqndcastle, github
Differential Revision: D47503004
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105297
Approved by: https://github.com/davidberard98
# Summary
Since we have upstreamed the latest changes of memory efficient attetnion we can remove the sm86/sm89 specific check. All head_sizes (assuming correctly alignment) should work for sm86 and sm89 size and don't have a max capability.
If head_size > 96 there will be a big drop in performance but should not error and still maintain memory savings by not materializing attention weights.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102985
Approved by: https://github.com/cpuhrsch
Summary: Move static checks of layers[0] (e.g., isinstance check) to model build time because isinstance() does not work for torchscripted code. Because the validation is now performed while constructing the object, the isinstance() call is performed in eager mode at model build time, and we avoid needing to call isinstance() at runtime to determine whether the layers in a model are an instance of the TransformerEncoderLayer class, or its derived classes.
Test Plan: sandcastle, github
Differential Revision: D46096222
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102045
Approved by: https://github.com/mikaylagawarecki
# Summary
This is another upstream which is much smaller than the previous.
This bumps the kernel versions from xformers
Current: [6425fd0cacb1a6579aa2f0c4a570b737cb10e9c3](6425fd0cac)
With this PR: [1d635e193e169fc677b2e7fa42dad7ebe88eec9e](1d635e193e)
### Notable Changes:
- Drastically improve the BW pass in multiple cases (especially when B*numHeads < 100)
- H100 Support: *Warning* While these kernels have been added, we don't have the CI/CD machines to test.
- Enables a deterministic mode.
## Specific Changes
- Updates to the backward kernel.
- Added num_splits_key which we hard code to -1. (This is a another performance knob that we set to the heuristic)
- Update gen_code and kernels to produce h100 instantiations.
### Due Diligence Checks:
* CUDA_lib size: No changes in size
#### Peformance
* Micro Benchmark: (batch_size: 1, num_heads=25, seq_len=4096, embed_dim = 64 | grid:[1,25,1]block: [128,1,1])
* MemEfficientAttention Backward Kernel: 27.972 ms
* After the updated Xformers code(https://github.com/pytorch/pytorch/pull/100583): 23.958 ms
* With this PR: 4.085 ms
* Ran micro benchmarks on sdpa_forw().sum().backward() over a range of dtypes, and input shapes
* Geo_mean increase -> 1.17x
* Max increase -> 2.95x
* min_increase -> 0.8x
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101847
Approved by: https://github.com/cpuhrsch
# Summary
Since the initial upstream of memory efficient attention from xformers: #86157, significant work updates have been made to the kernel including - increased performance, bug-fixes, and added functionality. This PR upstreams the latest version of this kernel as of: version 0.0.20 or commit: [6425fd0cacb1a6579aa2f0c4a570b737cb10e9c3](6425fd0cac)
## Future
Although this version of the Kernel has support for dropout and arbitrary attention bias, I did not add this support to SDPA yet, and left the guards in sdp_utils. Those will follow up PRs in order to reduce the scope creep of these substantial changes, and ensure that nothing is broken.
## Specific Changes
### Minor Changes
* The build system work was done in the previous PR and so no changes were needed to CMAKE 🤞
* Adding the new files and re-arranging/creating folder structure
* Updating include paths
* Switching from xformer specific functions: `XFORMERS_CHECK -> TORCH_CHECK`
* Changes to xformer specific macros
* Updates to the `generate_kernels.py` to use account for Pytorch file structure, also added an arg parse that I could run on a test dir before creating the files in place.
### Bigger Changes
* Previous Kernel changes "Removed the chunk optimization: see discussion here: https://github.com/pytorch/pytorch/pull/96880"
* Increased the number of cuda kernels -> potentially effecting the cuda_lib size.
* Preemptively made changes to the dtypes of seed and offset in order to allow for cuda_graphs: #100196 this is not finished.
* Made VERY BC breaking changes to at::_efficient_attention_forward and at::_efficeint_attention_backward function signatures.
* I made these changes due to in part to the ability for this PR to land:https://github.com/pytorch/pytorch/pull/100196
### Due Diligence Checks:
* CUDA_lib size:
* Before: 496 MiB
* After: 496MiB
* Performance Sweep:
* I sweeped over 576 configs for forward only inference and the geomean speedup was 0.98x with a min speed up of 0.84 and a max speedup of 1.2
* For Forw+Back running on 270 configs ( to reduce memory) the geomean speedup was 1.02X with a min speed up of 1.02 and a max speedup of 1.35.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100583
Approved by: https://github.com/cpuhrsch
High level approach:
1. I generated a bunch of data comparing FlashAttention and Cutlass implementations (https://pastebin.com/pe0j3YeK)
2. I trained a decision tree using standard train/val split methodology and hyperparameter sweeps (https://pastebin.com/fjYX1HjR).
2a. I did a bunch of feature augmentation to capture interactions between features.
The heuristic I ended up with is:
```
use_flash = seq_len / (num_heads * batch_size) > 6
```
TL;DR: On my dataset, where FlashAttention and Cutlass differ by more than 10%, the existing heuristic achieves 69% accuracy. My new heuristic achieves 94% accuracy.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99644
Approved by: https://github.com/ngimel, https://github.com/drisspg
Expand sdpa_utils.h check to disable FlashAttention when using autograd and mem eff attention for the following cases
- head_dim > 64
- sm86 or newer
Previously we only disable these kernels on sm86 and for head_dim equal to 128.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99105
Approved by: https://github.com/malfet
Summary:
This fixes an issue raised in [is_causal parameter in torch.nn.TransformerEncoderLayer.forward does not work #96941](https://github.com/pytorch/pytorch/issues/96941) where results computed with is_causal do not properly reflect causal masking.
In PyTorch 2.0, Accelerated PT Transformers added the is_causal parameter to legacy nn.Transformer* and nn.MHA APIs aligned with and intended to engage the is_causal parameter of the new scaled_dot_product_attention (SDPA) operator.
At present is_causal works differently for Transformer* modules, the nn.MHA and F.MHA:
* The nn.Transformer* modules treat is_causal as an optional indicator about the format of attn_mask. This is because some layers (such as the CLIP layer use the attention mask in the layer, and thus the attn_mask was a required feature.)
* Initially, nn.MHA and F.MHA were defined to align with F.SDPA in behavior: a user may specify either the attention mask, or is_causal, but not both. It seemed to make sense at the time to align SDPA and MHA, esp since there was a larger overlap of parameters which have since changed, e.g., with the removal of need_weights from SDPA. (See below for why this makes sense.)
Unfortunately, this does not work because of how MHA was changed to handle the need_weights parameter. When need_weights is present, we do not (any more) call SDPA because support for need_weights was removed from SDPA before the release. The rationale is that need_weights defeats all optimization at the foundation of SDPA performance. Having the flag might thus mislead users into thinking they get good performance and have them disappointed when they enable a legacy feature of MHA which massively degrades performance. (They might not think anything of enabling that, because it is on by default in MHA today, which leads to more issues.)
Since SDPA does not (no longer) support need_weights, we need to pick a separate path which implements attention using a set of discrete operations that allocates a tensor for weights. Alas, this code path does not have support for is_causal, because attention is implemented as matmul and using the attention mask. Thus, is_causal has no impact. (A substantially similar situation arises with how kpm is implemented today because Nested Tensors are not supported by torch.compile() in 2.0)
This problem was masked because all uses of legacy nn.MHA (and F.MHA) come through nn.Transformer* which called self-attention (i.e., nn.MHA) only ever with the attention mask attn_mask, and never with is_causal, a missed optimization opportunit that would have been addressed in a future performance update.
Regrettably, always calling nn.MHA with attn_mask prevented diagnosing of the issue of not having a suitable attention mask when need_weights support was dropped from SDPA and a discrete implementation of attention was added for that scenario, and for the execution path with key_padding_mask.
We have two options to address this issue:
Solution 1: Whenever nn.MHA and F.MHA are executed with is_causal set, we internally create a causal mask at significant expense of allocating a tensor and filling it with a triangular causal matrix. This increases memory usage, and runtime, for allocating a causal mask. To add insult to injury, in all current (and likely future) execution scenarios, MHA is called by a model using the nn.Transformer API which already has that matrix and passes it from nn.module to nn.module. Then the passing in of attn_mask has to be suppressed by nn.TransformerEncoderLayer, only for nn.MHA to immediately allocate the very same tensor again to satisfy the requirement to have an attention mask for the computation. (We expect new use cases to use SDPA directly.)
Solution 2: We align the behavior of nn.MHA and F.MHA with the rest of the existing nn.Transformer API, and require the attention mask to be passed into nn.MHA in addition to is_causal as an optional indicator about the nature of the attention mask rather than as an alternative to attn_mask. Then, when we choose the code path for processing MHA with need_weights or a key_padding_mask, we have the attn_mask passed down through the nn.Transformer* hierarchy, without the added overhead of allocating an attention mask as in scenario 1.
This PR implements solution 2 which offers better performance and in retrospect aligns MHA better with the rest of the Transformer modules as the definition of SDPA evolved into a more streamlined high-performance operator. It ostensibly changes how is_causal works, by requiring the attention mask to be specified. However, as described here, and as shown in the submitted issue, is_causal is not working as intended today, so it requires a change regardless.
In that sense, a change in API does not occur per-se, as the current implementation is not working, and a change has to occur either way to resolve the submitted issue, breaking any use cases that depend on the current implementation. Checks exist (and more can be added) that flag any scenarios where is_causal is passed as True, but no attention mask is provided, ensuring that there's not quiet change from even the faulty behavior present in 2.0.
As an upside, the present implementation will improve performance by addressing the passing of the is_causal flag from Transformer modules to MHA, speeding up training for these examples, e.g., finetuning BERT, RoBERTa, XLM-R models.
Differential Revision: D44245725
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97214
Approved by: https://github.com/albanD
Summary: Decoder native joins the dead code society
With the recent introduction of PT2, we no longer need native decoder operators:
1 - full-function SDPA kernels can be used to implement cross-attention efficiently without the (slower) decoder MHA blob.
2 - torch.compile() generates more efficient code across many platforms from the python implementation of decoders than the decoder layer blob by tailoring code to target
Test Plan: github & sandcastle
Differential Revision: D43811808
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96025
Approved by: https://github.com/ezyang, https://github.com/albanD
# Summary
This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention()
The new kwarg is a scaling factor that is applied after the q@k.T step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well.
Will reduce the complexity of: #94729 and has been asked for by a couple of users.
# Review Highlights
- As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right?
- I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename.
- 'scale' is interpreted as `Q@K.T * (scale)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95259
Approved by: https://github.com/cpuhrsch
Summary: fix src and pad mask bool regression
This fixes a regression introduced previously with #92733. That PR unified testing of masks to remove Byte Tensors as permissible mask, introduced mask compatibility check, and mask conversion to FP mask. The problem addressed in this PR was that after the first mask had been converted, a check for mask compatibility would fail.
Test Plan: sandcastle & github
Differential Revision: D43782858
Fixes https://github.com/pytorch/pytorch/issues/95702
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96009
Approved by: https://github.com/malfet
# Summary
Previously, for NestedTensor inputs flash_attention was disabled due to an Illegal Memory Access error that was occurring on the "cutlass" branch of flash-attention that had be incorporated into core. Since we have switched to the main branch of flash_attention we the existing repro script did not produce the same memory error. This PR re-enables the FlashAttention Path for NTs. As well it unifies the nested preprocessing between the two implementations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95438
Approved by: https://github.com/mikaylagawarecki
# Summary
Add more checks around shape constraints as well as update the sdp_utils to properly catch different head_dims between qk and v for flash_attention which is not supported.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94274
Approved by: https://github.com/cpuhrsch
# Summary
- Adds a large parameter sweep for testing the various configs a user can call sdpa with and compares the deviation of the fused kernels vs the eager math fallback to test for correctness.
- Sm86 + head_dim==128 is throwing an IMA for memory efficient attention. We add a filter for use_mem_efficient_attention(). This has since been fixed in the upstream Xformers version but will likely not make it for branch cut.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94009
Approved by: https://github.com/cpuhrsch
# Summary
This PR creates _flash_attention_backward and _scaled_dot_product_flash_attention_backward native functions and registers them to the respective derivatives.yaml.
The goal is to replicate the torch.autograd.Function defined in the FlashAttention repo [here](33e0860c9c/flash_attn/flash_attn_interface.py (L126)) natively in PyTorch. One thing that we don't have access to is ctx.save_for_backward in native PyTorch so in order to save these variables I extended the returned objects from the forward functions.
### MetaFunctions
I also updated the FlashAttention meta functions to mirror the real outputs now. As well I added a meta registration for backwards. I have an XLMR training script and while eager training now works with FlashAttention compiling this module fails with the inductor error down below.
### Questions?
Performance issues vs mem efficient when using torch.nn.mha_forward
TorchCompile -> See purposed solution below.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92917
Approved by: https://github.com/cpuhrsch
# Summary
Add support for fused attention kernels (FlashAttention and memory-efficient attention) on Windows. Previously we could not do this because the fixes required c++17 to do this but we have since update the PyTorch standard.
This PR:
- Changes invocations of unsigned long to the fixed width integer type
- Adds in the #define FP16_SWITCH(COND, ...) which has been added to the flash_attention main branch
- Changes the some macros used within mem-efficient attention code in order to work around the VA_ARG discrepancy between clang/gcc and msvc. An alternative would be setting the global flag Zc:preprocessor
- Selectively applies /Zc:lambda to only the mem-efficient sources since applying this globally caused quantization files to not compile
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91909
Approved by: https://github.com/cpuhrsch
Summary:
Regularize mask handling for attn_mask and key_padding_mask
* Update documentation to remove reference to byte masks (which were deprecated long ago)
* Introduce check and warn about deprecation if attn_mask and key_padding_mask types mismatch
* Convert all masks to float before combining
* Combine by adding
Test Plan: sandcastle & github CI
Differential Revision: D42653215
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92733
Approved by: https://github.com/ngimel, https://github.com/drisspg
# Summary
In preparation for pt 2.0 launch this PR updates SDPA's API and makes the function a nn.funcitonal public function.
## Changes
### API
Previously the the function signature was:
`scaled_dot_product_attention(query, key, value, attn_mask=None, need_attn_weights=False, dropout_p=0.0, is_causal=False) -> (Tensor, Tensor)`
Updated signature:
`scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor`
This PR removes the need_attn_weights optional boolean variable and updates the return type to a singular tensor.
#### Reasoning:
The main goal of this function is to provide an easy interface for users to call into fused attention kernels e.g. (FlashAttention). The fused kernels do not currently support arbitrary attn_mask or dropout but there is a PR to mem-efficient attention to enable these. We want to have the API surface ready for when the backing kernels get updated.
The fused kernels save on memory usage by not materializing the weights and it is unlikely that a fast fused implementation will enable this feature so we are removing.
Discussed with folks at FAIR/Xformers and +1 this API change.
#### Make function Public
In preparation for the pt 2.0 launch we make the function public to start to generate user feedback
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92189
Approved by: https://github.com/cpuhrsch
# Summary
Memory efficient attention is a non deterministic algorithm.
This PR ensures that the sdp_choice will allow for mem-efficient to be used as the backend to SDPA if we are in warn only mode. Otherwise if we have enabled determinism and and set warn_only to False sdp_choice will not return memory efficient attention as the backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91979
Approved by: https://github.com/cpuhrsch
# Summary
This PR updates the second return value from SDPA to return an empty tensor of size 0 not what it would be if need_attn_weights is True. Also updates the meta function to account for this change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91782
Approved by: https://github.com/cpuhrsch
# Summary
Creates a callable native function that can determine which implementation of scaled dot product will get called. This allows to bump re-order the runtime dispatch of SDP to enable autograd.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89029
Approved by: https://github.com/cpuhrsch
# Registers the derivative for mem efficient backward
- Use gradcheck to test correctness. The kernel is not implemented for fp64 so run checks with bumped tolerances in fp32
- I also made updates based off of Xformer main branch and flash-attention cutlass branch.
- This will enable the fused backward to be called for scaled dot product attention
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88856
Approved by: https://github.com/cpuhrsch
# Registers the derivative for mem efficient backward
- Use gradcheck to test correctness. The kernel is not implemented for fp64 so run checks with bumped tolerances in fp32
- I also made updates based off of Xformer main branch and flash-attention cutlass branch.
- This will enable the fused backward to be called for scaled dot product attention
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88856
Approved by: https://github.com/cpuhrsch
Fixes T135842750 (follow-up for #87377)
## Description
At present, having both `src_key_padding_mask` and `src_mask` at the same time is not supported on the fastpath in Transformer and Multi-Head Attention.
This PR enables using both masks on the fastpath on CPU and GPU: if both masks are passed, we merge them into a 4D mask in Python and change mask type to 2 before passing downstream.
Downstream processing in native code is not changed, as it already supports 4D mask. Indeed, it is done depending on the device:
- on CUDA, by `SoftMax.cu::masked_softmax_cuda`. When mask type is 2, it calls either `dispatch_softmax_forward` -> `softmax_warp_forward` or `at::softmax` (depending on the input size). In both cases 4D mask is supported.
- on CPU, by `SoftMax.cpp::masked_softmax_cpp`. It calls `hosted_softmax` which supports 4D mask.
## Tests
- Extended `test_mask_check_fastpath` to check that fast path is indeed taken in Transformer when two masks are passed
- Added `test_multihead_self_attn_two_masks_fast_path_mock` to check that fast path is taken in MHA when two masks are passed
- Added `test_multihead_self_attn_two_masks_fast_path` to check that fast and slow paths give the same result when two masks are passed in MHA
- `test_masked_softmax_mask_types` now covers mask type 2
- `test_transformerencoderlayer_fast_path` (CPU smoke test) is expanded to the case of both masks provided simultaneously
- `test_masked_softmax_devices_parity` checks that mask type 2 is accepted by CPU and CUDA paths
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88488
Approved by: https://github.com/mikekgfb
## Issues
Fixes https://github.com/pytorch/pytorch/issues/81129#issuecomment-1179435674
## Description
Passing a 2D attention mask `src_mask` into the fast path of `TransformerEncoderLayer` in CPU was causing an error and so was disabled in https://github.com/pytorch/pytorch/pull/81277. This PR unrolls this fix, enabling `src_mask` on the fast path:
- Either attention mask `src_mask` of shape `(L, L)` or padding mask `src_key_padding_mask` of shape `(B, L)` are now allowed on the CPU fast path. If softmax is applied along the last dimension (as in multi-head attention), these masks are processed without expanding them to 4D. Instead, when iterating through the input, `Softmax.cpp::host_softmax` converts the index to match the mask dimensions, depending on the type.
- If softmax is applied along the dimension other than the last, `Softmax.cpp::masked_softmax_cpu` expands masks to 4D, converting them to `mask_type=2`. Theoretically one could also add special optimized cases for `dim=0, 1, 2` and process them without mask expansion, but I don't know how often is that used
## Tests:
- `test_transformerencoderlayer_fast_path` is extended to cover both attention mask and padding mask
- `test_masked_softmax_mask_types_0_1` is added to ensure results from CPU softmax with attention and padding masks match the explicit slow calculation
- `test_masked_softmax_devices_parity` is added to ensure results from masked softmax on CPU and CUDA match
## Note
I had to replace `float` with `torch.get_default_dtype()` in a couple of tests for the following reason:
- `test_nn.py` [sets the default type to `torch.double`](https://github.com/pytorch/pytorch/blob/master/test/test_nn.py#L24-L26)
- If I execute `test_nn.py` and `test_transformers.py` in one `pytest` run, this default still holds for transformer tests
- Some tests in `test_transformers.py` which were previously following the slow path now switched to fast path, and hard-coded `float` started clashing with default `double`
Let me know if there is a better way around it - or maybe I'm not supposed to run tests with `pytest` like this
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87377
Approved by: https://github.com/mikekgfb, https://github.com/weiwangmeta, https://github.com/malfet
# Summary
Use the private _scaled_dot_product_attention to support _native_multiheaded_attention. _SDP provides access to fused kernels when certain conditions are meant enabling a speed up for MHA.
cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87312
Approved by: https://github.com/cpuhrsch
# Summary
Add in a torch.backends.cuda flag and update context manager to pic between the three implementations of the scaled_dot_product_attention.
cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87946
Approved by: https://github.com/cpuhrsch
# Summary
- This code creates the runtime dispatch system for choosing a performant fused SDP kernel. The only choice of fused kernel is flash_attention. It also creates python flags and a context manager that can be used to turn off and on behavior for dispatch.
- This also adds support for flash_attention with dense tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85984
Approved by: https://github.com/cpuhrsch
# Summary
This exposes the _scaled_dot_product_attention function to python in the nn namespace. It is still underscored because the api for args, and kwargs is still in flux for the next few weeks and will eventually land as a prototype feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85044
Approved by: https://github.com/cpuhrsch
# Summary
This exposes the _scaled_dot_product_attention function to python in the nn namespace. It is still underscored because the api for args, and kwargs is still in flux for the next few weeks and will eventually land as a prototype feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85044
Approved by: https://github.com/cpuhrsch
Summary: Check that fastpath is taken, which type (sparsity fastpath or normal) for mask that is aligned and one that is not.
Test Plan: buck test caffe2/test:test_transformers
Differential Revision: D38259928
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82999
Approved by: https://github.com/jbschlosser
Adds an initial private API version of the SDP interface.
Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```
Returns a tuple of `(output, attn_weights)`.
Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81956
Approved by: https://github.com/drisspg, https://github.com/erichan1
Adds an initial private API version of the SDP interface.
Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```
Returns a tuple of `(output, attn_weights)`.
Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81956
Approved by: https://github.com/drisspg, https://github.com/erichan1