mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
### Approach: Using the current function declaration **Constraint:** Q_Heads % KV_Heads == 0 **Major change:** - Added a new argument enable_gqa: bool to sdpa function call - It adds a meaning to the last third dimension. Sample use cases this would enable: LLama3 ``` # LLama3 8b call to SDPA query = torch.rand(batch, 32, seq_len_q, D) key = torch.rand(batch, 8, seq_len_kv, D) value = torch.rand(batch, 8, seq_len_kv, D) output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True) # Output Shape (batch, 32, seq_len_q, D) ``` ### Design Choice: - Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0 - The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms. - By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged. ### Benchmarks: - **sdpa.py: #130634** For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False | | ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- | | 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 | | 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 | | 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 | | 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |  - **TorchTitan: https://github.com/pytorch/torchtitan/pull/458** Differential Revision: D60772086 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689 Approved by: https://github.com/drisspg |
||
|---|---|---|
| .. | ||
| no_python_abi_suffix_test | ||
| open_registration_extension | ||
| self_compiler_include_dirs_test | ||
| torch_test_cpp_extension | ||
| cpp_c10d_extension.cpp | ||
| cpp_c10d_extension.hpp | ||
| cpp_frontend_extension.cpp | ||
| cublas_extension.cpp | ||
| cuda_dlink_extension_add.cu | ||
| cuda_dlink_extension_add.cuh | ||
| cuda_dlink_extension_kernel.cu | ||
| cuda_dlink_extension.cpp | ||
| cuda_extension_kernel.cu | ||
| cuda_extension_kernel2.cu | ||
| cuda_extension.cpp | ||
| cuda_extension.cu | ||
| cudnn_extension.cpp | ||
| cusolver_extension.cpp | ||
| dangling_impl_extension.cpp | ||
| doubler.h | ||
| extension.cpp | ||
| identity.cpp | ||
| jit_extension.cpp | ||
| jit_extension2.cpp | ||
| maia_extension.cpp | ||
| mps_extension.mm | ||
| mtia_extension.cpp | ||
| open_registration_extension.cpp | ||
| rng_extension.cpp | ||
| setup.py | ||
| torch_library.cu | ||