Commit Graph

285 Commits

Author SHA1 Message Date
Wei-Sheng Chin
bca75fe97a [MAIA] [Autocast] Enable autocast on MAIA device (#148511)
Fixes #148510.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148511
Approved by: https://github.com/albanD
2025-03-18 03:46:22 +00:00
Jane Xu
cccdf860e2 [BE] Add STABLE_LIBRARY test for multiple returns (#149230)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149230
Approved by: https://github.com/albanD, https://github.com/zou3519
ghstack dependencies: #149052
2025-03-18 02:40:54 +00:00
Jane Xu
988827cdfb Use schema as source of truth + support ones_like/empty_like (#149052)
This change does 2 important things:
(a) Instead of relying on IValue type as source of truth, we use the schema as the source of truth, which is important as IValue types are overloaded and can ambiguously convert incorrectly. For example, a MemoryFormat will look like an int + get converted to an int64_t vs a MemoryFormat!

(b) This PR expands support for many more types to encompass way more schemas, e.g., Optional, Device, dtype, etc. The main win from this PR is the ability for aoti_torch_call_dispatcher to call TensorFactory ops like ones_like/empty_like!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149052
Approved by: https://github.com/albanD
2025-03-18 02:40:54 +00:00
Jane Xu
e6ef0620cc Add shim.h C API to call dispatcher on our own aten ops (#148832)
This PR still needs testing through some cpp extension

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148832
Approved by: https://github.com/albanD, https://github.com/atalman
ghstack dependencies: #148124
2025-03-11 21:02:04 +00:00
Jane Xu
971606befa Add a stable TORCH_LIBRARY to C shim (#148124)
This PR adds two main parts:
- shim.h stable C APIs into torch::Library APIs
- a higher level API in torch/csrc/stable/library.h that calls into this shim.h + otherwise is self contained

Goal: custom kernel writers should be able to call the apis in the directories above in order to register their library in a way that allows their custom extension to run with a different libtorch version than it was built with.

Subplots resolved:

- Do we want a whole separate StableLibrary or do we want to freeze torch::Library and add `m.stable_impl(cstring, void (*fn)(void **, int64_t, int64_t)` into it
    - Yes, we want a separate StableLibrary. We cannot freeze Library and it is NOT header only.
- Should I use unint64_t as the common denominator instead of void* to support 32bit architectures better?
    -  Yes, and done
- Should I add a stable `def` and `fragment` when those can be done in python?
    - I think we do want these --- and now they're done
- Where should library_stable_impl.cpp live? -- no longer relevant
- I need some solid test cases to make sure everything's going ok. I've intentionally thrown in a bunch of random dtypes into the signature, but I still haven't tested returning multiple things, returning nothing, complex dtypes, etc.
    - Have since tested all the torch library endpoints. the others can be tested in a followup to separate components that need to be in shim.h vs can be added later

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148124
Approved by: https://github.com/albanD, https://github.com/zou3519, https://github.com/atalman
2025-03-11 19:12:46 +00:00
PyTorch MergeBot
275a7c5dbb Revert "Add a stable TORCH_LIBRARY to C shim (#148124)"
This reverts commit 327e07ac1d.

Reverted https://github.com/pytorch/pytorch/pull/148124 on behalf of https://github.com/malfet due to Sorry for reverting your PR, but somehow it caused test failures in newly introduced tests, see https://hud.pytorch.org/hud/pytorch/pytorch/main/1?per_page=50&name_filter=pull%20%2F%20linux-focal-cuda12.6-py3.10-gcc11-sm89%20%2F%20test%20(default%2C%201&mergeLF=true ([comment](https://github.com/pytorch/pytorch/pull/148124#issuecomment-2709057833))
2025-03-09 20:44:56 +00:00
Jane Xu
327e07ac1d Add a stable TORCH_LIBRARY to C shim (#148124)
This PR adds two main parts:
- shim.h stable C APIs into torch::Library APIs
- a higher level API in torch/csrc/stable/library.h that calls into this shim.h + otherwise is self contained

Goal: custom kernel writers should be able to call the apis in the directories above in order to register their library in a way that allows their custom extension to run with a different libtorch version than it was built with.

Subplots resolved:

- Do we want a whole separate StableLibrary or do we want to freeze torch::Library and add `m.stable_impl(cstring, void (*fn)(void **, int64_t, int64_t)` into it
    - Yes, we want a separate StableLibrary. We cannot freeze Library and it is NOT header only.
- Should I use unint64_t as the common denominator instead of void* to support 32bit architectures better?
    -  Yes, and done
- Should I add a stable `def` and `fragment` when those can be done in python?
    - I think we do want these --- and now they're done
- Where should library_stable_impl.cpp live? -- no longer relevant
- I need some solid test cases to make sure everything's going ok. I've intentionally thrown in a bunch of random dtypes into the signature, but I still haven't tested returning multiple things, returning nothing, complex dtypes, etc.
    - Have since tested all the torch library endpoints. the others can be tested in a followup to separate components that need to be in shim.h vs can be added later

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148124
Approved by: https://github.com/albanD, https://github.com/zou3519
2025-03-09 10:07:25 +00:00
Dmitry Rogozhkin
d27ecf85db xpu: support sycl with torch.utils.cpp_extension APIs (#132945)
This patch adds support for sycl kernels build via `torch.utils.cpp_extension.load`, `torch.utils.cpp_extension.load_inline` and (new) `class SyclExtension` APIs. Files having `.sycl` extension are considered to have sycl kernels and are compiled with `icpx` (dpc++ sycl compiler from Intel). Files with other extensions, `.cpp`, `.cu`, are handled as before. API supports building sycl along with other file types into single extension.

Note that `.sycl` file extension is a PyTorch convention for files containing sycl code which I propose to adopt. We did follow up with compiler team to introduce such file extension in the compiler, but they are opposed to this. At the same time discussion around sycl file extension and adding sycl language support into such tools as cmake is ongoing. Eventually cmake also considers to introduce some file extension convention for sycl. I hope we can further influence cmake and compiler communities to broader adopt `.sycl` file extension.

By default SYCL kernels are compiled for all Intel GPU devices for which pytorch native aten SYCL kernels are compiled. At the moment `pvc,xe-lpg`. This behavior can be overridden by setting `TORCH_XPU_ARCH_LIST` environment variables to the comma separated list of desired devices to compile for.

Fixes: #132944

CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132945
Approved by: https://github.com/albanD, https://github.com/guangyey, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-16 16:50:59 +00:00
PyTorch MergeBot
dd5d0ea6bb Revert "xpu: support sycl with torch.utils.cpp_extension APIs (#132945)"
This reverts commit 607379960b.

Reverted https://github.com/pytorch/pytorch/pull/132945 on behalf of https://github.com/malfet due to It just broke all the tests, see b16ae97ad0/1 ([comment](https://github.com/pytorch/pytorch/pull/132945#issuecomment-2661498747))
2025-02-16 16:03:42 +00:00
Dmitry Rogozhkin
607379960b xpu: support sycl with torch.utils.cpp_extension APIs (#132945)
This patch adds support for sycl kernels build via `torch.utils.cpp_extension.load`, `torch.utils.cpp_extension.load_inline` and (new) `class SyclExtension` APIs. Files having `.sycl` extension are considered to have sycl kernels and are compiled with `icpx` (dpc++ sycl compiler from Intel). Files with other extensions, `.cpp`, `.cu`, are handled as before. API supports building sycl along with other file types into single extension.

Note that `.sycl` file extension is a PyTorch convention for files containing sycl code which I propose to adopt. We did follow up with compiler team to introduce such file extension in the compiler, but they are opposed to this. At the same time discussion around sycl file extension and adding sycl language support into such tools as cmake is ongoing. Eventually cmake also considers to introduce some file extension convention for sycl. I hope we can further influence cmake and compiler communities to broader adopt `.sycl` file extension.

By default SYCL kernels are compiled for all Intel GPU devices for which pytorch native aten SYCL kernels are compiled. At the moment `pvc,xe-lpg`. This behavior can be overridden by setting `TORCH_XPU_ARCH_LIST` environment variables to the comma separated list of desired devices to compile for.

Fixes: #132944

CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132945
Approved by: https://github.com/albanD, https://github.com/guangyey
2025-02-16 10:16:09 +00:00
Jane Xu
515e55e692 Set -DPy_LIMITED_API flag for py_limited_api=True extensions (#145764)
This could be BC breaking, because there was a period of time when we use py_limited_api=True but don't enforce the flag, and now that we will start enforcing the flag, people's custom extensions may fail to build.

This is strictly still better behavior, as it is sketchy to claim CPython agnosticism without the flag, but calling this out as potential people yelling at us. Ways to mitigate this risk + reasons this may not be too big a deal:
- People haven't known about py_limited_api for extensions much due to lack of docs from python so usage is low right now
- My current tutorial is in store to make new users of py_limited_api pass this flag, so it'd be a noop for them.

Test plan:
* Locally i'm confident as I tried rebuilding ao with this change and it reliably failed (cuz importing torch/extension.h is a nono)
* Unit test wise, the normal python_agnostic one I added should work

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145764
Approved by: https://github.com/ezyang, https://github.com/zou3519, https://github.com/albanD
2025-01-28 20:11:05 +00:00
Zhenbin Lin
a08f7f3266 OpenReg: fix issue of pin_memory (#145046)
Fix issue of `pin_memory` when rewrapping a storage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145046
Approved by: https://github.com/albanD
2025-01-28 09:41:04 +00:00
Zhenbin Lin
392dc177a9 OpenReg: Refactor impl_registry (#145465)
Refactor impl_registry to use `driver.exec` as fallback.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145465
Approved by: https://github.com/albanD
2025-01-25 03:31:49 +00:00
Zhenbin Lin
47e65077b1 OpenReg: Remove REGISTER_GENERATOR_PRIVATEUSE1 (#144841)
Replace REGISTER_GENERATOR_PRIVATEUSE1 with new API in AcceleratorHooksInterface.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144841
Approved by: https://github.com/albanD
2025-01-24 01:52:10 +00:00
Aaron Orenstein
99dbc5b0e2 PEP585 update - test (#145176)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145176
Approved by: https://github.com/bobrenjc93
2025-01-22 04:48:28 +00:00
Zhenbin Lin
adbbcd87d9 OpenReg: Split Allocator (#144843)
Split the Allocator into HostAllocator and DeviceAllocator.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144843
Approved by: https://github.com/albanD
2025-01-17 03:38:15 +00:00
Zhenbin Lin
52a620845b OpenReg: Use device agnostic API (#144840)
Use `torch.accelerator.device_count()` to get the number of devices.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144840
Approved by: https://github.com/albanD
2025-01-16 03:31:52 +00:00
Zhenbin Lin
cbb1ed2966 [1/N] OpenReg: Replace open_registration_extension.cpp with openreg (#141815)
As described in OpenReg [next-steps](https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/README.md#next-steps), here we replace the current `open_registration_extension.cpp` test in PyTorch CI with openreg.

The current `open_registration_extension.cpp` contains two parts:
1. Implentations to support `PrivateUse1` backend.
2. Helper functions used for UTs in `test_cpp_extensions_open_device_registration.py` and `test_transformers.py`.

For the first part, we'll replace it with openreg. For the second part, we'll migrate them to ut files step by step.

@albanD

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141815
Approved by: https://github.com/albanD
2025-01-14 15:59:00 +00:00
FFFrog
f47aac6bc2 Make Context to be Device-agnostic Step by Step (3/N) (#137578)
Detailed Descriptions:
- Using unified Device-agnostic API to create new generator for accelerator.
- Add deprecated info for GeneratorForPrivateuseone

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137578
Approved by: https://github.com/cyyever, https://github.com/ezyang
2024-12-18 15:12:19 +00:00
Jane Xu
be27dbf2b8 Enable CPP/CUDAExtension with py_limited_api for python agnosticism (#138088)
Getting tested with ao, but now there is a real test i added.

## What does this PR do?

We want to allow custom PyTorch extensions to be able to build one wheel for multiple Python versions, in other words, achieve python agnosticism. It turns out that there is such a way that setuptools/Python provides already! Namely, if the user promises to use only the Python limited API in their extension, they can pass in `py_limited_api` to their Extension class and to the bdist_wheel command (with a min python version) in order to build 1 wheel that will suffice across multiple Python versions.

Sounds lovely! Why don't people do that already with PyTorch? Well 2 things. This workflow is hardly documented (even searching for python agnostic specifically does not reveal many answers) so I'd expect that people simply don't know about it. But even if they did, _PyTorch_ custom Extensions would still not work because we always link torch_python, which does not abide by py_limited_api rules.

So this is where this PR comes in! We respect when the user specifies py_limited_api and skip linking torch_python under that condition, allowing users to enroll in the provided functionality I just described.

## How do I know this PR works?

I manually tested my silly little ultra_norm locally (with `import python_agnostic`) and wrote a test case for the extension showing that
- torch_python doesn't show up in the ldd tree
- no Py- symbols show up
It may be a little confusing that our test case is actually python-free (more clean than python-agnostic) but it is sufficient (and not necessary) towards showing that this change works.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138088
Approved by: https://github.com/ezyang, https://github.com/albanD
2024-12-11 18:22:55 +00:00
Zhenbin Lin
e0d97e936a OpenReg: Fix releasing tensor issue when exiting process (#140936)
When executing the following code:

```
import pytorch_openreg

import torch

if __name__ == "__main__":
    a = torch.tensor(1, device="openreg")

```
Sometimes releases tensor a failed after the process finishes executing `main` function. The trace of releasing `a` is `~Tensor()` -> ... -> `OpenRegMem.cpp` -> `OpenRegHooks.cpp` -> `_aten_impl.py`.

There are two failed scenarios I've found:

1. Segmentation fault: Before executing `~Tensor()`, the process has released global variables in `_aten_impl.py`, which causes the issue.
2. Waiting indefinitely: The main process passes the `free ptr` command  to deamon process, however daemon processes have shutdown.

The way to fix this issue is when the process is shutting down, we ignore the del ptr operation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140936
Approved by: https://github.com/ezyang
2024-11-22 13:50:35 +00:00
Zhenbin Lin
c9c8370feb Openreg: Add RNG Generator (#138449)
Implement RNG Generator by falling back to CPUGeneratorImpl.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138449
Approved by: https://github.com/ezyang
2024-11-20 09:27:55 +00:00
Zhenbin Lin
217d328764 OpenReg: Support autograd (#140662)
Add some unfinished implements to support autograd.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140662
Approved by: https://github.com/ezyang
2024-11-14 23:47:56 +00:00
Zhenbin Lin
9a051f6ee0 OpenReg: Fix issue when creating empty tensor (#140496)
On the exeuctor side, when it is found that meta.data_ptr is not in the allocated memory, tensor creation will fail, but there is no need to allocate memory when creating an empty tensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140496
Approved by: https://github.com/ezyang
2024-11-14 11:10:37 +00:00
Zhenbin Lin
a8de84998d OpenReg: Export the number of devices (#140492)
Export the number of devices so that it can be used in ut.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140492
Approved by: https://github.com/ezyang
2024-11-13 22:08:37 +00:00
Zhenbin Lin
8304a1faad OpenReg: Fix issue when casting tensor on the executor size (#140255)
Previously we assumed that the number of tensor elements multiplied by the type size is not greater than the allocated memory size. However in some scenarios such as `tensor.expand`, the stride can be zero, which makes the assumption not true.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140255
Approved by: https://github.com/ezyang
2024-11-12 19:29:21 +00:00
Zhenbin Lin
d90c25e3e2 OpenReg: Support event (#140111)
Support events. Since cpu backend doesn't support asynchronous execution, all event operations will be executed immediately on the executor side.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140111
Approved by: https://github.com/ezyang
2024-11-10 08:38:45 +00:00
Zhenbin Lin
d36fdaf157 Openreg: Support stream (#136991)
Support stream. When the driver communicates with the executor, it will send the stream id corresponding to the execution command; when the executor receives the command with the stream id, it will ignore the stream id because cpu backend doesn't support asynchronous execution.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136991
Approved by: https://github.com/ezyang
2024-11-07 06:09:07 +00:00
Edward Yang
b14269dcfb Make Context to be Device-agnostic Step by Step (1/N) (#136519) (#138155)
Summary:
- make init to be device-agnostic and move it to AcceleratorHooksInterface
- refactoring context related to device initialization

Original pull request: https://github.com/pytorch/pytorch/pull/136519

Test Plan: contbuild & OSS CI, see 4a8e49389c

Reviewed By: malfet

Differential Revision: D64471142

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138155
Approved by: https://github.com/malfet, https://github.com/bobrenjc93
2024-10-17 20:58:56 +00:00
PyTorch MergeBot
d4d687ffb2 Revert "Make Context to be Device-agnostic Step by Step (1/N) (#136519)"
This reverts commit 4a8e49389c.

Reverted https://github.com/pytorch/pytorch/pull/136519 on behalf of https://github.com/clee2000 due to breaking internal tests related to MITA, @ezyang has a forward fix? ([comment](https://github.com/pytorch/pytorch/pull/136519#issuecomment-2414588302))
2024-10-15 17:19:16 +00:00
FFFrog
4a8e49389c Make Context to be Device-agnostic Step by Step (1/N) (#136519)
----

- make init to be device-agnostic and move it to AcceleratorHooksInterface
- refactoring context related to device initialization

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136519
Approved by: https://github.com/ezyang, https://github.com/EikanWang, https://github.com/guangyey
2024-10-13 12:38:02 +00:00
PyTorch MergeBot
079f909263 Revert "Make Context to be Device-agnostic Step by Step (1/N) (#136519)"
This reverts commit be0b75256a.

Reverted https://github.com/pytorch/pytorch/pull/136519 on behalf of https://github.com/jovianjaison due to this pr is causing errors internally ([comment](https://github.com/pytorch/pytorch/pull/136519#issuecomment-2405781093))
2024-10-10 18:32:17 +00:00
FFFrog
be0b75256a Make Context to be Device-agnostic Step by Step (1/N) (#136519)
- make init to be device-agnostic and move it to AcceleratorHooksInterface
- refactoring context related to device initialization

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136519
Approved by: https://github.com/ezyang, https://github.com/EikanWang, https://github.com/guangyey
2024-10-09 02:13:36 +00:00
Zhenbin Lin
de7f32a205 openreg add pin_memory (#135339)
Occording to `Next steps` in test/cpp_extensions/open_registration_extension/README.md, add Pinned memory and HostAllocator.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135339
Approved by: https://github.com/albanD
2024-10-09 00:07:59 +00:00
Zhenbin Lin
46525abb71 OpenReg: support multiple executors (#136249)
From PR https://github.com/pytorch/pytorch/pull/135646 we have split the daemon into drvier/executor, however, current executor stands for all devices and allocate memory all together. In order to better simulate device behavior, here we support multiple executors, each executor stands for one device.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136249
Approved by: https://github.com/FFFrog, https://github.com/albanD
2024-10-08 01:37:08 +00:00
Joshua Rosenkranz
e20e7a8c38 Fixed developer setup issue in open_registration_extension (#137355)
This PR fixes an issue where when running `python setup.py develop`, the `open_registration_extension` self contained example will not build due to the following:

```
error: 'synchronizeStream' overrides a member function but is not marked 'override'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137355
Approved by: https://github.com/albanD, https://github.com/spzala
2024-10-07 15:25:37 +00:00
Nikita Shulga
57c02e5a00 [BE] Use helper functions in mps_extension (#137313)
This should be a no-op change, i.e. it runs the same code, but replaces verbose ObjectiveC invocation with helper function from OperationUtils.h, which this example already depends on
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137313
Approved by: https://github.com/atalman
2024-10-04 00:26:38 +00:00
Zhenbin Lin
41b58a1bec OpenReg: Fix issue when copying on the same device (#135956)
Current copy gets wrong value when src and dst are both openreg.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135956
Approved by: https://github.com/albanD
2024-09-14 09:57:45 +00:00
Zhenbin Lin
8d68a02905 OpenReg: Split the daemon into drvier/executor (#135646)
Split the daemon into a proper user-process driver vs device-process executor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135646
Approved by: https://github.com/albanD
2024-09-12 05:03:46 +00:00
FFFrog
27d86f93fe Remove redundant code (#134955)
Remove GetPrivateUse1HooksInterface
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134955
Approved by: https://github.com/Skylion007
2024-09-05 01:11:32 +00:00
albanD
3b33f26513 Add device daemon (#131814)
Base implementation aiming towards https://github.com/pytorch/rfcs/pull/64

Details of the implementation and next steps in https://github.com/pytorch/pytorch/blob/gh/albanD/3/head/test/cpp_extensions/open_registration_extension/README.md

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131814
Approved by: https://github.com/ezyang
2024-08-27 23:32:07 +00:00
cyy
8f7cf796ea [14/N] Use std::optional (#133417)
Follows #132527
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133417
Approved by: https://github.com/ezyang
2024-08-16 00:48:34 +00:00
cyyever
636a7c4859 [13/N] Use std::optional (#132527)
Follows #132361

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132527
Approved by: https://github.com/ezyang
2024-08-08 03:16:28 +00:00
Apurva Jain
8bc5ef563e Grouped Query Attention (#132689)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Differential Revision: D60772086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689
Approved by: https://github.com/drisspg
2024-08-07 05:35:36 +00:00
albanD
3d87dfc088 Add basic OpenReg module scaffolding with autograd (#131708)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131708
Approved by: https://github.com/ezyang
2024-08-05 17:07:11 +00:00
PyTorch MergeBot
bcb4f7c172 Revert "Grouped Query Attention (#128898)"
This reverts commit 6b28af1b79.

Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/ZainRizvi due to Sorry, this broke a bunch of tests internally. See D60638265 ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2265961038))
2024-08-02 18:58:46 +00:00
jainapurva
6b28af1b79 Grouped Query Attention (#128898)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-31 22:58:51 +00:00
Xuehai Pan
548c460bf1 [BE][Easy][7/19] enforce style for empty lines in import segments in test/[a-c]*/ and test/[q-z]*/ (#129758)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129758
Approved by: https://github.com/ezyang
2024-07-31 10:54:03 +00:00
PyTorch MergeBot
499ead96ff Revert "Grouped Query Attention (#128898)"
This reverts commit d039b14207.

Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/albanD due to Broken test on main ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2258314481))
2024-07-30 13:11:24 +00:00
jainapurva
d039b14207 Grouped Query Attention (#128898)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-29 21:49:06 +00:00