[PFC] Native UCC process group for Pytorch (#79918)

Summary:
This diff integrates UCC process group as a native component of Pytorch Distributed core. It is based on the existing torch-ucc (https://github.com/facebookresearch/torch_ucc) as the wrapper for UCC collective communication library.
The environment and cmake variables are named in mirroring to the existing process groups such as NCCL and Gloo. Specifically,
- USE_UCC: enables UCC PG. This defaults to OFF, so there is no breakage of existing builds that do not have UCX/UCC external libraries.
- USE_SYSTEM_UCC: uses external UCX and UCC shared libraries that are set accordingly with UCX_HOME and UCC_HOME.

Currently, this diff only supports USE_SYSTEM_UCC=ON, i.e., requiring users to specify external libraries for UCX and UCC. In subsequent diffs, we will add UCX and UCC repos as third-party dependencies in pytorch/third-party.

Test Plan:
Passed Torch-UCC tests that invoke UCC process group. For example:

$ sh test/start_test.sh test/torch_allreduce_test.py --backend gloo --use-cuda
...
Test allreduce: succeeded

Differential Revision: D36973688

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79918
Approved by: https://github.com/kwen2501, https://github.com/kingchc
This commit is contained in:
Terry Lam 2022-07-12 14:45:44 +00:00 committed by PyTorch MergeBot
parent 268d910170
commit 54bdaf76d6
16 changed files with 3152 additions and 10 deletions

View File

@ -308,6 +308,14 @@ option(USE_DISTRIBUTED "Use distributed" ON)
cmake_dependent_option(
USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON
"USE_DISTRIBUTED" OFF)
cmake_dependent_option(
USE_UCC "Use UCC. Only available if USE_DISTRIBUTED is on." OFF
"USE_DISTRIBUTED" OFF)
cmake_dependent_option(
USE_SYSTEM_UCC "Use system-wide UCC" OFF
"USE_UCC" OFF)
cmake_dependent_option(
USE_C10D_UCC "USE C10D UCC" ON "USE_DISTRIBUTED;USE_UCC" OFF)
cmake_dependent_option(
USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON
"USE_DISTRIBUTED" OFF)

View File

@ -744,6 +744,9 @@ libtorch_cuda_distributed_base_sources = [
libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/NCCLUtils.cpp",
"torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
"torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
"torch/csrc/distributed/c10d/UCCTracing.cpp",
"torch/csrc/distributed/c10d/UCCUtils.cpp",
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
]

View File

@ -911,6 +911,11 @@ if(HAVE_SOVERSION)
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
endif()
if(USE_UCC)
target_link_libraries(torch_cpu PRIVATE __caffe2_ucc)
target_compile_definitions(torch_cpu PRIVATE USE_UCC)
endif()
if(USE_ROCM)
filter_list(__caffe2_hip_srcs_cpp Caffe2_HIP_SRCS "\\.(cu|hip)$")
set_source_files_properties(${__caffe2_hip_srcs_cpp} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
@ -972,6 +977,13 @@ elseif(USE_CUDA)
target_link_libraries(torch_cuda PRIVATE __caffe2_nccl)
target_compile_definitions(torch_cuda PRIVATE USE_NCCL)
endif()
if(USE_UCC AND BUILD_SPLIT_CUDA)
target_link_libraries(torch_cuda_cpp PRIVATE __caffe2_ucc)
target_compile_definitions(torch_cuda_cpp PRIVATE USE_UCC)
elseif(USE_UCC)
target_link_libraries(torch_cuda PRIVATE __caffe2_ucc)
target_compile_definitions(torch_cuda PRIVATE USE_UCC)
endif()
if(BUILD_LAZY_CUDA_LINALG)
add_library(torch_cuda_linalg ${ATen_CUDA_LINALG_SRCS})
target_compile_definitions(torch_cuda_linalg PRIVATE USE_CUDA BUILD_LAZY_CUDA_LINALG)
@ -1347,6 +1359,16 @@ if(USE_DISTRIBUTED)
if(USE_GLOO AND USE_C10D_GLOO)
target_compile_definitions(torch_cpu PUBLIC USE_C10D_GLOO)
endif()
if(USE_UCC AND USE_C10D_UCC)
target_compile_definitions(torch_cpu PUBLIC USE_C10D_UCC)
if(USE_CUDA)
if(BUILD_SPLIT_CUDA)
target_compile_definitions(torch_cuda_cpp PUBLIC USE_C10D_UCC)
else()
target_compile_definitions(torch_cuda PUBLIC USE_C10D_UCC)
endif()
endif()
endif()
if(USE_NCCL AND USE_C10D_NCCL)
if(USE_ROCM)
target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL)

View File

@ -1362,6 +1362,16 @@ if(USE_NCCL)
endif()
endif()
# ---[ UCC
if(USE_UCC)
if(NOT CMAKE_SYSTEM_NAME STREQUAL "Linux")
message(WARNING "UCC is currently only supported under Linux.")
caffe2_update_option(USE_UCC OFF)
else()
include(${CMAKE_CURRENT_LIST_DIR}/External/ucc.cmake)
endif()
endif()
# ---[ CUB
if(USE_CUDA)
find_package(CUB)

20
cmake/External/ucc.cmake vendored Normal file
View File

@ -0,0 +1,20 @@
if(NOT __UCC_INCLUDED)
set(__UCC_INCLUDED TRUE)
if(USE_SYSTEM_UCC)
set(UCX_HOME $ENV{UCX_HOME} CACHE PATH "UCX install directory")
set(UCC_HOME $ENV{UCC_HOME} CACHE PATH "UCC install directory")
add_library(__caffe2_ucc INTERFACE)
target_include_directories(__caffe2_ucc INTERFACE ${UCX_HOME}/include/)
target_include_directories(__caffe2_ucc INTERFACE ${UCC_HOME}/include/)
target_link_libraries(__caffe2_ucc INTERFACE ${UCX_HOME}/lib/libucp.so)
target_link_libraries(__caffe2_ucc INTERFACE ${UCX_HOME}/lib/libucs.so)
target_link_libraries(__caffe2_ucc INTERFACE ${UCC_HOME}/lib/libucc.so)
else()
message(FATAL_ERROR "USE_SYSTEM_UCC=OFF is not supported yet when using UCC")
endif()
endif()

View File

@ -146,6 +146,10 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}")
message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}")
endif()
message(STATUS " USE_UCC : ${USE_UCC}")
if(${USE_UCC})
message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}")
endif()
message(STATUS " USE_NCCL : ${USE_NCCL}")
if(${USE_NCCL})
message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}")

