Commit Graph

69 Commits

Author SHA1 Message Date
Ke Wen
062387fb53 [SymmMem] Speed up tests (#153677)
Use `MultiProcContinousTest` to avoid re-create ProcessGroup in each test instance.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153677
Approved by: https://github.com/fegin, https://github.com/Skylion007, https://github.com/ngimel
ghstack dependencies: #153653
2025-05-26 03:39:11 +00:00
Wei Wang
7128b50a65 [CI][CUDA][Distributed] Move cuda 11.8 distributed pull jobs to cuda 12.6 (#151594)
This PR moves distributed cuda CI job from cuda 11.8 to cuda 12.6.
In doing so, a few unit test failures were exposed, some if not all of which would take a while to root-cause and fix, so temporarily skip them after creating the issues.

https://github.com/pytorch/pytorch/issues/153479 test_nan_assert tricky behavior (e.g. skip_but_pass_in_sandcastle, ubuntu 20.04 does not work, ubuntu 22.04 works, Amazon Linux 2023 skip - what is Sandcastle OS?)
https://github.com/pytorch/pytorch/issues/153122 CUDA context related
https://github.com/pytorch/pytorch/issues/153517  NCCL regression, future NCCL may fix it
https://github.com/pytorch/pytorch/issues/154073 skip test_symmetric_memory for cuda 12.6 before it is fixed

See: https://github.com/pytorch/pytorch/issues/147383

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151594
Approved by: https://github.com/eqy, https://github.com/atalman, https://github.com/cyyever, https://github.com/huydhn, https://github.com/kwen2501
2025-05-22 06:33:29 +00:00
Chien-Chin Huang
498f364518 Fix test_fused_scaled_matmul_reduce_scatter when scatter_dim is 0 (#153286)
The function signature of fused_scaled_matmul_reduce_scatter was changed. This PR fixes the function signature. However when scatter_dim is 1, the two outputs are not close. We need a followup on this.

Another followup is to change fused_scaled_matmul_reduce_scatter to make those newly added arguments optional. Users shouldn't need to these arguments if they don't flatten the inputs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153286
Approved by: https://github.com/kwen2501
2025-05-12 17:38:49 +00:00
Jithun Nair
fe8ebacee4 [ROCm] Upgrade ROCm CI to ROCm6.4 (#151368)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151368
Approved by: https://github.com/jeffdaily, https://github.com/malfet

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-05-08 16:12:16 +00:00
Prachi Gupta
1ea2731e26 [ROCm] Add support for SymmetricMemory (#150580)
This is an attempt to re-land the initial PR https://github.com/pytorch/pytorch/pull/134817 with recent design changes from upstream.

**NOTE:**
ROCm currently does NOT have multicast/multimem hardware support at the moment, so those features are disabled in symmetric memory for ROCm. This also means that we currently do not have a way of lowering add + all_reduce + wait_tensor into one_shot_all_reduce op in inductor as it depends on a multicast buffer support.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150580
Approved by: https://github.com/jeffdaily, https://github.com/kwen2501, https://github.com/yoyoyocmu

Co-authored-by: Xiaodong Wang <xdwang@fb.com>
2025-05-02 18:35:14 +00:00
Prachi Gupta
7e5f6dcf7f Add @requires_multicast_support to test_multimem_all_gather (#151227)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151227
Approved by: https://github.com/jeffdaily
2025-04-15 18:41:12 +00:00
Natalia Gimelshein
d04a6ec021 add reduce_scatter to symm mem ops (#150813)
+ a few small fixes (don't error out on 0-element tensors, a few more checks for contiguous outputs, more threads for better perf).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150813
Approved by: https://github.com/xw285cornell
2025-04-09 17:59:17 +00:00
Natalia Gimelshein
1700599266 Add one_shot_all_reduce_copy to allow non-symm-mem allocated tensors to be reduced (#150129)
Per title, we want to be able to use it even if inputs are not registered. Separate copy would add latency, and one-shot is all about the lowest possible latency.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150129
Approved by: https://github.com/xw285cornell
2025-04-01 05:36:43 +00:00
Natalia Gimelshein
414b9ae016 enable out variant of 2-shot reduction (#150153)
Per title, this version uses symm mem input both as input source and as a work buffer, so input is modified after the end (similar to what fbgemm car reduction does). It is intended to be wrapped in an op that would first copy the real inputs to symm mem buffers that wouldn't be exposed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150153
Approved by: https://github.com/xw285cornell
2025-04-01 05:36:04 +00:00
PyTorch MergeBot
57fa99c5c3 Revert "enable out variant of 2-shot reduction (#150153)"
This reverts commit cdeb32d2d1.

Reverted https://github.com/pytorch/pytorch/pull/150153 on behalf of https://github.com/clee2000 due to failing internal builds D72083877 ([comment](https://github.com/pytorch/pytorch/pull/150153#issuecomment-2766633712))
2025-03-31 15:43:24 +00:00
PyTorch MergeBot
e57fa18b40 Revert "Add one_shot_all_reduce_copy to allow non-symm-mem allocated tensors to be reduced (#150129)"
This reverts commit 8a872261dc.

Reverted https://github.com/pytorch/pytorch/pull/150129 on behalf of https://github.com/clee2000 due to breaking internal builds D72080428 ([comment](https://github.com/pytorch/pytorch/pull/150129#issuecomment-2766619006))
2025-03-31 15:37:54 +00:00
Natalia Gimelshein
cdeb32d2d1 enable out variant of 2-shot reduction (#150153)
Per title, this version uses symm mem input both as input source and as a work buffer, so input is modified after the end (similar to what fbgemm car reduction does). It is intended to be wrapped in an op that would first copy the real inputs to symm mem buffers that wouldn't be exposed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150153
Approved by: https://github.com/xw285cornell
2025-03-28 19:06:03 +00:00
Natalia Gimelshein
8a872261dc Add one_shot_all_reduce_copy to allow non-symm-mem allocated tensors to be reduced (#150129)
Per title, we want to be able to use it even if inputs are not registered. Separate copy would add latency, and one-shot is all about the lowest possible latency.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150129
Approved by: https://github.com/xw285cornell
2025-03-28 02:14:27 +00:00
Yifu Wang
db33d23aa8 [SymmetricMemory] fix an issue where rendezvous is performed with wrong device context when torch.cuda.set_device() is not callled (#144886)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144886
Approved by: https://github.com/awgu
2025-01-28 01:43:37 +00:00
Jagadish Krishnamoorthy
8f3eb84373 ROCm: Enable 4 gpu tests for distributed config (#140319)
Change the label to make sure the jobs land on a
node which has >= 4 GPUs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140319
Approved by: https://github.com/jeffdaily, https://github.com/malfet, https://github.com/kwen2501
2025-01-02 17:22:11 +00:00
Yifu Wang
af190479c8 [fused_all_gather_matmul] use _multimem_all_gather_matmul for small global Ms (#143160)
## Benchmark
M=2048, N=3584, K=8192

baseline (nccl + cublas): 301us
decomp-based async-tp: 354us
comm-aware async-tp: 295us
**multimem_all_gather matmul: 277us**

As M further decreases, the multimem_all_gather approach consistently outperforms the baseline and other approaches (omitted other approaches in the chart as they start to be slower than the baseline):
![image](https://github.com/user-attachments/assets/5811455a-68c9-43fe-9d82-ca488dd77bc1)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143160
Approved by: https://github.com/weifengpy
ghstack dependencies: #142283, #142810, #143159
2024-12-17 01:07:27 +00:00
Yifu Wang
6fae60a34a [SymmetricMemory] introduce multimem_all_gather (#142810)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142810
Approved by: https://github.com/weifengpy
ghstack dependencies: #142283
2024-12-17 01:07:27 +00:00
Tom Ritchford
d25e6e623f Fix unused Python variables in test/[a-d]* (#134665)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134665
Approved by: https://github.com/albanD
2024-12-13 22:13:12 +00:00
Yifu Wang
810808d97d Enable cutlass-based all-gather matmul when TORCH_SYMM_MEM_ENABLE_NATIVE_ASYNC_TP is set (#142283)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142283
Approved by: https://github.com/weifengpy, https://github.com/Chillee
2024-12-13 10:29:14 +00:00
Yifu Wang
5513e2ec35 [SymmetricMemory] use the python version of empty() and rendezvous() for tests and library ops (#142154)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142154
Approved by: https://github.com/weifengpy
2024-12-05 22:09:36 +00:00
Kiuk Chung
5b0b16ca62 [torch/distributed] Make _SymmetricMemory.has_multicast_support() ret… (#141598)
`SymmetricMemory.has_multicast_support()` throws an exception rather than returning `False` when called with a `DeviceType` that does not support. For example:

```
 from torch._C._distributed_c10d import _SymmetricMemory
 from torch._C._autograd import DeviceType

try:
	supports_multicast = _SymmetricMemory.has_multicast_support(DeviceType.CPU, 0)
except RuntimeError as exc:
	assert str(exc) == "SymmetricMemory does not support device type cpu"
```

This is problematic when building PyTorch from source without `CUDASymmetricMemory.cu` since the [`@requires_multicast_support`](https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_distributed.py#L353) test decorator will throw an exception rather than skipping the test (as intended)

This PR makes `_SymmetricMemory.has_multicast_support()` properly return `False` when multicast is not supported on the passed device.

cc) @malfet , @atalman

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141598
Approved by: https://github.com/yifuwang
2024-11-26 23:36:32 +00:00
Syed Tousif Ahmed
e0482fdf95 Implements user buffer registration using MemPool (#133603)
This PR implements user buffer registration and demonstrates NVLink Sharp (NVLS) reductions using a combination of allocation special memory using MemPool and registering it with the nccl buffer registration APIs.

Part of https://github.com/pytorch/pytorch/issues/124807.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133603
Approved by: https://github.com/kwen2501, https://github.com/eqy
2024-11-21 01:40:11 +00:00
PyTorch MergeBot
496c1e78c5 Revert "Implements user buffer registration using MemPool (#133603)"
This reverts commit 25d9be37be.

Reverted https://github.com/pytorch/pytorch/pull/133603 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/133603#issuecomment-2486897708))
2024-11-19 22:42:26 +00:00
Yifu Wang
5a7e147ef3 [SymmetricMemory] introduce user-facing APIs empty() and rendezvous() (#139677)
Previously `SymmetricMemory` only had private pybind APIs:
```python
from torch.distributed._symmetric_memory import _SymmetricMemory
t = _SymmetricMemory.empty_strided_p2p(
    size=(64,),
    stride=(1,),
    dtype=torch.float32,
    device=device,
)
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group_name=group.group_name)
```

This PR introduces user-facing APIs empty() and rendezvous():
```python
import torch.distributed._symmetric_memory as symm_mem
t = symm_mem.empty(64, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t, group_name=group.group_name)
```

Notable differences compared to the pybind APIs:
- `empty()` now resembles `torch.empty()`:
  - shape can either be an integer sequence or pack
  - no need to/can't specify stride anymore
  - device can either be `torch.device` or string
- `group_name` needs to be specified at rendezvous time as opposed to allocation time. See https://github.com/pytorch/pytorch/pull/139529 for the rationales. I feel the new semantic is superior, hence enforcing it in the public API.
  - Currently, the pybind API still support specifying `group_name` at rendezvous time.

This PR does not change the behavior of the pybind APIs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139677
Approved by: https://github.com/lw
ghstack dependencies: #139529
2024-11-17 20:51:50 +00:00
Yifu Wang
ab5c8857ef [SymmetricMemory] support specifying group_name at rendezvous time (#139529)
Before this PR, users need to call `empty_strided_p2p()` with a `group_name`:

```python
tensor = _SymmetricMemory.empty_strided_p2p((1024,), (1,), device=device, group_name="0")
symm_mem = _SymmetricMemory.rendezvous(tensor)
```

Users can now omit `group_name` at allocation time and specify it later at rendezvous time:

```python
tensor = _SymmetricMemory.empty_strided_p2p((1024,), (1,), device=device)
symm_mem = _SymmetricMemory.rendezvous(tensor, group_name="0")
```

Rationales for this change:
- This allows the same allocation to establish symmetric memory under different groups
- Specifying `group_name` at rendezvous time instead of allocation time is a more natural UX

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139529
Approved by: https://github.com/lw
2024-11-17 09:31:17 +00:00
Syed Tousif Ahmed
25d9be37be Implements user buffer registration using MemPool (#133603)
This PR implements user buffer registration and demonstrates NVLink Sharp (NVLS) reductions using a combination of allocation special memory using MemPool and registering it with the nccl buffer registration APIs.

Part of https://github.com/pytorch/pytorch/issues/124807.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133603
Approved by: https://github.com/kwen2501, https://github.com/eqy
2024-11-15 12:47:49 +00:00
Yifu Wang
02d0c43c32 [SymmetricMemory] fix a bug in symm_mem::memset32_ where the ops fails when offset=0 (#140129)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140129
Approved by: https://github.com/lw
ghstack dependencies: #140127, #140128
2024-11-14 23:29:16 +00:00
Yifu Wang
684db9beb2 [SymmetricMemory] fix a bug where get_signal_pad() returns a tensor backed by a buffer ptr instead of a signal_pad ptr (#140128)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140128
Approved by: https://github.com/lw
ghstack dependencies: #140127
2024-11-14 23:29:16 +00:00
Yifu Wang
c3d61bd367 [SymmetricMemory] allow overlapping devices for testing (#140127)
When `TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES` is set, the check for overlapping devices and multicast support will be disabled. This is useful for testing with a single device.

Making this is an env var instead of an API argument since this is likely only useful for testing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140127
Approved by: https://github.com/lw
2024-11-14 23:29:16 +00:00
Yifu Wang
0a0915fb5e [SymmetricMemory] improve the API for stream_write_value32 (#139934)
This PR updates the binding for `stream_write_value32` to be consistent with `memset32` which IMO makes more sense for this type of utilities:
- Changed the API to take a uint32 tensor as argument, instead of a device pointer
- Changed the Python binding to be a static method of `_SymmetricMemory`, instead of a object method
- Use the dispatcher for device dispatching, as opposed to `SymmetricMemory` backends

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139934
Approved by: https://github.com/weifengpy
ghstack dependencies: #139227
2024-11-11 18:49:22 +00:00
PyTorch MergeBot
5f4a21dc58 Revert "[SymmetricMemory] improve the API for stream_write_value32 (#139934)"
This reverts commit 2f3a5a15ef.

Reverted https://github.com/pytorch/pytorch/pull/139934 on behalf of https://github.com/malfet due to Broke distributed tests, see https://github.com/pytorch/pytorch/actions/runs/11770673088/job/32784210441 ([comment](https://github.com/pytorch/pytorch/pull/139934#issuecomment-2468641512))
2024-11-11 17:02:07 +00:00
Yifu Wang
2f3a5a15ef [SymmetricMemory] improve the API for stream_write_value32 (#139934)
This PR updates the binding for `stream_write_value32` to be consistent with `memset32` which IMO makes more sense for this type of utilities:
- Changed the API to take a uint32 tensor as argument, instead of a device pointer
- Changed the Python binding to be a static method of `_SymmetricMemory`, instead of a object method
- Use the dispatcher for device dispatching, as opposed to `SymmetricMemory` backends

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139934
Approved by: https://github.com/weifengpy
ghstack dependencies: #139227
2024-11-11 01:54:35 +00:00
Yifu Wang
1659e241c8 [experimental] async-tp impl with cutlass-based, progress aware kernel (#139227)
This PR introduces the following:

### torch.ops.symm_mem._async_input_mm

`_async_input_mm(Tensor a, Tensor b, Tensor a_chunk_signals, int a_chunk_pivot) -> Tensor`

An mm impl that supports consuming asynchronous input. It guarantees the following rasterization order, and that the corresponding signal arrives before an input chunk is consumed.
```
num_chunks = a_chunks_signals.numel()
for chunk_idx in range(a_chunk_pivot, num_chunks + a_chunk_pivot):
    chunk_idx = chunk_idx % num_chunks
    wait_signal(a_chunk_signals, chunk_idx)
    # Compute output tiles that consumes the input chunk
```

### PersistentAsyncInputScheduler

This is a forked version of PersistentScheduler that supports consuming asynchronous input. This tile scheduler introduces the following arguments:

- `tiles_per_chunk_m` – Specifies the size of an M chunk. Chunks are the granularity at which the asynchronous input becomes ready. It must be an interger multiple of the size of an M tile.
- `chunk_signals` – `chunk_signals[i] == 1` indicates that chunk i is ready. Before returning a work tile, get_current_work() waits for the signal to ensure that the corresponding chunk is ready.
- `tile_idx_pivot_m` – After applying swizzling, apply `pivot(m) => (m + tile_idx_pivot_m) % tiles_m` to `m`. In a distributed setting, this allows different ranks to process different m indices at the same time, thus avoiding communication hotspots.

Note that this scheduler currently only supports the `KernelTmaWarpSpecializedCooperative` kernel schedule. This is enforced via the template argument `KernelSchedule`.

Usage:
```
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
   Shape<int, int, int, int>,
   CollectiveMainloop,
   CollectiveEpilogue,
   cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
```

### _fused_all_gather_matmul_native
An ag-mm impl that combines `torch.ops.symm_mem._async_input_mm` and progress-aware all-gather. This is not yet enabled via the async-tp passes. We will use it as a backend to optimize the current decomposition-based async-tp impl.

## Benchmarks

### 4096x3584x8192
- cublas + nccl: 539us
- decomp-based async-tp w/o cuda graph: 694us
- decomp-based async-tp w/ cuda graph: 478us
- new cutlass kernel: 408us

<img width="478" alt="image" src="https://github.com/user-attachments/assets/39f316ab-36c5-4b41-af77-07854a385dfc">

### 2048x3584x8192
- cublas + nccl: 301us
- decomp-based async-tp w/o cuda graph: 687us
- decomp-based async-tp w/ cuda graph: 356us
- new cutlass kernel: 276us

<img width="441" alt="image" src="https://github.com/user-attachments/assets/9e23ce21-863b-43dd-a562-fb05d3a5a144">

## Next Steps
- Add tuning logic
- Use `_fused_all_gather_matmul_native` as a backend for the decomp-based async-tp impl

Differential temp Revision: [D65623152](https://our.internmc.facebook.com/intern/diff/D65623152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139227
Approved by: https://github.com/weifengpy, https://github.com/Chillee
2024-11-08 23:28:25 +00:00
PyTorch MergeBot
36e0f119d0 Revert "[experimental] async-tp impl with cutlass-based, progress aware kernel (#139227)"
This reverts commit 5203138483.

Reverted https://github.com/pytorch/pytorch/pull/139227 on behalf of https://github.com/yifuwang due to Need to address internal build failure D65605027 ([comment](https://github.com/pytorch/pytorch/pull/139227#issuecomment-2463204467))
2024-11-07 21:01:36 +00:00
Yifu Wang
5203138483 [experimental] async-tp impl with cutlass-based, progress aware kernel (#139227)
This PR introduces the following:

### torch.ops.symm_mem._async_input_mm

`_async_input_mm(Tensor a, Tensor b, Tensor a_chunk_signals, int a_chunk_pivot) -> Tensor`

An mm impl that supports consuming asynchronous input. It guarantees the following rasterization order, and that the corresponding signal arrives before an input chunk is consumed.
```
num_chunks = a_chunks_signals.numel()
for chunk_idx in range(a_chunk_pivot, num_chunks + a_chunk_pivot):
    chunk_idx = chunk_idx % num_chunks
    wait_signal(a_chunk_signals, chunk_idx)
    # Compute output tiles that consumes the input chunk
```

### PersistentAsyncInputScheduler

This is a forked version of PersistentScheduler that supports consuming asynchronous input. This tile scheduler introduces the following arguments:

- `tiles_per_chunk_m` – Specifies the size of an M chunk. Chunks are the granularity at which the asynchronous input becomes ready. It must be an interger multiple of the size of an M tile.
- `chunk_signals` – `chunk_signals[i] == 1` indicates that chunk i is ready. Before returning a work tile, get_current_work() waits for the signal to ensure that the corresponding chunk is ready.
- `tile_idx_pivot_m` – After applying swizzling, apply `pivot(m) => (m + tile_idx_pivot_m) % tiles_m` to `m`. In a distributed setting, this allows different ranks to process different m indices at the same time, thus avoiding communication hotspots.

Note that this scheduler currently only supports the `KernelTmaWarpSpecializedCooperative` kernel schedule. This is enforced via the template argument `KernelSchedule`.

Usage:
```
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
   Shape<int, int, int, int>,
   CollectiveMainloop,
   CollectiveEpilogue,
   cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
```

### _fused_all_gather_matmul_native
An ag-mm impl that combines `torch.ops.symm_mem._async_input_mm` and progress-aware all-gather. This is not yet enabled via the async-tp passes. We will use it as a backend to optimize the current decomposition-based async-tp impl.

## Benchmarks

### 4096x3584x8192
- cublas + nccl: 539us
- decomp-based async-tp w/o cuda graph: 694us
- decomp-based async-tp w/ cuda graph: 478us
- new cutlass kernel: 408us

<img width="478" alt="image" src="https://github.com/user-attachments/assets/39f316ab-36c5-4b41-af77-07854a385dfc">

### 2048x3584x8192
- cublas + nccl: 301us
- decomp-based async-tp w/o cuda graph: 687us
- decomp-based async-tp w/ cuda graph: 356us
- new cutlass kernel: 276us

<img width="441" alt="image" src="https://github.com/user-attachments/assets/9e23ce21-863b-43dd-a562-fb05d3a5a144">

## Next Steps
- Add tuning logic
- Use `_fused_all_gather_matmul_native` as a backend for the decomp-based async-tp impl

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139227
Approved by: https://github.com/weifengpy, https://github.com/Chillee
2024-11-07 03:43:12 +00:00
Yifu Wang
ee42a99745 [SymmetricMemory] introduce a binding for cuMemset32Async (#138755)
## This Stack

This stack does the following things to support `xformers`-style, comm-aware Triton kernels:
- Exposes `signal_pad`s as tensors in Python
- Adds a binding for `cuMemsetAsync`

These in combination aims to provide users with more flexibility to express custom signaling/synchronization patterns.

## This PR
Make `cuMemset32Async` available via `_SymmetricMemory.memset32`. We chose `cuMemset32Async` over `cudaMemsetAsync` because it allows for `uint32_t`-wise memset. This provides users with better flexibility.

To enable this, we also added the following cuda driver APIs in `c10::cuda::DriverAPI`:
- `cuDevicePrimaryCtxRetain` - for obtaining the primary context of a device in the form of `CUcontext`.
- `cuCtxGetCurrent`/`cuCtxSetCurrent` - for setting and restoring the context for cuda driver APIs such as `cuMemset32Async`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138755
Approved by: https://github.com/weifengpy, https://github.com/eqy, https://github.com/lw
2024-11-05 18:47:24 +00:00
PyTorch MergeBot
3ca794783f Revert "[SymmetricMemory] introduce a binding for cuMemset32Async (#138755)"
This reverts commit 924e726c3a.

Reverted https://github.com/pytorch/pytorch/pull/138755 on behalf of https://github.com/ZainRizvi due to Sorry but this breaks internally.  Can you please fix this PR so it works internally and re-merge it? See D65401876 for more details ([comment](https://github.com/pytorch/pytorch/pull/138755#issuecomment-2455173596))
2024-11-04 16:34:34 +00:00
Yifu Wang
924e726c3a [SymmetricMemory] introduce a binding for cuMemset32Async (#138755)
## This Stack

This stack does the following things to support `xformers`-style, comm-aware Triton kernels:
- Exposes `signal_pad`s as tensors in Python
- Adds a binding for `cuMemsetAsync`

These in combination aims to provide users with more flexibility to express custom signaling/synchronization patterns.

## This PR
Make `cuMemset32Async` available via `_SymmetricMemory.memset32`. We chose `cuMemset32Async` over `cudaMemsetAsync` because it allows for `uint32_t`-wise memset. This provides users with better flexibility.

To enable this, we also added the following cuda driver APIs in `c10::cuda::DriverAPI`:
- `cuDevicePrimaryCtxRetain` - for obtaining the primary context of a device in the form of `CUcontext`.
- `cuCtxGetCurrent`/`cuCtxSetCurrent` - for setting and restoring the context for cuda driver APIs such as `cuMemset32Async`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138755
Approved by: https://github.com/weifengpy, https://github.com/eqy, https://github.com/lw
2024-11-03 21:37:31 +00:00
Yifu Wang
0dbc284a72 [SymmetricMemory] expose signal_pads as tensors in Python (#138754)
## This Stack

This stack does the following things to support `xformers`-style, comm-aware Triton kernels:
- Exposes `signal_pad`s as tensors in Python
- Adds a binding for `cuMemsetAsync`

These in combination aims to provide users with more flexibility to express custom signaling/synchronization patterns.

## This PR

```python
# Obtain the signal pad of the specified peer rank as a tensor.
# If both shape and dtype are unspecified, the returned tensor will be a
# 1d uint32 tensor, which is most natural for signaling purposes.
symm_mem.get_signal_pad(peer_rank)

# If only shape is specified, it is equivalent to:
# symm_mem.get_signal_pad(peer_rank)[:shape.numel()].view(shape)
symm_mem.get_signal_pad(peer_rank, shape)

# If only dtype is specified, it is equivalent to:
# symm_mem.get_signal_pad(peer_rank).view(dtype)
symm_mem.get_signal_pad(peer_rank, dtype=dtype)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138754
Approved by: https://github.com/weifengpy, https://github.com/lw
2024-11-01 20:17:15 +00:00
eqy
6fc63b4ef1 [ROCM][CUDA][NCCL] Disable test_lowering_one_shot_all_reduce on ROCM (#139414)
I'm not sure this is expected to run if it requires buffer-registration support CC @yifuwang @huydhn @syed-ahmed #138029

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139414
Approved by: https://github.com/huydhn, https://github.com/yifuwang
2024-11-01 18:39:47 +00:00
Yifu Wang
7765d1ef70 Preliminary registered-buffer collective support via Inductor (#138029)
```
NOTE [lowering-time collective optimization]

In collective communication libraries such as NCCL, every rank maintains
communication buffers that are remotely accessible by some peers. Depending
on the underlying transport, remote accessibility may be established via
mechanisms such as ib_reg_mr, CUDA P2P, or CUDA multicast. Typically, these
buffers are private to the communication library by default, and
communication ops copy user data in and out of these buffers.

To prevent these copies, an optimization commonly known as "user buffer
registration" can be employed. This allows direct establishment of remote
accessibility on user buffers, eliminating the need for copying. However,
this optimization introduces stringent usage requirements, which are
typically hard to satisfy without being intrusive to the user code:

- Establishing remote accessibility is expensive and often done ahead of
time. In such implementations, all ranks must agree on the set of allocations
used for every collective op. Failing to meet this requirement can
lead to runtime errors or even silent correctness issues.
- Even if the collective communication library supports gracefully falling
back to "unregistered" implementations, the fallback mechanism would nullify
the optimization.
- Some communication mechanisms impose stricter requirements than others. For
example, CUDA's multicast + multi-mem instructions require all ranks to agree
not only on the allocations used for every collective but also on the offsets
within these allocations.

To support all different mechanisms with optimal results, we aim to satisfy
the strictest requirement for this family of optimizations - we ensures that
every collective op invocation is guaranteed to operate on the same
allocation, at the same offset, in every iteration.

For eligible collective ops, we identify communication buffers at lowering
time and optionally choose to lower the op to a different kernel
(ommunication libraries like NCCL handle both registered and non-registered
buffers transparently within the same op, though some may require different
ops for different cases). Later, the codegen will perform "persistent
allocation" to satisfy the aforementioned constraints, and optionally,
perform buffer planning to optimize overall memory usage.
```

### Changes
- Created `comm_lowering.py` for the lowerings of `_c10d_functional` ops. This is to prevent cluttering `lowering.py` as we add more lowering-time collective optimizations. This PR moved the lowerings for `all_reduce` and `all_reduce_` to the file.
- Added `comm_buffer_type: Dict[str, str]` to `GraphLowering` to track whether a buffer is a comm buffer and the type of the comm buffer.
- Added codegen allocation support for comm buffers of type "symm_mem".
- Added support for auto-lowering `_c10d_functional.all_reduce_` to `symm_mem.one_shot_all_reduce`.
- Added an Inductor config for collective optimizations in general (`config._collective`).

### Limitation
Currently, each persistently allocated comm buffer is dedicated to a single callsite. This is not viable in terms of memory usage. However, this is a neccesary intermediate state before we tackle memory planning for comm buffers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138029
Approved by: https://github.com/Chillee
ghstack dependencies: #138028
2024-10-30 18:11:09 +00:00
Yifu Wang
ef541c1a65 [fused_all_gather_scaled_matmul] support rowwise scaling (#137805)
This PR add support for `A_scale` to be row-wise scale. The op can automatically detect whether the row-wise scale is sharded or replicated. When the row-wise scale is sharded, the op would all-gather the scale in a pipelined fashion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137805
Approved by: https://github.com/weifengpy
ghstack dependencies: #137643, #137738
2024-10-15 21:35:14 +00:00
Yifu Wang
05edaeaded [fused_scaled_matmul_reduce_scatter] support rowwise scaling (#137738)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137738
Approved by: https://github.com/Chillee, https://github.com/weifengpy
ghstack dependencies: #137643
2024-10-15 21:35:14 +00:00
Yifu Wang
91bc9dc2c9 [SymmetricMemory] implement timeout for barrier(), put_signal() and wait_signal() (#137643)
Suggested by @lw for better safety/reliability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137643
Approved by: https://github.com/weifengpy, https://github.com/lw
2024-10-15 21:35:14 +00:00
Yifu Wang
ea83c78174 [SymmetricMemory] set the storage_offset of tensors returned by get_buffer() to 0 (#137569)
It seems that there's a bug in `TensorMaker` - it would treat `storage_offset` as bytes when calculating the storage size, but as numel when setting the tensor `storage_offset`. This seems to be causing tensors returned by get_buffer() with non-0 offset to report wrong storage size.

Will look into the `TensorMaker` issue further. But for `get_buffer()`, it seems more natural to just incorporate the offset into the data pointer.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137569
Approved by: https://github.com/weifengpy
ghstack dependencies: #137567
2024-10-10 05:05:58 +00:00
Yifu Wang
fbaf9b62de [SymmetricMemoryOps] use float32 as the accumulator type when accumulating bfloat16 with multimem.ld_reduce (#137529)
This provides better accuracy without additional cost.

Also added documentation to `multimem_one_shot_all_reduce` to note the numerical caveats.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137529
Approved by: https://github.com/Chillee
ghstack dependencies: #137471, #137472, #137473, #137474, #137475
2024-10-09 23:30:16 +00:00
Yifu Wang
d3edf4ebf4 [SymmetricMemoryOps] implement two-shot all-reduce (#137473)
## This Stack

Implement custom all-reduce algos available in `IntraNodeComm` as `symm_mem` ops and replace the existing `IntraNodeComm` kernels with them.

## This PR

Implement `symm_mem::two_shot_all_reduce_`. Later we'll replace the two-shot all-reduce in `IntraNodeComm` with these.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137473
Approved by: https://github.com/Chillee
ghstack dependencies: #137471, #137472
2024-10-09 03:49:42 +00:00
Yifu Wang
82e55b624f [SymmetricMemoryOps] implement one_shot_all_reduce (#137472)
## This Stack

Implement custom all-reduce algos available in `IntraNodeComm` as `symm_mem` ops and replace the existing `IntraNodeComm` kernels with them.

## This PR

Implement `symm_mem::one_shot_all_reduce` and `symm_mem::one_shot_all_reduce_out`. Later we'll replace the one-shot all-reduce in `IntraNodeComm` with these.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137472
Approved by: https://github.com/Chillee, https://github.com/weifengpy
ghstack dependencies: #137471
2024-10-09 03:49:42 +00:00
Yifu Wang
d55eef5c59 [SymmetricMemory] improve multicast initialization/fallback logic (#136577)
Fixes https://github.com/pytorch/pytorch/issues/136494

Currently, CUDASymmetricMemory::rendezvous() initializes a multicast address if multicast support is present. However, if we believe multicast support is present but cuMulticastCreate still fails for some reason, we do not fallback gracefully.

- In addition to CUDART and driver version check, query CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED to determine multicast support for a rank/device.
- Before initializing multicast for a block, ensure all ranks/devices have multicast support.
- This is unlikely, but if cuMulticastCreate still fails on rank 0, print the corresponding driver error message as a warning, and gracefully skip multicast initialization for the block.
- Introduced an environment variable (TORCH_SYMM_MEM_DISABLE_MULTICAST) to allow users to explicitly disable multicast support as a workaround.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136577
Approved by: https://github.com/Chillee, https://github.com/eqy
2024-09-27 20:04:21 +00:00
Yifu Wang
da1560c49f [SymmetricMemory] add support for cuStreamWriteValue32 (#136488)
cuStreamWriteValue efficiently combines the issuing of a system-level fence with the update of a single memory location. It is highly suitable for inter-stream progress sharing (e.g., all_gather_with_progress).

Exposing it via SymmetricMemory allows users to more easily implement efficient progress-aware matmuls in triton ([xformers example](https://github.com/facebookresearch/xformers/blob/main/xformers/ops/_triton/sequence_parallel_fused_kernels.py)).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136488
Approved by: https://github.com/eqy, https://github.com/Chillee
2024-09-24 20:56:29 +00:00