Build torch.distributed with Gloo backend on macOS (#25260)

Summary:
In facebookincubator/gloo#212, a libuv based Gloo transport was introduced,
which allows us to use Gloo on macOS (and later perhaps also Windows). This
commit updates CMake code to enable building with USE_DISTRIBUTED=1 on macOS.

A few notes:
* The Caffe2 ops are not compiled, for they depend on `gloo::transport::tcp`.
* The process group implementation uses `gloo::transport::tcp` on Linux (because of `epoll(2)` on Linux and `gloo::transport::uv` on macOS).
* The TCP store works but sometimes crashes on process termination.
* The distributed tests are not yet run.
* The nightly builds don't use `USE_DISTRIBUTED=1`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/25260

Reviewed By: mrshenli

Differential Revision: D17202381

Pulled By: pietern

fbshipit-source-id: ca80a82e78a05b4154271d2fb0ed31c8d9f26a7c
This commit is contained in:
Pieter Noordhuis 2019-09-05 07:08:12 -07:00 committed by Facebook Github Bot
parent a3d0abf729
commit 3556bea5aa
17 changed files with 197 additions and 91 deletions

View File

@ -17,6 +17,10 @@ fi
export PATH="${PYTORCH_ENV_DIR}/miniconda3/bin:$PATH"
source ${PYTORCH_ENV_DIR}/miniconda3/bin/activate
conda install -y mkl mkl-include numpy pyyaml setuptools cmake cffi ninja
# Building with USE_DISTRIBUTED=1 requires libuv (for Gloo).
conda install -y libuv pkg-config
rm -rf ${PYTORCH_ENV_DIR}/miniconda3/lib/python3.6/site-packages/torch*
git submodule sync --recursive
@ -66,6 +70,8 @@ export MAX_JOBS=2
export IMAGE_COMMIT_TAG=${BUILD_ENVIRONMENT}-${IMAGE_COMMIT_ID}
export USE_DISTRIBUTED=1
python setup.py install
assert_git_not_dirty

View File

@ -83,6 +83,25 @@ else ()
set(CPU_INTEL OFF)
endif ()
# For Windows, turn USE_DISTRIBUTED off by default.
# It is not tested and likely won't work without additional changes.
if(MSVC)
set(USE_DISTRIBUTED OFF CACHE STRING "Use distributed")
endif()
# For macOS, turn USE_DISTRIBUTED off by default.
# Gloo depends on libuv, which has to be installed by the user first.
# Therefore, the user should explicitly specify -DUSE_DISTRIBUTED=1
# to build torch.distributed, after installing the dependency.
if(APPLE)
set(USE_DISTRIBUTED OFF CACHE STRING "Use distributed")
# If USE_DISTRIBUTED is set, also set USE_LIBUV=1 for Gloo.
# It is the only transport that can be built on macOS.
if(USE_DISTRIBUTED)
set(USE_LIBUV ON CACHE STRING "")
endif()
endif()
# ---[ Options.
# Note to developers: if you add an option below, make sure you also add it to
# cmake/Summary.cmake so that the summary prints out the option values.
@ -264,9 +283,6 @@ if (MSVC)
# Try harder
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler /w -w")
# Turning off USE_DISTRIBUTED on default
set(USE_DISTRIBUTED OFF)
endif(MSVC)
# Set INTERN_BUILD_MOBILE for all mobile builds. Components that are not

View File

@ -768,7 +768,7 @@ ENDIF()
DESTINATION share/cmake/Torch)
if (USE_DISTRIBUTED)
if (NOT MSVC AND NOT APPLE)
if (NOT MSVC)
add_subdirectory(${TORCH_SRC_DIR}/lib/c10d lib_c10d)
endif()
endif()

View File

@ -1,5 +1,4 @@
add_subdirectory(aten)
add_subdirectory(gloo)
add_subdirectory(nccl)
add_subdirectory(opencl)
add_subdirectory(prof)
@ -8,6 +7,12 @@ if (USE_TENSORRT)
add_subdirectory(tensorrt)
endif()
# Only build Gloo Caffe2 ops on Linux, as it hardcodes
# the Linux-specific `gloo::transport::tcp` namespace.
if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
add_subdirectory(gloo)
endif()
# Pass the src lists back to the parent
# CPU source, include, deps, test sources, binary sources

