Commit Graph

253 Commits

Author SHA1 Message Date
CaoE
a6fae2e811 Use BRGEMM for Half flash attention forward kernel (#131879)
Use oneDNN BRGEMM on packed data to get better performance on the 5th generation of Xeon where Intel® Advanced Matrix Extensions (AMX) will have fp16 support, e.g. amx-fp16.
Multiple models have achieved acceleration, for instance, FP16 stable diffusion v2.1 has achieved over 50% improvement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131879
Approved by: https://github.com/jgong5, https://github.com/peterbell10
ghstack dependencies: #131878
2024-09-08 12:32:23 +00:00
Valentine233
0dbc72887b [CPU][flash attention] make the stride of output align with input (#134656)
Fixes #133671

Currently, the output of CPU flash attention has a fixed layout, no matter what the input is. This PR makes the stride of output align with input q/k/v, which is the same behavior as math backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134656
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-08-29 16:04:25 +00:00
Xinya Zhang
46ecc673ae [ROCm] Prevent accidental enablement of efficient attention. (#133331)
Currently Efficient attention and Flash attention share the same set of GPU
kernels on ROCM and have common limitations on head sizes.

Fixes https://github.com/pytorch/pytorch/issues/132004

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133331
Approved by: https://github.com/malfet, https://github.com/jithunnair-amd
2024-08-27 00:03:45 +00:00
eqy
e93ca12c88 [CUDNN][SDPA] Fix unsupported trivial stride-1 transpose case (#134031)
Fixes #134001
Incorrect assumption that two same-shape tensors being contiguous meant that they would have the same stride

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134031
Approved by: https://github.com/drisspg, https://github.com/Skylion007

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2024-08-25 14:31:30 +00:00
drisspg
fb26b84390 Update fused kernels and call _safe_softmax from SDPA (#133882)
# UPDATE:
This is  take 3 of https://github.com/pytorch/pytorch/pull/131863 which was landed via co dev but not applying correclty

# Summary
Changes the stance of SDPA on what to do for fully masked out rows

## Current Behavior
Several PyTorch users have expressed frustration over this issue:
- https://github.com/pytorch/pytorch/issues/41508
- https://github.com/pytorch/pytorch/issues/103749
- https://github.com/pytorch/pytorch/issues/103963

These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here:
https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617

Can be paraphrased as follows:

When passing in fully masked out rows, attention becomes ambiguous. We have two main options:

1. Uniformly attend to all values:
   ```python
   scores[masked_out_rows] = 1 / len(row)
   out[masked_out_rows] = 1 / len(row) * value
   ```

2. Decide that attention between no queries (masked) and no keys (masked) is meaningless:
   ```python
   output[fully_masked_rows] = NaN
   ```

We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs:
``` Python
>fill_value = -float("inf")
>row0 = torch.randn(4)
>row1 = torch.tensor([(fill_value for _ in range(4)])
>matrix = torch.stack([row0, row1]).requires_grad_(True)
>out = torch.softmax(matrix, 1)
>out = out[0]
>print(out)
tensor([0.5377, 0.2729, 0.0692, 0.1201])
```
Cool, problem solved. But what happends when you call backwards..
```Python
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08],
        [       nan,        nan,        nan,        nan]])
```
Those pesky NaNs are back!

## Why do we see NaNs today?

The core of the problem revolves around using softmax function in sdpa:

```python
> row = torch.tensor([(-float("inf")) for _ in range(4)])
> torch.softmax(row, 0)
tensor([nan, nan, nan, nan])
```

## Quick Aside: Masking in Attention

Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs.

We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values.

## Alternative Approaches

If we use a very large negative number instead of -inf:

```python
> row = torch.tensor([(-1e6) for _ in range(4)])
> torch.softmax(row, 0)
tensor([0.2500, 0.2500, 0.2500, 0.2500])
```
However if users always remembered to "slice" out their outputs i.e.:
```Python
>fill_value = -1e6
>...
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[-0.0563, -0.0564,  0.1613, -0.0486],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])
```
This would bring us back into a better state.

## A Third Option

We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation.

This PR implements the new semantic for masking w/ attention in fully masked-out rows:
```python
out[masked_out_rows] = 0
```

**Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption.

## Details
This PR stack does 3 things:
1. Adds a PRIVATE _safe_softmax op
2. Updates semantic for flash_cpu fused kernel
3. Updates semantic for efficient_cuda fused kernel

_safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num.

Why I think this is okay? (please find a counter point if avail)
There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them?

The only case that this can happen is if the input itself had a NaN or an Inf
For example:
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = torch.finfo(torch.float16).max
print(a.softmax(-1))
```
Will return
`tensor([0., 1., 0., 0.], dtype=torch.float16)`

Where
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = float("inf")
a.softmax(-1)
```
returns:
`tensor([nan, nan, nan, nan], dtype=torch.float16)`

If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this

```Python
max = torch.max(a, dim=-1, keepdim=True)
exp = torch.exp(a - max.values)
denom = torch.sum(exp, dim=-1, keepdim=True)
softmax = exp / denom
softmax = torch.where(max.values == float('-inf'), 0.0, softmax)
```
however we would be paying for this in math performance.

## Why Now
I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic.

Differential Revision: [D61418679](https://our.internmc.facebook.com/intern/diff/D61418679)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133882
Approved by: https://github.com/soulitzer
2024-08-19 18:53:11 +00:00
eqy
c0c82a5f6a [CUDA][SDPA] Bump tolerances for test_mem_efficient_attention_attn_mask_vs (#133738)
Same thing as #133051 but for efficient attention

CC @drisspg @nWEIdia

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133738
Approved by: https://github.com/drisspg, https://github.com/nWEIdia, https://github.com/Skylion007
2024-08-18 19:14:29 +00:00
PyTorch MergeBot
cfec69e2a1 Revert "Update fused kernels and call _safe_softmax from SDPA (#131863)"
This reverts commit caba37e99b.

Reverted https://github.com/pytorch/pytorch/pull/131863 on behalf of https://github.com/izaitsevfb due to breaks executorch test executorch/backends/apple/coreml:test - test_vit_skip_conv (executorch.backends.apple.coreml.test.test_coreml_partitioner.TestCoreMLPartitioner) ([comment](https://github.com/pytorch/pytorch/pull/131863#issuecomment-2291855634))
2024-08-15 17:55:07 +00:00
drisspg
caba37e99b Update fused kernels and call _safe_softmax from SDPA (#131863)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131863
Approved by: https://github.com/jbschlosser, https://github.com/Chillee
2024-08-13 23:37:50 +00:00
PyTorch MergeBot
4cca18d5b6 Revert "Update fused kernels and call _safe_softmax from SDPA (#131863)"
This reverts commit e61def65d5.

Reverted https://github.com/pytorch/pytorch/pull/131863 on behalf of https://github.com/albanD due to Broke forward AD tests in main ([comment](https://github.com/pytorch/pytorch/pull/131863#issuecomment-2286432628))
2024-08-13 14:44:08 +00:00
drisspg
e61def65d5 Update fused kernels and call _safe_softmax from SDPA (#131863)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131863
Approved by: https://github.com/jbschlosser
2024-08-13 00:51:55 +00:00
Apurva Jain
79ca596dc6 Optimize test_transformers.py (#133049)
- Reduced number of skipped test cases
- Merged redundant test cases

**Benchmark:**

| | Original | New |
| ----- | ----- | ----- |
| Run time | 60 mins | 35 mins |
| Total tests | 75k | 18k |
| Skipped tests | 20k | 4k |

_These are approximate numbers from running test_transformers.py on a single H100, and can change based on the device._

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133049
Approved by: https://github.com/drisspg
2024-08-11 05:20:58 +00:00
eqy
c89936eaa0 [CUDA][SDPA] Bump grad_key fudge factor in test_flash_attention_vs_math_ref_grads (#133051)
Abates failures like `ValueError: grad_key Test error 1.592235639691353e-05 is greater than threshold 1.5236437320709229e-05!` that we've seen when bringing up newer versions of CUDA

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133051
Approved by: https://github.com/drisspg, https://github.com/Skylion007
2024-08-10 01:49:30 +00:00
Apurva Jain
8bc5ef563e Grouped Query Attention (#132689)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Differential Revision: D60772086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689
Approved by: https://github.com/drisspg
2024-08-07 05:35:36 +00:00
Jianyu Huang
c7cfa51721 Always use high precision for SDPA math backend (#128922)
Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.

Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.

Differential Revision: D58710805

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922
Approved by: https://github.com/xw285cornell, https://github.com/drisspg
2024-08-04 23:58:14 +00:00
Mikayla Gawarecki
f49d5e30eb Change owners of test/test_transformers.py to module: multi-headed-attention (#132519)
So flaky tests get tagged with `module: multi-headed-attention` instead of `module: nn`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132519
Approved by: https://github.com/Skylion007
2024-08-02 20:12:33 +00:00
PyTorch MergeBot
bcb4f7c172 Revert "Grouped Query Attention (#128898)"
This reverts commit 6b28af1b79.

Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/ZainRizvi due to Sorry, this broke a bunch of tests internally. See D60638265 ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2265961038))
2024-08-02 18:58:46 +00:00
PyTorch MergeBot
59b73079a0 Revert "Always use high precision for SDPA math backend (#128922)"
This reverts commit fbf3bc0a60.

Reverted https://github.com/pytorch/pytorch/pull/128922 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR has a dependency on another PR (https://github.com/pytorch/pytorch/pull/128898) that has to be reverted ([comment](https://github.com/pytorch/pytorch/pull/128922#issuecomment-2265949958))
2024-08-02 18:46:50 +00:00
Jianyu Huang
fbf3bc0a60 Always use high precision for SDPA math backend (#128922)
Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.

Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.

Differential Revision: D58710805

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922
Approved by: https://github.com/xw285cornell, https://github.com/drisspg
2024-08-01 18:55:48 +00:00
jainapurva
6b28af1b79 Grouped Query Attention (#128898)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-31 22:58:51 +00:00
PyTorch MergeBot
499ead96ff Revert "Grouped Query Attention (#128898)"
This reverts commit d039b14207.

Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/albanD due to Broken test on main ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2258314481))
2024-07-30 13:11:24 +00:00
jainapurva
d039b14207 Grouped Query Attention (#128898)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-29 21:49:06 +00:00
Yan Zhiwei
2a02b5cd22 [Intel GPU] Dispatch Stub support (#130019)
# Motivation
Structured codegen is beneficial for easier decoupling tensor meta setting and kernel implementation. At present, XPU operators need to handle tensor metas in hand-written way.

We plan to leverage the codegen system for auto generate structured operators. This PR facilitate the `DispatchStub` support for  Intel GPUs. Based on that, XPU operators would have possibility to register kernel functor to operator stubs.

This is a prerequisite of PR #130082, where we will modify the codegen system to generate XPU needed source files and headers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130019
Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD
2024-07-29 02:18:52 +00:00
drisspg
1bfe7eb7e6 Update how we do sdpa testing (#131743)
## Motivation

This refactor aligns our testing methodology with the Flash Attention upstream repository while addressing several key issues:

1. **Standardized comparison**: We now compare fused kernels against float64 references, using the maximum of a calculated tolerance (based on same-precision math implementation) or standard float32 `atol`.

2. **Reduced redundancy**: Utilizing the same tensors for both same-precision math and fused kernel runs eliminates duplication.

3. **Improved maintainability**: The new approach simplifies tolerance adjustments across all affected tests.

4. **Consistency**: Standardizing tensor comparisons ensures a more uniform and reliable testing suite.

These changes collectively simplify our testing code, improve its maintainability, and provide a more robust framework for validating our attention mechanisms.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131743
Approved by: https://github.com/jainapurva, https://github.com/jbschlosser
2024-07-27 03:58:49 +00:00
Valentine233
868d9a4f12 [cpu][flash attention] fix nan issue (#130014)
Fixes #127055.

NaNs are generated in flash attention because the computation of `std::exp((-inf) - (-inf))` and `+/-inf * 0` in lazy softmax. We fix the issue by avoiding the related calculation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130014
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-07-10 02:33:26 +00:00
eqy
86fb76e871 [SDPA] Clean up print in test/test_transformers.py (#130302)
Left this in #125343, oops...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130302
Approved by: https://github.com/awgu
2024-07-09 09:20:52 +00:00
Huy Do
8f70bf7a94 Skip TestSDPAPrivateUse1Only on FBCODE (#129997)
Summary: The test is from D59181111, but I couldn't figure out a way to make it pass on FBCODE because loading PyTorch C++ extension requires Ninja which is not going to work with BUCK

Test Plan: `buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test:transformers`

Differential Revision: D59304327

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129997
Approved by: https://github.com/drisspg
2024-07-03 06:48:51 +00:00
eqy
24b6c5a41f [cuDNN][SDPA] Bail out of dispatching to cuDNN for head dim > 128 on Ampere (#129587)
Fix for #129579

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129587
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2024-06-30 19:37:44 +00:00
eqy
f845a7a91a [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.

What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...

Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-30 19:22:16 +00:00
FEI
59e4e92556 sdp::SDPBackend::flash_attention support PrivateUse1 (#126392)
Fixes https://github.com/pytorch/pytorch/issues/124271

cc  @cpuhrsch @drisspg @albanD @soulitzer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126392
Approved by: https://github.com/drisspg
2024-06-28 17:48:40 +00:00
PyTorch MergeBot
999eec8dea Revert "[cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)"
This reverts commit b7e7a4cb01.

Reverted https://github.com/pytorch/pytorch/pull/125343 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to break some test_transformer running on internal A100 and V100 ([comment](https://github.com/pytorch/pytorch/pull/125343#issuecomment-2196202003))
2024-06-28 06:03:54 +00:00
PyTorch MergeBot
d21993bbb8 Revert "[cuDNN][SDPA] Bail out of dispatching to cuDNN for head dim > 128 on Ampere (#129587)"
This reverts commit 7854d84acb.

Reverted https://github.com/pytorch/pytorch/pull/129587 on behalf of https://github.com/huydhn due to Sorry for revert yet another of your change but I need to revert this to cleanly revert https://github.com/pytorch/pytorch/pull/125343#issuecomment-2196187332 ([comment](https://github.com/pytorch/pytorch/pull/129587#issuecomment-2196198756))
2024-06-28 06:01:07 +00:00
eqy
7854d84acb [cuDNN][SDPA] Bail out of dispatching to cuDNN for head dim > 128 on Ampere (#129587)
Fix for #129579

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129587
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2024-06-28 04:42:45 +00:00
Eddie Yan
b7e7a4cb01 [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.

What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...

Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-26 00:49:18 +00:00
iibrahimli
2db33054b3 Disable fast path in TransformerEncoderLayer when there are forward (pre-)hooks attached to modules (#128415)
Fixes #128413

Disable fast-path if there are forward hooks or pre-hooks.

Example failure case given in the issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128415
Approved by: https://github.com/mikaylagawarecki
2024-06-21 17:38:08 +00:00
Valentine233
5da428d9eb [cpu][flash attention] fix attention mask issue (#128816)
For attention mask in flash attention:

- Fix the issue of accessing illegal memory when the last size of mask is 1.
- Add UT of attention mask for various shapes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128816
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-06-21 01:12:48 +00:00
PyTorch MergeBot
817ce6835b Revert "[cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)"
This reverts commit 4c971932e8.

Reverted https://github.com/pytorch/pytorch/pull/125343 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/125343#issuecomment-2163690162))
2024-06-12 18:47:52 +00:00
PyTorch MergeBot
7db501ba2b Revert "[cuDNN][SDPA] Support different key, value dimension in cuDNN SDPA (#128350)"
This reverts commit 45dccfddcd.

Reverted https://github.com/pytorch/pytorch/pull/128350 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/128350#issuecomment-2163669538))
2024-06-12 18:35:18 +00:00
eqy
45dccfddcd [cuDNN][SDPA] Support different key, value dimension in cuDNN SDPA (#128350)
CC @vedaanta-nvidia @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128350
Approved by: https://github.com/Skylion007
2024-06-11 19:22:21 +00:00
eqy
4c971932e8 [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.

What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...

Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-09 06:53:34 +00:00
Xinya Zhang
d34075e0bd Add Efficient Attention support on ROCM (#124885)
This patch implements `with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):` by reusing AOTriton's accelerated SDPA implementation

Known limitations:
- Only supports MI200/MI300X GPUs
- Does not support varlen
- Does not support `CausalVariant`
- Optional arguments `causal_diagonal` and `seqlen_k` in `_efficient_attention_forward/backward` must be null
- Does not work well with inductor's SDPA rewriter. The rewriter has been updated to only use math and flash attention on ROCM.

This PR also uses a different approach of installing AOTriton binary instead of building it from source in the base docker image. More details on motivation: https://github.com/pytorch/pytorch/pull/124885#issuecomment-2153229129

`PYTORCH_TEST_WITH_ROCM=1 PYTORCH_TESTING_DEVICE_ONLY_FOR="cuda" python test/test_transformers.py` yields "55028 passed, 20784 skipped" results with this change.  [Previous result](https://hud.pytorch.org/pr/127528) of `test_transformers.py` was 0 error, 0 failure, 55229 skipped out of 75517 tests in total (the XML report does not contain total number of passed tests).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124885
Approved by: https://github.com/malfet
2024-06-08 22:41:05 +00:00
PyTorch MergeBot
a9309502af Revert "Refactoring to remove unused variable (#125252)"
This reverts commit b094622bc9.

Reverted https://github.com/pytorch/pytorch/pull/125252 on behalf of https://github.com/drisspg due to going to land codev ([comment](https://github.com/pytorch/pytorch/pull/125252#issuecomment-2089394606))
2024-05-02 01:49:57 +00:00
Apurva Jain
b094622bc9 Refactoring to remove unused variable (#125252)
Summary: Removed unused variable for running encoder

Test Plan: buck test //caffe2/test:transformers

Differential Revision: D56771972

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125252
Approved by: https://github.com/drisspg
2024-05-01 15:17:45 +00:00
Fuzzkatt
1cf62e86a4 skip various unit tests for Jetson (#122531)
skip multiprocessing, cuda expandable segments, mem eff and flash attention tests on Jetson due to hanging / sigkill issues from nvidia internal testing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122531
Approved by: https://github.com/eqy, https://github.com/malfet
2024-04-16 01:26:26 +00:00
Aaron Gokaslan
1d6c5972c1 [BE]: Optimize min/max/sum comprehensions C419 (#123960)
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
2024-04-12 23:54:15 +00:00
William Wen
cbde0f048b [dynamo, 3.12] enable tests disabled due to missing dynamo 3.12 support (#123300)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123300
Approved by: https://github.com/jansel, https://github.com/malfet, https://github.com/zou3519
2024-04-05 20:13:17 +00:00
Xinya Zhang
b83c94339e Fix performance regression and memory storage handling of Flash Attention on ROCM (#122857)
This PR fixes the two major issues that was discovered after the initial merge of PR #121561
1. The Flash Attention support added by has severe performance regressions on regular shapes (power of two head dimensions and sequence lengths) compared with PR #115981. Its performance is worse than the math backend and only has numerical stability advantages. This PR fixes this problem.
2. There is a flaw of memory storage handling in PR #121561 which does not copy the gradients back to the designated output tensor. This PR removes the deprecated `TensorStorageSanitizer` class which is unnecessary due to the more flexible backward kernel shipped by PR #121561

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122857
Approved by: https://github.com/jeffdaily, https://github.com/drisspg
2024-03-29 16:37:24 +00:00
Xinya Zhang
12116aee68 Add Flash Attention support on ROCM (#121561)
This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in future release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/huydhn
2024-03-28 00:27:38 +00:00
FEI
e08cbc0d41 update comment of test_invalid_last_dim_stride in test_transformers.py (#122679)
Fixes #122594

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122679
Approved by: https://github.com/mikaylagawarecki
2024-03-26 15:40:24 +00:00
PyTorch MergeBot
764eae9c4e Revert "Add Flash Attention support on ROCM (#121561)"
This reverts commit a37e22de70.

Reverted https://github.com/pytorch/pytorch/pull/121561 on behalf of https://github.com/huydhn due to Sorry for reverting your change but this needs more work to be able to land in fbcode because https://github.com/ROCm/aotriton is not available there atm.  We are working to reland this change before 2.3 release ([comment](https://github.com/pytorch/pytorch/pull/121561#issuecomment-2007717091))
2024-03-19 17:14:28 +00:00
drisspg
42624bceb6 Fixes nan with large bf16 values (#122135)
Fixes #121558

Performance on main:
``` Markdown
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal |     dtype      |    forward_time    |   backward_time    |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
|     1      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 12.608132004970683 | 65.90210803551601  |
|     1      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.75877740024589  | 64.83824399765581  |
|     1      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 16.465420153690506 |  67.6770955324173  |
|     1      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 17.398148600477725 | 68.19829455344006  |
|     1      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 29.053532000398263 | 99.58901099162175  |
|     1      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 |  27.826815698063   | 98.05690299253911  |
|     1      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 49.89655229728669  | 178.24282555375248 |
|     1      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 48.840098950313404 | 174.5950729819015  |
|     1      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 505.66218036692584 | 1865.9265094902366 |
|     1      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 295.0534054543823  | 967.3831606050952  |
|     1      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.496030446141958 | 55.11070846114308  |
|     1      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.47399884648621  | 55.452342028729625 |
|     1      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 13.216444296995178 | 55.14447903260589  |
|     1      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 12.763233599252999 | 55.142355500720434 |
|     1      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 19.409965351223946 |  74.9107634765096  |
|     1      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 19.02470579952933  | 74.84168506925926  |
|     1      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 46.37695319834165  | 172.19150450546294 |
|     1      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 45.225963747361675 | 185.19691249821335 |
|     1      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 634.3090848531574  | 2249.057865119539  |
|     1      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 320.47313248040155 | 1053.0515247955916 |
|     4      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 13.448987301671878 | 63.63581650657579  |
|     4      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 12.509283400140703 | 63.059300999157124 |
|     4      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 19.71098779467866  | 105.55780201684684 |
|     4      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 18.264925852417946 | 105.12311349157244 |
|     4      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 45.218703348655254 | 222.87272597895935 |
|     4      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 43.55393464793451  | 230.63290398567915 |
|     4      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 134.02968645095825 | 514.6893998607993  |
|     4      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 157.13709802366793 | 624.5892751030624  |
|     4      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 1776.7079547047617 | 6353.551096981391  |
|     4      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 1143.6000745743513 | 3811.8767354171723 |
|     4      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.717129248427227 | 55.35991647047922  |
|     4      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.746983398916198 | 55.76716404175386  |
|     4      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 17.255573300644752 | 106.47456656442955 |
|     4      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 16.46409669774584  | 108.07770595420152 |
|     4      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 46.63354124641045  | 213.74862996162847 |
|     4      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 47.01801469782367  | 240.78139301855117 |
|     4      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 127.76448752265424 | 508.08745552785695 |
|     4      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 168.6308984644711  | 667.2996102133766  |
|     4      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 2268.1598202325404 | 7727.2648515645415 |
|     4      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 1242.8469699807465 | 4161.965740495361  |
|     8      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 14.340955897932872 | 93.72280450770633  |
|     8      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 13.25262250029482  |  93.2030284893699  |
|     8      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 27.598425600444898 | 183.23776399483904 |
|     8      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 26.362583553418514 | 183.51862096460536 |
|     8      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 84.52303148806094  | 383.50319798337296 |
|     8      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 89.41743348259479  | 432.5502900755964  |
|     8      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 217.76640450116247 | 943.9354750793427  |
|     8      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 303.0781910638325  | 1225.4394043702632 |
|     8      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 3470.8542854059488 | 12194.579601055011 |
|     8      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 2268.1174043100327 | 7608.0941944383085 |
|     8      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 12.289720651460811 | 95.88620596332476  |
|     8      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.618648946750909 | 95.56685149436818  |
|     8      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 31.567946751601994 | 180.62468653079122 |
|     8      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 28.611703700153157 | 189.4215695792809  |
|     8      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 84.11306998459621  | 385.25596749968827 |
|     8      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 93.82540901424363  | 455.77428903197875 |
|     8      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 226.80530551588163 | 965.8026450779289  |
|     8      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 327.4116570246406  | 1312.5067745568228 |
|     8      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 4445.5064804060385 | 15020.768146496266 |
|     8      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 2433.0302356975153 | 8300.016750581563  |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
```

Performance on this branch:
```Markdown
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal |     dtype      |    forward_time    |   backward_time    |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
|     1      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 12.783618393586949 | 65.59692794689909  |
|     1      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 12.064015300711617 | 56.99719698168337  |
|     1      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 16.629025398287922 | 68.65267595276237  |
|     1      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 17.462356004398313 | 68.35797848179936  |
|     1      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 |  29.5476081490051  | 101.22994752600789 |
|     1      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 28.395320149138573 | 98.62275794148445  |
|     1      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 50.50016101449728  | 181.4357690163888  |
|     1      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 49.450615647947416 | 175.86063902126625 |
|     1      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 506.06461532879626 | 1866.0613044630736 |
|     1      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 299.9336270149797  | 976.4662646921353  |
|     1      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.45752210286446  | 58.79682704107836  |
|     1      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.407129396684468 | 58.14061599085107  |
|     1      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 13.822759891627355 | 56.56979401828722  |
|     1      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 13.39154909946956  |  56.7130644340068  |
|     1      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 20.282494352431968 | 77.29688903782517  |
|     1      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 19.899454596452415 |  75.4446149803698  |
|     1      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 48.494275606935844 | 177.5322465109639  |
|     1      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 46.84524350450374  | 189.1778860008344  |
|     1      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 635.1026654010639  | 2248.0451600858937 |
|     1      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 335.1591735263355  | 1080.4320796160027 |
|     4      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 13.63953539985232  | 65.50709309522063  |
|     4      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 12.858113402035087 | 63.021871959790595 |
|     4      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 19.98318645055406  | 105.87883047992364 |
|     4      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 18.619045056402683 | 104.90188701078296 |
|     4      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 45.91175540117546  | 226.00732848513871 |
|     4      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 44.39614630537107  | 232.39317198749632 |
|     4      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 135.5409600073472  | 522.7949097752571  |
|     4      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 158.79383607534692 | 628.5856699105352  |
|     4      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 1775.9978299727663 | 6343.203847063706  |
|     4      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 1160.680354805663  | 3842.235009651631  |
|     4      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.553713708417488 | 65.50691701704638  |
|     4      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.486379051348194 |  56.9980075233616  |
|     4      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 17.56585600087419  | 107.89892700267956 |
|     4      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 16.828144202008843 | 109.05519902007653 |
|     4      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 48.23235589428805  | 217.8974545095116  |
|     4      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 49.09284680034033  | 244.73925953498107 |
|     4      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 134.77827049791813 | 522.7259948151186  |
|     4      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 176.60772847011688 | 681.5171707421541  |
|     4      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 2267.821540008299  | 7720.425300067291  |
|     4      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 1295.3941145678982 | 4272.425139788538  |
|     8      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 14.514714101096615 |  94.2192979855463  |
|     8      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 13.553097198018804 |  93.244242540095   |
|     8      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 27.95821905019693  | 185.0469880155288  |
|     8      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 26.709681446664035 | 184.22623950755226 |
|     8      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 85.85420495364815  | 388.3417735341937  |
|     8      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 89.97473795898259  | 434.4228169647977  |
|     8      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 220.6919804448262  | 958.9654899900779  |
|     8      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 306.55586952343583 | 1233.2170095760375 |
|     8      |    16     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 3470.7326447824016 | 12183.611298678443 |
|     8      |    16     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 2299.064100370742  | 7669.618452200666  |
|     8      |    32     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 12.427107692928985 | 96.96270158747211  |
|     8      |    32     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 11.856995843118057 | 96.38117247959599  |
|     8      |    32     |    256    |    256     |   2048    |   True    | torch.bfloat16 |  32.9956392000895  | 182.52741603646427 |
|     8      |    32     |    256    |    256     |   2048    |   False   | torch.bfloat16 | 29.397601098753512 | 191.0755339777097  |
|     8      |    32     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 89.06024845782667  | 392.2585004474967  |
|     8      |    32     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 97.78487798757851  | 462.07307645818213 |
|     8      |    32     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 |  240.521906001959  | 992.4693452194335  |
|     8      |    32     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 341.98952303268015 | 1339.2950996058062 |
|     8      |    32     |   4096    |    2048    |   2048    |   True    | torch.bfloat16 | 4445.311005110853  | 15001.030603889374 |
|     8      |    32     |   4096    |    2048    |   2048    |   False   | torch.bfloat16 | 2535.9767401823774 | 8528.990152990447  |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
```

```
{'avg_forward_time_nan_fix': 399.7900972732653,
 'avg_backward_time_nan_fix': 1409.652114014413,
 'avg_forward_time_main_branch': 394.6807206988645,
 'avg_backward_time_main_branch': 1399.4055472857629,
 'geo_mean_nan_fix': 150.95049601244946,
 'geo_mean_main_branch': 148.3381648508822}
 ```

The y axis is wrong and is micro seconds but the relative comparison still works
<img width="790" alt="Screenshot 2024-03-18 at 3 34 15 PM" src="https://github.com/pytorch/pytorch/assets/32754868/ca278c15-b815-4535-bdcd-07e522055466">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122135
Approved by: https://github.com/cpuhrsch
2024-03-19 16:32:00 +00:00
Xinya Zhang
a37e22de70 Add Flash Attention support on ROCM (#121561)
This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in the next release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/malfet, https://github.com/atalman
2024-03-12 01:16:53 +00:00
y-sq
393b4ab432 Fixes issue_119785 (#121048)
Fixes #ISSUE_119785

- Removed all sentinel files of `test_causal_variants_.*`.

- The `test_causal_variants_causal_variant_` tests could pass after removing the dynamo_skips files.

- The `test_causal_variants_compile_causal_variant` fails with `PYTORCH_TEST_WITH_DYNAMO=1`. These tests already call torch.compile, so added @skipIfTorchDynamo to skip them for `PYTORCH_TEST_WITH_DYNAMO`.

**Tests**
```
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest test_transformers.py -v -k "test_causal_variants"
================================================================== test session starts ==================================================================
platform linux -- Python 3.10.13, pytest-7.4.0, pluggy-1.0.0 -- /home/shuqiyang/.conda/envs/pytorch/bin/python
cachedir: .pytest_cache
rootdir: /data/users/shuqiyang/pytorch
configfile: pytest.ini
collected 77250 items / 77218 deselected / 32 selected
Running 32 items in this shard

test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu PASSED [0.7745s]                  [  3%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu PASSED [0.8020s]                  [  6%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0385s] (Lower righ...) [  9%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu PASSED [0.5046s]                  [ 12%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu PASSED [0.6483s]                   [ 15%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu PASSED [0.8537s]                   [ 18%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu PASSED [0.8388s]                   [ 21%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu PASSED [0.4859s]                   [ 25%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu SKIPPED [0.0084s] (Th...) [ 28%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu SKIPPED [0.0086s] (Th...) [ 31%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0081s] (Th...) [ 34%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu SKIPPED [0.0085s] (Th...) [ 37%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu SKIPPED [0.0082s] (Thi...) [ 40%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu SKIPPED [0.0085s] (Thi...) [ 43%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu SKIPPED [0.0081s] (Thi...) [ 46%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu SKIPPED [0.0085s] (Thi...) [ 50%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda PASSED [9.4185s]                [ 53%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda PASSED [0.4273s]                [ 56%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0280s] (Lower ri...) [ 59%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda PASSED [8.0999s]                [ 62%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda PASSED [0.3785s]                 [ 65%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda PASSED [0.3818s]                 [ 68%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda PASSED [0.3864s]                 [ 71%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda PASSED [0.7668s]                 [ 75%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda SKIPPED [0.0089s] (...) [ 78%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda SKIPPED [0.0087s] (...) [ 81%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0087s] (...) [ 84%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda SKIPPED [0.0084s] (...) [ 87%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda SKIPPED [0.0087s] (T...) [ 90%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda SKIPPED [0.0087s] (T...) [ 93%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda SKIPPED [0.0084s] (T...) [ 96%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda SKIPPED [0.0087s] (T...) [100%]

=================================================== 14 passed, 18 skipped, 77218 deselected in 39.72s ===================================================
```
```
$ pytest test_transformers.py -v -k "test_causal_variants"
================================================================== test session starts ==================================================================
platform linux -- Python 3.10.13, pytest-7.4.0, pluggy-1.0.0 -- /home/shuqiyang/.conda/envs/pytorch/bin/python
cachedir: .pytest_cache
rootdir: /data/users/shuqiyang/pytorch
configfile: pytest.ini
collected 77250 items / 77218 deselected / 32 selected
Running 32 items in this shard

test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu PASSED [0.2410s]                  [  3%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu PASSED [0.3984s]                  [  6%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0011s] (Lower righ...) [  9%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu PASSED [0.0095s]                  [ 12%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu PASSED [0.1749s]                   [ 15%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu PASSED [0.2138s]                   [ 18%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu PASSED [0.2715s]                   [ 21%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu PASSED [0.0108s]                   [ 25%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu PASSED [0.4864s]          [ 28%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu PASSED [0.5346s]          [ 31%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu SKIPPED [0.0011s] (Lo...) [ 34%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu PASSED [0.1722s]          [ 37%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu PASSED [0.2341s]           [ 40%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu PASSED [0.4786s]           [ 43%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu PASSED [0.4635s]           [ 46%]
test_transformers.py::TestAttnBiasCPU::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu PASSED [0.0861s]           [ 50%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda PASSED [9.7579s]                [ 53%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda PASSED [0.0044s]                [ 56%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0007s] (Lower ri...) [ 59%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda PASSED [9.2065s]                [ 62%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda PASSED [0.0081s]                 [ 65%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda PASSED [0.0063s]                 [ 68%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda PASSED [0.0059s]                 [ 71%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda PASSED [0.0055s]                 [ 75%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda PASSED [0.1200s]        [ 78%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cuda PASSED [0.1032s]        [ 81%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cuda SKIPPED [0.0010s] (...) [ 84%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cuda PASSED [0.1151s]        [ 87%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cuda PASSED [0.0705s]         [ 90%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cuda PASSED [0.0713s]         [ 93%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cuda PASSED [0.0696s]         [ 96%]
test_transformers.py::TestAttnBiasCUDA::test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda PASSED [0.1516s]         [100%]

=================================================== 28 passed, 4 skipped, 77218 deselected in 39.23s ====================================================
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121048
Approved by: https://github.com/zou3519
2024-03-05 20:19:02 +00:00
drisspg
2e6c08a14b Update flash_attention kernel from 2.3.6 to 2.5.5 (#118935)
# Summary
Updates FlashAttention kernel code from tag [2.3.6](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.3.6) to [2.5.3](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.5).

The usual changes were then re-rellod on top of the modified kernel, changing how dropout saved for backward, removing the head_dim_pad since this would make the kernel inplace mutate and that has a bad interaction with functionalization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118935
Approved by: https://github.com/cpuhrsch
2024-03-04 17:36:22 +00:00
PyTorch MergeBot
1458f1de66 Revert "Update flash_attention kernel from 2.3.6 to 2.5.5 (#118935)"
This reverts commit 4b7a521856.

Reverted https://github.com/pytorch/pytorch/pull/118935 on behalf of https://github.com/atalman due to Significantly increases build time. Optimization is needed ([comment](https://github.com/pytorch/pytorch/pull/118935#issuecomment-1971723284))
2024-02-29 18:42:21 +00:00
drisspg
4b7a521856 Update flash_attention kernel from 2.3.6 to 2.5.5 (#118935)
# Summary
Updates FlashAttention kernel code from tag [2.3.6](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.3.6) to [2.5.3](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.5).

The usual changes were then re-rellod on top of the modified kernel, changing how dropout saved for backward, removing the head_dim_pad since this would make the kernel inplace mutate and that has a bad interaction with functionalization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118935
Approved by: https://github.com/cpuhrsch
2024-02-28 19:31:15 +00:00
Eddie Yan
702e82da28 [cuDNN][Flash Attention] Minor cleanup for cuDNN SDPA (#120750)
Cleaning up before hopefully starting work on backward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120750
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2024-02-28 17:32:07 +00:00
Eddie Yan
cd380c794f [CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)
#113713

Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with `TORCH_CUDNN_MHA_ENABLED=1`.

CC @drisspg @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115663
Approved by: https://github.com/drisspg
2024-02-14 22:02:06 +00:00
atalman
244b124bb8 Add linux cpu test for 3.12 (#117853)
This is continuation of work: https://github.com/pytorch/pytorch/pull/113987

Co-authored-by: albanD <desmaison.alban@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117853
Approved by: https://github.com/albanD
2024-02-14 20:52:23 +00:00
CaoE
dfdbd73360 add Half support for flash attention (#119247)
Re-open for https://github.com/pytorch/pytorch/pull/118368.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119247
Approved by: https://github.com/drisspg, https://github.com/malfet
2024-02-07 05:57:41 +00:00
CK Luk
2ad3599a71 Add torch.backends.mha.get_fastpath_enabled to FUNC_INLINELIST (#118979)
Summary: Add torch.backends.mha.get_fastpath_enabled to FUNC_INLINELIST

Test Plan: See the one in D53154041
Reviewed By: yjhao, yanboliang, Yuzhen11

Differential Revision: D53154041

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118979
Approved by: https://github.com/yanboliang
2024-02-06 16:25:33 +00:00
Catherine Lee
f481835115 Revert "add Half support for flash attention on CPU (#118368)" (#119204)
This reverts commit a5a63db3bf.

Fixes #ISSUE_NUMBER

Reverts #118368

Got reverted internally but branch got deleted to automation didn't work

Mildly edited stack trace
```

...
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "torch/_dynamo/eval_frame.py", line 453, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 25, in inner
    return fn(*args, **kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 635, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "torch/fx/experimental/proxy_tensor.py", line 995, in trace
    res = super().trace(root, concrete_args)
  File "torch/_dynamo/eval_frame.py", line 453, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 25, in inner
    return fn(*args, **kwargs)
  File "torch/fx/_symbolic_trace.py", line 793, in trace
    (self.create_arg(fn(*args)),),
  File "torch/fx/experimental/proxy_tensor.py", line 665, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 357, in _functionalized_f_helper
    f_outs = fn(*f_args)
  File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 68, in inner_fn
    outs = fn(*args)
  File "torch/_functorch/_aot_autograd/utils.py", line 161, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 618, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
  File "torch/fx/interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
  File "torch/_functorch/_aot_autograd/traced_function_transforms.py", line 593, in run_node
    result = super().run_node(n)
  File "torch/fx/interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "torch/fx/interpreter.py", line 274, in call_function
    return target(*args, **kwargs)
  File "torch/_ops.py", line 571, in __call__
    return self_._op(*args, **kwargs)
  File "torch/_subclasses/functional_tensor.py", line 380, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
  File "torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 744, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 779, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 423, in proxy_call
    r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
  File "torch/fx/experimental/proxy_tensor.py", line 1225, in maybe_handle_decomp
    return CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
  File "torch/_decomp/decompositions.py", line 4322, in scaled_dot_product_flash_attention_for_cpu
    torch._check(
  File "torch/__init__.py", line 1133, in _check
    _check_with(RuntimeError, cond, message)
  File "torch/__init__.py", line 1116, in _check_with
    raise error_type(message_evaluated)
RuntimeError: query must be FP32, FP64, BF16 but got torch.float16

While executing %_scaled_dot_product_flash_attention_for_cpu : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default](args = (%l_q_, %l_k_, %l_v_), kwargs = {attn_mask: %l_attn_mask_})
Original traceback:
  File "executorch/backends/xnnpack/partition/graphs/sdpa.py", line 34, in forward
    return torch.nn.functional.scaled_dot_product_attention(
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119204
Approved by: https://github.com/kit1980
2024-02-05 18:24:53 +00:00
CaoE
a5a63db3bf add Half support for flash attention on CPU (#118368)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118368
Approved by: https://github.com/jgong5, https://github.com/Valentine233, https://github.com/drisspg
ghstack dependencies: #118367
2024-02-02 01:08:39 +00:00
drisspg
126c1621ce Add Support for CausalBias to torch compile (#116071)
Fixes #115363

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116071
Approved by: https://github.com/mlazos
2024-01-30 02:22:48 +00:00
Wei Wang
80cb6db90d [CUDA] [CI] Disable flash attention for sm87 architecture when the head dim > 192 (#117678)
Head dim > 192 requires A100/H100 (sm80 or sm90) per TORCH_CHECK [here](0c26565d5d/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp (L760)).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117678
Approved by: https://github.com/eqy, https://github.com/malfet
2024-01-27 01:22:47 +00:00
drisspg
4e29f01bf2 Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689)
# Summary
Simplification of Backend Selection

This PR deprecates the `torch.backends/cuda/sdp_kernel` context manager and replaces it with a new context manager `torch.nn.attention.sdpa_kernel`. This context manager also changes the api for this context manager.

For `sdp_kernel` one would specify the backend choice by taking the negation of what kernel they would like to run. The purpose of this backend manager was to only to be a debugging tool, "turn off the math backend" and see if you can run one of the fused implementations.

Problems:
- This pattern makes sense if majority of users don't care to know anything about the backends that can be run. However, if users are seeking to use this context manager then they are explicitly trying to run a specific backend.
- This is not scalable. We are working on adding the cudnn backend and this API makes it so so that more implementations will need to be turned off if user wants to explicitly run a given backend.
- Discoverability of the current context manager. It is somewhat un-intutive that this backend manager is in backends/cuda/init when this now also controls the CPU fused kernel behavior. I think centralizing to attention namespace will be helpful.

Other concerns:
- Typically backends (kernels) for operators are entirely hidden from users and implementation details of the framework. We have exposed this to users already, albeit not by default and with beta warnings. Does making backends choices even more explicit lead to problems when we potentially want to remove existing backends, (perhaps inputs shapes will get covered by newer backends).

A nice side effect is now that we aren't using the `BACKEND_MAP` in test_transformers many, many dynamo failures are passing for CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114689
Approved by: https://github.com/cpuhrsch
2024-01-24 22:28:04 +00:00
PyTorch MergeBot
2f84a9d37c Revert "[CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)"
This reverts commit 5aa92b5090.

Reverted https://github.com/pytorch/pytorch/pull/115663 on behalf of https://github.com/PaliC due to Unfortunately, this pr breaks cuda builds internally ([comment](https://github.com/pytorch/pytorch/pull/115663#issuecomment-1899388813))
2024-01-18 23:40:30 +00:00
Eddie Yan
5aa92b5090 [CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)
#113713

Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with `TORCH_CUDNN_MHA_ENABLED=1`.

CC @drisspg @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115663
Approved by: https://github.com/drisspg
2024-01-18 01:20:36 +00:00
Sun, Jiayi
d9b265adaf modify the conditions as PythonModuleVariable (#116856)
## Motivation
The current code of `value in [torch.backends.cudnn, torch.ops]` requires `value` to have the implementation of `__eq__`. If the value is a custom object and does not implement `__eq__`, dynamo will throw error. For example, ConvolutionOpContext, the custom 'torch._C.ScriptClass' object registered in IPEX, dynamo will throw the following error:

**torch._dynamo.exc.InternalTorchDynamoError: '__eq__' is not implemented for __torch__.torch.classes.ipex_prepack.ConvolutionOpContext**

I think this is a common issue, To avoid this issue, the PR replaces the current code `value in [torch.backends.cudnn, torch.ops]`with `isinstance(value, (torch.backends.cudnn.CudnnModule, torch._ops._Ops)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116856
Approved by: https://github.com/jansel
2024-01-15 11:10:57 +00:00
drisspg
19e93b85b9 Fixes last_dim stride check for singleton dimensions (#117001)
Fixes #116333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117001
Approved by: https://github.com/cpuhrsch
2024-01-10 04:46:49 +00:00
Valentine233
20c2ec9a15 [CPU] Add flash attention mask version (#115913)
Add a masked-version flash attention for CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115913
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-01-07 04:58:23 +00:00
PyTorch MergeBot
2ccc7af028 Revert "[CPU] Add flash attention mask version (#115913)"
This reverts commit 76a3fbb709.

Reverted https://github.com/pytorch/pytorch/pull/115913 on behalf of https://github.com/zou3519 due to broke transformer test on dynamo shard ([comment](https://github.com/pytorch/pytorch/pull/115913#issuecomment-1878043389))
2024-01-05 02:39:12 +00:00
Valentine233
76a3fbb709 [CPU] Add flash attention mask version (#115913)
Add a masked-version flash attention for CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115913
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-01-05 01:27:36 +00:00
Mikayla Gawarecki
d0cf2182ea Fix TransformerEncoderLayer for bias=False (#116760)
Fixes https://github.com/pytorch/pytorch/issues/116385

Don't call `torch._transformer_encoder_layer_fwd` when `bias=False`

`bias=False` was not something that `torch._transformer_encoder_layer_fwd`  was meant to work with, it was my bad that this wasn't tested as I approved https://github.com/pytorch/pytorch/pull/101687.

`bias=False` was causing the `tensor_args` in [`TransformerEncoder`](a17de2d645/torch/nn/modules/transformer.py (L663-L677)) to contain `None`s and error on checks for the fastpath like `t.requires_grad for t in tensor_args`.

Alternative fix would be to
1) Pass `torch.zeros_like({*}.weight)` to the kernel when `bias=False` and filter `tensor_args` as appropriate
2) Fix `torch._transformer_encoder_layer_fwd` to take `Optional<Tensor>` for biases and fix the kernels as appropriate

Let me know if these approaches are preferable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116760
Approved by: https://github.com/jbschlosser
2024-01-05 00:13:10 +00:00
Xinya Zhang
e3ca7346ce Re-add initial Flash Attention support on ROCM (#115981)
Note about the Updates:

This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.

CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.

Original PR (https://github.com/pytorch/pytorch/pull/114309) Note:

This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.

Fixes #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115981
Approved by: https://github.com/malfet
2024-01-04 22:21:31 +00:00
Mikayla Gawarecki
0f6f582c0d Add config to disable TransformerEncoder/MHA fastpath (#112212)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112212
Approved by: https://github.com/jbschlosser
2024-01-02 23:59:30 +00:00
Aaron Gokaslan
bd10fea79a [BE]: Enable F821 and fix bugs (#116579)
Fixes #112371

I tried to fix as many of the bugs as I could, a few I could not figure out what the proper fix for them was though and so I left them with noqas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116579
Approved by: https://github.com/ezyang
2024-01-01 08:40:46 +00:00
drisspg
1e834e0e50 Fix bug in mem_eff kernel with attention mask and MQA (#116234)
# Summary

Found using the repros mentioned in this issue: #112577

After many go rounds with compute-sanitizer and eventual printf debugging I feel pretty confident that this was the underlying issue

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116234
Approved by: https://github.com/malfet, https://github.com/danthe3rd, https://github.com/atalman
2023-12-21 21:52:21 +00:00
drisspg
65d3dde665 Fix allowed dtypes for mem_eff attention (#116026)
# Summary

Fix issue bug in detecting mem eff capability for cuda devices less than sm80:
https://github.com/pytorch-labs/gpt-fast/issues/49

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116026
Approved by: https://github.com/janeyx99
2023-12-21 01:56:38 +00:00
PyTorch MergeBot
af8a50e656 Revert "Fix allowed dtypes for mem_eff attention (#116026)"
This reverts commit fc58909bab.

Reverted https://github.com/pytorch/pytorch/pull/116026 on behalf of https://github.com/jeanschmidt due to breaking internal windows buck builds, check internal diff for more details ([comment](https://github.com/pytorch/pytorch/pull/116026#issuecomment-1864354665))
2023-12-20 12:01:34 +00:00
drisspg
fc58909bab Fix allowed dtypes for mem_eff attention (#116026)
# Summary

Fix issue bug in detecting mem eff capability for cuda devices less than sm80:
https://github.com/pytorch-labs/gpt-fast/issues/49

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116026
Approved by: https://github.com/janeyx99
2023-12-18 23:20:52 +00:00
Jeff Daily
e3aefe2970 Revert "Initial Flash Attention support on ROCM (#114309)" (#115975)
This reverts commit 5bddbed399.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115975
Approved by: https://github.com/atalman, https://github.com/malfet
2023-12-16 03:40:14 +00:00
Xinya Zhang
5bddbed399
Initial Flash Attention support on ROCM (#114309)
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.

Fixes https://github.com/pytorch/pytorch/issues/112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114309

Approved by: https://github.com/jeffdaily, https://github.com/malfet

---------

Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
2023-12-14 08:52:57 -08:00
Fuzzkatt
661c1cf2aa numerical mismatch fix for test_mem_efficient_attention_attn_mask_vs_math_ref_grads in test_transformers.py (#115707)
adjust dropout_fudge_factor since previous fudge factor was too small and led to numerical mismatch in NVIDIA internal CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115707
Approved by: https://github.com/drisspg
2023-12-14 01:04:39 +00:00
Valentine233
064846dbc2 [cpu] flash attention optimization (#115151)
### Modifications
- **EXP**: Add a fast version with a reduced accuracy (ULP20) to vec exp `exp_u20` and use it in flash attention.
- **FUSION**: Do fusion for `softmax` ops.
- **SCALE**: Move the calculation of `scaling_factor` after `gemm`.

### Performance
_Model: Stable Diffusion V2.1_

| Version | BF16 Kernel latency (s) | BF16 speedup | FP32 Kernel latency (s) | FP32 speedup |
| ----- | ----- | ----- | ----- | ----- |
| PT | 15.865 |  | 35.362 |  |
| PT + EXP | 12.518 | 21.10% | 19.327 | 45.35% |
| PT + EXP + FUSION | 11.774 | 25.79% | 18.306 | 48.23% |
| PT + EXP + FUSION + SCALE | 11.053 | 30.33% | 18.360 | 48.08% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115151
Approved by: https://github.com/jgong5, https://github.com/drisspg
2023-12-12 01:09:55 +00:00
drisspg
d4c79a3078 Add an attention bias subclass for a lower right causal masking (#114823)
# Summary
This PR introduces a new Tensor subclass that is designed to be used with torch.nn.functional.scaled_dot_product_attention. Currently we have a boolean `is_causal` flag that allows users to do do causal masking without the need to actually create the "realized" attention bias and pass into sdpa. We originally added this flag since there is native support in both fused kernels we support. This provides a big performance gain ( the kernels only need to iterate over ~0.5x the sequence, and for very large sequence lengths this can provide vary large memory improvements.

The flag was introduced when the early on in the kernel development and at the time it was implicitly meant to "upper_left" causal attention. This distinction only matters when the attention_bias is not square. For a more detailed break down see: https://github.com/pytorch/pytorch/issues/108108. The kernels default behavior has since changed, largely due to the rise of autogressive text generation. And unfortunately this would lead to a BC break. In the long term it may actually be beneficial to change the default meaning of `is_causal` to represent lower_right causal masking.

The larger theme though is laid here: https://github.com/pytorch/pytorch/issues/110681. The thesis being that there is alot of innovation in SDPA revolving around the attention_bias being used. This is the first in hopefully a few more attention_biases that we would like to add. The next interesting one would be `sliding_window` which is used by the popular mistral model family.

Results from benchmarking, I improved the meff_attention perf hence the slightly decreased max perf.
```Shell
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
|  Type   |      Speedup       | batch_size | num_heads | q_seq_len | k_seq_len | embed_dim |     dtype      | head_dim |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
| Average | 1.2388050062214226 |            |           |           |           |           |                |          |
|   Max   | 1.831672915579016  |    128     |    32     |   1024    |   2048    |   2048    | torch.bfloat16 |    64    |
|   Min   | 0.9430534166730135 |     1      |    16     |    256    |    416    |   2048    | torch.bfloat16 |   128    |
+---------+--------------------+------------+-----------+-----------+-----------+-----------+----------------+----------+
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114823
Approved by: https://github.com/cpuhrsch
2023-12-06 08:29:26 +00:00
drisspg
8556a09d44 Require less alignment for attn bias (#114173)
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)

This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
2023-11-28 02:40:41 +00:00
PyTorch MergeBot
88a8a0daa4 Revert "Require less alignment for masking (#114173)"
This reverts commit f882c175d8.

Reverted https://github.com/pytorch/pytorch/pull/114173 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but it is failing some inductor tests f882c175d8 ([comment](https://github.com/pytorch/pytorch/pull/114173#issuecomment-1823552362))
2023-11-22 21:49:31 +00:00
drisspg
f882c175d8 Require less alignment for masking (#114173)
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)

This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.

## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.

Should this be warn_once?

We only call expand, once on the aligned mask.

Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115

@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
2023-11-22 20:02:51 +00:00
drisspg
9b0f2f8d94 expose sdpa helpers to python (#110496)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110496
Approved by: https://github.com/jbschlosser
2023-11-15 07:34:34 +00:00
drisspg
14811d69d7 [BE] Cleanup sdpa test helper usage (#113294)
# Summary

standardizes usage of the rand_sdpa_tensor helper

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113294
Approved by: https://github.com/soulitzer
2023-11-09 01:16:53 +00:00
drisspg
e509b162ed Disable FlashAttenion for is_causal=True when seqlen q not equal kv (#111007)
# Summary:
This pull request **removes** support for non-square sequence lengths in causal attention when using FlashAttention V2.

### Why are doing this
  // FlashAttention 2 updated the default mask meaning for causal in this PR:
  // 9e5e8bc91e it is now aligned to lower_right which would be a BC break
  // for non-square masks. We will not support non-square masks for causal w/ FAV2

 For more context see:
 https://github.com/pytorch/pytorch/issues/108108

 ### Followup
 A large number of people will likely want to use FAV2 with lower_right causal attention for non equal sequence lengths. See this RFC : https://github.com/pytorch/pytorch/issues/110681

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111007
Approved by: https://github.com/cpuhrsch
2023-10-23 20:33:37 +00:00
drisspg
5183760ca5 Adding Backward Support for NestedTensors and FlashAttention (#97485)
# Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 318764f</samp>

This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`.

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at ed4a773</samp>

> _Fused kernels of doom, unleash the flash attention_
> _Nested tensors on fire, reshape and pad with caution_
> _Backward pass of power, dispatch the CUDA key_
> _Test the gradients of hell, warn the user if they disagree_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97485
Approved by: https://github.com/jbschlosser
2023-10-10 18:08:17 +00:00
Fuzzkatt
c28bb46445 Fix test_mem_efficient_attention_vs_math_ref_grads tolerance from test_transformers.py (#108094)
Tolerance currently too low, triggering test failures via numerical mismatch in NVIDIA internal testing for certain H100, A16, A40 configs. cc: @ptrblck @eqy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108094
Approved by: https://github.com/eqy, https://github.com/msaroufim
2023-10-02 20:42:57 +00:00
PyTorch MergeBot
8d6479725a Revert "Adding Backward Support for NestedTensors and FlashAttention (#97485)"
This reverts commit 28d69d5256.

Reverted https://github.com/pytorch/pytorch/pull/97485 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but one of the tests test_fused_kernels_nested_broadcasting_requires_grad_failure_cuda is failing on Windows CUDA f7ba3e85e2 ([comment](https://github.com/pytorch/pytorch/pull/97485#issuecomment-1743474468))
2023-10-02 17:48:57 +00:00
drisspg
28d69d5256 Adding Backward Support for NestedTensors and FlashAttention (#97485)
# Summary
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 318764f</samp>

This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`.

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at ed4a773</samp>

> _Fused kernels of doom, unleash the flash attention_
> _Nested tensors on fire, reshape and pad with caution_
> _Backward pass of power, dispatch the CUDA key_
> _Test the gradients of hell, warn the user if they disagree_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97485
Approved by: https://github.com/jbschlosser
2023-09-29 21:34:47 +00:00
drisspg
ad90ab31f2 Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-09-13 13:59:05 +00:00
Huy Do
a9c663c269 Revert "Flash Attention v2 (#105602)" (#108827)
This reverts commit add45aea1c.

There are some conflicts on some benchmark csv file https://github.com/pytorch/pytorch/pull/105602#issuecomment-1710988951 so I need to revert this manually.

The diff has been reverted internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108827
Approved by: https://github.com/kit1980
2023-09-08 07:43:04 +00:00
PyTorch MergeBot
e45b290127 Revert "Revert "Flash Attention v2 (#105602)" (#108827)"
This reverts commit 24e9bbe22a.

Reverted https://github.com/pytorch/pytorch/pull/108827 on behalf of https://github.com/huydhn due to I need to land this revert properly as there are new failures showing up on trunk ([comment](https://github.com/pytorch/pytorch/pull/108827#issuecomment-1711020924))
2023-09-08 03:25:45 +00:00
Huy Do
24e9bbe22a Revert "Flash Attention v2 (#105602)" (#108827)
This reverts commit add45aea1c.

There are some conflicts on some benchmark csv file https://github.com/pytorch/pytorch/pull/105602#issuecomment-1710988951 so I need to revert this manually.

The diff has been reverted internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108827
Approved by: https://github.com/kit1980
2023-09-08 02:54:20 +00:00
Michael Gschwind
2a40fe2dbf [experimental] use EXCEPT_FOR env to suppress CPU tests from GPU RE (#108672)
Summary:
[experimental] use EXCEPT_FOR env to suppress CPU tests from GPU RE -- alternative implementation to D48997976 using preexisting PYTORCH_TESTING_DEVICE_EXCEPT_FOR facility and building remaining logic (for assert-positive listers like test_transformers)  on top of that.

Goal: save ~100 GPU (10% of capacity), enables us to fund more aggressive PyPer unit testing on GPU RE

Test Plan: sandcastle, github

Differential Revision: D48998582

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108672
Approved by: https://github.com/bertmaher
2023-09-06 23:33:18 +00:00