Apurva Jain
|
8bc5ef563e
|
Grouped Query Attention (#132689)
### Approach: Using the current function declaration
**Constraint:** Q_Heads % KV_Heads == 0
**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.
Sample use cases this would enable:
LLama3
```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)
output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)
# Output Shape
(batch, 32, seq_len_q, D)
```
### Design Choice:
- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.
### Benchmarks:
- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa
| batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
| 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 |
| 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 |
| 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 |
| 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**
Differential Revision: D60772086
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689
Approved by: https://github.com/drisspg
|
2024-08-07 05:35:36 +00:00 |
|
PyTorch MergeBot
|
bcb4f7c172
|
Revert "Grouped Query Attention (#128898)"
This reverts commit 6b28af1b79.
Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/ZainRizvi due to Sorry, this broke a bunch of tests internally. See D60638265 ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2265961038))
|
2024-08-02 18:58:46 +00:00 |
|
jainapurva
|
6b28af1b79
|
Grouped Query Attention (#128898)
### Approach: Using the current function declaration
**Constraint:** Q_Heads % KV_Heads == 0
**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.
Sample use cases this would enable:
LLama3
```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)
output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)
# Output Shape
(batch, 32, seq_len_q, D)
```
### Design Choice:
- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.
### Benchmarks:
- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa
| batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
| 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 |
| 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 |
| 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 |
| 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
|
2024-07-31 22:58:51 +00:00 |
|
PyTorch MergeBot
|
499ead96ff
|
Revert "Grouped Query Attention (#128898)"
This reverts commit d039b14207.
Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/albanD due to Broken test on main ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2258314481))
|
2024-07-30 13:11:24 +00:00 |
|
jainapurva
|
d039b14207
|
Grouped Query Attention (#128898)
### Approach: Using the current function declaration
**Constraint:** Q_Heads % KV_Heads == 0
**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.
Sample use cases this would enable:
LLama3
```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)
output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)
# Output Shape
(batch, 32, seq_len_q, D)
```
### Design Choice:
- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.
### Benchmarks:
- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa
| batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
| 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 |
| 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 |
| 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 |
| 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
|
2024-07-29 21:49:06 +00:00 |
|
Michael Lazos
|
c51a4e64c0
|
Add support for compiling SDPAParams (#117207)
Allows us to `allow_in_graph` this `torch._C` struct for supporting scaled dot product attention.
helps unblock https://github.com/pytorch/pytorch/pull/116071
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117207
Approved by: https://github.com/voznesenskym
|
2024-01-19 05:51:15 +00:00 |
|