View File

@ -251,6 +251,9 @@ if(USE_DISTRIBUTED)
if(USE_NCCL)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl)
endif()
if(USE_UCC)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_ucc)
endif()
# Same for MPI.
if(USE_MPI)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${MPI_CXX_LIBRARIES})
@ -284,6 +287,9 @@ if(USE_DEPLOY)
if(USE_GLOO AND USE_C10D_GLOO)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_GLOO)
endif()
if(USE_UCC AND USE_C10D_UCC)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_UCC)
endif()
if(USE_NCCL AND USE_C10D_NCCL)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_NCCL)
# Put nccl headers on the include path. We are specifically only setting

View File

@ -371,6 +371,15 @@ class ProcessGroupNCCL(ProcessGroup):
def _group_end() -> None: ...
...
class ProcessGroupUCC(ProcessGroup):
def __init__(
self,
store: Store,
rank: int,
size: int,
timeout: timedelta,
): ...
class ProcessGroupMPI(ProcessGroup):
def __init__(
self,

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,407 @@
#pragma once
#ifdef USE_C10D_UCC
#include <c10d/UCCUtils.hpp>
#include <exception>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <c10d/Types.hpp>
#include <c10d/Utils.hpp>
#ifdef USE_CUDA
#include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDAStream.h>
#endif
namespace c10d {
#define TORCH_UCC_DEVICE_NOT_SET -2
#define TORCH_UCX_MAKE_P2P_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_TAG_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_comm)) << TORCH_UCX_COMM_BITS_OFFSET))
#define TORCH_UCX_MAKE_OOB_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_OOB_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_COMM_BITS_OFFSET))
#define TORCH_UCX_MAKE_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
} while (0)
#define TORCH_UCX_ANY_SOURCE (TORCH_UCX_MAX_RANK - 1)
#define TORCH_UCX_ANY_SOURCE_MASK (~TORCH_UCX_RANK_MASK)
#define TORCH_UCX_SPECIFIC_SOURCE_MASK ((uint64_t)-1)
#define TORCH_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
if ((_rank) == TORCH_UCX_ANY_SOURCE) { \
(_ucp_tag_mask) = TORCH_UCX_ANY_SOURCE_MASK; \
} else { \
(_ucp_tag_mask) = TORCH_UCX_SPECIFIC_SOURCE_MASK; \
} \
} while (0)
#define TORCH_UCX_MAKE_OOB_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
} while (0)
#define TORCH_UCX_MAKE_OOB_RECV_TAG( \
_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
(_ucp_tag_mask) = (uint64_t)-1; \
} while (0)
#ifdef USE_CUDA
#define SAVE_TENSORS(_TENSORS, _DATA) \
do { \
if ((_TENSORS)[0].device().is_cuda()) { \
for (const auto i : c10::irange((_TENSORS).size())) { \
c10::cuda::CUDACachingAllocator::recordStream( \
(_TENSORS)[i].storage().data_ptr(), (*stream)); \
} \
} else { \
(_DATA) = (_TENSORS); \
} \
} while (0)
#else
#define SAVE_TENSORS(_TENSORS, _DATA) (_DATA) = (_TENSORS);
#endif
constexpr const char* UCC_BACKEND_NAME = "ucc";
enum torch_ucx_tag_type_t { TORCH_UCX_P2P_TAG, TORCH_UCX_OOB_TAG };
struct event_pool_t {
#ifdef USE_CUDA
std::queue<std::unique_ptr<at::cuda::CUDAEvent>> event_pool;
#endif
std::mutex event_pool_mutex;
};
class Comm;
// UCC does not support multiple CUDA devices per process.
class TORCH_API ProcessGroupUCC : public ProcessGroup {
private:
void set_timeout(ucc_coll_args_t& args);
public:
class WorkData {
public:
std::vector<at::Tensor> src;
std::vector<at::Tensor> dst;
std::vector<at::Tensor> flat;
WorkData() {}
virtual ~WorkData() = default;
};
class AlltoallWorkData : public WorkData {
public:
AlltoallWorkData(int size)
: send_lengths(size),
send_offsets(size),
recv_lengths(size),
recv_offsets(size) {}
std::vector<uint64_t> send_lengths;
std::vector<uint64_t> send_offsets;
std::vector<uint64_t> recv_lengths;
std::vector<uint64_t> recv_offsets;
};
class AllgathervWorkData : public WorkData {
public:
AllgathervWorkData(int size) : recv_lengths(size), recv_offsets(size) {}
std::vector<uint64_t> recv_lengths;
std::vector<uint64_t> recv_offsets;
};
class ScattervWorkData : public WorkData {
public:
ScattervWorkData(int size) : send_lengths(size), send_offsets(size) {}
std::vector<uint64_t> send_lengths;
std::vector<uint64_t> send_offsets;
};
class ProgressEntry {
friend class ProcessGroupUCC;
friend class Comm;
public:
ProgressEntry(CommBase* comm, ucc_coll_req_h request)
: status_(UCC_INPROGRESS), comm_(comm), request_(request) {}
// Finalizes UCC status or exception of collective request.
void finalize(std::exception_ptr eptr = nullptr);
ucc_status_t status_;
CommBase* comm_;
ucc_coll_req_h request_;
std::unique_ptr<WorkData> data;
c10::intrusive_ptr<c10::ivalue::Future> future_;
std::exception_ptr eptr_;
};
class WorkUCC : public ProcessGroup::Work {
friend class ProcessGroupUCC;
friend class Comm;
public:
WorkUCC(OpType opType, const char* prof_title)
: ProcessGroup::Work(-1, opType, prof_title) {}
WorkUCC(
OpType opType,
const char* prof_title,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
: ProcessGroup::Work(-1, opType, prof_title), logger_(logger) {}
~WorkUCC();
void setException();
void setAndThrowException();
bool isCompleted() override;
bool isSuccess() const override;
bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
std::vector<at::Tensor> result() override;
#ifdef USE_CUDA
std::unique_ptr<at::cuda::CUDAEvent> fence = nullptr;
event_pool_t* ep = nullptr;
#endif
protected:
std::shared_ptr<ProgressEntry> entry_;
c10::intrusive_ptr<ProcessGroupUCCLogger> logger_;
private:
// The future returned by getFuture.
c10::intrusive_ptr<at::ivalue::Future> future_;
// Store a reference to collective's outputs, used by result
std::shared_ptr<std::vector<at::Tensor>> outputs_;
};
explicit ProcessGroupUCC(
const c10::intrusive_ptr<Store>& store,
int rank = -1,
int size = -1,
std::chrono::duration<float> timeout = kProcessGroupDefaultTimeout);
void initComm(c10::Device dev);
~ProcessGroupUCC() override;
const std::string getBackendName() const override {
return std::string(UCC_BACKEND_NAME);
}
#ifdef USE_CUDA
std::unique_ptr<at::cuda::CUDAEvent> getPooledEvent();
#endif
// Performs a health check by initializing dummy UCC & UCX communicators and
// then destroying them. This will help indicate and signal any
// UCC/UCX-related issues prior to the first collective. The actual
// initialization and subsequent destruction is ran on a separate thread and
// the main thread is signalled about timeouts/errors to report to the
// application.
void runHealthCheck();
template <typename PreProcess, typename PostProcess>
c10::intrusive_ptr<ProcessGroup::Work> collective_post(
OpType opType,
PreProcess preproc,
PostProcess postproc,
ucc_coll_args_t& coll,
std::unique_ptr<ProcessGroupUCC::WorkData> data,
c10::Device dev,
std::vector<at::Tensor>& inputTensors,
std::vector<at::Tensor>& outputTensors,
const char* prof_title);
c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) override;
static c10::intrusive_ptr<ProcessGroup> createProcessGroupUCC(
const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size,
const std::chrono::duration<float>& timeout);
protected:
const std::chrono::duration<float> timeout_;
std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
std::shared_ptr<Comm> comm = {nullptr};
uint32_t comm_id;
std::vector<ucp_ep_h> eps;
ucc_team_h team{nullptr};
ucc_ee_h cuda_ee{nullptr};
#ifdef USE_CUDA
std::unique_ptr<at::cuda::CUDAStream> stream = nullptr;
event_pool_t ep;
#endif
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
};
class Comm {
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
CommUCX ucx_comm;
CommUCC ucc_comm;
std::mutex mutex;
std::thread progress_thread;
std::condition_variable queue_produce_cv;
std::condition_variable queue_consume_cv;
std::deque<std::shared_ptr<ProcessGroupUCC::ProgressEntry>> progress_queue;
bool stop_progress_loop;
bool collective_inprogress;
torch_ucc_phase_t finalize_phase;
public:
c10::DeviceIndex cuda_device_index;
Comm(
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
c10::Device dev,
bool is_health_check);
~Comm();
// Connects UCX end points.
void ucx_connect_eps(
std::vector<ucp_ep_h>& eps,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
// Disconnects UCX end points.
void ucx_disconnect_eps(
std::vector<ucp_ep_h>& eps,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
void ucc_create_team(
ucc_team_h& team,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
void ucc_destroy_team(ucc_team_h& team);
c10::intrusive_ptr<ProcessGroup::Work> enqueue_p2p(
OpType opType,
ucc_coll_req_h request,
const char* prof_title);
#ifdef USE_CUDA
void enqueue_cuda_collective(
std::unique_ptr<ProcessGroupUCC::WorkData> data,
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
ucc_coll_args_t& coll,
ucc_team_h team,
ucc_ee_h ee);
#endif
void enqueue_collective(
std::unique_ptr<ProcessGroupUCC::WorkData> data,
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
ucc_coll_args_t& coll,
ucc_team_h team);
static std::shared_ptr<Comm> get_comm(
uint32_t& id,
c10::Device dev,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
bool is_health_check = false);
void progress_loop();
ucc_coll_req_h send_nb(
ucp_ep_h ep,
void* data,
ucs_memory_type_t mtype,
size_t size,
ucp_tag_t ucp_tag);
ucc_coll_req_h recv_nb(
void* data,
ucs_memory_type_t mtype,
size_t size,
ucp_tag_t ucp_tag,
ucp_tag_t ucp_tag_mask);
};
} // namespace c10d
#endif // USE_C10D_UCC

View File

@ -0,0 +1,176 @@
#ifdef USE_C10D_UCC
#include <c10d/UCCTracing.hpp>
#include <c10d/UCCUtils.hpp>
#include <c10d/ParamCommsUtils.hpp>
#include <sys/stat.h>
#include <cstdlib>
#include <ctime>
#include <fstream>
#ifdef FBCODE_CAFFE2
#include <c10d/UCCInternalUtils.hpp>
#endif
namespace c10d {
void ProcessGroupUCCLogger::initCommsTracer() {
trace_generator = std::make_shared<CommTraceLogger>();
initialized_CommTraceLogger = true;
}
void ProcessGroupUCCLogger::flushComms(int rank, int world_size) {
if (!initialized_CommTraceLogger ||
trace_generator->getCommsTrace().empty()) {
return;
}
std::string dirname = c10::str("ProcessGroupUCC_trace_np", world_size);
time_t now_ = time(0);
std::tm* ltm = localtime(&now_);
if (ltm) {
dirname += c10::str(
"_", (1 + ltm->tm_mon), "_", ltm->tm_mday, "_", (1900 + ltm->tm_year));
}
std::string fullpath = "/tmp/" + dirname;
char* user_path = std::getenv("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR");
if (user_path) {
fullpath = user_path;
}
std::string trace_filename = c10::str(fullpath, "/rank", rank, ".json");
std::ofstream _outfile;
if (!_outfile.is_open()) {
if (!mkdir(fullpath.c_str(), 0777)) {
LOG(INFO) << getLogPrefix() << "[INFO] failed to mkdir " << fullpath;
} else if (errno != EEXIST) {
return;
}
_outfile.open(trace_filename, std::ofstream::out | std::ofstream::trunc);
}
// flush the traced comms
if (_outfile.is_open()) {
_outfile << "[" << c10::Join(",", trace_generator->getCommsTrace())
<< "\n]";
_outfile.flush();
_outfile.close();
}
#ifdef FBCODE_CAFFE2
uploadTrace_internal(
trace_filename, dirname, c10::str("rank", rank, ".json"));
#endif
}
/* unused */
void CommTraceLogger::setCurBlock(const std::string& name) {
curBlocks_.push_back(
c10::str("\"", name, "\"")); // add quote marks for JSON format
}
/* unused */
void CommTraceLogger::popBlock() {
// TODO: remove specific name
curBlocks_.pop_back();
}
void CommTraceLogger::recordOptionalInfo(int root) {
curRoot_ = root;
}
void CommTraceLogger::recordOptionalInfo(
const std::vector<int64_t>& outputSplitSizes,
const std::vector<int64_t>& inputSplitSizes) {
curOutSplitSizes_ = outputSplitSizes;
curInSplitSizes_ = inputSplitSizes;
}
void CommTraceLogger::recordComms(
const std::string& commName,
const uintptr_t workReq,
const int rank,
const int world_size,
const std::vector<at::Tensor>& inputTensors,
const std::vector<at::Tensor>& outputTensors) {
auto inSize = (!inputTensors.empty()) ? inputTensors[0].numel() : 0;
auto outSize = (!outputTensors.empty()) ? outputTensors[0].numel() : 0;
auto dtype =
(!outputTensors.empty()) ? outputTensors[0].scalar_type() : at::kByte;
auto devType = (!outputTensors.empty()) ? outputTensors[0].device().type()
: c10::DeviceType::CPU;
auto now = std::chrono::system_clock::now();
static auto startTS = now;
int64_t time_since_begin =
std::chrono::duration_cast<std::chrono::nanoseconds>(now - startTS)
.count();
// TODO: get markers from torch profiler if enabled
// common fields for all operations
std::string cur_trace_ = c10::str(
"\n\t\t\"markers\": [",
curBlocks_,
"]",
",\n\t\t\"startTime_ns\": ",
time_since_begin,
",\n\t\t\"comms\": \"",
commName,
"\"",
",\n\t\t\"req\": ",
workReq,
",\n\t\t\"seqnum\": ",
seqnum++,
",\n\t\t\"world_size\": ",
world_size);
if (inSize > 0 || outSize > 0) {
// for most collectives - append msg sizes, data type, device type
cur_trace_ = c10::str(
cur_trace_,
",\n\t\t\"in_msg_size\": ",
inSize,
",\n\t\t\"out_msg_size\": ",
outSize,
",\n\t\t\"dtype\": \"",
at::toString(dtype),
"\",\n\t\t\"devType\": \"",
c10::DeviceTypeName(devType),
"\"");
}
if (curRoot_ != -1) {
// append root rank if applicable, e.g., broadcast, gather, scatter
cur_trace_ = c10::str(cur_trace_, ",\n\t\t\"root\": ", curRoot_);
}
if (!curInSplitSizes_.empty() || !curOutSplitSizes_.empty()) {
// append input and output splits if applicable, e.g., ALLTOALL_BASE
cur_trace_ = c10::str(
cur_trace_,
",\n\t\t\"in_split\": [",
c10::Join(",", curInSplitSizes_),
"]"
",\n\t\t\"out_split\": [",
c10::Join(",", curOutSplitSizes_),
"]");
}
comms_trace_.push_back(c10::str("\n\t{", cur_trace_, "\n\t}"));
// record the trace to kineto trace if applicable
RECORD_PARAM_COMMS(
rank,
commName.c_str(),
inSize,
outSize,
dtype,
curInSplitSizes_,
curOutSplitSizes_);
// reset optional field
curRoot_ = -1;
curInSplitSizes_ = {};
curOutSplitSizes_ = {};
}
} // namespace c10d
#endif // USE_C10D_UCC

View File

@ -0,0 +1,58 @@
#pragma once
#ifdef USE_C10D_UCC
#include <c10d/UCCUtils.hpp>
namespace c10d {
#define RECORD_COMMS_TRACE( \
_comms_tracer, _work, _opType, _rank, _comm_size, _inTensors, _outTensors) \
do { \
if (torch_ucc_config.enable_comms_logger) { \
_comms_tracer->recordComms( \
opTypeToString(_opType), \
(uintptr_t)_work.get(), \
_rank, \
_comm_size, \
_inTensors, \
_outTensors); \
} \
} while (0)
// interfaces to collect communication traces
class TORCH_API CommTraceLogger : public torch::CustomClassHolder {
private:
std::vector<std::string> comms_trace_;
std::vector<std::string> curBlocks_; /* unused */
std::vector<int64_t> curOutSplitSizes_;
std::vector<int64_t> curInSplitSizes_;
int curRoot_ = -1;
unsigned long seqnum = 0;
public:
void setCurBlock(const std::string& name); /* unused */
void popBlock(); /* unused */
// record root info if applicable, e.g., broadcast, gather, scatter
void recordOptionalInfo(int root = -1);
// record input/output splits of Alltoallv
void recordOptionalInfo(
const std::vector<int64_t>& outputSplitSizes = {},
const std::vector<int64_t>& inputSplitSizes = {});
// record essential comms information
void recordComms(
const std::string& collName,
const uintptr_t workReq = 0,
const int rank = -1,
const int world_size = -1,
const std::vector<at::Tensor>& inputTensors = {},
const std::vector<at::Tensor>& outputTensor = {});
// return collected comms traces
std::vector<std::string>& getCommsTrace() {
return comms_trace_;
}
};
} // namespace c10d
#endif // USE_C10D_UCC

View File

@ -0,0 +1,285 @@
#ifdef USE_C10D_UCC
#include <c10d/UCCTracing.hpp>
#include <c10d/UCCUtils.hpp>
namespace c10d {
namespace {
// Constants for store keys.
constexpr char kTeamRank[] = "teamr";
constexpr char kAllGatherDone[] = "ag_done";
constexpr char kAllGatherFree[] = "ag_free";
} // namespace
CommUCX::CommUCX(
int comm_size,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
: CommBase(logger) {
ucp_params_t params;
ucp_config_t* config;
ucs_status_t st;
ucp_worker_params_t worker_params;
ucp_lib_attr_t ucp_attr;
ucp_attr.field_mask = UCP_LIB_ATTR_FIELD_MAX_THREAD_LEVEL;
TORCH_UCX_CHECK(
ucp_lib_query(&ucp_attr), "failed to query UCP lib attributes");
TORCH_CHECK(
ucp_attr.max_thread_level == UCS_THREAD_MODE_MULTI,
"ucx library wasn't initialized with multithreading support, "
"please check ucx build options");
TORCH_UCX_CHECK(
ucp_config_read("TORCH", nullptr, &config), "failed to read UCP config");
memset(&params, 0, sizeof(ucp_params_t));
params.field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_REQUEST_SIZE |
UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_TAG_SENDER_MASK |
UCP_PARAM_FIELD_REQUEST_INIT | UCP_PARAM_FIELD_REQUEST_CLEANUP;
params.request_size = sizeof(ucc_coll_req_t);
params.features = UCP_FEATURE_TAG;
params.estimated_num_eps = comm_size;
params.tag_sender_mask = TORCH_UCX_RANK_MASK;
params.request_init = [](void* request) {
static_cast<ucc_coll_req_h>(request)->status = UCC_INPROGRESS;
};
params.request_cleanup = [](void*) {};
TORCH_UCX_CHECK(
ucp_init(&params, config, &context), "failed to init UCP context");
ucp_config_release(config);
memset(&worker_params, 0, sizeof(ucp_worker_params_t));
worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
worker_params.thread_mode = UCS_THREAD_MODE_MULTI;
st = ucp_worker_create(context, &worker_params, &worker);
if (st != UCS_OK) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_INIT,
c10::str("UCX failed to create UCP worker:", ucs_status_string(st)));
ucp_cleanup(context);
throw std::runtime_error(ucs_status_string(st));
}
}
void CommUCX::progress() {
ucp_worker_progress(worker);
}
void CommUCX::free_request(ucc_coll_req_h request) {
request->status = UCC_INPROGRESS;
ucp_request_free(request);
}
CommUCX::~CommUCX() {
if (worker != nullptr) {
ucp_worker_destroy(worker);
}
if (context != nullptr) {
ucp_cleanup(context);
}
worker = nullptr;
context = nullptr;
}
ucc_status_t oob_allgather(
void* sbuf,
void* rbuf,
size_t msglen,
void* coll_info,
void** req) {
auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(coll_info);
TORCH_CHECK(info != nullptr);
std::vector<uint8_t> val = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(sbuf),
reinterpret_cast<uint8_t*>(sbuf) + msglen);
try {
info->store->set(info->getKey(kTeamRank + std::to_string(info->rank)), val);
info->rbuf = rbuf;
info->msglen = msglen;
*req = coll_info;
} catch (std::exception& ex) {
LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
<< "[" << ex.what() << "]";
return UCC_ERR_NO_MESSAGE;
}
return UCC_OK;
}
ucc_status_t oob_allgather_test(void* req) {
auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(req);
TORCH_CHECK(info != nullptr);
try {
for (int r = 0; r < info->size; r++) {
if (!info->store->check({info->getKey(kTeamRank + std::to_string(r))})) {
return UCC_INPROGRESS;
}
}
for (int r = 0; r < info->size; r++) {
std::vector<uint8_t> data =
info->store->get(info->getKey(kTeamRank + std::to_string(r)));
memcpy(
(void*)((ptrdiff_t)info->rbuf + info->msglen * r),
data.data(),
info->msglen);
}
} catch (std::exception& ex) {
LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
<< "[" << ex.what() << "]";
return UCC_ERR_NO_MESSAGE;
}
return UCC_OK;
}
ucc_status_t oob_allgather_free(void* req) {
auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(req);
TORCH_CHECK(info != nullptr);
try {
int num_done = info->store->add({info->getKey(kAllGatherDone)}, 1);
if (num_done == info->size) {
info->store->deleteKey(info->getKey(kAllGatherDone));
// Note: to avoid race condition, it's important to remove all keys in
// oob_allgather_free first and only after that signal completion to
// other ranks
for (const auto r : c10::irange(info->size)) {
info->store->deleteKey(info->getKey(kTeamRank + std::to_string(r)));
}
for (const auto r : c10::irange(info->size)) {
info->store->add({info->getKey(kAllGatherFree + std::to_string(r))}, 1);
}
} else {
info->store->wait(
{info->getKey(kAllGatherFree + std::to_string(info->rank))});
}
info->store->deleteKey(
info->getKey(kAllGatherFree + std::to_string(info->rank)));
} catch (std::exception& ex) {
LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
<< "[" << ex.what() << "]";
return UCC_ERR_NO_MESSAGE;
}
return UCC_OK;
}
CommUCC::CommUCC(
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
: CommBase(logger) {
ucc_lib_config_h lib_config;
ucc_context_config_h context_config;
ucc_lib_params_t lib_params;
ucc_context_params_t context_params;
ucc_status_t st;
TORCH_UCC_CHECK(
ucc_lib_config_read("TORCH", nullptr, &lib_config),
"failed to read UCC lib config");
memset(&lib_params, 0, sizeof(ucc_lib_params_t));
lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE;
lib_params.thread_mode = UCC_THREAD_MULTIPLE;
TORCH_UCC_CHECK(
ucc_init(&lib_params, lib_config, &lib), "failed to init UCC lib");
ucc_lib_config_release(lib_config);
ucc_lib_attr_t lib_attr;
lib_attr.mask = UCC_LIB_ATTR_FIELD_THREAD_MODE;
TORCH_UCC_CHECK(
ucc_lib_get_attr(lib, &lib_attr), "failed to query for lib attr");
TORCH_CHECK(
lib_attr.thread_mode == UCC_THREAD_MULTIPLE,
"ucc library wasn't initialized with multithreading support, "
"please check ucc build options");
st = ucc_context_config_read(lib, NULL, &context_config);
if (st != UCC_OK) {
// FIXME: would this cause deadlock if only one rank fails?
TORCH_UCC_CHECK(
ucc_finalize(lib),
"failed to finalize UCC library when failing to read UCC context config");
TORCH_UCC_LOG_ERROR(
TORCH_UCC_INIT,
c10::str("failed to read UCC context config: ", ucc_status_string(st)));
throw std::runtime_error(ucc_status_string(st));
}
st = ucc_context_config_modify(
context_config,
NULL,
"ESTIMATED_NUM_EPS",
std::to_string(oob->size).c_str());
if (st != UCC_OK) {
ucc_context_config_release(context_config);
ucc_finalize(lib);
TORCH_UCC_LOG_ERROR(
TORCH_UCC_INIT,
c10::str(
"UCC failed to modify UCC context config: ",
ucc_status_string(st)));
throw std::runtime_error(ucc_status_string(st));
}
memset(&context_params, 0, sizeof(ucc_context_params_t));
context_params.mask =
UCC_CONTEXT_PARAM_FIELD_TYPE | UCC_CONTEXT_PARAM_FIELD_OOB;
context_params.type = UCC_CONTEXT_SHARED;
context_params.oob.n_oob_eps = oob->size;
context_params.oob.oob_ep = oob->rank;
context_params.oob.allgather = oob_allgather;
context_params.oob.req_test = oob_allgather_test;
context_params.oob.req_free = oob_allgather_free;
context_params.oob.coll_info = oob.get();
st = ucc_context_create(lib, &context_params, context_config, &context);
ucc_context_config_release(context_config);
if (st != UCC_OK) {
TORCH_UCC_CHECK(
ucc_finalize(lib),
"failed to finalize UCC library when failing to creat UCC context");
TORCH_UCC_LOG_ERROR(
TORCH_UCC_INIT,
c10::str("UCC failed to create UCC context: ", ucc_status_string(st)));
throw std::runtime_error(ucc_status_string(st));
}
}
void CommUCC::progress() {
TORCH_UCC_CHECK(
ucc_context_progress(context), "failed to progress UCC collective");
}
void CommUCC::free_request(ucc_coll_req_h request) {
TORCH_UCC_CHECK(
ucc_collective_finalize(request), "failed to release UCC request");
}
CommUCC::~CommUCC() {
if (context != nullptr) {
TORCH_UCC_CHECK(
ucc_context_destroy(context), "failed to destory UCC context");
}
if (lib != nullptr) {
TORCH_UCC_CHECK(ucc_finalize(lib), "failed to finalize UCC library");
}
context = nullptr;
lib = nullptr;
}
std::string ProcessGroupUCCLogger::getLogPrefix(torch_ucc_phase_t phase) {
// caller can override the phase stored locally
torch_ucc_phase_t phase_ =
(local_phase != phase && phase != TORCH_UCC_UNKNOWN) ? phase
: local_phase;
return c10::str(log_prefix, "[", ucc_phase_map.at(phase_), "]");
}
void ProcessGroupUCCLogger::setLogPrefix(std::string log_prefix_) {
log_prefix = log_prefix_;
}
ProcessGroupUCCLogger::ProcessGroupUCCLogger() {
setLogPrefix("[ProcessGroupUCC]");
}
ProcessGroupUCCLogger::ProcessGroupUCCLogger(
std::string log_prefix,
torch_ucc_phase_t phase)
: local_phase(phase) {
setLogPrefix(log_prefix);
}
} // namespace c10d
#endif // USE_C10D_UCC

