Commit Graph

232 Commits

Author SHA1 Message Date
cyy
df458be4e5 [4/N] Apply py39 ruff and pyupgrade fixes (#143257)
```torch/fx/passes/annotate_getitem_nodes.py``` was changed to support the new type hinting annotations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143257
Approved by: https://github.com/justinchuby, https://github.com/albanD
2025-01-04 10:47:51 +00:00
Xiaodong Wang
0a94bb432e [ROCm] CK Flash Attention Backend (#143695)
Replace https://github.com/pytorch/pytorch/pull/138947 for re-import.

Replaces https://github.com/ROCm/pytorch/pull/1592

This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143695
Approved by: https://github.com/malfet

Co-authored-by: Andy Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
2025-01-03 22:01:36 +00:00
Tom Ritchford
f1cbf4b1b5 Enable ruff's unused variable checking everywhere in pytorch (#136965)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136965
Approved by: https://github.com/cyyever, https://github.com/albanD
2024-12-22 02:33:11 +00:00
Tom Ritchford
993b2f0ee0 Fix unused variables in test/test_transformers.py (#143407)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143407
Approved by: https://github.com/drisspg
2024-12-18 09:59:24 +00:00
PyTorch MergeBot
969b07b96f Revert "[ROCm] CK Flash Attention Backend (#138947)"
This reverts commit 500d02921b.

Reverted https://github.com/pytorch/pytorch/pull/138947 on behalf of https://github.com/atalman due to Breaks default windows checkout ([comment](https://github.com/pytorch/pytorch/pull/138947#issuecomment-2548998359))
2024-12-17 16:46:57 +00:00
Andy Lugo
500d02921b [ROCm] CK Flash Attention Backend (#138947)
Replaces https://github.com/ROCm/pytorch/pull/1592

This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling `torch.backends.cuda.preferred_rocm_fa_library("ck")`. Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via `USE_FLASH_ATTENTION`) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138947
Approved by: https://github.com/pruthvistony, https://github.com/xw285cornell, https://github.com/leitian

Co-authored-by: Xiaodong Wang <xw285@cornell.edu>
2024-12-17 02:18:07 +00:00
Xinya Zhang
424156c26c [ROCm] Update to AOTriton 0.8b (#140172)
Notable new features for SDPA operators on AMD systems from AOTriton 0.8b:

1. Nestedtensor support;
2. MQA/GQA support;
3. Restore Efficient attention support for causal=True and seqlen_q != seqlen_k cases;
    + The kernel should use top-left alignment, bottom right alignment will be added later
4. Move gfx1100 (RX7900/W7800/W7900) out of experimental support status.
   However, users are strongly recommended to update to ROCM 6.2.4, notably for
   its firmware updates.

Related unit tests are enabled as well.

Notable related changes from AOTriton 0.8b:

1. AOTriton 0.8b moves the GPU kernel out of libaotriton.so to a separate directory `aotriton.images`;
2. LZMA replaces ZSTD as GPU kernel compression algorithm for better compression ratio: aotriton0.8b (.so + aotriton.images take 350MB) compared to aotriton0.7b .so: 800MB
3. The compression cannot be disabled now, and `liblzma` is hard run-time dependency.
    + Should not be a problem, since `lzma` is part of Python Standard Library

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140172
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily

Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
2024-12-06 21:45:18 +00:00
eqy
8fc6d3a5d8 [SDPA] Allow user-specified priority order with context manager (#140467)
TODO: docs changes?
For better debuggability of issues like https://github.com/pytorch/pytorch/issues/139298

Better testing, current sketch:

``` Python
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel

q = torch.randn(64, 1024, 8, 64, dtype=torch.half, device='cuda')
print(torch._C._get_sdp_priority_order())

orders = [[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION],
          [SDPBackend.MATH, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION],
          [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]]
import time
times = list()
for order in orders:
    print(order)
    with sdpa_kernel(order, set_priority=True):
        scaled_dot_product_attention(q, q, q)
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    with sdpa_kernel(order, set_priority=True):
        scaled_dot_product_attention(q, q, q)
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    times.append(t1 - t0)
print(times)
assert times[0] < times[1]
assert times[0] > times[2]
assert times[1] > times[2]
print(torch._C._get_sdp_priority_order())
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140467
Approved by: https://github.com/drisspg
2024-12-06 07:56:35 +00:00
drisspg
02b52572db Lint: switch oncall owner for test_transformers (#141722)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141722
Approved by: https://github.com/malfet
2024-11-27 21:45:43 +00:00
drisspg
7671dd436e [SDPA-CPU] Fix Edge case w/ fused flash cpu kernel (#141519)
Fixes https://github.com/pytorch/pytorch/issues/141128

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141519
Approved by: https://github.com/jgong5, https://github.com/jbschlosser
2024-11-26 22:07:56 +00:00
PyTorch MergeBot
ad37afd590 Revert "Always unspecialize float in OSS (#138922)"
This reverts commit ba5253da9b.

Reverted https://github.com/pytorch/pytorch/pull/138922 on behalf of https://github.com/yf225 due to perf regression on torchbench ([comment](https://github.com/pytorch/pytorch/pull/138922#issuecomment-2499277511))
2024-11-26 00:03:03 +00:00
Bob Ren
ba5253da9b Always unspecialize float in OSS (#138922)
Fixes https://github.com/pytorch/pytorch/issues/107277

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138922
Approved by: https://github.com/ezyang

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
2024-11-24 01:58:13 +00:00
PyTorch MergeBot
a8c90e5140 Revert "Always unspecialize float in OSS (#138922)"
This reverts commit 6d779d0549.

Reverted https://github.com/pytorch/pytorch/pull/138922 on behalf of https://github.com/huydhn due to Sorry for reverting your change but there is some slow tests failing after this land ([comment](https://github.com/pytorch/pytorch/pull/138922#issuecomment-2495076878))
2024-11-22 23:18:36 +00:00
Bob Ren
6d779d0549 Always unspecialize float in OSS (#138922)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138922
Approved by: https://github.com/ezyang

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
2024-11-22 17:54:42 +00:00
zeshengzong
cb71bcc542 Replace clone.detach with detach.clone (#140264)
Fixes #64532

As state in issue, replace `clone.detach` by `detach.clone`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140264
Approved by: https://github.com/soulitzer
2024-11-13 07:01:02 +00:00
Eddie Yan
1565eba4b4 [cuDNN][SDPA] Match query's memory layout ordering for output in cuDNN SDPA (#138354)
For #138340

~~We might consider more sophisticated logic here but the corresponding logic in other backends doesn't seem to do anything fancy for non BSHD/BHSD cases ea8ea2f33f/aten/src/ATen/native/transformers/cuda/attention.cu (L1145~~)

ended up going with a more general approach to much more or less arbitrary layouts

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138354
Approved by: https://github.com/drisspg
2024-11-04 23:49:09 +00:00
drisspg
80c7c7178e Make sure all SDPA tests are ran with tensor cores enabled (#135592)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135592
Approved by: https://github.com/eqy
2024-10-29 20:53:10 +00:00
eqy
07b0d633b8 [cuDNN][SDPA] Bail out of cuDNN SDPA for seqlen 1 inputs (#138531)
Forwarded #138529 to the cuDNN team but for now but we want to avoid dispatching to unsupported cases

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138531
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2024-10-29 01:03:36 +00:00
eqy
f7bb11dcc2 [cuDNN][cuDNN Frontend] Check in test for previously broken dBias check (#138725)
see https://github.com/pytorch/pytorch/issues/137347, let's try to land before https://github.com/pytorch/pytorch/pull/138709

CC @malfet @drisspg @Skylion007

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138725
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2024-10-24 15:33:58 +00:00
drisspg
c92459488b Fix test on windows (#138641)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138641
Approved by: https://github.com/huydhn
2024-10-23 21:53:32 +00:00
drisspg
9a9a0abc28 [SDPA-CUDNN] Make CuDNN Attention Opt in (#138522)
# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. https://github.com/pytorch/pytorch/issues/138529
2. https://github.com/huggingface/diffusers/issues/9704
3. https://github.com/pytorch/pytorch/pull/138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).

Cc @atalman

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138522
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet
2024-10-22 07:23:06 +00:00
Eddie Yan
d52b2cf92f [CUDA][SDPA] Fix TF32 handling and bump threshold for multiheadattention test (#137752)
For sm90, main issue was that `torch.testing.assert_close` bypasses the `tf32_on_and_off` tolerance switch decorator

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137752
Approved by: https://github.com/ezyang
2024-10-12 03:05:21 +00:00
eqy
6be53d52c5 [CUDA][SDPA] Bump tolerances for grad_query in mem_eff test (#137750)
(for sm80)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137750
Approved by: https://github.com/drisspg
2024-10-12 02:15:14 +00:00
Simon Fan
b86269fab5 Unify cpp_extension build directory removal (#136059)
Keeps existing default directory clearing logic, even though it fails when TORCH_EXTENSIONS_DIR is set. To properly clear, we'd need to track all the folders we compiled the extensions to.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136059
Approved by: https://github.com/ezyang, https://github.com/albanD
2024-10-03 06:22:11 +00:00
drisspg
5c0ce8d0a6 Skip Flaky Test: for #134602 (#137226)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137226
Approved by: https://github.com/cpuhrsch
2024-10-03 01:53:59 +00:00
eqy
be423a8480 [SDPA] Bump grad_query fudge factor for Flash Attention (#135711)
Tolerance issue for small GPUs e.g., (A16, A2)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135711
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2024-10-02 21:35:00 +00:00
Jianyu Huang
0a35986cdb Add option to configure reduced precision math backend for SDPA (#135964)
Summary: Address https://github.com/pytorch/pytorch/issues/135778 by adding a global flag to configure whether using high precision or low precision for math backend of SDPA.

Test Plan: buck2 run mode/opt //scripts/feikou/llm:run_attn_kernels

Differential Revision: D62625515

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135964
Approved by: https://github.com/jbschlosser
2024-09-24 07:11:38 +00:00
Xinya Zhang
74fd1bf965 [ROCm] Update to AOTriton 0.7b (#134498)
Notable changes:
1. Enable CudaGraph related tests
2. Fix UT problems
3. EXPERIMENTAL Navi31 support. User should enable Navi31 support with Env Var `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1`

Know Problem:
1. `test/test_transformers.py` will massive failures and/or NaN outputs with `--use-pytest`
    + Update: Confirmed skip `class TestSDPAPrivateUse1Only` can fix the problem with `--use-pytest`

Note:
AOTriton 0.7b adds support to nestedtenosrs+SDPA but need more work (and consequently a separate PR) to enable it.

Fixes #133540

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134498
Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily, https://github.com/malfet
2024-09-11 20:34:01 +00:00
Eddie Yan
3084b7b5c0 [cuDNN][SDPA] Support attn_bias in cuDNN (#130482)
CC @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130482
Approved by: https://github.com/drisspg, https://github.com/Skylion007, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2024-09-11 05:59:25 +00:00
CaoE
a6fae2e811 Use BRGEMM for Half flash attention forward kernel (#131879)
Use oneDNN BRGEMM on packed data to get better performance on the 5th generation of Xeon where Intel® Advanced Matrix Extensions (AMX) will have fp16 support, e.g. amx-fp16.
Multiple models have achieved acceleration, for instance, FP16 stable diffusion v2.1 has achieved over 50% improvement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131879
Approved by: https://github.com/jgong5, https://github.com/peterbell10
ghstack dependencies: #131878
2024-09-08 12:32:23 +00:00
Valentine233
0dbc72887b [CPU][flash attention] make the stride of output align with input (#134656)
Fixes #133671

Currently, the output of CPU flash attention has a fixed layout, no matter what the input is. This PR makes the stride of output align with input q/k/v, which is the same behavior as math backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134656
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-08-29 16:04:25 +00:00
Xinya Zhang
46ecc673ae [ROCm] Prevent accidental enablement of efficient attention. (#133331)
Currently Efficient attention and Flash attention share the same set of GPU
kernels on ROCM and have common limitations on head sizes.

Fixes https://github.com/pytorch/pytorch/issues/132004

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133331
Approved by: https://github.com/malfet, https://github.com/jithunnair-amd
2024-08-27 00:03:45 +00:00
eqy
e93ca12c88 [CUDNN][SDPA] Fix unsupported trivial stride-1 transpose case (#134031)
Fixes #134001
Incorrect assumption that two same-shape tensors being contiguous meant that they would have the same stride

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134031
Approved by: https://github.com/drisspg, https://github.com/Skylion007

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2024-08-25 14:31:30 +00:00
drisspg
fb26b84390 Update fused kernels and call _safe_softmax from SDPA (#133882)
# UPDATE:
This is  take 3 of https://github.com/pytorch/pytorch/pull/131863 which was landed via co dev but not applying correclty

# Summary
Changes the stance of SDPA on what to do for fully masked out rows

## Current Behavior
Several PyTorch users have expressed frustration over this issue:
- https://github.com/pytorch/pytorch/issues/41508
- https://github.com/pytorch/pytorch/issues/103749
- https://github.com/pytorch/pytorch/issues/103963

These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here:
https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617

Can be paraphrased as follows:

When passing in fully masked out rows, attention becomes ambiguous. We have two main options:

1. Uniformly attend to all values:
   ```python
   scores[masked_out_rows] = 1 / len(row)
   out[masked_out_rows] = 1 / len(row) * value
   ```

2. Decide that attention between no queries (masked) and no keys (masked) is meaningless:
   ```python
   output[fully_masked_rows] = NaN
   ```

We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs:
``` Python
>fill_value = -float("inf")
>row0 = torch.randn(4)
>row1 = torch.tensor([(fill_value for _ in range(4)])
>matrix = torch.stack([row0, row1]).requires_grad_(True)
>out = torch.softmax(matrix, 1)
>out = out[0]
>print(out)
tensor([0.5377, 0.2729, 0.0692, 0.1201])
```
Cool, problem solved. But what happends when you call backwards..
```Python
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08],
        [       nan,        nan,        nan,        nan]])
```
Those pesky NaNs are back!

## Why do we see NaNs today?

The core of the problem revolves around using softmax function in sdpa:

```python
> row = torch.tensor([(-float("inf")) for _ in range(4)])
> torch.softmax(row, 0)
tensor([nan, nan, nan, nan])
```

## Quick Aside: Masking in Attention

Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs.

We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values.

## Alternative Approaches

If we use a very large negative number instead of -inf:

```python
> row = torch.tensor([(-1e6) for _ in range(4)])
> torch.softmax(row, 0)
tensor([0.2500, 0.2500, 0.2500, 0.2500])
```
However if users always remembered to "slice" out their outputs i.e.:
```Python
>fill_value = -1e6
>...
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[-0.0563, -0.0564,  0.1613, -0.0486],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])
```
This would bring us back into a better state.

## A Third Option

We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation.

This PR implements the new semantic for masking w/ attention in fully masked-out rows:
```python
out[masked_out_rows] = 0
```

**Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption.

## Details
This PR stack does 3 things:
1. Adds a PRIVATE _safe_softmax op
2. Updates semantic for flash_cpu fused kernel
3. Updates semantic for efficient_cuda fused kernel

_safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num.

Why I think this is okay? (please find a counter point if avail)
There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them?

The only case that this can happen is if the input itself had a NaN or an Inf
For example:
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = torch.finfo(torch.float16).max
print(a.softmax(-1))
```
Will return
`tensor([0., 1., 0., 0.], dtype=torch.float16)`

Where
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = float("inf")
a.softmax(-1)
```
returns:
`tensor([nan, nan, nan, nan], dtype=torch.float16)`

If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this

```Python
max = torch.max(a, dim=-1, keepdim=True)
exp = torch.exp(a - max.values)
denom = torch.sum(exp, dim=-1, keepdim=True)
softmax = exp / denom
softmax = torch.where(max.values == float('-inf'), 0.0, softmax)
```
however we would be paying for this in math performance.

## Why Now
I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic.

Differential Revision: [D61418679](https://our.internmc.facebook.com/intern/diff/D61418679)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133882
Approved by: https://github.com/soulitzer
2024-08-19 18:53:11 +00:00
eqy
c0c82a5f6a [CUDA][SDPA] Bump tolerances for test_mem_efficient_attention_attn_mask_vs (#133738)
Same thing as #133051 but for efficient attention

CC @drisspg @nWEIdia

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133738
Approved by: https://github.com/drisspg, https://github.com/nWEIdia, https://github.com/Skylion007
2024-08-18 19:14:29 +00:00
PyTorch MergeBot
cfec69e2a1 Revert "Update fused kernels and call _safe_softmax from SDPA (#131863)"
This reverts commit caba37e99b.

Reverted https://github.com/pytorch/pytorch/pull/131863 on behalf of https://github.com/izaitsevfb due to breaks executorch test executorch/backends/apple/coreml:test - test_vit_skip_conv (executorch.backends.apple.coreml.test.test_coreml_partitioner.TestCoreMLPartitioner) ([comment](https://github.com/pytorch/pytorch/pull/131863#issuecomment-2291855634))
2024-08-15 17:55:07 +00:00
drisspg
caba37e99b Update fused kernels and call _safe_softmax from SDPA (#131863)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131863
Approved by: https://github.com/jbschlosser, https://github.com/Chillee
2024-08-13 23:37:50 +00:00
PyTorch MergeBot
4cca18d5b6 Revert "Update fused kernels and call _safe_softmax from SDPA (#131863)"
This reverts commit e61def65d5.

Reverted https://github.com/pytorch/pytorch/pull/131863 on behalf of https://github.com/albanD due to Broke forward AD tests in main ([comment](https://github.com/pytorch/pytorch/pull/131863#issuecomment-2286432628))
2024-08-13 14:44:08 +00:00
drisspg
e61def65d5 Update fused kernels and call _safe_softmax from SDPA (#131863)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131863
Approved by: https://github.com/jbschlosser
2024-08-13 00:51:55 +00:00
Apurva Jain
79ca596dc6 Optimize test_transformers.py (#133049)
- Reduced number of skipped test cases
- Merged redundant test cases

**Benchmark:**

| | Original | New |
| ----- | ----- | ----- |
| Run time | 60 mins | 35 mins |
| Total tests | 75k | 18k |
| Skipped tests | 20k | 4k |

_These are approximate numbers from running test_transformers.py on a single H100, and can change based on the device._

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133049
Approved by: https://github.com/drisspg
2024-08-11 05:20:58 +00:00
eqy
c89936eaa0 [CUDA][SDPA] Bump grad_key fudge factor in test_flash_attention_vs_math_ref_grads (#133051)
Abates failures like `ValueError: grad_key Test error 1.592235639691353e-05 is greater than threshold 1.5236437320709229e-05!` that we've seen when bringing up newer versions of CUDA

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133051
Approved by: https://github.com/drisspg, https://github.com/Skylion007
2024-08-10 01:49:30 +00:00
Apurva Jain
8bc5ef563e Grouped Query Attention (#132689)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Differential Revision: D60772086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689
Approved by: https://github.com/drisspg
2024-08-07 05:35:36 +00:00
Jianyu Huang
c7cfa51721 Always use high precision for SDPA math backend (#128922)
Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.

Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.

Differential Revision: D58710805

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922
Approved by: https://github.com/xw285cornell, https://github.com/drisspg
2024-08-04 23:58:14 +00:00
Mikayla Gawarecki
f49d5e30eb Change owners of test/test_transformers.py to module: multi-headed-attention (#132519)
So flaky tests get tagged with `module: multi-headed-attention` instead of `module: nn`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132519
Approved by: https://github.com/Skylion007
2024-08-02 20:12:33 +00:00
PyTorch MergeBot
bcb4f7c172 Revert "Grouped Query Attention (#128898)"
This reverts commit 6b28af1b79.

Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/ZainRizvi due to Sorry, this broke a bunch of tests internally. See D60638265 ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2265961038))
2024-08-02 18:58:46 +00:00
PyTorch MergeBot
59b73079a0 Revert "Always use high precision for SDPA math backend (#128922)"
This reverts commit fbf3bc0a60.

Reverted https://github.com/pytorch/pytorch/pull/128922 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR has a dependency on another PR (https://github.com/pytorch/pytorch/pull/128898) that has to be reverted ([comment](https://github.com/pytorch/pytorch/pull/128922#issuecomment-2265949958))
2024-08-02 18:46:50 +00:00
Jianyu Huang
fbf3bc0a60 Always use high precision for SDPA math backend (#128922)
Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.

Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.

Differential Revision: D58710805

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922
Approved by: https://github.com/xw285cornell, https://github.com/drisspg
2024-08-01 18:55:48 +00:00
jainapurva
6b28af1b79 Grouped Query Attention (#128898)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-31 22:58:51 +00:00
PyTorch MergeBot
499ead96ff Revert "Grouped Query Attention (#128898)"
This reverts commit d039b14207.

Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/albanD due to Broken test on main ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2258314481))
2024-07-30 13:11:24 +00:00
jainapurva
d039b14207 Grouped Query Attention (#128898)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-29 21:49:06 +00:00