Revert "[SymmetricMemory] introduce multicast support, multimem_all_reduce_ and multimem_one_shot_all_reduce (#133424)"

This reverts commit 66d3eb783c.

Reverted https://github.com/pytorch/pytorch/pull/133424 on behalf of https://github.com/jeanschmidt due to Broke internal ADS builds, see D61611517 ([comment](https://github.com/pytorch/pytorch/pull/133424#issuecomment-2304676328))
This commit is contained in:
PyTorch MergeBot 2024-08-22 13:29:27 +00:00
parent 592a172910
commit cedfac20c7
12 changed files with 5 additions and 731 deletions

View File

@ -744,7 +744,6 @@ cc_library(
"torch/csrc/cuda/nccl.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],

View File

@ -688,7 +688,6 @@ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",

View File

@ -20,12 +20,6 @@ DriverAPI create_driver_api() {
C10_LIBCUDA_DRIVER_API(LOOKUP_LIBCUDA_ENTRY)
#undef LOOKUP_LIBCUDA_ENTRY
#define LOOKUP_LIBCUDA_ENTRY(name) \
r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \
dlerror();
C10_LIBCUDA_DRIVER_API_12030(LOOKUP_LIBCUDA_ENTRY)
#undef LOOKUP_LIBCUDA_ENTRY
if (handle_1) {
#define LOOKUP_NVML_ENTRY(name) \
r.name##_ = ((decltype(&name))dlsym(handle_1, #name)); \

View File

@ -31,15 +31,6 @@
_(cuMemImportFromShareableHandle) \
_(cuGetErrorString)
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
#define C10_LIBCUDA_DRIVER_API_12030(_) \
_(cuMulticastAddDevice) \
_(cuMulticastBindMem) \
_(cuMulticastCreate)
#else
#define C10_LIBCUDA_DRIVER_API_12030(_)
#endif
#define C10_NVML_DRIVER_API(_) \
_(nvmlInit_v2) \
_(nvmlDeviceGetHandleByPciBusId_v2) \
@ -52,7 +43,6 @@ namespace c10::cuda {
struct DriverAPI {
#define CREATE_MEMBER(name) decltype(&name) name##_;
C10_LIBCUDA_DRIVER_API(CREATE_MEMBER)
C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER)
C10_NVML_DRIVER_API(CREATE_MEMBER)
#undef CREATE_MEMBER
static DriverAPI* get();

View File

@ -2,7 +2,6 @@
import torch
import torch.distributed as dist
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory
from torch.distributed._symmetric_memory import (
_fused_all_gather_matmul_fallback,
@ -45,17 +44,6 @@ def requires_cuda_p2p_access():
)
def requires_multicast_support():
has_multicast_support = (
torch.cuda.is_available()
and _SymmetricMemory.has_multicast_support(DeviceType.CUDA)
)
return skip_but_pass_in_sandcastle_if(
not has_multicast_support,
"multicast support is not available",
)
@instantiate_parametrized_tests
@requires_cuda_p2p_access()
class SymmetricMemoryTest(MultiProcessTestCase):
@ -107,6 +95,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_cuda_nvlink_connectivity_detection(self) -> None:
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _detect_dma_connectivity
connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
@ -433,73 +422,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
dist.destroy_process_group()
@skip_if_lt_x_gpu(2)
@requires_multicast_support()
@parametrize("dtype", [torch.float, torch.bfloat16])
@parametrize("align_bytes", [4, 8, 16])
@parametrize("size_bytes", [4, 8192, 8196])
def test_multimem_all_reduce(
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
) -> None:
self._init_process()
group_name = dist.group.WORLD.group_name
t = _SymmetricMemory.empty_strided_p2p(
size=(16384,),
stride=(1,),
dtype=dtype,
device=self.device,
group_name=group_name,
).fill_(1)
self.assertTrue(t.data_ptr() % 16 == 0)
self.assertTrue(align_bytes % t.element_size() == 0)
self.assertTrue(size_bytes % t.element_size() == 0)
shift = align_bytes // t.element_size()
numel = size_bytes // t.element_size()
x = t[shift : shift + numel]
torch.ops.symm_mem.multimem_all_reduce_(x, "sum", group_name)
self.assertTrue(x.eq(self.world_size).all().item())
# Head and tail should not be written
self.assertTrue(t[:shift].eq(1).all().item())
self.assertTrue(t[shift + numel :].eq(1).all().item())
dist.destroy_process_group()
@skip_if_lt_x_gpu(2)
@requires_multicast_support()
@parametrize("dtype", [torch.float, torch.bfloat16])
@parametrize("align_bytes", [4, 8, 16])
@parametrize("size_bytes", [4, 8192, 8196])
def test_multimem_one_shot_all_reduce(
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
) -> None:
self._init_process()
group_name = dist.group.WORLD.group_name
t = _SymmetricMemory.empty_strided_p2p(
size=(16384,),
stride=(1,),
dtype=dtype,
device=self.device,
group_name=group_name,
).fill_(0)
self.assertTrue(t.data_ptr() % 16 == 0)
self.assertTrue(align_bytes % t.element_size() == 0)
self.assertTrue(size_bytes % t.element_size() == 0)
shift = align_bytes // t.element_size()
numel = size_bytes // t.element_size()
x = t[shift : shift + numel]
x.fill_(1)
res = torch.ops.symm_mem.multimem_one_shot_all_reduce(x, "sum", group_name)
self.assertTrue(res.eq(self.world_size).all().item())
dist.destroy_process_group()
if __name__ == "__main__":
run_tests()

View File

@ -1,256 +0,0 @@
#pragma once
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && CUDART_VERSION >= 12010
#define NVCC_SUPPORTS_MULTICAST 1
#endif
#include <ATen/ATen.h>
namespace c10d::symmetric_memory {
constexpr size_t max_num_threads_per_block = 1024;
constexpr size_t max_num_blocks = 8;
template <typename T>
size_t get_alignment(T ptr_or_size) {
auto val = reinterpret_cast<uintptr_t>(ptr_or_size);
if (val % 16 == 0) {
return 16;
} else if (val % 8 == 0) {
return 8;
} else if (val % 4 == 0) {
return 4;
} else if (val % 2 == 0) {
return 2;
} else {
return 1;
}
}
template <>
size_t get_alignment<size_t>(size_t size) {
return get_alignment(reinterpret_cast<void*>(size));
}
__device__ __forceinline__ uint32_t
cas_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
uint32_t old_val;
asm volatile("atom.global.sys.cas.b32 %0, [%1], %2, %3;"
: "=r"(old_val)
: "l"(addr), "r"(compare), "r"(val)
: "memory");
return old_val;
#endif
}
__device__ __forceinline__ uint32_t
cas_release_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
uint32_t old_val;
asm volatile("atom.global.release.sys.cas.b32 %0, [%1], %2, %3;"
: "=r"(old_val)
: "l"(addr), "r"(compare), "r"(val)
: "memory");
return old_val;
#endif
}
__device__ __forceinline__ void release_signal(uint32_t* addr) {
while (cas_release_sys(addr, 0, 1) != 0)
;
}
__device__ __forceinline__ void wait_signal(uint32_t* addr) {
while (cas_sys(addr, 1, 0) != 1)
;
}
__device__ __forceinline__ uint32_t acquire_signal(uint32_t* addr) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
uint32_t val;
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
: "=r"(val)
: "l"(addr)
: "memory");
return val;
#endif
}
// Perform a barrier to establish observation order between memory operations
// issued before and after the barrier.
__device__ __forceinline__ void barrier(
uint32_t** signal_pads,
size_t rank,
size_t world_size) {
if (threadIdx.x < world_size) {
auto target_rank = threadIdx.x;
release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
}
__syncthreads();
}
// Perform a barrier and establish causality order between memory operations
// issued before the calling kernel on all devices and memory operations
// issued after this function by all thread in the calling kernel.
//
// NOTE: this function does NOT ensure that memory operations issues in the
// current kernel are visible to all threads in the current kernel.
//
// | mem ops (guaranteed to be visible by all threads at point T)
// | kernel K
// | +- mem ops (not guaranteed to be visible all threads at point T)
// | +- barrier_and_acquire_previous_kernel_writes()
// | +- point T
// v
__device__ __forceinline__ void barrier_and_acquire_previous_kernel_writes(
uint32_t** signal_pads,
size_t rank,
size_t world_size) {
if (threadIdx.x < world_size) {
auto target_rank = threadIdx.x;
release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
}
__syncthreads();
// At this point, we established observation order between memory operations
// issued before and after the barrier. Now we convert the observation order
// into causality order by having every thread acquire the signals released
// by threads on peer devices. Due to the implicit synchronizes-with
// relationships at task/kernel boundaries, acquiring the signal released by
// thread T in kernel K transitively acquires memory operations issued prior
// to kernel K.
//
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-fence-interference
for (size_t target_rank = 0; target_rank < world_size; ++target_rank) {
acquire_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
}
}
template <bool Value, class... Args>
inline constexpr bool dependent_bool_value = Value;
template <class... Args>
inline constexpr bool dependent_false = dependent_bool_value<false, Args...>;
template <int Size>
union Vec;
template <>
union Vec<4> {
uint16_t u16[2];
uint32_t u32, as_scalar;
};
template <>
union Vec<8> {
uint16_t u16[4];
uint32_t u32[2];
uint64_t u64, as_scalar;
};
template <>
union alignas(16) Vec<16> {
uint16_t u16[8];
uint32_t u32[4];
uint64_t u64[2];
uint4 u128, as_scalar;
};
template <typename T>
struct MultimemLdReduce {
template <int Alignment>
__device__ __inline__ Vec<Alignment> operator()(T* mc_ptr) {
static_assert(dependent_false<T>);
}
};
template <int Alignment, typename T>
__device__ __inline__ Vec<Alignment> multimem_ld_reduce_add(T* mc_ptr) {
MultimemLdReduce<T> functor;
return functor.template operator()<Alignment>(mc_ptr);
}
#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \
template <> \
struct MultimemLdReduce<type> { \
template <int Alignment> \
__device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
CUDA_KERNEL_ASSERT(false); \
} \
};
#else
#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \
template <> \
struct MultimemLdReduce<type> { \
template <int Alignment> \
__device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
Vec<Alignment> vec; \
if constexpr (Alignment == 16) { \
asm("multimem.ld_reduce.relaxed.sys.global.add.v4." asm_type \
" {%0,%1,%2,%3}, [%4];" \
: "=r"(vec.u32[0]), \
"=r"(vec.u32[1]), \
"=r"(vec.u32[2]), \
"=r"(vec.u32[3]) \
: "l"(mc_ptr) \
: "memory"); \
} else if constexpr (Alignment == 8) { \
asm("multimem.ld_reduce.relaxed.sys.global.add.v2." asm_type \
" {%0,%1}, [%2];" \
: "=r"(vec.u32[0]), "=r"(vec.u32[1]) \
: "l"(mc_ptr) \
: "memory"); \
} else if constexpr (Alignment == 4) { \
asm("multimem.ld_reduce.relaxed.sys.global.add." asm_type " %0, [%1];" \
: "=r"(vec.u32) \
: "l"(mc_ptr) \
: "memory"); \
} \
return vec; \
} \
};
#endif
SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(at::BFloat16, "bf16x2");
SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(float, "f32");
template <int Alignment, typename T>
__device__ __inline__ void multimem_st(T* mc_ptr, Vec<Alignment>& vec) {
#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
CUDA_KERNEL_ASSERT(false);
#else
if constexpr (Alignment == 16) {
asm("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};"
:
: "l"(mc_ptr),
"r"(vec.u32[0]),
"r"(vec.u32[1]),
"r"(vec.u32[2]),
"r"(vec.u32[3])
: "memory");
} else if constexpr (Alignment == 8) {
asm("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1,%2};"
:
: "l"(mc_ptr), "r"(vec.u32[0]), "r"(vec.u32[1])
: "memory");
} else if constexpr (Alignment == 4) {
asm("multimem.st.relaxed.sys.global.f32 [%0], %1;"
:
: "l"(mc_ptr), "r"(vec.u32)
: "memory");
} else {
static_assert(dependent_false<T>);
}
#endif
}
} // namespace c10d::symmetric_memory