View File

@ -0,0 +1,193 @@
#pragma once
#ifdef USE_C10D_UCC
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <ucc/api/ucc.h>
#include <ucp/api/ucp.h>
#define TORCH_UCX_COMM_BITS 15
#define TORCH_UCX_RANK_BITS 16
#define TORCH_UCX_TAG_BITS 32
#define TORCH_UCX_OOB_BITS 1
#define TORCH_UCX_COMM_BITS_OFFSET 0
#define TORCH_UCX_RANK_BITS_OFFSET TORCH_UCX_COMM_BITS
#define TORCH_UCX_TAG_BITS_OFFSET (TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS)
#define TORCH_UCX_OOB_BITS_OFFSET \
(TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS + TORCH_UCX_TAG_BITS)
#define TORCH_UCX_MAX_COMM ((((uint64_t)1) << TORCH_UCX_COMM_BITS) - 1)
#define TORCH_UCX_MAX_RANK ((((uint64_t)1) << TORCH_UCX_RANK_BITS) - 1)
#define TORCH_UCX_MAX_TAG ((((uint64_t)1) << TORCH_UCX_TAG_BITS) - 1)
#define TORCH_UCX_MAX_OOB ((((uint64_t)1) << TORCH_UCX_OOB_BITS) - 1)
#define TORCH_UCX_COMM_MASK (TORCH_UCX_MAX_COMM << TORCH_UCX_COMM_BITS_OFFSET)
#define TORCH_UCX_RANK_MASK (TORCH_UCX_MAX_RANK << TORCH_UCX_RANK_BITS_OFFSET)
#define TORCH_UCX_TAG_MASK (TORCH_UCX_MAX_TAG << TORCH_UCX_TAG_BITS_OFFSET)
#define TORCH_UCX_OOB_MASK (TORCH_UCX_MAX_OOB << TORCH_UCX_OOB_BITS_OFFSET)
namespace c10d {
// Macro to throw on a non-successful UCC return value.
#define TORCH_UCC_CHECK(_cmd, _error_msg) \
do { \
ucc_status_t result = _cmd; \
if (result != UCC_OK) { \
std::string err = c10::str( \
"[", \
std::string(__FILE__), \
":", \
std::to_string(__LINE__), \
"] ", \
logger->getLogPrefix(), \
_error_msg, \
", error code ", \
result, \
": ", \
ucc_status_string(result), \
", system error code ", \
errno); \
TORCH_CHECK(false, err); \
} \
} while (0)
// Macro to throw on a non-successful UCX return value.
#define TORCH_UCX_CHECK(_cmd, _error_msg) \
do { \
ucs_status_t result = _cmd; \
if (result != UCS_OK) { \
std::string err = c10::str( \
"[", \
std::string(__FILE__), \
":", \
std::to_string(__LINE__), \
"] ", \
logger->getLogPrefix(), \
_error_msg, \
", error code ", \
result, \
": ", \
ucs_status_string(result), \
", system error code ", \
errno); \
TORCH_CHECK(false, err); \
} \
} while (0)
// Macros to print logs with unified format
#define TORCH_UCC_LOG_ERROR(_phase, _msg) \
LOG(ERROR) << logger->getLogPrefix(_phase) << "[ERROR] " << _msg;
#define TORCH_UCC_LOG_INFO(_phase, _msg) \
LOG(INFO) << logger->getLogPrefix(_phase) << "[INFO] " << _msg;
#define TORCH_UCC_LOG_DEBUG(_phase, _msg) \
VLOG(1) << logger->getLogPrefix(_phase) << "[DEBUG] " << _msg;
enum torch_ucc_phase_t {
TORCH_UCC_UNKNOWN = -1,
TORCH_UCC_INIT,
TORCH_UCC_HEALTH_CHECK,
TORCH_UCC_READY,
TORCH_UCC_COLL_POST,
TORCH_UCC_COLL_PROGRESS,
TORCH_UCC_FINALIZE,
};
const std::map<torch_ucc_phase_t, std::string> ucc_phase_map = {
{TORCH_UCC_UNKNOWN, "UNKNOWN"},
{TORCH_UCC_INIT, "INIT"},
{TORCH_UCC_HEALTH_CHECK, "HEALTH_CHECK"},
{TORCH_UCC_READY, "READY"},
{TORCH_UCC_COLL_POST, "COLL_POST"},
{TORCH_UCC_COLL_PROGRESS, "COLL_PROGRESS"},
{TORCH_UCC_FINALIZE, "FINALIZE"},
};
class CommTraceLogger;
class TORCH_API ProcessGroupUCCLogger : public torch::CustomClassHolder {
public:
ProcessGroupUCCLogger();
ProcessGroupUCCLogger(std::string log_prefix, torch_ucc_phase_t phase);
std::string getLogPrefix(torch_ucc_phase_t phase = TORCH_UCC_UNKNOWN);
void setLogPrefix(std::string log_prefix);
inline void setPhase(torch_ucc_phase_t phase) {
local_phase = phase;
}
void initCommsTracer();
void flushComms(int rank, int world_size);
std::shared_ptr<CommTraceLogger> trace_generator = nullptr;
protected:
std::string log_prefix;
torch_ucc_phase_t local_phase = TORCH_UCC_UNKNOWN;
bool initialized_CommTraceLogger = false;
};
struct torch_ucc_oob_coll_info_t {
c10::intrusive_ptr<Store> store;
uint32_t comm_id;
int rank;
int size;
void* rbuf;
size_t msglen;
std::string getKey(std::string key) {
return std::to_string(comm_id) + key;
}
};
class CommBase {
public:
CommBase(const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_)
: logger(logger_) {}
virtual void progress() = 0;
virtual void free_request(ucc_coll_req_h request) = 0;
virtual ~CommBase() {}
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
};
class CommUCX : public CommBase {
public:
ucp_context_h context{nullptr};
ucp_worker_h worker{nullptr};
public:
void progress() override;
void free_request(ucc_coll_req_h request) override;
CommUCX(
int comm_size,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
~CommUCX();
};
class CommUCC : public CommBase {
public:
ucc_lib_h lib{nullptr};
ucc_context_h context{nullptr};
public:
void progress() override;
CommUCC(
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
void free_request(ucc_coll_req_h request) override;
~CommUCC();
};
ucc_status_t oob_allgather(
void* sbuf,
void* rbuf,
size_t msglen,
void* coll_info,
void** req);
ucc_status_t oob_allgather_test(void* req);
ucc_status_t oob_allgather_free(void* req);
} // namespace c10d
#endif // USE_C10D_UCC

View File

@ -24,6 +24,10 @@
#include <c10d/ProcessGroupMPI.hpp>
#endif
#ifdef USE_C10D_UCC
#include <c10d/ProcessGroupUCC.hpp>
#endif
#include <c10d/PrefixStore.hpp>
#include <fmt/format.h>
#include <pybind11/chrono.h>
@ -1557,6 +1561,25 @@ Example::
py::call_guard<py::gil_scoped_release>());
#endif
#ifdef USE_C10D_UCC
auto processGroupUCC =
intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupUCC>(
module, "ProcessGroupUCC", processGroup)
.def(
py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size,
const std::chrono::milliseconds& timeout) {
return c10::make_intrusive<::c10d::ProcessGroupUCC>(
store, rank, size, timeout);
}),
py::arg("store"),
py::arg("rank"),
py::arg("size"),
py::arg("timeout") = kProcessGroupDefaultTimeout,
py::call_guard<py::gil_scoped_release>());
#endif
py::class_<
::c10d::ProcessGroup::Work,
c10::intrusive_ptr<::c10d::ProcessGroup::Work>,

