mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PFC] Native UCC process group for Pytorch (#79918)
Summary: This diff integrates UCC process group as a native component of Pytorch Distributed core. It is based on the existing torch-ucc (https://github.com/facebookresearch/torch_ucc) as the wrapper for UCC collective communication library. The environment and cmake variables are named in mirroring to the existing process groups such as NCCL and Gloo. Specifically, - USE_UCC: enables UCC PG. This defaults to OFF, so there is no breakage of existing builds that do not have UCX/UCC external libraries. - USE_SYSTEM_UCC: uses external UCX and UCC shared libraries that are set accordingly with UCX_HOME and UCC_HOME. Currently, this diff only supports USE_SYSTEM_UCC=ON, i.e., requiring users to specify external libraries for UCX and UCC. In subsequent diffs, we will add UCX and UCC repos as third-party dependencies in pytorch/third-party. Test Plan: Passed Torch-UCC tests that invoke UCC process group. For example: $ sh test/start_test.sh test/torch_allreduce_test.py --backend gloo --use-cuda ... Test allreduce: succeeded Differential Revision: D36973688 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79918 Approved by: https://github.com/kwen2501, https://github.com/kingchc
This commit is contained in:
parent
268d910170
commit
54bdaf76d6
|
|
@ -308,6 +308,14 @@ option(USE_DISTRIBUTED "Use distributed" ON)
|
||||||
cmake_dependent_option(
|
cmake_dependent_option(
|
||||||
USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON
|
USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON
|
||||||
"USE_DISTRIBUTED" OFF)
|
"USE_DISTRIBUTED" OFF)
|
||||||
|
cmake_dependent_option(
|
||||||
|
USE_UCC "Use UCC. Only available if USE_DISTRIBUTED is on." OFF
|
||||||
|
"USE_DISTRIBUTED" OFF)
|
||||||
|
cmake_dependent_option(
|
||||||
|
USE_SYSTEM_UCC "Use system-wide UCC" OFF
|
||||||
|
"USE_UCC" OFF)
|
||||||
|
cmake_dependent_option(
|
||||||
|
USE_C10D_UCC "USE C10D UCC" ON "USE_DISTRIBUTED;USE_UCC" OFF)
|
||||||
cmake_dependent_option(
|
cmake_dependent_option(
|
||||||
USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON
|
USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON
|
||||||
"USE_DISTRIBUTED" OFF)
|
"USE_DISTRIBUTED" OFF)
|
||||||
|
|
|
||||||
|
|
@ -744,6 +744,9 @@ libtorch_cuda_distributed_base_sources = [
|
||||||
libtorch_cuda_distributed_extra_sources = [
|
libtorch_cuda_distributed_extra_sources = [
|
||||||
"torch/csrc/distributed/c10d/NCCLUtils.cpp",
|
"torch/csrc/distributed/c10d/NCCLUtils.cpp",
|
||||||
"torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
|
"torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
|
||||||
|
"torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
|
||||||
|
"torch/csrc/distributed/c10d/UCCTracing.cpp",
|
||||||
|
"torch/csrc/distributed/c10d/UCCUtils.cpp",
|
||||||
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
||||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -911,6 +911,11 @@ if(HAVE_SOVERSION)
|
||||||
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
|
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(USE_UCC)
|
||||||
|
target_link_libraries(torch_cpu PRIVATE __caffe2_ucc)
|
||||||
|
target_compile_definitions(torch_cpu PRIVATE USE_UCC)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(USE_ROCM)
|
if(USE_ROCM)
|
||||||
filter_list(__caffe2_hip_srcs_cpp Caffe2_HIP_SRCS "\\.(cu|hip)$")
|
filter_list(__caffe2_hip_srcs_cpp Caffe2_HIP_SRCS "\\.(cu|hip)$")
|
||||||
set_source_files_properties(${__caffe2_hip_srcs_cpp} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
set_source_files_properties(${__caffe2_hip_srcs_cpp} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||||
|
|
@ -972,6 +977,13 @@ elseif(USE_CUDA)
|
||||||
target_link_libraries(torch_cuda PRIVATE __caffe2_nccl)
|
target_link_libraries(torch_cuda PRIVATE __caffe2_nccl)
|
||||||
target_compile_definitions(torch_cuda PRIVATE USE_NCCL)
|
target_compile_definitions(torch_cuda PRIVATE USE_NCCL)
|
||||||
endif()
|
endif()
|
||||||
|
if(USE_UCC AND BUILD_SPLIT_CUDA)
|
||||||
|
target_link_libraries(torch_cuda_cpp PRIVATE __caffe2_ucc)
|
||||||
|
target_compile_definitions(torch_cuda_cpp PRIVATE USE_UCC)
|
||||||
|
elseif(USE_UCC)
|
||||||
|
target_link_libraries(torch_cuda PRIVATE __caffe2_ucc)
|
||||||
|
target_compile_definitions(torch_cuda PRIVATE USE_UCC)
|
||||||
|
endif()
|
||||||
if(BUILD_LAZY_CUDA_LINALG)
|
if(BUILD_LAZY_CUDA_LINALG)
|
||||||
add_library(torch_cuda_linalg ${ATen_CUDA_LINALG_SRCS})
|
add_library(torch_cuda_linalg ${ATen_CUDA_LINALG_SRCS})
|
||||||
target_compile_definitions(torch_cuda_linalg PRIVATE USE_CUDA BUILD_LAZY_CUDA_LINALG)
|
target_compile_definitions(torch_cuda_linalg PRIVATE USE_CUDA BUILD_LAZY_CUDA_LINALG)
|
||||||
|
|
@ -1347,6 +1359,16 @@ if(USE_DISTRIBUTED)
|
||||||
if(USE_GLOO AND USE_C10D_GLOO)
|
if(USE_GLOO AND USE_C10D_GLOO)
|
||||||
target_compile_definitions(torch_cpu PUBLIC USE_C10D_GLOO)
|
target_compile_definitions(torch_cpu PUBLIC USE_C10D_GLOO)
|
||||||
endif()
|
endif()
|
||||||
|
if(USE_UCC AND USE_C10D_UCC)
|
||||||
|
target_compile_definitions(torch_cpu PUBLIC USE_C10D_UCC)
|
||||||
|
if(USE_CUDA)
|
||||||
|
if(BUILD_SPLIT_CUDA)
|
||||||
|
target_compile_definitions(torch_cuda_cpp PUBLIC USE_C10D_UCC)
|
||||||
|
else()
|
||||||
|
target_compile_definitions(torch_cuda PUBLIC USE_C10D_UCC)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
if(USE_NCCL AND USE_C10D_NCCL)
|
if(USE_NCCL AND USE_C10D_NCCL)
|
||||||
if(USE_ROCM)
|
if(USE_ROCM)
|
||||||
target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL)
|
target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL)
|
||||||
|
|
|
||||||
|
|
@ -1362,6 +1362,16 @@ if(USE_NCCL)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# ---[ UCC
|
||||||
|
if(USE_UCC)
|
||||||
|
if(NOT CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||||
|
message(WARNING "UCC is currently only supported under Linux.")
|
||||||
|
caffe2_update_option(USE_UCC OFF)
|
||||||
|
else()
|
||||||
|
include(${CMAKE_CURRENT_LIST_DIR}/External/ucc.cmake)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
# ---[ CUB
|
# ---[ CUB
|
||||||
if(USE_CUDA)
|
if(USE_CUDA)
|
||||||
find_package(CUB)
|
find_package(CUB)
|
||||||
|
|
|
||||||
20
cmake/External/ucc.cmake
vendored
Normal file
20
cmake/External/ucc.cmake
vendored
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
if(NOT __UCC_INCLUDED)
|
||||||
|
set(__UCC_INCLUDED TRUE)
|
||||||
|
|
||||||
|
if(USE_SYSTEM_UCC)
|
||||||
|
set(UCX_HOME $ENV{UCX_HOME} CACHE PATH "UCX install directory")
|
||||||
|
set(UCC_HOME $ENV{UCC_HOME} CACHE PATH "UCC install directory")
|
||||||
|
|
||||||
|
add_library(__caffe2_ucc INTERFACE)
|
||||||
|
|
||||||
|
target_include_directories(__caffe2_ucc INTERFACE ${UCX_HOME}/include/)
|
||||||
|
target_include_directories(__caffe2_ucc INTERFACE ${UCC_HOME}/include/)
|
||||||
|
|
||||||
|
target_link_libraries(__caffe2_ucc INTERFACE ${UCX_HOME}/lib/libucp.so)
|
||||||
|
target_link_libraries(__caffe2_ucc INTERFACE ${UCX_HOME}/lib/libucs.so)
|
||||||
|
target_link_libraries(__caffe2_ucc INTERFACE ${UCC_HOME}/lib/libucc.so)
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "USE_SYSTEM_UCC=OFF is not supported yet when using UCC")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
endif()
|
||||||
|
|
@ -146,6 +146,10 @@ function(caffe2_print_configuration_summary)
|
||||||
message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}")
|
message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}")
|
||||||
message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}")
|
message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}")
|
||||||
endif()
|
endif()
|
||||||
|
message(STATUS " USE_UCC : ${USE_UCC}")
|
||||||
|
if(${USE_UCC})
|
||||||
|
message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}")
|
||||||
|
endif()
|
||||||
message(STATUS " USE_NCCL : ${USE_NCCL}")
|
message(STATUS " USE_NCCL : ${USE_NCCL}")
|
||||||
if(${USE_NCCL})
|
if(${USE_NCCL})
|
||||||
message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}")
|
message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}")
|
||||||
|
|
|
||||||
|
|
@ -251,6 +251,9 @@ if(USE_DISTRIBUTED)
|
||||||
if(USE_NCCL)
|
if(USE_NCCL)
|
||||||
list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl)
|
list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl)
|
||||||
endif()
|
endif()
|
||||||
|
if(USE_UCC)
|
||||||
|
list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_ucc)
|
||||||
|
endif()
|
||||||
# Same for MPI.
|
# Same for MPI.
|
||||||
if(USE_MPI)
|
if(USE_MPI)
|
||||||
list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${MPI_CXX_LIBRARIES})
|
list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${MPI_CXX_LIBRARIES})
|
||||||
|
|
@ -284,6 +287,9 @@ if(USE_DEPLOY)
|
||||||
if(USE_GLOO AND USE_C10D_GLOO)
|
if(USE_GLOO AND USE_C10D_GLOO)
|
||||||
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_GLOO)
|
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_GLOO)
|
||||||
endif()
|
endif()
|
||||||
|
if(USE_UCC AND USE_C10D_UCC)
|
||||||
|
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_UCC)
|
||||||
|
endif()
|
||||||
if(USE_NCCL AND USE_C10D_NCCL)
|
if(USE_NCCL AND USE_C10D_NCCL)
|
||||||
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_NCCL)
|
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_NCCL)
|
||||||
# Put nccl headers on the include path. We are specifically only setting
|
# Put nccl headers on the include path. We are specifically only setting
|
||||||
|
|
|
||||||
|
|
@ -371,6 +371,15 @@ class ProcessGroupNCCL(ProcessGroup):
|
||||||
def _group_end() -> None: ...
|
def _group_end() -> None: ...
|
||||||
...
|
...
|
||||||
|
|
||||||
|
class ProcessGroupUCC(ProcessGroup):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
store: Store,
|
||||||
|
rank: int,
|
||||||
|
size: int,
|
||||||
|
timeout: timedelta,
|
||||||
|
): ...
|
||||||
|
|
||||||
class ProcessGroupMPI(ProcessGroup):
|
class ProcessGroupMPI(ProcessGroup):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
1870
torch/csrc/distributed/c10d/ProcessGroupUCC.cpp
Normal file
1870
torch/csrc/distributed/c10d/ProcessGroupUCC.cpp
Normal file
File diff suppressed because it is too large
Load Diff
407
torch/csrc/distributed/c10d/ProcessGroupUCC.hpp
Normal file
407
torch/csrc/distributed/c10d/ProcessGroupUCC.hpp
Normal file
|
|
@ -0,0 +1,407 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifdef USE_C10D_UCC
|
||||||
|
|
||||||
|
#include <c10d/UCCUtils.hpp>
|
||||||
|
|
||||||
|
#include <exception>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <queue>
|
||||||
|
#include <thread>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <c10d/ProcessGroup.hpp>
|
||||||
|
#include <c10d/Store.hpp>
|
||||||
|
#include <c10d/Types.hpp>
|
||||||
|
#include <c10d/Utils.hpp>
|
||||||
|
#ifdef USE_CUDA
|
||||||
|
#include <ATen/cuda/CUDAEvent.h>
|
||||||
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace c10d {
|
||||||
|
|
||||||
|
#define TORCH_UCC_DEVICE_NOT_SET -2
|
||||||
|
|
||||||
|
#define TORCH_UCX_MAKE_P2P_TAG(_tag, _rank, _comm) \
|
||||||
|
((((uint64_t)(_tag)) << TORCH_UCX_TAG_BITS_OFFSET) | \
|
||||||
|
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
|
||||||
|
(((uint64_t)(_comm)) << TORCH_UCX_COMM_BITS_OFFSET))
|
||||||
|
|
||||||
|
#define TORCH_UCX_MAKE_OOB_TAG(_tag, _rank, _comm) \
|
||||||
|
((((uint64_t)(_tag)) << TORCH_UCX_OOB_BITS_OFFSET) | \
|
||||||
|
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
|
||||||
|
(((uint64_t)(_rank)) << TORCH_UCX_COMM_BITS_OFFSET))
|
||||||
|
|
||||||
|
#define TORCH_UCX_MAKE_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
|
||||||
|
do { \
|
||||||
|
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define TORCH_UCX_ANY_SOURCE (TORCH_UCX_MAX_RANK - 1)
|
||||||
|
#define TORCH_UCX_ANY_SOURCE_MASK (~TORCH_UCX_RANK_MASK)
|
||||||
|
#define TORCH_UCX_SPECIFIC_SOURCE_MASK ((uint64_t)-1)
|
||||||
|
|
||||||
|
#define TORCH_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
|
||||||
|
do { \
|
||||||
|
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
|
||||||
|
if ((_rank) == TORCH_UCX_ANY_SOURCE) { \
|
||||||
|
(_ucp_tag_mask) = TORCH_UCX_ANY_SOURCE_MASK; \
|
||||||
|
} else { \
|
||||||
|
(_ucp_tag_mask) = TORCH_UCX_SPECIFIC_SOURCE_MASK; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define TORCH_UCX_MAKE_OOB_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
|
||||||
|
do { \
|
||||||
|
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define TORCH_UCX_MAKE_OOB_RECV_TAG( \
|
||||||
|
_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
|
||||||
|
do { \
|
||||||
|
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
|
||||||
|
(_ucp_tag_mask) = (uint64_t)-1; \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#ifdef USE_CUDA
|
||||||
|
#define SAVE_TENSORS(_TENSORS, _DATA) \
|
||||||
|
do { \
|
||||||
|
if ((_TENSORS)[0].device().is_cuda()) { \
|
||||||
|
for (const auto i : c10::irange((_TENSORS).size())) { \
|
||||||
|
c10::cuda::CUDACachingAllocator::recordStream( \
|
||||||
|
(_TENSORS)[i].storage().data_ptr(), (*stream)); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
(_DATA) = (_TENSORS); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#else
|
||||||
|
#define SAVE_TENSORS(_TENSORS, _DATA) (_DATA) = (_TENSORS);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
constexpr const char* UCC_BACKEND_NAME = "ucc";
|
||||||
|
|
||||||
|
enum torch_ucx_tag_type_t { TORCH_UCX_P2P_TAG, TORCH_UCX_OOB_TAG };
|
||||||
|
|
||||||
|
struct event_pool_t {
|
||||||
|
#ifdef USE_CUDA
|
||||||
|
std::queue<std::unique_ptr<at::cuda::CUDAEvent>> event_pool;
|
||||||
|
#endif
|
||||||
|
std::mutex event_pool_mutex;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Comm;
|
||||||
|
|
||||||
|
// UCC does not support multiple CUDA devices per process.
|
||||||
|
class TORCH_API ProcessGroupUCC : public ProcessGroup {
|
||||||
|
private:
|
||||||
|
void set_timeout(ucc_coll_args_t& args);
|
||||||
|
|
||||||
|
public:
|
||||||
|
class WorkData {
|
||||||
|
public:
|
||||||
|
std::vector<at::Tensor> src;
|
||||||
|
std::vector<at::Tensor> dst;
|
||||||
|
std::vector<at::Tensor> flat;
|
||||||
|
WorkData() {}
|
||||||
|
virtual ~WorkData() = default;
|
||||||
|
};
|
||||||
|
class AlltoallWorkData : public WorkData {
|
||||||
|
public:
|
||||||
|
AlltoallWorkData(int size)
|
||||||
|
: send_lengths(size),
|
||||||
|
send_offsets(size),
|
||||||
|
recv_lengths(size),
|
||||||
|
recv_offsets(size) {}
|
||||||
|
std::vector<uint64_t> send_lengths;
|
||||||
|
std::vector<uint64_t> send_offsets;
|
||||||
|
std::vector<uint64_t> recv_lengths;
|
||||||
|
std::vector<uint64_t> recv_offsets;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllgathervWorkData : public WorkData {
|
||||||
|
public:
|
||||||
|
AllgathervWorkData(int size) : recv_lengths(size), recv_offsets(size) {}
|
||||||
|
std::vector<uint64_t> recv_lengths;
|
||||||
|
std::vector<uint64_t> recv_offsets;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ScattervWorkData : public WorkData {
|
||||||
|
public:
|
||||||
|
ScattervWorkData(int size) : send_lengths(size), send_offsets(size) {}
|
||||||
|
std::vector<uint64_t> send_lengths;
|
||||||
|
std::vector<uint64_t> send_offsets;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ProgressEntry {
|
||||||
|
friend class ProcessGroupUCC;
|
||||||
|
friend class Comm;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ProgressEntry(CommBase* comm, ucc_coll_req_h request)
|
||||||
|
: status_(UCC_INPROGRESS), comm_(comm), request_(request) {}
|
||||||
|
// Finalizes UCC status or exception of collective request.
|
||||||
|
void finalize(std::exception_ptr eptr = nullptr);
|
||||||
|
ucc_status_t status_;
|
||||||
|
CommBase* comm_;
|
||||||
|
ucc_coll_req_h request_;
|
||||||
|
std::unique_ptr<WorkData> data;
|
||||||
|
c10::intrusive_ptr<c10::ivalue::Future> future_;
|
||||||
|
std::exception_ptr eptr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WorkUCC : public ProcessGroup::Work {
|
||||||
|
friend class ProcessGroupUCC;
|
||||||
|
friend class Comm;
|
||||||
|
|
||||||
|
public:
|
||||||
|
WorkUCC(OpType opType, const char* prof_title)
|
||||||
|
: ProcessGroup::Work(-1, opType, prof_title) {}
|
||||||
|
WorkUCC(
|
||||||
|
OpType opType,
|
||||||
|
const char* prof_title,
|
||||||
|
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
|
||||||
|
: ProcessGroup::Work(-1, opType, prof_title), logger_(logger) {}
|
||||||
|
~WorkUCC();
|
||||||
|
void setException();
|
||||||
|
void setAndThrowException();
|
||||||
|
bool isCompleted() override;
|
||||||
|
bool isSuccess() const override;
|
||||||
|
bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
|
||||||
|
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
|
||||||
|
std::vector<at::Tensor> result() override;
|
||||||
|
#ifdef USE_CUDA
|
||||||
|
std::unique_ptr<at::cuda::CUDAEvent> fence = nullptr;
|
||||||
|
event_pool_t* ep = nullptr;
|
||||||
|
#endif
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<ProgressEntry> entry_;
|
||||||
|
c10::intrusive_ptr<ProcessGroupUCCLogger> logger_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The future returned by getFuture.
|
||||||
|
c10::intrusive_ptr<at::ivalue::Future> future_;
|
||||||
|
// Store a reference to collective's outputs, used by result
|
||||||
|
std::shared_ptr<std::vector<at::Tensor>> outputs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
explicit ProcessGroupUCC(
|
||||||
|
const c10::intrusive_ptr<Store>& store,
|
||||||
|
int rank = -1,
|
||||||
|
int size = -1,
|
||||||
|
std::chrono::duration<float> timeout = kProcessGroupDefaultTimeout);
|
||||||
|
|
||||||
|
void initComm(c10::Device dev);
|
||||||
|
|
||||||
|
~ProcessGroupUCC() override;
|
||||||
|
|
||||||
|
const std::string getBackendName() const override {
|
||||||
|
return std::string(UCC_BACKEND_NAME);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef USE_CUDA
|
||||||
|
std::unique_ptr<at::cuda::CUDAEvent> getPooledEvent();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Performs a health check by initializing dummy UCC & UCX communicators and
|
||||||
|
// then destroying them. This will help indicate and signal any
|
||||||
|
// UCC/UCX-related issues prior to the first collective. The actual
|
||||||
|
// initialization and subsequent destruction is ran on a separate thread and
|
||||||
|
// the main thread is signalled about timeouts/errors to report to the
|
||||||
|
// application.
|
||||||
|
void runHealthCheck();
|
||||||
|
|
||||||
|
template <typename PreProcess, typename PostProcess>
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> collective_post(
|
||||||
|
OpType opType,
|
||||||
|
PreProcess preproc,
|
||||||
|
PostProcess postproc,
|
||||||
|
ucc_coll_args_t& coll,
|
||||||
|
std::unique_ptr<ProcessGroupUCC::WorkData> data,
|
||||||
|
c10::Device dev,
|
||||||
|
std::vector<at::Tensor>& inputTensors,
|
||||||
|
std::vector<at::Tensor>& outputTensors,
|
||||||
|
const char* prof_title);
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> broadcast(
|
||||||
|
std::vector<at::Tensor>& data,
|
||||||
|
const BroadcastOptions& opts = BroadcastOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
const AllreduceOptions& opts = AllreduceOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
const AllreduceCoalescedOptions& opts =
|
||||||
|
AllreduceCoalescedOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> reduce(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
const ReduceOptions& opts = ReduceOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> allgather(
|
||||||
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||||
|
std::vector<at::Tensor>& inputTensors,
|
||||||
|
const AllgatherOptions& opts = AllgatherOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
|
||||||
|
at::Tensor& outputBuffer,
|
||||||
|
at::Tensor& inputBuffer,
|
||||||
|
const AllgatherOptions& opts = AllgatherOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> barrier(
|
||||||
|
const BarrierOptions& opts = BarrierOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> gather(
|
||||||
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||||
|
std::vector<at::Tensor>& inputTensors,
|
||||||
|
const GatherOptions& opts = GatherOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> scatter(
|
||||||
|
std::vector<at::Tensor>& outputTensors,
|
||||||
|
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||||
|
const ScatterOptions& opts = ScatterOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
|
||||||
|
std::vector<at::Tensor>& outputTensors,
|
||||||
|
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||||
|
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
|
||||||
|
at::Tensor& outputTensor,
|
||||||
|
at::Tensor& inputTensor,
|
||||||
|
std::vector<int64_t>& outputSplitSizes,
|
||||||
|
std::vector<int64_t>& inputSplitSizes,
|
||||||
|
const AllToAllOptions& opts = AllToAllOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> alltoall(
|
||||||
|
std::vector<at::Tensor>& outputTensors,
|
||||||
|
std::vector<at::Tensor>& inputTensors,
|
||||||
|
const AllToAllOptions& opts = AllToAllOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> send(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
int dstRank,
|
||||||
|
int tag) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> recv(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
int srcRank,
|
||||||
|
int tag) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
int tag) override;
|
||||||
|
|
||||||
|
static c10::intrusive_ptr<ProcessGroup> createProcessGroupUCC(
|
||||||
|
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||||
|
int rank,
|
||||||
|
int size,
|
||||||
|
const std::chrono::duration<float>& timeout);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const std::chrono::duration<float> timeout_;
|
||||||
|
std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
|
||||||
|
std::shared_ptr<Comm> comm = {nullptr};
|
||||||
|
uint32_t comm_id;
|
||||||
|
std::vector<ucp_ep_h> eps;
|
||||||
|
ucc_team_h team{nullptr};
|
||||||
|
ucc_ee_h cuda_ee{nullptr};
|
||||||
|
#ifdef USE_CUDA
|
||||||
|
std::unique_ptr<at::cuda::CUDAStream> stream = nullptr;
|
||||||
|
event_pool_t ep;
|
||||||
|
#endif
|
||||||
|
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Comm {
|
||||||
|
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
|
||||||
|
std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
|
||||||
|
CommUCX ucx_comm;
|
||||||
|
CommUCC ucc_comm;
|
||||||
|
std::mutex mutex;
|
||||||
|
std::thread progress_thread;
|
||||||
|
std::condition_variable queue_produce_cv;
|
||||||
|
std::condition_variable queue_consume_cv;
|
||||||
|
std::deque<std::shared_ptr<ProcessGroupUCC::ProgressEntry>> progress_queue;
|
||||||
|
bool stop_progress_loop;
|
||||||
|
bool collective_inprogress;
|
||||||
|
torch_ucc_phase_t finalize_phase;
|
||||||
|
|
||||||
|
public:
|
||||||
|
c10::DeviceIndex cuda_device_index;
|
||||||
|
Comm(
|
||||||
|
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
|
||||||
|
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
|
||||||
|
c10::Device dev,
|
||||||
|
bool is_health_check);
|
||||||
|
|
||||||
|
~Comm();
|
||||||
|
|
||||||
|
// Connects UCX end points.
|
||||||
|
void ucx_connect_eps(
|
||||||
|
std::vector<ucp_ep_h>& eps,
|
||||||
|
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
|
||||||
|
|
||||||
|
// Disconnects UCX end points.
|
||||||
|
void ucx_disconnect_eps(
|
||||||
|
std::vector<ucp_ep_h>& eps,
|
||||||
|
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
|
||||||
|
|
||||||
|
void ucc_create_team(
|
||||||
|
ucc_team_h& team,
|
||||||
|
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
|
||||||
|
|
||||||
|
void ucc_destroy_team(ucc_team_h& team);
|
||||||
|
|
||||||
|
c10::intrusive_ptr<ProcessGroup::Work> enqueue_p2p(
|
||||||
|
OpType opType,
|
||||||
|
ucc_coll_req_h request,
|
||||||
|
const char* prof_title);
|
||||||
|
|
||||||
|
#ifdef USE_CUDA
|
||||||
|
void enqueue_cuda_collective(
|
||||||
|
std::unique_ptr<ProcessGroupUCC::WorkData> data,
|
||||||
|
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
|
||||||
|
ucc_coll_args_t& coll,
|
||||||
|
ucc_team_h team,
|
||||||
|
ucc_ee_h ee);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void enqueue_collective(
|
||||||
|
std::unique_ptr<ProcessGroupUCC::WorkData> data,
|
||||||
|
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
|
||||||
|
ucc_coll_args_t& coll,
|
||||||
|
ucc_team_h team);
|
||||||
|
|
||||||
|
static std::shared_ptr<Comm> get_comm(
|
||||||
|
uint32_t& id,
|
||||||
|
c10::Device dev,
|
||||||
|
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
|
||||||
|
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
|
||||||
|
bool is_health_check = false);
|
||||||
|
|
||||||
|
void progress_loop();
|
||||||
|
|
||||||
|
ucc_coll_req_h send_nb(
|
||||||
|
ucp_ep_h ep,
|
||||||
|
void* data,
|
||||||
|
ucs_memory_type_t mtype,
|
||||||
|
size_t size,
|
||||||
|
ucp_tag_t ucp_tag);
|
||||||
|
|
||||||
|
ucc_coll_req_h recv_nb(
|
||||||
|
void* data,
|
||||||
|
ucs_memory_type_t mtype,
|
||||||
|
size_t size,
|
||||||
|
ucp_tag_t ucp_tag,
|
||||||
|
ucp_tag_t ucp_tag_mask);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace c10d
|
||||||
|
|
||||||
|
#endif // USE_C10D_UCC
|
||||||
176
torch/csrc/distributed/c10d/UCCTracing.cpp
Normal file
176
torch/csrc/distributed/c10d/UCCTracing.cpp
Normal file
|
|
@ -0,0 +1,176 @@
|
||||||
|
#ifdef USE_C10D_UCC
|
||||||
|
|
||||||
|
#include <c10d/UCCTracing.hpp>
|
||||||
|
#include <c10d/UCCUtils.hpp>
|
||||||
|
|
||||||
|
#include <c10d/ParamCommsUtils.hpp>
|
||||||
|
|
||||||
|
#include <sys/stat.h>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <ctime>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#ifdef FBCODE_CAFFE2
|
||||||
|
#include <c10d/UCCInternalUtils.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace c10d {
|
||||||
|
|
||||||
|
void ProcessGroupUCCLogger::initCommsTracer() {
|
||||||
|
trace_generator = std::make_shared<CommTraceLogger>();
|
||||||
|
initialized_CommTraceLogger = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ProcessGroupUCCLogger::flushComms(int rank, int world_size) {
|
||||||
|
if (!initialized_CommTraceLogger ||
|
||||||
|
trace_generator->getCommsTrace().empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string dirname = c10::str("ProcessGroupUCC_trace_np", world_size);
|
||||||
|
time_t now_ = time(0);
|
||||||
|
std::tm* ltm = localtime(&now_);
|
||||||
|
if (ltm) {
|
||||||
|
dirname += c10::str(
|
||||||
|
"_", (1 + ltm->tm_mon), "_", ltm->tm_mday, "_", (1900 + ltm->tm_year));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string fullpath = "/tmp/" + dirname;
|
||||||
|
char* user_path = std::getenv("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR");
|
||||||
|
if (user_path) {
|
||||||
|
fullpath = user_path;
|
||||||
|
}
|
||||||
|
std::string trace_filename = c10::str(fullpath, "/rank", rank, ".json");
|
||||||
|
std::ofstream _outfile;
|
||||||
|
if (!_outfile.is_open()) {
|
||||||
|
if (!mkdir(fullpath.c_str(), 0777)) {
|
||||||
|
LOG(INFO) << getLogPrefix() << "[INFO] failed to mkdir " << fullpath;
|
||||||
|
} else if (errno != EEXIST) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
_outfile.open(trace_filename, std::ofstream::out | std::ofstream::trunc);
|
||||||
|
}
|
||||||
|
// flush the traced comms
|
||||||
|
if (_outfile.is_open()) {
|
||||||
|
_outfile << "[" << c10::Join(",", trace_generator->getCommsTrace())
|
||||||
|
<< "\n]";
|
||||||
|
_outfile.flush();
|
||||||
|
_outfile.close();
|
||||||
|
}
|
||||||
|
#ifdef FBCODE_CAFFE2
|
||||||
|
uploadTrace_internal(
|
||||||
|
trace_filename, dirname, c10::str("rank", rank, ".json"));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/* unused */
|
||||||
|
void CommTraceLogger::setCurBlock(const std::string& name) {
|
||||||
|
curBlocks_.push_back(
|
||||||
|
c10::str("\"", name, "\"")); // add quote marks for JSON format
|
||||||
|
}
|
||||||
|
|
||||||
|
/* unused */
|
||||||
|
void CommTraceLogger::popBlock() {
|
||||||
|
// TODO: remove specific name
|
||||||
|
curBlocks_.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommTraceLogger::recordOptionalInfo(int root) {
|
||||||
|
curRoot_ = root;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommTraceLogger::recordOptionalInfo(
|
||||||
|
const std::vector<int64_t>& outputSplitSizes,
|
||||||
|
const std::vector<int64_t>& inputSplitSizes) {
|
||||||
|
curOutSplitSizes_ = outputSplitSizes;
|
||||||
|
curInSplitSizes_ = inputSplitSizes;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommTraceLogger::recordComms(
|
||||||
|
const std::string& commName,
|
||||||
|
const uintptr_t workReq,
|
||||||
|
const int rank,
|
||||||
|
const int world_size,
|
||||||
|
const std::vector<at::Tensor>& inputTensors,
|
||||||
|
const std::vector<at::Tensor>& outputTensors) {
|
||||||
|
auto inSize = (!inputTensors.empty()) ? inputTensors[0].numel() : 0;
|
||||||
|
auto outSize = (!outputTensors.empty()) ? outputTensors[0].numel() : 0;
|
||||||
|
auto dtype =
|
||||||
|
(!outputTensors.empty()) ? outputTensors[0].scalar_type() : at::kByte;
|
||||||
|
auto devType = (!outputTensors.empty()) ? outputTensors[0].device().type()
|
||||||
|
: c10::DeviceType::CPU;
|
||||||
|
auto now = std::chrono::system_clock::now();
|
||||||
|
static auto startTS = now;
|
||||||
|
int64_t time_since_begin =
|
||||||
|
std::chrono::duration_cast<std::chrono::nanoseconds>(now - startTS)
|
||||||
|
.count();
|
||||||
|
|
||||||
|
// TODO: get markers from torch profiler if enabled
|
||||||
|
|
||||||
|
// common fields for all operations
|
||||||
|
std::string cur_trace_ = c10::str(
|
||||||
|
"\n\t\t\"markers\": [",
|
||||||
|
curBlocks_,
|
||||||
|
"]",
|
||||||
|
",\n\t\t\"startTime_ns\": ",
|
||||||
|
time_since_begin,
|
||||||
|
",\n\t\t\"comms\": \"",
|
||||||
|
commName,
|
||||||
|
"\"",
|
||||||
|
",\n\t\t\"req\": ",
|
||||||
|
workReq,
|
||||||
|
",\n\t\t\"seqnum\": ",
|
||||||
|
seqnum++,
|
||||||
|
",\n\t\t\"world_size\": ",
|
||||||
|
world_size);
|
||||||
|
|
||||||
|
if (inSize > 0 || outSize > 0) {
|
||||||
|
// for most collectives - append msg sizes, data type, device type
|
||||||
|
cur_trace_ = c10::str(
|
||||||
|
cur_trace_,
|
||||||
|
",\n\t\t\"in_msg_size\": ",
|
||||||
|
inSize,
|
||||||
|
",\n\t\t\"out_msg_size\": ",
|
||||||
|
outSize,
|
||||||
|
",\n\t\t\"dtype\": \"",
|
||||||
|
at::toString(dtype),
|
||||||
|
"\",\n\t\t\"devType\": \"",
|
||||||
|
c10::DeviceTypeName(devType),
|
||||||
|
"\"");
|
||||||
|
}
|
||||||
|
if (curRoot_ != -1) {
|
||||||
|
// append root rank if applicable, e.g., broadcast, gather, scatter
|
||||||
|
cur_trace_ = c10::str(cur_trace_, ",\n\t\t\"root\": ", curRoot_);
|
||||||
|
}
|
||||||
|
if (!curInSplitSizes_.empty() || !curOutSplitSizes_.empty()) {
|
||||||
|
// append input and output splits if applicable, e.g., ALLTOALL_BASE
|
||||||
|
cur_trace_ = c10::str(
|
||||||
|
cur_trace_,
|
||||||
|
",\n\t\t\"in_split\": [",
|
||||||
|
c10::Join(",", curInSplitSizes_),
|
||||||
|
"]"
|
||||||
|
",\n\t\t\"out_split\": [",
|
||||||
|
c10::Join(",", curOutSplitSizes_),
|
||||||
|
"]");
|
||||||
|
}
|
||||||
|
comms_trace_.push_back(c10::str("\n\t{", cur_trace_, "\n\t}"));
|
||||||
|
|
||||||
|
// record the trace to kineto trace if applicable
|
||||||
|
RECORD_PARAM_COMMS(
|
||||||
|
rank,
|
||||||
|
commName.c_str(),
|
||||||
|
inSize,
|
||||||
|
outSize,
|
||||||
|
dtype,
|
||||||
|
curInSplitSizes_,
|
||||||
|
curOutSplitSizes_);
|
||||||
|
|
||||||
|
// reset optional field
|
||||||
|
curRoot_ = -1;
|
||||||
|
curInSplitSizes_ = {};
|
||||||
|
curOutSplitSizes_ = {};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace c10d
|
||||||
|
|
||||||
|
#endif // USE_C10D_UCC
|
||||||
58
torch/csrc/distributed/c10d/UCCTracing.hpp
Normal file
58
torch/csrc/distributed/c10d/UCCTracing.hpp
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifdef USE_C10D_UCC
|
||||||
|
|
||||||
|
#include <c10d/UCCUtils.hpp>
|
||||||
|
|
||||||
|
namespace c10d {
|
||||||
|
|
||||||
|
#define RECORD_COMMS_TRACE( \
|
||||||
|
_comms_tracer, _work, _opType, _rank, _comm_size, _inTensors, _outTensors) \
|
||||||
|
do { \
|
||||||
|
if (torch_ucc_config.enable_comms_logger) { \
|
||||||
|
_comms_tracer->recordComms( \
|
||||||
|
opTypeToString(_opType), \
|
||||||
|
(uintptr_t)_work.get(), \
|
||||||
|
_rank, \
|
||||||
|
_comm_size, \
|
||||||
|
_inTensors, \
|
||||||
|
_outTensors); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
// interfaces to collect communication traces
|
||||||
|
class TORCH_API CommTraceLogger : public torch::CustomClassHolder {
|
||||||
|
private:
|
||||||
|
std::vector<std::string> comms_trace_;
|
||||||
|
std::vector<std::string> curBlocks_; /* unused */
|
||||||
|
std::vector<int64_t> curOutSplitSizes_;
|
||||||
|
std::vector<int64_t> curInSplitSizes_;
|
||||||
|
int curRoot_ = -1;
|
||||||
|
unsigned long seqnum = 0;
|
||||||
|
|
||||||
|
public:
|
||||||
|
void setCurBlock(const std::string& name); /* unused */
|
||||||
|
void popBlock(); /* unused */
|
||||||
|
// record root info if applicable, e.g., broadcast, gather, scatter
|
||||||
|
void recordOptionalInfo(int root = -1);
|
||||||
|
// record input/output splits of Alltoallv
|
||||||
|
void recordOptionalInfo(
|
||||||
|
const std::vector<int64_t>& outputSplitSizes = {},
|
||||||
|
const std::vector<int64_t>& inputSplitSizes = {});
|
||||||
|
// record essential comms information
|
||||||
|
void recordComms(
|
||||||
|
const std::string& collName,
|
||||||
|
const uintptr_t workReq = 0,
|
||||||
|
const int rank = -1,
|
||||||
|
const int world_size = -1,
|
||||||
|
const std::vector<at::Tensor>& inputTensors = {},
|
||||||
|
const std::vector<at::Tensor>& outputTensor = {});
|
||||||
|
// return collected comms traces
|
||||||
|
std::vector<std::string>& getCommsTrace() {
|
||||||
|
return comms_trace_;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace c10d
|
||||||
|
|
||||||
|
#endif // USE_C10D_UCC
|
||||||
285
torch/csrc/distributed/c10d/UCCUtils.cpp
Normal file
285
torch/csrc/distributed/c10d/UCCUtils.cpp
Normal file
|
|
@ -0,0 +1,285 @@
|
||||||
|
#ifdef USE_C10D_UCC
|
||||||
|
|
||||||
|
#include <c10d/UCCTracing.hpp>
|
||||||
|
#include <c10d/UCCUtils.hpp>
|
||||||
|
|
||||||
|
namespace c10d {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Constants for store keys.
|
||||||
|
constexpr char kTeamRank[] = "teamr";
|
||||||
|
constexpr char kAllGatherDone[] = "ag_done";
|
||||||
|
constexpr char kAllGatherFree[] = "ag_free";
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
CommUCX::CommUCX(
|
||||||
|
int comm_size,
|
||||||
|
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
|
||||||
|
: CommBase(logger) {
|
||||||
|
ucp_params_t params;
|
||||||
|
ucp_config_t* config;
|
||||||
|
ucs_status_t st;
|
||||||
|
ucp_worker_params_t worker_params;
|
||||||
|
ucp_lib_attr_t ucp_attr;
|
||||||
|
|
||||||
|
ucp_attr.field_mask = UCP_LIB_ATTR_FIELD_MAX_THREAD_LEVEL;
|
||||||
|
TORCH_UCX_CHECK(
|
||||||
|
ucp_lib_query(&ucp_attr), "failed to query UCP lib attributes");
|
||||||
|
TORCH_CHECK(
|
||||||
|
ucp_attr.max_thread_level == UCS_THREAD_MODE_MULTI,
|
||||||
|
"ucx library wasn't initialized with multithreading support, "
|
||||||
|
"please check ucx build options");
|
||||||
|
TORCH_UCX_CHECK(
|
||||||
|
ucp_config_read("TORCH", nullptr, &config), "failed to read UCP config");
|
||||||
|
|
||||||
|
memset(¶ms, 0, sizeof(ucp_params_t));
|
||||||
|
params.field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_REQUEST_SIZE |
|
||||||
|
UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_TAG_SENDER_MASK |
|
||||||
|
UCP_PARAM_FIELD_REQUEST_INIT | UCP_PARAM_FIELD_REQUEST_CLEANUP;
|
||||||
|
params.request_size = sizeof(ucc_coll_req_t);
|
||||||
|
params.features = UCP_FEATURE_TAG;
|
||||||
|
params.estimated_num_eps = comm_size;
|
||||||
|
params.tag_sender_mask = TORCH_UCX_RANK_MASK;
|
||||||
|
params.request_init = [](void* request) {
|
||||||
|
static_cast<ucc_coll_req_h>(request)->status = UCC_INPROGRESS;
|
||||||
|
};
|
||||||
|
params.request_cleanup = [](void*) {};
|
||||||
|
TORCH_UCX_CHECK(
|
||||||
|
ucp_init(¶ms, config, &context), "failed to init UCP context");
|
||||||
|
ucp_config_release(config);
|
||||||
|
|
||||||
|
memset(&worker_params, 0, sizeof(ucp_worker_params_t));
|
||||||
|
worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
|
||||||
|
worker_params.thread_mode = UCS_THREAD_MODE_MULTI;
|
||||||
|
st = ucp_worker_create(context, &worker_params, &worker);
|
||||||
|
if (st != UCS_OK) {
|
||||||
|
TORCH_UCC_LOG_ERROR(
|
||||||
|
TORCH_UCC_INIT,
|
||||||
|
c10::str("UCX failed to create UCP worker:", ucs_status_string(st)));
|
||||||
|
ucp_cleanup(context);
|
||||||
|
throw std::runtime_error(ucs_status_string(st));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommUCX::progress() {
|
||||||
|
ucp_worker_progress(worker);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommUCX::free_request(ucc_coll_req_h request) {
|
||||||
|
request->status = UCC_INPROGRESS;
|
||||||
|
ucp_request_free(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
CommUCX::~CommUCX() {
|
||||||
|
if (worker != nullptr) {
|
||||||
|
ucp_worker_destroy(worker);
|
||||||
|
}
|
||||||
|
if (context != nullptr) {
|
||||||
|
ucp_cleanup(context);
|
||||||
|
}
|
||||||
|
worker = nullptr;
|
||||||
|
context = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ucc_status_t oob_allgather(
|
||||||
|
void* sbuf,
|
||||||
|
void* rbuf,
|
||||||
|
size_t msglen,
|
||||||
|
void* coll_info,
|
||||||
|
void** req) {
|
||||||
|
auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(coll_info);
|
||||||
|
TORCH_CHECK(info != nullptr);
|
||||||
|
std::vector<uint8_t> val = std::vector<uint8_t>(
|
||||||
|
reinterpret_cast<uint8_t*>(sbuf),
|
||||||
|
reinterpret_cast<uint8_t*>(sbuf) + msglen);
|
||||||
|
try {
|
||||||
|
info->store->set(info->getKey(kTeamRank + std::to_string(info->rank)), val);
|
||||||
|
info->rbuf = rbuf;
|
||||||
|
info->msglen = msglen;
|
||||||
|
*req = coll_info;
|
||||||
|
} catch (std::exception& ex) {
|
||||||
|
LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
|
||||||
|
<< "[" << ex.what() << "]";
|
||||||
|
return UCC_ERR_NO_MESSAGE;
|
||||||
|
}
|
||||||
|
return UCC_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
ucc_status_t oob_allgather_test(void* req) {
|
||||||
|
auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(req);
|
||||||
|
TORCH_CHECK(info != nullptr);
|
||||||
|
|
||||||
|
try {
|
||||||
|
for (int r = 0; r < info->size; r++) {
|
||||||
|
if (!info->store->check({info->getKey(kTeamRank + std::to_string(r))})) {
|
||||||
|
return UCC_INPROGRESS;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int r = 0; r < info->size; r++) {
|
||||||
|
std::vector<uint8_t> data =
|
||||||
|
info->store->get(info->getKey(kTeamRank + std::to_string(r)));
|
||||||
|
memcpy(
|
||||||
|
(void*)((ptrdiff_t)info->rbuf + info->msglen * r),
|
||||||
|
data.data(),
|
||||||
|
info->msglen);
|
||||||
|
}
|
||||||
|
} catch (std::exception& ex) {
|
||||||
|
LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
|
||||||
|
<< "[" << ex.what() << "]";
|
||||||
|
return UCC_ERR_NO_MESSAGE;
|
||||||
|
}
|
||||||
|
return UCC_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
ucc_status_t oob_allgather_free(void* req) {
|
||||||
|
auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(req);
|
||||||
|
TORCH_CHECK(info != nullptr);
|
||||||
|
try {
|
||||||
|
int num_done = info->store->add({info->getKey(kAllGatherDone)}, 1);
|
||||||
|
if (num_done == info->size) {
|
||||||
|
info->store->deleteKey(info->getKey(kAllGatherDone));
|
||||||
|
// Note: to avoid race condition, it's important to remove all keys in
|
||||||
|
// oob_allgather_free first and only after that signal completion to
|
||||||
|
// other ranks
|
||||||
|
for (const auto r : c10::irange(info->size)) {
|
||||||
|
info->store->deleteKey(info->getKey(kTeamRank + std::to_string(r)));
|
||||||
|
}
|
||||||
|
for (const auto r : c10::irange(info->size)) {
|
||||||
|
info->store->add({info->getKey(kAllGatherFree + std::to_string(r))}, 1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
info->store->wait(
|
||||||
|
{info->getKey(kAllGatherFree + std::to_string(info->rank))});
|
||||||
|
}
|
||||||
|
info->store->deleteKey(
|
||||||
|
info->getKey(kAllGatherFree + std::to_string(info->rank)));
|
||||||
|
} catch (std::exception& ex) {
|
||||||
|
LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
|
||||||
|
<< "[" << ex.what() << "]";
|
||||||
|
return UCC_ERR_NO_MESSAGE;
|
||||||
|
}
|
||||||
|
return UCC_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommUCC::CommUCC(
|
||||||
|
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
|
||||||
|
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
|
||||||
|
: CommBase(logger) {
|
||||||
|
ucc_lib_config_h lib_config;
|
||||||
|
ucc_context_config_h context_config;
|
||||||
|
ucc_lib_params_t lib_params;
|
||||||
|
ucc_context_params_t context_params;
|
||||||
|
ucc_status_t st;
|
||||||
|
|
||||||
|
TORCH_UCC_CHECK(
|
||||||
|
ucc_lib_config_read("TORCH", nullptr, &lib_config),
|
||||||
|
"failed to read UCC lib config");
|
||||||
|
memset(&lib_params, 0, sizeof(ucc_lib_params_t));
|
||||||
|
lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE;
|
||||||
|
lib_params.thread_mode = UCC_THREAD_MULTIPLE;
|
||||||
|
TORCH_UCC_CHECK(
|
||||||
|
ucc_init(&lib_params, lib_config, &lib), "failed to init UCC lib");
|
||||||
|
ucc_lib_config_release(lib_config);
|
||||||
|
ucc_lib_attr_t lib_attr;
|
||||||
|
lib_attr.mask = UCC_LIB_ATTR_FIELD_THREAD_MODE;
|
||||||
|
TORCH_UCC_CHECK(
|
||||||
|
ucc_lib_get_attr(lib, &lib_attr), "failed to query for lib attr");
|
||||||
|
TORCH_CHECK(
|
||||||
|
lib_attr.thread_mode == UCC_THREAD_MULTIPLE,
|
||||||
|
"ucc library wasn't initialized with multithreading support, "
|
||||||
|
"please check ucc build options");
|
||||||
|
st = ucc_context_config_read(lib, NULL, &context_config);
|
||||||
|
if (st != UCC_OK) {
|
||||||
|
// FIXME: would this cause deadlock if only one rank fails?
|
||||||
|
TORCH_UCC_CHECK(
|
||||||
|
ucc_finalize(lib),
|
||||||
|
"failed to finalize UCC library when failing to read UCC context config");
|
||||||
|
TORCH_UCC_LOG_ERROR(
|
||||||
|
TORCH_UCC_INIT,
|
||||||
|
c10::str("failed to read UCC context config: ", ucc_status_string(st)));
|
||||||
|
throw std::runtime_error(ucc_status_string(st));
|
||||||
|
}
|
||||||
|
st = ucc_context_config_modify(
|
||||||
|
context_config,
|
||||||
|
NULL,
|
||||||
|
"ESTIMATED_NUM_EPS",
|
||||||
|
std::to_string(oob->size).c_str());
|
||||||
|
if (st != UCC_OK) {
|
||||||
|
ucc_context_config_release(context_config);
|
||||||
|
ucc_finalize(lib);
|
||||||
|
TORCH_UCC_LOG_ERROR(
|
||||||
|
TORCH_UCC_INIT,
|
||||||
|
c10::str(
|
||||||
|
"UCC failed to modify UCC context config: ",
|
||||||
|
ucc_status_string(st)));
|
||||||
|
throw std::runtime_error(ucc_status_string(st));
|
||||||
|
}
|
||||||
|
memset(&context_params, 0, sizeof(ucc_context_params_t));
|
||||||
|
context_params.mask =
|
||||||
|
UCC_CONTEXT_PARAM_FIELD_TYPE | UCC_CONTEXT_PARAM_FIELD_OOB;
|
||||||
|
context_params.type = UCC_CONTEXT_SHARED;
|
||||||
|
context_params.oob.n_oob_eps = oob->size;
|
||||||
|
context_params.oob.oob_ep = oob->rank;
|
||||||
|
context_params.oob.allgather = oob_allgather;
|
||||||
|
context_params.oob.req_test = oob_allgather_test;
|
||||||
|
context_params.oob.req_free = oob_allgather_free;
|
||||||
|
context_params.oob.coll_info = oob.get();
|
||||||
|
st = ucc_context_create(lib, &context_params, context_config, &context);
|
||||||
|
ucc_context_config_release(context_config);
|
||||||
|
if (st != UCC_OK) {
|
||||||
|
TORCH_UCC_CHECK(
|
||||||
|
ucc_finalize(lib),
|
||||||
|
"failed to finalize UCC library when failing to creat UCC context");
|
||||||
|
TORCH_UCC_LOG_ERROR(
|
||||||
|
TORCH_UCC_INIT,
|
||||||
|
c10::str("UCC failed to create UCC context: ", ucc_status_string(st)));
|
||||||
|
throw std::runtime_error(ucc_status_string(st));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommUCC::progress() {
|
||||||
|
TORCH_UCC_CHECK(
|
||||||
|
ucc_context_progress(context), "failed to progress UCC collective");
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommUCC::free_request(ucc_coll_req_h request) {
|
||||||
|
TORCH_UCC_CHECK(
|
||||||
|
ucc_collective_finalize(request), "failed to release UCC request");
|
||||||
|
}
|
||||||
|
|
||||||
|
CommUCC::~CommUCC() {
|
||||||
|
if (context != nullptr) {
|
||||||
|
TORCH_UCC_CHECK(
|
||||||
|
ucc_context_destroy(context), "failed to destory UCC context");
|
||||||
|
}
|
||||||
|
if (lib != nullptr) {
|
||||||
|
TORCH_UCC_CHECK(ucc_finalize(lib), "failed to finalize UCC library");
|
||||||
|
}
|
||||||
|
context = nullptr;
|
||||||
|
lib = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ProcessGroupUCCLogger::getLogPrefix(torch_ucc_phase_t phase) {
|
||||||
|
// caller can override the phase stored locally
|
||||||
|
torch_ucc_phase_t phase_ =
|
||||||
|
(local_phase != phase && phase != TORCH_UCC_UNKNOWN) ? phase
|
||||||
|
: local_phase;
|
||||||
|
return c10::str(log_prefix, "[", ucc_phase_map.at(phase_), "]");
|
||||||
|
}
|
||||||
|
void ProcessGroupUCCLogger::setLogPrefix(std::string log_prefix_) {
|
||||||
|
log_prefix = log_prefix_;
|
||||||
|
}
|
||||||
|
|
||||||
|
ProcessGroupUCCLogger::ProcessGroupUCCLogger() {
|
||||||
|
setLogPrefix("[ProcessGroupUCC]");
|
||||||
|
}
|
||||||
|
ProcessGroupUCCLogger::ProcessGroupUCCLogger(
|
||||||
|
std::string log_prefix,
|
||||||
|
torch_ucc_phase_t phase)
|
||||||
|
: local_phase(phase) {
|
||||||
|
setLogPrefix(log_prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace c10d
|
||||||
|
|
||||||
|
#endif // USE_C10D_UCC
|
||||||
193
torch/csrc/distributed/c10d/UCCUtils.hpp
Normal file
193
torch/csrc/distributed/c10d/UCCUtils.hpp
Normal file
|
|
@ -0,0 +1,193 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifdef USE_C10D_UCC
|
||||||
|
|
||||||
|
#include <c10d/ProcessGroup.hpp>
|
||||||
|
#include <c10d/Store.hpp>
|
||||||
|
#include <ucc/api/ucc.h>
|
||||||
|
#include <ucp/api/ucp.h>
|
||||||
|
|
||||||
|
#define TORCH_UCX_COMM_BITS 15
|
||||||
|
#define TORCH_UCX_RANK_BITS 16
|
||||||
|
#define TORCH_UCX_TAG_BITS 32
|
||||||
|
#define TORCH_UCX_OOB_BITS 1
|
||||||
|
|
||||||
|
#define TORCH_UCX_COMM_BITS_OFFSET 0
|
||||||
|
#define TORCH_UCX_RANK_BITS_OFFSET TORCH_UCX_COMM_BITS
|
||||||
|
#define TORCH_UCX_TAG_BITS_OFFSET (TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS)
|
||||||
|
#define TORCH_UCX_OOB_BITS_OFFSET \
|
||||||
|
(TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS + TORCH_UCX_TAG_BITS)
|
||||||
|
|
||||||
|
#define TORCH_UCX_MAX_COMM ((((uint64_t)1) << TORCH_UCX_COMM_BITS) - 1)
|
||||||
|
#define TORCH_UCX_MAX_RANK ((((uint64_t)1) << TORCH_UCX_RANK_BITS) - 1)
|
||||||
|
#define TORCH_UCX_MAX_TAG ((((uint64_t)1) << TORCH_UCX_TAG_BITS) - 1)
|
||||||
|
#define TORCH_UCX_MAX_OOB ((((uint64_t)1) << TORCH_UCX_OOB_BITS) - 1)
|
||||||
|
|
||||||
|
#define TORCH_UCX_COMM_MASK (TORCH_UCX_MAX_COMM << TORCH_UCX_COMM_BITS_OFFSET)
|
||||||
|
#define TORCH_UCX_RANK_MASK (TORCH_UCX_MAX_RANK << TORCH_UCX_RANK_BITS_OFFSET)
|
||||||
|
#define TORCH_UCX_TAG_MASK (TORCH_UCX_MAX_TAG << TORCH_UCX_TAG_BITS_OFFSET)
|
||||||
|
#define TORCH_UCX_OOB_MASK (TORCH_UCX_MAX_OOB << TORCH_UCX_OOB_BITS_OFFSET)
|
||||||
|
|
||||||
|
namespace c10d {
|
||||||
|
|
||||||
|
// Macro to throw on a non-successful UCC return value.
|
||||||
|
#define TORCH_UCC_CHECK(_cmd, _error_msg) \
|
||||||
|
do { \
|
||||||
|
ucc_status_t result = _cmd; \
|
||||||
|
if (result != UCC_OK) { \
|
||||||
|
std::string err = c10::str( \
|
||||||
|
"[", \
|
||||||
|
std::string(__FILE__), \
|
||||||
|
":", \
|
||||||
|
std::to_string(__LINE__), \
|
||||||
|
"] ", \
|
||||||
|
logger->getLogPrefix(), \
|
||||||
|
_error_msg, \
|
||||||
|
", error code ", \
|
||||||
|
result, \
|
||||||
|
": ", \
|
||||||
|
ucc_status_string(result), \
|
||||||
|
", system error code ", \
|
||||||
|
errno); \
|
||||||
|
TORCH_CHECK(false, err); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
// Macro to throw on a non-successful UCX return value.
|
||||||
|
#define TORCH_UCX_CHECK(_cmd, _error_msg) \
|
||||||
|
do { \
|
||||||
|
ucs_status_t result = _cmd; \
|
||||||
|
if (result != UCS_OK) { \
|
||||||
|
std::string err = c10::str( \
|
||||||
|
"[", \
|
||||||
|
std::string(__FILE__), \
|
||||||
|
":", \
|
||||||
|
std::to_string(__LINE__), \
|
||||||
|
"] ", \
|
||||||
|
logger->getLogPrefix(), \
|
||||||
|
_error_msg, \
|
||||||
|
", error code ", \
|
||||||
|
result, \
|
||||||
|
": ", \
|
||||||
|
ucs_status_string(result), \
|
||||||
|
", system error code ", \
|
||||||
|
errno); \
|
||||||
|
TORCH_CHECK(false, err); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
// Macros to print logs with unified format
|
||||||
|
#define TORCH_UCC_LOG_ERROR(_phase, _msg) \
|
||||||
|
LOG(ERROR) << logger->getLogPrefix(_phase) << "[ERROR] " << _msg;
|
||||||
|
#define TORCH_UCC_LOG_INFO(_phase, _msg) \
|
||||||
|
LOG(INFO) << logger->getLogPrefix(_phase) << "[INFO] " << _msg;
|
||||||
|
#define TORCH_UCC_LOG_DEBUG(_phase, _msg) \
|
||||||
|
VLOG(1) << logger->getLogPrefix(_phase) << "[DEBUG] " << _msg;
|
||||||
|
|
||||||
|
enum torch_ucc_phase_t {
|
||||||
|
TORCH_UCC_UNKNOWN = -1,
|
||||||
|
TORCH_UCC_INIT,
|
||||||
|
TORCH_UCC_HEALTH_CHECK,
|
||||||
|
TORCH_UCC_READY,
|
||||||
|
TORCH_UCC_COLL_POST,
|
||||||
|
TORCH_UCC_COLL_PROGRESS,
|
||||||
|
TORCH_UCC_FINALIZE,
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::map<torch_ucc_phase_t, std::string> ucc_phase_map = {
|
||||||
|
{TORCH_UCC_UNKNOWN, "UNKNOWN"},
|
||||||
|
{TORCH_UCC_INIT, "INIT"},
|
||||||
|
{TORCH_UCC_HEALTH_CHECK, "HEALTH_CHECK"},
|
||||||
|
{TORCH_UCC_READY, "READY"},
|
||||||
|
{TORCH_UCC_COLL_POST, "COLL_POST"},
|
||||||
|
{TORCH_UCC_COLL_PROGRESS, "COLL_PROGRESS"},
|
||||||
|
{TORCH_UCC_FINALIZE, "FINALIZE"},
|
||||||
|
};
|
||||||
|
|
||||||
|
class CommTraceLogger;
|
||||||
|
|
||||||
|
class TORCH_API ProcessGroupUCCLogger : public torch::CustomClassHolder {
|
||||||
|
public:
|
||||||
|
ProcessGroupUCCLogger();
|
||||||
|
ProcessGroupUCCLogger(std::string log_prefix, torch_ucc_phase_t phase);
|
||||||
|
|
||||||
|
std::string getLogPrefix(torch_ucc_phase_t phase = TORCH_UCC_UNKNOWN);
|
||||||
|
void setLogPrefix(std::string log_prefix);
|
||||||
|
inline void setPhase(torch_ucc_phase_t phase) {
|
||||||
|
local_phase = phase;
|
||||||
|
}
|
||||||
|
|
||||||
|
void initCommsTracer();
|
||||||
|
void flushComms(int rank, int world_size);
|
||||||
|
std::shared_ptr<CommTraceLogger> trace_generator = nullptr;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::string log_prefix;
|
||||||
|
torch_ucc_phase_t local_phase = TORCH_UCC_UNKNOWN;
|
||||||
|
bool initialized_CommTraceLogger = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct torch_ucc_oob_coll_info_t {
|
||||||
|
c10::intrusive_ptr<Store> store;
|
||||||
|
uint32_t comm_id;
|
||||||
|
int rank;
|
||||||
|
int size;
|
||||||
|
void* rbuf;
|
||||||
|
size_t msglen;
|
||||||
|
std::string getKey(std::string key) {
|
||||||
|
return std::to_string(comm_id) + key;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class CommBase {
|
||||||
|
public:
|
||||||
|
CommBase(const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_)
|
||||||
|
: logger(logger_) {}
|
||||||
|
virtual void progress() = 0;
|
||||||
|
virtual void free_request(ucc_coll_req_h request) = 0;
|
||||||
|
virtual ~CommBase() {}
|
||||||
|
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
|
||||||
|
};
|
||||||
|
|
||||||
|
class CommUCX : public CommBase {
|
||||||
|
public:
|
||||||
|
ucp_context_h context{nullptr};
|
||||||
|
ucp_worker_h worker{nullptr};
|
||||||
|
|
||||||
|
public:
|
||||||
|
void progress() override;
|
||||||
|
void free_request(ucc_coll_req_h request) override;
|
||||||
|
CommUCX(
|
||||||
|
int comm_size,
|
||||||
|
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
|
||||||
|
~CommUCX();
|
||||||
|
};
|
||||||
|
|
||||||
|
class CommUCC : public CommBase {
|
||||||
|
public:
|
||||||
|
ucc_lib_h lib{nullptr};
|
||||||
|
ucc_context_h context{nullptr};
|
||||||
|
|
||||||
|
public:
|
||||||
|
void progress() override;
|
||||||
|
CommUCC(
|
||||||
|
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
|
||||||
|
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
|
||||||
|
void free_request(ucc_coll_req_h request) override;
|
||||||
|
~CommUCC();
|
||||||
|
};
|
||||||
|
|
||||||
|
ucc_status_t oob_allgather(
|
||||||
|
void* sbuf,
|
||||||
|
void* rbuf,
|
||||||
|
size_t msglen,
|
||||||
|
void* coll_info,
|
||||||
|
void** req);
|
||||||
|
|
||||||
|
ucc_status_t oob_allgather_test(void* req);
|
||||||
|
|
||||||
|
ucc_status_t oob_allgather_free(void* req);
|
||||||
|
|
||||||
|
} // namespace c10d
|
||||||
|
|
||||||
|
#endif // USE_C10D_UCC
|
||||||
|
|
@ -24,6 +24,10 @@
|
||||||
#include <c10d/ProcessGroupMPI.hpp>
|
#include <c10d/ProcessGroupMPI.hpp>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef USE_C10D_UCC
|
||||||
|
#include <c10d/ProcessGroupUCC.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <c10d/PrefixStore.hpp>
|
#include <c10d/PrefixStore.hpp>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <pybind11/chrono.h>
|
#include <pybind11/chrono.h>
|
||||||
|
|
@ -1557,6 +1561,25 @@ Example::
|
||||||
py::call_guard<py::gil_scoped_release>());
|
py::call_guard<py::gil_scoped_release>());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef USE_C10D_UCC
|
||||||
|
auto processGroupUCC =
|
||||||
|
intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupUCC>(
|
||||||
|
module, "ProcessGroupUCC", processGroup)
|
||||||
|
.def(
|
||||||
|
py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
|
||||||
|
int rank,
|
||||||
|
int size,
|
||||||
|
const std::chrono::milliseconds& timeout) {
|
||||||
|
return c10::make_intrusive<::c10d::ProcessGroupUCC>(
|
||||||
|
store, rank, size, timeout);
|
||||||
|
}),
|
||||||
|
py::arg("store"),
|
||||||
|
py::arg("rank"),
|
||||||
|
py::arg("size"),
|
||||||
|
py::arg("timeout") = kProcessGroupDefaultTimeout,
|
||||||
|
py::call_guard<py::gil_scoped_release>());
|
||||||
|
#endif
|
||||||
|
|
||||||
py::class_<
|
py::class_<
|
||||||
::c10d::ProcessGroup::Work,
|
::c10d::ProcessGroup::Work,
|
||||||
c10::intrusive_ptr<::c10d::ProcessGroup::Work>,
|
c10::intrusive_ptr<::c10d::ProcessGroup::Work>,
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401
|
||||||
_MPI_AVAILABLE = True
|
_MPI_AVAILABLE = True
|
||||||
_NCCL_AVAILABLE = True
|
_NCCL_AVAILABLE = True
|
||||||
_GLOO_AVAILABLE = True
|
_GLOO_AVAILABLE = True
|
||||||
|
_UCC_AVAILABLE = True
|
||||||
|
|
||||||
_pickler = pickle.Pickler
|
_pickler = pickle.Pickler
|
||||||
_unpickler = pickle.Unpickler
|
_unpickler = pickle.Unpickler
|
||||||
|
|
@ -59,6 +60,11 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_GLOO_AVAILABLE = False
|
_GLOO_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch._C._distributed_c10d import ProcessGroupUCC
|
||||||
|
except ImportError:
|
||||||
|
_UCC_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -86,7 +92,7 @@ def supports_complex(reduceOp: ReduceOp) -> bool:
|
||||||
|
|
||||||
class Backend(object):
|
class Backend(object):
|
||||||
"""
|
"""
|
||||||
An enum-like class of available backends: GLOO, NCCL, MPI, and other registered
|
An enum-like class of available backends: GLOO, NCCL, UCC, MPI, and other registered
|
||||||
backends.
|
backends.
|
||||||
|
|
||||||
The values of this class are lowercase strings, e.g., ``"gloo"``. They can
|
The values of this class are lowercase strings, e.g., ``"gloo"``. They can
|
||||||
|
|
@ -105,6 +111,7 @@ class Backend(object):
|
||||||
UNDEFINED = "undefined"
|
UNDEFINED = "undefined"
|
||||||
GLOO = "gloo"
|
GLOO = "gloo"
|
||||||
NCCL = "nccl"
|
NCCL = "nccl"
|
||||||
|
UCC = "ucc"
|
||||||
MPI = "mpi"
|
MPI = "mpi"
|
||||||
TCP = "tcp"
|
TCP = "tcp"
|
||||||
_plugins: Dict[str, Callable] = {}
|
_plugins: Dict[str, Callable] = {}
|
||||||
|
|
@ -122,7 +129,7 @@ class Backend(object):
|
||||||
)
|
)
|
||||||
elif value == Backend.UNDEFINED:
|
elif value == Backend.UNDEFINED:
|
||||||
raise ValueError("Invalid backend: '{}'".format(name))
|
raise ValueError("Invalid backend: '{}'".format(name))
|
||||||
elif value != Backend.GLOO and value != Backend.NCCL and value != Backend.MPI:
|
elif value != Backend.GLOO and value != Backend.NCCL and value != Backend.UCC and value != Backend.MPI:
|
||||||
value = name.lower()
|
value = name.lower()
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
@ -145,9 +152,12 @@ class Backend(object):
|
||||||
.. note:: This support of 3rd party backend is experimental and subject to change.
|
.. note:: This support of 3rd party backend is experimental and subject to change.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
assert not hasattr(Backend, name.upper()), (
|
# Allow UCC plugin if Pytorch is not built with native support.
|
||||||
f"{name.upper()} c10d backend already exist"
|
# TODO: remove this exception once UCC plugin is fully deprecated.
|
||||||
)
|
if (name != Backend.UCC or (name == Backend.UCC and is_ucc_available())):
|
||||||
|
assert not hasattr(Backend, name.upper()), (
|
||||||
|
f"{name.upper()} c10d backend already exist"
|
||||||
|
)
|
||||||
assert name.upper() not in Backend._plugins, (
|
assert name.upper() not in Backend._plugins, (
|
||||||
f"{name.upper()} c10d backend creator function already exist"
|
f"{name.upper()} c10d backend creator function already exist"
|
||||||
)
|
)
|
||||||
|
|
@ -412,6 +422,13 @@ def is_gloo_available():
|
||||||
return _GLOO_AVAILABLE
|
return _GLOO_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
def is_ucc_available():
|
||||||
|
"""
|
||||||
|
Checks if the UCC backend is available.
|
||||||
|
"""
|
||||||
|
return _UCC_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
def is_initialized():
|
def is_initialized():
|
||||||
"""
|
"""
|
||||||
Checking if the default process group has been initialized
|
Checking if the default process group has been initialized
|
||||||
|
|
@ -511,12 +528,13 @@ def init_process_group(
|
||||||
Args:
|
Args:
|
||||||
backend (str or Backend): The backend to use. Depending on
|
backend (str or Backend): The backend to use. Depending on
|
||||||
build-time configurations, valid values include ``mpi``, ``gloo``,
|
build-time configurations, valid values include ``mpi``, ``gloo``,
|
||||||
and ``nccl``. This field should be given as a lowercase string
|
``nccl``, and ``ucc``. This field should be given as a lowercase
|
||||||
(e.g., ``"gloo"``), which can also be accessed via
|
string (e.g., ``"gloo"``), which can also be accessed via
|
||||||
:class:`Backend` attributes (e.g., ``Backend.GLOO``). If using
|
:class:`Backend` attributes (e.g., ``Backend.GLOO``). If using
|
||||||
multiple processes per machine with ``nccl`` backend, each process
|
multiple processes per machine with ``nccl`` backend, each process
|
||||||
must have exclusive access to every GPU it uses, as sharing GPUs
|
must have exclusive access to every GPU it uses, as sharing GPUs
|
||||||
between processes can result in deadlocks.
|
between processes can result in deadlocks. ``ucc`` backend is
|
||||||
|
experimental.
|
||||||
init_method (str, optional): URL specifying how to initialize the
|
init_method (str, optional): URL specifying how to initialize the
|
||||||
process group. Default is "env://" if no
|
process group. Default is "env://" if no
|
||||||
``init_method`` or ``store`` is specified.
|
``init_method`` or ``store`` is specified.
|
||||||
|
|
@ -547,6 +565,9 @@ def init_process_group(
|
||||||
continue executing user code since failed async NCCL operations
|
continue executing user code since failed async NCCL operations
|
||||||
might result in subsequent CUDA operations running on corrupted
|
might result in subsequent CUDA operations running on corrupted
|
||||||
data. Only one of these two environment variables should be set.
|
data. Only one of these two environment variables should be set.
|
||||||
|
For ``ucc``, blocking wait is supported similar to NCCL. However,
|
||||||
|
async error handling is done differently since with UCC we have
|
||||||
|
progress thread and not watch-dog thread.
|
||||||
group_name (str, optional, deprecated): Group name.
|
group_name (str, optional, deprecated): Group name.
|
||||||
pg_options (ProcessGroupOptions, optional): process group options
|
pg_options (ProcessGroupOptions, optional): process group options
|
||||||
specifying what additional options need to be passed in during
|
specifying what additional options need to be passed in during
|
||||||
|
|
@ -682,7 +703,7 @@ def _new_process_group_helper(
|
||||||
is_default_group = len(group_ranks) == 0
|
is_default_group = len(group_ranks) == 0
|
||||||
|
|
||||||
backend = Backend(backend)
|
backend = Backend(backend)
|
||||||
pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL]
|
pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL, ProcessGroupUCC]
|
||||||
if backend == Backend.MPI:
|
if backend == Backend.MPI:
|
||||||
if not is_mpi_available():
|
if not is_mpi_available():
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
@ -767,6 +788,33 @@ def _new_process_group_helper(
|
||||||
)
|
)
|
||||||
_pg_map[pg] = (Backend.NCCL, store)
|
_pg_map[pg] = (Backend.NCCL, store)
|
||||||
_pg_names[pg] = group_name
|
_pg_names[pg] = group_name
|
||||||
|
elif backend == Backend.UCC and is_ucc_available():
|
||||||
|
# TODO: once UCC plugin is fully deprecated, remove
|
||||||
|
# is_ucc_available() from above elif-condition and raise
|
||||||
|
# RuntimeError if is_ucc_available() returns false.
|
||||||
|
|
||||||
|
pg = ProcessGroupUCC(prefix_store, rank, world_size, timeout=timeout)
|
||||||
|
# In debug mode and if GLOO is available, wrap in a wrapper PG that
|
||||||
|
# enables enhanced collective checking for debugability.
|
||||||
|
if get_debug_level() == DebugLevel.DETAIL:
|
||||||
|
if not _GLOO_AVAILABLE:
|
||||||
|
logger.info(
|
||||||
|
"""TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but
|
||||||
|
GLOO is not available. Build with Gloo to
|
||||||
|
create a wrapper process group in debug mode
|
||||||
|
to aid collective desynchronization debugging."""
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pg = _create_process_group_wrapper(
|
||||||
|
wrapped_pg=pg,
|
||||||
|
store_prefix=group_name,
|
||||||
|
store=store,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
_pg_map[pg] = (Backend.UCC, store)
|
||||||
|
_pg_names[pg] = group_name
|
||||||
else:
|
else:
|
||||||
assert backend.upper() in Backend._plugins, (
|
assert backend.upper() in Backend._plugins, (
|
||||||
f"unknown c10d backend type {backend.upper()}"
|
f"unknown c10d backend type {backend.upper()}"
|
||||||
|
|
@ -1064,7 +1112,7 @@ def batch_isend_irecv(p2p_op_list):
|
||||||
Send or Receive a batch of tensors asynchronously and return a list of requests.
|
Send or Receive a batch of tensors asynchronously and return a list of requests.
|
||||||
|
|
||||||
Process each of the operations in ``p2p_op_list`` and return the corresponding
|
Process each of the operations in ``p2p_op_list`` and return the corresponding
|
||||||
requests. NCCL and Gloo backend are currently supported.
|
requests. NCCL, Gloo, and UCC backend are currently supported.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
p2p_op_list: A list of point-to-point operations(type of each operator is
|
p2p_op_list: A list of point-to-point operations(type of each operator is
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user