[PFC] Native UCC process group for Pytorch (#79918)

Summary:
This diff integrates UCC process group as a native component of Pytorch Distributed core. It is based on the existing torch-ucc (https://github.com/facebookresearch/torch_ucc) as the wrapper for UCC collective communication library.
The environment and cmake variables are named in mirroring to the existing process groups such as NCCL and Gloo. Specifically,
- USE_UCC: enables UCC PG. This defaults to OFF, so there is no breakage of existing builds that do not have UCX/UCC external libraries.
- USE_SYSTEM_UCC: uses external UCX and UCC shared libraries that are set accordingly with UCX_HOME and UCC_HOME.

Currently, this diff only supports USE_SYSTEM_UCC=ON, i.e., requiring users to specify external libraries for UCX and UCC. In subsequent diffs, we will add UCX and UCC repos as third-party dependencies in pytorch/third-party.

Test Plan:
Passed Torch-UCC tests that invoke UCC process group. For example:

$ sh test/start_test.sh test/torch_allreduce_test.py --backend gloo --use-cuda
...
Test allreduce: succeeded

Differential Revision: D36973688

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79918
Approved by: https://github.com/kwen2501, https://github.com/kingchc
This commit is contained in:
Terry Lam 2022-07-12 14:45:44 +00:00 committed by PyTorch MergeBot
parent 268d910170
commit 54bdaf76d6
16 changed files with 3152 additions and 10 deletions

View File

@ -308,6 +308,14 @@ option(USE_DISTRIBUTED "Use distributed" ON)
cmake_dependent_option( cmake_dependent_option(
USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON
"USE_DISTRIBUTED" OFF) "USE_DISTRIBUTED" OFF)
cmake_dependent_option(
USE_UCC "Use UCC. Only available if USE_DISTRIBUTED is on." OFF
"USE_DISTRIBUTED" OFF)
cmake_dependent_option(
USE_SYSTEM_UCC "Use system-wide UCC" OFF
"USE_UCC" OFF)
cmake_dependent_option(
USE_C10D_UCC "USE C10D UCC" ON "USE_DISTRIBUTED;USE_UCC" OFF)
cmake_dependent_option( cmake_dependent_option(
USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON
"USE_DISTRIBUTED" OFF) "USE_DISTRIBUTED" OFF)

View File

@ -744,6 +744,9 @@ libtorch_cuda_distributed_base_sources = [
libtorch_cuda_distributed_extra_sources = [ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/NCCLUtils.cpp", "torch/csrc/distributed/c10d/NCCLUtils.cpp",
"torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp", "torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
"torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
"torch/csrc/distributed/c10d/UCCTracing.cpp",
"torch/csrc/distributed/c10d/UCCUtils.cpp",
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
] ]

View File

@ -911,6 +911,11 @@ if(HAVE_SOVERSION)
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION}) VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
endif() endif()
if(USE_UCC)
target_link_libraries(torch_cpu PRIVATE __caffe2_ucc)
target_compile_definitions(torch_cpu PRIVATE USE_UCC)
endif()
if(USE_ROCM) if(USE_ROCM)
filter_list(__caffe2_hip_srcs_cpp Caffe2_HIP_SRCS "\\.(cu|hip)$") filter_list(__caffe2_hip_srcs_cpp Caffe2_HIP_SRCS "\\.(cu|hip)$")
set_source_files_properties(${__caffe2_hip_srcs_cpp} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) set_source_files_properties(${__caffe2_hip_srcs_cpp} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
@ -972,6 +977,13 @@ elseif(USE_CUDA)
target_link_libraries(torch_cuda PRIVATE __caffe2_nccl) target_link_libraries(torch_cuda PRIVATE __caffe2_nccl)
target_compile_definitions(torch_cuda PRIVATE USE_NCCL) target_compile_definitions(torch_cuda PRIVATE USE_NCCL)
endif() endif()
if(USE_UCC AND BUILD_SPLIT_CUDA)
target_link_libraries(torch_cuda_cpp PRIVATE __caffe2_ucc)
target_compile_definitions(torch_cuda_cpp PRIVATE USE_UCC)
elseif(USE_UCC)
target_link_libraries(torch_cuda PRIVATE __caffe2_ucc)
target_compile_definitions(torch_cuda PRIVATE USE_UCC)
endif()
if(BUILD_LAZY_CUDA_LINALG) if(BUILD_LAZY_CUDA_LINALG)
add_library(torch_cuda_linalg ${ATen_CUDA_LINALG_SRCS}) add_library(torch_cuda_linalg ${ATen_CUDA_LINALG_SRCS})
target_compile_definitions(torch_cuda_linalg PRIVATE USE_CUDA BUILD_LAZY_CUDA_LINALG) target_compile_definitions(torch_cuda_linalg PRIVATE USE_CUDA BUILD_LAZY_CUDA_LINALG)
@ -1347,6 +1359,16 @@ if(USE_DISTRIBUTED)
if(USE_GLOO AND USE_C10D_GLOO) if(USE_GLOO AND USE_C10D_GLOO)
target_compile_definitions(torch_cpu PUBLIC USE_C10D_GLOO) target_compile_definitions(torch_cpu PUBLIC USE_C10D_GLOO)
endif() endif()
if(USE_UCC AND USE_C10D_UCC)
target_compile_definitions(torch_cpu PUBLIC USE_C10D_UCC)
if(USE_CUDA)
if(BUILD_SPLIT_CUDA)
target_compile_definitions(torch_cuda_cpp PUBLIC USE_C10D_UCC)
else()
target_compile_definitions(torch_cuda PUBLIC USE_C10D_UCC)
endif()
endif()
endif()
if(USE_NCCL AND USE_C10D_NCCL) if(USE_NCCL AND USE_C10D_NCCL)
if(USE_ROCM) if(USE_ROCM)
target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL) target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL)

View File

@ -1362,6 +1362,16 @@ if(USE_NCCL)
endif() endif()
endif() endif()
# ---[ UCC
if(USE_UCC)
if(NOT CMAKE_SYSTEM_NAME STREQUAL "Linux")
message(WARNING "UCC is currently only supported under Linux.")
caffe2_update_option(USE_UCC OFF)
else()
include(${CMAKE_CURRENT_LIST_DIR}/External/ucc.cmake)
endif()
endif()
# ---[ CUB # ---[ CUB
if(USE_CUDA) if(USE_CUDA)
find_package(CUB) find_package(CUB)

20
cmake/External/ucc.cmake vendored Normal file
View File

