`SymmetricMemory.has_multicast_support()` throws an exception rather than returning `False` when called with a `DeviceType` that does not support. For example:
```
from torch._C._distributed_c10d import _SymmetricMemory
from torch._C._autograd import DeviceType
try:
supports_multicast = _SymmetricMemory.has_multicast_support(DeviceType.CPU, 0)
except RuntimeError as exc:
assert str(exc) == "SymmetricMemory does not support device type cpu"
```
This is problematic when building PyTorch from source without `CUDASymmetricMemory.cu` since the [`@requires_multicast_support`](https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_distributed.py#L353) test decorator will throw an exception rather than skipping the test (as intended)
This PR makes `_SymmetricMemory.has_multicast_support()` properly return `False` when multicast is not supported on the passed device.
cc) @malfet , @atalman
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141598
Approved by: https://github.com/yifuwang
Before this PR, users need to call `empty_strided_p2p()` with a `group_name`:
```python
tensor = _SymmetricMemory.empty_strided_p2p((1024,), (1,), device=device, group_name="0")
symm_mem = _SymmetricMemory.rendezvous(tensor)
```
Users can now omit `group_name` at allocation time and specify it later at rendezvous time:
```python
tensor = _SymmetricMemory.empty_strided_p2p((1024,), (1,), device=device)
symm_mem = _SymmetricMemory.rendezvous(tensor, group_name="0")
```
Rationales for this change:
- This allows the same allocation to establish symmetric memory under different groups
- Specifying `group_name` at rendezvous time instead of allocation time is a more natural UX
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139529
Approved by: https://github.com/lw
This PR updates the binding for `stream_write_value32` to be consistent with `memset32` which IMO makes more sense for this type of utilities:
- Changed the API to take a uint32 tensor as argument, instead of a device pointer
- Changed the Python binding to be a static method of `_SymmetricMemory`, instead of a object method
- Use the dispatcher for device dispatching, as opposed to `SymmetricMemory` backends
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139934
Approved by: https://github.com/weifengpy
ghstack dependencies: #139227
This PR updates the binding for `stream_write_value32` to be consistent with `memset32` which IMO makes more sense for this type of utilities:
- Changed the API to take a uint32 tensor as argument, instead of a device pointer
- Changed the Python binding to be a static method of `_SymmetricMemory`, instead of a object method
- Use the dispatcher for device dispatching, as opposed to `SymmetricMemory` backends
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139934
Approved by: https://github.com/weifengpy
ghstack dependencies: #139227
## This Stack
This stack does the following things to support `xformers`-style, comm-aware Triton kernels:
- Exposes `signal_pad`s as tensors in Python
- Adds a binding for `cuMemsetAsync`
These in combination aims to provide users with more flexibility to express custom signaling/synchronization patterns.
## This PR
```python
# Obtain the signal pad of the specified peer rank as a tensor.
# If both shape and dtype are unspecified, the returned tensor will be a
# 1d uint32 tensor, which is most natural for signaling purposes.
symm_mem.get_signal_pad(peer_rank)
# If only shape is specified, it is equivalent to:
# symm_mem.get_signal_pad(peer_rank)[:shape.numel()].view(shape)
symm_mem.get_signal_pad(peer_rank, shape)
# If only dtype is specified, it is equivalent to:
# symm_mem.get_signal_pad(peer_rank).view(dtype)
symm_mem.get_signal_pad(peer_rank, dtype=dtype)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138754
Approved by: https://github.com/weifengpy, https://github.com/lw
Fixes https://github.com/pytorch/pytorch/issues/136494
Currently, CUDASymmetricMemory::rendezvous() initializes a multicast address if multicast support is present. However, if we believe multicast support is present but cuMulticastCreate still fails for some reason, we do not fallback gracefully.
- In addition to CUDART and driver version check, query CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED to determine multicast support for a rank/device.
- Before initializing multicast for a block, ensure all ranks/devices have multicast support.
- This is unlikely, but if cuMulticastCreate still fails on rank 0, print the corresponding driver error message as a warning, and gracefully skip multicast initialization for the block.
- Introduced an environment variable (TORCH_SYMM_MEM_DISABLE_MULTICAST) to allow users to explicitly disable multicast support as a workaround.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136577
Approved by: https://github.com/Chillee, https://github.com/eqy
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):
This PR introduces a prototype for `SymmetricMemory` (including a CUDA implementation) - a remote-memory access-based communication primitive. It allows for user-defined communication patterns/kernels and is designed to be torch.compile-friendly. It addresses the major limitations of `IntraNodeComm` and `ProcessGroupCudaP2p` and serves as a replacement for them.
### SymmetricMemory
`SymmetricMemory` represents symmetric allocations across a group of devices. The allocations represented by a `SymmetricMemory` object are accessible by all devices in the group. The class can be used for **op-level custom communication patterns** (via the get_buffer APIs and the synchronization primitives), as well as **custom communication kernels** (via the buffer and signal_pad device pointers).
### Python API Example
```python
from torch._C.distributed_c10d import _SymmetricMemory
# Set a store for rendezvousing symmetric allocations on a group of devices
# identified by group_name. The concept of groups is logical; users can
# utilize predefined groups (e.g., a group of device identified by a
# ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator
# backends might employ a more efficient communication channel for the actual
# rendezvous process and only use the store for bootstrapping purposes.
_SymmetricMemory.set_group_info(group_name, rank, world_size, store)
# Identical to empty_strided, but allows symmetric memory access to be
# established for the allocated tensor via _SymmetricMemory.rendezvous().
# This function itself is not a collective operation.
t = _SymmetricMemory.empty_strided_p2p((64, 64), (64, 1), torch.float32, group_name)
# Users can write Python custom ops that leverages the symmetric memory access.
# Below are examples of things users can do (assuming the group's world_size is 2).
# Establishes symmetric memory access on tensors allocated via
# _SymmetricMemory.empty_strided_p2p(). rendezvous() is a one-time process,
# and the mapping between a local memory region and the associated SymmetricMemory
# object is unique. Subsequent calls to rendezvous() with the same tensor will receive
# the cached SymmetricMemory object.
#
# The function has a collective semantic and must be invoked simultaneously
# from all rendezvous participants.
symm_mem = _SymmetricMemory.rendezvous(t)
# This represents the allocation on rank 0 and is accessible from all devices.
buf = symm_mem.get_buffer(0, (64, 64), torch.float32)
if symm_mem.rank == 0:
symm_mem.wait_signal(src_rank=1)
assert buf.eq(42).all()
else:
# The remote buffer can be used as a regular tensor
buf.fill_(42)
symm_mem.put_signal(dst_rank=0)
symm_mem.barrier()
if symm_mem.rank == 0:
symm_mem.barrier()
assert buf.eq(43).all()
else:
new_val = torch.empty_like(buf)
new_val.fill_(43)
# Contiguous copies to/from a remote buffer utilize copy engines
# which bypasses SMs (i.e. no need to load the data into registers)
buf.copy_(new_val)
symm_mem.barrier()
```
### Custom CUDA Comm Kernels
Given a tensor, users can access the associated `SymmetricMemory` which provides pointer to remote buffers/signal_pads needed for custom communication kernels.
```cpp
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
const at::Tensor& tensor);
class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
public:
...
virtual std::vector<void*> get_buffer_ptrs() = 0;
virtual std::vector<void*> get_signal_pad_ptrs() = 0;
virtual void** get_buffer_ptrs_dev() = 0;
virtual void** get_signal_pad_ptrs_dev() = 0;
virtual size_t get_buffer_size() = 0;
virtual size_t get_signal_pad_size() = 0;
virtual int get_rank() = 0;
virtual int get_world_size() = 0;
...
};
```
### Limitations of IntraNodeComm and ProcessGroupCudaP2p
Both `IntraNodeComm` (used by `ProcessGroupCudaP2p`) manages a single fixed-size workspace. This approach:
- Leads to awkward UX in which the required workspace needs to be specified upfront.
- Can not avoid extra copies for some algorithms in eager mode (e.g., custom/multimem all-reduce, reduce-scatter, all-gather).
- Prevents torch.compile from eliminating all copies.
In addition, they only offer out-of-the-box communication kernels and don't expose required pointers for user-defined, custom CUDA comm kernels.
* __->__ #128582
Differential Revision: [D58849033](https://our.internmc.facebook.com/intern/diff/D58849033)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128582
Approved by: https://github.com/wanchaol
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):
This PR introduces a prototype for `SymmetricMemory` (including a CUDA implementation) - a remote-memory access-based communication primitive. It allows for user-defined communication patterns/kernels and is designed to be torch.compile-friendly. It addresses the major limitations of `IntraNodeComm` and `ProcessGroupCudaP2p` and serves as a replacement for them.
### SymmetricMemory
`SymmetricMemory` represents symmetric allocations across a group of devices. The allocations represented by a `SymmetricMemory` object are accessible by all devices in the group. The class can be used for **op-level custom communication patterns** (via the get_buffer APIs and the synchronization primitives), as well as **custom communication kernels** (via the buffer and signal_pad device pointers).
### Python API Example
```python
from torch._C.distributed_c10d import _SymmetricMemory
# Set a store for rendezvousing symmetric allocations on a group of devices
# identified by group_name. The concept of groups is logical; users can
# utilize predefined groups (e.g., a group of device identified by a
# ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator
# backends might employ a more efficient communication channel for the actual
# rendezvous process and only use the store for bootstrapping purposes.
_SymmetricMemory.set_group_info(group_name, rank, world_size, store)
# Identical to empty_strided, but allows symmetric memory access to be
# established for the allocated tensor via _SymmetricMemory.rendezvous().
# This function itself is not a collective operation.
t = _SymmetricMemory.empty_strided_p2p((64, 64), (64, 1), torch.float32, group_name)
# Users can write Python custom ops that leverages the symmetric memory access.
# Below are examples of things users can do (assuming the group's world_size is 2).
# Establishes symmetric memory access on tensors allocated via
# _SymmetricMemory.empty_strided_p2p(). rendezvous() is a one-time process,
# and the mapping between a local memory region and the associated SymmetricMemory
# object is unique. Subsequent calls to rendezvous() with the same tensor will receive
# the cached SymmetricMemory object.
#
# The function has a collective semantic and must be invoked simultaneously
# from all rendezvous participants.
symm_mem = _SymmetricMemory.rendezvous(t)
# This represents the allocation on rank 0 and is accessible from all devices.
buf = symm_mem.get_buffer(0, (64, 64), torch.float32)
if symm_mem.rank == 0:
symm_mem.wait_signal(src_rank=1)
assert buf.eq(42).all()
else:
# The remote buffer can be used as a regular tensor
buf.fill_(42)
symm_mem.put_signal(dst_rank=0)
symm_mem.barrier()
if symm_mem.rank == 0:
symm_mem.barrier()
assert buf.eq(43).all()
else:
new_val = torch.empty_like(buf)
new_val.fill_(43)
# Contiguous copies to/from a remote buffer utilize copy engines
# which bypasses SMs (i.e. no need to load the data into registers)
buf.copy_(new_val)
symm_mem.barrier()
```
### Custom CUDA Comm Kernels
Given a tensor, users can access the associated `SymmetricMemory` which provides pointer to remote buffers/signal_pads needed for custom communication kernels.
```cpp
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
const at::Tensor& tensor);
class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
public:
...
virtual std::vector<void*> get_buffer_ptrs() = 0;
virtual std::vector<void*> get_signal_pad_ptrs() = 0;
virtual void** get_buffer_ptrs_dev() = 0;
virtual void** get_signal_pad_ptrs_dev() = 0;
virtual size_t get_buffer_size() = 0;
virtual size_t get_signal_pad_size() = 0;
virtual int get_rank() = 0;
virtual int get_world_size() = 0;
...
};
```
### Limitations of IntraNodeComm and ProcessGroupCudaP2p
Both `IntraNodeComm` (used by `ProcessGroupCudaP2p`) manages a single fixed-size workspace. This approach:
- Leads to awkward UX in which the required workspace needs to be specified upfront.
- Can not avoid extra copies for some algorithms in eager mode (e.g., custom/multimem all-reduce, reduce-scatter, all-gather).
- Prevents torch.compile from eliminating all copies.
In addition, they only offer out-of-the-box communication kernels and don't expose required pointers for user-defined, custom CUDA comm kernels.
* __->__ #128582
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128582
Approved by: https://github.com/wanchaol
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):
This PR introduces a prototype for `SymmetricMemory` (including a CUDA implementation) - a remote-memory access-based communication primitive. It allows for user-defined communication patterns/kernels and is designed to be torch.compile-friendly. It addresses the major limitations of `IntraNodeComm` and `ProcessGroupCudaP2p` and serves as a replacement for them.
### SymmetricMemory
`SymmetricMemory` represents symmetric allocations across a group of devices. The allocations represented by a `SymmetricMemory` object are accessible by all devices in the group. The class can be used for **op-level custom communication patterns** (via the get_buffer APIs and the synchronization primitives), as well as **custom communication kernels** (via the buffer and signal_pad device pointers).
### Python API Example
```python
from torch._C.distributed_c10d import _SymmetricMemory
# Set a store for rendezvousing symmetric allocations on a group of devices
# identified by group_name. The concept of groups is logical; users can
# utilize predefined groups (e.g., a group of device identified by a
# ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator
# backends might employ a more efficient communication channel for the actual
# rendezvous process and only use the store for bootstrapping purposes.
_SymmetricMemory.set_group_info(group_name, rank, world_size, store)
# Identical to empty_strided, but allows symmetric memory access to be
# established for the allocated tensor via _SymmetricMemory.rendezvous().
# This function itself is not a collective operation.
t = _SymmetricMemory.empty_strided_p2p((64, 64), (64, 1), torch.float32, group_name)
# Users can write Python custom ops that leverages the symmetric memory access.
# Below are examples of things users can do (assuming the group's world_size is 2).
# Establishes symmetric memory access on tensors allocated via
# _SymmetricMemory.empty_strided_p2p(). rendezvous() is a one-time process,
# and the mapping between a local memory region and the associated SymmetricMemory
# object is unique. Subsequent calls to rendezvous() with the same tensor will receive
# the cached SymmetricMemory object.
#
# The function has a collective semantic and must be invoked simultaneously
# from all rendezvous participants.
symm_mem = _SymmetricMemory.rendezvous(t)
# This represents the allocation on rank 0 and is accessible from all devices.
buf = symm_mem.get_buffer(0, (64, 64), torch.float32)
if symm_mem.rank == 0:
symm_mem.wait_signal(src_rank=1)
assert buf.eq(42).all()
else:
# The remote buffer can be used as a regular tensor
buf.fill_(42)
symm_mem.put_signal(dst_rank=0)
symm_mem.barrier()
if symm_mem.rank == 0:
symm_mem.barrier()
assert buf.eq(43).all()
else:
new_val = torch.empty_like(buf)
new_val.fill_(43)
# Contiguous copies to/from a remote buffer utilize copy engines
# which bypasses SMs (i.e. no need to load the data into registers)
buf.copy_(new_val)
symm_mem.barrier()
```
### Custom CUDA Comm Kernels
Given a tensor, users can access the associated `SymmetricMemory` which provides pointer to remote buffers/signal_pads needed for custom communication kernels.
```cpp
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
const at::Tensor& tensor);
class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
public:
...
virtual std::vector<void*> get_buffer_ptrs() = 0;
virtual std::vector<void*> get_signal_pad_ptrs() = 0;
virtual void** get_buffer_ptrs_dev() = 0;
virtual void** get_signal_pad_ptrs_dev() = 0;
virtual size_t get_buffer_size() = 0;
virtual size_t get_signal_pad_size() = 0;
virtual int get_rank() = 0;
virtual int get_world_size() = 0;
...
};
```
### Limitations of IntraNodeComm and ProcessGroupCudaP2p
Both `IntraNodeComm` (used by `ProcessGroupCudaP2p`) manages a single fixed-size workspace. This approach:
- Leads to awkward UX in which the required workspace needs to be specified upfront.
- Can not avoid extra copies for some algorithms in eager mode (e.g., custom/multimem all-reduce, reduce-scatter, all-gather).
- Prevents torch.compile from eliminating all copies.
In addition, they only offer out-of-the-box communication kernels and don't expose required pointers for user-defined, custom CUDA comm kernels.
* __->__ #128582
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128582
Approved by: https://github.com/wanchaol