#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef USE_CUDA #include #include #include #include #include #include #endif #include #include #include #define GENERATE_ALL_TYPES(type, func, args...) \ switch (type) { \ case ::at::ScalarType::Float: \ func(args); \ break; \ case ::at::ScalarType::Double: \ func(args); \ break; \ case ::at::ScalarType::Half: \ func(args); \ break; \ case ::at::ScalarType::Char: \ func(args); \ break; \ case ::at::ScalarType::Byte: \ func(args); \ break; \ case ::at::ScalarType::Int: \ func(args); \ break; \ case ::at::ScalarType::Long: \ func(args); \ break; \ default: \ throw std::runtime_error("Invalid scalar type"); \ } namespace c10d { namespace { // Wrap c10d store as Gloo store class GlooStore : public ::gloo::rendezvous::Store { public: GlooStore(const std::shared_ptr<::c10d::Store>& store) : store_(store) {} void set(const std::string& key, const std::vector& value) override { std::vector tmp(value.begin(), value.end()); store_->set(key, tmp); } std::vector get(const std::string& key) override { auto value = store_->get(key); return std::vector(value.begin(), value.end()); } void wait(const std::vector& keys) override { store_->wait(keys, Store::kDefaultTimeout); } void wait( const std::vector& keys, const std::chrono::milliseconds& timeout) override { store_->wait(keys, timeout); } protected: std::shared_ptr<::c10d::Store> store_; }; typedef void (*ReduceFunc)(void*, const void*, const void*, size_t); template < typename T, typename std::enable_if::value, int>::type = 0> ReduceFunc toFunction(const ReduceOp& r) { switch (r) { case ReduceOp::SUM: return ReduceFunc(&::gloo::sum); case ReduceOp::PRODUCT: return ReduceFunc(&::gloo::product); case ReduceOp::MIN: return ReduceFunc(&::gloo::min); case ReduceOp::MAX: return ReduceFunc(&::gloo::max); case ReduceOp::BAND: throw std::runtime_error( "Cannot use ReduceOp.BAND with non-integral dtype"); break; case ReduceOp::BOR: throw std::runtime_error( "Cannot use ReduceOp.BOR with non-integral dtype"); break; case ReduceOp::BXOR: throw std::runtime_error( "Cannot use ReduceOp.BXOR with non-integral dtype"); break; case ReduceOp::UNUSED: break; } throw std::runtime_error("Unhandled ReduceOp"); } // Bitwise AND with SFINAE guard for integral types. template < typename T, typename std::enable_if::value, int>::type = 0> void band(void* c, const void* a, const void* b, size_t n) { auto tc = static_cast(c); auto ta = static_cast(a); auto tb = static_cast(b); for (size_t i = 0; i < n; i++) { tc[i] = ta[i] & tb[i]; } } // Bitwise OR with SFINAE guard for integral types. template < typename T, typename std::enable_if::value, int>::type = 0> void bor(void* c, const void* a, const void* b, size_t n) { auto tc = static_cast(c); auto ta = static_cast(a); auto tb = static_cast(b); for (size_t i = 0; i < n; i++) { tc[i] = ta[i] | tb[i]; } } // Bitwise XOR with SFINAE guard for integral types. template < typename T, typename std::enable_if::value, int>::type = 0> void bxor(void* c, const void* a, const void* b, size_t n) { auto tc = static_cast(c); auto ta = static_cast(a); auto tb = static_cast(b); for (size_t i = 0; i < n; i++) { tc[i] = ta[i] ^ tb[i]; } } template < typename T, typename std::enable_if::value, int>::type = 0> ReduceFunc toFunction(const ReduceOp& r) { switch (r) { case ReduceOp::SUM: return ReduceFunc(&::gloo::sum); case ReduceOp::PRODUCT: return ReduceFunc(&::gloo::product); case ReduceOp::MIN: return ReduceFunc(&::gloo::min); case ReduceOp::MAX: return ReduceFunc(&::gloo::max); case ReduceOp::BAND: return ReduceFunc(&band); case ReduceOp::BOR: return ReduceFunc(&bor); case ReduceOp::BXOR: return ReduceFunc(&bxor); case ReduceOp::UNUSED: break; } throw std::runtime_error("Unhandled ReduceOp"); } template void setInputs(O& opts, std::vector& tensors) { opts.setInputs(getDataPointers(tensors), tensors[0].numel()); } template void setInput(O& opts, at::Tensor& tensor) { opts.setInput(getDataPointer(tensor), tensor.numel()); } template void setOutputs(O& opts, std::vector& tensors) { opts.setOutputs(getDataPointers(tensors), tensors[0].numel()); } template void setOutput(O& opts, at::Tensor& tensor) { opts.setOutput(getDataPointer(tensor), tensor.numel()); } template void setOutput(O& opts, at::Tensor& tensor, std::vector& counts) { opts.setOutput(getDataPointer(tensor), counts); } #ifdef USE_CUDA at::Tensor pinnedLike(at::Tensor& tensor) { auto* allocator = at::cuda::getPinnedMemoryAllocator(); auto storage = c10::Storage( tensor.dtype(), at::detail::computeStorageSize(tensor.sizes(), tensor.strides()), allocator, /*resizable=*/false); return at::empty({0}, tensor.options().device(at::kCPU)) .set_(storage, 0, tensor.sizes(), tensor.strides()); } // This function initializes a vector of CUDA streams, one for every // tensor in the input tensor vector, and ensures that these streams are // synchronized with the current default streams. This is needed so // that new work on the new streams is serialized w.r.t. all operations // on the tensors. void initializeStreamsEvents( std::vector& tensors, std::vector& streams, std::vector& events) { at::cuda::OptionalCUDAGuard guard; streams.reserve(tensors.size()); events.resize(tensors.size()); for (size_t i = 0; i < tensors.size(); i++) { guard.set_index(tensors[i].device().index()); // Record event on current stream events[i].record(at::cuda::getCurrentCUDAStream()); // Get a non-default stream to execute asynchronous CUDA operations // on for this device. This ensures that the default stream used // by the caller is not occupied by c10d related operations. streams.push_back(at::cuda::getStreamFromPool( /* isHighPriority */ true, tensors[i].device().index())); // Ensure the new stream is synchronized with the current stream. events[i].block(streams[i]); // `tensors` are created on a different stream. Hence, they must record // new streams in this Work to prevent being freed before the Work finishes. if (tensors[i].is_sparse()) { if (tensors[i].is_coalesced()) { c10::cuda::CUDACachingAllocator::recordStream( tensors[i].indices().storage().data(), streams[i]); c10::cuda::CUDACachingAllocator::recordStream( tensors[i].values().storage().data(), streams[i]); } else { // We will need to coalesce first, which means new tensors will // be allocated on the streams we just allocated, and there // is no need to record them separately. } } else { c10::cuda::CUDACachingAllocator::recordStream( tensors[i].storage().data(), streams[i]); } } } // This function initializes a vector of CUDA streams, one per device, // and ensures that these streams are synchronized with the current default // streams. It is assumed that the tensors in the nested tensor vectors are // on the same device. void initializeStreamsEvents( std::vector>& tensors, std::vector& streams, std::vector& events) { // Ensure that the tensors in the nested tensor vectors are on the same // device. for (size_t i = 0; i < tensors.size(); i++) { auto device_id = tensors[i][0].device().index(); for (size_t j = 1; j < tensors[i].size(); j++) { if (tensors[i][j].device().index() != device_id) { throw std::runtime_error( "tensors in the nested tensor vectors need to " "be on the same device"); } } } at::cuda::OptionalCUDAGuard guard; streams.reserve(tensors.size()); events.resize(tensors.size()); for (size_t i = 0; i < tensors.size(); i++) { guard.set_index(tensors[i][0].device().index()); // Record event on current stream events[i].record(at::cuda::getCurrentCUDAStream()); // Get a non-default stream to execute asynchronous CUDA operations // on for this output. This ensures that the default stream used // by the caller is not occupied by c10d related operations. streams.push_back(at::cuda::getStreamFromPool( /* isHighPriority */ true, tensors[i][0].device().index())); // Ensure the new stream is synchronized with the current stream. events[i].block(streams[i]); for (at::Tensor& tensor : tensors[i]) { // `tensors` are created on a different stream. Hence, they must record // new streams in this Work to prevent being freed before the Work // finishes. c10::cuda::CUDACachingAllocator::recordStream( tensor.storage().data(), streams[i]); } } } #endif const auto kLoopbackAddress = "127.0.0.1"; } // namespace ProcessGroupGloo::SendWork::SendWork( at::Tensor& tensor, std::unique_ptr<::gloo::transport::UnboundBuffer> buffer) : tensor_(tensor), buffer_(std::move(buffer)) {} bool ProcessGroupGloo::SendWork::wait() { bool sendCompleted = false; std::unique_lock lock(mutex_); try { sendCompleted = buffer_->waitSend(); } catch (...) { exception_ = std::current_exception(); } completed_ = true; if (exception_) { std::rethrow_exception(exception_); } return sendCompleted; } void ProcessGroupGloo::SendWork::abort() { buffer_->abortWaitSend(); } ProcessGroupGloo::RecvWork::RecvWork( at::Tensor& tensor, std::unique_ptr<::gloo::transport::UnboundBuffer> buffer) : tensor_(tensor), buffer_(std::move(buffer)), srcRank_(-1) {} int ProcessGroupGloo::RecvWork::sourceRank() const { std::lock_guard lock(mutex_); return srcRank_; } bool ProcessGroupGloo::RecvWork::wait() { bool recvCompleted = false; std::unique_lock lock(mutex_); try { recvCompleted = buffer_->waitRecv(&srcRank_); } catch (...) { exception_ = std::current_exception(); } completed_ = true; if (exception_) { std::rethrow_exception(exception_); } return recvCompleted; } void ProcessGroupGloo::RecvWork::abort() { buffer_->abortWaitRecv(); } ProcessGroupGloo::Options::Options() : timeout(std::chrono::milliseconds(10 * 1000)), threads(2) {} namespace { // Gloo assumes that this machine's hostname can always be resolved // to an address. If it doesn't it throws a runtime error saying // that it can't be resolved. Instead of catching it, we choose // to proactively check if an address can be resolved, so we can // gracefully fall back to an alternative if it doesn't. bool doesHostnameResolveToUsableAddress(const std::string& hostname) { struct addrinfo hints; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; struct addrinfo* result; auto rv = getaddrinfo(hostname.c_str(), nullptr, &hints, &result); if (rv < 0) { return false; } struct addrinfo* rp; for (rp = result; rp != nullptr; rp = rp->ai_next) { auto fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); if (fd == -1) { continue; } rv = bind(fd, rp->ai_addr, rp->ai_addrlen); close(fd); if (rv == -1) { continue; } break; } freeaddrinfo(result); return rp != nullptr; } } // namespace #if defined(__linux__) || defined(__APPLE__) std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: createDeviceForInterface(const std::string& interface) { return ::c10d::GlooDeviceFactory::makeDeviceForInterface(interface); } #endif #if defined(__linux__) || defined(__APPLE__) std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: createDeviceForHostname(const std::string& hostname) { TORCH_CHECK( doesHostnameResolveToUsableAddress(hostname), "Cannot resolve ", hostname, " to a (local) address"); return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname); } #endif #ifdef __linux__ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: createDefaultDevice() { // Use the hostname to resolve the network address to // use. Note: if the hostname does not resolve to an address (e.g. // because of misconfigured /etc/hosts file), this will not work. std::array hostname{}; auto rv = gethostname(hostname.data(), HOST_NAME_MAX); if (rv != 0) { throw std::system_error(errno, std::system_category()); } // Use this machine's hostname if it resolves to an address. if (doesHostnameResolveToUsableAddress(hostname.data())) { return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.data()); } // Otherwise, use the loopback address. TORCH_WARN_ONCE( "Unable to resolve hostname to a (local) address. ", "Using the loopback address as fallback. ", "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME."); return createDeviceForHostname(kLoopbackAddress); } #endif #ifdef __APPLE__ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: createDefaultDevice() { // Use the hostname to resolve the network address to // use. Note: if the hostname does not resolve to an address (e.g. // because of misconfigured /etc/hosts file), this will not work. const auto hostNameMax = sysconf(_SC_HOST_NAME_MAX); auto hostname = std::unique_ptr(new char[hostNameMax]); auto rv = gethostname(hostname.get(), hostNameMax); if (rv != 0) { throw std::system_error(errno, std::system_category()); } // Use this machine's hostname if it resolves to an address. if (doesHostnameResolveToUsableAddress(hostname.get())) { return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.get()); } // Otherwise, use the loopback address. TORCH_WARN_ONCE( "Unable to resolve hostname to a (local) address. ", "Using the loopback address as fallback. ", "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME."); return createDeviceForHostname(kLoopbackAddress); } #endif ProcessGroupGloo::ProcessGroupGloo( const std::shared_ptr& store, int rank, int size, Options options) : ProcessGroup(rank, size), store_(new GlooStore(store)), stop_(false), collectiveCounter_(0) { auto& devices = options.devices; if (devices.empty()) { throw std::runtime_error("No device(s) specified"); } // Create and connect a context for every device. // // Note that the same device can be specified multiple times, either // the same object, or the same logical device as different objects. // Either mode is fine and only has performance implications. // // Using the same object multiple times means all contexts share a // single I/O thread. If you use different objects for the same // logical device they will have independent I/O threads. The latter // option is needed if you have a fast NIC that cannot be saturated // by a single I/O thread. // contexts_.reserve(options.devices.size()); for (size_t i = 0; i < options.devices.size(); i++) { auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_); auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_); context->setTimeout(options.timeout); context->connectFullMesh(store, options.devices[i]); contexts_.push_back(std::move(context)); } // Every worker thread stores the AsyncWork object it's currently // working on in the workInProgress_ vector. It must have size equal // to the number of workers such that they can simply index into it // using the worker index they are started with. workInProgress_.resize(options.threads); threads_.resize(options.threads); for (size_t i = 0; i < threads_.size(); i++) { threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this, i); } } ProcessGroupGloo::~ProcessGroupGloo() { std::unique_lock lock(workMutex_); workConsumeCV_.wait(lock, [&] { return workQueue_.empty(); }); // Queue is empty, signal stop stop_ = true; // Release lock to allow threads to terminate lock.unlock(); workProduceCV_.notify_all(); // Wait for worker threads to terminate for (auto& thread : threads_) { thread.join(); } } uint32_t ProcessGroupGloo::nextTag() { return collectiveCounter_++; } std::shared_ptr<::gloo::Context> ProcessGroupGloo::getContext(uint32_t tag) { return contexts_[tag % contexts_.size()]; } void ProcessGroupGloo::runLoop(int workerIndex) { std::unique_lock lock(workMutex_); while (!stop_) { if (workQueue_.empty()) { workProduceCV_.wait(lock); continue; } auto work = std::move(workQueue_.front()); workQueue_.pop_front(); workInProgress_[workerIndex] = work; lock.unlock(); // Notify after releasing the lock so that the waiter // does not immediately block. workConsumeCV_.notify_one(); AsyncWork::execute(std::move(work)); lock.lock(); workInProgress_[workerIndex] = nullptr; } } void ProcessGroupGloo::enqueue(std::shared_ptr work) { std::unique_lock lock(workMutex_); workQueue_.push_back(std::move(work)); lock.unlock(); // Notify after releasing the lock so that the waiter // does not immediately block. workProduceCV_.notify_one(); } namespace { class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { public: AsyncBroadcastWork( const std::shared_ptr& context, std::vector& inputs, int rootRank, int rootTensor, uint32_t tag) : context(context), inputs(inputs), rootRank(rootRank), rootTensor(rootTensor), tag(tag) {} std::shared_ptr context; std::vector inputs; const int rootRank; const int rootTensor; const uint32_t tag; void broadcast(at::Tensor& tensor) { const auto& scalarType = tensor.scalar_type(); gloo::BroadcastOptions opts(context); opts.setRoot(rootRank); opts.setTag(tag); GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensor); gloo::broadcast(opts); } void run() override { broadcast(inputs[rootTensor]); // Copy to non-root tensors for (size_t i = 0; i < inputs.size(); i++) { if (i == static_cast(rootTensor)) { continue; } inputs[i].copy_(inputs[rootTensor]); } } }; #ifdef USE_CUDA class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { public: AsyncBroadcastCUDAWork( const std::shared_ptr& context, std::vector& inputs, int rootRank, int rootTensor, uint32_t tag) : AsyncBroadcastWork(context, inputs, rootRank, rootTensor, tag) { initializeStreamsEvents(inputs, streams, events); // Create pinned host side tensors. tmp = pinnedLike(inputs[rootTensor]); at::cuda::OptionalCUDAStreamGuard guard; if (context->rank == rootRank) { guard.reset_stream(streams[rootTensor]); tmp.copy_(inputs[rootTensor], /* non_blocking */ true); } } void run() override { at::cuda::OptionalCUDAStreamGuard guard; // Synchronize with copy operation if applicable. if (context->rank == rootRank) { guard.reset_stream(streams[rootTensor]); AT_CUDA_CHECK(cudaStreamSynchronize(streams[rootTensor])); } // Run broadcast on host side tensors. broadcast(tmp); // Kick off copy back to the CUDA tensors. for (size_t i = 0; i < inputs.size(); i++) { guard.reset_stream(streams[i]); inputs[i].copy_(tmp, /* non_blocking */ true); events[i].record(streams[i]); } } void synchronize() override { at::cuda::OptionalCUDAGuard guard; // Synchronize with the copy back to CUDA tensors. for (size_t i = 0; i < inputs.size(); i++) { guard.set_index(inputs[i].device().index()); events[i].block(at::cuda::getCurrentCUDAStream()); } } at::Tensor tmp; std::vector streams; std::vector events; }; #endif } // namespace std::shared_ptr ProcessGroupGloo::broadcast( std::vector& inputs, const BroadcastOptions& opts) { static auto invalidArgument = [](const std::string& msg) { throw std::invalid_argument("ProcessGroupGloo::broadcast: " + msg); }; assertRootRank(invalidArgument, opts.rootRank, size_); assertRootTensor(invalidArgument, opts.rootTensor, inputs.size()); assertDense(invalidArgument, inputs); assertTypeAndSizesMatch(invalidArgument, inputs); const auto& device = inputs[0].device(); switch (device.type()) { case at::kCPU: #ifdef USE_CUDA case at::kCUDA: #endif break; default: invalidArgument("unsupported device type"); } std::shared_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { work = std::make_shared( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { work = std::make_shared( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); #endif } else { throw std::runtime_error("Invalid backend"); } enqueue(work); return work; } namespace { class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { public: AsyncAllreduceWork( const std::shared_ptr& context, std::vector& inputs, ReduceOp reduceOp, uint32_t tag) : context(context), inputs(inputs), reduceOp(reduceOp), tag(tag) {} std::shared_ptr context; std::vector inputs; const ReduceOp reduceOp; const uint32_t tag; void allreduce(std::vector& tensors) { const auto& scalarType = tensors[0].scalar_type(); gloo::AllreduceOptions opts(context); opts.setReduceFunction(getFunction(scalarType, reduceOp)); opts.setTag(tag); GENERATE_ALL_TYPES(scalarType, setOutputs, opts, tensors); gloo::allreduce(opts); } void run() override { allreduce(inputs); // Only the first output in the tensor list contains the results. // See https://github.com/facebookincubator/gloo/issues/152. // The contents is the same for every entry in the tensor list, so // we can use the first entry as the source of the copy below. for (size_t i = 1; i < inputs.size(); i++) { inputs[i].copy_(inputs[0]); } } template void getFunction(gloo::AllreduceOptions::Func& fn, const ReduceOp op) { fn = toFunction(op); } gloo::AllreduceOptions::Func getFunction( const at::ScalarType& dtype, const ReduceOp op) { gloo::AllreduceOptions::Func fn; GENERATE_ALL_TYPES(dtype, getFunction, fn, op); return fn; } }; class AsyncAllreduceCoalescedWork : public AsyncAllreduceWork { public: AsyncAllreduceCoalescedWork( const std::shared_ptr& context, std::vector& inputs, ReduceOp reduceOp, uint32_t tag) : AsyncAllreduceWork(context, inputs, reduceOp, tag) {} void run() override { allreduceCoalesced(inputs); } private: void allreduceCoalesced(std::vector& tensors) { // reduce coalesced, flattened tensors. at::Tensor coalescedTensor = flattenDenseTensors(tensors); std::vector allreduceInput = {coalescedTensor}; allreduce(allreduceInput); // separate and reshape tensors. size_t offset = 0; for (at::Tensor& tensor : tensors) { const int64_t tensorNumel = tensor.numel(); const c10::IntArrayRef tensorShape = tensor.sizes(); tensor.copy_(coalescedTensor.slice(0, offset, offset + tensorNumel) .view(tensorShape)); offset += tensorNumel; } } }; class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { public: AsyncSparseAllreduceWork( const std::shared_ptr& context, std::vector& inputs, uint32_t tag) : context(context), inputs(inputs), tag(tag) {} std::shared_ptr context; std::vector inputs; std::vector outputs; const uint32_t tag; // We share dimensionality about the sparse tensors before collecting // their contents. We assume here that the maximum number of sparse // and dense dimensions is 4. This is stored in a contiguous piece of // memory so that we can easily run allgather on it. // // The layout of this memory is as follows: // // - [0:4]: sparse dims // - [4:8]: dense dims // - [8]: nnz // class SparseTensorMetadata { public: static constexpr auto dim = 9; // Construct from an existing metadata tensor to facilitate structured // access to metadata from peers, after gathering it. explicit SparseTensorMetadata(at::Tensor metadata) : metadata_(metadata), data_(metadata_.data_ptr()) { AT_ASSERT(metadata.scalar_type() == at::kLong); AT_ASSERT(metadata.dim() == 1); AT_ASSERT(metadata.size(0) == dim); } // Populate the metadata. void populate_from_sparse_tensor(const at::Tensor& tensor) { const auto sparse_dim = tensor.sparse_dim(); AT_ASSERT(sparse_dim <= 4); for (auto i = 0; i < 4; i++) { if (i < sparse_dim) { data_[i] = tensor.size(i); } } const auto dense_dim = tensor.dense_dim(); AT_ASSERT(dense_dim <= 4); for (auto i = 0; i < 4; i++) { if (i < dense_dim) { data_[i + 4] = tensor.size(sparse_dim + i); } } data_[8] = tensor._nnz(); } std::vector sizes() const { std::vector sizes; // Sparse sizes for (auto i = 0; i < 4; i++) { if (data_[i] <= 0) { break; } sizes.push_back(data_[i]); } // Dense sizes for (auto i = 4; i < 8; i++) { if (data_[i] <= 0) { break; } sizes.push_back(data_[i]); } return sizes; } int64_t nnz() const { return data_[8]; } protected: at::Tensor metadata_; int64_t* data_; }; // Sparse allreduce is implemented with allgather on indices and values. // Every process then sums the resulting sparse tensors locally. // The nnz for sparse tensors may be different across processes, so first // we run allgather on the nnz, and then allgather with max(nnz). // We could use an allgatherv for this, if it were available. at::Tensor allreduce(std::vector& tensors) { // TODO: This is a massive hack! There is some confusion about // Variable/Tensor inside the body of this function. Turning off // grad smooths over the confusion for now. This fixes // test/test_c10d.py ProcessGroupGlooTest.test_sparse_allreduce_basics // // The correct fix is to stop allocating tensors that are not variables, // but to conveniently do this c10d must depend on torch not ATen at::AutoNonVariableTypeMode _no_grad(true); auto input = tensors[0]; // Perform local reduction if we have multiple inputs. for (size_t i = 1; i < tensors.size(); i++) { input += tensors[i]; } // Need to coalesce before we can access indices and values. input = input.coalesce(); // Gather metadata information from all ranks. auto metadata = allgather_metadata(input); // Sanity check dimensionality across ranks. { const auto expected = metadata[context->rank].sizes(); for (auto i = 0; i < context->size; i++) { if (i == context->rank) { continue; } const auto actual = metadata[i].sizes(); AT_CHECK(actual == expected, "Sparse dimensions do not match"); } } // Gather all indices and all values. auto indices = allgather_indices(input, metadata); auto values = allgather_values(input, metadata); // Perform global reduction. AT_ASSERT(static_cast(indices.size()) == context->size); AT_ASSERT(static_cast(values.size()) == context->size); auto output = at::sparse_coo_tensor( indices[0], values[0], input.sizes(), input.options()); for (auto i = 1; i < context->size; i++) { output += at::sparse_coo_tensor( indices[i], values[i], input.sizes(), input.options()); } // Coalesce for good measure. return output.coalesce(); } void run() override { auto output = allreduce(inputs); // Copy back to input tensors. outputs.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { if (output.is_sparse()) { outputs.push_back(output.clone()); } else { outputs.push_back(output.clone(at::MemoryFormat::Contiguous)); } } } std::vector result() const override { return outputs; } private: std::vector allgather_metadata( const at::Tensor& tensor) { auto buffer = at::zeros({context->size, SparseTensorMetadata::dim}, at::kLong); // Prepare metadata vector (1 entry per rank) std::vector metadata; metadata.reserve(context->size); for (auto i = 0; i < context->size; i++) { metadata.emplace_back(buffer.select(0, i)); } // Populate data for this rank metadata[context->rank].populate_from_sparse_tensor(tensor); // Allgather metadata gloo::AllgatherOptions opts(context); opts.setOutput(buffer.data_ptr(), buffer.numel()); opts.setTag(tag); gloo::allgather(opts); return metadata; } std::vector allgather_indices( const at::Tensor& tensor, const std::vector& metadata) { const auto sparseDim = tensor.sparse_dim(); std::vector counts(context->size); int64_t totalSize = 0; for (size_t i = 0; i < metadata.size(); i++) { counts[i] = metadata[i].nnz() * sparseDim; totalSize += counts[i]; } auto output = at::empty({totalSize}, at::kLong); // tensors copied from cuda may not be contiguous, get a contiguous // tensor before use its data_ptr auto input = tensor.indices().contiguous(); // Allgatherv indices. gloo::AllgathervOptions opts(context); opts.setInput(input.data_ptr(), input.numel()); opts.setOutput(output.data_ptr(), counts); opts.setTag(tag); gloo::allgatherv(opts); // Compile indices tensor per rank. std::vector indices; indices.reserve(metadata.size()); size_t offset = 0; for (size_t i = 0; i < metadata.size(); i++) { const auto nnz = metadata[i].nnz(); const auto numel = sparseDim * nnz; indices.push_back( output.narrow(0, offset, numel).reshape({sparseDim, nnz})); offset += numel; } return indices; } std::vector allgather_values( const at::Tensor& tensor, const std::vector& metadata) { // There are nnz #dense_dim()-dimensional tensors per rank. const auto valueShape = tensor.sizes().slice(tensor.sparse_dim()); size_t denseNumel = 1; for (auto dim : valueShape) { denseNumel *= dim; } std::vector counts(context->size); int64_t totalSize = 0; for (size_t i = 0; i < metadata.size(); i++) { counts[i] = metadata[i].nnz() * denseNumel; totalSize += counts[i]; } auto output = at::empty({totalSize}, tensor.scalar_type()); // Allgatherv indices. gloo::AllgathervOptions opts(context); // tensors copied from cuda may not be contiguous, get a contiguous // tensor before use its data_ptr at::Tensor valueTensor = tensor.values().contiguous(); GENERATE_ALL_TYPES(valueTensor.scalar_type(), setInput, opts, valueTensor); GENERATE_ALL_TYPES( valueTensor.scalar_type(), setOutput, opts, output, counts); opts.setTag(tag); gloo::allgatherv(opts); // Compile values tensor per rank. std::vector values; values.reserve(metadata.size()); size_t offset = 0; for (size_t i = 0; i < metadata.size(); i++) { const auto nnz = metadata[i].nnz(); const auto numel = denseNumel * nnz; auto tensorShape = std::vector({(int64_t)nnz}); std::copy( valueShape.begin(), valueShape.end(), std::back_inserter(tensorShape)); values.push_back(output.narrow(0, offset, numel).reshape(tensorShape)); offset += numel; } return values; } }; #ifdef USE_CUDA class AsyncAllreduceCUDAWork : public AsyncAllreduceWork { public: AsyncAllreduceCUDAWork( const std::shared_ptr& context, std::vector& inputs, ReduceOp reduceOp, uint32_t tag) : AsyncAllreduceWork(context, inputs, reduceOp, tag) { initializeStreamsEvents(inputs, streams, events); // Kick off copy from CUDA tensors to pinned CPU tensors. tmp.reserve(inputs.size()); at::cuda::OptionalCUDAStreamGuard guard; for (size_t i = 0; i < inputs.size(); i++) { guard.reset_stream(streams[i]); tmp.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); } } void run() override { // Synchronize with copy operations. at::cuda::OptionalCUDAGuard device_guard; for (size_t i = 0; i < inputs.size(); i++) { device_guard.set_index(inputs[i].device().index()); AT_CUDA_CHECK(cudaStreamSynchronize(streams[i])); } // Run allreduce on host side tensors. allreduce(tmp); // Kick off copy back to the CUDA tensors. // Only the first output in the tensor list contains the results. // See https://github.com/facebookincubator/gloo/issues/152. // The contents is the same for every entry in the tensor list, so // we can use the first entry as the source of the copy below. at::cuda::OptionalCUDAStreamGuard stream_guard; for (size_t i = 0; i < inputs.size(); i++) { stream_guard.reset_stream(streams[i]); inputs[i].copy_(tmp[0], /* non_blocking */ true); events[i].record(streams[i]); } } void synchronize() override { // Synchronize with the copy back to CUDA tensors. at::cuda::OptionalCUDAGuard guard; for (size_t i = 0; i < inputs.size(); i++) { guard.set_index(inputs[i].device().index()); events[i].block(at::cuda::getCurrentCUDAStream()); } } std::vector tmp; std::vector streams; std::vector events; }; class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork { public: AsyncSparseAllreduceCUDAWork( const std::shared_ptr& context, std::vector& inputs, uint32_t tag) : AsyncSparseAllreduceWork(context, inputs, tag) { initializeStreamsEvents(inputs, streams, events); // Kick off copy from CUDA tensors to CPU tensors. // Note that both coalescing the sparse tensor and copying it to CPU // memory must be performed asynchronously, or we block the caller. tmp.reserve(inputs.size()); at::cuda::OptionalCUDAStreamGuard guard; for (size_t i = 0; i < inputs.size(); i++) { guard.reset_stream(streams[i]); tmp.push_back( inputs[i].coalesce().to(at::DeviceType::CPU, /*non_blocking=*/true)); } } void run() override { // Synchronize with copy operations. at::cuda::OptionalCUDAGuard device_guard; for (size_t i = 0; i < inputs.size(); i++) { device_guard.set_index(inputs[i].device().index()); AT_CUDA_CHECK(cudaStreamSynchronize(streams[i])); } // Run allreduce on host side tensors. auto output = allreduce(tmp); // Kick off copy back to the CUDA tensors. at::cuda::OptionalCUDAStreamGuard stream_guard; for (size_t i = 0; i < inputs.size(); i++) { stream_guard.reset_stream(streams[i]); outputs.push_back(output.to(inputs[i].device(), /*non_blocking=*/true)); events[i].record(streams[i]); } } void synchronize() override { // Synchronize with the copy back to CUDA tensors. at::cuda::OptionalCUDAGuard guard; for (size_t i = 0; i < inputs.size(); i++) { guard.set_index(inputs[i].device().index()); events[i].block(at::cuda::getCurrentCUDAStream()); } } std::vector tmp; std::vector streams; std::vector events; }; #endif } // namespace std::shared_ptr ProcessGroupGloo::allreduce( std::vector& inputs, const AllreduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { throw std::invalid_argument("ProcessGroupGloo::allreduce: " + msg); }; assertNonEmpty(invalidArgument, inputs); assertLayoutMatch(invalidArgument, inputs); assertTypeAndSizesMatch(invalidArgument, inputs); const auto& device = inputs[0].device(); switch (device.type()) { case at::kCPU: #ifdef USE_CUDA case at::kCUDA: #endif break; default: invalidArgument("unsupported device type"); } const auto& layout = inputs[0].layout(); if (layout == c10::kSparse && opts.reduceOp != ReduceOp::SUM) { invalidArgument( "unsupported reduction operation " "(allreduce of sparse tensors only works with ReduceOp.SUM)"); } std::shared_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { if (layout == c10::kStrided) { work = std::make_shared( std::move(context), inputs, opts.reduceOp, tag); } else if (layout == c10::kSparse) { work = std::make_shared( std::move(context), inputs, tag); } else { invalidArgument("unsupported layout"); } #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { if (layout == c10::kStrided) { work = std::make_shared( std::move(context), inputs, opts.reduceOp, tag); } else if (layout == c10::kSparse) { work = std::make_shared( std::move(context), inputs, tag); } else { invalidArgument("unsupported layout"); } #endif } else { throw std::runtime_error("Invalid backend"); } enqueue(work); return work; } std::shared_ptr ProcessGroupGloo::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { static auto invalidArgument = [](const std::string& msg) { throw std::invalid_argument( "ProcessGroupGloo::allreduce_coalesced: " + msg); }; assertNonEmpty(invalidArgument, tensors); // tensors will be flattened and concatenated (coalesced). This means that // input // tensors must have the same device, layout and type. assertLayoutMatch(invalidArgument, tensors); if (!std::all_of(tensors.begin(), tensors.end(), [&](at::Tensor& t) { return t.options().type_equal(tensors[0].options()); })) { invalidArgument("tensors must all have the same type"); } if (!std::all_of(tensors.begin(), tensors.end(), [&](at::Tensor& t) { return t.device() == tensors[0].device(); })) { invalidArgument("tensors must all be on the same device"); } const c10::Device& device = tensors[0].device(); const c10::Layout& layout = tensors[0].layout(); // invalid arguments are detected early here before any calls to nextTag() // which result in the collectiveCounter_ being incremented. switch (device.type()) { case c10::kCPU: break; default: invalidArgument("unsupported device type"); } switch (layout) { case c10::kStrided: break; default: invalidArgument("unsupported layout"); } std::shared_ptr work; const uint32_t tag = nextTag(); std::shared_ptr context = getContext(tag); if (device.type() == c10::kCPU) { if (layout == c10::kStrided) { work = std::make_shared( std::move(context), tensors, opts.reduceOp, tag); } else { invalidArgument("unsupported layout"); } } else { throw std::runtime_error("Invalid backend"); } enqueue(work); return work; } namespace { class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { public: AsyncReduceWork( const std::shared_ptr& context, std::vector& inputs, int rootRank, int rootTensor, ReduceOp reduceOp, uint32_t tag) : context(context), inputs(inputs), rootRank(rootRank), rootTensor(rootTensor), reduceOp(reduceOp), tag(tag) {} std::shared_ptr context; std::vector inputs; const int rootRank; const int rootTensor; const ReduceOp reduceOp; const uint32_t tag; void reduce(std::vector& tensors) { const auto& scalarType = tensors[0].scalar_type(); gloo::ReduceOptions opts(context); opts.setRoot(rootRank); opts.setTag(tag); opts.setReduceFunction(getFunction(scalarType, reduceOp)); GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensors[0]); gloo::reduce(opts); } void run() override { reduce(inputs); } protected: template void getFunction(gloo::ReduceOptions::Func& fn, const ReduceOp op) { fn = toFunction(op); } gloo::ReduceOptions::Func getFunction( const at::ScalarType& dtype, const ReduceOp op) { gloo::ReduceOptions::Func fn; GENERATE_ALL_TYPES(dtype, getFunction, fn, op); return fn; } }; #ifdef USE_CUDA class AsyncReduceCUDAWork : public AsyncReduceWork { public: AsyncReduceCUDAWork( const std::shared_ptr& context, std::vector& inputs, int rootRank, int rootTensor, ReduceOp reduceOp, uint32_t tag) : AsyncReduceWork(context, inputs, rootRank, rootTensor, reduceOp, tag) { initializeStreamsEvents(inputs, streams, events); // Kick off copy from CUDA tensors to pinned CPU tensors. tmp.reserve(inputs.size()); at::cuda::OptionalCUDAStreamGuard guard; for (size_t i = 0; i < inputs.size(); i++) { guard.reset_stream(streams[i]); tmp.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); } } void run() override { // Synchronize with copy operations. at::cuda::OptionalCUDAGuard device_guard; for (size_t i = 0; i < inputs.size(); i++) { device_guard.set_index(inputs[i].device().index()); AT_CUDA_CHECK(cudaStreamSynchronize(streams[i])); } // Run reduce on host side tensors. reduce(tmp); // Kick off copy back to the CUDA tensors. at::cuda::OptionalCUDAStreamGuard stream_guard; for (size_t i = 0; i < inputs.size(); i++) { stream_guard.reset_stream(streams[i]); inputs[i].copy_(tmp[i], /* non_blocking */ true); events[i].record(streams[i]); } } void synchronize() override { // Synchronize with the copy back to CUDA tensors. at::cuda::OptionalCUDAGuard guard; for (size_t i = 0; i < inputs.size(); i++) { guard.set_index(inputs[i].device().index()); events[i].block(at::cuda::getCurrentCUDAStream()); } } std::vector tmp; std::vector streams; std::vector events; }; #endif } // namespace std::shared_ptr ProcessGroupGloo::reduce( std::vector& inputs, const ReduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { throw std::invalid_argument("ProcessGroupGloo::reduce: " + msg); }; assertRootRank(invalidArgument, opts.rootRank, size_); assertRootTensor(invalidArgument, opts.rootTensor, inputs.size()); assertSingleElement(invalidArgument, inputs); assertDense(invalidArgument, inputs); const auto& device = inputs[0].device(); switch (device.type()) { case at::kCPU: #ifdef USE_CUDA case at::kCUDA: #endif break; default: invalidArgument("unsupported device type"); } std::shared_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { work = std::make_shared( std::move(context), inputs, opts.rootRank, opts.rootTensor, opts.reduceOp, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { work = std::make_shared( std::move(context), inputs, opts.rootRank, opts.rootTensor, opts.reduceOp, tag); #endif } else { throw std::runtime_error("Invalid backend"); } enqueue(work); return work; } namespace { class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { public: AsyncAllgatherWork( const std::shared_ptr& context, std::vector>& outputs, std::vector& inputs, uint32_t tag) : context(context), outputs(outputs), inputs(inputs), tag(tag) {} std::shared_ptr context; std::vector> outputs; std::vector inputs; const uint32_t tag; void allgather( std::vector>& outputs, std::vector& inputs) { const auto& scalarType = inputs[0].scalar_type(); gloo::AllgatherOptions opts(context); opts.setTag(tag); // Use single flattened input tensor. at::Tensor flatInputTensor = flattenDenseTensors(inputs); GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor); // Use single flat output tensor. // The first dimension corresponds to the index into outputs[N], // so copying into the actual output later is easy. at::Tensor flatOutputTensor = newLikeFlat(outputs[0]); GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor); gloo::allgather(opts); // Unflatten into output tensors. for (size_t i = 0; i < outputs.size(); i++) { for (size_t j = 0; j < outputs[i].size(); j++) { outputs[i][j].copy_(flatOutputTensor[j]); } } } void run() override { allgather(outputs, inputs); } }; #ifdef USE_CUDA // Note: current CUDA implementation holds the assumption that the // tensors in the nested output tensor vectors are on the same device. class AsyncAllgatherCUDAWork : public AsyncAllgatherWork { public: AsyncAllgatherCUDAWork( const std::shared_ptr& context, std::vector>& outputs, std::vector& inputs, uint32_t tag) : AsyncAllgatherWork(context, outputs, inputs, tag) { initializeStreamsEvents(inputs, inputStreams, inputEvents); initializeStreamsEvents(outputs, outputStreams, outputEvents); // Kick off copy from CUDA tensors to pinned CPU tensors. tmpInputs.reserve(inputs.size()); at::cuda::OptionalCUDAStreamGuard guard; for (size_t i = 0; i < inputs.size(); i++) { guard.reset_stream(inputStreams[i]); tmpInputs.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); } tmpOutputs.resize(outputs.size()); for (size_t i = 0; i < outputs.size(); i++) { tmpOutputs[i].reserve(outputs[i].size()); for (size_t j = 0; j < outputs[i].size(); j++) { tmpOutputs[i].push_back(pinnedLike(outputs[i][j])); } } } void run() override { // Synchronize with copy operations. at::cuda::OptionalCUDAGuard device_guard; for (size_t i = 0; i < inputs.size(); i++) { device_guard.set_index(inputs[i].device().index()); AT_CUDA_CHECK(cudaStreamSynchronize(inputStreams[i])); } for (size_t i = 0; i < outputs.size(); i++) { device_guard.set_index(outputs[i][0].device().index()); AT_CUDA_CHECK(cudaStreamSynchronize(outputStreams[i])); } // Run allgather on host side tensors. allgather(tmpOutputs, tmpInputs); // Kick off copy back to the CUDA tensors. at::cuda::OptionalCUDAStreamGuard stream_guard; for (size_t i = 0; i < outputs.size(); i++) { stream_guard.reset_stream(outputStreams[i]); for (size_t j = 0; j < outputs[i].size(); j++) { outputs[i][j].copy_(tmpOutputs[i][j], /* non_blocking */ true); } outputEvents[i].record(outputStreams[i]); } } void synchronize() override { // Synchronize with the copy back to CUDA tensors. at::cuda::OptionalCUDAGuard guard; for (size_t i = 0; i < outputs.size(); i++) { guard.set_index(outputs[i][0].device().index()); outputEvents[i].block(at::cuda::getCurrentCUDAStream()); } } std::vector tmpInputs; std::vector inputStreams; std::vector inputEvents; std::vector> tmpOutputs; std::vector outputStreams; std::vector outputEvents; }; #endif } // namespace // Note: current CUDA implementation holds the assumption that the // tensors in the nested output tensor vectors are on the same device. std::shared_ptr ProcessGroupGloo::allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts) { static auto invalidArgument = [](const std::string& msg) { throw std::invalid_argument("ProcessGroupGloo::allgather: " + msg); }; if (inputs.size() == 0) { invalidArgument("requires non-empty input tensor list"); } if (inputs.size() != outputs.size()) { invalidArgument( "requires input/output tensor lists to have the same length"); } for (size_t i = 0; i < outputs.size(); i++) { const auto expected = inputs.size() * getSize(); const auto actual = outputs[i].size(); if (actual != expected) { invalidArgument( "invalid output tensor list at index " + std::to_string(i) + " (expected length " + std::to_string(expected) + ", got " + std::to_string(actual) + ")"); } } assertDense(invalidArgument, inputs); // Expect all input/output tensors to have the same type and sizes const auto& options = inputs[0].options(); const auto& sizes = inputs[0].sizes(); assertTypeAndSizesMatch(invalidArgument, inputs, options, sizes); for (size_t i = 0; i < outputs.size(); i++) { assertTypeAndSizesMatch(invalidArgument, outputs[i], options, sizes); } const auto& device = inputs[0].device(); switch (device.type()) { case at::kCPU: #ifdef USE_CUDA case at::kCUDA: #endif break; default: invalidArgument("unsupported device type"); } std::shared_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { work = std::make_shared( std::move(context), outputs, inputs, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { work = std::make_shared( std::move(context), outputs, inputs, tag); #endif } else { throw std::runtime_error("Invalid backend"); } enqueue(work); return work; } namespace { class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { public: AsyncAllgatherCoalescedWork( const std::shared_ptr& context, std::vector>& output_lists, std::vector& input_list, uint32_t tag) : context(context), output_lists(output_lists), input_list(input_list), tag(tag) {} std::shared_ptr context; std::vector> output_lists; std::vector input_list; const uint32_t tag; void allgather_coalesced() { assert(!output_lists.empty()); assert(!output_lists[0].empty()); assert(!input_list.empty()); const auto& scalarType = input_list[0].scalar_type(); gloo::AllgatherOptions opts(context); opts.setTag(tag); // Use single flattened input tensor. at::Tensor flatInputTensor = flattenDenseTensors(input_list); GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor); // Compute total number of elements we need to allocate for all tensors // requested. int64_t output_numel = 0; for (const auto& t : output_lists[0]) { output_numel += t.numel(); } output_numel *= output_lists.size(); // Use single flat output tensor. at::Tensor flatOutputTensor = at::empty({output_numel}, output_lists[0][0].options()); GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor); gloo::allgather(opts); int64_t current_element = 0; for (auto& output_list : output_lists) { for (auto& output_tensor : output_list) { output_tensor.copy_( flatOutputTensor.narrow(0, current_element, output_tensor.numel()) .reshape(output_tensor.sizes()), true); current_element += output_tensor.numel(); } } } void run() override { allgather_coalesced(); } }; } // namespace std::shared_ptr ProcessGroupGloo::allgather_coalesced( std::vector>& output_lists, std::vector& input_list, const AllgatherOptions& /* unused */) { static auto invalidArgument = [](const std::string& msg) { throw std::invalid_argument( "ProcessGroupGloo::allgather_coalesced: " + msg); }; if (input_list.empty()) { invalidArgument("requires non-empty input tensor list"); } if (output_lists.size() != getSize()) { invalidArgument("output lists should be equal to world size"); } assertSameDevice(invalidArgument, input_list); // Expect i'th tensor of each list from 'output_lists' match i'th tensor // from 'input_list' in type and size. for (const auto& output_list : output_lists) { if (output_list.size() != input_list.size()) { invalidArgument( "invalid output size: (expected length " + std::to_string(input_list.size()) + ", got " + std::to_string(output_list.size()) + ")"); } for (int i = 0; i < output_list.size(); ++i) { const auto expected = input_list[i].sizes(); const auto actual = output_list[i].sizes(); if (actual != expected) { invalidArgument( "invalid size of output tensor at index " + std::to_string(i) + " (expected length " + toString(expected) + ", got " + toString(actual) + ")"); } if (!input_list[i].options().type_equal(output_list[i].options())) { invalidArgument( "invalid tensor type at index " + std::to_string(i) + " (expected " + input_list[i].toString() + ", got " + output_list[i].toString() + ")"); } } } assertDense(invalidArgument, input_list); auto tag = nextTag(); auto context = getContext(tag); auto work = std::make_shared( std::move(context), output_lists, input_list, tag); enqueue(work); return work; } namespace { class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { public: AsyncGatherWork( const std::shared_ptr& context, std::vector>& outputs, std::vector& inputs, int root, uint32_t tag) : context(context), outputs(outputs), inputs(inputs), root(root), tag(tag) {} std::shared_ptr context; std::vector> outputs; std::vector inputs; const int root; const uint32_t tag; void gather( std::vector>& outputs, std::vector& inputs) { const auto scalarType = inputs[0].scalar_type(); gloo::GatherOptions opts(context); opts.setRoot(root); opts.setTag(tag); // Set single temporary tensor on root process. // This is later scattered to the separate output tensors. at::Tensor flatOutputTensor; if (context->rank == root) { flatOutputTensor = newLikeFlat(outputs[0]); GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor); } // Set single input tensor on all processes. GENERATE_ALL_TYPES(scalarType, setInput, opts, inputs[0]); gloo::gather(opts); // Unflatten into output tensors on root process. if (context->rank == root) { for (size_t i = 0; i < outputs[0].size(); i++) { outputs[0][i].copy_(flatOutputTensor[i]); } } } void run() override { gather(outputs, inputs); } }; #ifdef USE_CUDA // Note: current CUDA implementation holds the assumptions: // - inputs.size() is 1 // - outputs.size() is 1 // - the size of the nested output tensors is world size, i.e., // outputs[0].size, is world size class AsyncGatherCUDAWork : public AsyncGatherWork { public: AsyncGatherCUDAWork( const std::shared_ptr& context, std::vector>& outputs, std::vector& inputs, int root, uint32_t tag) : AsyncGatherWork(context, outputs, inputs, root, tag) { initializeStreamsEvents(inputs, inputStreams, inputEvents); initializeStreamsEvents(outputs, outputStreams, outputEvents); // Kick off copy from CUDA tensors to pinned CPU tensors. tmpInputs.reserve(inputs.size()); at::cuda::OptionalCUDAStreamGuard guard; for (size_t i = 0; i < inputs.size(); i++) { guard.reset_stream(inputStreams[i]); tmpInputs.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); } tmpOutputs.resize(outputs.size()); for (size_t i = 0; i < outputs.size(); i++) { tmpOutputs[i].reserve(outputs[i].size()); for (size_t j = 0; j < outputs[i].size(); j++) { tmpOutputs[i].push_back(pinnedLike(outputs[i][j])); } } } void run() override { // Synchronize with copy operations. at::cuda::OptionalCUDAGuard device_guard; for (size_t i = 0; i < inputs.size(); i++) { device_guard.set_index(inputs[i].get_device()); AT_CUDA_CHECK(cudaStreamSynchronize(inputStreams[i])); } for (size_t i = 0; i < outputs.size(); i++) { device_guard.set_index(outputs[i][0].get_device()); AT_CUDA_CHECK(cudaStreamSynchronize(outputStreams[i])); } // Run gather on host side tensors. gather(tmpOutputs, tmpInputs); // Kick off copy back to the CUDA tensors. at::cuda::OptionalCUDAStreamGuard stream_guard; for (size_t i = 0; i < outputs.size(); i++) { stream_guard.reset_stream(outputStreams[i]); for (size_t j = 0; j < outputs[i].size(); j++) { outputs[i][j].copy_(tmpOutputs[i][j], /* non_blocking */ true); } outputEvents[i].record(outputStreams[i]); } } void synchronize() override { // Synchronize with the copy back to CUDA tensors. at::cuda::OptionalCUDAGuard guard; for (size_t i = 0; i < outputs.size(); i++) { guard.set_index(static_cast(outputs[i][0].get_device())); outputEvents[i].block(at::cuda::getCurrentCUDAStream()); } } std::vector tmpInputs; std::vector inputStreams; std::vector inputEvents; std::vector> tmpOutputs; std::vector outputStreams; std::vector outputEvents; }; #endif } // namespace std::shared_ptr ProcessGroupGloo::gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts) { static auto invalidArgument = [](const std::string& msg) { throw std::invalid_argument("ProcessGroupGloo::gather: " + msg); }; assertRootRank(invalidArgument, opts.rootRank, size_); assertSingleElementInput(invalidArgument, inputs); assertDense(invalidArgument, inputs); if (getRank() == opts.rootRank) { if (outputs.size() != 1) { std::stringstream ss; ss << "requires a single-element output list containing a list with " << getSize() << " tensors."; invalidArgument(ss.str()); } else if (outputs[0].size() != static_cast(getSize())) { std::stringstream ss; ss << "Incorrect output list size " << outputs[0].size() << ". Output list size should be " << getSize() << ", same as size of the process group."; invalidArgument(ss.str()); } const auto& options = inputs[0].options(); const auto& sizes = inputs[0].sizes(); assertTypeAndSizesMatch(invalidArgument, outputs[0], options, sizes); } else { if (outputs.size() != 0) { invalidArgument("requires empty output on non-root"); } } const auto& device = inputs[0].device(); switch (device.type()) { case at::kCPU: #ifdef USE_CUDA case at::kCUDA: #endif break; default: invalidArgument("unsupported device type"); } std::shared_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { work = std::make_shared( std::move(context), outputs, inputs, opts.rootRank, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { work = std::make_shared( std::move(context), outputs, inputs, opts.rootRank, tag); #endif } else { throw std::runtime_error("Invalid backend"); } enqueue(work); return work; } namespace { class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { public: AsyncScatterWork( const std::shared_ptr& context, std::vector& outputs, std::vector>& inputs, int root, uint32_t tag) : context(context), outputs(outputs), inputs(inputs), root(root), tag(tag) {} std::shared_ptr context; std::vector outputs; std::vector> inputs; const int root; const uint32_t tag; void scatter( std::vector& outputs, std::vector>& inputs) { const auto scalarType = outputs[0].scalar_type(); gloo::ScatterOptions opts(context); opts.setRoot(root); opts.setTag(tag); // Set list of input tensors on root process if (context->rank == root) { GENERATE_ALL_TYPES(scalarType, setInputs, opts, inputs[0]); } // Set single output tensor on all processes GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputs[0]); gloo::scatter(opts); } void run() override { scatter(outputs, inputs); } }; #ifdef USE_CUDA class AsyncScatterCUDAWork : public AsyncScatterWork { public: AsyncScatterCUDAWork( const std::shared_ptr& context, std::vector& outputs, std::vector>& inputs, int root, uint32_t tag) : AsyncScatterWork(context, outputs, inputs, root, tag) { initializeStreamsEvents(inputs, inputStreams, inputEvents); initializeStreamsEvents(outputs, outputStreams, outputEvents); // Kick off copy from CUDA tensors to pinned CPU tensors. tmpInputs.resize(inputs.size()); at::cuda::OptionalCUDAStreamGuard guard; for (size_t i = 0; i < inputs.size(); i++) { guard.reset_stream(inputStreams[i]); tmpInputs[i].reserve(inputs[i].size()); for (size_t j = 0; j < inputs[i].size(); j++) { tmpInputs[i].push_back( pinnedLike(inputs[i][j]).copy_(inputs[i][j], true)); } } tmpOutputs.reserve(outputs.size()); for (size_t i = 0; i < outputs.size(); i++) { tmpOutputs.push_back(pinnedLike(outputs[i])); } } void run() override { // Synchronize with copy operations. at::cuda::OptionalCUDAGuard device_guard; for (size_t i = 0; i < inputs.size(); i++) { device_guard.set_index(inputs[i][0].get_device()); AT_CUDA_CHECK(cudaStreamSynchronize(inputStreams[i])); } for (size_t i = 0; i < outputs.size(); i++) { device_guard.set_index(outputs[i].get_device()); AT_CUDA_CHECK(cudaStreamSynchronize(outputStreams[i])); } // Run scatter on host side tensors. scatter(tmpOutputs, tmpInputs); // Kick off copy back to the CUDA tensors. at::cuda::OptionalCUDAStreamGuard stream_guard; for (size_t i = 0; i < outputs.size(); i++) { stream_guard.reset_stream(outputStreams[i]); outputs[i].copy_(tmpOutputs[i], /* non_blocking */ true); outputEvents[i].record(outputStreams[i]); } } void synchronize() override { // Synchronize with the copy back to CUDA tensors. at::cuda::OptionalCUDAGuard guard; for (size_t i = 0; i < outputs.size(); i++) { guard.set_index(static_cast(outputs[i].get_device())); outputEvents[i].block(at::cuda::getCurrentCUDAStream()); } } std::vector tmpOutputs; std::vector outputStreams; std::vector outputEvents; std::vector> tmpInputs; std::vector inputStreams; std::vector inputEvents; }; #endif } // namespace std::shared_ptr ProcessGroupGloo::scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts) { static auto invalidArgument = [](const std::string& msg) { throw std::invalid_argument("ProcessGroupGloo::scatter: " + msg); }; assertRootRank(invalidArgument, opts.rootRank, size_); assertSingleElementOutput(invalidArgument, outputs); assertDense(invalidArgument, outputs); if (getRank() == opts.rootRank) { if (inputs.size() != 1) { std::stringstream ss; ss << "requires a single-element input list containing a list with " << getSize() << " tensors"; invalidArgument(ss.str()); } else if (inputs[0].size() != static_cast(getSize())) { std::stringstream ss; ss << "Incorrect input list size " << inputs[0].size() << ". Input list size should be " << getSize() << ", same as size of the process group."; invalidArgument(ss.str()); } const auto& options = outputs[0].options(); const auto& sizes = outputs[0].sizes(); assertTypeAndSizesMatch(invalidArgument, inputs[0], options, sizes); } else { if (inputs.size() != 0) { invalidArgument("requires empty input on non-root"); } } const auto& device = outputs[0].device(); switch (device.type()) { case at::kCPU: #ifdef USE_CUDA case at::kCUDA: #endif break; default: invalidArgument("unsupported device type"); } std::shared_ptr work; auto tag = nextTag(); auto context = getContext(tag); if (device.type() == at::kCPU) { work = std::make_shared( std::move(context), outputs, inputs, opts.rootRank, tag); #ifdef USE_CUDA } else if (device.type() == at::kCUDA) { work = std::make_shared( std::move(context), outputs, inputs, opts.rootRank, tag); #endif } else { throw std::runtime_error("Invalid backend"); } enqueue(work); return work; } std::shared_ptr ProcessGroupGloo::reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts) { throw std::runtime_error("ProcessGroupGloo does not support reduce_scatter"); } at::Tensor& checkSingleTensor(std::vector& tensors) { if (tensors.size() != 1) { throw std::runtime_error("ProcessGroupGloo::send takes a single tensor"); } auto& tensor = tensors[0]; if (!tensor.is_contiguous()) { throw std::runtime_error("input tensor has to be contiguous"); } if (tensor.is_sparse()) { throw std::runtime_error("input tensor has to be dense"); } return tensor; } uint32_t checkTag(int32_t tag) { if (tag < 0) { throw std::runtime_error("Tag must be >= 0"); } return (uint32_t)tag; } std::shared_ptr ProcessGroupGloo::send( std::vector& tensors, int dstRank, int tag) { auto& tensor = checkSingleTensor(tensors); auto utag = checkTag(tag); auto ptr = tensor.data_ptr(); auto size = tensor.numel() * tensor.element_size(); // Construct unbound buffer. auto context = getContext(tag); auto buf = context->createUnboundBuffer(ptr, size); buf->send(dstRank, utag); // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the send. return std::make_shared(tensor, std::move(buf)); } std::shared_ptr ProcessGroupGloo::recv( std::vector& tensors, int srcRank, int tag) { auto& tensor = checkSingleTensor(tensors); auto utag = checkTag(tag); auto ptr = tensor.data_ptr(); auto size = tensor.numel() * tensor.element_size(); // Construct unbound buffer. auto context = getContext(tag); auto buf = context->createUnboundBuffer(ptr, size); buf->recv(srcRank, utag); // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the recv. return std::make_shared(tensor, std::move(buf)); } std::shared_ptr ProcessGroupGloo::recvAnysource( std::vector& tensors, int tag) { auto& tensor = checkSingleTensor(tensors); auto utag = checkTag(tag); auto ptr = tensor.data_ptr(); auto size = tensor.numel() * tensor.element_size(); // Construct unbound buffer. auto context = getContext(tag); auto buf = context->createUnboundBuffer(ptr, size); // Build list of ranks that this operation can recv from. In these // bindings we don't differentiate between ranks and can receive // from any other process in the group. std::vector srcRanks; srcRanks.resize(size_); for (auto i = 0; i < size_; i++) { srcRanks.push_back(i); } buf->recv(srcRanks, utag); // The work captures the tensor to prevent it being deallocated and // the unbound buffer to synchronize on completion of the recv. return std::make_shared(tensor, std::move(buf)); } namespace { class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { public: AsyncBarrierWork( const std::shared_ptr& context, std::vector> priorWork, uint32_t tag) : context(context), priorWork(std::move(priorWork)), tag(tag) {} std::shared_ptr context; std::vector> priorWork; const uint32_t tag; void run() override { // Wait on prior work to complete for (auto& weakWork : priorWork) { auto work = weakWork.lock(); if (work) { work->wait(); } } gloo::BarrierOptions opts(context); opts.setTag(tag); gloo::barrier(opts); } }; } // namespace std::shared_ptr ProcessGroupGloo::barrier( const BarrierOptions& opts) { std::vector> priorWork; // Snapshot all in progress and pending work as weak_ptr. // When executing a barrier, we need to ensure that all prior work // has completed before completing itself. { std::unique_lock lock(workMutex_); priorWork.insert( priorWork.end(), workInProgress_.begin(), workInProgress_.end()); priorWork.insert(priorWork.end(), workQueue_.begin(), workQueue_.end()); } auto tag = nextTag(); auto context = getContext(tag); auto work = std::make_shared( std::move(context), std::move(priorWork), tag); enqueue(work); return work; } } // namespace c10d