mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/153805 Approved by: https://github.com/nareshrajkumar866, https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
1720 lines
56 KiB
C++
1720 lines
56 KiB
C++
#ifdef USE_C10D_UCC
|
|
|
|
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
|
#include <c10/util/CallOnce.h>
|
|
#include <c10/util/env.h>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
|
|
#include <torch/csrc/distributed/c10d/UCCTracing.hpp>
|
|
#include <torch/csrc/distributed/c10d/UCCUtils.hpp>
|
|
#include <list>
|
|
#include <memory>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
|
|
namespace c10d {
|
|
|
|
namespace {
|
|
|
|
const std::map<c10::DeviceType, ucc_memory_type_t> ucc_mtype_map = {
|
|
{c10::kCPU, UCC_MEMORY_TYPE_HOST},
|
|
{c10::kCUDA, UCC_MEMORY_TYPE_CUDA},
|
|
};
|
|
|
|
ucc_memory_type_t to_ucc_memType(c10::DeviceType _c10_type) {
|
|
if (ucc_mtype_map.find(_c10_type) != ucc_mtype_map.end())
|
|
return ucc_mtype_map.at(_c10_type);
|
|
else
|
|
return UCC_MEMORY_TYPE_UNKNOWN;
|
|
}
|
|
|
|
const std::map<at::ScalarType, ucc_datatype_t> ucc_dtype_map = {
|
|
{at::kByte, UCC_DT_UINT8},
|
|
{at::kChar, UCC_DT_INT8},
|
|
{at::kHalf, UCC_DT_FLOAT16},
|
|
{at::kBFloat16, UCC_DT_BFLOAT16},
|
|
{at::kDouble, UCC_DT_FLOAT64},
|
|
{at::kFloat, UCC_DT_FLOAT32},
|
|
{at::kInt, UCC_DT_INT32},
|
|
{at::kLong, UCC_DT_INT64},
|
|
{at::kBool, UCC_DT_UINT8},
|
|
};
|
|
|
|
ucc_datatype_t to_ucc_dType(at::Tensor _tensor) {
|
|
if (_tensor.scalar_type() == at::kBool && _tensor.element_size() != 1) {
|
|
TORCH_CHECK(
|
|
false, "Size of Boolean type larger than 1 is not supported in UCC");
|
|
}
|
|
try {
|
|
return ucc_dtype_map.at(_tensor.scalar_type());
|
|
} catch (const std::out_of_range&) {
|
|
TORCH_CHECK(false, "Not supported data type for UCC");
|
|
}
|
|
}
|
|
|
|
const std::map<ReduceOp, ucc_reduction_op_t> ucc_op_map = {
|
|
{ReduceOp::SUM, UCC_OP_SUM},
|
|
{ReduceOp::PRODUCT, UCC_OP_PROD},
|
|
{ReduceOp::MIN, UCC_OP_MIN},
|
|
{ReduceOp::MAX, UCC_OP_MAX},
|
|
{ReduceOp::BAND, UCC_OP_BAND},
|
|
{ReduceOp::BOR, UCC_OP_BOR},
|
|
{ReduceOp::BXOR, UCC_OP_BXOR},
|
|
{ReduceOp::AVG, UCC_OP_AVG},
|
|
};
|
|
|
|
ucc_reduction_op_t to_ucc_reduceOp(
|
|
const ReduceOp _op,
|
|
const at::ScalarType _dt) {
|
|
if (_dt == at::kBool) {
|
|
if (_op == ReduceOp::SUM) {
|
|
// bitwise or
|
|
return UCC_OP_MAX;
|
|
} else if (_op == ReduceOp::PRODUCT) {
|
|
// bitwise and
|
|
return UCC_OP_MIN;
|
|
} else if (_op == ReduceOp::AVG) {
|
|
TORCH_CHECK(false, "Cannot use ReduceOp.AVG with boolean inputs");
|
|
}
|
|
}
|
|
|
|
try {
|
|
return ucc_op_map.at(_op);
|
|
} catch (const std::out_of_range&) {
|
|
TORCH_CHECK(false, "Not supported ReduceOp for UCC");
|
|
}
|
|
}
|
|
|
|
struct torch_ucc_config_t {
|
|
c10::once_flag flag;
|
|
std::array<bool, 32> blocking_wait;
|
|
bool enable_comms_logger;
|
|
bool use_future;
|
|
// Sharing UCC communicator among multiple PGs to save resource.
|
|
bool shared_comm;
|
|
// Using allgatherv to achieve allgather, without flattening the list of
|
|
// (potentially non-contiguous) tensors.
|
|
bool use_allgatherv;
|
|
bool enable_health_check;
|
|
} torch_ucc_config;
|
|
|
|
std::unordered_map<std::string, std::string> torch_ucc_envs_map = {
|
|
// TORCH_UCC_BLOCKING_WAIT allowed syntax:
|
|
// - TORCH_UCC_BLOCKING_WAIT=none --> blocking wait completely disabled
|
|
// - TORCH_UCC_BLOCKING_WAIT=all --> blocking wait completely enabled
|
|
// - TORCH_UCC_BLOCKING_WAIT=allreduce,send,recv --> blocking wait enabled
|
|
// on selected operations
|
|
// Supported operations:
|
|
// [allgather,allgather_base,allreduce,alltoall,broadcast,
|
|
// gather,reduce,reduce_scatter, reduce_scatter_base,scatter,send,recv]
|
|
{"TORCH_UCC_BLOCKING_WAIT", "none"},
|
|
|
|
{"TORCH_UCC_USE_FUTURE", "1"},
|
|
{"TORCH_UCC_PROFILING_ENABLE", "0"},
|
|
{"TORCH_UCC_SHARED_COMM", "1"},
|
|
{"TORCH_UCC_USE_ALLGATHERV", "0"},
|
|
{"TORCH_UCC_ENABLE_HEALTH_CHECK", "0"},
|
|
{"TORCH_UCC_ENABLE_COMMS_LOGGER", "0"},
|
|
};
|
|
|
|
std::vector<OpType> parse_blocking_wait(std::string op_list_string) {
|
|
const static std::unordered_map<std::string, OpType> str2op = {
|
|
{"allgather", OpType::ALLGATHER},
|
|
{"allgather_base", OpType::_ALLGATHER_BASE},
|
|
{"allreduce", OpType::ALLREDUCE},
|
|
{"alltoall_base", OpType::ALLTOALL_BASE},
|
|
{"broadcast", OpType::BROADCAST},
|
|
{"gather", OpType::GATHER},
|
|
{"reduce", OpType::REDUCE},
|
|
{"reduce_scatter", OpType::REDUCE_SCATTER},
|
|
{"reduce_scatter_base", OpType::_REDUCE_SCATTER_BASE},
|
|
{"scatter", OpType::SCATTER},
|
|
{"send", OpType::SEND},
|
|
{"recv", OpType::RECV},
|
|
};
|
|
auto op_list = parse_list(op_list_string);
|
|
if (op_list == std::vector<std::string>{"none"}) {
|
|
return {};
|
|
}
|
|
std::vector<OpType> result;
|
|
if (op_list == std::vector<std::string>{"all"}) {
|
|
for (auto entry : str2op) {
|
|
result.push_back(entry.second);
|
|
}
|
|
} else {
|
|
for (auto op_string : op_list) {
|
|
result.push_back(str2op.at(op_string));
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void read_config() {
|
|
// default configuration
|
|
torch_ucc_config.blocking_wait.fill(false);
|
|
torch_ucc_config.use_future = true;
|
|
torch_ucc_config.shared_comm = false;
|
|
torch_ucc_config.use_allgatherv = false;
|
|
torch_ucc_config.enable_health_check = false;
|
|
torch_ucc_config.enable_comms_logger = false;
|
|
|
|
// read all torch_ucc env. variables and update the map
|
|
for (auto& [env_name, value] : torch_ucc_envs_map) {
|
|
auto env = c10::utils::get_env(env_name.c_str());
|
|
if (env.has_value()) {
|
|
value = std::move(env.value());
|
|
}
|
|
}
|
|
|
|
auto blocking_wait_str = torch_ucc_envs_map.at("TORCH_UCC_BLOCKING_WAIT");
|
|
for (auto op : parse_blocking_wait(blocking_wait_str)) {
|
|
torch_ucc_config.blocking_wait[(std::uint8_t)op] = true;
|
|
}
|
|
// barrier is always blocking
|
|
torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BARRIER] = true;
|
|
|
|
torch_ucc_config.use_future =
|
|
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE"));
|
|
torch_ucc_config.shared_comm =
|
|
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_SHARED_COMM"));
|
|
torch_ucc_config.use_allgatherv =
|
|
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_ALLGATHERV"));
|
|
torch_ucc_config.enable_health_check =
|
|
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_HEALTH_CHECK"));
|
|
torch_ucc_config.enable_comms_logger =
|
|
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_COMMS_LOGGER"));
|
|
}
|
|
|
|
void check_device(c10::Device dev1, c10::Device dev2) {
|
|
if (dev1.is_cuda() && dev2.is_cuda() && dev1 != dev2) {
|
|
throw std::invalid_argument("ProcessGroupUCC multidevice is not supported");
|
|
}
|
|
}
|
|
|
|
void check_tensor(const std::vector<at::Tensor>& tensors) {
|
|
if (tensors.size() != 1) {
|
|
throw std::invalid_argument(
|
|
"ProcessGroupUCC takes 1 tensor. Got " +
|
|
std::to_string(tensors.size()) + ". ");
|
|
}
|
|
if (!tensors[0].is_contiguous()) {
|
|
throw std::invalid_argument(
|
|
"ProcessGroupUCC input tensor has to be contiguous");
|
|
}
|
|
if (tensors[0].is_sparse()) {
|
|
throw std::invalid_argument("ProcessGroupUCC input tensor has to be dense");
|
|
}
|
|
// TODO: check cuda case
|
|
}
|
|
|
|
ProcessGroupUCC::WorkUCC::~WorkUCC() {
|
|
#ifdef USE_CUDA
|
|
if (fence && ep) {
|
|
std::lock_guard<std::mutex> lock(ep->event_pool_mutex);
|
|
ep->event_pool.push(std::move(fence));
|
|
}
|
|
#endif
|
|
}
|
|
|
|
void ProcessGroupUCC::WorkUCC::setException() {
|
|
if (exception() || !entry_) {
|
|
return;
|
|
}
|
|
exception_ = entry_->eptr_;
|
|
}
|
|
|
|
void ProcessGroupUCC::WorkUCC::setAndThrowException() {
|
|
setException();
|
|
if (exception()) {
|
|
std::rethrow_exception(exception());
|
|
}
|
|
}
|
|
|
|
bool ProcessGroupUCC::WorkUCC::isCompleted() {
|
|
if (!entry_) {
|
|
return true;
|
|
}
|
|
setException();
|
|
// status_ <= 0 to avoid listing all possible status codes. The main thread
|
|
// needs to be unblocked when UCC (in progress thread) returns success (== 0)
|
|
// or any error code (< 0).
|
|
return exception() || entry_->status_ <= 0;
|
|
}
|
|
|
|
bool ProcessGroupUCC::WorkUCC::isSuccess() const {
|
|
if (!entry_) {
|
|
return true;
|
|
}
|
|
return !exception() && entry_->status_ == 0;
|
|
}
|
|
|
|
bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) {
|
|
if (torch_ucc_config.enable_comms_logger && logger_) {
|
|
logger_->trace_generator->recordComms("wait", (uintptr_t)this, rank_);
|
|
}
|
|
#ifdef USE_CUDA
|
|
if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) {
|
|
// block user stream
|
|
setAndThrowException();
|
|
fence->block(at::cuda::getCurrentCUDAStream());
|
|
return true;
|
|
}
|
|
#endif
|
|
// wait for complete. For blocking case, the main thread will be blocked in
|
|
// this loop until the progress thread changes the status of this request.
|
|
// If timeout occurs, UCC will return UCC_ERR_TIMEOUT as the status. The
|
|
// main thread will throw out the exception then. There is no "abort"
|
|
// function in UCC currently.
|
|
while (!isCompleted())
|
|
;
|
|
setAndThrowException();
|
|
// manually call profiling end callbacks if they are set,
|
|
// since progress thread does not own WorkUCC
|
|
if (Work::recordFunctionEndCallback_) {
|
|
Work::recordFunctionEndCallback_();
|
|
Work::recordFunctionEndCallback_ = nullptr;
|
|
}
|
|
if (c10d::allow_inflight_collective_as_graph_input()) {
|
|
c10d::unregister_work(
|
|
c10::intrusive_ptr<
|
|
ProcessGroupUCC::WorkUCC>::unsafe_reclaim_from_nonowning(this));
|
|
}
|
|
return true;
|
|
}
|
|
|
|
c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupUCC::WorkUCC::getFuture() {
|
|
return future_;
|
|
}
|
|
|
|
int ProcessGroupUCC::WorkUCC::sourceRank() const {
|
|
if (opType_ != OpType::RECV && opType_ != OpType::RECVANYSOURCE) {
|
|
// Throw an error
|
|
return Work::sourceRank();
|
|
}
|
|
return sourceRank_;
|
|
}
|
|
|
|
std::vector<at::Tensor> ProcessGroupUCC::WorkUCC::result() {
|
|
return *outputs_;
|
|
}
|
|
|
|
void ProcessGroupUCC::ProgressEntry::finalize(std::exception_ptr eptr) {
|
|
ucc_status_t status = UCC_OK;
|
|
|
|
if (request_ != nullptr) {
|
|
status = request_->status;
|
|
comm_->free_request(request_);
|
|
}
|
|
if (eptr) {
|
|
eptr_ = eptr;
|
|
} else {
|
|
status_ = status;
|
|
}
|
|
if (future_) {
|
|
if (eptr) {
|
|
future_->setError(eptr);
|
|
} else {
|
|
future_->markCompleted(
|
|
c10::IValue(data ? data->dst : std::vector<at::Tensor>()));
|
|
}
|
|
}
|
|
}
|
|
|
|
Comm::Comm(
|
|
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_,
|
|
std::shared_ptr<torch_ucc_oob_coll_info_t> oob_,
|
|
c10::Device dev,
|
|
bool is_health_check)
|
|
: logger(logger_),
|
|
oob(oob_),
|
|
ucc_comm(oob, logger),
|
|
finalize_phase(
|
|
is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_FINALIZE),
|
|
cuda_device_index(TORCH_UCC_DEVICE_NOT_SET) {
|
|
if (dev.is_cuda()) {
|
|
cuda_device_index = dev.index();
|
|
}
|
|
stop_progress_loop = false;
|
|
collective_inprogress = false;
|
|
progress_thread = std::thread(&Comm::progress_loop, this);
|
|
#ifdef _GNU_SOURCE
|
|
pthread_setname_np(progress_thread.native_handle(), "ucc-progress");
|
|
#endif
|
|
}
|
|
|
|
Comm::~Comm() {
|
|
std::unique_lock<std::mutex> lock(mutex);
|
|
queue_consume_cv.wait(
|
|
lock, [&] { return progress_queue.empty() && !collective_inprogress; });
|
|
stop_progress_loop = true;
|
|
lock.unlock();
|
|
queue_produce_cv.notify_all();
|
|
progress_thread.join();
|
|
}
|
|
|
|
std::shared_ptr<Comm> Comm::get_comm(
|
|
uint32_t& id,
|
|
c10::Device dev,
|
|
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
|
|
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
|
|
bool is_health_check) {
|
|
static std::mutex m;
|
|
static std::weak_ptr<Comm> comm;
|
|
static uint32_t comm_id;
|
|
|
|
std::lock_guard<std::mutex> lock(m);
|
|
id = comm_id;
|
|
|
|
std::string group_id = "group_id";
|
|
if (is_health_check) {
|
|
group_id = c10::str(dev.type()) + "/" + group_id;
|
|
}
|
|
|
|
std::vector<uint8_t> remote_comm_id;
|
|
oob->store->deleteKey(group_id + std::to_string(0));
|
|
if (oob->rank != 0) {
|
|
std::vector<uint8_t> val = std::vector<uint8_t>(
|
|
reinterpret_cast<uint8_t*>(&id),
|
|
reinterpret_cast<uint8_t*>(&id) + sizeof(id));
|
|
oob->store->set(group_id + std::to_string(oob->rank), val);
|
|
} else {
|
|
for (int i = 1; i < oob->size; i++) {
|
|
remote_comm_id = oob->store->get(group_id + std::to_string(i));
|
|
oob->store->deleteKey(group_id + std::to_string(i));
|
|
// Find the highest id.
|
|
id = std::max(id, *(reinterpret_cast<uint32_t*>(remote_comm_id.data())));
|
|
}
|
|
std::vector<uint8_t> val = std::vector<uint8_t>(
|
|
reinterpret_cast<uint8_t*>(&id),
|
|
reinterpret_cast<uint8_t*>(&id) + sizeof(id));
|
|
oob->store->set(group_id + std::to_string(oob->rank), val);
|
|
}
|
|
remote_comm_id = oob->store->get(group_id + std::to_string(0));
|
|
oob->comm_id = *(reinterpret_cast<uint32_t*>(remote_comm_id.data()));
|
|
// Prepare comm_id (static variable) to the next id.
|
|
comm_id = oob->comm_id + 1;
|
|
|
|
if (torch_ucc_config.shared_comm) {
|
|
std::shared_ptr<Comm> shared_comm = comm.lock();
|
|
if (!shared_comm) {
|
|
shared_comm = std::make_shared<Comm>(logger, oob, dev, is_health_check);
|
|
comm = shared_comm;
|
|
} else {
|
|
if (dev.is_cuda() && !is_health_check) {
|
|
if ((shared_comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
|
|
(shared_comm->cuda_device_index != dev.index())) {
|
|
TORCH_UCC_LOG_ERROR(
|
|
is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_INIT,
|
|
"ucc communicator was initialized with different cuda device,"
|
|
"multi device is not supported");
|
|
throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
|
|
}
|
|
shared_comm->cuda_device_index = dev.index();
|
|
}
|
|
}
|
|
return shared_comm;
|
|
} else {
|
|
return std::make_shared<Comm>(logger, oob, dev, is_health_check);
|
|
}
|
|
}
|
|
|
|
void Comm::ucc_create_team(
|
|
ucc_team_h& team,
|
|
std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
|
|
ucc_status_t st;
|
|
ucc_team_params_t team_params;
|
|
team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE |
|
|
UCC_TEAM_PARAM_FIELD_OOB;
|
|
team_params.oob.allgather = oob_allgather;
|
|
team_params.oob.req_test = oob_allgather_test;
|
|
team_params.oob.req_free = oob_allgather_free;
|
|
team_params.oob.coll_info = oob.get();
|
|
team_params.oob.n_oob_eps = oob->size;
|
|
team_params.oob.oob_ep = oob->rank;
|
|
team_params.ep = oob->rank;
|
|
team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG;
|
|
TORCH_UCC_CHECK(
|
|
ucc_team_create_post(&ucc_comm.context, 1, &team_params, &team),
|
|
"failed to post team create");
|
|
do {
|
|
st = ucc_team_create_test(team);
|
|
ucc_context_progress(ucc_comm.context);
|
|
} while (st == UCC_INPROGRESS);
|
|
TORCH_UCC_CHECK(st, "failed to create UCC team");
|
|
}
|
|
|
|
void Comm::ucc_destroy_team(ucc_team_h& team) {
|
|
std::unique_lock<std::mutex> lock(mutex);
|
|
queue_consume_cv.wait(
|
|
lock, [&] { return progress_queue.empty() && !collective_inprogress; });
|
|
|
|
ucc_status_t status;
|
|
while (UCC_INPROGRESS == (status = ucc_team_destroy(team))) {
|
|
if (UCC_OK != status) {
|
|
TORCH_UCC_LOG_ERROR(
|
|
finalize_phase,
|
|
c10::str("ucc team destroy error: ", ucc_status_string(status)));
|
|
break;
|
|
}
|
|
}
|
|
|
|
lock.unlock();
|
|
}
|
|
|
|
void Comm::enqueue_collective(
|
|
std::unique_ptr<ProcessGroupUCC::WorkData> data,
|
|
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
|
|
ucc_coll_args_t& coll,
|
|
ucc_team_h team) {
|
|
ucc_coll_req_h request;
|
|
TORCH_UCC_CHECK(
|
|
ucc_collective_init(&coll, &request, team), "failed to init collective");
|
|
TORCH_UCC_CHECK_REQUEST(
|
|
request, ucc_collective_post(request), "failed to post collective");
|
|
|
|
auto entry =
|
|
std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request);
|
|
entry->data = std::move(data);
|
|
entry->future_ = work->getFuture();
|
|
work->entry_ = entry;
|
|
std::unique_lock<std::mutex> lock(mutex);
|
|
progress_queue.push_back(entry);
|
|
lock.unlock();
|
|
queue_produce_cv.notify_one();
|
|
}
|
|
|
|
#ifdef USE_CUDA
|
|
void Comm::enqueue_cuda_collective(
|
|
std::unique_ptr<ProcessGroupUCC::WorkData> data,
|
|
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
|
|
ucc_coll_args_t& coll,
|
|
ucc_team_h team,
|
|
ucc_ee_h ee) {
|
|
ucc_coll_req_h request;
|
|
TORCH_UCC_CHECK(
|
|
ucc_collective_init(&coll, &request, team),
|
|
"failed to init cuda collective");
|
|
ucc_ev_t comp_ev, *post_ev;
|
|
comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE;
|
|
comp_ev.ev_context = nullptr;
|
|
comp_ev.ev_context_size = 0;
|
|
comp_ev.req = request;
|
|
TORCH_UCC_CHECK_REQUEST(
|
|
request,
|
|
ucc_collective_triggered_post(ee, &comp_ev),
|
|
"failed to post triggered collective");
|
|
ucc_status_t st = ucc_ee_get_event(ee, &post_ev);
|
|
TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST);
|
|
ucc_ee_ack_event(ee, post_ev);
|
|
auto entry =
|
|
std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request);
|
|
entry->data = std::move(data);
|
|
work->entry_ = entry;
|
|
std::unique_lock<std::mutex> lock(mutex);
|
|
progress_queue.push_back(entry);
|
|
lock.unlock();
|
|
queue_produce_cv.notify_one();
|
|
}
|
|
#endif
|
|
|
|
void Comm::progress_loop() {
|
|
std::unique_lock<std::mutex> lock(mutex);
|
|
#ifdef USE_CUDA
|
|
bool device_set = false;
|
|
#endif
|
|
while (!stop_progress_loop) {
|
|
if (progress_queue.empty()) {
|
|
queue_produce_cv.wait(lock);
|
|
continue;
|
|
}
|
|
collective_inprogress = true;
|
|
auto work = progress_queue.front();
|
|
progress_queue.pop_front();
|
|
lock.unlock();
|
|
#ifdef USE_CUDA
|
|
if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) {
|
|
c10::cuda::set_device(cuda_device_index);
|
|
CUcontext pctx = nullptr;
|
|
at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx);
|
|
if (C10_UNLIKELY(!pctx)) {
|
|
at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(
|
|
&pctx, cuda_device_index);
|
|
at::globalContext().getNVRTC().cuCtxSetCurrent(pctx);
|
|
}
|
|
device_set = true;
|
|
}
|
|
#endif
|
|
std::exception_ptr eptr;
|
|
try {
|
|
while (work->request_->status > 0) {
|
|
ucc_comm.progress();
|
|
}
|
|
if (work->request_->status < 0) {
|
|
eptr = std::make_exception_ptr(
|
|
std::runtime_error(ucc_status_string(work->request_->status)));
|
|
std::string err_log = c10::str(
|
|
"Failed to progress communication", // TODO: report exact op type or
|
|
// id?
|
|
ucc_status_string(work->request_->status));
|
|
TORCH_UCC_LOG_ERROR(TORCH_UCC_COLL_PROGRESS, err_log);
|
|
}
|
|
} catch (...) {
|
|
eptr = std::current_exception();
|
|
}
|
|
work->finalize(eptr);
|
|
work = nullptr;
|
|
collective_inprogress = false;
|
|
queue_consume_cv.notify_one();
|
|
lock.lock();
|
|
}
|
|
}
|
|
|
|
ProcessGroupUCC::ProcessGroupUCC(
|
|
const c10::intrusive_ptr<Store>& store,
|
|
int rank,
|
|
int size,
|
|
std::chrono::duration<float> timeout)
|
|
: Backend(rank, size), timeout_(timeout) {
|
|
c10::call_once(torch_ucc_config.flag, read_config);
|
|
oob = std::make_shared<torch_ucc_oob_coll_info_t>();
|
|
oob->rank = rank;
|
|
oob->size = size;
|
|
oob->store = store;
|
|
comm = nullptr;
|
|
cuda_ee = nullptr;
|
|
static uint32_t id = 0;
|
|
uint32_t pg_id = id++;
|
|
|
|
logger = c10::make_intrusive<ProcessGroupUCCLogger>(
|
|
c10::str("[Rank ", rank_, "]", "[ProcessGroupUCC-", pg_id, "]"),
|
|
TORCH_UCC_INIT);
|
|
TORCH_UCC_LOG_INFO(
|
|
TORCH_UCC_INIT,
|
|
c10::str(
|
|
"Created ProcessGroupUCC with ",
|
|
size,
|
|
" ranks, with timeout ",
|
|
timeout_.count(),
|
|
" secs"));
|
|
std::string envs = "";
|
|
for (auto& torch_ucc_env : torch_ucc_envs_map) {
|
|
envs += ("\n\t" + torch_ucc_env.first + "=" + torch_ucc_env.second);
|
|
}
|
|
TORCH_UCC_LOG_INFO(
|
|
TORCH_UCC_INIT,
|
|
c10::str(
|
|
"Successfully read and set ProcessGroupUCC env. variables as followings",
|
|
envs));
|
|
|
|
if (torch_ucc_config.enable_health_check) {
|
|
// Perform health check by initializing dummy communicators and destroying
|
|
// them. This will help indicate any UCC/UCX-related issues prior to the
|
|
// first collective. Run it in a separate thread and wait on CV to handle
|
|
// timeouts so that if there are hangs, the main thread can still run
|
|
// correctly.
|
|
runHealthCheck();
|
|
}
|
|
if (torch_ucc_config.enable_comms_logger) {
|
|
logger->initCommsTracer();
|
|
}
|
|
}
|
|
|
|
ProcessGroupUCC::~ProcessGroupUCC() {
|
|
if (torch_ucc_config.enable_comms_logger) {
|
|
logger->flushComms(this->getRank(), this->getSize());
|
|
}
|
|
if (comm) {
|
|
logger->setPhase(TORCH_UCC_FINALIZE);
|
|
comm->ucc_destroy_team(team);
|
|
TORCH_UCC_LOG_INFO(
|
|
TORCH_UCC_FINALIZE, "Successfully destroyed UCC library");
|
|
try {
|
|
if (cuda_ee) {
|
|
ucc_ee_destroy(cuda_ee);
|
|
ucc_ee_destroy(cuda_ee_p2p[0]);
|
|
ucc_ee_destroy(cuda_ee_p2p[1]);
|
|
}
|
|
} catch (std::exception& ex) {
|
|
TORCH_UCC_LOG_INFO(
|
|
TORCH_UCC_FINALIZE,
|
|
c10::str(
|
|
"(~ProcessGroupUCC) Caught error in Store Operation .. ",
|
|
"[",
|
|
ex.what(),
|
|
"]"));
|
|
}
|
|
comm = nullptr;
|
|
}
|
|
}
|
|
|
|
#ifdef USE_CUDA
|
|
// Return CUDA device with ordinal given by input rank.
|
|
c10::Device getCUDADeviceForRank(int rank) {
|
|
TORCH_CHECK(rank >= 0, "Invalid rank ", rank);
|
|
auto numGPUs = at::cuda::getNumGPUs();
|
|
auto deviceIdx = static_cast<c10::DeviceIndex>(rank % numGPUs);
|
|
return c10::Device(c10::DeviceType::CUDA, deviceIdx);
|
|
}
|
|
#endif
|
|
|
|
void ProcessGroupUCC::runHealthCheck() {
|
|
// Run health check in a separate thread and wait on CV to handle timeouts.
|
|
// This design allows us to handle hangs.
|
|
|
|
// When size_ is 1, there is no need to do any communication at all.
|
|
if (size_ == 1)
|
|
return;
|
|
|
|
struct HealthCheckData {
|
|
std::mutex healthCheckMutex;
|
|
std::condition_variable healthCheckCv;
|
|
bool uccHealthCheckSuccess = false;
|
|
std::exception_ptr healthCheckException;
|
|
} healthCheckData;
|
|
|
|
auto t = std::thread([&healthCheckData, this]() {
|
|
std::list<c10::Device> devices{c10::kCPU};
|
|
#ifdef USE_CUDA
|
|
c10::cuda::OptionalCUDAGuard gpuGuard;
|
|
if (at::cuda::is_available()) {
|
|
devices.emplace_front(getCUDADeviceForRank(rank_));
|
|
}
|
|
#endif
|
|
for (auto device : devices) {
|
|
bool is_last_device = (device == devices.back());
|
|
try {
|
|
auto oob = std::make_shared<torch_ucc_oob_coll_info_t>();
|
|
oob->rank = this->oob->rank;
|
|
oob->size = this->oob->size;
|
|
oob->store = this->oob->store;
|
|
ucc_team_h team = nullptr;
|
|
uint32_t comm_id;
|
|
#ifdef USE_CUDA
|
|
if (device.is_cuda()) {
|
|
gpuGuard.set_index(device.index());
|
|
}
|
|
#endif
|
|
auto comm = Comm::get_comm(comm_id, device, oob, logger, true);
|
|
comm->ucc_create_team(team, oob);
|
|
comm->ucc_destroy_team(team);
|
|
TORCH_UCC_LOG_INFO(
|
|
TORCH_UCC_HEALTH_CHECK,
|
|
c10::str(
|
|
"UCC library health check succeed for device ",
|
|
c10::DeviceTypeName(device.type())));
|
|
// Mark ucc health check as complete.
|
|
if (is_last_device) {
|
|
std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex);
|
|
healthCheckData.uccHealthCheckSuccess = true;
|
|
}
|
|
|
|
comm = nullptr;
|
|
oob = nullptr;
|
|
// Notify main thread the health check is complete.
|
|
if (is_last_device) {
|
|
healthCheckData.healthCheckCv.notify_one();
|
|
}
|
|
} catch (const std::exception&) {
|
|
// Populate exception ptr.
|
|
healthCheckData.healthCheckException = std::current_exception();
|
|
// Unblock waiting main thread which will report exception.
|
|
healthCheckData.healthCheckCv.notify_one();
|
|
} // Unknown exceptions will just cause the program to terminate.
|
|
}
|
|
});
|
|
// We don't need to join the thread, just need to verify health check via the
|
|
// CV. Hence we detach the thread here.
|
|
t.detach(); // NOLINT
|
|
TORCH_UCC_LOG_INFO(
|
|
TORCH_UCC_HEALTH_CHECK,
|
|
c10::str(
|
|
"will wait up to ",
|
|
timeout_.count(),
|
|
" msec for UCC health check to complete."));
|
|
std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex);
|
|
healthCheckData.healthCheckCv.wait_for(lock, timeout_, [&healthCheckData]() {
|
|
return healthCheckData.uccHealthCheckSuccess;
|
|
});
|
|
|
|
if (healthCheckData.healthCheckException) {
|
|
std::rethrow_exception(healthCheckData.healthCheckException);
|
|
}
|
|
// If there is no exception, the likely culprit is a timeout/hang
|
|
TORCH_CHECK(
|
|
healthCheckData.uccHealthCheckSuccess,
|
|
"ProcessGroupUCC: Health check failure: Failed to initialize UCC on rank ",
|
|
rank_);
|
|
}
|
|
|
|
void ProcessGroupUCC::set_timeout(ucc_coll_args_t& args) {
|
|
args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
|
|
args.flags |= UCC_COLL_ARGS_FLAG_TIMEOUT;
|
|
args.timeout = timeout_.count();
|
|
}
|
|
|
|
#ifdef USE_CUDA
|
|
std::unique_ptr<at::cuda::CUDAEvent> ProcessGroupUCC::getPooledEvent() {
|
|
std::unique_ptr<at::cuda::CUDAEvent> ev;
|
|
std::lock_guard<std::mutex> lock(ep.event_pool_mutex);
|
|
if (ep.event_pool.empty()) {
|
|
ev = std::make_unique<at::cuda::CUDAEvent>();
|
|
} else {
|
|
ev = std::move(ep.event_pool.front());
|
|
ep.event_pool.pop();
|
|
}
|
|
return ev;
|
|
}
|
|
#endif
|
|
|
|
template <typename PreProcess, typename PostProcess>
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::collective_post(
|
|
OpType opType,
|
|
PreProcess preproc,
|
|
PostProcess postproc,
|
|
ucc_coll_args_t& coll,
|
|
std::unique_ptr<ProcessGroupUCC::WorkData> data,
|
|
c10::Device dev,
|
|
std::vector<at::Tensor>& inputTensors,
|
|
std::vector<at::Tensor>& outputTensors,
|
|
const char* prof_title) {
|
|
seq_++;
|
|
set_timeout(coll);
|
|
auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
|
|
opType, seq_, prof_title, inputTensors, logger);
|
|
|
|
if (opType == OpType::RECV) {
|
|
work->sourceRank_ = coll.root;
|
|
}
|
|
|
|
RECORD_COMMS_TRACE(
|
|
logger->trace_generator,
|
|
work,
|
|
opType,
|
|
this->getRank(),
|
|
this->getSize(),
|
|
inputTensors,
|
|
outputTensors);
|
|
|
|
// Store references to outputs to be used by result
|
|
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputTensors);
|
|
switch (dev.type()) {
|
|
case c10::DeviceType::CPU: {
|
|
if (torch_ucc_config.use_future) {
|
|
work->future_ = c10::make_intrusive<at::ivalue::Future>(
|
|
c10::ListType::create(c10::TensorType::get()));
|
|
}
|
|
preproc();
|
|
comm->enqueue_collective(std::move(data), work, coll, team);
|
|
postproc();
|
|
return work;
|
|
}
|
|
#ifdef USE_CUDA
|
|
case c10::DeviceType::CUDA: {
|
|
auto cuda_ev = getPooledEvent();
|
|
at::cuda::CUDAStream* op_stream;
|
|
ucc_ee_h* op_ee;
|
|
if (opType == OpType::SEND) {
|
|
op_stream = stream_p2p[0].get();
|
|
op_ee = &cuda_ee_p2p[0];
|
|
} else if (opType == OpType::RECV) {
|
|
op_stream = stream_p2p[1].get();
|
|
op_ee = &cuda_ee_p2p[1];
|
|
} else {
|
|
op_stream = stream.get();
|
|
op_ee = &cuda_ee;
|
|
}
|
|
|
|
cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index()));
|
|
cuda_ev->block(*op_stream);
|
|
at::cuda::CUDAStreamGuard guard(*op_stream);
|
|
preproc();
|
|
comm->enqueue_cuda_collective(std::move(data), work, coll, team, *op_ee);
|
|
postproc();
|
|
cuda_ev->record(*op_stream);
|
|
work->fence = std::move(cuda_ev);
|
|
work->ep = &ep;
|
|
if (torch_ucc_config.use_future) {
|
|
c10::cuda::CUDAMultiStreamGuard streamGuard(*op_stream);
|
|
std::vector<c10::Device> devList{dev};
|
|
work->future_ = c10::make_intrusive<at::ivalue::Future>(
|
|
c10::ListType::create(c10::TensorType::get()), devList);
|
|
// Add a callback that runs profiling end callbacks
|
|
if (work->recordFunctionEndCallback_) {
|
|
work->future_->addCallback([work](at::ivalue::Future& /* unused */) {
|
|
work->recordFunctionEndCallback_();
|
|
});
|
|
}
|
|
|
|
work->future_->markCompleted(c10::IValue(outputTensors));
|
|
}
|
|
return work;
|
|
}
|
|
#endif // #ifdef USE_CUDA
|
|
default: {
|
|
TORCH_UCC_LOG_ERROR(
|
|
TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str()));
|
|
throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
|
|
}
|
|
}
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::allgather(
|
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
|
std::vector<at::Tensor>& inputTensors,
|
|
const AllgatherOptions& /* unused */) {
|
|
auto& tensor = inputTensors[0];
|
|
check_device(tensor.device(), outputTensors[0][0].device());
|
|
initComm(tensor.device());
|
|
|
|
if (tensor.device().is_cpu() || torch_ucc_config.use_allgatherv) {
|
|
AllgathervWorkData* data = new AllgathervWorkData(size_);
|
|
for (int i = 0; i < size_; i++) {
|
|
data->recv_lengths[i] = tensor.element_size() * tensor.numel();
|
|
data->recv_offsets[i] = (uint64_t)outputTensors[0][i].data_ptr();
|
|
}
|
|
ucc_coll_args_t coll;
|
|
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
|
|
coll.flags =
|
|
UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
|
|
coll.coll_type = UCC_COLL_TYPE_ALLGATHERV;
|
|
coll.src.info.buffer = tensor.data_ptr();
|
|
coll.src.info.count = tensor.element_size() * tensor.numel();
|
|
coll.src.info.datatype = UCC_DT_UINT8;
|
|
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
|
|
coll.dst.info_v.buffer = nullptr;
|
|
coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
|
|
coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
|
|
coll.dst.info_v.datatype = UCC_DT_UINT8;
|
|
coll.dst.info_v.mem_type =
|
|
to_ucc_memType(outputTensors[0][0].device().type());
|
|
SAVE_TENSORS(inputTensors, data->src);
|
|
SAVE_TENSORS(outputTensors[0], data->dst);
|
|
|
|
return collective_post(
|
|
OpType::ALLGATHER,
|
|
[]() {},
|
|
[]() {},
|
|
coll,
|
|
std::unique_ptr<WorkData>(data),
|
|
tensor.device(),
|
|
inputTensors,
|
|
outputTensors[0],
|
|
"ucc:all_gather");
|
|
} else {
|
|
WorkData* data = new WorkData();
|
|
std::vector<at::Tensor> flat_output(outputTensors.size());
|
|
for (size_t i = 0; i < outputTensors.size(); i++) {
|
|
TORCH_CHECK(
|
|
outputTensors[i].size() == outputTensors.size() * size_,
|
|
"Tensor output list is not valid for the number of participants");
|
|
flat_output[i] = c10d::newLikeFlat(outputTensors, i);
|
|
}
|
|
SAVE_TENSORS(flat_output, data->flat);
|
|
ucc_coll_args_t coll;
|
|
coll.mask = 0;
|
|
coll.flags = 0;
|
|
coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
|
|
coll.src.info.buffer = tensor.data_ptr();
|
|
coll.src.info.count = tensor.numel();
|
|
coll.src.info.datatype = to_ucc_dType(tensor);
|
|
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
|
|
coll.dst.info.buffer = flat_output[0].data_ptr();
|
|
coll.dst.info.count = flat_output[0].numel();
|
|
coll.dst.info.datatype = to_ucc_dType(flat_output[0]);
|
|
coll.dst.info.mem_type =
|
|
to_ucc_memType(outputTensors[0][0].device().type());
|
|
|
|
auto copy_from_flat = [&] {
|
|
bool asyncCopy = false;
|
|
#ifdef USE_CUDA
|
|
bool isCuda = outputTensors[0][0].device().is_cuda();
|
|
;
|
|
#endif
|
|
for (size_t i = 0; i < outputTensors.size(); i++) {
|
|
auto inumel = inputTensors[i].numel();
|
|
for (size_t j = 0; j < outputTensors[i].size(); j++) {
|
|
TORCH_CHECK(
|
|
(outputTensors[i][j].numel() == inumel),
|
|
"Tensor operand counts must be same");
|
|
#ifdef USE_CUDA
|
|
if (isCuda) {
|
|
c10::cuda::CUDACachingAllocator::recordStream(
|
|
outputTensors[i][j].storage().data_ptr(), (*stream));
|
|
asyncCopy = true;
|
|
}
|
|
#endif
|
|
outputTensors[i][j].copy_(flat_output[i][j], asyncCopy);
|
|
}
|
|
}
|
|
};
|
|
return collective_post(
|
|
OpType::ALLGATHER,
|
|
[]() {},
|
|
copy_from_flat,
|
|
coll,
|
|
std::unique_ptr<WorkData>(data),
|
|
tensor.device(),
|
|
inputTensors,
|
|
outputTensors[0],
|
|
"ucc:all_gather");
|
|
}
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::_allgather_base(
|
|
at::Tensor& outputTensor,
|
|
at::Tensor& inputTensor,
|
|
const AllgatherOptions& opts) {
|
|
check_tensor({outputTensor});
|
|
check_tensor({inputTensor});
|
|
initComm(outputTensor.device());
|
|
|
|
WorkData* data = new WorkData();
|
|
|
|
ucc_coll_args_t coll;
|
|
coll.mask = 0;
|
|
coll.flags = 0;
|
|
coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
|
|
coll.src.info.buffer = inputTensor.data_ptr();
|
|
coll.src.info.count = inputTensor.numel();
|
|
coll.src.info.datatype = ucc_dtype_map.at(inputTensor.scalar_type());
|
|
coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
|
|
coll.dst.info.buffer = outputTensor.data_ptr();
|
|
coll.dst.info.count = outputTensor.numel();
|
|
coll.dst.info.datatype = ucc_dtype_map.at(outputTensor.scalar_type());
|
|
coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());
|
|
|
|
std::vector<at::Tensor> inputTensors = {inputTensor};
|
|
std::vector<at::Tensor> outputTensors = {outputTensor};
|
|
SAVE_TENSORS(inputTensors, data->src);
|
|
SAVE_TENSORS(outputTensors, data->dst);
|
|
|
|
return collective_post(
|
|
OpType::_ALLGATHER_BASE,
|
|
[]() {},
|
|
[]() {},
|
|
coll,
|
|
std::unique_ptr<WorkData>(data),
|
|
outputTensor.device(),
|
|
inputTensors,
|
|
outputTensors,
|
|
"ucc:allgather_base");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce(
|
|
std::vector<at::Tensor>& tensors,
|
|
const AllreduceOptions& opts) {
|
|
check_tensor(tensors);
|
|
auto& tensor = tensors[0];
|
|
initComm(tensor.device());
|
|
WorkData* data = new WorkData();
|
|
|
|
ucc_coll_args_t coll;
|
|
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
|
|
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
|
|
coll.coll_type = UCC_COLL_TYPE_ALLREDUCE;
|
|
coll.op = to_ucc_reduceOp(opts.reduceOp, tensor.scalar_type());
|
|
coll.src.info.buffer = nullptr;
|
|
coll.src.info.count = tensor.numel();
|
|
coll.src.info.datatype = to_ucc_dType(tensor);
|
|
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
|
|
coll.dst.info.buffer = tensor.data_ptr();
|
|
coll.dst.info.count = tensor.numel();
|
|
coll.dst.info.datatype = to_ucc_dType(tensor);
|
|
coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
|
|
SAVE_TENSORS(tensors, data->dst);
|
|
return collective_post(
|
|
OpType::ALLREDUCE,
|
|
[]() {},
|
|
[]() {},
|
|
coll,
|
|
std::unique_ptr<WorkData>(data),
|
|
tensor.device(),
|
|
tensors,
|
|
tensors,
|
|
"ucc:all_reduce");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce_coalesced(
|
|
std::vector<at::Tensor>& /* unused */,
|
|
const AllreduceCoalescedOptions& /* unused */) {
|
|
throw std::invalid_argument(
|
|
"ProcessGroupUCC does not support allreduce_coalesced");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall(
|
|
std::vector<at::Tensor>& outputTensors,
|
|
std::vector<at::Tensor>& inputTensors,
|
|
const AllToAllOptions& /* unused */) {
|
|
auto device = outputTensors[0].device();
|
|
for (const auto r : c10::irange(outputTensors.size())) {
|
|
TORCH_CHECK(
|
|
device == outputTensors[r].device() &&
|
|
device == inputTensors[r].device(),
|
|
"Tensors must be on the same device")
|
|
}
|
|
|
|
initComm(device);
|
|
ucc_coll_args_t coll;
|
|
AlltoallWorkData* data;
|
|
data = new AlltoallWorkData(size_);
|
|
|
|
/* to avoid flatten the tensors, we use alltoallv to achieve Alltoall as
|
|
follow.
|
|
1. store addresses of each tensor directly in displacements, keep buffer
|
|
to nullptr, i.e., 0
|
|
2. convert datatype to UINT8, which is always 1 bytes, to avoid wrong size
|
|
calculation in UCC layer
|
|
3. post Alltoallv
|
|
*/
|
|
for (const auto i : c10::irange(size_)) {
|
|
data->send_lengths[i] =
|
|
(uint64_t)(inputTensors[i].element_size() * inputTensors[i].numel());
|
|
data->send_offsets[i] = (uint64_t)inputTensors[i].data_ptr();
|
|
data->recv_lengths[i] =
|
|
(uint64_t)(outputTensors[i].element_size() * outputTensors[i].numel());
|
|
data->recv_offsets[i] = (uint64_t)outputTensors[i].data_ptr();
|
|
}
|
|
|
|
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
|
|
coll.flags =
|
|
UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
|
|
coll.coll_type = UCC_COLL_TYPE_ALLTOALLV;
|
|
coll.src.info_v.buffer = 0;
|
|
coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
|
|
coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
|
|
coll.src.info_v.datatype = UCC_DT_UINT8;
|
|
coll.src.info_v.mem_type = to_ucc_memType(inputTensors[0].device().type());
|
|
coll.dst.info_v.buffer = 0;
|
|
coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
|
|
coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
|
|
coll.dst.info_v.datatype = UCC_DT_UINT8;
|
|
coll.dst.info_v.mem_type = to_ucc_memType(outputTensors[0].device().type());
|
|
|
|
SAVE_TENSORS(inputTensors, data->src);
|
|
SAVE_TENSORS(outputTensors, data->dst);
|
|
|
|
return collective_post(
|
|
OpType::ALLTOALL,
|
|
[]() {},
|
|
[]() {},
|
|
coll,
|
|
std::unique_ptr<WorkData>(data),
|
|
device,
|
|
inputTensors,
|
|
outputTensors,
|
|
"ucc:alltoall");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall_base(
|
|
at::Tensor& outputTensor,
|
|
at::Tensor& inputTensor,
|
|
std::vector<int64_t>& outputSplitSizes,
|
|
std::vector<int64_t>& inputSplitSizes,
|
|
const AllToAllOptions& /* unused */) {
|
|
check_device(inputTensor.device(), outputTensor.device());
|
|
initComm(inputTensor.device());
|
|
ucc_coll_args_t coll;
|
|
AlltoallWorkData* data;
|
|
|
|
if ((outputSplitSizes.size() == 0) && (inputSplitSizes.size() == 0)) {
|
|
data = new AlltoallWorkData(0);
|
|
TORCH_CHECK(
|
|
(outputTensor.size(0) % size_ == 0) &&
|
|
(inputTensor.size(0) % size_ == 0),
|
|
"Tensor's dim 0 does not divide equally across group size");
|
|
coll.mask = 0;
|
|
coll.flags = 0;
|
|
coll.coll_type = UCC_COLL_TYPE_ALLTOALL;
|
|
coll.src.info.buffer = inputTensor.data_ptr();
|
|
coll.src.info.count = inputTensor.element_size() * inputTensor.numel();
|
|
coll.src.info.datatype = UCC_DT_UINT8;
|
|
coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
|
|
coll.dst.info.buffer = outputTensor.data_ptr();
|
|
coll.dst.info.count = outputTensor.element_size() * outputTensor.numel();
|
|
coll.dst.info.datatype = UCC_DT_UINT8;
|
|
coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());
|
|
coll.flags = 0;
|
|
} else {
|
|
data = new AlltoallWorkData(size_);
|
|
c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
|
|
c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
|
|
computeLengthsAndOffsets(
|
|
outputSplitSizes,
|
|
outputTensor,
|
|
&data->recv_lengths,
|
|
&data->recv_offsets);
|
|
computeLengthsAndOffsets(
|
|
inputSplitSizes, inputTensor, &data->send_lengths, &data->send_offsets);
|
|
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
|
|
coll.coll_type = UCC_COLL_TYPE_ALLTOALLV;
|
|
coll.src.info_v.buffer = inputTensor.data_ptr();
|
|
coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
|
|
coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
|
|
coll.src.info_v.datatype = to_ucc_dType(inputTensor);
|
|
coll.src.info_v.mem_type = to_ucc_memType(inputTensor.device().type());
|
|
coll.dst.info_v.buffer = outputTensor.data_ptr();
|
|
coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
|
|
coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
|
|
coll.dst.info_v.datatype = to_ucc_dType(outputTensor);
|
|
coll.dst.info_v.mem_type = to_ucc_memType(outputTensor.device().type());
|
|
coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER |
|
|
UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER | UCC_COLL_ARGS_FLAG_COUNT_64BIT |
|
|
UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
|
|
|
|
if (torch_ucc_config.enable_comms_logger) {
|
|
logger->trace_generator->recordOptionalInfo(
|
|
outputSplitSizes, inputSplitSizes);
|
|
}
|
|
}
|
|
std::vector<at::Tensor> inputTensors = {inputTensor};
|
|
std::vector<at::Tensor> outputTensors = {outputTensor};
|
|
SAVE_TENSORS(inputTensors, data->src);
|
|
SAVE_TENSORS(outputTensors, data->dst);
|
|
|
|
return collective_post(
|
|
OpType::ALLTOALL_BASE,
|
|
[]() {},
|
|
[]() {},
|
|
coll,
|
|
std::unique_ptr<WorkData>(data),
|
|
inputTensor.device(),
|
|
inputTensors,
|
|
outputTensors,
|
|
"ucc:alltoall");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::barrier(const BarrierOptions& opts) {
|
|
c10::Device device = c10::Device(c10::DeviceType::CPU);
|
|
#ifdef USE_CUDA
|
|
auto numGPUs = c10::cuda::device_count();
|
|
if (!opts.device_ids.empty()) {
|
|
device = c10::Device(c10::DeviceType::CUDA, opts.device_ids.front());
|
|
} else if (comm && comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) {
|
|
device = c10::Device(c10::DeviceType::CUDA, comm->cuda_device_index);
|
|
} else if (numGPUs > 0) {
|
|
int8_t deviceIdx = static_cast<int8_t>(c10::cuda::current_device());
|
|
// if current device is 0, likely the device is not set, use the best guess
|
|
if (0 == (int)deviceIdx) {
|
|
deviceIdx = static_cast<int8_t>(this->getRank() % numGPUs);
|
|
}
|
|
TORCH_UCC_LOG_INFO(
|
|
TORCH_UCC_COLL_POST,
|
|
c10::str(
|
|
"post barrier before specifying any GPU while there are ",
|
|
numGPUs,
|
|
" GPUs available. ",
|
|
"Not clear if GPU barrier is required, using GPU ",
|
|
(int)deviceIdx,
|
|
" to perform barrier. ",
|
|
"Specify device_ids option in barrier() to force ",
|
|
"use of a particular device"));
|
|
device = c10::Device(c10::DeviceType::CUDA, deviceIdx);
|
|
}
|
|
#endif
|
|
initComm(device);
|
|
|
|
ucc_coll_args_t coll;
|
|
coll.mask = 0;
|
|
coll.flags = 0;
|
|
coll.coll_type = UCC_COLL_TYPE_BARRIER;
|
|
auto dummy_tensor = std::vector<at::Tensor>();
|
|
return collective_post(
|
|
OpType::BARRIER,
|
|
[]() {},
|
|
[]() {},
|
|
coll,
|
|
nullptr,
|
|
device,
|
|
dummy_tensor,
|
|
dummy_tensor,
|
|
"ucc:barrier");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::broadcast(
|
|
std::vector<at::Tensor>& tensors,
|
|
const BroadcastOptions& opts) {
|
|
check_tensor(tensors);
|
|
auto& tensor = tensors[0];
|
|
initComm(tensor.device());
|
|
WorkData* data = new WorkData();
|
|
|
|
ucc_coll_args_t coll;
|
|
coll.mask = 0;
|
|
coll.flags = 0;
|
|
coll.coll_type = UCC_COLL_TYPE_BCAST;
|
|
coll.src.info.buffer = tensor.data_ptr();
|
|
coll.src.info.count = tensor.numel();
|
|
coll.src.info.datatype = to_ucc_dType(tensor);
|
|
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
|
|
coll.root = opts.rootRank;
|
|
SAVE_TENSORS(tensors, data->dst);
|
|
|
|
if (torch_ucc_config.enable_comms_logger) {
|
|
logger->trace_generator->recordOptionalInfo(opts.rootRank);
|
|
}
|
|
|
|
return collective_post(
|
|
OpType::BROADCAST,
|
|
[]() {},
|
|
[]() {},
|
|
coll,
|
|
std::unique_ptr<WorkData>(data),
|
|
tensor.device(),
|
|
tensors,
|
|
tensors,
|
|
"ucc:broadcast");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::gather(
|
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
|
std::vector<at::Tensor>& inputTensors,
|
|
const GatherOptions& opts) {
|
|
std::vector<at::Tensor> outputs;
|
|
auto& input = inputTensors[0];
|
|
initComm(input.device());
|
|
|
|
AllgathervWorkData* data = new AllgathervWorkData(size_);
|
|
ucc_coll_args_t coll;
|
|
coll.root = opts.rootRank;
|
|
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
|
|
coll.flags =
|
|
UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
|
|
coll.coll_type = UCC_COLL_TYPE_GATHERV;
|
|
|
|
/* for non-root ranks, only src is valid */
|
|
coll.src.info.buffer = input.data_ptr();
|
|
coll.src.info.count = (uint64_t)(input.element_size() * input.numel());
|
|
coll.src.info.datatype = UCC_DT_UINT8;
|
|
coll.src.info.mem_type = to_ucc_memType(input.device().type());
|
|
|
|
if (getRank() == opts.rootRank) {
|
|
if (outputTensors.size() != 1) {
|
|
TORCH_UCC_LOG_ERROR(
|
|
TORCH_UCC_COLL_POST,
|
|
c10::str(
|
|
"gather requires a single-element output list containing a list with ",
|
|
getSize(),
|
|
" tensors."));
|
|
} else if (outputTensors[0].size() != static_cast<size_t>(getSize())) {
|
|
TORCH_UCC_LOG_ERROR(
|
|
TORCH_UCC_COLL_POST,
|
|
c10::str(
|
|
"Incorrect output list size ",
|
|
outputTensors[0].size(),
|
|
". Output list size should be ",
|
|
getSize(),
|
|
", same as size of the process group."));
|
|
}
|
|
outputs = outputTensors[0];
|
|
|
|
for (int i = 0; i < size_; i++) {
|
|
data->recv_lengths[i] =
|
|
(uint64_t)(outputs[i].element_size() * outputs[i].numel());
|
|
data->recv_offsets[i] = (uint64_t)outputs[i].data_ptr();
|
|
}
|
|
/* use gatherv and store non-contiguous addresses in displacements to avoid
|
|
* flatten outputTensors */
|
|
coll.dst.info_v.buffer = nullptr;
|
|
coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
|
|
coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
|
|
coll.dst.info_v.datatype = UCC_DT_UINT8;
|
|
coll.dst.info_v.mem_type = to_ucc_memType(outputs[0].device().type());
|
|
|
|
SAVE_TENSORS(outputs, data->dst);
|
|
} else {
|
|
// for non-root ranks, outputTensors should be an empty list
|
|
if (!outputTensors.empty()) {
|
|
TORCH_UCC_LOG_ERROR(
|
|
TORCH_UCC_COLL_POST, "requires empty output on non-root");
|
|
}
|
|
outputs = {};
|
|
// append a empty tensor to the list to be used by future mark
|
|
outputs.emplace_back();
|
|
}
|
|
|
|
SAVE_TENSORS(inputTensors, data->src);
|
|
|
|
return collective_post(
|
|
OpType::GATHER,
|
|
[]() {},
|
|
[]() {},
|
|
coll,
|
|
std::unique_ptr<WorkData>(data),
|
|
input.device(),
|
|
inputTensors,
|
|
outputs,
|
|
"ucc:gather");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::reduce(
|
|
std::vector<at::Tensor>& tensors,
|
|
const ReduceOptions& opts) {
|
|
check_tensor(tensors);
|
|
auto& tensor = tensors[0];
|
|
initComm(tensor.device());
|
|
WorkData* data = new WorkData();
|
|
|
|
ucc_coll_args_t coll;
|
|
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
|
|
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
|
|
coll.coll_type = UCC_COLL_TYPE_REDUCE;
|
|
coll.op = ucc_op_map.at(opts.reduceOp);
|
|
coll.root = opts.rootRank;
|
|
coll.src.info.buffer = tensor.data_ptr();
|
|
coll.src.info.count = tensor.numel();
|
|
coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
|
|
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
|
|
coll.dst.info.buffer = tensor.data_ptr();
|
|
coll.dst.info.count = tensor.numel();
|
|
coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
|
|
coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
|
|
SAVE_TENSORS(tensors, data->dst);
|
|
return collective_post(
|
|
OpType::REDUCE,
|
|
[]() {},
|
|
[]() {},
|
|
coll,
|
|
std::unique_ptr<WorkData>(data),
|
|
tensor.device(),
|
|
tensors,
|
|
tensors,
|
|
"ucc:reduce");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::reduce_scatter(
|
|
std::vector<at::Tensor>& outputTensors,
|
|
std::vector<std::vector<at::Tensor>>& inputTensors,
|
|
const ReduceScatterOptions& opts) {
|
|
TORCH_CHECK(
|
|
(outputTensors.size() == inputTensors.size()),
|
|
"Tensor input/output list for reduce_scatter must have same size");
|
|
check_tensor(outputTensors);
|
|
check_device(inputTensors[0][0].device(), outputTensors[0].device());
|
|
initComm(inputTensors[0][0].device());
|
|
auto data = std::make_unique<WorkData>();
|
|
std::vector<at::Tensor> flat_input(inputTensors.size());
|
|
for (size_t i = 0; i < inputTensors.size(); i++) {
|
|
TORCH_CHECK(
|
|
inputTensors[i].size() == inputTensors.size() * size_,
|
|
"Tensor input list is not valid for the number of participants");
|
|
flat_input[i] = c10d::newLikeFlat(inputTensors, i);
|
|
}
|
|
SAVE_TENSORS(flat_input, data->flat);
|
|
check_tensor(flat_input);
|
|
ucc_coll_args_t coll;
|
|
coll.mask = 0;
|
|
coll.flags = 0;
|
|
coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER;
|
|
coll.op = to_ucc_reduceOp(opts.reduceOp, flat_input[0].scalar_type());
|
|
|
|
coll.src.info.buffer = flat_input[0].data_ptr();
|
|
coll.src.info.count = flat_input[0].numel();
|
|
coll.src.info.datatype = to_ucc_dType(flat_input[0]);
|
|
coll.src.info.mem_type = to_ucc_memType(flat_input[0].device().type());
|
|
coll.dst.info.buffer = outputTensors[0].data_ptr();
|
|
coll.dst.info.count = outputTensors[0].numel();
|
|
coll.dst.info.datatype = to_ucc_dType(outputTensors[0]);
|
|
coll.dst.info.mem_type = to_ucc_memType(outputTensors[0].device().type());
|
|
|
|
SAVE_TENSORS(inputTensors[0], data->src);
|
|
SAVE_TENSORS(outputTensors, data->dst);
|
|
|
|
auto copy_to_flat = [&] {
|
|
bool asyncCopy = false;
|
|
auto isize = inputTensors.size();
|
|
#ifdef USE_CUDA
|
|
bool isCuda = inputTensors[0][0].device().is_cuda();
|
|
#endif
|
|
for (size_t i = 0; i < isize; i++) {
|
|
auto onumel = outputTensors[i].numel();
|
|
for (size_t j = 0; j < inputTensors[i].size(); j++) {
|
|
TORCH_CHECK(
|
|
(inputTensors[i][j].numel() == onumel),
|
|
"Tensor operand counts must be same");
|
|
#ifdef USE_CUDA
|
|
if (isCuda) {
|
|
c10::cuda::CUDACachingAllocator::recordStream(
|
|
inputTensors[i][j].storage().data_ptr(), (*stream));
|
|
asyncCopy = true;
|
|
}
|
|
#endif
|
|
flat_input[i][j].copy_(inputTensors[i][j], asyncCopy);
|
|
}
|
|
}
|
|
};
|
|
|
|
return collective_post(
|
|
OpType::REDUCE_SCATTER,
|
|
copy_to_flat,
|
|
[]() {},
|
|
coll,
|
|
std::move(data),
|
|
inputTensors[0][0].device(),
|
|
inputTensors[0],
|
|
outputTensors,
|
|
"ucc:reduce_scatter");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::_reduce_scatter_base(
|
|
at::Tensor& outputTensor,
|
|
at::Tensor& inputTensor,
|
|
const ReduceScatterOptions& opts) {
|
|
check_tensor({outputTensor});
|
|
check_tensor({inputTensor});
|
|
initComm(outputTensor.device());
|
|
|
|
auto data = std::make_unique<WorkData>();
|
|
|
|
ucc_coll_args_t coll;
|
|
coll.mask = 0;
|
|
coll.flags = 0;
|
|
coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER;
|
|
coll.op = to_ucc_reduceOp(opts.reduceOp, inputTensor.scalar_type());
|
|
|
|
coll.src.info.buffer = inputTensor.data_ptr();
|
|
coll.src.info.count = inputTensor.numel();
|
|
coll.src.info.datatype = ucc_dtype_map.at(inputTensor.scalar_type());
|
|
coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
|
|
coll.dst.info.buffer = outputTensor.data_ptr();
|
|
coll.dst.info.count = outputTensor.numel();
|
|
coll.dst.info.datatype = ucc_dtype_map.at(outputTensor.scalar_type());
|
|
coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());
|
|
|
|
std::vector<at::Tensor> inputTensors = {inputTensor};
|
|
std::vector<at::Tensor> outputTensors = {outputTensor};
|
|
SAVE_TENSORS(inputTensors, data->src);
|
|
SAVE_TENSORS(outputTensors, data->dst);
|
|
|
|
return collective_post(
|
|
OpType::_REDUCE_SCATTER_BASE,
|
|
[]() {},
|
|
[]() {},
|
|
coll,
|
|
std::move(data),
|
|
outputTensor.device(),
|
|
inputTensors,
|
|
outputTensors,
|
|
"ucc:_reduce_scatter_base");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::scatter(
|
|
std::vector<at::Tensor>& outputTensors,
|
|
std::vector<std::vector<at::Tensor>>& inputTensors,
|
|
const ScatterOptions& opts) {
|
|
auto& tensor = outputTensors[0];
|
|
initComm(tensor.device());
|
|
|
|
ScattervWorkData* data = new ScattervWorkData(size_);
|
|
ucc_coll_args_t coll;
|
|
coll.root = opts.rootRank;
|
|
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
|
|
coll.flags =
|
|
UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
|
|
coll.coll_type = UCC_COLL_TYPE_SCATTERV;
|
|
|
|
if (getRank() == opts.rootRank) {
|
|
/* src is only valid at non-root rank */
|
|
if (inputTensors.size() != 1) {
|
|
TORCH_UCC_LOG_ERROR(
|
|
TORCH_UCC_COLL_POST,
|
|
c10::str(
|
|
"gather requires a single-element output list containing a list with ",
|
|
getSize(),
|
|
" tensors."));
|
|
} else if (inputTensors[0].size() != static_cast<size_t>(getSize())) {
|
|
TORCH_UCC_LOG_ERROR(
|
|
TORCH_UCC_COLL_POST,
|
|
c10::str(
|
|
"Incorrect output list size ",
|
|
inputTensors[0].size(),
|
|
". Output list size should be ",
|
|
getSize(),
|
|
", same as size of the process group."));
|
|
}
|
|
|
|
for (int i = 0; i < size_; i++) {
|
|
data->send_lengths[i] = (uint64_t)tensor.element_size() * tensor.numel();
|
|
data->send_offsets[i] = (uint64_t)inputTensors[0][i].data_ptr();
|
|
}
|
|
/* use scatter and store non-contiguous addresses in displacements to avoid
|
|
* flatten inputTensors */
|
|
coll.src.info_v.buffer = nullptr;
|
|
coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
|
|
coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
|
|
coll.src.info_v.datatype = UCC_DT_UINT8;
|
|
coll.src.info_v.mem_type =
|
|
to_ucc_memType(inputTensors[0][0].device().type());
|
|
|
|
SAVE_TENSORS(inputTensors[0], data->src);
|
|
} else {
|
|
// for non-root ranks, inputTensors should be an empty list
|
|
if (!inputTensors.empty()) {
|
|
TORCH_UCC_LOG_ERROR(
|
|
TORCH_UCC_COLL_POST, "requires empty output on non-root");
|
|
}
|
|
}
|
|
|
|
coll.dst.info.buffer = tensor.data_ptr();
|
|
coll.dst.info.count = (uint64_t)tensor.element_size() * tensor.numel();
|
|
coll.dst.info.datatype = UCC_DT_UINT8;
|
|
coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
|
|
SAVE_TENSORS(outputTensors, data->dst);
|
|
|
|
return collective_post(
|
|
OpType::SCATTER,
|
|
[]() {},
|
|
[]() {},
|
|
coll,
|
|
std::unique_ptr<WorkData>(data),
|
|
tensor.device(),
|
|
(getRank() == opts.rootRank) ? inputTensors[0] : outputTensors,
|
|
outputTensors,
|
|
"ucc:scatter");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::send(
|
|
std::vector<at::Tensor>& tensors,
|
|
int dstRank,
|
|
int tag) {
|
|
check_tensor(tensors);
|
|
auto& tensor = tensors[0];
|
|
initComm(tensor.device());
|
|
|
|
WorkData* data = new WorkData();
|
|
ucc_coll_args_t coll;
|
|
coll.tag = tag;
|
|
coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG;
|
|
coll.flags = 0;
|
|
coll.coll_type = UCC_COLL_TYPE_BCAST;
|
|
coll.src.info.buffer = tensor.data_ptr();
|
|
coll.src.info.count = tensor.numel();
|
|
coll.src.info.datatype = to_ucc_dType(tensor);
|
|
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
|
|
coll.root = getRank();
|
|
|
|
coll.active_set.size = 2;
|
|
coll.active_set.start = getRank();
|
|
coll.active_set.stride = dstRank - getRank();
|
|
SAVE_TENSORS(tensors, data->dst);
|
|
|
|
return collective_post(
|
|
OpType::SEND,
|
|
[]() {},
|
|
[]() {},
|
|
coll,
|
|
std::unique_ptr<WorkData>(data),
|
|
tensor.device(),
|
|
tensors,
|
|
tensors,
|
|
"ucc:send");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupUCC::recv(
|
|
std::vector<at::Tensor>& tensors,
|
|
int srcRank,
|
|
int tag) {
|
|
check_tensor(tensors);
|
|
auto& tensor = tensors[0];
|
|
initComm(tensor.device());
|
|
|
|
WorkData* data = new WorkData();
|
|
ucc_coll_args_t coll;
|
|
coll.tag = tag;
|
|
coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG;
|
|
coll.flags = 0;
|
|
coll.coll_type = UCC_COLL_TYPE_BCAST;
|
|
coll.src.info.buffer = tensor.data_ptr();
|
|
coll.src.info.count = tensor.numel();
|
|
coll.src.info.datatype = to_ucc_dType(tensor);
|
|
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
|
|
coll.root = srcRank;
|
|
|
|
coll.active_set.size = 2;
|
|
coll.active_set.start = srcRank;
|
|
coll.active_set.stride = getRank() - srcRank;
|
|
SAVE_TENSORS(tensors, data->dst);
|
|
|
|
return collective_post(
|
|
OpType::RECV,
|
|
[]() {},
|
|
[]() {},
|
|
coll,
|
|
std::unique_ptr<WorkData>(data),
|
|
tensor.device(),
|
|
tensors,
|
|
tensors,
|
|
"ucc:recv");
|
|
}
|
|
|
|
void ProcessGroupUCC::setSequenceNumberForGroup() {}
|
|
|
|
uint64_t ProcessGroupUCC::getSequenceNumberForGroup() {
|
|
return seq_;
|
|
}
|
|
|
|
c10::intrusive_ptr<Backend> ProcessGroupUCC::createProcessGroupUCC(
|
|
const c10::intrusive_ptr<::c10d::Store>& store,
|
|
int rank,
|
|
int size,
|
|
const std::chrono::duration<float>& timeout) {
|
|
return c10::make_intrusive<ProcessGroupUCC>(store, rank, size, timeout);
|
|
}
|
|
|
|
void ProcessGroupUCC::initComm(c10::Device dev) {
|
|
if (!comm) {
|
|
#ifdef USE_CUDA
|
|
if (dev.is_cuda()) {
|
|
c10::cuda::set_device(dev.index());
|
|
}
|
|
#endif
|
|
comm = Comm::get_comm(comm_id, dev, oob, logger);
|
|
TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library");
|
|
comm->ucc_create_team(team, oob);
|
|
TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library");
|
|
logger->setPhase(TORCH_UCC_READY);
|
|
} else {
|
|
if (dev.is_cuda()) {
|
|
if ((comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
|
|
(comm->cuda_device_index != dev.index())) {
|
|
TORCH_UCC_LOG_ERROR(
|
|
TORCH_UCC_INIT,
|
|
"ucc communicator was initialized with different cuda device,"
|
|
"multi device is not supported");
|
|
throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
|
|
}
|
|
comm->cuda_device_index = dev.index();
|
|
}
|
|
}
|
|
#ifdef USE_CUDA
|
|
// Create UCC execution engine.
|
|
if (!cuda_ee && dev.is_cuda()) {
|
|
stream = std::make_unique<at::cuda::CUDAStream>(
|
|
at::cuda::getStreamFromPool(true, dev.index()));
|
|
ucc_ee_params_t params;
|
|
params.ee_type = UCC_EE_CUDA_STREAM;
|
|
params.ee_context = (void*)stream->stream();
|
|
params.ee_context_size = sizeof(cudaStream_t);
|
|
TORCH_UCC_CHECK(
|
|
ucc_ee_create(team, ¶ms, &cuda_ee),
|
|
"failed to create UCC execution engine");
|
|
for (int i = 0; i < 2; i++) {
|
|
stream_p2p[i] = std::make_unique<at::cuda::CUDAStream>(
|
|
at::cuda::getStreamFromPool(true, dev.index()));
|
|
ucc_ee_params_t params;
|
|
params.ee_type = UCC_EE_CUDA_STREAM;
|
|
params.ee_context = (void*)stream_p2p[i]->stream();
|
|
params.ee_context_size = sizeof(cudaStream_t);
|
|
TORCH_UCC_CHECK(
|
|
ucc_ee_create(team, ¶ms, &cuda_ee_p2p[i]),
|
|
"failed to create UCC P2P execution engine");
|
|
}
|
|
}
|
|
#endif
|
|
}
|
|
|
|
} // namespace c10d
|
|
|
|
#endif // USE_C10D_UCC
|