ankushwahaRH
ba3c2c80ab
SDP Backend function fix ( #161169 )
...
The issue cannot be reproduced using the original repro code provided in the issue description.
However, the underlying issue mentioned by the maintainer (missing functions in `builder.py` and `trace_rules.py`) was never addressed and can still be reproduced with this test case:
```python
import torch
from torch.nn.attention import _cur_sdpa_kernel_backends
@torch.compile(fullgraph=True)
def test_function_that_triggers_error():
return _cur_sdpa_kernel_backends()
print("Calling torch.compile function...")
try:
result = test_function_that_triggers_error()
print(f"Success: {result}")
except Exception as e:
print(f"ERROR: {e}")
print(f"Error type: {type(e)}")
```
The original repro likely no longer triggers the issue due to code path changes in the SDPA implementation, while the direct call to `_cur_sdpa_kernel_backends()` exposes the underlying problem where certain torch._C functions returning non-Tensor values aren't properly handled by dynamo tracing.
I have implemented the changes by adding the missing functions to both `builder.py` and `trace_rules.py` to properly handle these cases during compilation.
@guilhermeleobas
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161169
Approved by: https://github.com/guilhermeleobas , https://github.com/StrongerXi
2025-09-19 20:19:59 +00:00
Apurva Jain
8bc5ef563e
Grouped Query Attention ( #132689 )
...
### Approach: Using the current function declaration
**Constraint:** Q_Heads % KV_Heads == 0
**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.
Sample use cases this would enable:
LLama3
```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)
output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)
# Output Shape
(batch, 32, seq_len_q, D)
```
### Design Choice:
- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.
### Benchmarks:
- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa
| batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
| 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 |
| 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 |
| 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 |
| 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458 **
Differential Revision: D60772086
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689
Approved by: https://github.com/drisspg
2024-08-07 05:35:36 +00:00
PyTorch MergeBot
bcb4f7c172
Revert "Grouped Query Attention ( #128898 )"
...
This reverts commit 6b28af1b79 .
Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/ZainRizvi due to Sorry, this broke a bunch of tests internally. See D60638265 ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2265961038 ))
2024-08-02 18:58:46 +00:00
jainapurva
6b28af1b79
Grouped Query Attention ( #128898 )
...
### Approach: Using the current function declaration
**Constraint:** Q_Heads % KV_Heads == 0
**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.
Sample use cases this would enable:
LLama3
```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)
output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)
# Output Shape
(batch, 32, seq_len_q, D)
```
### Design Choice:
- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.
### Benchmarks:
- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa
| batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
| 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 |
| 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 |
| 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 |
| 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458 **
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-31 22:58:51 +00:00
PyTorch MergeBot
499ead96ff
Revert "Grouped Query Attention ( #128898 )"
...
This reverts commit d039b14207 .
Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/albanD due to Broken test on main ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2258314481 ))
2024-07-30 13:11:24 +00:00
jainapurva
d039b14207
Grouped Query Attention ( #128898 )
...
### Approach: Using the current function declaration
**Constraint:** Q_Heads % KV_Heads == 0
**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.
Sample use cases this would enable:
LLama3
```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)
output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)
# Output Shape
(batch, 32, seq_len_q, D)
```
### Design Choice:
- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.
### Benchmarks:
- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa
| batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
| 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 |
| 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 |
| 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 |
| 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458 **
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-29 21:49:06 +00:00
Michael Lazos
c51a4e64c0
Add support for compiling SDPAParams ( #117207 )
...
Allows us to `allow_in_graph` this `torch._C` struct for supporting scaled dot product attention.
helps unblock https://github.com/pytorch/pytorch/pull/116071
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117207
Approved by: https://github.com/voznesenskym
2024-01-19 05:51:15 +00:00