View File

@ -50,7 +50,7 @@
#include "caffe2/operators/bbox_transform_op.h"
#include "caffe2/operators/box_with_nms_limit_op.h"
#ifdef CAFFE2_USE_GLOO
#if __linux__ && defined(CAFFE2_USE_GLOO)
#include <caffe2/contrib/gloo/common_world_ops.h>
#include <caffe2/contrib/gloo/broadcast_ops.h>
#include <caffe2/contrib/gloo/allreduce_ops.h>
@ -284,7 +284,7 @@ REGISTER_IDEEP_OPERATOR(
BatchMatMul,
IDEEPFallbackOp<BatchMatMulOp<CPUContext>>);
#ifdef CAFFE2_USE_GLOO
#if __linux__ && defined(CAFFE2_USE_GLOO)
namespace gloo {
// gloo operators
REGISTER_IDEEP_OPERATOR(

View File

@ -941,8 +941,8 @@ if(USE_CUDA)
endif()
if(USE_GLOO)
if(NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
message(WARNING "Gloo can only be used on Linux.")
if(MSVC)
message(WARNING "Gloo can not be used on Windows.")
caffe2_update_option(USE_GLOO OFF)
elseif(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
message(WARNING "Gloo can only be used on 64-bit systems.")

View File

@ -181,7 +181,7 @@ import glob
import importlib
from tools.build_pytorch_libs import build_caffe2
from tools.setup_helpers.env import (IS_WINDOWS, IS_DARWIN, IS_LINUX,
from tools.setup_helpers.env import (IS_WINDOWS, IS_DARWIN,
check_env_flag, build_type)
from tools.setup_helpers.cmake import CMake
from tools.setup_helpers.cuda import CUDA_HOME, CUDA_VERSION
@ -403,10 +403,10 @@ class build_ext(setuptools.command.build_ext.build_ext):
else:
report('-- Not using NCCL')
if cmake_cache_vars['USE_DISTRIBUTED']:
if IS_LINUX:
report('-- Building with c10d distributed package ')
if IS_WINDOWS:
report('-- Building without distributed package')
else:
report('-- Building without c10d distributed package')
report('-- Building with distributed package ')
else:
report('-- Building without distributed package')

View File

@ -11,6 +11,7 @@ import threading
import time
import unittest
from datetime import timedelta
from sys import platform
from itertools import groupby
from functools import partial, reduce
@ -38,6 +39,12 @@ if not c10d.is_available():
sys.exit(0)
if platform == 'darwin':
LOOPBACK = 'lo0'
else:
LOOPBACK = 'lo'
def gpus_for_rank(world_size):
"""Multigpu tests are designed to simulate the multi nodes with multi
GPUs on each node. Nccl backend requires equal #GPUs in each process.
@ -511,7 +518,7 @@ class TimeoutTest(TestCase):
class ProcessGroupGlooTest(MultiProcessTestCase):
def opts(self, threads=2):
opts = c10d.ProcessGroupGloo.Options()
opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
opts.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
opts.timeout = 5.0
opts.threads = threads
return opts
@ -521,8 +528,8 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
opts = c10d.ProcessGroupGloo.Options()
opts.timeout = 5.0
opts.devices = [
c10d.ProcessGroupGloo.create_tcp_device(interface="lo"),
c10d.ProcessGroupGloo.create_tcp_device(interface="lo"),
c10d.ProcessGroupGloo.create_device(interface=LOOPBACK),
c10d.ProcessGroupGloo.create_device(interface=LOOPBACK),
]
pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, opts)
@ -645,8 +652,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
(i * self.world_size) + (i % self.world_size)
]),
inputs[i],
None,
"Mismatch in iteration %d" % i,
message=("Mismatch in iteration %d" % i),
)
def test_broadcast_stress(self):
@ -728,8 +734,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
(self.world_size * (self.world_size - 1) / 2)
]),
inputs[i],
None,
"Mismatch in iteration %d" % i,
message=("Mismatch in iteration %d" % i),
)
def test_allreduce_stress(self):
@ -981,8 +986,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
self.assertEqual(
torch.Tensor([iter + root]),
outputs[iter][root],
None,
"Mismatch in iteration %d for rank %d" % (iter, root)
message=("Mismatch in iteration %d for rank %d" % (iter, root)),
)
def test_scatter_stress(self):
@ -1128,8 +1132,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
self.assertEqual(
expected_outputs[iter],
outputs[iter],
None,
"Mismatch in iteration %d for root %d" % (iter, root)
message=("Mismatch in iteration %d for root %d" % (iter, root))
)
def test_gather_stress(self):
@ -1229,8 +1232,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
self.assertEqual(
expected_outputs[i],
outputs[i],
None,
"Mismatch in iteration %d" % i
message=("Mismatch in iteration %d" % i),
)
def test_allgather_stress(self):
@ -1318,8 +1320,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
(self.world_size * (self.world_size - 1) / 2)
]),
outputs[i],
None,
"Mismatch in iteration %d with root rank %d" % (iter, root),
message=("Mismatch in iteration %d with root rank %d" % (iter, root)),
)
def test_reduce_stress(self):
@ -1369,6 +1370,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
continue
self.assertEqual(torch.Tensor([i]), outputs[i])
@unittest.skipIf(platform == 'darwin', 'ProcessGroup timeout not yet supported on macOS')
def test_timeout_kwarg(self):
store = c10d.FileStore(self.file.name, self.world_size)
pg = c10d.ProcessGroupGloo(
@ -1838,7 +1840,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
def _test_gloo_backend(self, devices, device_ids, multi_device=False):
store = c10d.FileStore(self.file.name, self.world_size)
options = c10d.ProcessGroupGloo.Options()
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
self._test_ddp_with_process_group(process_group, devices, device_ids, multi_device)
@ -1980,7 +1982,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
def test_dist_broadcast_coalesced_gloo(self):
store = c10d.FileStore(self.file.name, self.world_size)
options = c10d.ProcessGroupGloo.Options()
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
device = torch.device('cuda')
@ -2019,7 +2021,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
def test_sync_params_no_buffers(self):
store = c10d.FileStore(self.file.name, self.world_size)
options = c10d.ProcessGroupGloo.Options()
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
# Use all available devices on every process here (data is small, so should be fine).
@ -2046,7 +2048,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
def test_sync_params_with_buffers(self):
store = c10d.FileStore(self.file.name, self.world_size)
options = c10d.ProcessGroupGloo.Options()
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
devices = gpus_for_rank(self.world_size)[self.rank]
@ -3053,7 +3055,7 @@ class CommTest(MultiProcessTestCase):
def test_broadcast_coalesced_gloo_cuda(self):
store = c10d.FileStore(self.file.name, self.world_size)
options = c10d.ProcessGroupGloo.Options()
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
device = torch.device('cuda:%d' % self.rank)
self._test_broadcast_coalesced(process_group, device)
@ -3062,7 +3064,7 @@ class CommTest(MultiProcessTestCase):
def test_broadcast_coalesced_gloo_cpu(self):
store = c10d.FileStore(self.file.name, self.world_size)
options = c10d.ProcessGroupGloo.Options()
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
device = torch.device('cpu')
self._test_broadcast_coalesced(process_group, device)

View File

@ -34,7 +34,7 @@ class ProcessGroupShareTensorTest(TestCase):
@classmethod
def opts(cls, threads=2):
opts = c10d.ProcessGroupGloo.Options()
opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
opts.devices = [c10d.ProcessGroupGloo.create_device(interface="lo")]
opts.timeout = 5.0
opts.threads = threads
return opts

2
third_party/gloo vendored

@ -1 +1 @@
Subproject commit a9fa7c8d6e95a6be22b734e22b595bc80f03aea0
Subproject commit 2101e02ceabd9f1b0bb354f6ea705cefe83558b2

View File

@ -221,7 +221,7 @@ endif()
if (USE_DISTRIBUTED)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_DISTRIBUTED)
if (NOT MSVC AND NOT APPLE)
if (NOT MSVC)
list(APPEND TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp
${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_container.cpp

View File

@ -17,7 +17,6 @@
#include <c10d/PrefixStore.hpp>
#include <c10d/TCPStore.hpp>
#include <gloo/transport/tcp/device.h>
#include <pybind11/chrono.h>
#include <torch/csrc/Exceptions.h>
@ -35,28 +34,6 @@ namespace {
#ifdef USE_C10D_GLOO
constexpr char* GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
std::shared_ptr<::gloo::transport::Device> createDeviceForDefaultHostname() {
::gloo::transport::tcp::attr attr;
// Use the hostname to resolve the network address to
// use. Note: if the hostname does not resolve to an address (e.g.
// because of misconfigured /etc/hosts file), this will not work.
std::array<char, HOST_NAME_MAX> hostname{};
auto rv = gethostname(hostname.data(), hostname.size());
if (rv != 0) {
throw std::system_error(errno, std::system_category());
}
attr.hostname = hostname.data();
return ::gloo::transport::tcp::CreateDevice(attr);
}
std::shared_ptr<::gloo::transport::Device> createDeviceForInterface(
std::string iface) {
::gloo::transport::tcp::attr attr;
attr.iface = std::move(iface);
return ::gloo::transport::tcp::CreateDevice(attr);
}
#endif
std::vector<std::string> split(char separator, const std::string& string) {
@ -446,20 +423,17 @@ They are used in specifying strategies for reduction collectives, e.g.,
.def_readwrite("threads", &::c10d::ProcessGroupGloo::Options::threads);
processGroupGloo.def_static(
"create_tcp_device",
"create_device",
[](const std::string& hostname, const std::string& interface)
-> std::shared_ptr<::gloo::transport::Device> {
::gloo::transport::tcp::attr attr;
if (!hostname.empty()) {
attr.hostname = hostname;
} else if (!interface.empty()) {
attr.iface = interface;
} else {
// Neither argument is specified; Gloo itself will use the
// hostname
// Nothing specified, default to something useful
return ::c10d::ProcessGroupGloo::createDeviceForHostname(hostname);
}
return ::gloo::transport::tcp::CreateDevice(attr);
if (!interface.empty()) {
return ::c10d::ProcessGroupGloo::createDeviceForInterface(interface);
}
throw std::invalid_argument(
"Specify either `hostname` or `interface` argument.");
},
py::arg("hostname") = "",
py::arg("interface") = "");
@ -481,10 +455,15 @@ They are used in specifying strategies for reduction collectives, e.g.,
char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV);
if (ifnameEnv) {
for (const auto& iface : split(',', ifnameEnv)) {
options.devices.push_back(createDeviceForInterface(iface));
options.devices.push_back(
::c10d::ProcessGroupGloo::createDeviceForInterface(iface));
}
} else {
options.devices.push_back(createDeviceForDefaultHostname());
// If no hostname is specified, this function looks up
// the machine's hostname and returns a device instance
// associated with the address that the hostname resolves to.
options.devices.push_back(
::c10d::ProcessGroupGloo::createDeviceForHostname(""));
}
options.timeout = timeout;

View File

@ -1,5 +1,7 @@
#include <c10d/ProcessGroupGloo.hpp>
#include <unistd.h>
#include <gloo/allgather.h>
#include <gloo/allreduce.h>
#include <gloo/barrier.h>
@ -19,9 +21,31 @@
#include <c10/cuda/CUDAStream.h>
#endif
#include <gloo/config.h>
#include <gloo/rendezvous/context.h>
#include <gloo/rendezvous/prefix_store.h>
#if GLOO_HAVE_TRANSPORT_TCP
#include <gloo/transport/tcp/device.h>
#endif
#if GLOO_HAVE_TRANSPORT_UV
#include <gloo/transport/uv/device.h>
#endif
// On Linux, check that the tcp transport is available.
#ifdef __linux__
#if !GLOO_HAVE_TRANSPORT_TCP
#error "Expected the tcp transport to be available on Linux."
#endif
#endif
// On macOS, check that the uv transport is available.
#ifdef __APPLE__
#if !GLOO_HAVE_TRANSPORT_UV
#error "Expected the uv transport to be available on macOS."
#endif
#endif
#define GENERATE_ALL_TYPES(type, func, args...) \
switch (type) { \
@ -276,6 +300,71 @@ void ProcessGroupGloo::RecvWork::wait() {
ProcessGroupGloo::Options::Options()
: timeout(std::chrono::milliseconds(10 * 1000)), threads(2) {}
#ifdef __linux__
std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
createDeviceForInterface(const std::string& interface) {
::gloo::transport::tcp::attr attr;
attr.iface = interface;
return ::gloo::transport::tcp::CreateDevice(attr);
}
#endif
#ifdef __APPLE__
std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
createDeviceForInterface(const std::string& interface) {
::gloo::transport::uv::attr attr;
attr.iface = interface;
return ::gloo::transport::uv::CreateDevice(attr);
}
#endif
#ifdef __linux__
std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
createDeviceForHostname(const std::string& hostname) {
::gloo::transport::tcp::attr attr;
if (hostname.empty()) {
// Use the hostname to resolve the network address to
// use. Note: if the hostname does not resolve to an address (e.g.
// because of misconfigured /etc/hosts file), this will not work.
std::array<char, HOST_NAME_MAX> buffer{};
auto rv = gethostname(buffer.data(), buffer.size());
if (rv != 0) {
throw std::system_error(errno, std::system_category());
}
attr.hostname = buffer.data();
} else {
attr.hostname = hostname;
}
return ::gloo::transport::tcp::CreateDevice(attr);
}
#endif
#ifdef __APPLE__
std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
createDeviceForHostname(const std::string& hostname) {
::gloo::transport::uv::attr attr;
if (hostname.empty()) {
// Use the hostname to resolve the network address to
// use. Note: if the hostname does not resolve to an address (e.g.
// because of misconfigured /etc/hosts file), this will not work.
const auto hostNameMax = sysconf(_SC_HOST_NAME_MAX);
auto buffer = std::unique_ptr<char[]>(new char[hostNameMax]);
auto rv = gethostname(buffer.get(), hostNameMax);
if (rv != 0) {
throw std::system_error(errno, std::system_category());
}
attr.hostname = buffer.get();
} else {
attr.hostname = hostname;
}
return ::gloo::transport::uv::CreateDevice(attr);
}
#endif
ProcessGroupGloo::ProcessGroupGloo(
const std::shared_ptr<Store>& store,
int rank,
@ -644,7 +733,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
// Construct from an existing metadata tensor to facilitate structured
// access to metadata from peers, after gathering it.
explicit SparseTensorMetadata(at::Tensor metadata)
: metadata_(metadata), data_(metadata_.data_ptr<long>()) {
: metadata_(metadata), data_(metadata_.data_ptr<int64_t>()) {
AT_ASSERT(metadata.scalar_type() == at::kLong);
AT_ASSERT(metadata.dim() == 1);
AT_ASSERT(metadata.size(0) == dim);
@ -694,7 +783,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
protected:
at::Tensor metadata_;
long* data_;
int64_t* data_;
};
// Sparse allreduce is implemented with allgather on indices and values.
@ -719,7 +808,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
// Sanity check dimensionality across ranks.
{
const auto expected = metadata[context->rank].sizes();
for (size_t i = 0; i < context->size; i++) {
for (auto i = 0; i < context->size; i++) {
if (i == context->rank) {
continue;
}
@ -733,11 +822,11 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
auto values = allgather_values(input, metadata);
// Perform global reduction.
AT_ASSERT(indices.size() == context->size);
AT_ASSERT(values.size() == context->size);
AT_ASSERT(static_cast<int>(indices.size()) == context->size);
AT_ASSERT(static_cast<int>(values.size()) == context->size);
auto output = at::sparse_coo_tensor(
indices[0], values[0], input.sizes(), input.options());
for (size_t i = 1; i < context->size; i++) {
for (auto i = 1; i < context->size; i++) {
output += at::sparse_coo_tensor(
indices[i], values[i], input.sizes(), input.options());
}
@ -778,7 +867,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
// Allgather metadata
gloo::AllgatherOptions opts(context);
opts.setOutput(buffer.data_ptr<long>(), buffer.numel());
opts.setOutput(buffer.data_ptr<int64_t>(), buffer.numel());
opts.setTag(tag);
gloo::allgather(opts);
@ -802,7 +891,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
// Allgather indices.
gloo::AllgatherOptions opts(context);
opts.setOutput(buffer.data_ptr<long>(), buffer.numel());
opts.setOutput(buffer.data_ptr<int64_t>(), buffer.numel());
opts.setTag(tag);
gloo::allgather(opts);

View File

@ -127,6 +127,19 @@ class ProcessGroupGloo : public ProcessGroup {
int threads;
};
// Helper functions to create a new device object.
// They are static functions on this class to keep them logically
// separate from the rest of the code base (e.g. torch/csrc/distributed).
// Create new device instance for specific interface.
static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface(
const std::string& interface);
// Create new device instance for hostname or address.
// If specified argument is empty, it defaults to this machine's hostname.
static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
const std::string& hostname);
explicit ProcessGroupGloo(
const std::shared_ptr<Store>& store,
int rank,

View File

@ -1,5 +1,3 @@
#include <gloo/transport/tcp/device.h>
#include <ATen/cuda/CUDAMultiStreamGuard.h>
#include <c10/cuda/CUDAGuard.h>
@ -51,8 +49,8 @@ class AsyncTest {
// Use tiny timeout to make this test run fast
::c10d::ProcessGroupGloo::Options options;
options.timeout = std::chrono::milliseconds(50);
::gloo::transport::tcp::attr attr;
options.devices.push_back(::gloo::transport::tcp::CreateDevice(attr));
options.devices.push_back(
::c10d::ProcessGroupGloo::createDeviceForHostname("127.0.0.1"));
pg_ = std::unique_ptr<::c10d::ProcessGroupGloo>(
new ::c10d::ProcessGroupGloo(store, rank, size, options));

View File

@ -9,8 +9,6 @@
#include <sstream>
#include <thread>
#include <gloo/transport/tcp/device.h>
#include <c10d/FileStore.hpp>
#include <c10d/ProcessGroupGloo.hpp>
#include <c10d/test/TestUtils.hpp>
@ -42,8 +40,8 @@ class SignalTest {
// Use tiny timeout to make this test run fast
::c10d::ProcessGroupGloo::Options options;
options.timeout = std::chrono::milliseconds(50);
::gloo::transport::tcp::attr attr;
options.devices.push_back(::gloo::transport::tcp::CreateDevice(attr));
options.devices.push_back(
::c10d::ProcessGroupGloo::createDeviceForHostname("127.0.0.1"));
::c10d::ProcessGroupGloo pg(store, rank, size, options);
@ -127,9 +125,8 @@ class CollectiveTest {
// Use tiny timeout to make this test run fast
::c10d::ProcessGroupGloo::Options options;
options.timeout = std::chrono::milliseconds(50);
::gloo::transport::tcp::attr attr;
options.devices.push_back(::gloo::transport::tcp::CreateDevice(attr));
options.devices.push_back(
::c10d::ProcessGroupGloo::createDeviceForHostname("127.0.0.1"));
pg_ = std::unique_ptr<::c10d::ProcessGroupGloo>(
new ::c10d::ProcessGroupGloo(store, rank, size, options));

View File

@ -1,5 +1,6 @@
#pragma once
#include <signal.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <unistd.h>