mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Build torch.distributed with Gloo backend on macOS (#25260)
Summary: In facebookincubator/gloo#212, a libuv based Gloo transport was introduced, which allows us to use Gloo on macOS (and later perhaps also Windows). This commit updates CMake code to enable building with USE_DISTRIBUTED=1 on macOS. A few notes: * The Caffe2 ops are not compiled, for they depend on `gloo::transport::tcp`. * The process group implementation uses `gloo::transport::tcp` on Linux (because of `epoll(2)` on Linux and `gloo::transport::uv` on macOS). * The TCP store works but sometimes crashes on process termination. * The distributed tests are not yet run. * The nightly builds don't use `USE_DISTRIBUTED=1`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/25260 Reviewed By: mrshenli Differential Revision: D17202381 Pulled By: pietern fbshipit-source-id: ca80a82e78a05b4154271d2fb0ed31c8d9f26a7c
This commit is contained in:
parent
a3d0abf729
commit
3556bea5aa
|
|
@ -17,6 +17,10 @@ fi
|
|||
export PATH="${PYTORCH_ENV_DIR}/miniconda3/bin:$PATH"
|
||||
source ${PYTORCH_ENV_DIR}/miniconda3/bin/activate
|
||||
conda install -y mkl mkl-include numpy pyyaml setuptools cmake cffi ninja
|
||||
|
||||
# Building with USE_DISTRIBUTED=1 requires libuv (for Gloo).
|
||||
conda install -y libuv pkg-config
|
||||
|
||||
rm -rf ${PYTORCH_ENV_DIR}/miniconda3/lib/python3.6/site-packages/torch*
|
||||
|
||||
git submodule sync --recursive
|
||||
|
|
@ -66,6 +70,8 @@ export MAX_JOBS=2
|
|||
|
||||
export IMAGE_COMMIT_TAG=${BUILD_ENVIRONMENT}-${IMAGE_COMMIT_ID}
|
||||
|
||||
export USE_DISTRIBUTED=1
|
||||
|
||||
python setup.py install
|
||||
|
||||
assert_git_not_dirty
|
||||
|
|
|
|||
|
|
@ -83,6 +83,25 @@ else ()
|
|||
set(CPU_INTEL OFF)
|
||||
endif ()
|
||||
|
||||
# For Windows, turn USE_DISTRIBUTED off by default.
|
||||
# It is not tested and likely won't work without additional changes.
|
||||
if(MSVC)
|
||||
set(USE_DISTRIBUTED OFF CACHE STRING "Use distributed")
|
||||
endif()
|
||||
|
||||
# For macOS, turn USE_DISTRIBUTED off by default.
|
||||
# Gloo depends on libuv, which has to be installed by the user first.
|
||||
# Therefore, the user should explicitly specify -DUSE_DISTRIBUTED=1
|
||||
# to build torch.distributed, after installing the dependency.
|
||||
if(APPLE)
|
||||
set(USE_DISTRIBUTED OFF CACHE STRING "Use distributed")
|
||||
# If USE_DISTRIBUTED is set, also set USE_LIBUV=1 for Gloo.
|
||||
# It is the only transport that can be built on macOS.
|
||||
if(USE_DISTRIBUTED)
|
||||
set(USE_LIBUV ON CACHE STRING "")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ---[ Options.
|
||||
# Note to developers: if you add an option below, make sure you also add it to
|
||||
# cmake/Summary.cmake so that the summary prints out the option values.
|
||||
|
|
@ -264,9 +283,6 @@ if (MSVC)
|
|||
|
||||
# Try harder
|
||||
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler /w -w")
|
||||
|
||||
# Turning off USE_DISTRIBUTED on default
|
||||
set(USE_DISTRIBUTED OFF)
|
||||
endif(MSVC)
|
||||
|
||||
# Set INTERN_BUILD_MOBILE for all mobile builds. Components that are not
|
||||
|
|
|
|||
|
|
@ -768,7 +768,7 @@ ENDIF()
|
|||
DESTINATION share/cmake/Torch)
|
||||
|
||||
if (USE_DISTRIBUTED)
|
||||
if (NOT MSVC AND NOT APPLE)
|
||||
if (NOT MSVC)
|
||||
add_subdirectory(${TORCH_SRC_DIR}/lib/c10d lib_c10d)
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
add_subdirectory(aten)
|
||||
add_subdirectory(gloo)
|
||||
add_subdirectory(nccl)
|
||||
add_subdirectory(opencl)
|
||||
add_subdirectory(prof)
|
||||
|
|
@ -8,6 +7,12 @@ if (USE_TENSORRT)
|
|||
add_subdirectory(tensorrt)
|
||||
endif()
|
||||
|
||||
# Only build Gloo Caffe2 ops on Linux, as it hardcodes
|
||||
# the Linux-specific `gloo::transport::tcp` namespace.
|
||||
if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
|
||||
add_subdirectory(gloo)
|
||||
endif()
|
||||
|
||||
# Pass the src lists back to the parent
|
||||
|
||||
# CPU source, include, deps, test sources, binary sources
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@
|
|||
#include "caffe2/operators/bbox_transform_op.h"
|
||||
#include "caffe2/operators/box_with_nms_limit_op.h"
|
||||
|
||||
#ifdef CAFFE2_USE_GLOO
|
||||
#if __linux__ && defined(CAFFE2_USE_GLOO)
|
||||
#include <caffe2/contrib/gloo/common_world_ops.h>
|
||||
#include <caffe2/contrib/gloo/broadcast_ops.h>
|
||||
#include <caffe2/contrib/gloo/allreduce_ops.h>
|
||||
|
|
@ -284,7 +284,7 @@ REGISTER_IDEEP_OPERATOR(
|
|||
BatchMatMul,
|
||||
IDEEPFallbackOp<BatchMatMulOp<CPUContext>>);
|
||||
|
||||
#ifdef CAFFE2_USE_GLOO
|
||||
#if __linux__ && defined(CAFFE2_USE_GLOO)
|
||||
namespace gloo {
|
||||
// gloo operators
|
||||
REGISTER_IDEEP_OPERATOR(
|
||||
|
|
|
|||
|
|
@ -941,8 +941,8 @@ if(USE_CUDA)
|
|||
endif()
|
||||
|
||||
if(USE_GLOO)
|
||||
if(NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
|
||||
message(WARNING "Gloo can only be used on Linux.")
|
||||
if(MSVC)
|
||||
message(WARNING "Gloo can not be used on Windows.")
|
||||
caffe2_update_option(USE_GLOO OFF)
|
||||
elseif(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
|
||||
message(WARNING "Gloo can only be used on 64-bit systems.")
|
||||
|
|
|
|||
8
setup.py
8
setup.py
|
|
@ -181,7 +181,7 @@ import glob
|
|||
import importlib
|
||||
|
||||
from tools.build_pytorch_libs import build_caffe2
|
||||
from tools.setup_helpers.env import (IS_WINDOWS, IS_DARWIN, IS_LINUX,
|
||||
from tools.setup_helpers.env import (IS_WINDOWS, IS_DARWIN,
|
||||
check_env_flag, build_type)
|
||||
from tools.setup_helpers.cmake import CMake
|
||||
from tools.setup_helpers.cuda import CUDA_HOME, CUDA_VERSION
|
||||
|
|
@ -403,10 +403,10 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||
else:
|
||||
report('-- Not using NCCL')
|
||||
if cmake_cache_vars['USE_DISTRIBUTED']:
|
||||
if IS_LINUX:
|
||||
report('-- Building with c10d distributed package ')
|
||||
if IS_WINDOWS:
|
||||
report('-- Building without distributed package')
|
||||
else:
|
||||
report('-- Building without c10d distributed package')
|
||||
report('-- Building with distributed package ')
|
||||
else:
|
||||
report('-- Building without distributed package')
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import threading
|
|||
import time
|
||||
import unittest
|
||||
from datetime import timedelta
|
||||
from sys import platform
|
||||
|
||||
from itertools import groupby
|
||||
from functools import partial, reduce
|
||||
|
|
@ -38,6 +39,12 @@ if not c10d.is_available():
|
|||
sys.exit(0)
|
||||
|
||||
|
||||
if platform == 'darwin':
|
||||
LOOPBACK = 'lo0'
|
||||
else:
|
||||
LOOPBACK = 'lo'
|
||||
|
||||
|
||||
def gpus_for_rank(world_size):
|
||||
"""Multigpu tests are designed to simulate the multi nodes with multi
|
||||
GPUs on each node. Nccl backend requires equal #GPUs in each process.
|
||||
|
|
@ -511,7 +518,7 @@ class TimeoutTest(TestCase):
|
|||
class ProcessGroupGlooTest(MultiProcessTestCase):
|
||||
def opts(self, threads=2):
|
||||
opts = c10d.ProcessGroupGloo.Options()
|
||||
opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
|
||||
opts.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
|
||||
opts.timeout = 5.0
|
||||
opts.threads = threads
|
||||
return opts
|
||||
|
|
@ -521,8 +528,8 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
opts = c10d.ProcessGroupGloo.Options()
|
||||
opts.timeout = 5.0
|
||||
opts.devices = [
|
||||
c10d.ProcessGroupGloo.create_tcp_device(interface="lo"),
|
||||
c10d.ProcessGroupGloo.create_tcp_device(interface="lo"),
|
||||
c10d.ProcessGroupGloo.create_device(interface=LOOPBACK),
|
||||
c10d.ProcessGroupGloo.create_device(interface=LOOPBACK),
|
||||
]
|
||||
pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, opts)
|
||||
|
||||
|
|
@ -645,8 +652,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
(i * self.world_size) + (i % self.world_size)
|
||||
]),
|
||||
inputs[i],
|
||||
None,
|
||||
"Mismatch in iteration %d" % i,
|
||||
message=("Mismatch in iteration %d" % i),
|
||||
)
|
||||
|
||||
def test_broadcast_stress(self):
|
||||
|
|
@ -728,8 +734,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
(self.world_size * (self.world_size - 1) / 2)
|
||||
]),
|
||||
inputs[i],
|
||||
None,
|
||||
"Mismatch in iteration %d" % i,
|
||||
message=("Mismatch in iteration %d" % i),
|
||||
)
|
||||
|
||||
def test_allreduce_stress(self):
|
||||
|
|
@ -981,8 +986,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
self.assertEqual(
|
||||
torch.Tensor([iter + root]),
|
||||
outputs[iter][root],
|
||||
None,
|
||||
"Mismatch in iteration %d for rank %d" % (iter, root)
|
||||
message=("Mismatch in iteration %d for rank %d" % (iter, root)),
|
||||
)
|
||||
|
||||
def test_scatter_stress(self):
|
||||
|
|
@ -1128,8 +1132,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
self.assertEqual(
|
||||
expected_outputs[iter],
|
||||
outputs[iter],
|
||||
None,
|
||||
"Mismatch in iteration %d for root %d" % (iter, root)
|
||||
message=("Mismatch in iteration %d for root %d" % (iter, root))
|
||||
)
|
||||
|
||||
def test_gather_stress(self):
|
||||
|
|
@ -1229,8 +1232,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
self.assertEqual(
|
||||
expected_outputs[i],
|
||||
outputs[i],
|
||||
None,
|
||||
"Mismatch in iteration %d" % i
|
||||
message=("Mismatch in iteration %d" % i),
|
||||
)
|
||||
|
||||
def test_allgather_stress(self):
|
||||
|
|
@ -1318,8 +1320,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
(self.world_size * (self.world_size - 1) / 2)
|
||||
]),
|
||||
outputs[i],
|
||||
None,
|
||||
"Mismatch in iteration %d with root rank %d" % (iter, root),
|
||||
message=("Mismatch in iteration %d with root rank %d" % (iter, root)),
|
||||
)
|
||||
|
||||
def test_reduce_stress(self):
|
||||
|
|
@ -1369,6 +1370,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
continue
|
||||
self.assertEqual(torch.Tensor([i]), outputs[i])
|
||||
|
||||
@unittest.skipIf(platform == 'darwin', 'ProcessGroup timeout not yet supported on macOS')
|
||||
def test_timeout_kwarg(self):
|
||||
store = c10d.FileStore(self.file.name, self.world_size)
|
||||
pg = c10d.ProcessGroupGloo(
|
||||
|
|
@ -1838,7 +1840,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
|
|||
def _test_gloo_backend(self, devices, device_ids, multi_device=False):
|
||||
store = c10d.FileStore(self.file.name, self.world_size)
|
||||
options = c10d.ProcessGroupGloo.Options()
|
||||
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
|
||||
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
|
||||
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
|
||||
self._test_ddp_with_process_group(process_group, devices, device_ids, multi_device)
|
||||
|
||||
|
|
@ -1980,7 +1982,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
|
|||
def test_dist_broadcast_coalesced_gloo(self):
|
||||
store = c10d.FileStore(self.file.name, self.world_size)
|
||||
options = c10d.ProcessGroupGloo.Options()
|
||||
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
|
||||
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
|
||||
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
|
@ -2019,7 +2021,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
|
|||
def test_sync_params_no_buffers(self):
|
||||
store = c10d.FileStore(self.file.name, self.world_size)
|
||||
options = c10d.ProcessGroupGloo.Options()
|
||||
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
|
||||
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
|
||||
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
|
||||
|
||||
# Use all available devices on every process here (data is small, so should be fine).
|
||||
|
|
@ -2046,7 +2048,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
|
|||
def test_sync_params_with_buffers(self):
|
||||
store = c10d.FileStore(self.file.name, self.world_size)
|
||||
options = c10d.ProcessGroupGloo.Options()
|
||||
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
|
||||
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
|
||||
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
|
||||
|
||||
devices = gpus_for_rank(self.world_size)[self.rank]
|
||||
|
|
@ -3053,7 +3055,7 @@ class CommTest(MultiProcessTestCase):
|
|||
def test_broadcast_coalesced_gloo_cuda(self):
|
||||
store = c10d.FileStore(self.file.name, self.world_size)
|
||||
options = c10d.ProcessGroupGloo.Options()
|
||||
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
|
||||
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
|
||||
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
|
||||
device = torch.device('cuda:%d' % self.rank)
|
||||
self._test_broadcast_coalesced(process_group, device)
|
||||
|
|
@ -3062,7 +3064,7 @@ class CommTest(MultiProcessTestCase):
|
|||
def test_broadcast_coalesced_gloo_cpu(self):
|
||||
store = c10d.FileStore(self.file.name, self.world_size)
|
||||
options = c10d.ProcessGroupGloo.Options()
|
||||
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
|
||||
options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)]
|
||||
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
|
||||
device = torch.device('cpu')
|
||||
self._test_broadcast_coalesced(process_group, device)
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class ProcessGroupShareTensorTest(TestCase):
|
|||
@classmethod
|
||||
def opts(cls, threads=2):
|
||||
opts = c10d.ProcessGroupGloo.Options()
|
||||
opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
|
||||
opts.devices = [c10d.ProcessGroupGloo.create_device(interface="lo")]
|
||||
opts.timeout = 5.0
|
||||
opts.threads = threads
|
||||
return opts
|
||||
|
|
|
|||
2
third_party/gloo
vendored
2
third_party/gloo
vendored
|
|
@ -1 +1 @@
|
|||
Subproject commit a9fa7c8d6e95a6be22b734e22b595bc80f03aea0
|
||||
Subproject commit 2101e02ceabd9f1b0bb354f6ea705cefe83558b2
|
||||
|
|
@ -221,7 +221,7 @@ endif()
|
|||
|
||||
if (USE_DISTRIBUTED)
|
||||
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_DISTRIBUTED)
|
||||
if (NOT MSVC AND NOT APPLE)
|
||||
if (NOT MSVC)
|
||||
list(APPEND TORCH_PYTHON_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_container.cpp
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@
|
|||
|
||||
#include <c10d/PrefixStore.hpp>
|
||||
#include <c10d/TCPStore.hpp>
|
||||
#include <gloo/transport/tcp/device.h>
|
||||
#include <pybind11/chrono.h>
|
||||
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
|
|
@ -35,28 +34,6 @@ namespace {
|
|||
|
||||
#ifdef USE_C10D_GLOO
|
||||
constexpr char* GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
|
||||
|
||||
std::shared_ptr<::gloo::transport::Device> createDeviceForDefaultHostname() {
|
||||
::gloo::transport::tcp::attr attr;
|
||||
|
||||
// Use the hostname to resolve the network address to
|
||||
// use. Note: if the hostname does not resolve to an address (e.g.
|
||||
// because of misconfigured /etc/hosts file), this will not work.
|
||||
std::array<char, HOST_NAME_MAX> hostname{};
|
||||
auto rv = gethostname(hostname.data(), hostname.size());
|
||||
if (rv != 0) {
|
||||
throw std::system_error(errno, std::system_category());
|
||||
}
|
||||
attr.hostname = hostname.data();
|
||||
return ::gloo::transport::tcp::CreateDevice(attr);
|
||||
}
|
||||
|
||||
std::shared_ptr<::gloo::transport::Device> createDeviceForInterface(
|
||||
std::string iface) {
|
||||
::gloo::transport::tcp::attr attr;
|
||||
attr.iface = std::move(iface);
|
||||
return ::gloo::transport::tcp::CreateDevice(attr);
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<std::string> split(char separator, const std::string& string) {
|
||||
|
|
@ -446,20 +423,17 @@ They are used in specifying strategies for reduction collectives, e.g.,
|
|||
.def_readwrite("threads", &::c10d::ProcessGroupGloo::Options::threads);
|
||||
|
||||
processGroupGloo.def_static(
|
||||
"create_tcp_device",
|
||||
"create_device",
|
||||
[](const std::string& hostname, const std::string& interface)
|
||||
-> std::shared_ptr<::gloo::transport::Device> {
|
||||
::gloo::transport::tcp::attr attr;
|
||||
if (!hostname.empty()) {
|
||||
attr.hostname = hostname;
|
||||
} else if (!interface.empty()) {
|
||||
attr.iface = interface;
|
||||
} else {
|
||||
// Neither argument is specified; Gloo itself will use the
|
||||
// hostname
|
||||
// Nothing specified, default to something useful
|
||||
return ::c10d::ProcessGroupGloo::createDeviceForHostname(hostname);
|
||||
}
|
||||
return ::gloo::transport::tcp::CreateDevice(attr);
|
||||
if (!interface.empty()) {
|
||||
return ::c10d::ProcessGroupGloo::createDeviceForInterface(interface);
|
||||
}
|
||||
throw std::invalid_argument(
|
||||
"Specify either `hostname` or `interface` argument.");
|
||||
},
|
||||
py::arg("hostname") = "",
|
||||
py::arg("interface") = "");
|
||||
|
|
@ -481,10 +455,15 @@ They are used in specifying strategies for reduction collectives, e.g.,
|
|||
char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV);
|
||||
if (ifnameEnv) {
|
||||
for (const auto& iface : split(',', ifnameEnv)) {
|
||||
options.devices.push_back(createDeviceForInterface(iface));
|
||||
options.devices.push_back(
|
||||
::c10d::ProcessGroupGloo::createDeviceForInterface(iface));
|
||||
}
|
||||
} else {
|
||||
options.devices.push_back(createDeviceForDefaultHostname());
|
||||
// If no hostname is specified, this function looks up
|
||||
// the machine's hostname and returns a device instance
|
||||
// associated with the address that the hostname resolves to.
|
||||
options.devices.push_back(
|
||||
::c10d::ProcessGroupGloo::createDeviceForHostname(""));
|
||||
}
|
||||
|
||||
options.timeout = timeout;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
#include <c10d/ProcessGroupGloo.hpp>
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
#include <gloo/allgather.h>
|
||||
#include <gloo/allreduce.h>
|
||||
#include <gloo/barrier.h>
|
||||
|
|
@ -19,9 +21,31 @@
|
|||
#include <c10/cuda/CUDAStream.h>
|
||||
#endif
|
||||
|
||||
#include <gloo/config.h>
|
||||
#include <gloo/rendezvous/context.h>
|
||||
#include <gloo/rendezvous/prefix_store.h>
|
||||
|
||||
#if GLOO_HAVE_TRANSPORT_TCP
|
||||
#include <gloo/transport/tcp/device.h>
|
||||
#endif
|
||||
|
||||
#if GLOO_HAVE_TRANSPORT_UV
|
||||
#include <gloo/transport/uv/device.h>
|
||||
#endif
|
||||
|
||||
// On Linux, check that the tcp transport is available.
|
||||
#ifdef __linux__
|
||||
#if !GLOO_HAVE_TRANSPORT_TCP
|
||||
#error "Expected the tcp transport to be available on Linux."
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// On macOS, check that the uv transport is available.
|
||||
#ifdef __APPLE__
|
||||
#if !GLOO_HAVE_TRANSPORT_UV
|
||||
#error "Expected the uv transport to be available on macOS."
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define GENERATE_ALL_TYPES(type, func, args...) \
|
||||
switch (type) { \
|
||||
|
|
@ -276,6 +300,71 @@ void ProcessGroupGloo::RecvWork::wait() {
|
|||
ProcessGroupGloo::Options::Options()
|
||||
: timeout(std::chrono::milliseconds(10 * 1000)), threads(2) {}
|
||||
|
||||
#ifdef __linux__
|
||||
std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
|
||||
createDeviceForInterface(const std::string& interface) {
|
||||
::gloo::transport::tcp::attr attr;
|
||||
attr.iface = interface;
|
||||
return ::gloo::transport::tcp::CreateDevice(attr);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __APPLE__
|
||||
std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
|
||||
createDeviceForInterface(const std::string& interface) {
|
||||
::gloo::transport::uv::attr attr;
|
||||
attr.iface = interface;
|
||||
return ::gloo::transport::uv::CreateDevice(attr);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __linux__
|
||||
std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
|
||||
createDeviceForHostname(const std::string& hostname) {
|
||||
::gloo::transport::tcp::attr attr;
|
||||
|
||||
if (hostname.empty()) {
|
||||
// Use the hostname to resolve the network address to
|
||||
// use. Note: if the hostname does not resolve to an address (e.g.
|
||||
// because of misconfigured /etc/hosts file), this will not work.
|
||||
std::array<char, HOST_NAME_MAX> buffer{};
|
||||
auto rv = gethostname(buffer.data(), buffer.size());
|
||||
if (rv != 0) {
|
||||
throw std::system_error(errno, std::system_category());
|
||||
}
|
||||
attr.hostname = buffer.data();
|
||||
} else {
|
||||
attr.hostname = hostname;
|
||||
}
|
||||
|
||||
return ::gloo::transport::tcp::CreateDevice(attr);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __APPLE__
|
||||
std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
|
||||
createDeviceForHostname(const std::string& hostname) {
|
||||
::gloo::transport::uv::attr attr;
|
||||
|
||||
if (hostname.empty()) {
|
||||
// Use the hostname to resolve the network address to
|
||||
// use. Note: if the hostname does not resolve to an address (e.g.
|
||||
// because of misconfigured /etc/hosts file), this will not work.
|
||||
const auto hostNameMax = sysconf(_SC_HOST_NAME_MAX);
|
||||
auto buffer = std::unique_ptr<char[]>(new char[hostNameMax]);
|
||||
auto rv = gethostname(buffer.get(), hostNameMax);
|
||||
if (rv != 0) {
|
||||
throw std::system_error(errno, std::system_category());
|
||||
}
|
||||
attr.hostname = buffer.get();
|
||||
} else {
|
||||
attr.hostname = hostname;
|
||||
}
|
||||
|
||||
return ::gloo::transport::uv::CreateDevice(attr);
|
||||
}
|
||||
#endif
|
||||
|
||||
ProcessGroupGloo::ProcessGroupGloo(
|
||||
const std::shared_ptr<Store>& store,
|
||||
int rank,
|
||||
|
|
@ -644,7 +733,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
|
|||
// Construct from an existing metadata tensor to facilitate structured
|
||||
// access to metadata from peers, after gathering it.
|
||||
explicit SparseTensorMetadata(at::Tensor metadata)
|
||||
: metadata_(metadata), data_(metadata_.data_ptr<long>()) {
|
||||
: metadata_(metadata), data_(metadata_.data_ptr<int64_t>()) {
|
||||
AT_ASSERT(metadata.scalar_type() == at::kLong);
|
||||
AT_ASSERT(metadata.dim() == 1);
|
||||
AT_ASSERT(metadata.size(0) == dim);
|
||||
|
|
@ -694,7 +783,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
|
|||
|
||||
protected:
|
||||
at::Tensor metadata_;
|
||||
long* data_;
|
||||
int64_t* data_;
|
||||
};
|
||||
|
||||
// Sparse allreduce is implemented with allgather on indices and values.
|
||||
|
|
@ -719,7 +808,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
|
|||
// Sanity check dimensionality across ranks.
|
||||
{
|
||||
const auto expected = metadata[context->rank].sizes();
|
||||
for (size_t i = 0; i < context->size; i++) {
|
||||
for (auto i = 0; i < context->size; i++) {
|
||||
if (i == context->rank) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -733,11 +822,11 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
|
|||
auto values = allgather_values(input, metadata);
|
||||
|
||||
// Perform global reduction.
|
||||
AT_ASSERT(indices.size() == context->size);
|
||||
AT_ASSERT(values.size() == context->size);
|
||||
AT_ASSERT(static_cast<int>(indices.size()) == context->size);
|
||||
AT_ASSERT(static_cast<int>(values.size()) == context->size);
|
||||
auto output = at::sparse_coo_tensor(
|
||||
indices[0], values[0], input.sizes(), input.options());
|
||||
for (size_t i = 1; i < context->size; i++) {
|
||||
for (auto i = 1; i < context->size; i++) {
|
||||
output += at::sparse_coo_tensor(
|
||||
indices[i], values[i], input.sizes(), input.options());
|
||||
}
|
||||
|
|
@ -778,7 +867,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
|
|||
|
||||
// Allgather metadata
|
||||
gloo::AllgatherOptions opts(context);
|
||||
opts.setOutput(buffer.data_ptr<long>(), buffer.numel());
|
||||
opts.setOutput(buffer.data_ptr<int64_t>(), buffer.numel());
|
||||
opts.setTag(tag);
|
||||
gloo::allgather(opts);
|
||||
|
||||
|
|
@ -802,7 +891,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
|
|||
|
||||
// Allgather indices.
|
||||
gloo::AllgatherOptions opts(context);
|
||||
opts.setOutput(buffer.data_ptr<long>(), buffer.numel());
|
||||
opts.setOutput(buffer.data_ptr<int64_t>(), buffer.numel());
|
||||
opts.setTag(tag);
|
||||
gloo::allgather(opts);
|
||||
|
||||
|
|
|
|||
|
|
@ -127,6 +127,19 @@ class ProcessGroupGloo : public ProcessGroup {
|
|||
int threads;
|
||||
};
|
||||
|
||||
// Helper functions to create a new device object.
|
||||
// They are static functions on this class to keep them logically
|
||||
// separate from the rest of the code base (e.g. torch/csrc/distributed).
|
||||
|
||||
// Create new device instance for specific interface.
|
||||
static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface(
|
||||
const std::string& interface);
|
||||
|
||||
// Create new device instance for hostname or address.
|
||||
// If specified argument is empty, it defaults to this machine's hostname.
|
||||
static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
|
||||
const std::string& hostname);
|
||||
|
||||
explicit ProcessGroupGloo(
|
||||
const std::shared_ptr<Store>& store,
|
||||
int rank,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
#include <gloo/transport/tcp/device.h>
|
||||
|
||||
#include <ATen/cuda/CUDAMultiStreamGuard.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
|
|
@ -51,8 +49,8 @@ class AsyncTest {
|
|||
// Use tiny timeout to make this test run fast
|
||||
::c10d::ProcessGroupGloo::Options options;
|
||||
options.timeout = std::chrono::milliseconds(50);
|
||||
::gloo::transport::tcp::attr attr;
|
||||
options.devices.push_back(::gloo::transport::tcp::CreateDevice(attr));
|
||||
options.devices.push_back(
|
||||
::c10d::ProcessGroupGloo::createDeviceForHostname("127.0.0.1"));
|
||||
|
||||
pg_ = std::unique_ptr<::c10d::ProcessGroupGloo>(
|
||||
new ::c10d::ProcessGroupGloo(store, rank, size, options));
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@
|
|||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include <gloo/transport/tcp/device.h>
|
||||
|
||||
#include <c10d/FileStore.hpp>
|
||||
#include <c10d/ProcessGroupGloo.hpp>
|
||||
#include <c10d/test/TestUtils.hpp>
|
||||
|
|
@ -42,8 +40,8 @@ class SignalTest {
|
|||
// Use tiny timeout to make this test run fast
|
||||
::c10d::ProcessGroupGloo::Options options;
|
||||
options.timeout = std::chrono::milliseconds(50);
|
||||
::gloo::transport::tcp::attr attr;
|
||||
options.devices.push_back(::gloo::transport::tcp::CreateDevice(attr));
|
||||
options.devices.push_back(
|
||||
::c10d::ProcessGroupGloo::createDeviceForHostname("127.0.0.1"));
|
||||
|
||||
::c10d::ProcessGroupGloo pg(store, rank, size, options);
|
||||
|
||||
|
|
@ -127,9 +125,8 @@ class CollectiveTest {
|
|||
// Use tiny timeout to make this test run fast
|
||||
::c10d::ProcessGroupGloo::Options options;
|
||||
options.timeout = std::chrono::milliseconds(50);
|
||||
|
||||
::gloo::transport::tcp::attr attr;
|
||||
options.devices.push_back(::gloo::transport::tcp::CreateDevice(attr));
|
||||
options.devices.push_back(
|
||||
::c10d::ProcessGroupGloo::createDeviceForHostname("127.0.0.1"));
|
||||
|
||||
pg_ = std::unique_ptr<::c10d::ProcessGroupGloo>(
|
||||
new ::c10d::ProcessGroupGloo(store, rank, size, options));
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <signal.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/wait.h>
|
||||
#include <unistd.h>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user