View File

@ -39,6 +39,7 @@ from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401
_MPI_AVAILABLE = True
_NCCL_AVAILABLE = True
_GLOO_AVAILABLE = True
_UCC_AVAILABLE = True
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
@ -59,6 +60,11 @@ try:
except ImportError:
_GLOO_AVAILABLE = False
try:
from torch._C._distributed_c10d import ProcessGroupUCC
except ImportError:
_UCC_AVAILABLE = False
logger = logging.getLogger(__name__)
@ -86,7 +92,7 @@ def supports_complex(reduceOp: ReduceOp) -> bool:
class Backend(object):
"""
An enum-like class of available backends: GLOO, NCCL, MPI, and other registered
An enum-like class of available backends: GLOO, NCCL, UCC, MPI, and other registered
backends.
The values of this class are lowercase strings, e.g., ``"gloo"``. They can
@ -105,6 +111,7 @@ class Backend(object):
UNDEFINED = "undefined"
GLOO = "gloo"
NCCL = "nccl"
UCC = "ucc"
MPI = "mpi"
TCP = "tcp"
_plugins: Dict[str, Callable] = {}
@ -122,7 +129,7 @@ class Backend(object):
)
elif value == Backend.UNDEFINED:
raise ValueError("Invalid backend: '{}'".format(name))
elif value != Backend.GLOO and value != Backend.NCCL and value != Backend.MPI:
elif value != Backend.GLOO and value != Backend.NCCL and value != Backend.UCC and value != Backend.MPI:
value = name.lower()
return value
@ -145,6 +152,9 @@ class Backend(object):
.. note:: This support of 3rd party backend is experimental and subject to change.
"""
# Allow UCC plugin if Pytorch is not built with native support.
# TODO: remove this exception once UCC plugin is fully deprecated.
if (name != Backend.UCC or (name == Backend.UCC and is_ucc_available())):
assert not hasattr(Backend, name.upper()), (
f"{name.upper()} c10d backend already exist"
)
@ -412,6 +422,13 @@ def is_gloo_available():
return _GLOO_AVAILABLE
def is_ucc_available():
"""
Checks if the UCC backend is available.
"""
return _UCC_AVAILABLE
def is_initialized():
"""
Checking if the default process group has been initialized
@ -511,12 +528,13 @@ def init_process_group(
Args:
backend (str or Backend): The backend to use. Depending on
build-time configurations, valid values include ``mpi``, ``gloo``,
and ``nccl``. This field should be given as a lowercase string
(e.g., ``"gloo"``), which can also be accessed via
``nccl``, and ``ucc``. This field should be given as a lowercase
string (e.g., ``"gloo"``), which can also be accessed via
:class:`Backend` attributes (e.g., ``Backend.GLOO``). If using
multiple processes per machine with ``nccl`` backend, each process
must have exclusive access to every GPU it uses, as sharing GPUs
between processes can result in deadlocks.
between processes can result in deadlocks. ``ucc`` backend is
experimental.
init_method (str, optional): URL specifying how to initialize the
process group. Default is "env://" if no
``init_method`` or ``store`` is specified.
@ -547,6 +565,9 @@ def init_process_group(
continue executing user code since failed async NCCL operations
might result in subsequent CUDA operations running on corrupted
data. Only one of these two environment variables should be set.
For ``ucc``, blocking wait is supported similar to NCCL. However,
async error handling is done differently since with UCC we have
progress thread and not watch-dog thread.
group_name (str, optional, deprecated): Group name.
pg_options (ProcessGroupOptions, optional): process group options
specifying what additional options need to be passed in during
@ -682,7 +703,7 @@ def _new_process_group_helper(
is_default_group = len(group_ranks) == 0
backend = Backend(backend)
pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL]
pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL, ProcessGroupUCC]
if backend == Backend.MPI:
if not is_mpi_available():
raise RuntimeError(
@ -767,6 +788,33 @@ def _new_process_group_helper(
)
_pg_map[pg] = (Backend.NCCL, store)
_pg_names[pg] = group_name
elif backend == Backend.UCC and is_ucc_available():
# TODO: once UCC plugin is fully deprecated, remove
# is_ucc_available() from above elif-condition and raise
# RuntimeError if is_ucc_available() returns false.
pg = ProcessGroupUCC(prefix_store, rank, world_size, timeout=timeout)
# In debug mode and if GLOO is available, wrap in a wrapper PG that
# enables enhanced collective checking for debugability.
if get_debug_level() == DebugLevel.DETAIL:
if not _GLOO_AVAILABLE:
logger.info(
"""TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but
GLOO is not available. Build with Gloo to
create a wrapper process group in debug mode
to aid collective desynchronization debugging."""
)
else:
pg = _create_process_group_wrapper(
wrapped_pg=pg,
store_prefix=group_name,
store=store,
rank=rank,
world_size=world_size,
timeout=timeout,
)
_pg_map[pg] = (Backend.UCC, store)
_pg_names[pg] = group_name
else:
assert backend.upper() in Backend._plugins, (
f"unknown c10d backend type {backend.upper()}"
@ -1064,7 +1112,7 @@ def batch_isend_irecv(p2p_op_list):
Send or Receive a batch of tensors asynchronously and return a list of requests.
Process each of the operations in ``p2p_op_list`` and return the corresponding
requests. NCCL and Gloo backend are currently supported.
requests. NCCL, Gloo, and UCC backend are currently supported.
Args:
p2p_op_list: A list of point-to-point operations(type of each operator is