View File

@ -14,20 +14,8 @@
#include <sys/un.h>
#include <unistd.h>
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
#define CUDART_SUPPORTS_MULTICAST
#endif
namespace {
bool has_multicast_support() {
#if defined(CUDART_SUPPORTS_MULTICAST)
return c10::cuda::DriverAPI::get()->cuMulticastCreate_ != nullptr;
#else
return false;
#endif
}
class IpcChannel {
public:
IpcChannel() : socket_name_(get_socket_name(getpid())) {
@ -73,7 +61,9 @@ class IpcChannel {
memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
TORCH_CHECK(
sendmsg(socket_, &msg, 0) > 0, "Failed to send fd: ", strerror(errno));
sendmsg(socket_, &msg, 0) > 0,
"Failed to send fd: ",
strerror(errno));
}
int recv_fd() {
@ -120,25 +110,6 @@ class IpcChannel {
return fds;
}
int broadcast_fds(
int rank,
int src_rank,
const std::vector<int>& pids,
int fd) {
size_t world_size = pids.size();
if (rank == src_rank) {
for (int dst_rank = 0; dst_rank < (int)world_size; ++dst_rank) {
if (dst_rank == rank) {
continue;
}
send_fd(pids[dst_rank], fd);
}
return fd;
}
return recv_fd();
}
private:
static std::string get_socket_name(int pid) {
const char* tmp_dir = "/tmp";
@ -242,8 +213,6 @@ CUDASymmetricMemory::CUDASymmetricMemory(
size_t block_size,
std::vector<void*> buffers,
std::vector<void*> signal_pads,
HandleType mc_handle,
void* mc_addr,
size_t buffer_size,
int local_device_idx,
int rank,
@ -252,8 +221,6 @@ CUDASymmetricMemory::CUDASymmetricMemory(
block_size_(block_size),
buffers_(std::move(buffers)),
signal_pads_(std::move(signal_pads)),
mc_handle_(mc_handle),
mc_addr_(mc_addr),
buffer_size_(buffer_size),
local_device_idx_(local_device_idx),
rank_(rank),
@ -318,14 +285,6 @@ size_t CUDASymmetricMemory::get_signal_pad_size() {
return signal_pad_size;
}
bool CUDASymmetricMemory::has_multicast_support() {
return ::has_multicast_support();
}
void* CUDASymmetricMemory::get_multicast_ptr() {
return mc_addr_;
}
at::Tensor CUDASymmetricMemory::get_buffer(
int rank,
c10::IntArrayRef sizes,
@ -642,46 +601,6 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
store_barrier(store, rank, world_size);
close(block_fd);
CUmemGenericAllocationHandle mc_handle{};
void* mc_addr = nullptr;
#if defined(CUDART_SUPPORTS_MULTICAST)
// We have to further check if the driver supports multicast
if (has_multicast_support()) {
// Rank 0 creates a multicast object and share it with peers
if (rank == 0) {
CUmulticastObjectProp mc_prop{};
mc_prop.numDevices = world_size;
mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
mc_prop.size = block->block_size;
CUresult res = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop);
TORCH_CHECK(res == CUDA_SUCCESS);
int mc_fd;
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
&mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
ipc_channel.broadcast_fds(rank, 0, pids, mc_fd);
// Ref count is incremented as soon as SCM_RIGHTS send happens
close(mc_fd);
} else {
int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1);
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
&mc_handle,
(void*)(uintptr_t)mc_fd,
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close(mc_fd);
}
// All rank adds their physical allocation to the multicast object
C10_CUDA_DRIVER_CHECK(
driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
mc_handle, 0, block->handle, 0, block->block_size, 0));
map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
store_barrier(store, rank, world_size);
}
#endif
// Initializing CUDASymmetricMemory with an allocation transfers its
// ownership to the CUDASymmetricMemory object. So that outstanding
// references to the CUDASymmetricMemory object can keep the allocation
@ -691,8 +610,6 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
block->block_size,
std::move(buffers),
std::move(signal_pads),
mc_handle,
mc_addr,
block->buffer_size,
block->device_idx,
group_info.rank,
@ -713,10 +630,6 @@ bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) {
return block->symm_mem != nullptr;
}
bool CUDASymmetricMemoryAllocator::has_multicast_support() {
return ::has_multicast_support();
}
c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block(void* ptr) {
std::shared_lock lock(mutex_);
auto it = ptr_to_block_.find(ptr);

View File

@ -20,8 +20,6 @@ class CUDASymmetricMemory : public SymmetricMemory {
size_t block_size,
std::vector<void*> buffers,
std::vector<void*> signal_pads,
HandleType mc_handle,
void* mc_addr,
size_t buffer_size,
int local_device_idx,
int rank,
@ -36,9 +34,6 @@ class CUDASymmetricMemory : public SymmetricMemory {
size_t get_buffer_size() override;
size_t get_signal_pad_size() override;
bool has_multicast_support() override;
void* get_multicast_ptr() override;
at::Tensor get_buffer(
int rank,
c10::IntArrayRef sizes,
@ -57,8 +52,6 @@ class CUDASymmetricMemory : public SymmetricMemory {
size_t block_size_;
std::vector<void*> buffers_;
std::vector<void*> signal_pads_;
HandleType mc_handle_;
void* mc_addr_;
size_t buffer_size_;
int local_device_idx_;
int rank_;
@ -102,7 +95,6 @@ class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator {
size_t get_alloc_size(void* ptr) override;
c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override;
bool is_rendezvous_completed(void* ptr) override;
bool has_multicast_support() override;
private:
c10::intrusive_ptr<Block> find_block(void* ptr);

View File

@ -1,267 +0,0 @@
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
#include <ATen/ATen.h>
#include <ATen/ceil_div.h>
#include <ATen/cuda/CUDAContext.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty_like.h>
#endif
#include <torch/library.h>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
namespace {
using namespace c10d::symmetric_memory;
size_t get_and_verify_alignment(const at::Tensor& input, const char* op_name) {
const size_t min_alignment = std::max(4l, input.element_size());
// Only check the offset since the multicast address is always at least
// 128-bit aligned
const size_t ptr_alignment = get_alignment(
static_cast<size_t>(input.storage_offset() * input.element_size()));
TORCH_CHECK(
ptr_alignment >= min_alignment,
op_name,
"<",
input.scalar_type(),
">: input ptr + offset must be at least ",
min_alignment,
"-byte aligned.");
const size_t size_alignment =
get_alignment(static_cast<size_t>(input.numel() * input.element_size()));
TORCH_CHECK(
size_alignment >= min_alignment,
op_name,
"<",
input.scalar_type(),
">: input size must be at least ",
min_alignment,
"-byte aligned.");
return std::min(ptr_alignment, size_alignment);
}
void init_elementwise_launch_config(
size_t numel,
size_t element_size,
size_t alignment,
size_t splits,
int& num_blocks,
int& num_threads) {
// Align to preserve alignment in each split
const size_t aligned_numel = at::round_up(numel, alignment * splits);
const size_t numel_per_split = aligned_numel / splits;
const size_t numel_per_thread = alignment / element_size;
if (numel_per_split <= max_num_threads_per_block * numel_per_thread) {
num_blocks = 1;
num_threads = at::round_up(
at::ceil_div(numel_per_split, numel_per_thread),
static_cast<size_t>(C10_WARP_SIZE));
} else {
num_blocks = std::min(
at::ceil_div(
numel_per_split, max_num_threads_per_block * numel_per_thread),
max_num_blocks);
num_threads = max_num_threads_per_block;
}
}
template <typename T, int alignment>
static __global__ void multimem_all_reduce_kernel(
T* input_mc_ptr,
size_t numel,
uint32_t** signal_pads,
size_t rank,
size_t world_size) {
static_assert(alignment % sizeof(T) == 0);
constexpr size_t numel_per_thread = alignment / sizeof(T);
barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
const size_t numel_per_rank =
at::round_up(numel, alignment * world_size) / world_size;
const size_t start = numel_per_rank * rank;
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
auto stride = blockDim.x * gridDim.x * numel_per_thread;
for (size_t i = offset; i < numel_per_rank; i += stride) {
if (start + i >= numel) {
continue;
}
auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + start + i);
multimem_st<alignment>(input_mc_ptr + start + i, vec);
}
// Establish observation order - all writes are in-flight beyond this point.
barrier(signal_pads, rank, world_size);
// Establish causality order - all writes are visible to all devices beyond
// this point.
__threadfence_system();
}
at::Tensor multimem_all_reduce_(
const at::Tensor& input,
std::string reduce_op,
std::string group_name) {
TORCH_CHECK(
input.is_contiguous(), "multimem_all_reduce_: input must be contiguous.");
TORCH_CHECK(
reduce_op == "sum",
"multimem_all_reduce_: only sum is supported for now.");
auto symm_mem = c10d::symmetric_memory::rendezvous(input);
TORCH_CHECK(
symm_mem != nullptr,
"multimem_all_reduce_: input must be allocated with empty_strided_p2p().");
TORCH_CHECK(
symm_mem->has_multicast_support(),
"multimem_all_reduce_: multicast support is required.");
const size_t alignment =
get_and_verify_alignment(input, "multimem_all_reduce_");
int num_blocks = 0, num_threads = 0;
init_elementwise_launch_config(
input.numel(),
input.element_size(),
alignment,
symm_mem->get_world_size(),
num_blocks,
num_threads);
#define DISPATCH(scalar_t, kernel_alignment) \
if (alignment == kernel_alignment) { \
multimem_all_reduce_kernel<scalar_t, kernel_alignment> \
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) + \
input.storage_offset(), \
input.numel(), \
reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
symm_mem->get_rank(), \
symm_mem->get_world_size()); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}
AT_DISPATCH_SWITCH(
input.scalar_type(),
"multimem_all_reduce",
AT_DISPATCH_CASE(at::kBFloat16, [&] {
DISPATCH(scalar_t, 16);
DISPATCH(scalar_t, 8);
DISPATCH(scalar_t, 4);
}) AT_DISPATCH_CASE(at::kFloat, [&] {
DISPATCH(scalar_t, 16);
DISPATCH(scalar_t, 8);
DISPATCH(scalar_t, 4);
}));
#undef DISPATCH
return input;
}
template <typename T, int alignment>
static __global__ void multimem_one_shot_all_reduce_kernel(
T* input_mc_ptr,
T* output_ptr,
size_t numel,
uint32_t** signal_pads,
size_t rank,
size_t world_size) {
static_assert(alignment % sizeof(T) == 0);
constexpr size_t numel_per_thread = alignment / sizeof(T);
barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
auto stride = blockDim.x * gridDim.x * numel_per_thread;
for (size_t i = offset; i < numel; i += stride) {
auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + i);
*reinterpret_cast<decltype(vec.as_scalar)*>(output_ptr + i) = vec.as_scalar;
}
}
at::Tensor multimem_one_shot_all_reduce(
const at::Tensor& input,
std::string reduce_op,
std::string group_name) {
TORCH_CHECK(
input.is_contiguous(),
"multimem_one_shot_all_reduce: input must be contiguous.");
TORCH_CHECK(
reduce_op == "sum",
"multimem_one_shot_all_reduce: only sum is supported for now.");
auto symm_mem = c10d::symmetric_memory::rendezvous(input);
TORCH_CHECK(
symm_mem != nullptr,
"multimem_one_shot_all_reduce: input must be allocated with empty_strided_p2p().");
TORCH_CHECK(
symm_mem->has_multicast_support(),
"multimem_one_shot_all_reduce: requires multicast support.");
auto output = at::empty_like(input);
const size_t alignment =
get_and_verify_alignment(input, "multimem_one_shot_all_reduce");
int num_blocks = 0, num_threads = 0;
init_elementwise_launch_config(
input.numel(),
input.element_size(),
alignment,
1,
num_blocks,
num_threads);
#define DISPATCH(scalar_t, kernel_alignment) \
if (alignment == kernel_alignment) { \
multimem_one_shot_all_reduce_kernel<scalar_t, kernel_alignment> \
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) + \
input.storage_offset(), \
output.data_ptr<scalar_t>(), \
input.numel(), \
reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
symm_mem->get_rank(), \
symm_mem->get_world_size()); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}
AT_DISPATCH_SWITCH(
input.scalar_type(),
"multimem_all_reduce",
AT_DISPATCH_CASE(at::kBFloat16, [&] {
DISPATCH(scalar_t, 16);
DISPATCH(scalar_t, 8);
DISPATCH(scalar_t, 4);
}) AT_DISPATCH_CASE(at::kFloat, [&] {
DISPATCH(scalar_t, 16);
DISPATCH(scalar_t, 8);
DISPATCH(scalar_t, 4);
}));
return output;
}
TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
m.def(
"multimem_all_reduce_(Tensor input, str reduce_op, str group_name) -> Tensor",
torch::dispatch(c10::DispatchKey::CUDA, ::multimem_all_reduce_),
{at::Tag::pt2_compliant_tag});
m.def(
"multimem_one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
torch::dispatch(c10::DispatchKey::CUDA, ::multimem_one_shot_all_reduce),
{at::Tag::pt2_compliant_tag});
}
} // namespace
#endif

View File

@ -176,7 +176,7 @@ at::Tensor empty_strided_p2p(
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
const at::Tensor& tensor) {
auto allocator = get_allocator(tensor.device().type());
return allocator->rendezvous(tensor.storage().data_ptr().get());
return allocator->rendezvous(tensor.data_ptr());
}
c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
@ -189,9 +189,5 @@ c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
return allocator->rendezvous(tensor.data_ptr());
}
TORCH_API bool has_multicast_support(c10::DeviceType device_type) {
auto allocator = get_allocator(device_type);
return allocator->has_multicast_support();
}
} // namespace symmetric_memory
} // namespace c10d

