mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[SymmetricMemory] introduce multicast support, multimem_all_reduce_ and multimem_one_shot_all_reduce (#133424)"
This reverts commit 66d3eb783c.
Reverted https://github.com/pytorch/pytorch/pull/133424 on behalf of https://github.com/jeanschmidt due to Broke internal ADS builds, see D61611517 ([comment](https://github.com/pytorch/pytorch/pull/133424#issuecomment-2304676328))
This commit is contained in:
parent
592a172910
commit
cedfac20c7
|
|
@ -744,7 +744,6 @@ cc_library(
|
|||
"torch/csrc/cuda/nccl.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
|
||||
"torch/csrc/distributed/c10d/Utils.cu",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -688,7 +688,6 @@ libtorch_cuda_distributed_extra_sources = [
|
|||
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
|
||||
"torch/csrc/distributed/c10d/Utils.cu",
|
||||
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
|
|
|
|||
|
|
@ -20,12 +20,6 @@ DriverAPI create_driver_api() {
|
|||
C10_LIBCUDA_DRIVER_API(LOOKUP_LIBCUDA_ENTRY)
|
||||
#undef LOOKUP_LIBCUDA_ENTRY
|
||||
|
||||
#define LOOKUP_LIBCUDA_ENTRY(name) \
|
||||
r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \
|
||||
dlerror();
|
||||
C10_LIBCUDA_DRIVER_API_12030(LOOKUP_LIBCUDA_ENTRY)
|
||||
#undef LOOKUP_LIBCUDA_ENTRY
|
||||
|
||||
if (handle_1) {
|
||||
#define LOOKUP_NVML_ENTRY(name) \
|
||||
r.name##_ = ((decltype(&name))dlsym(handle_1, #name)); \
|
||||
|
|
|
|||
|
|
@ -31,15 +31,6 @@
|
|||
_(cuMemImportFromShareableHandle) \
|
||||
_(cuGetErrorString)
|
||||
|
||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
|
||||
#define C10_LIBCUDA_DRIVER_API_12030(_) \
|
||||
_(cuMulticastAddDevice) \
|
||||
_(cuMulticastBindMem) \
|
||||
_(cuMulticastCreate)
|
||||
#else
|
||||
#define C10_LIBCUDA_DRIVER_API_12030(_)
|
||||
#endif
|
||||
|
||||
#define C10_NVML_DRIVER_API(_) \
|
||||
_(nvmlInit_v2) \
|
||||
_(nvmlDeviceGetHandleByPciBusId_v2) \
|
||||
|
|
@ -52,7 +43,6 @@ namespace c10::cuda {
|
|||
struct DriverAPI {
|
||||
#define CREATE_MEMBER(name) decltype(&name) name##_;
|
||||
C10_LIBCUDA_DRIVER_API(CREATE_MEMBER)
|
||||
C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER)
|
||||
C10_NVML_DRIVER_API(CREATE_MEMBER)
|
||||
#undef CREATE_MEMBER
|
||||
static DriverAPI* get();
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._C._autograd import DeviceType
|
||||
from torch._C._distributed_c10d import _SymmetricMemory
|
||||
from torch.distributed._symmetric_memory import (
|
||||
_fused_all_gather_matmul_fallback,
|
||||
|
|
@ -45,17 +44,6 @@ def requires_cuda_p2p_access():
|
|||
)
|
||||
|
||||
|
||||
def requires_multicast_support():
|
||||
has_multicast_support = (
|
||||
torch.cuda.is_available()
|
||||
and _SymmetricMemory.has_multicast_support(DeviceType.CUDA)
|
||||
)
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
not has_multicast_support,
|
||||
"multicast support is not available",
|
||||
)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
@requires_cuda_p2p_access()
|
||||
class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
|
|
@ -107,6 +95,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
|||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_cuda_nvlink_connectivity_detection(self) -> None:
|
||||
from torch._C._autograd import DeviceType
|
||||
from torch._C._distributed_c10d import _detect_dma_connectivity
|
||||
|
||||
connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
|
||||
|
|
@ -433,73 +422,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
|||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_multicast_support()
|
||||
@parametrize("dtype", [torch.float, torch.bfloat16])
|
||||
@parametrize("align_bytes", [4, 8, 16])
|
||||
@parametrize("size_bytes", [4, 8192, 8196])
|
||||
def test_multimem_all_reduce(
|
||||
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
|
||||
) -> None:
|
||||
self._init_process()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(
|
||||
size=(16384,),
|
||||
stride=(1,),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
group_name=group_name,
|
||||
).fill_(1)
|
||||
|
||||
self.assertTrue(t.data_ptr() % 16 == 0)
|
||||
self.assertTrue(align_bytes % t.element_size() == 0)
|
||||
self.assertTrue(size_bytes % t.element_size() == 0)
|
||||
|
||||
shift = align_bytes // t.element_size()
|
||||
numel = size_bytes // t.element_size()
|
||||
x = t[shift : shift + numel]
|
||||
|
||||
torch.ops.symm_mem.multimem_all_reduce_(x, "sum", group_name)
|
||||
self.assertTrue(x.eq(self.world_size).all().item())
|
||||
|
||||
# Head and tail should not be written
|
||||
self.assertTrue(t[:shift].eq(1).all().item())
|
||||
self.assertTrue(t[shift + numel :].eq(1).all().item())
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_multicast_support()
|
||||
@parametrize("dtype", [torch.float, torch.bfloat16])
|
||||
@parametrize("align_bytes", [4, 8, 16])
|
||||
@parametrize("size_bytes", [4, 8192, 8196])
|
||||
def test_multimem_one_shot_all_reduce(
|
||||
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
|
||||
) -> None:
|
||||
self._init_process()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(
|
||||
size=(16384,),
|
||||
stride=(1,),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
group_name=group_name,
|
||||
).fill_(0)
|
||||
|
||||
self.assertTrue(t.data_ptr() % 16 == 0)
|
||||
self.assertTrue(align_bytes % t.element_size() == 0)
|
||||
self.assertTrue(size_bytes % t.element_size() == 0)
|
||||
|
||||
shift = align_bytes // t.element_size()
|
||||
numel = size_bytes // t.element_size()
|
||||
x = t[shift : shift + numel]
|
||||
x.fill_(1)
|
||||
|
||||
res = torch.ops.symm_mem.multimem_one_shot_all_reduce(x, "sum", group_name)
|
||||
self.assertTrue(res.eq(self.world_size).all().item())
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,256 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && CUDART_VERSION >= 12010
|
||||
#define NVCC_SUPPORTS_MULTICAST 1
|
||||
#endif
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace c10d::symmetric_memory {
|
||||
|
||||
constexpr size_t max_num_threads_per_block = 1024;
|
||||
constexpr size_t max_num_blocks = 8;
|
||||
|
||||
template <typename T>
|
||||
size_t get_alignment(T ptr_or_size) {
|
||||
auto val = reinterpret_cast<uintptr_t>(ptr_or_size);
|
||||
if (val % 16 == 0) {
|
||||
return 16;
|
||||
} else if (val % 8 == 0) {
|
||||
return 8;
|
||||
} else if (val % 4 == 0) {
|
||||
return 4;
|
||||
} else if (val % 2 == 0) {
|
||||
return 2;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t get_alignment<size_t>(size_t size) {
|
||||
return get_alignment(reinterpret_cast<void*>(size));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t
|
||||
cas_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
uint32_t old_val;
|
||||
asm volatile("atom.global.sys.cas.b32 %0, [%1], %2, %3;"
|
||||
: "=r"(old_val)
|
||||
: "l"(addr), "r"(compare), "r"(val)
|
||||
: "memory");
|
||||
return old_val;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t
|
||||
cas_release_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
uint32_t old_val;
|
||||
asm volatile("atom.global.release.sys.cas.b32 %0, [%1], %2, %3;"
|
||||
: "=r"(old_val)
|
||||
: "l"(addr), "r"(compare), "r"(val)
|
||||
: "memory");
|
||||
return old_val;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void release_signal(uint32_t* addr) {
|
||||
while (cas_release_sys(addr, 0, 1) != 0)
|
||||
;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void wait_signal(uint32_t* addr) {
|
||||
while (cas_sys(addr, 1, 0) != 1)
|
||||
;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t acquire_signal(uint32_t* addr) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
uint32_t val;
|
||||
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
|
||||
: "=r"(val)
|
||||
: "l"(addr)
|
||||
: "memory");
|
||||
return val;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Perform a barrier to establish observation order between memory operations
|
||||
// issued before and after the barrier.
|
||||
__device__ __forceinline__ void barrier(
|
||||
uint32_t** signal_pads,
|
||||
size_t rank,
|
||||
size_t world_size) {
|
||||
if (threadIdx.x < world_size) {
|
||||
auto target_rank = threadIdx.x;
|
||||
release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
|
||||
wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Perform a barrier and establish causality order between memory operations
|
||||
// issued before the calling kernel on all devices and memory operations
|
||||
// issued after this function by all thread in the calling kernel.
|
||||
//
|
||||
// NOTE: this function does NOT ensure that memory operations issues in the
|
||||
// current kernel are visible to all threads in the current kernel.
|
||||
//
|
||||
// | mem ops (guaranteed to be visible by all threads at point T)
|
||||
// | kernel K
|
||||
// | +- mem ops (not guaranteed to be visible all threads at point T)
|
||||
// | +- barrier_and_acquire_previous_kernel_writes()
|
||||
// | +- point T
|
||||
// v
|
||||
__device__ __forceinline__ void barrier_and_acquire_previous_kernel_writes(
|
||||
uint32_t** signal_pads,
|
||||
size_t rank,
|
||||
size_t world_size) {
|
||||
if (threadIdx.x < world_size) {
|
||||
auto target_rank = threadIdx.x;
|
||||
release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
|
||||
wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
|
||||
}
|
||||
__syncthreads();
|
||||
// At this point, we established observation order between memory operations
|
||||
// issued before and after the barrier. Now we convert the observation order
|
||||
// into causality order by having every thread acquire the signals released
|
||||
// by threads on peer devices. Due to the implicit synchronizes-with
|
||||
// relationships at task/kernel boundaries, acquiring the signal released by
|
||||
// thread T in kernel K transitively acquires memory operations issued prior
|
||||
// to kernel K.
|
||||
//
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-fence-interference
|
||||
for (size_t target_rank = 0; target_rank < world_size; ++target_rank) {
|
||||
acquire_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool Value, class... Args>
|
||||
inline constexpr bool dependent_bool_value = Value;
|
||||
|
||||
template <class... Args>
|
||||
inline constexpr bool dependent_false = dependent_bool_value<false, Args...>;
|
||||
|
||||
template <int Size>
|
||||
union Vec;
|
||||
|
||||
template <>
|
||||
union Vec<4> {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32, as_scalar;
|
||||
};
|
||||
|
||||
template <>
|
||||
union Vec<8> {
|
||||
uint16_t u16[4];
|
||||
uint32_t u32[2];
|
||||
uint64_t u64, as_scalar;
|
||||
};
|
||||
|
||||
template <>
|
||||
union alignas(16) Vec<16> {
|
||||
uint16_t u16[8];
|
||||
uint32_t u32[4];
|
||||
uint64_t u64[2];
|
||||
uint4 u128, as_scalar;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MultimemLdReduce {
|
||||
template <int Alignment>
|
||||
__device__ __inline__ Vec<Alignment> operator()(T* mc_ptr) {
|
||||
static_assert(dependent_false<T>);
|
||||
}
|
||||
};
|
||||
|
||||
template <int Alignment, typename T>
|
||||
__device__ __inline__ Vec<Alignment> multimem_ld_reduce_add(T* mc_ptr) {
|
||||
MultimemLdReduce<T> functor;
|
||||
return functor.template operator()<Alignment>(mc_ptr);
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
|
||||
#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \
|
||||
template <> \
|
||||
struct MultimemLdReduce<type> { \
|
||||
template <int Alignment> \
|
||||
__device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
|
||||
CUDA_KERNEL_ASSERT(false); \
|
||||
} \
|
||||
};
|
||||
#else
|
||||
#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \
|
||||
template <> \
|
||||
struct MultimemLdReduce<type> { \
|
||||
template <int Alignment> \
|
||||
__device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
|
||||
Vec<Alignment> vec; \
|
||||
if constexpr (Alignment == 16) { \
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v4." asm_type \
|
||||
" {%0,%1,%2,%3}, [%4];" \
|
||||
: "=r"(vec.u32[0]), \
|
||||
"=r"(vec.u32[1]), \
|
||||
"=r"(vec.u32[2]), \
|
||||
"=r"(vec.u32[3]) \
|
||||
: "l"(mc_ptr) \
|
||||
: "memory"); \
|
||||
} else if constexpr (Alignment == 8) { \
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v2." asm_type \
|
||||
" {%0,%1}, [%2];" \
|
||||
: "=r"(vec.u32[0]), "=r"(vec.u32[1]) \
|
||||
: "l"(mc_ptr) \
|
||||
: "memory"); \
|
||||
} else if constexpr (Alignment == 4) { \
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add." asm_type " %0, [%1];" \
|
||||
: "=r"(vec.u32) \
|
||||
: "l"(mc_ptr) \
|
||||
: "memory"); \
|
||||
} \
|
||||
return vec; \
|
||||
} \
|
||||
};
|
||||
#endif
|
||||
|
||||
SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(at::BFloat16, "bf16x2");
|
||||
SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(float, "f32");
|
||||
|
||||
template <int Alignment, typename T>
|
||||
__device__ __inline__ void multimem_st(T* mc_ptr, Vec<Alignment>& vec) {
|
||||
#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
if constexpr (Alignment == 16) {
|
||||
asm("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};"
|
||||
:
|
||||
: "l"(mc_ptr),
|
||||
"r"(vec.u32[0]),
|
||||
"r"(vec.u32[1]),
|
||||
"r"(vec.u32[2]),
|
||||
"r"(vec.u32[3])
|
||||
: "memory");
|
||||
} else if constexpr (Alignment == 8) {
|
||||
asm("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1,%2};"
|
||||
:
|
||||
: "l"(mc_ptr), "r"(vec.u32[0]), "r"(vec.u32[1])
|
||||
: "memory");
|
||||
} else if constexpr (Alignment == 4) {
|
||||
asm("multimem.st.relaxed.sys.global.f32 [%0], %1;"
|
||||
:
|
||||
: "l"(mc_ptr), "r"(vec.u32)
|
||||
: "memory");
|
||||
} else {
|
||||
static_assert(dependent_false<T>);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace c10d::symmetric_memory
|
||||
|
|
@ -14,20 +14,8 @@
|
|||
#include <sys/un.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
|
||||
#define CUDART_SUPPORTS_MULTICAST
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
bool has_multicast_support() {
|
||||
#if defined(CUDART_SUPPORTS_MULTICAST)
|
||||
return c10::cuda::DriverAPI::get()->cuMulticastCreate_ != nullptr;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
class IpcChannel {
|
||||
public:
|
||||
IpcChannel() : socket_name_(get_socket_name(getpid())) {
|
||||
|
|
@ -73,7 +61,9 @@ class IpcChannel {
|
|||
memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
|
||||
|
||||
TORCH_CHECK(
|
||||
sendmsg(socket_, &msg, 0) > 0, "Failed to send fd: ", strerror(errno));
|
||||
sendmsg(socket_, &msg, 0) > 0,
|
||||
"Failed to send fd: ",
|
||||
strerror(errno));
|
||||
}
|
||||
|
||||
int recv_fd() {
|
||||
|
|
@ -120,25 +110,6 @@ class IpcChannel {
|
|||
return fds;
|
||||
}
|
||||
|
||||
int broadcast_fds(
|
||||
int rank,
|
||||
int src_rank,
|
||||
const std::vector<int>& pids,
|
||||
int fd) {
|
||||
size_t world_size = pids.size();
|
||||
|
||||
if (rank == src_rank) {
|
||||
for (int dst_rank = 0; dst_rank < (int)world_size; ++dst_rank) {
|
||||
if (dst_rank == rank) {
|
||||
continue;
|
||||
}
|
||||
send_fd(pids[dst_rank], fd);
|
||||
}
|
||||
return fd;
|
||||
}
|
||||
return recv_fd();
|
||||
}
|
||||
|
||||
private:
|
||||
static std::string get_socket_name(int pid) {
|
||||
const char* tmp_dir = "/tmp";
|
||||
|
|
@ -242,8 +213,6 @@ CUDASymmetricMemory::CUDASymmetricMemory(
|
|||
size_t block_size,
|
||||
std::vector<void*> buffers,
|
||||
std::vector<void*> signal_pads,
|
||||
HandleType mc_handle,
|
||||
void* mc_addr,
|
||||
size_t buffer_size,
|
||||
int local_device_idx,
|
||||
int rank,
|
||||
|
|
@ -252,8 +221,6 @@ CUDASymmetricMemory::CUDASymmetricMemory(
|
|||
block_size_(block_size),
|
||||
buffers_(std::move(buffers)),
|
||||
signal_pads_(std::move(signal_pads)),
|
||||
mc_handle_(mc_handle),
|
||||
mc_addr_(mc_addr),
|
||||
buffer_size_(buffer_size),
|
||||
local_device_idx_(local_device_idx),
|
||||
rank_(rank),
|
||||
|
|
@ -318,14 +285,6 @@ size_t CUDASymmetricMemory::get_signal_pad_size() {
|
|||
return signal_pad_size;
|
||||
}
|
||||
|
||||
bool CUDASymmetricMemory::has_multicast_support() {
|
||||
return ::has_multicast_support();
|
||||
}
|
||||
|
||||
void* CUDASymmetricMemory::get_multicast_ptr() {
|
||||
return mc_addr_;
|
||||
}
|
||||
|
||||
at::Tensor CUDASymmetricMemory::get_buffer(
|
||||
int rank,
|
||||
c10::IntArrayRef sizes,
|
||||
|
|
@ -642,46 +601,6 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
|||
store_barrier(store, rank, world_size);
|
||||
close(block_fd);
|
||||
|
||||
CUmemGenericAllocationHandle mc_handle{};
|
||||
void* mc_addr = nullptr;
|
||||
#if defined(CUDART_SUPPORTS_MULTICAST)
|
||||
// We have to further check if the driver supports multicast
|
||||
if (has_multicast_support()) {
|
||||
// Rank 0 creates a multicast object and share it with peers
|
||||
if (rank == 0) {
|
||||
CUmulticastObjectProp mc_prop{};
|
||||
mc_prop.numDevices = world_size;
|
||||
mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
mc_prop.size = block->block_size;
|
||||
|
||||
CUresult res = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop);
|
||||
TORCH_CHECK(res == CUDA_SUCCESS);
|
||||
|
||||
int mc_fd;
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
|
||||
&mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
|
||||
ipc_channel.broadcast_fds(rank, 0, pids, mc_fd);
|
||||
// Ref count is incremented as soon as SCM_RIGHTS send happens
|
||||
close(mc_fd);
|
||||
} else {
|
||||
int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1);
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
|
||||
&mc_handle,
|
||||
(void*)(uintptr_t)mc_fd,
|
||||
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
|
||||
close(mc_fd);
|
||||
}
|
||||
// All rank adds their physical allocation to the multicast object
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
|
||||
mc_handle, 0, block->handle, 0, block->block_size, 0));
|
||||
|
||||
map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
|
||||
store_barrier(store, rank, world_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Initializing CUDASymmetricMemory with an allocation transfers its
|
||||
// ownership to the CUDASymmetricMemory object. So that outstanding
|
||||
// references to the CUDASymmetricMemory object can keep the allocation
|
||||
|
|
@ -691,8 +610,6 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
|||
block->block_size,
|
||||
std::move(buffers),
|
||||
std::move(signal_pads),
|
||||
mc_handle,
|
||||
mc_addr,
|
||||
block->buffer_size,
|
||||
block->device_idx,
|
||||
group_info.rank,
|
||||
|
|
@ -713,10 +630,6 @@ bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) {
|
|||
return block->symm_mem != nullptr;
|
||||
}
|
||||
|
||||
bool CUDASymmetricMemoryAllocator::has_multicast_support() {
|
||||
return ::has_multicast_support();
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block(void* ptr) {
|
||||
std::shared_lock lock(mutex_);
|
||||
auto it = ptr_to_block_.find(ptr);
|
||||
|
|
|
|||
|
|
@ -20,8 +20,6 @@ class CUDASymmetricMemory : public SymmetricMemory {
|
|||
size_t block_size,
|
||||
std::vector<void*> buffers,
|
||||
std::vector<void*> signal_pads,
|
||||
HandleType mc_handle,
|
||||
void* mc_addr,
|
||||
size_t buffer_size,
|
||||
int local_device_idx,
|
||||
int rank,
|
||||
|
|
@ -36,9 +34,6 @@ class CUDASymmetricMemory : public SymmetricMemory {
|
|||
size_t get_buffer_size() override;
|
||||
size_t get_signal_pad_size() override;
|
||||
|
||||
bool has_multicast_support() override;
|
||||
void* get_multicast_ptr() override;
|
||||
|
||||
at::Tensor get_buffer(
|
||||
int rank,
|
||||
c10::IntArrayRef sizes,
|
||||
|
|
@ -57,8 +52,6 @@ class CUDASymmetricMemory : public SymmetricMemory {
|
|||
size_t block_size_;
|
||||
std::vector<void*> buffers_;
|
||||
std::vector<void*> signal_pads_;
|
||||
HandleType mc_handle_;
|
||||
void* mc_addr_;
|
||||
size_t buffer_size_;
|
||||
int local_device_idx_;
|
||||
int rank_;
|
||||
|
|
@ -102,7 +95,6 @@ class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator {
|
|||
size_t get_alloc_size(void* ptr) override;
|
||||
c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override;
|
||||
bool is_rendezvous_completed(void* ptr) override;
|
||||
bool has_multicast_support() override;
|
||||
|
||||
private:
|
||||
c10::intrusive_ptr<Block> find_block(void* ptr);
|
||||
|
|
|
|||
|
|
@ -1,267 +0,0 @@
|
|||
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/empty_like.h>
|
||||
#endif
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace c10d::symmetric_memory;
|
||||
|
||||
size_t get_and_verify_alignment(const at::Tensor& input, const char* op_name) {
|
||||
const size_t min_alignment = std::max(4l, input.element_size());
|
||||
// Only check the offset since the multicast address is always at least
|
||||
// 128-bit aligned
|
||||
const size_t ptr_alignment = get_alignment(
|
||||
static_cast<size_t>(input.storage_offset() * input.element_size()));
|
||||
TORCH_CHECK(
|
||||
ptr_alignment >= min_alignment,
|
||||
op_name,
|
||||
"<",
|
||||
input.scalar_type(),
|
||||
">: input ptr + offset must be at least ",
|
||||
min_alignment,
|
||||
"-byte aligned.");
|
||||
|
||||
const size_t size_alignment =
|
||||
get_alignment(static_cast<size_t>(input.numel() * input.element_size()));
|
||||
TORCH_CHECK(
|
||||
size_alignment >= min_alignment,
|
||||
op_name,
|
||||
"<",
|
||||
input.scalar_type(),
|
||||
">: input size must be at least ",
|
||||
min_alignment,
|
||||
"-byte aligned.");
|
||||
return std::min(ptr_alignment, size_alignment);
|
||||
}
|
||||
|
||||
void init_elementwise_launch_config(
|
||||
size_t numel,
|
||||
size_t element_size,
|
||||
size_t alignment,
|
||||
size_t splits,
|
||||
int& num_blocks,
|
||||
int& num_threads) {
|
||||
// Align to preserve alignment in each split
|
||||
const size_t aligned_numel = at::round_up(numel, alignment * splits);
|
||||
const size_t numel_per_split = aligned_numel / splits;
|
||||
const size_t numel_per_thread = alignment / element_size;
|
||||
|
||||
if (numel_per_split <= max_num_threads_per_block * numel_per_thread) {
|
||||
num_blocks = 1;
|
||||
num_threads = at::round_up(
|
||||
at::ceil_div(numel_per_split, numel_per_thread),
|
||||
static_cast<size_t>(C10_WARP_SIZE));
|
||||
} else {
|
||||
num_blocks = std::min(
|
||||
at::ceil_div(
|
||||
numel_per_split, max_num_threads_per_block * numel_per_thread),
|
||||
max_num_blocks);
|
||||
num_threads = max_num_threads_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int alignment>
|
||||
static __global__ void multimem_all_reduce_kernel(
|
||||
T* input_mc_ptr,
|
||||
size_t numel,
|
||||
uint32_t** signal_pads,
|
||||
size_t rank,
|
||||
size_t world_size) {
|
||||
static_assert(alignment % sizeof(T) == 0);
|
||||
constexpr size_t numel_per_thread = alignment / sizeof(T);
|
||||
|
||||
barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
|
||||
|
||||
const size_t numel_per_rank =
|
||||
at::round_up(numel, alignment * world_size) / world_size;
|
||||
const size_t start = numel_per_rank * rank;
|
||||
|
||||
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
|
||||
auto stride = blockDim.x * gridDim.x * numel_per_thread;
|
||||
for (size_t i = offset; i < numel_per_rank; i += stride) {
|
||||
if (start + i >= numel) {
|
||||
continue;
|
||||
}
|
||||
auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + start + i);
|
||||
multimem_st<alignment>(input_mc_ptr + start + i, vec);
|
||||
}
|
||||
// Establish observation order - all writes are in-flight beyond this point.
|
||||
barrier(signal_pads, rank, world_size);
|
||||
// Establish causality order - all writes are visible to all devices beyond
|
||||
// this point.
|
||||
__threadfence_system();
|
||||
}
|
||||
|
||||
at::Tensor multimem_all_reduce_(
|
||||
const at::Tensor& input,
|
||||
std::string reduce_op,
|
||||
std::string group_name) {
|
||||
TORCH_CHECK(
|
||||
input.is_contiguous(), "multimem_all_reduce_: input must be contiguous.");
|
||||
TORCH_CHECK(
|
||||
reduce_op == "sum",
|
||||
"multimem_all_reduce_: only sum is supported for now.");
|
||||
|
||||
auto symm_mem = c10d::symmetric_memory::rendezvous(input);
|
||||
TORCH_CHECK(
|
||||
symm_mem != nullptr,
|
||||
"multimem_all_reduce_: input must be allocated with empty_strided_p2p().");
|
||||
TORCH_CHECK(
|
||||
symm_mem->has_multicast_support(),
|
||||
"multimem_all_reduce_: multicast support is required.");
|
||||
|
||||
const size_t alignment =
|
||||
get_and_verify_alignment(input, "multimem_all_reduce_");
|
||||
|
||||
int num_blocks = 0, num_threads = 0;
|
||||
init_elementwise_launch_config(
|
||||
input.numel(),
|
||||
input.element_size(),
|
||||
alignment,
|
||||
symm_mem->get_world_size(),
|
||||
num_blocks,
|
||||
num_threads);
|
||||
|
||||
#define DISPATCH(scalar_t, kernel_alignment) \
|
||||
if (alignment == kernel_alignment) { \
|
||||
multimem_all_reduce_kernel<scalar_t, kernel_alignment> \
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
|
||||
reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) + \
|
||||
input.storage_offset(), \
|
||||
input.numel(), \
|
||||
reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
|
||||
symm_mem->get_rank(), \
|
||||
symm_mem->get_world_size()); \
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
|
||||
}
|
||||
|
||||
AT_DISPATCH_SWITCH(
|
||||
input.scalar_type(),
|
||||
"multimem_all_reduce",
|
||||
AT_DISPATCH_CASE(at::kBFloat16, [&] {
|
||||
DISPATCH(scalar_t, 16);
|
||||
DISPATCH(scalar_t, 8);
|
||||
DISPATCH(scalar_t, 4);
|
||||
}) AT_DISPATCH_CASE(at::kFloat, [&] {
|
||||
DISPATCH(scalar_t, 16);
|
||||
DISPATCH(scalar_t, 8);
|
||||
DISPATCH(scalar_t, 4);
|
||||
}));
|
||||
|
||||
#undef DISPATCH
|
||||
return input;
|
||||
}
|
||||
|
||||
template <typename T, int alignment>
|
||||
static __global__ void multimem_one_shot_all_reduce_kernel(
|
||||
T* input_mc_ptr,
|
||||
T* output_ptr,
|
||||
size_t numel,
|
||||
uint32_t** signal_pads,
|
||||
size_t rank,
|
||||
size_t world_size) {
|
||||
static_assert(alignment % sizeof(T) == 0);
|
||||
constexpr size_t numel_per_thread = alignment / sizeof(T);
|
||||
|
||||
barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
|
||||
|
||||
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
|
||||
auto stride = blockDim.x * gridDim.x * numel_per_thread;
|
||||
for (size_t i = offset; i < numel; i += stride) {
|
||||
auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + i);
|
||||
*reinterpret_cast<decltype(vec.as_scalar)*>(output_ptr + i) = vec.as_scalar;
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor multimem_one_shot_all_reduce(
|
||||
const at::Tensor& input,
|
||||
std::string reduce_op,
|
||||
std::string group_name) {
|
||||
TORCH_CHECK(
|
||||
input.is_contiguous(),
|
||||
"multimem_one_shot_all_reduce: input must be contiguous.");
|
||||
TORCH_CHECK(
|
||||
reduce_op == "sum",
|
||||
"multimem_one_shot_all_reduce: only sum is supported for now.");
|
||||
|
||||
auto symm_mem = c10d::symmetric_memory::rendezvous(input);
|
||||
TORCH_CHECK(
|
||||
symm_mem != nullptr,
|
||||
"multimem_one_shot_all_reduce: input must be allocated with empty_strided_p2p().");
|
||||
TORCH_CHECK(
|
||||
symm_mem->has_multicast_support(),
|
||||
"multimem_one_shot_all_reduce: requires multicast support.");
|
||||
|
||||
auto output = at::empty_like(input);
|
||||
|
||||
const size_t alignment =
|
||||
get_and_verify_alignment(input, "multimem_one_shot_all_reduce");
|
||||
|
||||
int num_blocks = 0, num_threads = 0;
|
||||
init_elementwise_launch_config(
|
||||
input.numel(),
|
||||
input.element_size(),
|
||||
alignment,
|
||||
1,
|
||||
num_blocks,
|
||||
num_threads);
|
||||
|
||||
#define DISPATCH(scalar_t, kernel_alignment) \
|
||||
if (alignment == kernel_alignment) { \
|
||||
multimem_one_shot_all_reduce_kernel<scalar_t, kernel_alignment> \
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
|
||||
reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) + \
|
||||
input.storage_offset(), \
|
||||
output.data_ptr<scalar_t>(), \
|
||||
input.numel(), \
|
||||
reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
|
||||
symm_mem->get_rank(), \
|
||||
symm_mem->get_world_size()); \
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
|
||||
}
|
||||
|
||||
AT_DISPATCH_SWITCH(
|
||||
input.scalar_type(),
|
||||
"multimem_all_reduce",
|
||||
AT_DISPATCH_CASE(at::kBFloat16, [&] {
|
||||
DISPATCH(scalar_t, 16);
|
||||
DISPATCH(scalar_t, 8);
|
||||
DISPATCH(scalar_t, 4);
|
||||
}) AT_DISPATCH_CASE(at::kFloat, [&] {
|
||||
DISPATCH(scalar_t, 16);
|
||||
DISPATCH(scalar_t, 8);
|
||||
DISPATCH(scalar_t, 4);
|
||||
}));
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
||||
m.def(
|
||||
"multimem_all_reduce_(Tensor input, str reduce_op, str group_name) -> Tensor",
|
||||
torch::dispatch(c10::DispatchKey::CUDA, ::multimem_all_reduce_),
|
||||
{at::Tag::pt2_compliant_tag});
|
||||
|
||||
m.def(
|
||||
"multimem_one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
|
||||
torch::dispatch(c10::DispatchKey::CUDA, ::multimem_one_shot_all_reduce),
|
||||
{at::Tag::pt2_compliant_tag});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif
|
||||
|
|
@ -176,7 +176,7 @@ at::Tensor empty_strided_p2p(
|
|||
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
||||
const at::Tensor& tensor) {
|
||||
auto allocator = get_allocator(tensor.device().type());
|
||||
return allocator->rendezvous(tensor.storage().data_ptr().get());
|
||||
return allocator->rendezvous(tensor.data_ptr());
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
||||
|
|
@ -189,9 +189,5 @@ c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
|||
return allocator->rendezvous(tensor.data_ptr());
|
||||
}
|
||||
|
||||
TORCH_API bool has_multicast_support(c10::DeviceType device_type) {
|
||||
auto allocator = get_allocator(device_type);
|
||||
return allocator->has_multicast_support();
|
||||
}
|
||||
} // namespace symmetric_memory
|
||||
} // namespace c10d
|
||||
|
|
|
|||
|
|
@ -51,9 +51,6 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
|
|||
virtual size_t get_buffer_size() = 0;
|
||||
virtual size_t get_signal_pad_size() = 0;
|
||||
|
||||
virtual bool has_multicast_support() = 0;
|
||||
virtual void* get_multicast_ptr() = 0;
|
||||
|
||||
virtual at::Tensor get_buffer(
|
||||
int rank,
|
||||
c10::IntArrayRef sizes,
|
||||
|
|
@ -81,7 +78,6 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {
|
|||
virtual size_t get_alloc_size(void* ptr) = 0;
|
||||
virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) = 0;
|
||||
virtual bool is_rendezvous_completed(void* ptr) = 0;
|
||||
virtual bool has_multicast_support() = 0;
|
||||
};
|
||||
|
||||
C10_EXPORT bool is_finalizing();
|
||||
|
|
@ -154,6 +150,5 @@ TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
|||
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
||||
const at::Tensor& tensor);
|
||||
|
||||
TORCH_API bool has_multicast_support(c10::DeviceType device_type);
|
||||
} // namespace symmetric_memory
|
||||
} // namespace c10d
|
||||
|
|
|
|||
|
|
@ -1044,9 +1044,6 @@ This class does not support ``__members__`` property.)");
|
|||
.def_static(
|
||||
"get_symmetric_memory",
|
||||
&::c10d::symmetric_memory::get_symmetric_memory)
|
||||
.def_static(
|
||||
"has_multicast_support",
|
||||
&::c10d::symmetric_memory::has_multicast_support)
|
||||
.def_property_readonly("rank", &SymmetricMemory::get_rank)
|
||||
.def_property_readonly("world_size", &SymmetricMemory::get_world_size)
|
||||
.def_property_readonly(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user