@ -0,0 +1,20 @@
if(NOT __UCC_INCLUDED)
set(__UCC_INCLUDED TRUE)
if(USE_SYSTEM_UCC)
set(UCX_HOME $ENV{UCX_HOME} CACHE PATH "UCX install directory")
set(UCC_HOME $ENV{UCC_HOME} CACHE PATH "UCC install directory")
add_library(__caffe2_ucc INTERFACE)
target_include_directories(__caffe2_ucc INTERFACE ${UCX_HOME}/include/)
target_include_directories(__caffe2_ucc INTERFACE ${UCC_HOME}/include/)
target_link_libraries(__caffe2_ucc INTERFACE ${UCX_HOME}/lib/libucp.so)
target_link_libraries(__caffe2_ucc INTERFACE ${UCX_HOME}/lib/libucs.so)
target_link_libraries(__caffe2_ucc INTERFACE ${UCC_HOME}/lib/libucc.so)
else()
message(FATAL_ERROR "USE_SYSTEM_UCC=OFF is not supported yet when using UCC")
endif()
endif()

View File

@ -146,6 +146,10 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}") message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}")
message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}") message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}")
endif() endif()
message(STATUS " USE_UCC : ${USE_UCC}")
if(${USE_UCC})
message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}")
endif()
message(STATUS " USE_NCCL : ${USE_NCCL}") message(STATUS " USE_NCCL : ${USE_NCCL}")
if(${USE_NCCL}) if(${USE_NCCL})
message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}") message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}")

View File

@ -251,6 +251,9 @@ if(USE_DISTRIBUTED)
if(USE_NCCL) if(USE_NCCL)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl)
endif() endif()
if(USE_UCC)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_ucc)
endif()
# Same for MPI. # Same for MPI.
if(USE_MPI) if(USE_MPI)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${MPI_CXX_LIBRARIES}) list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${MPI_CXX_LIBRARIES})
@ -284,6 +287,9 @@ if(USE_DEPLOY)
if(USE_GLOO AND USE_C10D_GLOO) if(USE_GLOO AND USE_C10D_GLOO)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_GLOO) list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_GLOO)
endif() endif()
if(USE_UCC AND USE_C10D_UCC)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_UCC)
endif()
if(USE_NCCL AND USE_C10D_NCCL) if(USE_NCCL AND USE_C10D_NCCL)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_NCCL) list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_NCCL)
# Put nccl headers on the include path. We are specifically only setting # Put nccl headers on the include path. We are specifically only setting

View File

