pytorch/test/cpp_extensions
Apurva Jain 8bc5ef563e Grouped Query Attention (#132689)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Differential Revision: D60772086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689
Approved by: https://github.com/drisspg
2024-08-07 05:35:36 +00:00
..
no_python_abi_suffix_test [BE][Easy][7/19] enforce style for empty lines in import segments in test/[a-c]*/ and test/[q-z]*/ (#129758) 2024-07-31 10:54:03 +00:00
open_registration_extension Add basic OpenReg module scaffolding with autograd (#131708) 2024-08-05 17:07:11 +00:00
self_compiler_include_dirs_test
torch_test_cpp_extension [RFC] Add support for device extension autoloading (#127074) 2024-07-09 06:14:13 +00:00
cpp_c10d_extension.cpp
cpp_c10d_extension.hpp
cpp_frontend_extension.cpp
cublas_extension.cpp
cuda_dlink_extension_add.cu
cuda_dlink_extension_add.cuh
cuda_dlink_extension_kernel.cu
cuda_dlink_extension.cpp
cuda_extension_kernel.cu
cuda_extension_kernel2.cu
cuda_extension.cpp
cuda_extension.cu
cudnn_extension.cpp
cusolver_extension.cpp
dangling_impl_extension.cpp
doubler.h
extension.cpp Support torch.dtype as parameter in pybind11 cpp extension. (#126865) 2024-05-29 23:19:32 +00:00
identity.cpp
jit_extension.cpp
jit_extension2.cpp
maia_extension.cpp [codemod] c10::optional -> std::optional in caffe2/aten/src/ATen/DeviceGuard.h +117 (#126901) 2024-05-24 00:26:15 +00:00
mps_extension.mm
mtia_extension.cpp Support generic stream/event on CUDA/HIP backend (#125757) 2024-05-10 13:34:09 +00:00
open_registration_extension.cpp Grouped Query Attention (#132689) 2024-08-07 05:35:36 +00:00
rng_extension.cpp [codemod] c10:optional -> std::optional (#126135) 2024-05-14 19:35:51 +00:00
setup.py [BE][Easy][7/19] enforce style for empty lines in import segments in test/[a-c]*/ and test/[q-z]*/ (#129758) 2024-07-31 10:54:03 +00:00
torch_library.cu