View File

@ -51,9 +51,6 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
virtual size_t get_buffer_size() = 0;
virtual size_t get_signal_pad_size() = 0;
virtual bool has_multicast_support() = 0;
virtual void* get_multicast_ptr() = 0;
virtual at::Tensor get_buffer(
int rank,
c10::IntArrayRef sizes,
@ -81,7 +78,6 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {
virtual size_t get_alloc_size(void* ptr) = 0;
virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) = 0;
virtual bool is_rendezvous_completed(void* ptr) = 0;
virtual bool has_multicast_support() = 0;
};
C10_EXPORT bool is_finalizing();
@ -154,6 +150,5 @@ TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
const at::Tensor& tensor);
TORCH_API bool has_multicast_support(c10::DeviceType device_type);
} // namespace symmetric_memory
} // namespace c10d

View File

@ -1044,9 +1044,6 @@ This class does not support ``__members__`` property.)");
.def_static(
"get_symmetric_memory",
&::c10d::symmetric_memory::get_symmetric_memory)
.def_static(
"has_multicast_support",
&::c10d::symmetric_memory::has_multicast_support)
.def_property_readonly("rank", &SymmetricMemory::get_rank)
.def_property_readonly("world_size", &SymmetricMemory::get_world_size)
.def_property_readonly(