@ -371,6 +371,15 @@ class ProcessGroupNCCL(ProcessGroup):
def _group_end() -> None: ... def _group_end() -> None: ...
... ...
class ProcessGroupUCC(ProcessGroup):
def __init__(
self,
store: Store,
rank: int,
size: int,
timeout: timedelta,
): ...
class ProcessGroupMPI(ProcessGroup): class ProcessGroupMPI(ProcessGroup):
def __init__( def __init__(
self, self,

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,407 @@
#pragma once
#ifdef USE_C10D_UCC
#include <c10d/UCCUtils.hpp>
#include <exception>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <c10d/Types.hpp>
#include <c10d/Utils.hpp>
#ifdef USE_CUDA
#include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDAStream.h>
#endif
namespace c10d {
#define TORCH_UCC_DEVICE_NOT_SET -2
#define TORCH_UCX_MAKE_P2P_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_TAG_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_comm)) << TORCH_UCX_COMM_BITS_OFFSET))
#define TORCH_UCX_MAKE_OOB_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_OOB_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_COMM_BITS_OFFSET))
#define TORCH_UCX_MAKE_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
} while (0)
#define TORCH_UCX_ANY_SOURCE (TORCH_UCX_MAX_RANK - 1)
#define TORCH_UCX_ANY_SOURCE_MASK (~TORCH_UCX_RANK_MASK)
#define TORCH_UCX_SPECIFIC_SOURCE_MASK ((uint64_t)-1)
#define TORCH_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
if ((_rank) == TORCH_UCX_ANY_SOURCE) { \
(_ucp_tag_mask) = TORCH_UCX_ANY_SOURCE_MASK; \
} else { \
(_ucp_tag_mask) = TORCH_UCX_SPECIFIC_SOURCE_MASK; \
} \
} while (0)
#define TORCH_UCX_MAKE_OOB_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
} while (0)
#define TORCH_UCX_MAKE_OOB_RECV_TAG( \
_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
(_ucp_tag_mask) = (uint64_t)-1; \
} while (0)
#ifdef USE_CUDA
#define SAVE_TENSORS(_TENSORS, _DATA) \
do { \
if ((_TENSORS)[0].device().is_cuda()) { \
for (const auto i : c10::irange((_TENSORS).size())) { \
c10::cuda::CUDACachingAllocator::recordStream( \
(_TENSORS)[i].storage().data_ptr(), (*stream)); \
} \
} else { \
(_DATA) = (_TENSORS); \
} \
} while (0)
#else
#define SAVE_TENSORS(_TENSORS, _DATA) (_DATA) = (_TENSORS);
#endif
constexpr const char* UCC_BACKEND_NAME = "ucc";
enum torch_ucx_tag_type_t { TORCH_UCX_P2P_TAG, TORCH_UCX_OOB_TAG };
struct event_pool_t {
#ifdef USE_CUDA
std::queue<std::unique_ptr<at::cuda::CUDAEvent>> event_pool;
#endif
std::mutex event_pool_mutex;
};
class Comm;
// UCC does not support multiple CUDA devices per process.
class TORCH_API ProcessGroupUCC : public ProcessGroup {
private:
void set_timeout(ucc_coll_args_t& args);
public:
class WorkData {
public:
std::vector<at::Tensor> src;
std::vector<at::Tensor> dst;
std::vector<at::Tensor> flat;
WorkData() {}
virtual ~WorkData() = default;
};
class AlltoallWorkData : public WorkData {
public:
AlltoallWorkData(int size)
: send_lengths(size),
send_offsets(size),
recv_lengths(size),
recv_offsets(size) {}
std::vector<uint64_t> send_lengths;
std::vector<uint64_t> send_offsets;
std::vector<uint64_t> recv_lengths;
std::vector<uint64_t> recv_offsets;
};
class AllgathervWorkData : public WorkData {
public:
AllgathervWorkData(int size) : recv_lengths(size), recv_offsets(size) {}
std::vector<uint64_t> recv_lengths;
std::vector<uint64_t> recv_offsets;
};
class ScattervWorkData : public WorkData {
public:
ScattervWorkData(int size) : send_lengths(size), send_offsets(size) {}
std::vector<uint64_t> send_lengths;
std::vector<uint64_t> send_offsets;
};
class ProgressEntry {
friend class ProcessGroupUCC;
friend class Comm;
public:
ProgressEntry(CommBase* comm, ucc_coll_req_h request)
: status_(UCC_INPROGRESS), comm_(comm), request_(request) {}
// Finalizes UCC status or exception of collective request.
void finalize(std::exception_ptr eptr = nullptr);
ucc_status_t status_;
CommBase* comm_;
ucc_coll_req_h request_;
std::unique_ptr<WorkData> data;
c10::intrusive_ptr<c10::ivalue::Future> future_;
std::exception_ptr eptr_;
};
class WorkUCC : public ProcessGroup::Work {
friend class ProcessGroupUCC;
friend class Comm;
public:
WorkUCC(OpType opType, const char* prof_title)
: ProcessGroup::Work(-1, opType, prof_title) {}
WorkUCC(
OpType opType,
const char* prof_title,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
: ProcessGroup::Work(-1, opType, prof_title), logger_(logger) {}
~WorkUCC();
void setException();
void setAndThrowException();
bool isCompleted() override;
bool isSuccess() const override;
bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
std::vector<at::Tensor> result() override;
#ifdef USE_CUDA
std::unique_ptr<at::cuda::CUDAEvent> fence = nullptr;
event_pool_t* ep = nullptr;
#endif
protected:
std::shared_ptr<ProgressEntry> entry_;
c10::intrusive_ptr<ProcessGroupUCCLogger> logger_;
private:
// The future returned by getFuture.
c10::intrusive_ptr<at::ivalue::Future> future_;
// Store a reference to collective's outputs, used by result
std::shared_ptr<std::vector<at::Tensor>> outputs_;
};
explicit ProcessGroupUCC(
const c10::intrusive_ptr<Store>& store,
int rank = -1,
int size = -1,
std::chrono::duration<float> timeout = kProcessGroupDefaultTimeout);
void initComm(c10::Device dev);
~ProcessGroupUCC() override;
const std::string getBackendName() const override {
return std::string(UCC_BACKEND_NAME);
}
#ifdef USE_CUDA
std::unique_ptr<at::cuda::CUDAEvent> getPooledEvent();
#endif
// Performs a health check by initializing dummy UCC & UCX communicators and
// then destroying them. This will help indicate and signal any
// UCC/UCX-related issues prior to the first collective. The actual
// initialization and subsequent destruction is ran on a separate thread and
// the main thread is signalled about timeouts/errors to report to the
// application.
void runHealthCheck();
template <typename PreProcess, typename PostProcess>
c10::intrusive_ptr<ProcessGroup::Work> collective_post(
OpType opType,
PreProcess preproc,
PostProcess postproc,
ucc_coll_args_t& coll,
std::unique_ptr<ProcessGroupUCC::WorkData> data,
c10::Device dev,
std::vector<at::Tensor>& inputTensors,
std::vector<at::Tensor>& outputTensors,
const char* prof_title);
c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) override;
static c10::intrusive_ptr<ProcessGroup> createProcessGroupUCC(
const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size,
const std::chrono::duration<float>& timeout);
protected:
const std::chrono::duration<float> timeout_;
std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
std::shared_ptr<Comm> comm = {nullptr};
uint32_t comm_id;
std::vector<ucp_ep_h> eps;
ucc_team_h team{nullptr};
ucc_ee_h cuda_ee{nullptr};
#ifdef USE_CUDA
std::unique_ptr<at::cuda::CUDAStream> stream = nullptr;
event_pool_t ep;
#endif
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
};
class Comm {
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
CommUCX ucx_comm;
CommUCC ucc_comm;
std::mutex mutex;
std::thread progress_thread;
std::condition_variable queue_produce_cv;
std::condition_variable queue_consume_cv;
std::deque<std::shared_ptr<ProcessGroupUCC::ProgressEntry>> progress_queue;
bool stop_progress_loop;
bool collective_inprogress;
torch_ucc_phase_t finalize_phase;
public:
c10::DeviceIndex cuda_device_index;
Comm(
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
c10::Device dev,
bool is_health_check);
~Comm();
// Connects UCX end points.
void ucx_connect_eps(
std::vector<ucp_ep_h>& eps,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
// Disconnects UCX end points.
void ucx_disconnect_eps(
std::vector<ucp_ep_h>& eps,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
void ucc_create_team(
ucc_team_h& team,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
void ucc_destroy_team(ucc_team_h& team);
c10::intrusive_ptr<ProcessGroup::Work> enqueue_p2p(
OpType opType,
ucc_coll_req_h request,
const char* prof_title);
#ifdef USE_CUDA
void enqueue_cuda_collective(
std::unique_ptr<ProcessGroupUCC::WorkData> data,
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
ucc_coll_args_t& coll,
ucc_team_h team,
ucc_ee_h ee);
#endif
void enqueue_collective(
std::unique_ptr<ProcessGroupUCC::WorkData> data,
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
ucc_coll_args_t& coll,
ucc_team_h team);
static std::shared_ptr<Comm> get_comm(
uint32_t& id,
c10::Device dev,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
bool is_health_check = false);
void progress_loop();
ucc_coll_req_h send_nb(
ucp_ep_h ep,
void* data,
ucs_memory_type_t mtype,
size_t size,
ucp_tag_t ucp_tag);
ucc_coll_req_h recv_nb(
void* data,
ucs_memory_type_t mtype,
size_t size,
ucp_tag_t ucp_tag,
ucp_tag_t ucp_tag_mask);
};
} // namespace c10d
#endif // USE_C10D_UCC

View File

@ -0,0 +1,176 @@
#ifdef USE_C10D_UCC
#include <c10d/UCCTracing.hpp>
#include <c10d/UCCUtils.hpp>
#include <c10d/ParamCommsUtils.hpp>
#include <sys/stat.h>
#include <cstdlib>
#include <ctime>
#include <fstream>
#ifdef FBCODE_CAFFE2
#include <c10d/UCCInternalUtils.hpp>
#endif
namespace c10d {
void ProcessGroupUCCLogger::initCommsTracer() {
trace_generator = std::make_shared<CommTraceLogger>();
initialized_CommTraceLogger = true;
}
void ProcessGroupUCCLogger::flushComms(int rank, int world_size) {
if (!initialized_CommTraceLogger ||
trace_generator->getCommsTrace().empty()) {
return;
}
std::string dirname = c10::str("ProcessGroupUCC_trace_np", world_size);
time_t now_ = time(0);
std::tm* ltm = localtime(&now_);
if (ltm) {
dirname += c10::str(
"_", (1 + ltm->tm_mon), "_", ltm->tm_mday, "_", (1900 + ltm->tm_year));
}
std::string fullpath = "/tmp/" + dirname;
char* user_path = std::getenv("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR");
if (user_path) {
fullpath = user_path;
}
std::string trace_filename = c10::str(fullpath, "/rank", rank, ".json");
std::ofstream _outfile;
if (!_outfile.is_open()) {
if (!mkdir(fullpath.c_str(), 0777)) {
LOG(INFO) << getLogPrefix() << "[INFO] failed to mkdir " << fullpath;
} else if (errno != EEXIST) {
return;
}
_outfile.open(trace_filename, std::ofstream::out | std::ofstream::trunc);
}
// flush the traced comms
if (_outfile.is_open()) {
_outfile << "[" << c10::Join(",", trace_generator->getCommsTrace())
<< "\n]";
_outfile.flush();
_outfile.close();
}
#ifdef FBCODE_CAFFE2
uploadTrace_internal(
trace_filename, dirname, c10::str("rank", rank, ".json"));
#endif
}
/* unused */
void CommTraceLogger::setCurBlock(const std::string& name) {
curBlocks_.push_back(
c10::str("\"", name, "\"")); // add quote marks for JSON format
}
/* unused */
void CommTraceLogger::popBlock() {
// TODO: remove specific name
curBlocks_.pop_back();
}
void CommTraceLogger::recordOptionalInfo(int root) {
curRoot_ = root;
}
void CommTraceLogger::recordOptionalInfo(
const std::vector<int64_t>& outputSplitSizes,
const std::vector<int64_t>& inputSplitSizes) {
curOutSplitSizes_ = outputSplitSizes;
curInSplitSizes_ = inputSplitSizes;
}
void CommTraceLogger::recordComms(
const std::string& commName,
const uintptr_t workReq,
const int rank,
const int world_size,
const std::vector<at::Tensor>& inputTensors,
const std::vector<at::Tensor>& outputTensors) {
auto inSize = (!inputTensors.empty()) ? inputTensors[0].numel() : 0;
auto outSize = (!outputTensors.empty()) ? outputTensors[0].numel() : 0;
auto dtype =
(!outputTensors.empty()) ? outputTensors[0].scalar_type() : at::kByte;
auto devType = (!outputTensors.empty()) ? outputTensors[0].device().type()
: c10::DeviceType::CPU;
auto now = std::chrono::system_clock::now();
static auto startTS = now;
int64_t time_since_begin =
std::chrono::duration_cast<std::chrono::nanoseconds>(now - startTS)
.count();
// TODO: get markers from torch profiler if enabled
// common fields for all operations
std::string cur_trace_ = c10::str(
"\n\t\t\"markers\": [",
curBlocks_,
"]",
",\n\t\t\"startTime_ns\": ",
time_since_begin,
",\n\t\t\"comms\": \"",
commName,
"\"",
",\n\t\t\"req\": ",
workReq,
",\n\t\t\"seqnum\": ",
seqnum++,
",\n\t\t\"world_size\": ",
world_size);
if (inSize > 0 || outSize > 0) {
// for most collectives - append msg sizes, data type, device type
cur_trace_ = c10::str(
cur_trace_,
",\n\t\t\"in_msg_size\": ",
inSize,
",\n\t\t\"out_msg_size\": ",
outSize,
",\n\t\t\"dtype\": \"",
at::toString(dtype),
"\",\n\t\t\"devType\": \"",
c10::DeviceTypeName(devType),
"\"");
}
if (curRoot_ != -1) {
// append root rank if applicable, e.g., broadcast, gather, scatter
cur_trace_ = c10::str(cur_trace_, ",\n\t\t\"root\": ", curRoot_);
}
if (!curInSplitSizes_.empty() || !curOutSplitSizes_.empty()) {
// append input and output splits if applicable, e.g., ALLTOALL_BASE
cur_trace_ = c10::str(
cur_trace_,
",\n\t\t\"in_split\": [",
c10::Join(",", curInSplitSizes_),
"]"
",\n\t\t\"out_split\": [",
c10::Join(",", curOutSplitSizes_),
"]");
}
comms_trace_.push_back(c10::str("\n\t{", cur_trace_, "\n\t}"));
// record the trace to kineto trace if applicable
RECORD_PARAM_COMMS(
rank,
commName.c_str(),
inSize,
outSize,
dtype,
curInSplitSizes_,
curOutSplitSizes_);
// reset optional field
curRoot_ = -1;
curInSplitSizes_ = {};
curOutSplitSizes_ = {};
}
} // namespace c10d
#endif // USE_C10D_UCC

View File

@ -0,0 +1,58 @@
#pragma once
#ifdef USE_C10D_UCC
#include <c10d/UCCUtils.hpp>
namespace c10d {
#define RECORD_COMMS_TRACE( \
_comms_tracer, _work, _opType, _rank, _comm_size, _inTensors, _outTensors) \
do { \
if (torch_ucc_config.enable_comms_logger) { \
_comms_tracer->recordComms( \
opTypeToString(_opType), \
(uintptr_t)_work.get(), \
_rank, \
_comm_size, \
_inTensors, \
_outTensors); \
} \
} while (0)
// interfaces to collect communication traces
class TORCH_API CommTraceLogger : public torch::CustomClassHolder {
private:
std::vector<std::string> comms_trace_;
std::vector<std::string> curBlocks_; /* unused */
std::vector<int64_t> curOutSplitSizes_;
std::vector<int64_t> curInSplitSizes_;
int curRoot_ = -1;
unsigned long seqnum = 0;
public:
void setCurBlock(const std::string& name); /* unused */
void popBlock(); /* unused */
// record root info if applicable, e.g., broadcast, gather, scatter
void recordOptionalInfo(int root = -1);
// record input/output splits of Alltoallv
void recordOptionalInfo(
const std::vector<int64_t>& outputSplitSizes = {},
const std::vector<int64_t>& inputSplitSizes = {});
// record essential comms information
void recordComms(
const std::string& collName,
const uintptr_t workReq = 0,
const int rank = -1,
const int world_size = -1,
const std::vector<at::Tensor>& inputTensors = {},
const std::vector<at::Tensor>& outputTensor = {});
// return collected comms traces
std::vector<std::string>& getCommsTrace() {
return comms_trace_;
}
};
} // namespace c10d
#endif // USE_C10D_UCC

View File

@ -0,0 +1,285 @@
#ifdef USE_C10D_UCC
#include <c10d/UCCTracing.hpp>
#include <c10d/UCCUtils.hpp>
namespace c10d {
namespace {
// Constants for store keys.
constexpr char kTeamRank[] = "teamr";
constexpr char kAllGatherDone[] = "ag_done";
constexpr char kAllGatherFree[] = "ag_free";
} // namespace
CommUCX::CommUCX(
int comm_size,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
: CommBase(logger) {
ucp_params_t params;
ucp_config_t* config;
ucs_status_t st;
ucp_worker_params_t worker_params;
ucp_lib_attr_t ucp_attr;
ucp_attr.field_mask = UCP_LIB_ATTR_FIELD_MAX_THREAD_LEVEL;
TORCH_UCX_CHECK(
ucp_lib_query(&ucp_attr), "failed to query UCP lib attributes");
TORCH_CHECK(
ucp_attr.max_thread_level == UCS_THREAD_MODE_MULTI,
"ucx library wasn't initialized with multithreading support, "
"please check ucx build options");
TORCH_UCX_CHECK(
ucp_config_read("TORCH", nullptr, &config), "failed to read UCP config");
memset(&params, 0, sizeof(ucp_params_t));
params.field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_REQUEST_SIZE |
UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_TAG_SENDER_MASK |
UCP_PARAM_FIELD_REQUEST_INIT | UCP_PARAM_FIELD_REQUEST_CLEANUP;
params.request_size = sizeof(ucc_coll_req_t);
params.features = UCP_FEATURE_TAG;
params.estimated_num_eps = comm_size;
params.tag_sender_mask = TORCH_UCX_RANK_MASK;
params.request_init = [](void* request) {
static_cast<ucc_coll_req_h>(request)->status = UCC_INPROGRESS;
};
params.request_cleanup = [](void*) {};
TORCH_UCX_CHECK(
ucp_init(&params, config, &context), "failed to init UCP context");
ucp_config_release(config);
memset(&worker_params, 0, sizeof(ucp_worker_params_t));
worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
worker_params.thread_mode = UCS_THREAD_MODE_MULTI;
st = ucp_worker_create(context, &worker_params, &worker);
if (st != UCS_OK) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_INIT,
c10::str("UCX failed to create UCP worker:", ucs_status_string(st)));
ucp_cleanup(context);
throw std::runtime_error(ucs_status_string(st));
}
}
void CommUCX::progress() {
ucp_worker_progress(worker);
}
void CommUCX::free_request(ucc_coll_req_h request) {
request->status = UCC_INPROGRESS;
ucp_request_free(request);
}
CommUCX::~CommUCX() {
if (worker != nullptr) {
ucp_worker_destroy(worker);
}
if (context != nullptr) {
ucp_cleanup(context);
}
worker = nullptr;
context = nullptr;
}
ucc_status_t oob_allgather(
void* sbuf,
void* rbuf,
size_t msglen,
void* coll_info,
void** req) {
auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(coll_info);
TORCH_CHECK(info != nullptr);
std::vector<uint8_t> val = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(sbuf),
reinterpret_cast<uint8_t*>(sbuf) + msglen);
try {
info->store->set(info->getKey(kTeamRank + std::to_string(info->rank)), val);
info->rbuf = rbuf;
info->msglen = msglen;
*req = coll_info;
} catch (std::exception& ex) {
LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
<< "[" << ex.what() << "]";
return UCC_ERR_NO_MESSAGE;
}
return UCC_OK;
}
ucc_status_t oob_allgather_test(void* req) {
auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(req);
TORCH_CHECK(info != nullptr);
try {
for (int r = 0; r < info->size; r++) {
if (!info->store->check({info->getKey(kTeamRank + std::to_string(r))})) {
return UCC_INPROGRESS;
}
}
for (int r = 0; r < info->size; r++) {
std::vector<uint8_t> data =
info->store->get(info->getKey(kTeamRank + std::to_string(r)));
memcpy(
(void*)((ptrdiff_t)info->rbuf + info->msglen * r),
data.data(),
info->msglen);
}
} catch (std::exception& ex) {
LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
<< "[" << ex.what() << "]";
return UCC_ERR_NO_MESSAGE;
}
return UCC_OK;
}
ucc_status_t oob_allgather_free(void* req) {
auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(req);
TORCH_CHECK(info != nullptr);
try {
int num_done = info->store->add({info->getKey(kAllGatherDone)}, 1);
if (num_done == info->size) {
info->store->deleteKey(info->getKey(kAllGatherDone));
// Note: to avoid race condition, it's important to remove all keys in
// oob_allgather_free first and only after that signal completion to
// other ranks
for (const auto r : c10::irange(info->size)) {
info->store->deleteKey(info->getKey(kTeamRank + std::to_string(r)));
}
for (const auto r : c10::irange(info->size)) {
info->store->add({info->getKey(kAllGatherFree + std::to_string(r))}, 1);
}
} else {
info->store->wait(
{info->getKey(kAllGatherFree + std::to_string(info->rank))});
}
info->store->deleteKey(
info->getKey(kAllGatherFree + std::to_string(info->rank)));
} catch (std::exception& ex) {
LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
<< "[" << ex.what() << "]";
return UCC_ERR_NO_MESSAGE;
}
return UCC_OK;
}
CommUCC::CommUCC(
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
: CommBase(logger) {
ucc_lib_config_h lib_config;
ucc_context_config_h context_config;
ucc_lib_params_t lib_params;
ucc_context_params_t context_params;
ucc_status_t st;
TORCH_UCC_CHECK(
ucc_lib_config_read("TORCH", nullptr, &lib_config),
"failed to read UCC lib config");
memset(&lib_params, 0, sizeof(ucc_lib_params_t));
lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE;
lib_params.thread_mode = UCC_THREAD_MULTIPLE;
TORCH_UCC_CHECK(
ucc_init(&lib_params, lib_config, &lib), "failed to init UCC lib");
ucc_lib_config_release(lib_config);
ucc_lib_attr_t lib_attr;
lib_attr.mask = UCC_LIB_ATTR_FIELD_THREAD_MODE;
TORCH_UCC_CHECK(
ucc_lib_get_attr(lib, &lib_attr), "failed to query for lib attr");
TORCH_CHECK(
lib_attr.thread_mode == UCC_THREAD_MULTIPLE,
"ucc library wasn't initialized with multithreading support, "
"please check ucc build options");
st = ucc_context_config_read(lib, NULL, &context_config);
if (st != UCC_OK) {
// FIXME: would this cause deadlock if only one rank fails?
TORCH_UCC_CHECK(
ucc_finalize(lib),
"failed to finalize UCC library when failing to read UCC context config");
TORCH_UCC_LOG_ERROR(
TORCH_UCC_INIT,
c10::str("failed to read UCC context config: ", ucc_status_string(st)));
throw std::runtime_error(ucc_status_string(st));
}
st = ucc_context_config_modify(
context_config,
NULL,
"ESTIMATED_NUM_EPS",
std::to_string(oob->size).c_str());
if (st != UCC_OK) {
ucc_context_config_release(context_config);
ucc_finalize(lib);
TORCH_UCC_LOG_ERROR(
TORCH_UCC_INIT,
c10::str(
"UCC failed to modify UCC context config: ",
ucc_status_string(st)));
throw std::runtime_error(ucc_status_string(st));
}
memset(&context_params, 0, sizeof(ucc_context_params_t));
context_params.mask =
UCC_CONTEXT_PARAM_FIELD_TYPE | UCC_CONTEXT_PARAM_FIELD_OOB;
context_params.type = UCC_CONTEXT_SHARED;
context_params.oob.n_oob_eps = oob->size;
context_params.oob.oob_ep = oob->rank;
context_params.oob.allgather = oob_allgather;
context_params.oob.req_test = oob_allgather_test;
context_params.oob.req_free = oob_allgather_free;
context_params.oob.coll_info = oob.get();
st = ucc_context_create(lib, &context_params, context_config, &context);
ucc_context_config_release(context_config);
if (st != UCC_OK) {
TORCH_UCC_CHECK(
ucc_finalize(lib),
"failed to finalize UCC library when failing to creat UCC context");
TORCH_UCC_LOG_ERROR(
TORCH_UCC_INIT,
c10::str("UCC failed to create UCC context: ", ucc_status_string(st)));
throw std::runtime_error(ucc_status_string(st));
}
}
void CommUCC::progress() {
TORCH_UCC_CHECK(
ucc_context_progress(context), "failed to progress UCC collective");
}
void CommUCC::free_request(ucc_coll_req_h request) {
TORCH_UCC_CHECK(
ucc_collective_finalize(request), "failed to release UCC request");
}
CommUCC::~CommUCC() {
if (context != nullptr) {
TORCH_UCC_CHECK(
ucc_context_destroy(context), "failed to destory UCC context");
}
if (lib != nullptr) {
TORCH_UCC_CHECK(ucc_finalize(lib), "failed to finalize UCC library");
}
context = nullptr;
lib = nullptr;
}
std::string ProcessGroupUCCLogger::getLogPrefix(torch_ucc_phase_t phase) {
// caller can override the phase stored locally
torch_ucc_phase_t phase_ =
(local_phase != phase && phase != TORCH_UCC_UNKNOWN) ? phase
: local_phase;
return c10::str(log_prefix, "[", ucc_phase_map.at(phase_), "]");
}
void ProcessGroupUCCLogger::setLogPrefix(std::string log_prefix_) {
log_prefix = log_prefix_;
}
ProcessGroupUCCLogger::ProcessGroupUCCLogger() {
setLogPrefix("[ProcessGroupUCC]");
}
ProcessGroupUCCLogger::ProcessGroupUCCLogger(
std::string log_prefix,
torch_ucc_phase_t phase)
: local_phase(phase) {
setLogPrefix(log_prefix);
}
} // namespace c10d
#endif // USE_C10D_UCC

