mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[SymmetricMemory] introduce multicast support, multimem_all_reduce_ and multimem_one_shot_all_reduce (#133424)"
This reverts commit 66d3eb783c.
Reverted https://github.com/pytorch/pytorch/pull/133424 on behalf of https://github.com/jeanschmidt due to Broke internal ADS builds, see D61611517 ([comment](https://github.com/pytorch/pytorch/pull/133424#issuecomment-2304676328))
This commit is contained in:
parent
592a172910
commit
cedfac20c7
|
|
@ -744,7 +744,6 @@ cc_library(
|
||||||
"torch/csrc/cuda/nccl.cpp",
|
"torch/csrc/cuda/nccl.cpp",
|
||||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||||
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
|
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
|
||||||
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
|
|
||||||
"torch/csrc/distributed/c10d/Utils.cu",
|
"torch/csrc/distributed/c10d/Utils.cu",
|
||||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -688,7 +688,6 @@ libtorch_cuda_distributed_extra_sources = [
|
||||||
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
|
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
|
||||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||||
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
|
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
|
||||||
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
|
|
||||||
"torch/csrc/distributed/c10d/Utils.cu",
|
"torch/csrc/distributed/c10d/Utils.cu",
|
||||||
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
||||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,6 @@ DriverAPI create_driver_api() {
|
||||||
C10_LIBCUDA_DRIVER_API(LOOKUP_LIBCUDA_ENTRY)
|
C10_LIBCUDA_DRIVER_API(LOOKUP_LIBCUDA_ENTRY)
|
||||||
#undef LOOKUP_LIBCUDA_ENTRY
|
#undef LOOKUP_LIBCUDA_ENTRY
|
||||||
|
|
||||||
#define LOOKUP_LIBCUDA_ENTRY(name) \
|
|
||||||
r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \
|
|
||||||
dlerror();
|
|
||||||
C10_LIBCUDA_DRIVER_API_12030(LOOKUP_LIBCUDA_ENTRY)
|
|
||||||
#undef LOOKUP_LIBCUDA_ENTRY
|
|
||||||
|
|
||||||
if (handle_1) {
|
if (handle_1) {
|
||||||
#define LOOKUP_NVML_ENTRY(name) \
|
#define LOOKUP_NVML_ENTRY(name) \
|
||||||
r.name##_ = ((decltype(&name))dlsym(handle_1, #name)); \
|
r.name##_ = ((decltype(&name))dlsym(handle_1, #name)); \
|
||||||
|
|
|
||||||
|
|
@ -31,15 +31,6 @@
|
||||||
_(cuMemImportFromShareableHandle) \
|
_(cuMemImportFromShareableHandle) \
|
||||||
_(cuGetErrorString)
|
_(cuGetErrorString)
|
||||||
|
|
||||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
|
|
||||||
#define C10_LIBCUDA_DRIVER_API_12030(_) \
|
|
||||||
_(cuMulticastAddDevice) \
|
|
||||||
_(cuMulticastBindMem) \
|
|
||||||
_(cuMulticastCreate)
|
|
||||||
#else
|
|
||||||
#define C10_LIBCUDA_DRIVER_API_12030(_)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define C10_NVML_DRIVER_API(_) \
|
#define C10_NVML_DRIVER_API(_) \
|
||||||
_(nvmlInit_v2) \
|
_(nvmlInit_v2) \
|
||||||
_(nvmlDeviceGetHandleByPciBusId_v2) \
|
_(nvmlDeviceGetHandleByPciBusId_v2) \
|
||||||
|
|
@ -52,7 +43,6 @@ namespace c10::cuda {
|
||||||
struct DriverAPI {
|
struct DriverAPI {
|
||||||
#define CREATE_MEMBER(name) decltype(&name) name##_;
|
#define CREATE_MEMBER(name) decltype(&name) name##_;
|
||||||
C10_LIBCUDA_DRIVER_API(CREATE_MEMBER)
|
C10_LIBCUDA_DRIVER_API(CREATE_MEMBER)
|
||||||
C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER)
|
|
||||||
C10_NVML_DRIVER_API(CREATE_MEMBER)
|
C10_NVML_DRIVER_API(CREATE_MEMBER)
|
||||||
#undef CREATE_MEMBER
|
#undef CREATE_MEMBER
|
||||||
static DriverAPI* get();
|
static DriverAPI* get();
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch._C._autograd import DeviceType
|
|
||||||
from torch._C._distributed_c10d import _SymmetricMemory
|
from torch._C._distributed_c10d import _SymmetricMemory
|
||||||
from torch.distributed._symmetric_memory import (
|
from torch.distributed._symmetric_memory import (
|
||||||
_fused_all_gather_matmul_fallback,
|
_fused_all_gather_matmul_fallback,
|
||||||
|
|
@ -45,17 +44,6 @@ def requires_cuda_p2p_access():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def requires_multicast_support():
|
|
||||||
has_multicast_support = (
|
|
||||||
torch.cuda.is_available()
|
|
||||||
and _SymmetricMemory.has_multicast_support(DeviceType.CUDA)
|
|
||||||
)
|
|
||||||
return skip_but_pass_in_sandcastle_if(
|
|
||||||
not has_multicast_support,
|
|
||||||
"multicast support is not available",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@instantiate_parametrized_tests
|
@instantiate_parametrized_tests
|
||||||
@requires_cuda_p2p_access()
|
@requires_cuda_p2p_access()
|
||||||
class SymmetricMemoryTest(MultiProcessTestCase):
|
class SymmetricMemoryTest(MultiProcessTestCase):
|
||||||
|
|
@ -107,6 +95,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
def test_cuda_nvlink_connectivity_detection(self) -> None:
|
def test_cuda_nvlink_connectivity_detection(self) -> None:
|
||||||
|
from torch._C._autograd import DeviceType
|
||||||
from torch._C._distributed_c10d import _detect_dma_connectivity
|
from torch._C._distributed_c10d import _detect_dma_connectivity
|
||||||
|
|
||||||
connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
|
connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
|
||||||
|
|
@ -433,73 +422,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||||
|
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
|
||||||
@requires_multicast_support()
|
|
||||||
@parametrize("dtype", [torch.float, torch.bfloat16])
|
|
||||||
@parametrize("align_bytes", [4, 8, 16])
|
|
||||||
@parametrize("size_bytes", [4, 8192, 8196])
|
|
||||||
def test_multimem_all_reduce(
|
|
||||||
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
|
|
||||||
) -> None:
|
|
||||||
self._init_process()
|
|
||||||
group_name = dist.group.WORLD.group_name
|
|
||||||
|
|
||||||
t = _SymmetricMemory.empty_strided_p2p(
|
|
||||||
size=(16384,),
|
|
||||||
stride=(1,),
|
|
||||||
dtype=dtype,
|
|
||||||
device=self.device,
|
|
||||||
group_name=group_name,
|
|
||||||
).fill_(1)
|
|
||||||
|
|
||||||
self.assertTrue(t.data_ptr() % 16 == 0)
|
|
||||||
self.assertTrue(align_bytes % t.element_size() == 0)
|
|
||||||
self.assertTrue(size_bytes % t.element_size() == 0)
|
|
||||||
|
|
||||||
shift = align_bytes // t.element_size()
|
|
||||||
numel = size_bytes // t.element_size()
|
|
||||||
x = t[shift : shift + numel]
|
|
||||||
|
|
||||||
torch.ops.symm_mem.multimem_all_reduce_(x, "sum", group_name)
|
|
||||||
self.assertTrue(x.eq(self.world_size).all().item())
|
|
||||||
|
|
||||||
# Head and tail should not be written
|
|
||||||
self.assertTrue(t[:shift].eq(1).all().item())
|
|
||||||
self.assertTrue(t[shift + numel :].eq(1).all().item())
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
|
||||||
@requires_multicast_support()
|
|
||||||
@parametrize("dtype", [torch.float, torch.bfloat16])
|
|
||||||
@parametrize("align_bytes", [4, 8, 16])
|
|
||||||
@parametrize("size_bytes", [4, 8192, 8196])
|
|
||||||
def test_multimem_one_shot_all_reduce(
|
|
||||||
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
|
|
||||||
) -> None:
|
|
||||||
self._init_process()
|
|
||||||
group_name = dist.group.WORLD.group_name
|
|
||||||
|
|
||||||
t = _SymmetricMemory.empty_strided_p2p(
|
|
||||||
size=(16384,),
|
|
||||||
stride=(1,),
|
|
||||||
dtype=dtype,
|
|
||||||
device=self.device,
|
|
||||||
group_name=group_name,
|
|
||||||
).fill_(0)
|
|
||||||
|
|
||||||
self.assertTrue(t.data_ptr() % 16 == 0)
|
|
||||||
self.assertTrue(align_bytes % t.element_size() == 0)
|
|
||||||
self.assertTrue(size_bytes % t.element_size() == 0)
|
|
||||||
|
|
||||||
shift = align_bytes // t.element_size()
|
|
||||||
numel = size_bytes // t.element_size()
|
|
||||||
x = t[shift : shift + numel]
|
|
||||||
x.fill_(1)
|
|
||||||
|
|
||||||
res = torch.ops.symm_mem.multimem_one_shot_all_reduce(x, "sum", group_name)
|
|
||||||
self.assertTrue(res.eq(self.world_size).all().item())
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -1,256 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && CUDART_VERSION >= 12010
|
|
||||||
#define NVCC_SUPPORTS_MULTICAST 1
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
|
|
||||||
namespace c10d::symmetric_memory {
|
|
||||||
|
|
||||||
constexpr size_t max_num_threads_per_block = 1024;
|
|
||||||
constexpr size_t max_num_blocks = 8;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
size_t get_alignment(T ptr_or_size) {
|
|
||||||
auto val = reinterpret_cast<uintptr_t>(ptr_or_size);
|
|
||||||
if (val % 16 == 0) {
|
|
||||||
return 16;
|
|
||||||
} else if (val % 8 == 0) {
|
|
||||||
return 8;
|
|
||||||
} else if (val % 4 == 0) {
|
|
||||||
return 4;
|
|
||||||
} else if (val % 2 == 0) {
|
|
||||||
return 2;
|
|
||||||
} else {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
size_t get_alignment<size_t>(size_t size) {
|
|
||||||
return get_alignment(reinterpret_cast<void*>(size));
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ uint32_t
|
|
||||||
cas_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
|
|
||||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
|
||||||
CUDA_KERNEL_ASSERT(false);
|
|
||||||
#else
|
|
||||||
uint32_t old_val;
|
|
||||||
asm volatile("atom.global.sys.cas.b32 %0, [%1], %2, %3;"
|
|
||||||
: "=r"(old_val)
|
|
||||||
: "l"(addr), "r"(compare), "r"(val)
|
|
||||||
: "memory");
|
|
||||||
return old_val;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ uint32_t
|
|
||||||
cas_release_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
|
|
||||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
|
||||||
CUDA_KERNEL_ASSERT(false);
|
|
||||||
#else
|
|
||||||
uint32_t old_val;
|
|
||||||
asm volatile("atom.global.release.sys.cas.b32 %0, [%1], %2, %3;"
|
|
||||||
: "=r"(old_val)
|
|
||||||
: "l"(addr), "r"(compare), "r"(val)
|
|
||||||
: "memory");
|
|
||||||
return old_val;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void release_signal(uint32_t* addr) {
|
|
||||||
while (cas_release_sys(addr, 0, 1) != 0)
|
|
||||||
;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void wait_signal(uint32_t* addr) {
|
|
||||||
while (cas_sys(addr, 1, 0) != 1)
|
|
||||||
;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ uint32_t acquire_signal(uint32_t* addr) {
|
|
||||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
|
||||||
CUDA_KERNEL_ASSERT(false);
|
|
||||||
#else
|
|
||||||
uint32_t val;
|
|
||||||
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
|
|
||||||
: "=r"(val)
|
|
||||||
: "l"(addr)
|
|
||||||
: "memory");
|
|
||||||
return val;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform a barrier to establish observation order between memory operations
|
|
||||||
// issued before and after the barrier.
|
|
||||||
__device__ __forceinline__ void barrier(
|
|
||||||
uint32_t** signal_pads,
|
|
||||||
size_t rank,
|
|
||||||
size_t world_size) {
|
|
||||||
if (threadIdx.x < world_size) {
|
|
||||||
auto target_rank = threadIdx.x;
|
|
||||||
release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
|
|
||||||
wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform a barrier and establish causality order between memory operations
|
|
||||||
// issued before the calling kernel on all devices and memory operations
|
|
||||||
// issued after this function by all thread in the calling kernel.
|
|
||||||
//
|
|
||||||
// NOTE: this function does NOT ensure that memory operations issues in the
|
|
||||||
// current kernel are visible to all threads in the current kernel.
|
|
||||||
//
|
|
||||||
// | mem ops (guaranteed to be visible by all threads at point T)
|
|
||||||
// | kernel K
|
|
||||||
// | +- mem ops (not guaranteed to be visible all threads at point T)
|
|
||||||
// | +- barrier_and_acquire_previous_kernel_writes()
|
|
||||||
// | +- point T
|
|
||||||
// v
|
|
||||||
__device__ __forceinline__ void barrier_and_acquire_previous_kernel_writes(
|
|
||||||
uint32_t** signal_pads,
|
|
||||||
size_t rank,
|
|
||||||
size_t world_size) {
|
|
||||||
if (threadIdx.x < world_size) {
|
|
||||||
auto target_rank = threadIdx.x;
|
|
||||||
release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
|
|
||||||
wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
// At this point, we established observation order between memory operations
|
|
||||||
// issued before and after the barrier. Now we convert the observation order
|
|
||||||
// into causality order by having every thread acquire the signals released
|
|
||||||
// by threads on peer devices. Due to the implicit synchronizes-with
|
|
||||||
// relationships at task/kernel boundaries, acquiring the signal released by
|
|
||||||
// thread T in kernel K transitively acquires memory operations issued prior
|
|
||||||
// to kernel K.
|
|
||||||
//
|
|
||||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-fence-interference
|
|
||||||
for (size_t target_rank = 0; target_rank < world_size; ++target_rank) {
|
|
||||||
acquire_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool Value, class... Args>
|
|
||||||
inline constexpr bool dependent_bool_value = Value;
|
|
||||||
|
|
||||||
template <class... Args>
|
|
||||||
inline constexpr bool dependent_false = dependent_bool_value<false, Args...>;
|
|
||||||
|
|
||||||
template <int Size>
|
|
||||||
union Vec;
|
|
||||||
|
|
||||||
template <>
|
|
||||||
union Vec<4> {
|
|
||||||
uint16_t u16[2];
|
|
||||||
uint32_t u32, as_scalar;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
union Vec<8> {
|
|
||||||
uint16_t u16[4];
|
|
||||||
uint32_t u32[2];
|
|
||||||
uint64_t u64, as_scalar;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
union alignas(16) Vec<16> {
|
|
||||||
uint16_t u16[8];
|
|
||||||
uint32_t u32[4];
|
|
||||||
uint64_t u64[2];
|
|
||||||
uint4 u128, as_scalar;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct MultimemLdReduce {
|
|
||||||
template <int Alignment>
|
|
||||||
__device__ __inline__ Vec<Alignment> operator()(T* mc_ptr) {
|
|
||||||
static_assert(dependent_false<T>);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <int Alignment, typename T>
|
|
||||||
__device__ __inline__ Vec<Alignment> multimem_ld_reduce_add(T* mc_ptr) {
|
|
||||||
MultimemLdReduce<T> functor;
|
|
||||||
return functor.template operator()<Alignment>(mc_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
|
|
||||||
#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \
|
|
||||||
template <> \
|
|
||||||
struct MultimemLdReduce<type> { \
|
|
||||||
template <int Alignment> \
|
|
||||||
__device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
|
|
||||||
CUDA_KERNEL_ASSERT(false); \
|
|
||||||
} \
|
|
||||||
};
|
|
||||||
#else
|
|
||||||
#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \
|
|
||||||
template <> \
|
|
||||||
struct MultimemLdReduce<type> { \
|
|
||||||
template <int Alignment> \
|
|
||||||
__device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
|
|
||||||
Vec<Alignment> vec; \
|
|
||||||
if constexpr (Alignment == 16) { \
|
|
||||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v4." asm_type \
|
|
||||||
" {%0,%1,%2,%3}, [%4];" \
|
|
||||||
: "=r"(vec.u32[0]), \
|
|
||||||
"=r"(vec.u32[1]), \
|
|
||||||
"=r"(vec.u32[2]), \
|
|
||||||
"=r"(vec.u32[3]) \
|
|
||||||
: "l"(mc_ptr) \
|
|
||||||
: "memory"); \
|
|
||||||
} else if constexpr (Alignment == 8) { \
|
|
||||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v2." asm_type \
|
|
||||||
" {%0,%1}, [%2];" \
|
|
||||||
: "=r"(vec.u32[0]), "=r"(vec.u32[1]) \
|
|
||||||
: "l"(mc_ptr) \
|
|
||||||
: "memory"); \
|
|
||||||
} else if constexpr (Alignment == 4) { \
|
|
||||||
asm("multimem.ld_reduce.relaxed.sys.global.add." asm_type " %0, [%1];" \
|
|
||||||
: "=r"(vec.u32) \
|
|
||||||
: "l"(mc_ptr) \
|
|
||||||
: "memory"); \
|
|
||||||
} \
|
|
||||||
return vec; \
|
|
||||||
} \
|
|
||||||
};
|
|
||||||
#endif
|
|
||||||
|
|
||||||
SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(at::BFloat16, "bf16x2");
|
|
||||||
SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(float, "f32");
|
|
||||||
|
|
||||||
template <int Alignment, typename T>
|
|
||||||
__device__ __inline__ void multimem_st(T* mc_ptr, Vec<Alignment>& vec) {
|
|
||||||
#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
|
|
||||||
CUDA_KERNEL_ASSERT(false);
|
|
||||||
#else
|
|
||||||
if constexpr (Alignment == 16) {
|
|
||||||
asm("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};"
|
|
||||||
:
|
|
||||||
: "l"(mc_ptr),
|
|
||||||
"r"(vec.u32[0]),
|
|
||||||
"r"(vec.u32[1]),
|
|
||||||
"r"(vec.u32[2]),
|
|
||||||
"r"(vec.u32[3])
|
|
||||||
: "memory");
|
|
||||||
} else if constexpr (Alignment == 8) {
|
|
||||||
asm("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1,%2};"
|
|
||||||
:
|
|
||||||
: "l"(mc_ptr), "r"(vec.u32[0]), "r"(vec.u32[1])
|
|
||||||
: "memory");
|
|
||||||
} else if constexpr (Alignment == 4) {
|
|
||||||
asm("multimem.st.relaxed.sys.global.f32 [%0], %1;"
|
|
||||||
:
|
|
||||||
: "l"(mc_ptr), "r"(vec.u32)
|
|
||||||
: "memory");
|
|
||||||
} else {
|
|
||||||
static_assert(dependent_false<T>);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace c10d::symmetric_memory
|
|
||||||
|
|
@ -14,20 +14,8 @@
|
||||||
#include <sys/un.h>
|
#include <sys/un.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
|
||||||
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
|
|
||||||
#define CUDART_SUPPORTS_MULTICAST
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
bool has_multicast_support() {
|
|
||||||
#if defined(CUDART_SUPPORTS_MULTICAST)
|
|
||||||
return c10::cuda::DriverAPI::get()->cuMulticastCreate_ != nullptr;
|
|
||||||
#else
|
|
||||||
return false;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
class IpcChannel {
|
class IpcChannel {
|
||||||
public:
|
public:
|
||||||
IpcChannel() : socket_name_(get_socket_name(getpid())) {
|
IpcChannel() : socket_name_(get_socket_name(getpid())) {
|
||||||
|
|
@ -73,7 +61,9 @@ class IpcChannel {
|
||||||
memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
|
memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
|
||||||
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
sendmsg(socket_, &msg, 0) > 0, "Failed to send fd: ", strerror(errno));
|
sendmsg(socket_, &msg, 0) > 0,
|
||||||
|
"Failed to send fd: ",
|
||||||
|
strerror(errno));
|
||||||
}
|
}
|
||||||
|
|
||||||
int recv_fd() {
|
int recv_fd() {
|
||||||
|
|
@ -120,25 +110,6 @@ class IpcChannel {
|
||||||
return fds;
|
return fds;
|
||||||
}
|
}
|
||||||
|
|
||||||
int broadcast_fds(
|
|
||||||
int rank,
|
|
||||||
int src_rank,
|
|
||||||
const std::vector<int>& pids,
|
|
||||||
int fd) {
|
|
||||||
size_t world_size = pids.size();
|
|
||||||
|
|
||||||
if (rank == src_rank) {
|
|
||||||
for (int dst_rank = 0; dst_rank < (int)world_size; ++dst_rank) {
|
|
||||||
if (dst_rank == rank) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
send_fd(pids[dst_rank], fd);
|
|
||||||
}
|
|
||||||
return fd;
|
|
||||||
}
|
|
||||||
return recv_fd();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static std::string get_socket_name(int pid) {
|
static std::string get_socket_name(int pid) {
|
||||||
const char* tmp_dir = "/tmp";
|
const char* tmp_dir = "/tmp";
|
||||||
|
|
@ -242,8 +213,6 @@ CUDASymmetricMemory::CUDASymmetricMemory(
|
||||||
size_t block_size,
|
size_t block_size,
|
||||||
std::vector<void*> buffers,
|
std::vector<void*> buffers,
|
||||||
std::vector<void*> signal_pads,
|
std::vector<void*> signal_pads,
|
||||||
HandleType mc_handle,
|
|
||||||
void* mc_addr,
|
|
||||||
size_t buffer_size,
|
size_t buffer_size,
|
||||||
int local_device_idx,
|
int local_device_idx,
|
||||||
int rank,
|
int rank,
|
||||||
|
|
@ -252,8 +221,6 @@ CUDASymmetricMemory::CUDASymmetricMemory(
|
||||||
block_size_(block_size),
|
block_size_(block_size),
|
||||||
buffers_(std::move(buffers)),
|
buffers_(std::move(buffers)),
|
||||||
signal_pads_(std::move(signal_pads)),
|
signal_pads_(std::move(signal_pads)),
|
||||||
mc_handle_(mc_handle),
|
|
||||||
mc_addr_(mc_addr),
|
|
||||||
buffer_size_(buffer_size),
|
buffer_size_(buffer_size),
|
||||||
local_device_idx_(local_device_idx),
|
local_device_idx_(local_device_idx),
|
||||||
rank_(rank),
|
rank_(rank),
|
||||||
|
|
@ -318,14 +285,6 @@ size_t CUDASymmetricMemory::get_signal_pad_size() {
|
||||||
return signal_pad_size;
|
return signal_pad_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CUDASymmetricMemory::has_multicast_support() {
|
|
||||||
return ::has_multicast_support();
|
|
||||||
}
|
|
||||||
|
|
||||||
void* CUDASymmetricMemory::get_multicast_ptr() {
|
|
||||||
return mc_addr_;
|
|
||||||
}
|
|
||||||
|
|
||||||
at::Tensor CUDASymmetricMemory::get_buffer(
|
at::Tensor CUDASymmetricMemory::get_buffer(
|
||||||
int rank,
|
int rank,
|
||||||
c10::IntArrayRef sizes,
|
c10::IntArrayRef sizes,
|
||||||
|
|
@ -642,46 +601,6 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
||||||
store_barrier(store, rank, world_size);
|
store_barrier(store, rank, world_size);
|
||||||
close(block_fd);
|
close(block_fd);
|
||||||
|
|
||||||
CUmemGenericAllocationHandle mc_handle{};
|
|
||||||
void* mc_addr = nullptr;
|
|
||||||
#if defined(CUDART_SUPPORTS_MULTICAST)
|
|
||||||
// We have to further check if the driver supports multicast
|
|
||||||
if (has_multicast_support()) {
|
|
||||||
// Rank 0 creates a multicast object and share it with peers
|
|
||||||
if (rank == 0) {
|
|
||||||
CUmulticastObjectProp mc_prop{};
|
|
||||||
mc_prop.numDevices = world_size;
|
|
||||||
mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
|
||||||
mc_prop.size = block->block_size;
|
|
||||||
|
|
||||||
CUresult res = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop);
|
|
||||||
TORCH_CHECK(res == CUDA_SUCCESS);
|
|
||||||
|
|
||||||
int mc_fd;
|
|
||||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
|
|
||||||
&mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
|
|
||||||
ipc_channel.broadcast_fds(rank, 0, pids, mc_fd);
|
|
||||||
// Ref count is incremented as soon as SCM_RIGHTS send happens
|
|
||||||
close(mc_fd);
|
|
||||||
} else {
|
|
||||||
int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1);
|
|
||||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
|
|
||||||
&mc_handle,
|
|
||||||
(void*)(uintptr_t)mc_fd,
|
|
||||||
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
|
|
||||||
close(mc_fd);
|
|
||||||
}
|
|
||||||
// All rank adds their physical allocation to the multicast object
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
|
|
||||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
|
|
||||||
mc_handle, 0, block->handle, 0, block->block_size, 0));
|
|
||||||
|
|
||||||
map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
|
|
||||||
store_barrier(store, rank, world_size);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Initializing CUDASymmetricMemory with an allocation transfers its
|
// Initializing CUDASymmetricMemory with an allocation transfers its
|
||||||
// ownership to the CUDASymmetricMemory object. So that outstanding
|
// ownership to the CUDASymmetricMemory object. So that outstanding
|
||||||
// references to the CUDASymmetricMemory object can keep the allocation
|
// references to the CUDASymmetricMemory object can keep the allocation
|
||||||
|
|
@ -691,8 +610,6 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
||||||
block->block_size,
|
block->block_size,
|
||||||
std::move(buffers),
|
std::move(buffers),
|
||||||
std::move(signal_pads),
|
std::move(signal_pads),
|
||||||
mc_handle,
|
|
||||||
mc_addr,
|
|
||||||
block->buffer_size,
|
block->buffer_size,
|
||||||
block->device_idx,
|
block->device_idx,
|
||||||
group_info.rank,
|
group_info.rank,
|
||||||
|
|
@ -713,10 +630,6 @@ bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) {
|
||||||
return block->symm_mem != nullptr;
|
return block->symm_mem != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CUDASymmetricMemoryAllocator::has_multicast_support() {
|
|
||||||
return ::has_multicast_support();
|
|
||||||
}
|
|
||||||
|
|
||||||
c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block(void* ptr) {
|
c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block(void* ptr) {
|
||||||
std::shared_lock lock(mutex_);
|
std::shared_lock lock(mutex_);
|
||||||
auto it = ptr_to_block_.find(ptr);
|
auto it = ptr_to_block_.find(ptr);
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,6 @@ class CUDASymmetricMemory : public SymmetricMemory {
|
||||||
size_t block_size,
|
size_t block_size,
|
||||||
std::vector<void*> buffers,
|
std::vector<void*> buffers,
|
||||||
std::vector<void*> signal_pads,
|
std::vector<void*> signal_pads,
|
||||||
HandleType mc_handle,
|
|
||||||
void* mc_addr,
|
|
||||||
size_t buffer_size,
|
size_t buffer_size,
|
||||||
int local_device_idx,
|
int local_device_idx,
|
||||||
int rank,
|
int rank,
|
||||||
|
|
@ -36,9 +34,6 @@ class CUDASymmetricMemory : public SymmetricMemory {
|
||||||
size_t get_buffer_size() override;
|
size_t get_buffer_size() override;
|
||||||
size_t get_signal_pad_size() override;
|
size_t get_signal_pad_size() override;
|
||||||
|
|
||||||
bool has_multicast_support() override;
|
|
||||||
void* get_multicast_ptr() override;
|
|
||||||
|
|
||||||
at::Tensor get_buffer(
|
at::Tensor get_buffer(
|
||||||
int rank,
|
int rank,
|
||||||
c10::IntArrayRef sizes,
|
c10::IntArrayRef sizes,
|
||||||
|
|
@ -57,8 +52,6 @@ class CUDASymmetricMemory : public SymmetricMemory {
|
||||||
size_t block_size_;
|
size_t block_size_;
|
||||||
std::vector<void*> buffers_;
|
std::vector<void*> buffers_;
|
||||||
std::vector<void*> signal_pads_;
|
std::vector<void*> signal_pads_;
|
||||||
HandleType mc_handle_;
|
|
||||||
void* mc_addr_;
|
|
||||||
size_t buffer_size_;
|
size_t buffer_size_;
|
||||||
int local_device_idx_;
|
int local_device_idx_;
|
||||||
int rank_;
|
int rank_;
|
||||||
|
|
@ -102,7 +95,6 @@ class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator {
|
||||||
size_t get_alloc_size(void* ptr) override;
|
size_t get_alloc_size(void* ptr) override;
|
||||||
c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override;
|
c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override;
|
||||||
bool is_rendezvous_completed(void* ptr) override;
|
bool is_rendezvous_completed(void* ptr) override;
|
||||||
bool has_multicast_support() override;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
c10::intrusive_ptr<Block> find_block(void* ptr);
|
c10::intrusive_ptr<Block> find_block(void* ptr);
|
||||||
|
|
|
||||||
|
|
@ -1,267 +0,0 @@
|
||||||
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
#include <ATen/ceil_div.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
|
|
||||||
#ifndef AT_PER_OPERATOR_HEADERS
|
|
||||||
#include <ATen/Functions.h>
|
|
||||||
#include <ATen/NativeFunctions.h>
|
|
||||||
#else
|
|
||||||
#include <ATen/ops/empty_like.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <torch/library.h>
|
|
||||||
|
|
||||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
|
|
||||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using namespace c10d::symmetric_memory;
|
|
||||||
|
|
||||||
size_t get_and_verify_alignment(const at::Tensor& input, const char* op_name) {
|
|
||||||
const size_t min_alignment = std::max(4l, input.element_size());
|
|
||||||
// Only check the offset since the multicast address is always at least
|
|
||||||
// 128-bit aligned
|
|
||||||
const size_t ptr_alignment = get_alignment(
|
|
||||||
static_cast<size_t>(input.storage_offset() * input.element_size()));
|
|
||||||
TORCH_CHECK(
|
|
||||||
ptr_alignment >= min_alignment,
|
|
||||||
op_name,
|
|
||||||
"<",
|
|
||||||
input.scalar_type(),
|
|
||||||
">: input ptr + offset must be at least ",
|
|
||||||
min_alignment,
|
|
||||||
"-byte aligned.");
|
|
||||||
|
|
||||||
const size_t size_alignment =
|
|
||||||
get_alignment(static_cast<size_t>(input.numel() * input.element_size()));
|
|
||||||
TORCH_CHECK(
|
|
||||||
size_alignment >= min_alignment,
|
|
||||||
op_name,
|
|
||||||
"<",
|
|
||||||
input.scalar_type(),
|
|
||||||
">: input size must be at least ",
|
|
||||||
min_alignment,
|
|
||||||
"-byte aligned.");
|
|
||||||
return std::min(ptr_alignment, size_alignment);
|
|
||||||
}
|
|
||||||
|
|
||||||
void init_elementwise_launch_config(
|
|
||||||
size_t numel,
|
|
||||||
size_t element_size,
|
|
||||||
size_t alignment,
|
|
||||||
size_t splits,
|
|
||||||
int& num_blocks,
|
|
||||||
int& num_threads) {
|
|
||||||
// Align to preserve alignment in each split
|
|
||||||
const size_t aligned_numel = at::round_up(numel, alignment * splits);
|
|
||||||
const size_t numel_per_split = aligned_numel / splits;
|
|
||||||
const size_t numel_per_thread = alignment / element_size;
|
|
||||||
|
|
||||||
if (numel_per_split <= max_num_threads_per_block * numel_per_thread) {
|
|
||||||
num_blocks = 1;
|
|
||||||
num_threads = at::round_up(
|
|
||||||
at::ceil_div(numel_per_split, numel_per_thread),
|
|
||||||
static_cast<size_t>(C10_WARP_SIZE));
|
|
||||||
} else {
|
|
||||||
num_blocks = std::min(
|
|
||||||
at::ceil_div(
|
|
||||||
numel_per_split, max_num_threads_per_block * numel_per_thread),
|
|
||||||
max_num_blocks);
|
|
||||||
num_threads = max_num_threads_per_block;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, int alignment>
|
|
||||||
static __global__ void multimem_all_reduce_kernel(
|
|
||||||
T* input_mc_ptr,
|
|
||||||
size_t numel,
|
|
||||||
uint32_t** signal_pads,
|
|
||||||
size_t rank,
|
|
||||||
size_t world_size) {
|
|
||||||
static_assert(alignment % sizeof(T) == 0);
|
|
||||||
constexpr size_t numel_per_thread = alignment / sizeof(T);
|
|
||||||
|
|
||||||
barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
|
|
||||||
|
|
||||||
const size_t numel_per_rank =
|
|
||||||
at::round_up(numel, alignment * world_size) / world_size;
|
|
||||||
const size_t start = numel_per_rank * rank;
|
|
||||||
|
|
||||||
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
|
|
||||||
auto stride = blockDim.x * gridDim.x * numel_per_thread;
|
|
||||||
for (size_t i = offset; i < numel_per_rank; i += stride) {
|
|
||||||
if (start + i >= numel) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + start + i);
|
|
||||||
multimem_st<alignment>(input_mc_ptr + start + i, vec);
|
|
||||||
}
|
|
||||||
// Establish observation order - all writes are in-flight beyond this point.
|
|
||||||
barrier(signal_pads, rank, world_size);
|
|
||||||
// Establish causality order - all writes are visible to all devices beyond
|
|
||||||
// this point.
|
|
||||||
__threadfence_system();
|
|
||||||
}
|
|
||||||
|
|
||||||
at::Tensor multimem_all_reduce_(
|
|
||||||
const at::Tensor& input,
|
|
||||||
std::string reduce_op,
|
|
||||||
std::string group_name) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
input.is_contiguous(), "multimem_all_reduce_: input must be contiguous.");
|
|
||||||
TORCH_CHECK(
|
|
||||||
reduce_op == "sum",
|
|
||||||
"multimem_all_reduce_: only sum is supported for now.");
|
|
||||||
|
|
||||||
auto symm_mem = c10d::symmetric_memory::rendezvous(input);
|
|
||||||
TORCH_CHECK(
|
|
||||||
symm_mem != nullptr,
|
|
||||||
"multimem_all_reduce_: input must be allocated with empty_strided_p2p().");
|
|
||||||
TORCH_CHECK(
|
|
||||||
symm_mem->has_multicast_support(),
|
|
||||||
"multimem_all_reduce_: multicast support is required.");
|
|
||||||
|
|
||||||
const size_t alignment =
|
|
||||||
get_and_verify_alignment(input, "multimem_all_reduce_");
|
|
||||||
|
|
||||||
int num_blocks = 0, num_threads = 0;
|
|
||||||
init_elementwise_launch_config(
|
|
||||||
input.numel(),
|
|
||||||
input.element_size(),
|
|
||||||
alignment,
|
|
||||||
symm_mem->get_world_size(),
|
|
||||||
num_blocks,
|
|
||||||
num_threads);
|
|
||||||
|
|
||||||
#define DISPATCH(scalar_t, kernel_alignment) \
|
|
||||||
if (alignment == kernel_alignment) { \
|
|
||||||
multimem_all_reduce_kernel<scalar_t, kernel_alignment> \
|
|
||||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
|
|
||||||
reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) + \
|
|
||||||
input.storage_offset(), \
|
|
||||||
input.numel(), \
|
|
||||||
reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
|
|
||||||
symm_mem->get_rank(), \
|
|
||||||
symm_mem->get_world_size()); \
|
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
|
|
||||||
}
|
|
||||||
|
|
||||||
AT_DISPATCH_SWITCH(
|
|
||||||
input.scalar_type(),
|
|
||||||
"multimem_all_reduce",
|
|
||||||
AT_DISPATCH_CASE(at::kBFloat16, [&] {
|
|
||||||
DISPATCH(scalar_t, 16);
|
|
||||||
DISPATCH(scalar_t, 8);
|
|
||||||
DISPATCH(scalar_t, 4);
|
|
||||||
}) AT_DISPATCH_CASE(at::kFloat, [&] {
|
|
||||||
DISPATCH(scalar_t, 16);
|
|
||||||
DISPATCH(scalar_t, 8);
|
|
||||||
DISPATCH(scalar_t, 4);
|
|
||||||
}));
|
|
||||||
|
|
||||||
#undef DISPATCH
|
|
||||||
return input;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, int alignment>
|
|
||||||
static __global__ void multimem_one_shot_all_reduce_kernel(
|
|
||||||
T* input_mc_ptr,
|
|
||||||
T* output_ptr,
|
|
||||||
size_t numel,
|
|
||||||
uint32_t** signal_pads,
|
|
||||||
size_t rank,
|
|
||||||
size_t world_size) {
|
|
||||||
static_assert(alignment % sizeof(T) == 0);
|
|
||||||
constexpr size_t numel_per_thread = alignment / sizeof(T);
|
|
||||||
|
|
||||||
barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
|
|
||||||
|
|
||||||
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
|
|
||||||
auto stride = blockDim.x * gridDim.x * numel_per_thread;
|
|
||||||
for (size_t i = offset; i < numel; i += stride) {
|
|
||||||
auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + i);
|
|
||||||
*reinterpret_cast<decltype(vec.as_scalar)*>(output_ptr + i) = vec.as_scalar;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
at::Tensor multimem_one_shot_all_reduce(
|
|
||||||
const at::Tensor& input,
|
|
||||||
std::string reduce_op,
|
|
||||||
std::string group_name) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
input.is_contiguous(),
|
|
||||||
"multimem_one_shot_all_reduce: input must be contiguous.");
|
|
||||||
TORCH_CHECK(
|
|
||||||
reduce_op == "sum",
|
|
||||||
"multimem_one_shot_all_reduce: only sum is supported for now.");
|
|
||||||
|
|
||||||
auto symm_mem = c10d::symmetric_memory::rendezvous(input);
|
|
||||||
TORCH_CHECK(
|
|
||||||
symm_mem != nullptr,
|
|
||||||
"multimem_one_shot_all_reduce: input must be allocated with empty_strided_p2p().");
|
|
||||||
TORCH_CHECK(
|
|
||||||
symm_mem->has_multicast_support(),
|
|
||||||
"multimem_one_shot_all_reduce: requires multicast support.");
|
|
||||||
|
|
||||||
auto output = at::empty_like(input);
|
|
||||||
|
|
||||||
const size_t alignment =
|
|
||||||
get_and_verify_alignment(input, "multimem_one_shot_all_reduce");
|
|
||||||
|
|
||||||
int num_blocks = 0, num_threads = 0;
|
|
||||||
init_elementwise_launch_config(
|
|
||||||
input.numel(),
|
|
||||||
input.element_size(),
|
|
||||||
alignment,
|
|
||||||
1,
|
|
||||||
num_blocks,
|
|
||||||
num_threads);
|
|
||||||
|
|
||||||
#define DISPATCH(scalar_t, kernel_alignment) \
|
|
||||||
if (alignment == kernel_alignment) { \
|
|
||||||
multimem_one_shot_all_reduce_kernel<scalar_t, kernel_alignment> \
|
|
||||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
|
|
||||||
reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) + \
|
|
||||||
input.storage_offset(), \
|
|
||||||
output.data_ptr<scalar_t>(), \
|
|
||||||
input.numel(), \
|
|
||||||
reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
|
|
||||||
symm_mem->get_rank(), \
|
|
||||||
symm_mem->get_world_size()); \
|
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
|
|
||||||
}
|
|
||||||
|
|
||||||
AT_DISPATCH_SWITCH(
|
|
||||||
input.scalar_type(),
|
|
||||||
"multimem_all_reduce",
|
|
||||||
AT_DISPATCH_CASE(at::kBFloat16, [&] {
|
|
||||||
DISPATCH(scalar_t, 16);
|
|
||||||
DISPATCH(scalar_t, 8);
|
|
||||||
DISPATCH(scalar_t, 4);
|
|
||||||
}) AT_DISPATCH_CASE(at::kFloat, [&] {
|
|
||||||
DISPATCH(scalar_t, 16);
|
|
||||||
DISPATCH(scalar_t, 8);
|
|
||||||
DISPATCH(scalar_t, 4);
|
|
||||||
}));
|
|
||||||
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
|
||||||
m.def(
|
|
||||||
"multimem_all_reduce_(Tensor input, str reduce_op, str group_name) -> Tensor",
|
|
||||||
torch::dispatch(c10::DispatchKey::CUDA, ::multimem_all_reduce_),
|
|
||||||
{at::Tag::pt2_compliant_tag});
|
|
||||||
|
|
||||||
m.def(
|
|
||||||
"multimem_one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
|
|
||||||
torch::dispatch(c10::DispatchKey::CUDA, ::multimem_one_shot_all_reduce),
|
|
||||||
{at::Tag::pt2_compliant_tag});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
@ -176,7 +176,7 @@ at::Tensor empty_strided_p2p(
|
||||||
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
||||||
const at::Tensor& tensor) {
|
const at::Tensor& tensor) {
|
||||||
auto allocator = get_allocator(tensor.device().type());
|
auto allocator = get_allocator(tensor.device().type());
|
||||||
return allocator->rendezvous(tensor.storage().data_ptr().get());
|
return allocator->rendezvous(tensor.data_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
||||||
|
|
@ -189,9 +189,5 @@ c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
||||||
return allocator->rendezvous(tensor.data_ptr());
|
return allocator->rendezvous(tensor.data_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_API bool has_multicast_support(c10::DeviceType device_type) {
|
|
||||||
auto allocator = get_allocator(device_type);
|
|
||||||
return allocator->has_multicast_support();
|
|
||||||
}
|
|
||||||
} // namespace symmetric_memory
|
} // namespace symmetric_memory
|
||||||
} // namespace c10d
|
} // namespace c10d
|
||||||
|
|
|
||||||
|
|
@ -51,9 +51,6 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
|
||||||
virtual size_t get_buffer_size() = 0;
|
virtual size_t get_buffer_size() = 0;
|
||||||
virtual size_t get_signal_pad_size() = 0;
|
virtual size_t get_signal_pad_size() = 0;
|
||||||
|
|
||||||
virtual bool has_multicast_support() = 0;
|
|
||||||
virtual void* get_multicast_ptr() = 0;
|
|
||||||
|
|
||||||
virtual at::Tensor get_buffer(
|
virtual at::Tensor get_buffer(
|
||||||
int rank,
|
int rank,
|
||||||
c10::IntArrayRef sizes,
|
c10::IntArrayRef sizes,
|
||||||
|
|
@ -81,7 +78,6 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {
|
||||||
virtual size_t get_alloc_size(void* ptr) = 0;
|
virtual size_t get_alloc_size(void* ptr) = 0;
|
||||||
virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) = 0;
|
virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) = 0;
|
||||||
virtual bool is_rendezvous_completed(void* ptr) = 0;
|
virtual bool is_rendezvous_completed(void* ptr) = 0;
|
||||||
virtual bool has_multicast_support() = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
C10_EXPORT bool is_finalizing();
|
C10_EXPORT bool is_finalizing();
|
||||||
|
|
@ -154,6 +150,5 @@ TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
||||||
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
||||||
const at::Tensor& tensor);
|
const at::Tensor& tensor);
|
||||||
|
|
||||||
TORCH_API bool has_multicast_support(c10::DeviceType device_type);
|
|
||||||
} // namespace symmetric_memory
|
} // namespace symmetric_memory
|
||||||
} // namespace c10d
|
} // namespace c10d
|
||||||
|
|
|
||||||
|
|
@ -1044,9 +1044,6 @@ This class does not support ``__members__`` property.)");
|
||||||
.def_static(
|
.def_static(
|
||||||
"get_symmetric_memory",
|
"get_symmetric_memory",
|
||||||
&::c10d::symmetric_memory::get_symmetric_memory)
|
&::c10d::symmetric_memory::get_symmetric_memory)
|
||||||
.def_static(
|
|
||||||
"has_multicast_support",
|
|
||||||
&::c10d::symmetric_memory::has_multicast_support)
|
|
||||||
.def_property_readonly("rank", &SymmetricMemory::get_rank)
|
.def_property_readonly("rank", &SymmetricMemory::get_rank)
|
||||||
.def_property_readonly("world_size", &SymmetricMemory::get_world_size)
|
.def_property_readonly("world_size", &SymmetricMemory::get_world_size)
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user