Notable changes:
1. Enable CudaGraph related tests
2. Fix UT problems
3. EXPERIMENTAL Navi31 support. User should enable Navi31 support with Env Var `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1`
Know Problem:
1. `test/test_transformers.py` will massive failures and/or NaN outputs with `--use-pytest`
+ Update: Confirmed skip `class TestSDPAPrivateUse1Only` can fix the problem with `--use-pytest`
Note:
AOTriton 0.7b adds support to nestedtenosrs+SDPA but need more work (and consequently a separate PR) to enable it.
Fixes#133540
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134498
Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily, https://github.com/malfet
Use oneDNN BRGEMM on packed data to get better performance on the 5th generation of Xeon where Intel® Advanced Matrix Extensions (AMX) will have fp16 support, e.g. amx-fp16.
Multiple models have achieved acceleration, for instance, FP16 stable diffusion v2.1 has achieved over 50% improvement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131879
Approved by: https://github.com/jgong5, https://github.com/peterbell10
ghstack dependencies: #131878
# UPDATE:
This is take 3 of https://github.com/pytorch/pytorch/pull/131863 which was landed via co dev but not applying correclty
# Summary
Changes the stance of SDPA on what to do for fully masked out rows
## Current Behavior
Several PyTorch users have expressed frustration over this issue:
- https://github.com/pytorch/pytorch/issues/41508
- https://github.com/pytorch/pytorch/issues/103749
- https://github.com/pytorch/pytorch/issues/103963
These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here:
https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617
Can be paraphrased as follows:
When passing in fully masked out rows, attention becomes ambiguous. We have two main options:
1. Uniformly attend to all values:
```python
scores[masked_out_rows] = 1 / len(row)
out[masked_out_rows] = 1 / len(row) * value
```
2. Decide that attention between no queries (masked) and no keys (masked) is meaningless:
```python
output[fully_masked_rows] = NaN
```
We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs:
``` Python
>fill_value = -float("inf")
>row0 = torch.randn(4)
>row1 = torch.tensor([(fill_value for _ in range(4)])
>matrix = torch.stack([row0, row1]).requires_grad_(True)
>out = torch.softmax(matrix, 1)
>out = out[0]
>print(out)
tensor([0.5377, 0.2729, 0.0692, 0.1201])
```
Cool, problem solved. But what happends when you call backwards..
```Python
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08],
[ nan, nan, nan, nan]])
```
Those pesky NaNs are back!
## Why do we see NaNs today?
The core of the problem revolves around using softmax function in sdpa:
```python
> row = torch.tensor([(-float("inf")) for _ in range(4)])
> torch.softmax(row, 0)
tensor([nan, nan, nan, nan])
```
## Quick Aside: Masking in Attention
Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs.
We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values.
## Alternative Approaches
If we use a very large negative number instead of -inf:
```python
> row = torch.tensor([(-1e6) for _ in range(4)])
> torch.softmax(row, 0)
tensor([0.2500, 0.2500, 0.2500, 0.2500])
```
However if users always remembered to "slice" out their outputs i.e.:
```Python
>fill_value = -1e6
>...
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[-0.0563, -0.0564, 0.1613, -0.0486],
[ 0.0000, 0.0000, 0.0000, 0.0000]])
```
This would bring us back into a better state.
## A Third Option
We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation.
This PR implements the new semantic for masking w/ attention in fully masked-out rows:
```python
out[masked_out_rows] = 0
```
**Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption.
## Details
This PR stack does 3 things:
1. Adds a PRIVATE _safe_softmax op
2. Updates semantic for flash_cpu fused kernel
3. Updates semantic for efficient_cuda fused kernel
_safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num.
Why I think this is okay? (please find a counter point if avail)
There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them?
The only case that this can happen is if the input itself had a NaN or an Inf
For example:
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = torch.finfo(torch.float16).max
print(a.softmax(-1))
```
Will return
`tensor([0., 1., 0., 0.], dtype=torch.float16)`
Where
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = float("inf")
a.softmax(-1)
```
returns:
`tensor([nan, nan, nan, nan], dtype=torch.float16)`
If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this
```Python
max = torch.max(a, dim=-1, keepdim=True)
exp = torch.exp(a - max.values)
denom = torch.sum(exp, dim=-1, keepdim=True)
softmax = exp / denom
softmax = torch.where(max.values == float('-inf'), 0.0, softmax)
```
however we would be paying for this in math performance.
## Why Now
I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic.
Differential Revision: [D61418679](https://our.internmc.facebook.com/intern/diff/D61418679)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133882
Approved by: https://github.com/soulitzer
- Reduced number of skipped test cases
- Merged redundant test cases
**Benchmark:**
| | Original | New |
| ----- | ----- | ----- |
| Run time | 60 mins | 35 mins |
| Total tests | 75k | 18k |
| Skipped tests | 20k | 4k |
_These are approximate numbers from running test_transformers.py on a single H100, and can change based on the device._
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133049
Approved by: https://github.com/drisspg
Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.
Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.
Differential Revision: D58710805
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922
Approved by: https://github.com/xw285cornell, https://github.com/drisspg
Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.
Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.
Differential Revision: D58710805
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922
Approved by: https://github.com/xw285cornell, https://github.com/drisspg
# Motivation
Structured codegen is beneficial for easier decoupling tensor meta setting and kernel implementation. At present, XPU operators need to handle tensor metas in hand-written way.
We plan to leverage the codegen system for auto generate structured operators. This PR facilitate the `DispatchStub` support for Intel GPUs. Based on that, XPU operators would have possibility to register kernel functor to operator stubs.
This is a prerequisite of PR #130082, where we will modify the codegen system to generate XPU needed source files and headers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130019
Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD
## Motivation
This refactor aligns our testing methodology with the Flash Attention upstream repository while addressing several key issues:
1. **Standardized comparison**: We now compare fused kernels against float64 references, using the maximum of a calculated tolerance (based on same-precision math implementation) or standard float32 `atol`.
2. **Reduced redundancy**: Utilizing the same tensors for both same-precision math and fused kernel runs eliminates duplication.
3. **Improved maintainability**: The new approach simplifies tolerance adjustments across all affected tests.
4. **Consistency**: Standardizing tensor comparisons ensures a more uniform and reliable testing suite.
These changes collectively simplify our testing code, improve its maintainability, and provide a more robust framework for validating our attention mechanisms.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131743
Approved by: https://github.com/jainapurva, https://github.com/jbschlosser
Summary: The test is from D59181111, but I couldn't figure out a way to make it pass on FBCODE because loading PyTorch C++ extension requires Ninja which is not going to work with BUCK
Test Plan: `buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test:transformers`
Differential Revision: D59304327
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129997
Approved by: https://github.com/drisspg
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
This patch implements `with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):` by reusing AOTriton's accelerated SDPA implementation
Known limitations:
- Only supports MI200/MI300X GPUs
- Does not support varlen
- Does not support `CausalVariant`
- Optional arguments `causal_diagonal` and `seqlen_k` in `_efficient_attention_forward/backward` must be null
- Does not work well with inductor's SDPA rewriter. The rewriter has been updated to only use math and flash attention on ROCM.
This PR also uses a different approach of installing AOTriton binary instead of building it from source in the base docker image. More details on motivation: https://github.com/pytorch/pytorch/pull/124885#issuecomment-2153229129
`PYTORCH_TEST_WITH_ROCM=1 PYTORCH_TESTING_DEVICE_ONLY_FOR="cuda" python test/test_transformers.py` yields "55028 passed, 20784 skipped" results with this change. [Previous result](https://hud.pytorch.org/pr/127528) of `test_transformers.py` was 0 error, 0 failure, 55229 skipped out of 75517 tests in total (the XML report does not contain total number of passed tests).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124885
Approved by: https://github.com/malfet
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
This PR fixes the two major issues that was discovered after the initial merge of PR #121561
1. The Flash Attention support added by has severe performance regressions on regular shapes (power of two head dimensions and sequence lengths) compared with PR #115981. Its performance is worse than the math backend and only has numerical stability advantages. This PR fixes this problem.
2. There is a flaw of memory storage handling in PR #121561 which does not copy the gradients back to the designated output tensor. This PR removes the deprecated `TensorStorageSanitizer` class which is unnecessary due to the more flexible backward kernel shipped by PR #121561
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122857
Approved by: https://github.com/jeffdaily, https://github.com/drisspg
This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)
- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
* MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
* Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
* varlen API will be supported in future release of AOTriton
- [x] Only support head dimension 16,32,64,128.
* Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
* Kernel is selected according to autotune information from Triton.
Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API
This is a more extensive fix to #112997
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/huydhn
This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)
- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
* MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
* Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
* varlen API will be supported in the next release of AOTriton
- [x] Only support head dimension 16,32,64,128.
* Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
* Kernel is selected according to autotune information from Triton.
Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API
This is a more extensive fix to #112997
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/malfet, https://github.com/atalman
This reverts commit a5a63db3bf.
Fixes #ISSUE_NUMBER
Reverts #118368
Got reverted internally but branch got deleted to automation didn't work
Mildly edited stack trace
```
...
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "torch/_dynamo/eval_frame.py", line 453, in _fn
return fn(*args, **kwargs)
File "torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 635, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "torch/fx/experimental/proxy_tensor.py", line 995, in trace
res = super().trace(root, concrete_args)
File "torch/_dynamo/eval_frame.py", line 453, in _fn
return fn(*args, **kwargs)
File "torch/_dynamo/external_utils.py", line 25, in inner
return fn(*args, **kwargs)
File "torch/fx/_symbolic_trace.py", line 793, in trace
(self.create_arg(fn(*args)),),
File "torch/fx/experimental/proxy_tensor.py", line 665, in wrapped
out = f(*tensors)
File "<string>", line 1, in <lambda>
File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 357, in _functionalized_f_helper
f_outs = fn(*f_args)
File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 68, in inner_fn
outs = fn(*args)
File "torch/_functorch/_aot_autograd/utils.py", line 161, in flat_fn
tree_out = fn(*args, **kwargs)
File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 618, in functional_call
out = PropagateUnbackedSymInts(mod).run(
File "torch/fx/interpreter.py", line 145, in run
self.env[node] = self.run_node(node)
File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 593, in run_node
result = super().run_node(n)
File "torch/fx/interpreter.py", line 202, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "torch/fx/interpreter.py", line 274, in call_function
return target(*args, **kwargs)
File "torch/_ops.py", line 571, in __call__
return self_._op(*args, **kwargs)
File "torch/_subclasses/functional_tensor.py", line 380, in __torch_dispatch__
outs_unwrapped = func._op_dk(
File "torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 744, in __torch_dispatch__
return self.inner_torch_dispatch(func, types, args, kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 779, in inner_torch_dispatch
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 423, in proxy_call
r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
File "torch/fx/experimental/proxy_tensor.py", line 1225, in maybe_handle_decomp
return CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
File "torch/_decomp/decompositions.py", line 4322, in scaled_dot_product_flash_attention_for_cpu
torch._check(
File "torch/__init__.py", line 1133, in _check
_check_with(RuntimeError, cond, message)
File "torch/__init__.py", line 1116, in _check_with
raise error_type(message_evaluated)
RuntimeError: query must be FP32, FP64, BF16 but got torch.float16
While executing %_scaled_dot_product_flash_attention_for_cpu : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default](args = (%l_q_, %l_k_, %l_v_), kwargs = {attn_mask: %l_attn_mask_})
Original traceback:
File "executorch/backends/xnnpack/partition/graphs/sdpa.py", line 34, in forward
return torch.nn.functional.scaled_dot_product_attention(
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119204
Approved by: https://github.com/kit1980
# Summary
Simplification of Backend Selection
This PR deprecates the `torch.backends/cuda/sdp_kernel` context manager and replaces it with a new context manager `torch.nn.attention.sdpa_kernel`. This context manager also changes the api for this context manager.
For `sdp_kernel` one would specify the backend choice by taking the negation of what kernel they would like to run. The purpose of this backend manager was to only to be a debugging tool, "turn off the math backend" and see if you can run one of the fused implementations.
Problems:
- This pattern makes sense if majority of users don't care to know anything about the backends that can be run. However, if users are seeking to use this context manager then they are explicitly trying to run a specific backend.
- This is not scalable. We are working on adding the cudnn backend and this API makes it so so that more implementations will need to be turned off if user wants to explicitly run a given backend.
- Discoverability of the current context manager. It is somewhat un-intutive that this backend manager is in backends/cuda/init when this now also controls the CPU fused kernel behavior. I think centralizing to attention namespace will be helpful.
Other concerns:
- Typically backends (kernels) for operators are entirely hidden from users and implementation details of the framework. We have exposed this to users already, albeit not by default and with beta warnings. Does making backends choices even more explicit lead to problems when we potentially want to remove existing backends, (perhaps inputs shapes will get covered by newer backends).
A nice side effect is now that we aren't using the `BACKEND_MAP` in test_transformers many, many dynamo failures are passing for CPU tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114689
Approved by: https://github.com/cpuhrsch
## Motivation
The current code of `value in [torch.backends.cudnn, torch.ops]` requires `value` to have the implementation of `__eq__`. If the value is a custom object and does not implement `__eq__`, dynamo will throw error. For example, ConvolutionOpContext, the custom 'torch._C.ScriptClass' object registered in IPEX, dynamo will throw the following error:
**torch._dynamo.exc.InternalTorchDynamoError: '__eq__' is not implemented for __torch__.torch.classes.ipex_prepack.ConvolutionOpContext**
I think this is a common issue, To avoid this issue, the PR replaces the current code `value in [torch.backends.cudnn, torch.ops]`with `isinstance(value, (torch.backends.cudnn.CudnnModule, torch._ops._Ops)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116856
Approved by: https://github.com/jansel
Fixes https://github.com/pytorch/pytorch/issues/116385
Don't call `torch._transformer_encoder_layer_fwd` when `bias=False`
`bias=False` was not something that `torch._transformer_encoder_layer_fwd` was meant to work with, it was my bad that this wasn't tested as I approved https://github.com/pytorch/pytorch/pull/101687.
`bias=False` was causing the `tensor_args` in [`TransformerEncoder`](a17de2d645/torch/nn/modules/transformer.py (L663-L677)) to contain `None`s and error on checks for the fastpath like `t.requires_grad for t in tensor_args`.
Alternative fix would be to
1) Pass `torch.zeros_like({*}.weight)` to the kernel when `bias=False` and filter `tensor_args` as appropriate
2) Fix `torch._transformer_encoder_layer_fwd` to take `Optional<Tensor>` for biases and fix the kernels as appropriate
Let me know if these approaches are preferable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116760
Approved by: https://github.com/jbschlosser
Note about the Updates:
This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.
CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.
Original PR (https://github.com/pytorch/pytorch/pull/114309) Note:
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.
Know limitations:
- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.
Fixes#112997
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115981
Approved by: https://github.com/malfet
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.
Know limitations:
- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.
Fixes https://github.com/pytorch/pytorch/issues/112997
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114309
Approved by: https://github.com/jeffdaily, https://github.com/malfet
---------
Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
# Summary
This PR introduces a new Tensor subclass that is designed to be used with torch.nn.functional.scaled_dot_product_attention. Currently we have a boolean `is_causal` flag that allows users to do do causal masking without the need to actually create the "realized" attention bias and pass into sdpa. We originally added this flag since there is native support in both fused kernels we support. This provides a big performance gain ( the kernels only need to iterate over ~0.5x the sequence, and for very large sequence lengths this can provide vary large memory improvements.
The flag was introduced when the early on in the kernel development and at the time it was implicitly meant to "upper_left" causal attention. This distinction only matters when the attention_bias is not square. For a more detailed break down see: https://github.com/pytorch/pytorch/issues/108108. The kernels default behavior has since changed, largely due to the rise of autogressive text generation. And unfortunately this would lead to a BC break. In the long term it may actually be beneficial to change the default meaning of `is_causal` to represent lower_right causal masking.
The larger theme though is laid here: https://github.com/pytorch/pytorch/issues/110681. The thesis being that there is alot of innovation in SDPA revolving around the attention_bias being used. This is the first in hopefully a few more attention_biases that we would like to add. The next interesting one would be `sliding_window` which is used by the popular mistral model family.
Results from benchmarking, I improved the meff_attention perf hence the slightly decreased max perf.
```Shell
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
| Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | embed_dim | dtype | head_dim |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
| Average | 1.2388050062214226 | | | | | | | |
| Max | 1.831672915579016 | 128 | 32 | 1024 | 2048 | 2048 | torch.bfloat16 | 64 |
| Min | 0.9430534166730135 | 1 | 16 | 256 | 416 | 2048 | torch.bfloat16 | 128 |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114823
Approved by: https://github.com/cpuhrsch
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)
This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.
## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.
Should this be warn_once?
We only call expand, once on the aligned mask.
Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115
@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)
This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.
## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.
Should this be warn_once?
We only call expand, once on the aligned mask.
Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115
@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
# Summary:
This pull request **removes** support for non-square sequence lengths in causal attention when using FlashAttention V2.
### Why are doing this
// FlashAttention 2 updated the default mask meaning for causal in this PR:
// 9e5e8bc91e it is now aligned to lower_right which would be a BC break
// for non-square masks. We will not support non-square masks for causal w/ FAV2
For more context see:
https://github.com/pytorch/pytorch/issues/108108
### Followup
A large number of people will likely want to use FAV2 with lower_right causal attention for non equal sequence lengths. See this RFC : https://github.com/pytorch/pytorch/issues/110681
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111007
Approved by: https://github.com/cpuhrsch
# Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 318764f</samp>
This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`.
<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at ed4a773</samp>
> _Fused kernels of doom, unleash the flash attention_
> _Nested tensors on fire, reshape and pad with caution_
> _Backward pass of power, dispatch the CUDA key_
> _Test the gradients of hell, warn the user if they disagree_
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97485
Approved by: https://github.com/jbschlosser
# Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 318764f</samp>
This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`.
<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at ed4a773</samp>
> _Fused kernels of doom, unleash the flash attention_
> _Nested tensors on fire, reshape and pad with caution_
> _Backward pass of power, dispatch the CUDA key_
> _Test the gradients of hell, warn the user if they disagree_
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97485
Approved by: https://github.com/jbschlosser