View File

@ -0,0 +1,193 @@
#pragma once
#ifdef USE_C10D_UCC
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <ucc/api/ucc.h>
#include <ucp/api/ucp.h>
#define TORCH_UCX_COMM_BITS 15
#define TORCH_UCX_RANK_BITS 16
#define TORCH_UCX_TAG_BITS 32
#define TORCH_UCX_OOB_BITS 1
#define TORCH_UCX_COMM_BITS_OFFSET 0
#define TORCH_UCX_RANK_BITS_OFFSET TORCH_UCX_COMM_BITS
#define TORCH_UCX_TAG_BITS_OFFSET (TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS)
#define TORCH_UCX_OOB_BITS_OFFSET \
(TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS + TORCH_UCX_TAG_BITS)
#define TORCH_UCX_MAX_COMM ((((uint64_t)1) << TORCH_UCX_COMM_BITS) - 1)
#define TORCH_UCX_MAX_RANK ((((uint64_t)1) << TORCH_UCX_RANK_BITS) - 1)
#define TORCH_UCX_MAX_TAG ((((uint64_t)1) << TORCH_UCX_TAG_BITS) - 1)
#define TORCH_UCX_MAX_OOB ((((uint64_t)1) << TORCH_UCX_OOB_BITS) - 1)
#define TORCH_UCX_COMM_MASK (TORCH_UCX_MAX_COMM << TORCH_UCX_COMM_BITS_OFFSET)
#define TORCH_UCX_RANK_MASK (TORCH_UCX_MAX_RANK << TORCH_UCX_RANK_BITS_OFFSET)
#define TORCH_UCX_TAG_MASK (TORCH_UCX_MAX_TAG << TORCH_UCX_TAG_BITS_OFFSET)
#define TORCH_UCX_OOB_MASK (TORCH_UCX_MAX_OOB << TORCH_UCX_OOB_BITS_OFFSET)
namespace c10d {
// Macro to throw on a non-successful UCC return value.
#define TORCH_UCC_CHECK(_cmd, _error_msg) \
do { \
ucc_status_t result = _cmd; \
if (result != UCC_OK) { \
std::string err = c10::str( \
"[", \
std::string(__FILE__), \
":", \
std::to_string(__LINE__), \
"] ", \
logger->getLogPrefix(), \
_error_msg, \
", error code ", \
result, \
": ", \
ucc_status_string(result), \
", system error code ", \
errno); \
TORCH_CHECK(false, err); \
} \
} while (0)
// Macro to throw on a non-successful UCX return value.
#define TORCH_UCX_CHECK(_cmd, _error_msg) \
do { \
ucs_status_t result = _cmd; \
if (result != UCS_OK) { \
std::string err = c10::str( \
"[", \
std::string(__FILE__), \
":", \
std::to_string(__LINE__), \
"] ", \
logger->getLogPrefix(), \
_error_msg, \
", error code ", \
result, \
": ", \
ucs_status_string(result), \
", system error code ", \
errno); \
TORCH_CHECK(false, err); \
} \
} while (0)
// Macros to print logs with unified format
#define TORCH_UCC_LOG_ERROR(_phase, _msg) \
LOG(ERROR) << logger->getLogPrefix(_phase) << "[ERROR] " << _msg;
#define TORCH_UCC_LOG_INFO(_phase, _msg) \
LOG(INFO) << logger->getLogPrefix(_phase) << "[INFO] " << _msg;
#define TORCH_UCC_LOG_DEBUG(_phase, _msg) \
VLOG(1) << logger->getLogPrefix(_phase) << "[DEBUG] " << _msg;
enum torch_ucc_phase_t {
TORCH_UCC_UNKNOWN = -1,
TORCH_UCC_INIT,
TORCH_UCC_HEALTH_CHECK,
TORCH_UCC_READY,
TORCH_UCC_COLL_POST,
TORCH_UCC_COLL_PROGRESS,
TORCH_UCC_FINALIZE,
};
const std::map<torch_ucc_phase_t, std::string> ucc_phase_map = {
{TORCH_UCC_UNKNOWN, "UNKNOWN"},
{TORCH_UCC_INIT, "INIT"},
{TORCH_UCC_HEALTH_CHECK, "HEALTH_CHECK"},
{TORCH_UCC_READY, "READY"},
{TORCH_UCC_COLL_POST, "COLL_POST"},
{TORCH_UCC_COLL_PROGRESS, "COLL_PROGRESS"},
{TORCH_UCC_FINALIZE, "FINALIZE"},
};
class CommTraceLogger;
class TORCH_API ProcessGroupUCCLogger : public torch::CustomClassHolder {
public:
ProcessGroupUCCLogger();
ProcessGroupUCCLogger(std::string log_prefix, torch_ucc_phase_t phase);
std::string getLogPrefix(torch_ucc_phase_t phase = TORCH_UCC_UNKNOWN);
void setLogPrefix(std::string log_prefix);
inline void setPhase(torch_ucc_phase_t phase) {
local_phase = phase;
}
void initCommsTracer();
void flushComms(int rank, int world_size);
std::shared_ptr<CommTraceLogger> trace_generator = nullptr;
protected:
std::string log_prefix;
torch_ucc_phase_t local_phase = TORCH_UCC_UNKNOWN;
bool initialized_CommTraceLogger = false;
};
struct torch_ucc_oob_coll_info_t {
c10::intrusive_ptr<Store> store;
uint32_t comm_id;
int rank;
int size;
void* rbuf;
size_t msglen;
std::string getKey(std::string key) {
return std::to_string(comm_id) + key;
}
};
class CommBase {
public:
CommBase(const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_)
: logger(logger_) {}
virtual void progress() = 0;
virtual void free_request(ucc_coll_req_h request) = 0;
virtual ~CommBase() {}
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
};
class CommUCX : public CommBase {
public:
ucp_context_h context{nullptr};
ucp_worker_h worker{nullptr};
public:
void progress() override;
void free_request(ucc_coll_req_h request) override;
CommUCX(
int comm_size,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
~CommUCX();
};
class CommUCC : public CommBase {
public:
ucc_lib_h lib{nullptr};
ucc_context_h context{nullptr};
public:
void progress() override;
CommUCC(
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
void free_request(ucc_coll_req_h request) override;
~CommUCC();
};
ucc_status_t oob_allgather(
void* sbuf,
void* rbuf,
size_t msglen,
void* coll_info,
void** req);
ucc_status_t oob_allgather_test(void* req);
ucc_status_t oob_allgather_free(void* req);
} // namespace c10d
#endif // USE_C10D_UCC

View File

@ -24,6 +24,10 @@
#include <c10d/ProcessGroupMPI.hpp> #include <c10d/ProcessGroupMPI.hpp>
#endif #endif
#ifdef USE_C10D_UCC
#include <c10d/ProcessGroupUCC.hpp>
#endif
#include <c10d/PrefixStore.hpp> #include <c10d/PrefixStore.hpp>
#include <fmt/format.h> #include <fmt/format.h>
#include <pybind11/chrono.h> #include <pybind11/chrono.h>
@ -1557,6 +1561,25 @@ Example::
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
#endif #endif
#ifdef USE_C10D_UCC
auto processGroupUCC =
intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupUCC>(
module, "ProcessGroupUCC", processGroup)
.def(
py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size,
const std::chrono::milliseconds& timeout) {
return c10::make_intrusive<::c10d::ProcessGroupUCC>(
store, rank, size, timeout);
}),
py::arg("store"),
py::arg("rank"),
py::arg("size"),
py::arg("timeout") = kProcessGroupDefaultTimeout,
py::call_guard<py::gil_scoped_release>());
#endif
py::class_< py::class_<
::c10d::ProcessGroup::Work, ::c10d::ProcessGroup::Work,
c10::intrusive_ptr<::c10d::ProcessGroup::Work>, c10::intrusive_ptr<::c10d::ProcessGroup::Work>,

