diff --git a/BUILD.bazel b/BUILD.bazel index f5739a9875e..59d2ea857a1 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1452,7 +1452,10 @@ cu_library( # https://github.com/pytorch/pytorch/issues/79236 # To solve it we add it into the `caffe2_cuda`, # this is also aligned with the CMake build. - srcs = [":caffe2_cu_srcs"] + ["torch/csrc/distributed/c10d/quantization/quantization_gpu.cu"], + srcs = [":caffe2_cu_srcs"] + [ + "torch/csrc/distributed/c10d/intra_node_comm.cu", + "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", + ], copts = CAFFE2_COPTS + torch_cuda_half_options, visibility = ["//visibility:public"], deps = [ @@ -1619,6 +1622,7 @@ cc_library( exclude = [ "torch/csrc/cuda/python_nccl.cpp", "torch/csrc/cuda/nccl.cpp", + "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], )) + torch_sources, diff --git a/build_variables.bzl b/build_variables.bzl index 9d61861e3a3..b028a9b28c0 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -674,6 +674,8 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/ProcessGroupUCC.cpp", "torch/csrc/distributed/c10d/UCCTracing.cpp", "torch/csrc/distributed/c10d/UCCUtils.cpp", + "torch/csrc/distributed/c10d/intra_node_comm.cpp", + "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ] diff --git a/c10/cuda/driver_api.cpp b/c10/cuda/driver_api.cpp index a90081f6ccd..56243e663c7 100644 --- a/c10/cuda/driver_api.cpp +++ b/c10/cuda/driver_api.cpp @@ -37,7 +37,7 @@ void* DriverAPI::get_nvml_handle() { return nvml_hanle; } -DriverAPI* DriverAPI::get() { +C10_EXPORT DriverAPI* DriverAPI::get() { static DriverAPI singleton = create_driver_api(); return &singleton; } diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 6f5e46f18ed..f4054c23e44 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -28,9 +28,11 @@ _(cuMemCreate) \ _(cuGetErrorString) -#define C10_NVML_DRIVER_API(_) \ - _(nvmlInit_v2) \ - _(nvmlDeviceGetHandleByPciBusId_v2) \ +#define C10_NVML_DRIVER_API(_) \ + _(nvmlInit_v2) \ + _(nvmlDeviceGetHandleByPciBusId_v2) \ + _(nvmlDeviceGetNvLinkRemoteDeviceType) \ + _(nvmlDeviceGetNvLinkRemotePciInfo_v2) \ _(nvmlDeviceGetComputeRunningProcesses) namespace c10 { diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 748363725bc..f2acc61ad38 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -641,6 +641,10 @@ if(USE_CUDA) append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS) if(NOT WIN32) append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) + set_source_files_properties( + ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp + PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" + ) endif() endif() set_source_files_properties( diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 2992a262360..15fc8e353c6 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -15,7 +15,7 @@ import warnings from contextlib import contextmanager from datetime import datetime, timedelta from itertools import chain, product -from unittest import mock +from unittest import SkipTest, mock import torch import torch.distributed as c10d @@ -3113,6 +3113,65 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): for i, t in enumerate(tensors): self.assertEqual(t, torch.full_like(t, self.world_size * (i + (self.world_size + 1.) / 2.))) + @requires_nccl() + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_intra_node_comm_all_reduce(self): + from torch._C._distributed_c10d import _get_intra_node_comm_usage_counter + from torch.testing._internal.common_cuda import SM80OrLater + for peer in range(self.world_size): + if peer == self.rank: + continue + if not torch._C._cuda_canDeviceAccessPeer(self.rank, peer): + raise SkipTest("Test requires p2p access") + + if not SM80OrLater: + raise SkipTest("Test requires sm>=80") + + store = c10d.FileStore(self.file_name, self.world_size) + os.environ["ENABLE_INTRA_NODE_COMM"] = "1" + os.environ["TEST_INTRA_NODE_COMM"] = "1" + torch.cuda.set_device(self.rank) + c10d.init_process_group( + backend="nccl", rank=self.rank, world_size=self.world_size, store=store + ) + expect = self.world_size * (self.world_size - 1) // 2 + + # IntraNodeComm currently only supports sum and bf16. + # Verify that it is not used in the next two configurations. + t = torch.full((4 * 1024 // 2,), self.rank).cuda() + c10d.all_reduce(t, c10d.ReduceOp.SUM) + self.assertTrue(t.eq(expect).all()) + self.assertEqual(_get_intra_node_comm_usage_counter(), 0) + + t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda() + c10d.all_reduce(t, c10d.ReduceOp.AVG) + self.assertEqual(_get_intra_node_comm_usage_counter(), 0) + + # Verify that IntraNodeComm is used up to 10MB + t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda() + c10d.all_reduce(t, c10d.ReduceOp.SUM) + self.assertTrue(t.eq(expect).all()) + self.assertEqual(_get_intra_node_comm_usage_counter(), 1) + + t = torch.full((512 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda() + c10d.all_reduce(t, c10d.ReduceOp.SUM) + self.assertTrue(t.eq(expect).all()) + self.assertEqual(_get_intra_node_comm_usage_counter(), 2) + + t = torch.full((10 * 1024 ** 2 // 2,), self.rank, dtype=torch.bfloat16).cuda() + c10d.all_reduce(t, c10d.ReduceOp.SUM) + self.assertTrue(t.eq(expect).all()) + self.assertEqual(_get_intra_node_comm_usage_counter(), 3) + + # Verify that IntraNodeComm is not used beyond 10MB + t = torch.full((10 * 1024 ** 2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda() + c10d.all_reduce(t, c10d.ReduceOp.SUM) + self.assertTrue(t.eq(expect).all()) + self.assertEqual(_get_intra_node_comm_usage_counter(), 3) + + c10d.destroy_process_group() + @requires_nccl() @skip_if_lt_x_gpu(2) def test_sequence_num_set_default_pg_nccl(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 026576a6daa..0580ea360a8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -712,7 +712,8 @@ ProcessGroupNCCL::ProcessGroupNCCL( terminateProcessGroup_(false), terminateHeartbeatMonitorThread_(false), collectiveDebugInfoMode_(false), - uid_(process_group_id++) { + uid_(process_group_id++), + intraNodeComm_(initIntraNodeComm()) { TORCH_CHECK_WITH( ValueError, at::cuda::getNumGPUs() != 0, @@ -896,6 +897,12 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) { #endif } +c10::intrusive_ptr ProcessGroupNCCL:: + initIntraNodeComm() { + return intra_node_comm::IntraNodeComm::rendezvous( + store_, std::to_string(uid_), rank_, size_); +} + void ProcessGroupNCCL::runHealthCheck() { // Run health check in a separate thread and wait on CV to handle timeouts, // since majority of getNCCLComm failures are hangs. @@ -2842,6 +2849,16 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_impl( c10::intrusive_ptr ProcessGroupNCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { + if (intraNodeComm_ != nullptr && tensors.size() == 1 && + opts.reduceOp == ReduceOp::SUM) { + using namespace intra_node_comm; + auto algo = intraNodeComm_->selectAllReduceAlgo(tensors[0]); + if (algo != intra_node_comm::AllReduceAlgo::NONE) { + intraNodeComm_->allReduce(tensors[0], algo); + return c10::make_intrusive(); + } + } + check_gpu_tensors_different_devices(tensors); // @lint-ignore CLANGTIDY diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 00022b16521..4c07053982d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -552,6 +553,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { ncclCommsMap, c10::optional abortReason); + c10::intrusive_ptr initIntraNodeComm(); + // Provides an API to abort the ProcessGroup (similar to ncclCommAbort) // instead of relying on ProcessGroupNCCL destructor. void abort(c10::optional abortReason = c10::nullopt); @@ -950,6 +953,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::unique_ptr debugInfoWriter_ = nullptr; size_t uid_; + + c10::intrusive_ptr intraNodeComm_; }; TORCH_API std::string dump_nccl_trace(); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 252cf4a768e..22b02f0a167 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -21,6 +21,7 @@ #ifdef USE_C10D_NCCL #include #include +#include #endif #ifdef USE_C10D_MPI @@ -2328,6 +2329,10 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). "perform_nocolor_split", &::c10d::ProcessGroupNCCL::performNocolorSplit); + module.def( + "_get_intra_node_comm_usage_counter", + &::c10d::intra_node_comm::getIntraNodeCommUsageCounter); + #ifdef NCCL_HAS_COMM_CTA_CGA py::class_( processGroupNCCL, diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp new file mode 100644 index 00000000000..50b0147b300 --- /dev/null +++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp @@ -0,0 +1,485 @@ +#include + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +#include +#include +#endif + +#include + +namespace c10d { +namespace intra_node_comm { + +static std::vector ENABLE_INTRA_NODE_COMM = { + "ENABLE_INTRA_NODE_COMM"}; +// Forces detectedTopology() to return Topology::FULLY_CONNECTED, so +// IntraNodeComm can be used even without NVLink connection. This is only used +// for testing purposes. +static std::vector TEST_INTRA_NODE_COMM = {"TEST_INTRA_NODE_COMM"}; + +//////////////////////////////////////////////////////////////////////////////// +// CUDA Functions +//////////////////////////////////////////////////////////////////////////////// + +bool isIntraNodeCommSupported(); + +std::optional getHybridCubeMesh(NvlMesh nvlMesh); + +void* initP2pState(); + +void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank); + +AllReduceAlgo selectAllReduceAlgo( + const at::Tensor& input, + Topology topology, + size_t worldSize); + +at::Tensor allReduce( + const at::Tensor& input, + std::array p2pStates, + std::array buffers, + void* p2pStatesDev, + void* buffersDev, + void* topoInfo, + size_t rank, + size_t worldSize, + AllReduceAlgo algo, + at::cuda::CUDAStream& stream); + +//////////////////////////////////////////////////////////////////////////////// +// Topology Detection +//////////////////////////////////////////////////////////////////////////////// + +// TODO: find a better way to determine this +static constexpr size_t kMaxNvLinks = 20; + +static std::ostream& operator<<(std::ostream& os, const NvlMesh& nvlMesh) { + std::ostringstream oss; + for (size_t i = 0; i < kMaxDevices; ++i) { + for (size_t j = 0; j < kMaxDevices; ++j) { + oss << nvlMesh[i][j] << " "; + } + oss << std::endl; + } + os << oss.str(); + return os; +} + +static bool isSame(NvlMesh lhs, NvlMesh rhs) { + for (size_t i = 0; i < kMaxDevices; ++i) { + for (size_t j = 0; j < kMaxDevices; ++j) { + if (lhs[i][j] != rhs[i][j]) { + return false; + } + } + } + return true; +} + +/** + * Query the nvlink connection among devices. + */ +static NvlMesh getNvlMesh(std::vector rankToBusId) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + using namespace c10::cuda; + + NvlMesh nvlMesh = {}; + auto driverApi = DriverAPI::get(); + if (driverApi == nullptr) { + return nvlMesh; + } + + const auto worldSize = rankToBusId.size(); + std::vector devices(worldSize, 0); + std::unordered_map busIdToRank; + std::vector switchLinkCount(worldSize, 0); + + for (size_t r = 0; r < worldSize; ++r) { + busIdToRank.emplace(std::make_pair(rankToBusId[r], r)); + TORCH_CHECK( + driverApi->nvmlDeviceGetHandleByPciBusId_v2_( + rankToBusId[r].c_str(), &devices[r]) == NVML_SUCCESS); + } + + // For each device, loop over devices connected to it via NVLink + for (size_t idx = 0; idx < worldSize; ++idx) { + for (size_t link = 0; link < kMaxNvLinks; ++link) { + nvmlReturn_t ret; + nvmlIntNvLinkDeviceType_t deviceType; + ret = driverApi->nvmlDeviceGetNvLinkRemoteDeviceType_( + devices[idx], link, &deviceType); + if (ret != NVML_SUCCESS) { + // We've exhausted the NVLinks connected to this device. + // This error is benign. There doesn't seem to be a reliable + // way to obtain the maximum link value that can be passed to + // the API, so we simply increment the link value until the + // API fails or we hit a predefined maximum value. + break; + } + // Remote device is GPU + if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) { + nvmlPciInfo_t pciInfo; + ret = driverApi->nvmlDeviceGetNvLinkRemotePciInfo_v2_( + devices[idx], link, &pciInfo); + if (ret != NVML_SUCCESS) { + // Unexpected error. Return an empty NvlMesh + return {}; + } + auto it = busIdToRank.find(pciInfo.busId); + if (it != busIdToRank.end()) { + if (idx != it->second) { + nvlMesh[idx][it->second] += 1; + } + } + // Remote device is NVSwitch + } else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) { + switchLinkCount[idx] += 1; + } + } + } + // Process NVSwitch connections. For simplicity, we assume + // all NVSwitches are interconnected. + for (size_t i = 0; i < worldSize; ++i) { + for (size_t j = 0; j < worldSize; ++j) { + if (i == j) { + continue; + } + nvlMesh[i][j] += std::min(switchLinkCount[i], switchLinkCount[j]); + } + } + return nvlMesh; +#else + return {}; +#endif +} + +/** + * Determine if the devices form a hybrid cube mesh + * topology given a NvlMesh. + */ +static bool isHybridCubeMesh(const NvlMesh nvlMesh) { + std::array numNeighbors = {}; + for (size_t i = 0; i < kMaxDevices; ++i) { + for (size_t j = 0; j < kMaxDevices; ++j) { + if (nvlMesh[i][j] > 0) { + numNeighbors[i] += 1; + } + } + } + for (size_t i = 0; i < kMaxDevices; ++i) { + // TODO: this is insufficent and needs revisit + if (numNeighbors[i] != 4) { + return false; + } + } + return true; +} + +/** + * Detech topology given a NvlMesh. + */ +static Topology detectTopology(const NvlMesh nvlMesh, size_t worldSize) { + if (getCvarBool(TEST_INTRA_NODE_COMM, false)) { + return Topology::FULLY_CONNECTED; + } + bool fullyConnected = true; + for (size_t i = 0; i < worldSize - 1; ++i) { + for (size_t j = i + 1; j < worldSize; ++j) { + if (nvlMesh[i][j] == 0 || nvlMesh[j][i] == 0) { + fullyConnected = false; + } + } + } + if (fullyConnected) { + LOG(INFO) << "IntraNodeComm: Topology::FULLY_CONNECTED"; + return Topology::FULLY_CONNECTED; + } + if (worldSize == kMaxDevices && getHybridCubeMesh(nvlMesh) != std::nullopt) { + LOG(INFO) << "IntraNodeComm: Topology::HYBRID_CUBE_MESH"; + return Topology::HYBRID_CUBE_MESH; + } + LOG(INFO) << "IntraNodeComm: Topology::UNKNOWN"; + return Topology::UNKNOWN; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Rendezvous and Initialization +//////////////////////////////////////////////////////////////////////////////// + +IntraNodeComm::IntraNodeComm( + Topology topology, + std::array p2pStates, + std::array buffers, + void* p2pStatesDev, + void* buffersDev, + void* topoInfo, + size_t rank, + size_t worldSize) + : topology_(topology), + p2pStates_(p2pStates), + buffers_(buffers), + p2pStatesDev_(p2pStatesDev), + buffersDev_(buffersDev), + topoInfo_(topoInfo), + rank_(rank), + worldSize_(worldSize) {} + +IntraNodeComm::~IntraNodeComm() { + // Intentionally releasing resources without synchronizing devices. The + // teardown logic is safe for propoerly sync'd user program. We don't want + // improperly sync'd user program to hang here. + for (size_t r = 0; r < worldSize_; ++r) { + if (r == rank_) { + continue; + } + AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r])); + AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r])); + } + AT_CUDA_CHECK(cudaFree(p2pStates_[rank_])); + AT_CUDA_CHECK(cudaFree(buffers_[rank_])); + if (topoInfo_ != nullptr) { + AT_CUDA_CHECK(cudaFree(topoInfo_)); + } + AT_CUDA_CHECK(cudaFree(p2pStatesDev_)); + AT_CUDA_CHECK(cudaFree(buffersDev_)); +} + +/** + * Use c10d::Store to perform allgather on a trivially copyable type. + */ +template +std::vector storeAllGather( + c10::intrusive_ptr store, + const std::string& prefix, + size_t rank, + size_t worldSize, + T val) { + static_assert(std::is_trivially_copyable::value); + + std::vector peerKeys; + for (size_t r = 0; r < worldSize; ++r) { + std::ostringstream oss; + oss << prefix << "-" << r; + peerKeys.push_back(oss.str()); + } + + { + std::vector payload( + reinterpret_cast(&val), + reinterpret_cast(&val) + sizeof(T)); + store->set(peerKeys[rank], payload); + } + + std::vector peerVals; + for (size_t r = 0; r < worldSize; ++r) { + if (r == rank) { + peerVals.push_back(val); + continue; + } + store->wait({peerKeys[r]}); + auto payload = store->get(peerKeys[r]); + TORCH_CHECK(payload.size() == sizeof(T)); + T peerVal; + std::memcpy(&peerVal, payload.data(), sizeof(T)); + peerVals.push_back(peerVal); + } + return peerVals; +} + +c10::intrusive_ptr IntraNodeComm::rendezvous( + c10::intrusive_ptr store, + const std::string& prefix, + size_t rank, + size_t worldSize) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + if (!isIntraNodeCommSupported() || + !getCvarBool(ENABLE_INTRA_NODE_COMM, false) || worldSize < 2 || + worldSize > kMaxDevices) { + return nullptr; + } + + int deviceIdx = at::cuda::current_device(); + c10::cuda::CUDAGuard guard(deviceIdx); + + // First hand shake: exchange hostname and device bus ID + struct DevInfo { + char hostname[HOST_NAME_MAX + 1]; + char busId[80]; + }; + + DevInfo devInfo{}; + gethostname(devInfo.hostname, sizeof(devInfo.hostname)); + cudaDeviceProp prop{}; + AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx)); + snprintf( + devInfo.busId, + sizeof(devInfo.busId), + NVML_DEVICE_PCI_BUS_ID_FMT, + prop.pciDomainID, + prop.pciBusID, + prop.pciDeviceID); + + auto peerDevInfos = storeAllGather( + store, prefix + "-IntraNodeCommHandShake-0", rank, worldSize, devInfo); + + std::vector rankToBusId; + for (const auto& info : peerDevInfos) { + if (strcmp(info.hostname, peerDevInfos.front().hostname) != 0) { + LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some " + "participants are not on the same host (" + << info.hostname << ", " << devInfo.hostname << ")"; + return nullptr; + } + rankToBusId.emplace_back(info.busId); + } + + // Verify unique devices + { + std::unordered_set uniqueBusIds(rankToBusId.begin(), rankToBusId.end()); + TORCH_CHECK( + uniqueBusIds.size() == worldSize, + "IntraNodeComm::rendezvous: detected overlapping devices across ranks. " + "Please properly set device via torch.cuda.set_device() before " + "initiating rendezvous."); + } + + // Query nvlink connection + auto nvlMesh = getNvlMesh(rankToBusId); + + // Detect topology + Topology topology = detectTopology(nvlMesh, worldSize); + + // Initialize p2p state + auto p2pState = initP2pState(); + + // Allocate buffer + void* buffer = nullptr; + AT_CUDA_CHECK(cudaMalloc(&buffer, kMaxIntraNodeSize * 2)); + + // Second handshake: exchange topology and CUDA IPC handles + struct IpcInfo { + NvlMesh nvlMesh; + Topology topology; + cudaIpcMemHandle_t p2pStateHandle, bufferHandle; + }; + + // Make p2p state and buffer available for IPC + cudaIpcMemHandle_t p2pStateHandle, bufferHandle; + AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState)); + AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer)); + + IpcInfo ipcInfo{ + .nvlMesh = nvlMesh, + .topology = topology, + .p2pStateHandle = p2pStateHandle, + .bufferHandle = bufferHandle}; + + auto peerIpcInfos = storeAllGather( + store, prefix + "-IntraNodeCommHandShake-2", rank, worldSize, ipcInfo); + + for (const auto& info : peerIpcInfos) { + if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) || + info.topology != peerIpcInfos.front().topology) { + LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some " + "participants are observing different topologies (" + << int(info.topology) << " and " << int(topology) << ")"; + AT_CUDA_CHECK(cudaFree(p2pState)); + AT_CUDA_CHECK(cudaFree(buffer)); + return nullptr; + } + } + + std::array p2pStates = {}, buffers = {}; + for (size_t r = 0; r < peerIpcInfos.size(); ++r) { + if (r == rank) { + p2pStates[r] = p2pState; + buffers[r] = buffer; + } else { + AT_CUDA_CHECK(cudaIpcOpenMemHandle( + &p2pStates[r], + peerIpcInfos[r].p2pStateHandle, + cudaIpcMemLazyEnablePeerAccess)); + AT_CUDA_CHECK(cudaIpcOpenMemHandle( + &buffers[r], + peerIpcInfos[r].bufferHandle, + cudaIpcMemLazyEnablePeerAccess)); + } + } + void* p2pStatesDev = nullptr; + AT_CUDA_CHECK(cudaMalloc(&p2pStatesDev, sizeof(p2pStates))); + AT_CUDA_CHECK(cudaMemcpy( + p2pStatesDev, + p2pStates.data(), + sizeof(p2pStates), + cudaMemcpyHostToDevice)); + + void* buffersDev = nullptr; + AT_CUDA_CHECK(cudaMalloc(&buffersDev, sizeof(buffers))); + AT_CUDA_CHECK(cudaMemcpy( + buffersDev, buffers.data(), sizeof(buffers), cudaMemcpyHostToDevice)); + + void* topoInfo = initTopoInfo(topology, nvlMesh, rank); + return c10::make_intrusive( + topology, + p2pStates, + buffers, + p2pStatesDev, + buffersDev, + topoInfo, + rank, + worldSize); +#else + return nullptr; +#endif +} + +AllReduceAlgo IntraNodeComm::selectAllReduceAlgo(const at::Tensor& input) { + return c10d::intra_node_comm::selectAllReduceAlgo( + input, topology_, worldSize_); +} + +static int64_t usageCounter = 0; + +at::Tensor IntraNodeComm::allReduce( + const at::Tensor& input, + AllReduceAlgo algo) { + // Report usage for testing purposes. + // We don't care about overflowing. + ++usageCounter; + auto stream = at::cuda::getCurrentCUDAStream(); + c10::cuda::CUDACachingAllocator::recordStream( + input.storage().data_ptr(), stream); + return c10d::intra_node_comm::allReduce( + input, + p2pStates_, + buffers_, + p2pStatesDev_, + buffersDev_, + topoInfo_, + rank_, + worldSize_, + algo, + stream); +} + +int64_t getIntraNodeCommUsageCounter() { + return usageCounter; +} + +} // namespace intra_node_comm +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/intra_node_comm.cu new file mode 100644 index 00000000000..7723140a333 --- /dev/null +++ b/torch/csrc/distributed/c10d/intra_node_comm.cu @@ -0,0 +1,729 @@ +#include + +#include +#include +#include + +namespace c10d { +namespace intra_node_comm { + +static constexpr size_t kBytesPerThread = 16; +static constexpr size_t kMaxAllReduceBlocks = 24; +static constexpr size_t kThreadsPerBlock = 1024; +static constexpr size_t kWarpSize = 32; + +static constexpr size_t kHcmThreshBytes = 256 * 1024; +static constexpr size_t kOneShotThreshBytes = 256 * 1024; +static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024; + +#if defined(USE_ROCM) +using __nv_bfloat162 = uint32_t; +#endif + +struct __align__(16) bf16x8 { + __nv_bfloat162 vals[4]; +}; + +#define DEVICE_INLINE __device__ inline __attribute__((always_inline)) + +DEVICE_INLINE __nv_bfloat162 +bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + return __hadd2(x, y); +#endif +} + +DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) { + bf16x8 c; + c.vals[0] = bf16hadd2(a.vals[0], b.vals[0]); + c.vals[1] = bf16hadd2(a.vals[1], b.vals[1]); + c.vals[2] = bf16hadd2(a.vals[2], b.vals[2]); + c.vals[3] = bf16hadd2(a.vals[3], b.vals[3]); + return c; +} + +/** + * NOTE [cross device memory synchronization] + * + * The multi-stage algorithms (e.g. two-shot, hcm allreduce) require the writes + * of a thread to be visible by threads with the same block/thread ID on other + * devices. To satisfy CUDA's memory consistency model, every thread has to + * release its writes at the system scope, and the consuming thread has to + * acquire the writes at the system scope. This incurs high overhead and + * attempts in optmizing this process can be prone to race condition. + * + * Instead, we go around caching by having each thread: + * + * - Directly write to global memory via st.cs (cache-streaming). + * - Synchronize with threads within the block. + * - Perform cross device synchronization at block level (via system scope + * atomic ops). + * - Synchronize with threads within the block. + * - Directly read from global memory via ld.nc (non-coherent/non-cached). + */ +template +DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + unsigned long long int low, high; + asm("ld.global.nc.v2.u64 {%0, %1}, [%2];" + : "=l"(low), "=l"(high) + : "l"(addr)); + reinterpret_cast(&val)[0] = low; + reinterpret_cast(&val)[1] = high; +#endif +} + +__device__ inline void streamStore128(at::BFloat16* addr, const bf16x8& val) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + unsigned long long int low, high; + low = reinterpret_cast(&val)[0]; + high = reinterpret_cast(&val)[1]; + asm("st.global.cs.v2.u64 [%0], {%1, %2};" : : "l"(addr), "l"(low), "l"(high)); +#endif +} + +template +DEVICE_INLINE void load128(bf16x8& val, const T* addr) { + *reinterpret_cast(&val) = reinterpret_cast(addr)[0]; +} + +template +DEVICE_INLINE void store128(T* addr, const bf16x8& val) { + *reinterpret_cast(addr) = reinterpret_cast(&val)[0]; +} + +DEVICE_INLINE void releaseSignal(uint32_t* addr) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + atomicAdd_system(addr, 1); +#endif +} + +DEVICE_INLINE void acquireSignal(uint32_t* addr) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + volatile uint32_t* signal = addr; + uint32_t val; + do { + val = *signal; + } while (val == 0 || atomicCAS_system(addr, val, val - 1) != val); +#endif +} + +//////////////////////////////////////////////////////////////////////////////// +// Fully Connected Algos +//////////////////////////////////////////////////////////////////////////////// + +struct P2pState { + uint32_t signals0[kMaxAllReduceBlocks][kMaxDevices]; + uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices]; +}; + +template +static __global__ void oneShotAllReduceKernel( + at::BFloat16* input, + size_t N, + size_t N_aligned, + P2pState** p2pStates, + at::BFloat16** buffers, + size_t rank) { + const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16); + const size_t offset = + (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread; + const size_t stride = blockDim.x * gridDim.x * numelPerThread; + + // Wait for all other ranks to enter the kernel + if (threadIdx.x < kWorldSize) { + auto targetRank = threadIdx.x; + releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]); + acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]); + } + __syncthreads(); + + // The source pointers. Distributed round-robin for the different warps + const at::BFloat16* srcs[kWorldSize]; +#pragma unroll kWorldSize + for (int ii = 0; ii < kWorldSize; ++ii) { + int srcRank = (rank + ii) % kWorldSize; + srcs[ii] = buffers[srcRank]; + } + + for (size_t i = offset; i < N_aligned; i += stride) { + bf16x8 vals[kWorldSize]; +#pragma unroll kWorldSize + for (size_t ii = 0; ii < kWorldSize; ++ii) { + streamLoad128(vals[ii], &srcs[ii][i]); + } + + bf16x8 sums; + memset(reinterpret_cast(&sums), 0, sizeof(sums)); + +#pragma unroll kWorldSize + for (size_t ii = 0; ii < kWorldSize; ++ii) { + sums = add_bf16x8(sums, vals[ii]); + } + if constexpr (kAligned) { + streamStore128(&input[i], sums); + } else { + for (size_t ii = 0; ii < numelPerThread; ++ii) { + if (i + ii < N) { + input[i + ii] = reinterpret_cast(&sums)[ii]; + } + } + } + } +} + +template +static __launch_bounds__(1024) __global__ void twoShotAllReduceKernel( + at::BFloat16* input, + size_t N_aligned, + P2pState** p2pStates, + at::BFloat16** buffers, + size_t rank) { + const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16); + const size_t offset = + (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread; + const size_t stride = blockDim.x * gridDim.x * numelPerThread; + const size_t N_per_rank = N_aligned / kWorldSize; + const size_t N_start = N_per_rank * rank; + + // Wait for all other ranks to enter the kernel + if (threadIdx.x < kWorldSize) { + auto targetRank = threadIdx.x; + releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]); + acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]); + } + __syncthreads(); + + // The source pointers. Distributed round-robin for the different warps + at::BFloat16* srcs[kWorldSize]; + size_t srcRanks[kWorldSize]; +#pragma unroll kWorldSize + for (int ii = 0; ii < kWorldSize; ++ii) { + int srcRank = (rank + ii) % kWorldSize; + srcs[ii] = buffers[srcRank]; + srcRanks[ii] = srcRank; + } + + for (size_t i = offset; i < N_per_rank; i += stride) { + bf16x8 vals[kWorldSize]; +#pragma unroll kWorldSize + for (size_t ii = 0; ii < kWorldSize; ++ii) { + streamLoad128(vals[ii], &srcs[ii][N_start + i]); + } + + bf16x8 sums; + memset(reinterpret_cast(&sums), 0, sizeof(sums)); + +#pragma unroll kWorldSize + for (size_t ii = 0; ii < kWorldSize; ++ii) { + sums = add_bf16x8(sums, vals[ii]); + } + streamStore128(&srcs[0][N_start + i], sums); + // Store local sums into input now so we can avoid + // a global memory access later for it. + streamStore128(&input[N_start + i], sums); + } + __syncthreads(); + + if (threadIdx.x < kWorldSize) { + auto targetRank = threadIdx.x; + releaseSignal(&p2pStates[targetRank]->signals1[blockIdx.x][rank]); + acquireSignal(&p2pStates[rank]->signals1[blockIdx.x][targetRank]); + } + __syncthreads(); + + for (size_t i = offset; i < N_per_rank; i += stride) { +#pragma unroll kWorldSize - 1 + for (size_t ii = 1; ii < kWorldSize; ++ii) { + size_t k = N_start + i + (srcRanks[ii] - rank) * N_per_rank; + bf16x8 val; + streamLoad128(val, &srcs[ii][k]); + streamStore128(&input[k], val); + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Hybrid Cube Mesh Algos +//////////////////////////////////////////////////////////////////////////////// + +/** + * NOTE [hybrid cube mesh] + * + * In a hybrid cube mesh topology, every device has exactly 4 neighbors + * (directly connected via NVLink). For every device X, it has exactly 1 + * neighbor Y that is a neighbor of the 3 non-neighbor of X. We call Y the + * relay neighbor of X. This property is symmetrical: X is also guaranteed to + * be the relay neighbor of Y. + * + * With this property, we can perform a variant of one-shot allreduce algo that + * only moves data across NVLinks: + * + * - Each device one-shot allreduce among itself and 3 non-relay neighbors. + * - Each device exchange data with its relay neighbor. + * + * HybridCubeMesh is a data structure for describing the topology: + * + * - hcm[X][0:3] are the 3 neighbors of X. + * - hcm[X][3] is the relay neighbor of X. + * - For load balancing purpose, we also ensure that if hcm[X][k] = Y, + * hcm[Y][k] = X. + */ +std::optional getHybridCubeMesh(NvlMesh nvlMesh) { + std::array, kMaxDevices> neighbors = {}; + std::array neighborMasks = {}; + for (size_t i = 0; i < kMaxDevices; ++i) { + for (size_t j = 0; j < kMaxDevices; ++j) { + if (nvlMesh[i][j] > 0) { + neighbors[i].insert(j); + neighborMasks[i] |= (1ul << j); + } + } + } + HybridCubeMesh hcm = {}; + for (auto& row : hcm) { + row.fill(-1); + } + // A topology is an HCM if: + // - Every device has exactly 4 neighbors. + // - For every device, it has exactly 1 relay neighbor that is + // a neighbor of the 3 non-neighbor of the device. + for (size_t i = 0; i < kMaxDevices; ++i) { + if (neighbors[i].size() != 4) { + return std::nullopt; + } + // Condition 1: check the number of neighbors + std::vector relayNeighbors; + for (size_t j = 0; j < kMaxDevices; ++j) { + if ((neighborMasks[i] & neighborMasks[j]) == 0) { + relayNeighbors.push_back(j); + } + } + // Condition 2: check the number of relay neighbors + if (relayNeighbors.size() != 1) { + return std::nullopt; + } + neighbors[i].erase(relayNeighbors[0]); + hcm[i][3] = relayNeighbors[0]; + } + + for (size_t i = 0; i < kMaxDevices; ++i) { + for (size_t k = 0; k < 3; ++k) { + // We can only fill hcm[i][k] with j if hcm[j][k] is not filled + for (size_t j : neighbors[i]) { + if (hcm[j][k] == -1) { + hcm[i][k] = j; + hcm[j][k] = i; + break; + } + } + TORCH_CHECK(hcm[i][k] != -1); + neighbors[i].erase(hcm[i][k]); + } + } + return hcm; +} + +template +static __global__ void hybridCubeMeshAllReduceKernel( + at::BFloat16* input, + size_t N, + size_t N_aligned, + P2pState** p2pStates, + at::BFloat16** buffers, + int hcmInfo[4], + size_t rank) { + const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16); + const size_t offset = + (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread; + const size_t stride = blockDim.x * gridDim.x * numelPerThread; + const int relayRank = hcmInfo[3]; + + // Wait for HCM neigbors to enter the kernel + if (threadIdx.x < 3) { + auto targetRank = hcmInfo[threadIdx.x]; + releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]); + acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]); + } + __syncthreads(); + + const at::BFloat16* srcs[4] = { + buffers[rank], + buffers[hcmInfo[0]], + buffers[hcmInfo[1]], + buffers[hcmInfo[2]], + }; + at::BFloat16* localRelay = buffers[rank] + kMaxIntraNodeSize / 2; + at::BFloat16* remoteRelay = buffers[relayRank] + kMaxIntraNodeSize / 2; + + for (size_t i = offset; i < N_aligned; i += stride) { + bf16x8 vals[4]; + +#pragma unroll 4 + for (size_t ii = 0; ii < 4; ++ii) { + streamLoad128(vals[ii], &srcs[ii][i]); + } + + bf16x8 sums; + memset(reinterpret_cast(&sums), 0, sizeof(sums)); + +#pragma unroll 4 + for (size_t ii = 0; ii < 4; ++ii) { + sums = add_bf16x8(sums, vals[ii]); + } + // Cached store for local sums + store128(&localRelay[i], sums); + } + __syncthreads(); + + if (threadIdx.x == 0) { + releaseSignal(&p2pStates[relayRank]->signals0[blockIdx.x][rank]); + acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][relayRank]); + } + __syncthreads(); + + for (size_t i = offset; i < N_aligned; i += stride) { + bf16x8 localSum, remoteSum; + // Cached load for local sums + load128(localSum, &localRelay[i]); + streamLoad128(remoteSum, &remoteRelay[i]); + localSum = add_bf16x8(localSum, remoteSum); + if constexpr (kAligned) { + streamStore128(&input[i], localSum); + } else { + for (size_t ii = 0; ii < numelPerThread; ++ii) { + if (i + ii < N) { + input[i + ii] = reinterpret_cast(&localSum)[ii]; + } + } + } + } +} + +static inline size_t divUp(uint32_t a, uint32_t b) { + return (a + b - 1) / b; +} + +static inline size_t alignUp(uint32_t a, uint32_t b) { + return divUp(a, b) * b; +} + +static void checkInput(const at::Tensor& input, size_t rank) { + TORCH_CHECK( + input.dtype() == at::kBFloat16, + "oneShotAllReduce only supports bf16 for now"); + TORCH_CHECK(input.is_non_overlapping_and_dense()); + TORCH_CHECK(input.device().is_cuda()); + TORCH_CHECK(static_cast(input.get_device()) == rank); +} + +static void getLaunchConfig( + size_t N_aligned, + size_t elemSize, + dim3& blocks, + dim3& threads) { + blocks = dim3(0, 1, 1); + threads = dim3(0, 1, 1); + + const auto numelPerThread = kBytesPerThread / elemSize; + const auto numelPerWarp = numelPerThread * kWarpSize; + TORCH_CHECK(N_aligned % numelPerThread == 0); + TORCH_CHECK(N_aligned % numelPerWarp == 0); + if (N_aligned < numelPerThread * kThreadsPerBlock) { + threads.x = N_aligned / numelPerWarp * kWarpSize; + blocks.x = 1; + } else { + auto warpsRequired = N_aligned / numelPerWarp; + auto threadsRequired = N_aligned / numelPerThread; + blocks.x = + std::min(divUp(threadsRequired, kThreadsPerBlock), kMaxAllReduceBlocks); + auto warpsPerBlock = divUp(warpsRequired, blocks.x); + threads.x = std::min(kThreadsPerBlock, warpsPerBlock * kWarpSize); + } +} + +bool isIntraNodeCommSupported() { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + return false; +#else + return true; +#endif +} + +void* initP2pState() { + void* state = nullptr; + AT_CUDA_CHECK(cudaMalloc(&state, sizeof(P2pState))); + AT_CUDA_CHECK(cudaMemset(state, 0, sizeof(P2pState))); + return state; +} + +void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank) { + void* topoInfo = nullptr; + if (topology != Topology::HYBRID_CUBE_MESH) { + return topoInfo; + } + auto hcm = getHybridCubeMesh(nvlMesh); + int hcmInfo[4]; + std::copy((*hcm)[rank].begin(), (*hcm)[rank].begin() + 4, hcmInfo); + AT_CUDA_CHECK(cudaMalloc(&topoInfo, sizeof(hcmInfo))); + AT_CUDA_CHECK( + cudaMemcpy(topoInfo, hcmInfo, sizeof(hcmInfo), cudaMemcpyHostToDevice)); + return topoInfo; +} + +at::Tensor oneShotAllReduce( + const at::Tensor& input, + std::array p2pStates, + std::array buffers, + void* p2pStatesDev, + void* buffersDev, + size_t rank, + size_t worldSize, + at::cuda::CUDAStream& stream) { + checkInput(input, rank); + + size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize; + size_t N_aligned = alignUp(input.numel(), numelPerWarp); + TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size()); + + dim3 blocks, threads; + getLaunchConfig(N_aligned, input.element_size(), blocks, threads); + + at::cuda::OptionalCUDAGuard guard(input.get_device()); + AT_CUDA_CHECK(cudaMemcpyAsync( + buffers[rank], + input.data_ptr(), + input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, + stream)); + +#define X(kWorldSize, kAligned) \ + if (worldSize == kWorldSize) { \ + oneShotAllReduceKernel \ + <<>>( \ + input.data_ptr(), \ + input.numel(), \ + N_aligned, \ + reinterpret_cast(p2pStatesDev), \ + reinterpret_cast(buffersDev), \ + rank); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + } + +#define DISPATCH_ALL_WORLD_SIZES(kAligned) \ + X(2, kAligned); \ + X(3, kAligned); \ + X(4, kAligned); \ + X(5, kAligned); \ + X(6, kAligned); \ + X(7, kAligned); \ + X(8, kAligned); + + if (N_aligned == static_cast(input.numel())) { + DISPATCH_ALL_WORLD_SIZES(true); + } else { + DISPATCH_ALL_WORLD_SIZES(false); + } + +#undef DISPATCH_ALL_WORLD_SIZES +#undef X + return input; +} + +at::Tensor twoShotAllReduce( + const at::Tensor& input, + std::array p2pStates, + std::array buffers, + void* p2pStatesDev, + void* buffersDev, + size_t rank, + size_t worldSize, + at::cuda::CUDAStream& stream) { + checkInput(input, rank); + + size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize; + size_t N_aligned = alignUp(input.numel(), worldSize * numelPerWarp); + size_t N_per_rank = N_aligned / worldSize; + TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size()); + + dim3 blocks, threads; + getLaunchConfig(N_per_rank, input.element_size(), blocks, threads); + + auto output = N_aligned == static_cast(input.numel()) + ? input + : input.new_empty(N_aligned); + + at::cuda::OptionalCUDAGuard guard(input.get_device()); + AT_CUDA_CHECK(cudaMemcpyAsync( + buffers[rank], + input.data_ptr(), + input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, + stream)); + +#define X(kWorldSize) \ + if (worldSize == kWorldSize) { \ + twoShotAllReduceKernel<<>>( \ + output.data_ptr(), \ + N_aligned, \ + reinterpret_cast(p2pStatesDev), \ + reinterpret_cast(buffersDev), \ + rank); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + } + X(2); + X(3); + X(4); + X(5); + X(6); + X(7); + X(8); +#undef X + + if (output.data_ptr() != input.data_ptr()) { + AT_CUDA_CHECK(cudaMemcpyAsync( + input.data_ptr(), + output.data_ptr(), + input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, + stream)); + } + return input; +} + +at::Tensor hybridCubeMeshAllReduce( + const at::Tensor& input, + std::array p2pStates, + std::array buffers, + void* p2pStatesDev, + void* buffersDev, + int hcmInfo[4], + size_t rank, + size_t worldSize, + at::cuda::CUDAStream& stream) { + checkInput(input, rank); + + size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize; + size_t N_aligned = alignUp(input.numel(), numelPerWarp); + TORCH_CHECK(N_aligned <= kMaxIntraNodeSize / input.element_size()); + + dim3 blocks, threads; + getLaunchConfig(N_aligned, input.element_size(), blocks, threads); + + at::cuda::OptionalCUDAGuard guard(input.get_device()); + AT_CUDA_CHECK(cudaMemcpyAsync( + buffers[rank], + input.data_ptr(), + input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, + stream)); + +#define X(kAligned) \ + hybridCubeMeshAllReduceKernel<<>>( \ + input.data_ptr(), \ + input.numel(), \ + N_aligned, \ + reinterpret_cast(p2pStatesDev), \ + reinterpret_cast(buffersDev), \ + hcmInfo, \ + rank); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + if (N_aligned == static_cast(input.numel())) { + X(true); + } else { + X(false); + } +#undef X + return input; +} + +AllReduceAlgo selectAllReduceAlgo( + const at::Tensor& input, + Topology topology, + size_t worldSize) { + // Only support bf16 for now + if (input.dtype() != at::kBFloat16 || + input.numel() * input.element_size() > kMaxIntraNodeSize) { + return AllReduceAlgo::NONE; + } + const auto numel = input.numel(); + const auto numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize; + if (topology == Topology::HYBRID_CUBE_MESH) { + TORCH_CHECK( + worldSize == 8, "hyperCubeAllReduce only supports exactly 8 GPUs"); + if (alignUp(numel, numelPerWarp) <= kHcmThreshBytes) { + return AllReduceAlgo::HCM; + } + } + if (topology == Topology::FULLY_CONNECTED) { + if (alignUp(numel, numelPerWarp) <= kOneShotThreshBytes) { + return AllReduceAlgo::ONE_SHOT; + } + if (alignUp(numel, numelPerWarp * worldSize) <= kTwoShotThreshBytes) { + return AllReduceAlgo::TWO_SHOT; + } + } + return AllReduceAlgo::NONE; +} + +at::Tensor allReduce( + const at::Tensor& input, + std::array p2pStates, + std::array buffers, + void* p2pStatesDev, + void* buffersDev, + void* topoInfo, + size_t rank, + size_t worldSize, + AllReduceAlgo algo, + at::cuda::CUDAStream& stream) { + switch (algo) { + case AllReduceAlgo::ONE_SHOT: + return oneShotAllReduce( + input, + p2pStates, + buffers, + p2pStatesDev, + buffersDev, + rank, + worldSize, + stream); + case AllReduceAlgo::TWO_SHOT: + return twoShotAllReduce( + input, + p2pStates, + buffers, + p2pStatesDev, + buffersDev, + rank, + worldSize, + stream); + case AllReduceAlgo::HCM: + return hybridCubeMeshAllReduce( + input, + p2pStates, + buffers, + p2pStatesDev, + buffersDev, + (int*)topoInfo, + rank, + worldSize, + stream); + default: + C10_THROW_ERROR(ValueError, "IntraNodeComm: invalid algo"); + } +} + +} // namespace intra_node_comm +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp new file mode 100644 index 00000000000..b4949906789 --- /dev/null +++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10d { +namespace intra_node_comm { + +constexpr size_t kMaxDevices = 8; +constexpr size_t kMaxIntraNodeSize = 10 * 1024 * 1024; + +using NvlMesh = std::array, kMaxDevices>; +using HybridCubeMesh = std::array, kMaxDevices>; + +enum class Topology { UNKNOWN = 0, FULLY_CONNECTED = 1, HYBRID_CUBE_MESH = 2 }; + +enum class AllReduceAlgo { NONE = 0, ONE_SHOT = 1, TWO_SHOT = 2, HCM = 3 }; + +class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { + public: + IntraNodeComm( + Topology topology, + std::array p2pStates, + std::array buffers, + void* p2pStatesDev, + void* buffersDev, + void* topoInfo, + size_t rank, + size_t worldSize); + + ~IntraNodeComm(); + + /** + * Rendezvous via a c10d::Store. + * This function may return nullptr if intra-node comm is not applicable. + * It guarantees all participants either succeeds or abort. + */ + static c10::intrusive_ptr rendezvous( + c10::intrusive_ptr store, + const std::string& prefix, + size_t rank, + size_t worldSize); + + /** + * Selects a AllReduceAlgo that we think will outperform nccl. + * Returns AllReduceAlgo::NONE if we don't think we can outperform nccl. + */ + AllReduceAlgo selectAllReduceAlgo(const at::Tensor& input); + + at::Tensor allReduce(const at::Tensor& input, AllReduceAlgo algo); + + private: + Topology topology_; + std::array p2pStates_; + std::array buffers_; + void* p2pStatesDev_; + void* buffersDev_; + void* topoInfo_; + size_t rank_; + size_t worldSize_; +}; + +/** + * NOTE [IntraNodeComm Stream Semantics] + * + * ProcessGroupNCCL launches kernels differently from the conventional PyTorch + * CUDA semantics: it always launches collective kernels onto a dedicated + * communication stream. Therefore, it needs to: + * + * - Synchronize the calling stream and the comm stream. + * - Ensure the memory safety of the operands (via record_stream or stashing). + * - Synchronize the waiting stream with the comm stream. + * + * Unconditionally performing these tasks makes sense when we expect most of the + * communication to benefit from compute/comm overlap. However, IntraNodeComm + * primarily aims to optimize small, latency-sensitive, blocking communication, + * in which the overhead incurred by the above steps can be quite pronounced. + * + * Thus, IntraNodeComm follows the conventional PyTorch CUDA semantics and + * launches kernels onto the stream specified by the user. Although the user + * can perform neccessary synchronization via wait_stream, to provide a UX + * consistent to that of ProcessGroupNCCL, the neccessary stream + * synchronization can also be performed via IntraNodeWork::wait(). + */ +class IntraNodeCommWork : public c10d::Work { + public: + IntraNodeCommWork() : c10d::Work() { + event_.record(); + } + + bool wait(std::chrono::milliseconds timeout = kNoTimeout) override { + event_.block(at::cuda::getCurrentCUDAStream()); + return true; + } + + private: + at::cuda::CUDAEvent event_; +}; + +TORCH_API int64_t getIntraNodeCommUsageCounter(); + +} // namespace intra_node_comm +} // namespace c10d