View File

@ -39,6 +39,7 @@ from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401
_MPI_AVAILABLE = True _MPI_AVAILABLE = True
_NCCL_AVAILABLE = True _NCCL_AVAILABLE = True
_GLOO_AVAILABLE = True _GLOO_AVAILABLE = True
_UCC_AVAILABLE = True
_pickler = pickle.Pickler _pickler = pickle.Pickler
_unpickler = pickle.Unpickler _unpickler = pickle.Unpickler
@ -59,6 +60,11 @@ try:
except ImportError: except ImportError:
_GLOO_AVAILABLE = False _GLOO_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupUCC
except ImportError:
_UCC_AVAILABLE = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -86,7 +92,7 @@ def supports_complex(reduceOp: ReduceOp) -> bool:
class Backend(object): class Backend(object):
""" """
An enum-like class of available backends: GLOO, NCCL, MPI, and other registered An enum-like class of available backends: GLOO, NCCL, UCC, MPI, and other registered
backends. backends.
The values of this class are lowercase strings, e.g., ``"gloo"``. They can The values of this class are lowercase strings, e.g., ``"gloo"``. They can
@ -105,6 +111,7 @@ class Backend(object):
UNDEFINED = "undefined" UNDEFINED = "undefined"
GLOO = "gloo" GLOO = "gloo"
NCCL = "nccl" NCCL = "nccl"
UCC = "ucc"
MPI = "mpi" MPI = "mpi"
TCP = "tcp" TCP = "tcp"
_plugins: Dict[str, Callable] = {} _plugins: Dict[str, Callable] = {}
@ -122,7 +129,7 @@ class Backend(object):
) )
elif value == Backend.UNDEFINED: elif value == Backend.UNDEFINED:
raise ValueError("Invalid backend: '{}'".format(name)) raise ValueError("Invalid backend: '{}'".format(name))
elif value != Backend.GLOO and value != Backend.NCCL and value != Backend.MPI: elif value != Backend.GLOO and value != Backend.NCCL and value != Backend.UCC and value != Backend.MPI:
value = name.lower() value = name.lower()
return value return value
@ -145,9 +152,12 @@ class Backend(object):
.. note:: This support of 3rd party backend is experimental and subject to change. .. note:: This support of 3rd party backend is experimental and subject to change.
""" """
assert not hasattr(Backend, name.upper()), ( # Allow UCC plugin if Pytorch is not built with native support.
f"{name.upper()} c10d backend already exist" # TODO: remove this exception once UCC plugin is fully deprecated.
) if (name != Backend.UCC or (name == Backend.UCC and is_ucc_available())):
assert not hasattr(Backend, name.upper()), (
f"{name.upper()} c10d backend already exist"
)
assert name.upper() not in Backend._plugins, ( assert name.upper() not in Backend._plugins, (
f"{name.upper()} c10d backend creator function already exist" f"{name.upper()} c10d backend creator function already exist"
) )
@ -412,6 +422,13 @@ def is_gloo_available():
return _GLOO_AVAILABLE return _GLOO_AVAILABLE
def is_ucc_available():
"""
Checks if the UCC backend is available.
"""
return _UCC_AVAILABLE
def is_initialized(): def is_initialized():
""" """
Checking if the default process group has been initialized Checking if the default process group has been initialized
@ -511,12 +528,13 @@ def init_process_group(
Args: Args:
backend (str or Backend): The backend to use. Depending on backend (str or Backend): The backend to use. Depending on
build-time configurations, valid values include ``mpi``, ``gloo``, build-time configurations, valid values include ``mpi``, ``gloo``,
and ``nccl``. This field should be given as a lowercase string ``nccl``, and ``ucc``. This field should be given as a lowercase
(e.g., ``"gloo"``), which can also be accessed via string (e.g., ``"gloo"``), which can also be accessed via
:class:`Backend` attributes (e.g., ``Backend.GLOO``). If using :class:`Backend` attributes (e.g., ``Backend.GLOO``). If using
multiple processes per machine with ``nccl`` backend, each process multiple processes per machine with ``nccl`` backend, each process
must have exclusive access to every GPU it uses, as sharing GPUs must have exclusive access to every GPU it uses, as sharing GPUs
between processes can result in deadlocks. between processes can result in deadlocks. ``ucc`` backend is
experimental.
init_method (str, optional): URL specifying how to initialize the init_method (str, optional): URL specifying how to initialize the
process group. Default is "env://" if no process group. Default is "env://" if no
``init_method`` or ``store`` is specified. ``init_method`` or ``store`` is specified.
@ -547,6 +565,9 @@ def init_process_group(
continue executing user code since failed async NCCL operations continue executing user code since failed async NCCL operations
might result in subsequent CUDA operations running on corrupted might result in subsequent CUDA operations running on corrupted
data. Only one of these two environment variables should be set. data. Only one of these two environment variables should be set.
For ``ucc``, blocking wait is supported similar to NCCL. However,
async error handling is done differently since with UCC we have
progress thread and not watch-dog thread.
group_name (str, optional, deprecated): Group name. group_name (str, optional, deprecated): Group name.
pg_options (ProcessGroupOptions, optional): process group options pg_options (ProcessGroupOptions, optional): process group options
specifying what additional options need to be passed in during specifying what additional options need to be passed in during
@ -682,7 +703,7 @@ def _new_process_group_helper(
is_default_group = len(group_ranks) == 0 is_default_group = len(group_ranks) == 0
backend = Backend(backend) backend = Backend(backend)
pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL] pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL, ProcessGroupUCC]
if backend == Backend.MPI: if backend == Backend.MPI:
if not is_mpi_available(): if not is_mpi_available():
raise RuntimeError( raise RuntimeError(
@ -767,6 +788,33 @@ def _new_process_group_helper(
) )
_pg_map[pg] = (Backend.NCCL, store) _pg_map[pg] = (Backend.NCCL, store)
_pg_names[pg] = group_name _pg_names[pg] = group_name
elif backend == Backend.UCC and is_ucc_available():
# TODO: once UCC plugin is fully deprecated, remove
# is_ucc_available() from above elif-condition and raise
# RuntimeError if is_ucc_available() returns false.
pg = ProcessGroupUCC(prefix_store, rank, world_size, timeout=timeout)
# In debug mode and if GLOO is available, wrap in a wrapper PG that
# enables enhanced collective checking for debugability.
if get_debug_level() == DebugLevel.DETAIL:
if not _GLOO_AVAILABLE:
logger.info(
"""TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but
GLOO is not available. Build with Gloo to
create a wrapper process group in debug mode
to aid collective desynchronization debugging."""
)
else:
pg = _create_process_group_wrapper(
wrapped_pg=pg,
store_prefix=group_name,
store=store,
rank=rank,
world_size=world_size,
timeout=timeout,
)
_pg_map[pg] = (Backend.UCC, store)
_pg_names[pg] = group_name
else: else:
assert backend.upper() in Backend._plugins, ( assert backend.upper() in Backend._plugins, (
f"unknown c10d backend type {backend.upper()}" f"unknown c10d backend type {backend.upper()}"
@ -1064,7 +1112,7 @@ def batch_isend_irecv(p2p_op_list):
Send or Receive a batch of tensors asynchronously and return a list of requests. Send or Receive a batch of tensors asynchronously and return a list of requests.
Process each of the operations in ``p2p_op_list`` and return the corresponding Process each of the operations in ``p2p_op_list`` and return the corresponding
requests. NCCL and Gloo backend are currently supported. requests. NCCL, Gloo, and UCC backend are currently supported.
Args: Args:
p2p_op_list: A list of point-to-point operations(type of each operator is p2p_op_list: A list of point-to-point operations(type of each operator is