mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove THD (#22065)
Summary: It's been ~9 months since moving THD to the `torch.distributed.deprecated` namespace (see https://github.com/pytorch/pytorch/issues/11405) and we haven't seen issues related to it, so it's time to remove it. Closes https://github.com/pytorch/pytorch/issues/18967. Pull Request resolved: https://github.com/pytorch/pytorch/pull/22065 Reviewed By: mrshenli Differential Revision: D15983669 Pulled By: pietern fbshipit-source-id: 2a2f5866f9a63040bc7cef3956d5fd215aba7165
This commit is contained in:
parent
bcb5fd8f06
commit
6ff0c6ca3f
|
|
@ -8,7 +8,6 @@ multiple variants of the library, summarized here:
|
|||
* THC = TorcH Cuda
|
||||
* THCS = TorcH Cuda Sparse (now defunct)
|
||||
* THCUNN = TorcH CUda Neural Network (see cunn)
|
||||
* THD = TorcH Distributed
|
||||
* THNN = TorcH Neural Network
|
||||
* THS = TorcH Sparse (now defunct)
|
||||
|
||||
|
|
|
|||
|
|
@ -769,7 +769,6 @@ ENDIF()
|
|||
DESTINATION share/cmake/Torch)
|
||||
|
||||
if (USE_DISTRIBUTED)
|
||||
add_subdirectory(${TORCH_SRC_DIR}/lib/THD lib_THD)
|
||||
if (NOT MSVC AND NOT APPLE)
|
||||
add_subdirectory(${TORCH_SRC_DIR}/lib/c10d lib_c10d)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -1,280 +0,0 @@
|
|||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
|
||||
Distributed communication package (deprecated) - torch.distributed.deprecated
|
||||
=============================================================================
|
||||
|
||||
.. warning::
|
||||
torch.distributed.deprecated is the older version of torch.distributed and
|
||||
currently deprecated. It will be removed soon. Please use and refer the doc
|
||||
for torch.distributed, which is the latest distributed communication
|
||||
package for PyTorch
|
||||
|
||||
.. automodule:: torch.distributed.deprecated
|
||||
.. currentmodule:: torch.distributed.deprecated
|
||||
|
||||
Currently torch.distributed.deprecated supports four backends, each with
|
||||
different capabilities. The table below shows which functions are available
|
||||
for use with CPU / CUDA tensors.
|
||||
MPI supports cuda only if the implementation used to build PyTorch supports it.
|
||||
|
||||
|
||||
+------------+-----------+-----------+-----------+-----------+
|
||||
| Backend | ``tcp`` | ``gloo`` | ``mpi`` | ``nccl`` |
|
||||
+------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
| Device | CPU | GPU | CPU | GPU | CPU | GPU | CPU | GPU |
|
||||
+============+=====+=====+=====+=====+=====+=====+=====+=====+
|
||||
| send | ✓ | ✘ | ✘ | ✘ | ✓ | ? | ✘ | ✘ |
|
||||
+------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
| recv | ✓ | ✘ | ✘ | ✘ | ✓ | ? | ✘ | ✘ |
|
||||
+------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
| broadcast | ✓ | ✘ | ✓ | ✓ | ✓ | ? | ✘ | ✓ |
|
||||
+------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
| all_reduce | ✓ | ✘ | ✓ | ✓ | ✓ | ? | ✘ | ✓ |
|
||||
+------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
| reduce | ✓ | ✘ | ✘ | ✘ | ✓ | ? | ✘ | ✓ |
|
||||
+------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
| all_gather | ✓ | ✘ | ✘ | ✘ | ✓ | ? | ✘ | ✓ |
|
||||
+------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
| gather | ✓ | ✘ | ✘ | ✘ | ✓ | ? | ✘ | ✘ |
|
||||
+------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
| scatter | ✓ | ✘ | ✘ | ✘ | ✓ | ? | ✘ | ✘ |
|
||||
+------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
| barrier | ✓ | ✘ | ✓ | ✓ | ✓ | ? | ✘ | ✘ |
|
||||
+------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
|
||||
.. _distributed-deprecated-basics:
|
||||
|
||||
Basics
|
||||
------
|
||||
|
||||
The `torch.distributed.deprecated` package provides PyTorch support and communication primitives
|
||||
for multiprocess parallelism across several computation nodes running on one or more
|
||||
machines. The class :func:`torch.nn.parallel.deprecated.DistributedDataParallel` builds on this
|
||||
functionality to provide synchronous distributed training as a wrapper around any
|
||||
PyTorch model. This differs from the kinds of parallelism provided by
|
||||
:doc:`multiprocessing` and :func:`torch.nn.DataParallel` in that it supports
|
||||
multiple network-connected machines and in that the user must explicitly launch a separate
|
||||
copy of the main training script for each process.
|
||||
|
||||
In the single-machine synchronous case, `torch.distributed.deprecated` or the
|
||||
:func:`torch.nn.parallel.deprecated.DistributedDataParallel` wrapper may still have advantages over other
|
||||
approaches to data-parallelism, including :func:`torch.nn.DataParallel`:
|
||||
|
||||
* Each process maintains its own optimizer and performs a complete optimization step with each
|
||||
iteration. While this may appear redundant, since the gradients have already been gathered
|
||||
together and averaged across processes and are thus the same for every process, this means
|
||||
that no parameter broadcast step is needed, reducing time spent transferring tensors between
|
||||
nodes.
|
||||
* Each process contains an independent Python interpreter, eliminating the extra interpreter
|
||||
overhead and "GIL-thrashing" that comes from driving several execution threads, model
|
||||
replicas, or GPUs from a single Python process. This is especially important for models that
|
||||
make heavy use of the Python runtime, including models with recurrent layers or many small
|
||||
components.
|
||||
|
||||
Initialization
|
||||
--------------
|
||||
|
||||
The package needs to be initialized using the :func:`torch.distributed.deprecated.init_process_group`
|
||||
function before calling any other methods. This blocks until all processes have
|
||||
joined.
|
||||
|
||||
.. autofunction:: init_process_group
|
||||
|
||||
.. autofunction:: get_rank
|
||||
|
||||
.. autofunction:: get_world_size
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Currently three initialization methods are supported:
|
||||
|
||||
TCP initialization
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
There are two ways to initialize using TCP, both requiring a network address
|
||||
reachable from all processes and a desired ``world_size``. The first way
|
||||
requires specifying an address that belongs to the rank 0 process. This
|
||||
initialization method requires that all processes have manually specified ranks.
|
||||
|
||||
Alternatively, the address has to be a valid IP multicast address, in which case
|
||||
ranks can be assigned automatically. Multicast initialization also supports
|
||||
a ``group_name`` argument, which allows you to use the same address for multiple
|
||||
jobs, as long as they use different group names.
|
||||
|
||||
::
|
||||
|
||||
import torch.distributed.deprecated as dist
|
||||
|
||||
# Use address of one of the machines
|
||||
dist.init_process_group(backend, init_method='tcp://10.1.1.20:23456', rank=args.rank, world_size=4)
|
||||
|
||||
# or a multicast address - rank will be assigned automatically if unspecified
|
||||
dist.init_process_group(backend, init_method='tcp://[ff15:1e18:5d4c:4cf0:d02d:b659:53ba:b0a7]:23456',
|
||||
world_size=4)
|
||||
|
||||
Shared file-system initialization
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Another initialization method makes use of a file system that is shared and
|
||||
visible from all machines in a group, along with a desired ``world_size``. The URL should start
|
||||
with ``file://`` and contain a path to a non-existent file (in an existing
|
||||
directory) on a shared file system. This initialization method also supports a
|
||||
``group_name`` argument, which allows you to use the same shared file path for
|
||||
multiple jobs, as long as they use different group names.
|
||||
|
||||
.. warning::
|
||||
This method assumes that the file system supports locking using ``fcntl`` - most
|
||||
local systems and NFS support it.
|
||||
|
||||
::
|
||||
|
||||
import torch.distributed.deprecated as dist
|
||||
|
||||
# Rank will be assigned automatically if unspecified
|
||||
dist.init_process_group(backend, init_method='file:///mnt/nfs/sharedfile',
|
||||
world_size=4, group_name=args.group)
|
||||
|
||||
Environment variable initialization
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
This method will read the configuration from environment variables, allowing
|
||||
one to fully customize how the information is obtained. The variables to be set
|
||||
are:
|
||||
|
||||
* ``MASTER_PORT`` - required; has to be a free port on machine with rank 0
|
||||
* ``MASTER_ADDR`` - required (except for rank 0); address of rank 0 node
|
||||
* ``WORLD_SIZE`` - required; can be set either here, or in a call to init function
|
||||
* ``RANK`` - required; can be set either here, or in a call to init function
|
||||
|
||||
The machine with rank 0 will be used to set up all connections.
|
||||
|
||||
This is the default method, meaning that ``init_method`` does not have to be specified (or
|
||||
can be ``env://``).
|
||||
|
||||
Groups
|
||||
------
|
||||
|
||||
By default collectives operate on the default group (also called the world) and
|
||||
require all processes to enter the distributed function call. However, some workloads can benefit
|
||||
from more fine-grained communication. This is where distributed groups come
|
||||
into play. :func:`~torch.distributed.deprecated.new_group` function can be
|
||||
used to create new groups, with arbitrary subsets of all processes. It returns
|
||||
an opaque group handle that can be given as a ``group`` argument to all collectives
|
||||
(collectives are distributed functions to exchange information in certain well-known programming patterns).
|
||||
|
||||
.. autofunction:: new_group
|
||||
|
||||
Point-to-point communication
|
||||
----------------------------
|
||||
|
||||
.. autofunction:: send
|
||||
|
||||
.. autofunction:: recv
|
||||
|
||||
:func:`~torch.distributed.deprecated.isend` and :func:`~torch.distributed.deprecated.irecv`
|
||||
return distributed request objects when used. In general, the type of this object is unspecified
|
||||
as they should never be created manually, but they are guaranteed to support two methods:
|
||||
|
||||
* ``is_completed()`` - returns True if the operation has finished
|
||||
* ``wait()`` - will block the process until the operation is finished.
|
||||
``is_completed()`` is guaranteed to return True once it returns.
|
||||
|
||||
When using the MPI backend, :func:`~torch.distributed.deprecated.isend` and :func:`~torch.distributed.deprecated.irecv`
|
||||
support non-overtaking, which has some guarantees on supporting message order. For more detail, see
|
||||
http://mpi-forum.org/docs/mpi-2.2/mpi22-report/node54.htm#Node54
|
||||
|
||||
.. autofunction:: isend
|
||||
|
||||
.. autofunction:: irecv
|
||||
|
||||
Collective functions
|
||||
--------------------
|
||||
|
||||
.. autofunction:: broadcast
|
||||
|
||||
.. autofunction:: all_reduce
|
||||
|
||||
.. autofunction:: reduce
|
||||
|
||||
.. autofunction:: all_gather
|
||||
|
||||
.. autofunction:: gather
|
||||
|
||||
.. autofunction:: scatter
|
||||
|
||||
.. autofunction:: barrier
|
||||
|
||||
Multi-GPU collective functions
|
||||
------------------------------
|
||||
|
||||
If you have more than one GPU on each node, when using the NCCL backend,
|
||||
:func:`~torch.distributed.deprecated.broadcast_multigpu`
|
||||
:func:`~torch.distributed.deprecated.all_reduce_multigpu`
|
||||
:func:`~torch.distributed.deprecated.reduce_multigpu` and
|
||||
:func:`~torch.distributed.deprecated.all_gather_multigpu` support distributed collective
|
||||
operations among multiple GPUs within each node. These functions can potentially
|
||||
improve the overall distributed training performance and be easily used by
|
||||
passing a list of tensors. Each Tensor in the passed tensor list needs
|
||||
to be on a separate GPU device of the host where the function is called. Note
|
||||
that the length of the tensor list needs to be identical among all the
|
||||
distributed processes. Also note that currently the multi-GPU collective
|
||||
functions are only supported by the NCCL backend.
|
||||
|
||||
For example, if the system we use for distributed training has 2 nodes, each
|
||||
of which has 8 GPUs. On each of the 16 GPUs, there is a tensor that we would
|
||||
like to all-reduce. The following code can serve as a reference:
|
||||
|
||||
Code running on Node 0
|
||||
|
||||
::
|
||||
|
||||
import torch
|
||||
import torch.distributed.deprecated as dist
|
||||
|
||||
dist.init_process_group(backend="nccl",
|
||||
init_method="file:///distributed_test",
|
||||
world_size=2,
|
||||
rank=0)
|
||||
tensor_list = []
|
||||
for dev_idx in range(torch.cuda.device_count()):
|
||||
tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx))
|
||||
|
||||
dist.all_reduce_multigpu(tensor_list)
|
||||
|
||||
Code running on Node 1
|
||||
|
||||
::
|
||||
|
||||
import torch
|
||||
import torch.distributed.deprecated as dist
|
||||
|
||||
dist.init_process_group(backend="nccl",
|
||||
init_method="file:///distributed_test",
|
||||
world_size=2,
|
||||
rank=1)
|
||||
tensor_list = []
|
||||
for dev_idx in range(torch.cuda.device_count()):
|
||||
tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx))
|
||||
|
||||
dist.all_reduce_multigpu(tensor_list)
|
||||
|
||||
After the call, all 16 tensors on the two nodes will have the all-reduced value
|
||||
of 16
|
||||
|
||||
.. autofunction:: broadcast_multigpu
|
||||
|
||||
.. autofunction:: all_reduce_multigpu
|
||||
|
||||
.. autofunction:: reduce_multigpu
|
||||
|
||||
.. autofunction:: all_gather_multigpu
|
||||
|
||||
|
||||
Launch utility
|
||||
--------------
|
||||
|
||||
The `torch.distributed.deprecated` package also provides a launch utility in
|
||||
`torch.distributed.deprecated.launch`.
|
||||
|
||||
.. automodule:: torch.distributed.launch
|
||||
|
|
@ -54,7 +54,6 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
|
|||
torch.utils.tensorboard (experimental) <tensorboard>
|
||||
onnx
|
||||
torch.__config__ <__config__>
|
||||
torch.distributed.deprecated <distributed_deprecated>
|
||||
|
||||
.. toctree::
|
||||
:glob:
|
||||
|
|
|
|||
1
setup.py
1
setup.py
|
|
@ -397,7 +397,6 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||
else:
|
||||
report('-- Not using NCCL')
|
||||
if cmake_cache_vars['USE_DISTRIBUTED']:
|
||||
report('-- Building with THD distributed package ')
|
||||
if IS_LINUX:
|
||||
report('-- Building with c10d distributed package ')
|
||||
else:
|
||||
|
|
|
|||
16
test/run_test.py
Normal file → Executable file
16
test/run_test.py
Normal file → Executable file
|
|
@ -42,7 +42,6 @@ TESTS = [
|
|||
'optim',
|
||||
'quantized',
|
||||
'sparse',
|
||||
'thd_distributed',
|
||||
'torch',
|
||||
'type_info',
|
||||
'type_hints',
|
||||
|
|
@ -55,7 +54,6 @@ TESTS = [
|
|||
|
||||
WINDOWS_BLACKLIST = [
|
||||
'distributed',
|
||||
'thd_distributed',
|
||||
]
|
||||
|
||||
ROCM_BLACKLIST = [
|
||||
|
|
@ -64,7 +62,6 @@ ROCM_BLACKLIST = [
|
|||
'distributed',
|
||||
'multiprocessing',
|
||||
'nccl',
|
||||
'thd_distributed',
|
||||
]
|
||||
|
||||
DISTRIBUTED_TESTS_CONFIG = {
|
||||
|
|
@ -85,16 +82,6 @@ if dist.is_available():
|
|||
}
|
||||
|
||||
|
||||
THD_DISTRIBUTED_TESTS_CONFIG = {
|
||||
'tcp': {
|
||||
'WORLD_SIZE': '3'
|
||||
},
|
||||
'gloo': {
|
||||
'WORLD_SIZE': '2' if torch.cuda.device_count() == 2 else '3'
|
||||
},
|
||||
# THD NCCL and MPI tests are known to be flaky in CI
|
||||
}
|
||||
|
||||
# https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python
|
||||
SIGNALS_TO_NAMES_DICT = {getattr(signal, n): n for n in dir(signal)
|
||||
if n.startswith('SIG') and '_' not in n}
|
||||
|
|
@ -194,8 +181,6 @@ def test_distributed(executable, test_module, test_directory, options):
|
|||
print_to_stderr(
|
||||
'MPI not available -- MPI backend tests will be skipped')
|
||||
config = DISTRIBUTED_TESTS_CONFIG
|
||||
if test_module == "test_thd_distributed":
|
||||
config = THD_DISTRIBUTED_TESTS_CONFIG
|
||||
for backend, env_vars in config.items():
|
||||
if backend == 'mpi' and not mpi_available:
|
||||
continue
|
||||
|
|
@ -243,7 +228,6 @@ def test_distributed(executable, test_module, test_directory, options):
|
|||
CUSTOM_HANDLERS = {
|
||||
'cpp_extensions': test_cpp_extensions,
|
||||
'distributed': test_distributed,
|
||||
'thd_distributed': test_distributed,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -225,7 +225,6 @@ def add_torch_libs():
|
|||
"torch/csrc/autograd/python_variable.cpp",
|
||||
"torch/csrc/autograd/python_variable_indexing.cpp",
|
||||
"torch/csrc/byte_order.cpp",
|
||||
"torch/csrc/distributed/Module.cpp",
|
||||
"torch/csrc/distributed/c10d/comm.cpp",
|
||||
"torch/csrc/distributed/c10d/init.cpp",
|
||||
"torch/csrc/distributed/c10d/reducer.cpp",
|
||||
|
|
@ -427,7 +426,6 @@ def add_torch_libs():
|
|||
":torch-cpp-cpu",
|
||||
":thnn",
|
||||
"//caffe2/torch/fb/init:init",
|
||||
"//caffe2/torch/lib/THD:THD_cpu",
|
||||
"//caffe2/torch/lib/c10d:c10d_cpu",
|
||||
"//caffe2/torch/lib/libshm:libshm",
|
||||
],
|
||||
|
|
@ -448,7 +446,6 @@ def add_torch_libs():
|
|||
":torch-cpp-cuda",
|
||||
":thnn",
|
||||
"//caffe2/torch/fb/init:init",
|
||||
"//caffe2/torch/lib/THD:THD",
|
||||
"//caffe2/torch/lib/c10d:c10d",
|
||||
"//caffe2/torch/lib/libshm:libshm",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ then
|
|||
python tools/clang_tidy.py \
|
||||
--paths torch/csrc \
|
||||
--diff HEAD \
|
||||
-g"-torch/csrc/distributed/Module.cpp" \
|
||||
-g"-torch/csrc/jit/export.cpp" \
|
||||
-g"-torch/csrc/jit/import.cpp" \
|
||||
-j
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ time python tools/clang_tidy.py \
|
|||
--verbose \
|
||||
--paths torch/csrc/ \
|
||||
--diff "$BASE_BRANCH" \
|
||||
-g"-torch/csrc/distributed/Module.cpp" \
|
||||
-g"-torch/csrc/jit/export.cpp" \
|
||||
-g"-torch/csrc/jit/import.cpp" \
|
||||
-g"-torch/csrc/jit/netdef_converter.cpp" \
|
||||
|
|
|
|||
|
|
@ -309,7 +309,6 @@ class CMake:
|
|||
CMAKE_CXX_FLAGS=cflags,
|
||||
CMAKE_EXE_LINKER_FLAGS=ldflags,
|
||||
CMAKE_SHARED_LINKER_FLAGS=ldflags,
|
||||
THD_SO_VERSION="1",
|
||||
CUDA_NVCC_EXECUTABLE=escape_path(os.getenv('CUDA_NVCC_EXECUTABLE')),
|
||||
**build_options)
|
||||
|
||||
|
|
|
|||
|
|
@ -221,9 +221,6 @@ if (USE_ROCM)
|
|||
endif()
|
||||
|
||||
if (USE_DISTRIBUTED)
|
||||
list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/distributed/Module.cpp)
|
||||
list(APPEND TORCH_PYTHON_INCLUDE_DIRECTORIES ${TORCH_SRC_DIR}/lib/THD)
|
||||
list(APPEND TORCH_PYTHON_LINK_LIBRARIES THD)
|
||||
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_DISTRIBUTED)
|
||||
if (NOT MSVC AND NOT APPLE)
|
||||
list(APPEND TORCH_PYTHON_SRCS
|
||||
|
|
|
|||
|
|
@ -536,22 +536,8 @@ void init__THCUNN(PyObject*);
|
|||
|
||||
}} // namespace torch::nn
|
||||
|
||||
bool THDPDoubleStorage_init(PyObject *module);
|
||||
bool THDPFloatStorage_init(PyObject *module);
|
||||
//bool THDPHalfStorage_init(PyObject *module);
|
||||
bool THDPLongStorage_init(PyObject *module);
|
||||
bool THDPIntStorage_init(PyObject *module);
|
||||
bool THDPShortStorage_init(PyObject *module);
|
||||
bool THDPCharStorage_init(PyObject *module);
|
||||
bool THDPByteStorage_init(PyObject *module);
|
||||
bool THDPBoolStorage_init(PyObject *module);
|
||||
|
||||
static std::vector<PyMethodDef> methods;
|
||||
|
||||
#ifdef USE_DISTRIBUTED
|
||||
PyMethodDef* THDPModule_methods();
|
||||
#endif
|
||||
|
||||
// TODO: Refactor this in some less manual way
|
||||
#ifdef USE_CUDNN
|
||||
static PyObject * THCUDNN_cudnn_version(PyObject *self, PyObject *args)
|
||||
|
|
@ -624,7 +610,6 @@ PyObject* initModule() {
|
|||
THPUtils_addPyMethodDefs(methods, THCUDNN_methods());
|
||||
#endif
|
||||
#ifdef USE_DISTRIBUTED
|
||||
THPUtils_addPyMethodDefs(methods, THDPModule_methods());
|
||||
#ifdef USE_C10D
|
||||
THPUtils_addPyMethodDefs(methods, torch::distributed::c10d::python_functions());
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1,916 +0,0 @@
|
|||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/distributed/THDP.h>
|
||||
#include <torch/csrc/PythonTypes.h>
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <torch/csrc/cuda/Stream.h>
|
||||
#endif
|
||||
|
||||
|
||||
static std::unordered_map<std::string, THDChannelType> name2channel_type = {
|
||||
{"mpi", THDChannelMPI},
|
||||
{"tcp", THDChannelTCP},
|
||||
{"gloo", THDChannelGloo},
|
||||
{"nccl", THDChannelNccl},
|
||||
};
|
||||
|
||||
static std::unordered_map<PyObject*, THDReduceOp> obj2reduceop;
|
||||
static std::unordered_map<PyObject*, THDGroup> obj2group;
|
||||
|
||||
#ifdef USE_CUDA
|
||||
extern THCState* state;
|
||||
#endif
|
||||
|
||||
PyObject* THDPModule_initProcessGroup(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 5 || !THPUtils_checkString(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkString(PyTuple_GET_ITEM(args, 1)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 2)) ||
|
||||
!THPUtils_checkString(PyTuple_GET_ITEM(args, 3)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 4))) {
|
||||
THPUtils_invalidArguments(args, nullptr, "init_process_group", 1, "(string backend, string init_method, int world_size, string group_name, int rank)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string backend_name = THPUtils_unpackString(PyTuple_GET_ITEM(args, 0));
|
||||
std::string init_method = THPUtils_unpackString(PyTuple_GET_ITEM(args, 1));
|
||||
int world_size = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 2));
|
||||
std::string group_name = THPUtils_unpackString(PyTuple_GET_ITEM(args, 3));
|
||||
int rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 4));
|
||||
|
||||
THDChannelType channel_type = name2channel_type.at(backend_name);
|
||||
{
|
||||
AutoNoGIL nogil;
|
||||
THDProcessGroupInit(channel_type, init_method, world_size, group_name, rank);
|
||||
}
|
||||
#ifdef USE_CUDA
|
||||
THDSetCudaStatePtr(&state);
|
||||
#endif
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_destroyProcessGroup(PyObject *_unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
{
|
||||
AutoNoGIL nogil;
|
||||
THDProcessGroupDestroy();
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
PyObject* THDPModule_registerStream(PyObject *_unused, PyObject *_stream)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPUtils_assert(THCPStream_Check(_stream), "_register_stream expects a "
|
||||
"torch.cuda.Stream object");
|
||||
THCPStream *stream = (THCPStream*)_stream;
|
||||
THDRegisterCudaStream(stream->cuda_stream);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
#endif
|
||||
|
||||
PyObject* THDPModule_getRank(PyObject *_unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
return PyInt_FromLong(THDGetRank());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_getNumProcesses(PyObject *_unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
return PyInt_FromLong(THDGetNumProcesses());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
extern PyObject* THCPDoubleTensorClass;
|
||||
extern PyObject* THCPFloatTensorClass;
|
||||
extern PyObject* THCPHalfTensorClass;
|
||||
extern PyObject* THCPLongTensorClass;
|
||||
extern PyObject* THCPIntTensorClass;
|
||||
extern PyObject* THCPShortTensorClass;
|
||||
extern PyObject* THCPCharTensorClass;
|
||||
extern PyObject* THCPByteTensorClass;
|
||||
#endif
|
||||
|
||||
THDTensorDescriptor THDPModule_makeDescriptor(PyObject *obj) {
|
||||
auto var = (THPVariable*)obj;
|
||||
return var->cdata.tensor_data();
|
||||
}
|
||||
|
||||
static THDRequest* _unpackRequest(PyObject *obj)
|
||||
{
|
||||
return static_cast<THDRequest*>(THPWrapper_get(obj));
|
||||
}
|
||||
|
||||
static THDReduceOp _getReduceOp(PyObject *obj)
|
||||
{
|
||||
auto it = obj2reduceop.find(obj);
|
||||
if (it == obj2reduceop.end()) {
|
||||
throw std::runtime_error("op should be a constant from "
|
||||
"torch.distributed.deprecated.reduce_op");
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static THDGroup _getGroup(PyObject *obj)
|
||||
{
|
||||
auto it = obj2group.find(obj);
|
||||
if (it == obj2group.end()) {
|
||||
if (!THPUtils_checkLong(obj))
|
||||
throw std::runtime_error("group should be an int or one of the values "
|
||||
"from torch.distributed.deprecated.group");
|
||||
return THPUtils_unpackLong(obj);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
PyObject* THDPModule_clearGroupCache(PyObject *_unused, PyObject *args) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 1) {
|
||||
THPUtils_invalidArguments(args, nullptr, "clear_group_cache", 1, "(group gr)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 0));
|
||||
|
||||
{
|
||||
AutoNoGIL nogil;
|
||||
THDClearGroupCache(group);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_isend(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 2 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, nullptr, "isend", 1, "(tensor input, int dst_rank)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
THDRequest* req;
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
req = THDIsend(desc, dst_rank);
|
||||
}
|
||||
return THPWrapper_New(req, (void(*)(void*))THDRequest_free);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_irecv(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 2 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, nullptr, "irecv", 1, "(tensor output, int src_rank)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
THDRequest* req;
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
req = THDIrecv(desc, src_rank);
|
||||
}
|
||||
return THPWrapper_New(req, (void(*)(void*))THDRequest_free);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_send(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 2 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, nullptr, "send", 1, "(tensor input, int dst_rank)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDSend(desc, dst_rank);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_recvAnySource(PyObject *_unused, PyObject *_tensor)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (!THPVariable_Check(_tensor)) {
|
||||
THPUtils_invalidArguments(_tensor, nullptr, "recv", 1, "(tensor output)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto desc = THDPModule_makeDescriptor(_tensor);
|
||||
int sender;
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
sender = THDRecvAnySource(desc);
|
||||
}
|
||||
return PyInt_FromLong(sender);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_recv(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 2 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, nullptr, "recv", 1, "(tensor output, int src_rank)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDRecv(desc, src_rank);
|
||||
}
|
||||
// Return sender rank
|
||||
Py_INCREF(PyTuple_GET_ITEM(args, 1));
|
||||
return PyTuple_GET_ITEM(args, 1);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
PyObject* THDPModule_allReduceMultiGPU(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
std::vector<at::Tensor> descriptors;
|
||||
size_t length;
|
||||
THDGroup group;
|
||||
THDReduceOp op;
|
||||
THPObjectPtr sequence;
|
||||
|
||||
if (PyTuple_GET_SIZE(args) != 3) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
if (!PySequence_Check(PyTuple_GET_ITEM(args, 0))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
sequence = THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
|
||||
"expected a sequence"));
|
||||
if (!sequence.get()) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
length = static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
|
||||
|
||||
descriptors.reserve(length);
|
||||
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence.get(), i))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
descriptors.push_back(
|
||||
THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence.get(), i))
|
||||
);
|
||||
}
|
||||
|
||||
group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
op = _getReduceOp(PyTuple_GET_ITEM(args, 1));
|
||||
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDAllReduceMultiGPU(descriptors.data(), length, op, group);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
|
||||
invalid_arguments:
|
||||
THPUtils_invalidArguments(args, nullptr, "all_reduce_multigpu", 1,
|
||||
"(list[tensor] in_out, reduce_op op, group gr)");
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
PyObject* THDPModule_reduceMultiGPU(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPObjectPtr sequence;
|
||||
size_t length;
|
||||
std::vector<at::Tensor> descriptors;
|
||||
THDGroup group;
|
||||
THDReduceOp op;
|
||||
int dst_rank;
|
||||
|
||||
if (PyTuple_GET_SIZE(args) != 4) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
if (!PySequence_Check(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
sequence = THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
|
||||
"expected a sequence"));
|
||||
if (!sequence.get()) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
length = static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
|
||||
|
||||
descriptors.reserve(length);
|
||||
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence.get(), i))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
descriptors.push_back(
|
||||
THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence.get(), i))
|
||||
);
|
||||
}
|
||||
|
||||
group = _getGroup(PyTuple_GET_ITEM(args, 3));
|
||||
op = _getReduceOp(PyTuple_GET_ITEM(args, 2));
|
||||
dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDReduceMultiGPU(descriptors.data(), length, op, dst_rank, group);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
|
||||
invalid_arguments:
|
||||
THPUtils_invalidArguments(args, nullptr, "reduce_multigpu", 1,
|
||||
"(list[tensor] in_out, int dst_rank, "
|
||||
"reduce_op op, group gr)");
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
PyObject* THDPModule_broadcastMultiGPU(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPObjectPtr sequence;
|
||||
size_t length;
|
||||
std::vector<at::Tensor> descriptors;
|
||||
THDGroup group;
|
||||
int src_rank;
|
||||
|
||||
if (PyTuple_GET_SIZE(args) != 3) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
if (!PySequence_Check(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
sequence = THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
|
||||
"expected a sequence"));
|
||||
if (!sequence.get()) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
length = static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
|
||||
|
||||
descriptors.reserve(length);
|
||||
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence.get(), i))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
descriptors.push_back(
|
||||
THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence.get(), i))
|
||||
);
|
||||
}
|
||||
|
||||
group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDBroadcastMultiGPU(descriptors.data(), length, src_rank, group);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
|
||||
invalid_arguments:
|
||||
THPUtils_invalidArguments(args, nullptr, "broadcast_multigpu", 1,
|
||||
"(list[tensor] in_out, int src_rank, group gr)");
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
PyObject* THDPModule_allGatherMultiGPU(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPObjectPtr sequence_one;
|
||||
THPObjectPtr sequence_two;
|
||||
|
||||
size_t length_one;
|
||||
size_t length_two;
|
||||
|
||||
std::vector<at::Tensor> output_descriptors;
|
||||
std::vector<at::Tensor> input_descriptors;
|
||||
|
||||
THDGroup group;
|
||||
|
||||
if (PyTuple_GET_SIZE(args) != 3) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
if (!PySequence_Check(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!PySequence_Check(PyTuple_GET_ITEM(args, 1))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
sequence_one = THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
|
||||
"expected a sequence"));
|
||||
sequence_two = THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 1),
|
||||
"expected a sequence"));
|
||||
|
||||
if (!sequence_one.get() || !sequence_two.get()) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
length_one = static_cast<size_t>(
|
||||
PySequence_Fast_GET_SIZE(sequence_one.get()));
|
||||
|
||||
length_two = static_cast<size_t>(
|
||||
PySequence_Fast_GET_SIZE(sequence_two.get()));
|
||||
|
||||
if (length_one != length_two) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
output_descriptors.reserve(length_one);
|
||||
input_descriptors.reserve(length_two);
|
||||
|
||||
// Get the input list
|
||||
for (size_t i = 0; i < length_two; ++i) {
|
||||
if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence_two.get(), i)) ||
|
||||
!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence_one.get(), i))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
input_descriptors.push_back(
|
||||
THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence_two.get(), i))
|
||||
);
|
||||
|
||||
output_descriptors.push_back(
|
||||
THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence_one.get(), i))
|
||||
);
|
||||
}
|
||||
|
||||
group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDAllGatherMultiGPU(output_descriptors.data(),
|
||||
length_one,
|
||||
input_descriptors.data(),
|
||||
length_two,
|
||||
group);
|
||||
}
|
||||
|
||||
Py_RETURN_NONE;
|
||||
|
||||
invalid_arguments:
|
||||
THPUtils_invalidArguments(args, nullptr, "all_gather_multigpu", 1,
|
||||
"(list[list[tensor]] output, list[tensor] input, group gr)");
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
PyObject* THDPModule_allReduce(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 3 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0))) {
|
||||
THPUtils_invalidArguments(args, nullptr, "all_reduce", 1, "(tensor in_out, reduce_op op, group gr)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
THDReduceOp op = _getReduceOp(PyTuple_GET_ITEM(args, 1));
|
||||
auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDAllReduce(desc, op, group);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_reduce(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 4 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, nullptr, "reduce", 1,
|
||||
"(tensor reduced, int dst_rank, reduce_op op, group gr)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 3));
|
||||
THDReduceOp op = _getReduceOp(PyTuple_GET_ITEM(args, 2));
|
||||
auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDReduce(desc, op, dst_rank, group);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_broadcast(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 3 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, nullptr, "broadcast", 1,
|
||||
"(tensor src_dst, int src_rank, group gr)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDBroadcast(desc, src_rank, group);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_allGather(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPObjectPtr sequence;
|
||||
size_t length;
|
||||
std::vector<at::Tensor> descriptors;
|
||||
THDGroup group;
|
||||
at::Tensor desc;
|
||||
|
||||
if (PyTuple_GET_SIZE(args) != 3 ||
|
||||
!PySequence_Check(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPVariable_Check(PyTuple_GET_ITEM(args, 1))) {
|
||||
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
sequence = THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
|
||||
"expected a sequence"));
|
||||
if (!sequence.get()) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
length = static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
|
||||
|
||||
descriptors.reserve(length);
|
||||
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence.get(), i)))
|
||||
goto invalid_arguments;
|
||||
|
||||
descriptors.push_back(
|
||||
THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence.get(), i))
|
||||
);
|
||||
}
|
||||
|
||||
group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 1));
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDAllGather(descriptors.data(), length, desc, group);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
|
||||
invalid_arguments:
|
||||
THPUtils_invalidArguments(args, nullptr, "allGather", 1,
|
||||
"(list[tensor] output, tensor input, group gr)");
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_gatherSend(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 3 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0))) {
|
||||
THPUtils_invalidArguments(args, nullptr, "gatherSend", 1,
|
||||
"(tensor input, int dst_rank, group gr)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDGatherSend(desc, dst_rank, group);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_gatherRecv(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPObjectPtr sequence;
|
||||
size_t length;
|
||||
std::vector<at::Tensor> descriptors;
|
||||
THDGroup group;
|
||||
at::Tensor desc;
|
||||
|
||||
if (PyTuple_GET_SIZE(args) != 3 ||
|
||||
!PySequence_Check(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPVariable_Check(PyTuple_GET_ITEM(args, 1))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
sequence = THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
|
||||
"expected a sequence"));
|
||||
if (!sequence.get()) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
length = static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
|
||||
|
||||
descriptors.reserve(length);
|
||||
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence.get(), i)))
|
||||
goto invalid_arguments;
|
||||
|
||||
descriptors.push_back(
|
||||
THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence.get(), i))
|
||||
);
|
||||
}
|
||||
|
||||
desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 1));
|
||||
group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDGatherRecv(descriptors.data(), length, desc, group);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
|
||||
invalid_arguments:
|
||||
THPUtils_invalidArguments(args, nullptr, "gatherRecv", 1,
|
||||
"(list[tensor] output, tensor input, group gr)");
|
||||
return nullptr;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_scatterSend(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPObjectPtr sequence;
|
||||
size_t length;
|
||||
std::vector<at::Tensor> descriptors;
|
||||
THDGroup group;
|
||||
at::Tensor desc;
|
||||
|
||||
if (PyTuple_GET_SIZE(args) != 3 ||
|
||||
!PySequence_Check(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPVariable_Check(PyTuple_GET_ITEM(args, 1))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
sequence = THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
|
||||
"expected a sequence"));
|
||||
if (!sequence.get()) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
length = static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
|
||||
|
||||
descriptors.reserve(length);
|
||||
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (!THPVariable_Check(PySequence_Fast_GET_ITEM(sequence.get(), i)))
|
||||
goto invalid_arguments;
|
||||
|
||||
descriptors.push_back(
|
||||
THDPModule_makeDescriptor(PySequence_Fast_GET_ITEM(sequence.get(), i))
|
||||
);
|
||||
}
|
||||
|
||||
desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 1));
|
||||
group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDScatterSend(descriptors.data(), length, desc, group);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
|
||||
invalid_arguments:
|
||||
THPUtils_invalidArguments(args, nullptr, "scatterSend", 1,
|
||||
"(list[tensor] input, tensor output, group gr)");
|
||||
return nullptr;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_scatterRecv(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (PyTuple_GET_SIZE(args) != 3 || !THPVariable_Check(PyTuple_GET_ITEM(args, 0)) ||
|
||||
!THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
|
||||
THPUtils_invalidArguments(args, nullptr, "scatterRecv", 1,
|
||||
"(tensor output, int src_rank, group gr)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
THDGroup group = _getGroup(PyTuple_GET_ITEM(args, 2));
|
||||
auto desc = THDPModule_makeDescriptor(PyTuple_GET_ITEM(args, 0));
|
||||
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDScatterRecv(desc, src_rank, group);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_barrier(PyObject *_unused, PyObject *_group)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDBarrier(_getGroup(_group));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_newGroup(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPObjectPtr sequence;
|
||||
size_t length;
|
||||
std::vector<int> ranks;
|
||||
|
||||
if (PyTuple_GET_SIZE(args) != 1 ||
|
||||
!PySequence_Check(PyTuple_GET_ITEM(args, 0))) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
sequence = THPObjectPtr(PySequence_Fast(PyTuple_GET_ITEM(args, 0),
|
||||
"expected a sequence"));
|
||||
if (!sequence.get()) {
|
||||
goto invalid_arguments;
|
||||
}
|
||||
|
||||
length = static_cast<size_t>(PySequence_Fast_GET_SIZE(sequence.get()));
|
||||
|
||||
ranks.reserve(length);
|
||||
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (!THPUtils_checkLong(PySequence_Fast_GET_ITEM(sequence.get(), i)))
|
||||
goto invalid_arguments;
|
||||
|
||||
ranks.push_back(THPUtils_unpackLong(
|
||||
PySequence_Fast_GET_ITEM(sequence.get(), i)));
|
||||
|
||||
for (size_t j = 0; j < i; ++j)
|
||||
THPUtils_assert(ranks[i] != ranks[j], "ranks should be unique");
|
||||
}
|
||||
|
||||
THDGroup group;
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
group = THDNewGroup(ranks.data(), length);
|
||||
}
|
||||
return PyInt_FromLong(group);
|
||||
|
||||
invalid_arguments:
|
||||
THPUtils_invalidArguments(args, nullptr, "newGroup", 1, "(list[int] ranks)");
|
||||
return nullptr;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_requestIsCompleted(PyObject *_unused, PyObject *_req)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (!THPWrapper_check(_req)) {
|
||||
THPUtils_invalidArguments(_req, nullptr, "requestIsCompleted", 1, "(request req)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return PyBool_FromLong(THDRequest_isCompleted(_unpackRequest(_req)));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_requestWait(PyObject *_unused, PyObject *_req)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (!THPWrapper_check(_req)) {
|
||||
THPUtils_invalidArguments(_req, nullptr, "requestWait", 1, "(request req)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
{
|
||||
AutoNoGIL guard;
|
||||
THDRequest_wait(_unpackRequest(_req));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THDPModule_initExtension(PyObject *_unused, PyObject *args) {
|
||||
if (PyTuple_GET_SIZE(args) != 3) {
|
||||
THPUtils_invalidArguments(args, nullptr, "initExtension", 1, "(bool is_master_worker, reduce_op obj, group obj)");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyObject* is_master_worker_obj = PyTuple_GET_ITEM(args, 0);
|
||||
PyObject* reduce_op_obj = PyTuple_GET_ITEM(args, 1);
|
||||
PyObject* group_obj = PyTuple_GET_ITEM(args, 2);
|
||||
|
||||
THPUtils_assert(PyBool_Check(is_master_worker_obj), "first argument should be a bool");
|
||||
bool is_master_worker = is_master_worker_obj == Py_True;
|
||||
|
||||
THPObjectPtr reduce_op;
|
||||
#define REGISTER_REDUCE_OP(NAME) \
|
||||
reduce_op = PyObject_GetAttrString(reduce_op_obj, #NAME); \
|
||||
THPUtils_assert(reduce_op, "Missing object for reduce op " #NAME); \
|
||||
obj2reduceop.emplace(reduce_op.get(), THDReduce##NAME);
|
||||
REGISTER_REDUCE_OP(SUM);
|
||||
REGISTER_REDUCE_OP(PRODUCT);
|
||||
REGISTER_REDUCE_OP(MIN);
|
||||
REGISTER_REDUCE_OP(MAX);
|
||||
#undef REGISTER_REDUCE_OP
|
||||
|
||||
THPObjectPtr group;
|
||||
#define REGISTER_GROUP(NAME) \
|
||||
group = PyObject_GetAttrString(group_obj, #NAME); \
|
||||
THPUtils_assert(group, "Missing object for group " #NAME); \
|
||||
obj2group.emplace(group.get(), THDGroup##NAME);
|
||||
REGISTER_GROUP(WORLD);
|
||||
#undef REGISTER_GROUP
|
||||
|
||||
if (is_master_worker) {
|
||||
throw std::runtime_error("THD master_worker no longer supported");
|
||||
}
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
||||
static struct PyMethodDef _THDPModule_methods[] = {
|
||||
{"_dist_init_extension", (PyCFunction)THDPModule_initExtension, METH_VARARGS, nullptr},
|
||||
{"_dist_init_process_group", (PyCFunction)THDPModule_initProcessGroup, METH_VARARGS, nullptr},
|
||||
{"_dist_destroy_process_group", (PyCFunction)THDPModule_destroyProcessGroup, METH_NOARGS, nullptr},
|
||||
{"_dist_clear_group_cache", (PyCFunction)THDPModule_clearGroupCache, METH_VARARGS, nullptr},
|
||||
#ifdef USE_CUDA
|
||||
{"_dist_register_stream", (PyCFunction)THDPModule_registerStream, METH_O, nullptr},
|
||||
#endif
|
||||
{"_dist_get_rank", (PyCFunction)THDPModule_getRank, METH_NOARGS, nullptr},
|
||||
{"_dist_get_num_processes", (PyCFunction)THDPModule_getNumProcesses, METH_NOARGS, nullptr},
|
||||
{"_dist_isend", (PyCFunction)THDPModule_isend, METH_VARARGS, nullptr},
|
||||
{"_dist_irecv", (PyCFunction)THDPModule_irecv, METH_VARARGS, nullptr},
|
||||
{"_dist_send", (PyCFunction)THDPModule_send, METH_VARARGS, nullptr},
|
||||
{"_dist_recv_any_source", (PyCFunction)THDPModule_recvAnySource, METH_O, nullptr},
|
||||
{"_dist_recv", (PyCFunction)THDPModule_recv, METH_VARARGS, nullptr},
|
||||
{"_dist_all_reduce", (PyCFunction)THDPModule_allReduce, METH_VARARGS, nullptr},
|
||||
{"_dist_all_reduce_multigpu", (PyCFunction)THDPModule_allReduceMultiGPU, METH_VARARGS, nullptr},
|
||||
{"_dist_reduce", (PyCFunction)THDPModule_reduce, METH_VARARGS, nullptr},
|
||||
{"_dist_reduce_multigpu", (PyCFunction)THDPModule_reduceMultiGPU, METH_VARARGS, nullptr},
|
||||
{"_dist_broadcast", (PyCFunction)THDPModule_broadcast, METH_VARARGS, nullptr},
|
||||
{"_dist_broadcast_multigpu", (PyCFunction)THDPModule_broadcastMultiGPU, METH_VARARGS, nullptr},
|
||||
{"_dist_all_gather", (PyCFunction)THDPModule_allGather, METH_VARARGS, nullptr},
|
||||
{"_dist_all_gather_multigpu", (PyCFunction)THDPModule_allGatherMultiGPU, METH_VARARGS, nullptr},
|
||||
{"_dist_gather_send", (PyCFunction)THDPModule_gatherSend, METH_VARARGS, nullptr},
|
||||
{"_dist_gather_recv", (PyCFunction)THDPModule_gatherRecv, METH_VARARGS, nullptr},
|
||||
{"_dist_scatter_send", (PyCFunction)THDPModule_scatterSend, METH_VARARGS, nullptr},
|
||||
{"_dist_scatter_recv", (PyCFunction)THDPModule_scatterRecv, METH_VARARGS, nullptr},
|
||||
{"_dist_barrier", (PyCFunction)THDPModule_barrier, METH_O, nullptr},
|
||||
{"_dist_new_group", (PyCFunction)THDPModule_newGroup, METH_VARARGS, nullptr},
|
||||
{"_dist_request_is_completed", (PyCFunction)THDPModule_requestIsCompleted, METH_O, nullptr},
|
||||
{"_dist_request_wait", (PyCFunction)THDPModule_requestWait, METH_O, nullptr},
|
||||
{nullptr}
|
||||
};
|
||||
|
||||
PyMethodDef* THDPModule_methods() {
|
||||
return _THDPModule_methods;
|
||||
}
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
#ifndef THDP_H
|
||||
#define THDP_H
|
||||
|
||||
#include <THD/THD.h>
|
||||
|
||||
#include <torch/csrc/THP.h>
|
||||
#include <torch/csrc/Module.h>
|
||||
|
||||
#endif
|
||||
|
|
@ -25,7 +25,7 @@ static void THPStorage_(dealloc)(THPStorage* self)
|
|||
|
||||
static THWStorage* THPStorage_(newWithAllocator)(int64_t size, at::Allocator* allocator)
|
||||
{
|
||||
#if defined(THC_GENERIC_FILE) || defined(THD_GENERIC_FILE)
|
||||
#if defined(THC_GENERIC_FILE)
|
||||
THPUtils_setError(THPStorageStr " does not support custom allocators");
|
||||
return nullptr;
|
||||
#else
|
||||
|
|
@ -94,9 +94,6 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
|
|||
|
||||
// torch.Storage(sequence)
|
||||
if (num_args == 1 && PySequence_Check(first_arg)) {
|
||||
#ifdef THD_GENERIC_FILE
|
||||
THPUtils_setError("distributed storages don't support construction from a sequence");
|
||||
#else
|
||||
Py_ssize_t length = PySequence_Length(first_arg);
|
||||
THPUtils_assert(length >= 0, "couldn't obtain the length of %s",
|
||||
THPUtils_typename(first_arg));
|
||||
|
|
@ -122,7 +119,6 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
|
|||
return nullptr;
|
||||
}
|
||||
return (PyObject*)self.release();
|
||||
#endif
|
||||
}
|
||||
|
||||
THPUtils_invalidArguments(args, kwargs, THPStorageStr " constructor", 6,
|
||||
|
|
@ -310,7 +306,6 @@ THPCopyList THWStorage_(copy_functions);
|
|||
|
||||
void THPStorage_(initCopyMethods)()
|
||||
{
|
||||
#ifndef THD_GENERIC_FILE
|
||||
auto& h = THWStorage_(copy_functions);
|
||||
// copy from CPU types
|
||||
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPByteStorageType, h, &THWStorage_(copyByte));
|
||||
|
|
@ -350,21 +345,16 @@ void THPStorage_(initCopyMethods)()
|
|||
#undef THCpuStorage
|
||||
#undef THCpuStorage_
|
||||
#endif
|
||||
#endif // !defined(THD_GENERIC_FILE)
|
||||
}
|
||||
|
||||
#include <torch/csrc/generic/StorageMethods.cpp>
|
||||
#ifndef THD_GENERIC_FILE
|
||||
#include <torch/csrc/generic/StorageSharing.cpp>
|
||||
#endif
|
||||
|
||||
bool THPStorage_(init)(PyObject *module)
|
||||
{
|
||||
static std::vector<PyMethodDef> methods;
|
||||
THPUtils_addPyMethodDefs(methods, THPStorage_(methods));
|
||||
#ifndef THD_GENERIC_FILE
|
||||
THPUtils_addPyMethodDefs(methods, THPStorage_(sharingMethods));
|
||||
#endif
|
||||
|
||||
THPStorageType.tp_methods = methods.data();
|
||||
THPStorageType.tp_members = THPStorage_(members);
|
||||
|
|
|
|||
|
|
@ -9,14 +9,12 @@ static PyObject * THPStorage_(size)(THPStorage *self)
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#ifndef THD_GENERIC_FILE
|
||||
static PyObject * THPStorage_(dataPtr)(THPStorage *self)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
return PyLong_FromVoidPtr(THWStorage_(data)(LIBRARY_STATE self->cdata));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
#endif
|
||||
|
||||
static PyObject * THPStorage_(copy_)(PyObject *self, PyObject *args, PyObject *kwargs)
|
||||
{
|
||||
|
|
@ -25,7 +23,6 @@ static PyObject * THPStorage_(copy_)(PyObject *self, PyObject *args, PyObject *k
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#ifndef THD_GENERIC_FILE
|
||||
static PyObject * THPStorage_(isPinned)(THPStorage *self)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
|
|
@ -46,7 +43,6 @@ static PyObject * THPStorage_(isPinned)(THPStorage *self)
|
|||
#endif
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
#endif
|
||||
|
||||
static PyObject * THPStorage_(elementSize)(THPStorage *self)
|
||||
{
|
||||
|
|
@ -89,7 +85,7 @@ static PyObject * THPStorage_(fill_)(THPStorage *self, PyObject *number_arg)
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#if !defined(THC_GENERIC_FILE) && !defined(THD_GENERIC_FILE)
|
||||
#if !defined(THC_GENERIC_FILE)
|
||||
static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyObject *keywds)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
|
|
@ -206,7 +202,6 @@ static PyObject * THPStorage_(fromFile)(PyObject *_unused, PyObject *args, PyObj
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#ifndef THD_GENERIC_FILE
|
||||
PyObject * THPStorage_(writeFile)(THPStorage *self, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
|
|
@ -287,7 +282,6 @@ static PyObject *THPStorage_(setFromFile)(THPStorage *self, PyObject *args)
|
|||
return (PyObject *) self;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
#endif // !defined(THD_GENERIC_FILE)
|
||||
|
||||
#ifdef THC_GENERIC_FILE
|
||||
PyObject * THPStorage_(getDevice)(THPStorage *self)
|
||||
|
|
@ -320,14 +314,12 @@ static PyMethodDef THPStorage_(methods)[] = {
|
|||
{"new", (PyCFunction)THPStorage_(new), METH_NOARGS, nullptr},
|
||||
{"resize_", (PyCFunction)THPStorage_(resize_), METH_O, nullptr},
|
||||
{"size", (PyCFunction)THPStorage_(size), METH_NOARGS, nullptr},
|
||||
#ifndef THD_GENERIC_FILE
|
||||
{"data_ptr", (PyCFunction)THPStorage_(dataPtr), METH_NOARGS, nullptr},
|
||||
{"is_pinned", (PyCFunction)THPStorage_(isPinned), METH_NOARGS, nullptr},
|
||||
{"_write_file", (PyCFunction)THPStorage_(writeFile), METH_VARARGS, nullptr},
|
||||
{"_new_with_file", (PyCFunction)THPStorage_(newWithFile), METH_O | METH_STATIC, nullptr},
|
||||
{"_set_from_file", (PyCFunction)THPStorage_(setFromFile), METH_VARARGS, nullptr},
|
||||
#endif // !defined(THD_GENERIC_FILE)
|
||||
#if !defined(THC_GENERIC_FILE) && !defined(THD_GENERIC_FILE)
|
||||
#if !defined(THC_GENERIC_FILE)
|
||||
{"from_buffer", (PyCFunction)THPStorage_(fromBuffer), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
#endif
|
||||
{"from_file", (PyCFunction)THPStorage_(fromFile), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
|
|
@ -335,7 +327,5 @@ static PyMethodDef THPStorage_(methods)[] = {
|
|||
{"get_device", (PyCFunction)THPStorage_(getDevice), METH_NOARGS, nullptr},
|
||||
#endif
|
||||
{"_set_cdata", (PyCFunction)THPStorage_(_setCdata), METH_O, nullptr},
|
||||
#ifndef THD_GENERIC_FILE
|
||||
#endif
|
||||
{nullptr}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
#define TH_GENERIC_FILE "torch/csrc/generic/utils.cpp"
|
||||
#else
|
||||
|
||||
#if defined(THD_GENERIC_FILE) || defined(TH_REAL_IS_HALF)
|
||||
#if defined(TH_REAL_IS_HALF)
|
||||
#define GENERATE_SPARSE 0
|
||||
#else
|
||||
#define GENERATE_SPARSE 1
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
#define TH_GENERIC_FILE "torch/csrc/generic/utils.h"
|
||||
#else
|
||||
|
||||
#if defined(THD_GENERIC_FILE) || defined(TH_REAL_IS_HALF)
|
||||
#if defined(TH_REAL_IS_HALF)
|
||||
#define GENERATE_SPARSE 0
|
||||
#else
|
||||
#define GENERATE_SPARSE 1
|
||||
|
|
@ -16,7 +16,6 @@ typedef class THPPointer<THWTensor> THWTensorPtr;
|
|||
typedef class THPPointer<THPStorage> THPStoragePtr;
|
||||
|
||||
#if (!defined(THC_GENERIC_FILE)) && \
|
||||
(!defined(THD_GENERIC_FILE)) && \
|
||||
(!defined(THQUANTIZED))
|
||||
template<>
|
||||
struct THPUtils_typeTraits<scalar_t> {
|
||||
|
|
|
|||
|
|
@ -1,566 +0,0 @@
|
|||
"""
|
||||
torch.distributed.deprecated provides an MPI-like interface for exchanging tensor
|
||||
data across multi-machine networks. It supports a few different backends
|
||||
and initialization methods.
|
||||
"""
|
||||
import torch
|
||||
import atexit
|
||||
import warnings
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
|
||||
class dist_backend:
|
||||
UNDEFINED = -1
|
||||
TCP = 0
|
||||
MPI = 1
|
||||
GLOO = 2
|
||||
NCCL = 3
|
||||
|
||||
|
||||
_INITIALIZED_PG = 1
|
||||
_INITIALIZED_MW = 2
|
||||
_initialized = 0
|
||||
_backend = dist_backend.UNDEFINED
|
||||
_scope = locals()
|
||||
|
||||
|
||||
def _extend_scope(module):
|
||||
_scope.update({k: getattr(module, k) for k in dir(module) if not k.startswith('_')})
|
||||
|
||||
|
||||
def is_available():
|
||||
return torch._C._has_distributed()
|
||||
|
||||
|
||||
def destroy_process_group():
|
||||
r"""Destroy the initialized distributed package
|
||||
"""
|
||||
global _backend
|
||||
global _initialized
|
||||
torch._C._dist_destroy_process_group()
|
||||
_backend = dist_backend.UNDEFINED
|
||||
_initialized = 0
|
||||
|
||||
|
||||
def is_initialized():
|
||||
r"""Checking if the process group has been initialized
|
||||
"""
|
||||
return _initialized == _INITIALIZED_PG
|
||||
|
||||
|
||||
def init_process_group(backend, init_method='env://', **kwargs):
|
||||
r"""Initializes the distributed package.
|
||||
|
||||
Arguments:
|
||||
backend (str): Name of the backend to use. Depending on build-time configuration
|
||||
valid values include: ``tcp``, ``mpi``, ``gloo`` and ``nccl``.
|
||||
init_method (str, optional): URL specifying how to initialize the package.
|
||||
world_size (int, optional): Number of processes participating in the job.
|
||||
rank (int, optional): Rank of the current process.
|
||||
group_name (str, optional): Group name. See description of init methods.
|
||||
|
||||
To enable ``backend == mpi``, PyTorch needs to built from source on a system that
|
||||
supports MPI. If you want to use Open MPI with CUDA-aware support, please use
|
||||
Open MPI major version 2 and above.
|
||||
|
||||
.. note::
|
||||
This method initializes CUDA context. Therefore, if multiple processes
|
||||
run on a single machine but use different GPUs, make sure to use
|
||||
:func:`torch.cuda.set_device` before this method to avoid unnecessarily
|
||||
creating context on the first visible device.
|
||||
|
||||
"""
|
||||
world_size = kwargs.pop('world_size', -1)
|
||||
group_name = kwargs.pop('group_name', '')
|
||||
rank = kwargs.pop('rank', -1)
|
||||
assert len(kwargs) == 0, "got unexpected keyword arguments: %s" % ",".join(kwargs.keys())
|
||||
|
||||
if not is_available():
|
||||
raise RuntimeError("PyTorch built without distributed support")
|
||||
|
||||
global _initialized
|
||||
if _initialized:
|
||||
raise RuntimeError("trying to initialize torch.distributed.deprecated twice!")
|
||||
|
||||
# Checking and assigning the distributed backend
|
||||
global _backend
|
||||
|
||||
backend = backend.lower()
|
||||
if backend == "tcp":
|
||||
_backend = dist_backend.TCP
|
||||
elif backend == "mpi":
|
||||
_backend = dist_backend.MPI
|
||||
elif backend == "gloo":
|
||||
_backend = dist_backend.GLOO
|
||||
elif backend == "nccl":
|
||||
_backend = dist_backend.NCCL
|
||||
else:
|
||||
raise RuntimeError("Invalid distributed backend name: " + backend)
|
||||
|
||||
torch._C._dist_init_process_group(backend, init_method, world_size,
|
||||
group_name, rank)
|
||||
_initialized = _INITIALIZED_PG
|
||||
|
||||
if _backend == dist_backend.NCCL:
|
||||
atexit.register(destroy_process_group)
|
||||
|
||||
if not torch._C._dist_init_extension(False, reduce_op, group):
|
||||
raise RuntimeError("distributed module initialization failed")
|
||||
|
||||
|
||||
def init_master_worker(backend, init_method='env://', **kwargs):
|
||||
warnings.warn("""
|
||||
================================================================================
|
||||
WARNING
|
||||
================================================================================
|
||||
Master-worker mode is still experimental. The API will change without
|
||||
notice and we do not guarantee full correctness and expected performance yet.
|
||||
We'll announce it once it's ready.
|
||||
""")
|
||||
world_size = kwargs.pop('world_size', -1)
|
||||
group_name = kwargs.pop('group_name', '')
|
||||
rank = kwargs.pop('rank', -1)
|
||||
assert len(kwargs) == 0, "got unexpected keyword arguments: %s" % ",".join(kwargs.keys())
|
||||
|
||||
if not is_available():
|
||||
raise RuntimeError("PyTorch built without distributed support")
|
||||
|
||||
global _initialized
|
||||
if _initialized:
|
||||
raise RuntimeError("trying to initialize torch.distributed.deprecated twice!")
|
||||
torch._C._dist_init_master_worker(backend, init_method, world_size,
|
||||
group_name, rank)
|
||||
_initialized = _INITIALIZED_MW
|
||||
import torch.distributed.deprecated.collectives as collectives
|
||||
import torch.distributed.deprecated.remote_types as remote_types
|
||||
_extend_scope(collectives)
|
||||
_extend_scope(remote_types)
|
||||
if not torch._C._dist_init_extension(True, reduce_op, group):
|
||||
raise RuntimeError("distributed module initialization failed")
|
||||
|
||||
|
||||
class reduce_op(object):
|
||||
SUM = object()
|
||||
PRODUCT = object()
|
||||
MAX = object()
|
||||
MIN = object()
|
||||
|
||||
|
||||
class group(object):
|
||||
WORLD = object()
|
||||
|
||||
|
||||
class _DistributedRequest(object):
|
||||
def __init__(self, request):
|
||||
self.request = request
|
||||
|
||||
def is_completed(self):
|
||||
return torch._C._dist_request_is_completed(self.request)
|
||||
|
||||
def wait(self):
|
||||
torch._C._dist_request_wait(self.request)
|
||||
|
||||
|
||||
def get_rank():
|
||||
r"""Returns the rank of current process.
|
||||
|
||||
Rank is a unique identifier assigned to each process within a distributed
|
||||
group. They are always consecutive integers ranging from ``0`` to
|
||||
``world_size - 1`` (inclusive).
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized
|
||||
return torch._C._dist_get_rank()
|
||||
|
||||
|
||||
def get_world_size():
|
||||
r"""Returns the number of processes in the distributed group."""
|
||||
assert torch.distributed.deprecated._initialized
|
||||
return torch._C._dist_get_num_processes()
|
||||
|
||||
|
||||
def isend(tensor, dst):
|
||||
r"""Sends a tensor asynchronously.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): Tensor to send.
|
||||
dst (int): Destination rank.
|
||||
|
||||
Returns:
|
||||
A distributed request object.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return _DistributedRequest(torch._C._dist_isend(tensor, dst))
|
||||
|
||||
|
||||
def irecv(tensor, src):
|
||||
r"""Receives a tensor asynchronously.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): Tensor to fill with received data.
|
||||
src (int): Source rank.
|
||||
|
||||
Returns:
|
||||
A distributed request object.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return _DistributedRequest(torch._C._dist_irecv(tensor, src))
|
||||
|
||||
|
||||
def send(tensor, dst):
|
||||
r"""Sends a tensor synchronously.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): Tensor to send.
|
||||
dst (int): Destination rank.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return torch._C._dist_send(tensor, dst)
|
||||
|
||||
|
||||
def recv(tensor, src=None):
|
||||
r"""Receives a tensor synchronously.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): Tensor to fill with received data.
|
||||
src (int, optional): Source rank. Will receive from any
|
||||
process if unspecified.
|
||||
|
||||
Returns:
|
||||
Sender rank.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
if src is None:
|
||||
return torch._C._dist_recv_any_source(tensor)
|
||||
return torch._C._dist_recv(tensor, src)
|
||||
|
||||
|
||||
def broadcast_multigpu(tensor_list, src, group=group.WORLD):
|
||||
r"""Broadcasts the tensor to the whole group with multiple GPU tensors
|
||||
per node.
|
||||
|
||||
:attr:`tensor` must have the same number of elements in all the GPUs from
|
||||
all processes participating in the collective. each tensor in the list must
|
||||
be on a different GPU.
|
||||
|
||||
.. note::
|
||||
Only NCCL backend is currently supported. :attr:`tensor_list` should only
|
||||
contain GPU tensors.
|
||||
|
||||
Arguments:
|
||||
tensor_list (List[Tensor]): Tensors that participate in the collective
|
||||
operation. if ``src`` is the rank, then the first element of
|
||||
``tensor_list`` (``tensor_list[0]``) will be broadcasted to all
|
||||
other tensors (on different GPUs) in the src process and all tensors
|
||||
in ``tensor_list`` of other non-src processes. You also need to make
|
||||
sure that ``len(tensor_list)`` is the same for all the distributed
|
||||
processes calling this function.
|
||||
|
||||
src (int): Source rank.
|
||||
group (optional): Group of the collective.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
|
||||
return torch._C._dist_broadcast_multigpu(tensor_list, src, group)
|
||||
|
||||
|
||||
def broadcast(tensor, src, group=group.WORLD):
|
||||
r"""Broadcasts the tensor to the whole group.
|
||||
|
||||
:attr:`tensor` must have the same number of elements in all processes
|
||||
participating in the collective.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): Data to be sent if :attr:`src` is the rank of
|
||||
current process, and tensor to be used to save received data
|
||||
otherwise.
|
||||
src (int): Source rank.
|
||||
group (optional): Group of the collective.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return torch._C._dist_broadcast(tensor, src, group)
|
||||
|
||||
|
||||
def all_reduce_multigpu(tensor_list, op=reduce_op.SUM, group=group.WORLD):
|
||||
r"""Reduces the tensor data across all machines in such a way that all get
|
||||
the final result. This function reduces a number of tensors on every node,
|
||||
while each tensor resides on a different GPU.
|
||||
Therefore, the input tensor in the tensor list needs to be GPU tensors.
|
||||
Also, each tensor in the tensor list needs to reside on a different GPU.
|
||||
|
||||
After the call, all tensors in :attr:`tensor_list` will be bitwise identical
|
||||
in all processes.
|
||||
|
||||
.. note::
|
||||
Only NCCL backend is currently supported. :attr:`tensor_list` should only
|
||||
contain GPU tensors.
|
||||
|
||||
Arguments:
|
||||
tensor_list (List[Tensor]): List of input and output tensors of
|
||||
the collective. The function operates in-place and requires that
|
||||
each tensor to be a GPU tensor on different GPUs.
|
||||
You also need to make sure that ``len(tensor_list)`` is the same for
|
||||
all the distributed processes calling this function.
|
||||
|
||||
op (optional): One of the values from ``torch.distributed.deprecated.reduce_op``
|
||||
enum. Specifies an operation used for element-wise reductions.
|
||||
group (optional): Group of the collective.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
|
||||
return torch._C._dist_all_reduce_multigpu(tensor_list, op, group)
|
||||
|
||||
|
||||
def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
|
||||
r"""Reduces the tensor data across all machines in such a way that all get
|
||||
the final result.
|
||||
|
||||
After the call :attr:`tensor` will be bitwise identical in all processes.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): Input and output of the collective. The function
|
||||
operates in-place.
|
||||
op (optional): One of the values from ``torch.distributed.deprecated.reduce_op``
|
||||
enum. Specifies an operation used for element-wise reductions.
|
||||
group (optional): Group of the collective.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return torch._C._dist_all_reduce(tensor, op, group)
|
||||
|
||||
|
||||
def reduce_multigpu(tensor_list, dst, op=reduce_op.SUM, group=group.WORLD):
|
||||
r"""Reduces the tensor data on multiple GPUs across all machines. Each tensor
|
||||
in :attr:`tensor_list` should reside on a separate GPU.
|
||||
|
||||
Only the GPU of ``tensor_list[0]`` on the process with rank :attr:`dst` is
|
||||
going to receive the final result.
|
||||
|
||||
.. note::
|
||||
Only NCCL backend is currently supported. :attr:`tensor_list` should only
|
||||
contain GPU tensors.
|
||||
|
||||
Arguments:
|
||||
tensor_list (List[Tensor]): Input and output GPU tensors of the
|
||||
collective. The function operates in-place.
|
||||
You also need to make sure that ``len(tensor_list)`` is the same for
|
||||
all the distributed processes calling this function.
|
||||
|
||||
dst (int): Destination rank
|
||||
op (optional): One of the values from ``torch.distributed.deprecated.reduce_op``
|
||||
enum. Specifies an operation used for element-wise reductions.
|
||||
group (optional): Group of the collective.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
|
||||
return torch._C._dist_reduce_multigpu(tensor_list, dst, op, group)
|
||||
|
||||
|
||||
def reduce(tensor, dst, op=reduce_op.SUM, group=group.WORLD):
|
||||
r"""Reduces the tensor data across all machines.
|
||||
|
||||
Only the process with rank :attr:`dst` is going to receive the final result.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): Input and output of the collective. The function
|
||||
operates in-place.
|
||||
dst (int): Destination rank
|
||||
op (optional): One of the values from ``torch.distributed.deprecated.reduce_op``
|
||||
enum. Specifies an operation used for element-wise reductions.
|
||||
group (optional): Group of the collective.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return torch._C._dist_reduce(tensor, dst, op, group)
|
||||
|
||||
|
||||
def all_gather_multigpu(output_tensor_lists,
|
||||
input_tensor_list,
|
||||
group=group.WORLD):
|
||||
r"""Gathers tensors from the whole group in a list.
|
||||
Each tensor in :attr:`input_tensor_list` should reside on a separate GPU.
|
||||
|
||||
.. note::
|
||||
Only NCCL backend is currently supported. :attr:`output_tensor_lists` and
|
||||
:attr:`input_tensor_list` should only contain GPU tensors.
|
||||
|
||||
Arguments:
|
||||
output_tensor_lists (List[List[Tensor]]): Output lists. It should
|
||||
contain correctly-sized tensors on each GPU to be used for output of
|
||||
the collective.
|
||||
e.g. ``output_tensor_lists[i]`` contains the all_gather
|
||||
result that resides on the GPU of ``input_tensor_list[i]``.
|
||||
Note that each element of ``output_tensor_lists[i]`` has the size of
|
||||
``world_size * len(input_tensor_list)``, since the function all
|
||||
gathers the result from every single GPU in the group. To interpret
|
||||
each element of ``output_tensor_list[i]``, note that
|
||||
``input_tensor_list[j]`` of rank k will be appear in
|
||||
``output_tensor_list[i][rank * world_size + j]``
|
||||
Also note that ``len(output_tensor_lists)``, and the size of each
|
||||
element in ``output_tensor_lists`` (each element is a list,
|
||||
therefore ``len(output_tensor_lists[i])``) need to be the same
|
||||
for all the distributed processes calling this function.
|
||||
|
||||
input_tensor_list (List[Tensor]): List of tensors (on different GPUs) to
|
||||
be broadcast from current process.
|
||||
Note that ``len(input_tensor_list)`` needs to be the same for
|
||||
all the distributed processes calling this function.
|
||||
group (optional): Group of the collective.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
|
||||
flatten_tensor_list = []
|
||||
for output_tensor_list in output_tensor_lists:
|
||||
flatten_tensor_list.append(_flatten_dense_tensors(output_tensor_list))
|
||||
|
||||
ret = torch._C._dist_all_gather_multigpu(flatten_tensor_list,
|
||||
input_tensor_list,
|
||||
group)
|
||||
|
||||
for output_tensor_list, flatten_tensor in zip(output_tensor_lists,
|
||||
flatten_tensor_list):
|
||||
for tensor, value in zip(output_tensor_list,
|
||||
_unflatten_dense_tensors(flatten_tensor,
|
||||
output_tensor_list)):
|
||||
tensor.copy_(value)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def all_gather(tensor_list, tensor, group=group.WORLD):
|
||||
r"""Gathers tensors from the whole group in a list.
|
||||
|
||||
Arguments:
|
||||
tensor_list (list[Tensor]): Output list. It should contain
|
||||
correctly-sized tensors to be used for output of the collective.
|
||||
tensor (Tensor): Tensor to be broadcast from current process.
|
||||
group (optional): Group of the collective.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
if _backend != dist_backend.NCCL:
|
||||
return torch._C._dist_all_gather(tensor_list, tensor, group)
|
||||
else:
|
||||
return all_gather_multigpu([tensor_list], [tensor], group)
|
||||
|
||||
|
||||
def gather(tensor, **kwargs):
|
||||
r"""Gathers a list of tensors in a single process.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): Input tensor.
|
||||
dst (int): Destination rank. Required in all processes except the one that
|
||||
is receiveing the data.
|
||||
gather_list (list[Tensor]): List of appropriately-sized tensors to
|
||||
use for received data. Required only in the receiving process.
|
||||
group (optional): Group of the collective.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
my_rank = get_rank()
|
||||
dst = kwargs.pop('dst', my_rank)
|
||||
gather_list = kwargs.pop('gather_list', None)
|
||||
_group = kwargs.pop('group', group.WORLD)
|
||||
if kwargs:
|
||||
raise RuntimeError("got unexpected kwargs")
|
||||
if dst == my_rank:
|
||||
if gather_list is None:
|
||||
raise RuntimeError("gather_list is a required argument in gather destination")
|
||||
return torch._C._dist_gather_recv(gather_list, tensor, _group)
|
||||
else:
|
||||
if gather_list:
|
||||
raise RuntimeError("non-empty gather_list can be given only to gather destination")
|
||||
return torch._C._dist_gather_send(tensor, dst, _group)
|
||||
|
||||
|
||||
def scatter(tensor, **kwargs):
|
||||
r"""Scatters a list of tensors to all processes in a group.
|
||||
|
||||
Each process will receive exactly one tensor and store its data in the
|
||||
:attr:`tensor` argument.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): Output tensor.
|
||||
src (int): Source rank. Required in all processes except the one that
|
||||
is sending the data.
|
||||
scatter_list (list[Tensor]): List of tensors to scatter. Required only
|
||||
in the process that is sending the data.
|
||||
group (optional): Group of the collective.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
my_rank = get_rank()
|
||||
src = kwargs.pop('src', my_rank)
|
||||
scatter_list = kwargs.pop('scatter_list', None)
|
||||
_group = kwargs.pop('group', group.WORLD)
|
||||
if kwargs:
|
||||
raise RuntimeError("got unexpected kwargs: {}".format(", ".join(kwargs.keys())))
|
||||
if src == my_rank:
|
||||
if scatter_list is None:
|
||||
raise RuntimeError("scatter_list is a required argument in scatter source")
|
||||
return torch._C._dist_scatter_send(scatter_list, tensor, _group)
|
||||
else:
|
||||
if scatter_list:
|
||||
raise RuntimeError("non-empty can be given only to scatter source")
|
||||
return torch._C._dist_scatter_recv(tensor, src, _group)
|
||||
|
||||
|
||||
def barrier(group=group.WORLD):
|
||||
r"""Synchronizes all processes.
|
||||
|
||||
This collective blocks processes until the whole group enters this function.
|
||||
|
||||
Arguments:
|
||||
group (optional): Group of the collective.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return torch._C._dist_barrier(group)
|
||||
|
||||
|
||||
def new_group(ranks=None):
|
||||
r"""Creates a new distributed group.
|
||||
|
||||
This function requires that all processes in the main group (i.e., all
|
||||
processes that are part of the distributed job) enter this function, even
|
||||
if they are not going to be members of the group. Additionally, groups
|
||||
should be created in the same order in all processes.
|
||||
|
||||
Arguments:
|
||||
ranks (list[int]): List of ranks of group members.
|
||||
|
||||
Returns:
|
||||
A handle of distributed group that can be given to collective calls.
|
||||
"""
|
||||
assert torch.distributed.deprecated._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
if ranks is None:
|
||||
ranks = list(range(get_world_size()))
|
||||
return torch._C._dist_new_group(ranks)
|
||||
|
||||
|
||||
def _clear_group_cache(group=group.WORLD):
|
||||
r"""Clear the created distributed group's cached resource.
|
||||
|
||||
Only NCCL backend is currently supported.
|
||||
|
||||
Cached resource includes NCCL communicators and CUDA events.
|
||||
|
||||
Arguments:
|
||||
group (optional): Group of the collective.
|
||||
"""
|
||||
return torch._C._dist_clear_group_cache(group)
|
||||
|
||||
|
||||
def _register_stream(stream):
|
||||
if not _initialized:
|
||||
raise RuntimeError("torch.distributed.deprecated needs to be initialized first")
|
||||
return torch._C._dist_register_stream(stream)
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
import torch
|
||||
|
||||
from ..storage import _StorageBase
|
||||
|
||||
|
||||
class _DistributedBase(object):
|
||||
is_cuda = False
|
||||
is_distributed = True
|
||||
|
||||
|
||||
class DoubleStorage(_DistributedBase, torch._C.DistributedDoubleStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class FloatStorage(_DistributedBase, torch._C.DistributedFloatStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class LongStorage(_DistributedBase, torch._C.DistributedLongStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class IntStorage(_DistributedBase, torch._C.DistributedIntStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class ShortStorage(_DistributedBase, torch._C.DistributedShortStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class CharStorage(_DistributedBase, torch._C.DistributedCharStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class ByteStorage(_DistributedBase, torch._C.DistributedByteStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
class HalfStorage(_DistributedBase, torch._C.DistributedHalfStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
torch._storage_classes.add(DoubleStorage)
|
||||
torch._storage_classes.add(FloatStorage)
|
||||
torch._storage_classes.add(HalfStorage)
|
||||
torch._storage_classes.add(LongStorage)
|
||||
torch._storage_classes.add(IntStorage)
|
||||
torch._storage_classes.add(ShortStorage)
|
||||
torch._storage_classes.add(CharStorage)
|
||||
torch._storage_classes.add(ByteStorage)
|
||||
|
||||
|
||||
_type_names = ['Double', 'Float', 'Half', 'Long', 'Int', 'Short', 'Char', 'Byte']
|
||||
_locals = locals()
|
||||
_tensors = [_locals[t + 'Tensor'] for t in _type_names]
|
||||
_storages = [_locals[t + 'Storage'] for t in _type_names]
|
||||
for cls in _tensors + _storages:
|
||||
cls.__module__ = 'torch.distributed.deprecated'
|
||||
torch._C._init_names(_tensors + _storages)
|
||||
del _locals, _type_names, _tensors, _storages
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
CMAKE_MINIMUM_REQUIRED(VERSION 2.8)
|
||||
# TODO(jiayq): once we have unified CMake entry, remove this module path.
|
||||
SET(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/Modules ${CMAKE_MODULE_PATH})
|
||||
|
||||
################################################################################
|
||||
# Helper functions
|
||||
################################################################################
|
||||
|
||||
FUNCTION(EXCLUDE_DIR list_name dir_name)
|
||||
# A helper that excludes all files that contain dir_name in their file path
|
||||
SET(local_list ${${list_name}})
|
||||
FOREACH(source ${local_list})
|
||||
IF(${source} MATCHES ${dir_name})
|
||||
MESSAGE(STATUS "Excluding " ${source} " from the build")
|
||||
LIST(REMOVE_ITEM local_list ${source})
|
||||
ENDIF()
|
||||
ENDFOREACH()
|
||||
SET(${list_name} ${local_list} PARENT_SCOPE)
|
||||
ENDFUNCTION()
|
||||
|
||||
################################################################################
|
||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||
|
||||
INCLUDE(CheckCXXSourceCompiles)
|
||||
|
||||
CHECK_CXX_SOURCE_COMPILES("
|
||||
#include <thread>
|
||||
|
||||
thread_local int foo=1;
|
||||
|
||||
int main() {
|
||||
return 0;
|
||||
}" HAS_THREAD_LOCAL)
|
||||
|
||||
IF(NOT HAS_THREAD_LOCAL)
|
||||
MESSAGE(FATAL_ERROR "thread_local not supported. THD requires a compiler"
|
||||
" that supports thread_local. Please upgrade your "
|
||||
"compiler. If you are on macOS, upgrade to "
|
||||
"XCode 8 or newer.")
|
||||
ENDIF(NOT HAS_THREAD_LOCAL)
|
||||
|
||||
|
||||
FIND_PACKAGE(MPI)
|
||||
|
||||
INCLUDE_DIRECTORIES(${CAFFE2_INCLUDE_DIR})
|
||||
|
||||
if (USE_TBB)
|
||||
include_directories(${TBB_ROOT_DIR}/include)
|
||||
endif()
|
||||
|
||||
IF(USE_CUDA)
|
||||
FIND_PACKAGE(CUDA 7.5)
|
||||
IF(CUDA_FOUND)
|
||||
INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS})
|
||||
LINK_DIRECTORIES("${CUDA_TOOLKIT_ROOT_DIR}/lib" "${CUDA_TOOLKIT_ROOT_DIR}/lib64")
|
||||
|
||||
ADD_DEFINITIONS(-DUSE_CUDA=1)
|
||||
ENDIF()
|
||||
ELSEIF(USE_ROCM)
|
||||
INCLUDE_DIRECTORIES(${Caffe2_HIP_INCLUDE})
|
||||
INCLUDE_DIRECTORIES(${GLOO_HIP_INCLUDE})
|
||||
ADD_DEFINITIONS(-DUSE_ROCM=1)
|
||||
ADD_DEFINITIONS(-D__HIP_PLATFORM_HCC__=1)
|
||||
ADD_DEFINITIONS(-DHIP_VERSION=${HIP_VERSION_MAJOR})
|
||||
ELSE()
|
||||
MESSAGE(STATUS "ignoring GPU")
|
||||
ENDIF()
|
||||
|
||||
IF(MPI_FOUND)
|
||||
ADD_DEFINITIONS(-DWITH_MPI=1)
|
||||
MESSAGE(STATUS "MPI_LIBRARIES: ${MPI_LIBRARIES}")
|
||||
ENDIF()
|
||||
|
||||
IF(USE_GLOO AND (USE_CUDA OR USE_ROCM))
|
||||
ADD_DEFINITIONS(-DWITH_GLOO=1)
|
||||
IF(USE_GLOO_IBVERBS)
|
||||
MESSAGE(STATUS "Building the gloo backend with both TCP and infiniband support")
|
||||
ADD_DEFINITIONS(-DUSE_GLOO_IBVERBS=1)
|
||||
ELSE()
|
||||
MESSAGE(STATUS "Building the gloo backend with TCP support only")
|
||||
ENDIF()
|
||||
ENDIF()
|
||||
|
||||
# Can be compiled standalone
|
||||
IF(NOT THD_INSTALL_BIN_DIR OR NOT THD_INSTALL_LIB_DIR OR NOT THD_INSTALL_INCLUDE_DIR)
|
||||
SET(THD_INSTALL_BIN_DIR "bin" CACHE PATH "THD install binary subdirectory")
|
||||
SET(THD_INSTALL_LIB_DIR "lib" CACHE PATH "THD install library subdirectory")
|
||||
SET(THD_INSTALL_INCLUDE_DIR "include" CACHE PATH "THD install include subdirectory")
|
||||
ENDIF()
|
||||
|
||||
FILE(GLOB_RECURSE process_group_h RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "process_group/*.h")
|
||||
FILE(GLOB_RECURSE process_group_cpp "process_group/*.cpp")
|
||||
FILE(GLOB_RECURSE base_h RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "base/*.h")
|
||||
FILE(GLOB_RECURSE base_cpp "base/*.cpp")
|
||||
FILE(GLOB_RECURSE test_cpp "test/*.cpp")
|
||||
|
||||
IF(NOT MPI_FOUND)
|
||||
LIST(REMOVE_ITEM base_cpp "${CMAKE_CURRENT_SOURCE_DIR}/base/data_channels/DataChannelMPI.cpp")
|
||||
LIST(REMOVE_ITEM test_cpp "${CMAKE_CURRENT_SOURCE_DIR}/test/data_channel_mpi_smoke.cpp")
|
||||
ENDIF()
|
||||
|
||||
IF(NOT (USE_GLOO AND (USE_CUDA OR USE_ROCM)))
|
||||
LIST(REMOVE_ITEM base_cpp "${CMAKE_CURRENT_SOURCE_DIR}/base/data_channels/DataChannelGloo.cpp")
|
||||
LIST(REMOVE_ITEM base_cpp "${CMAKE_CURRENT_SOURCE_DIR}/base/data_channels/Store.cpp")
|
||||
LIST(REMOVE_ITEM test_cpp "${CMAKE_CURRENT_SOURCE_DIR}/test/data_channel_gloo_store.cpp")
|
||||
LIST(REMOVE_ITEM test_cpp "${CMAKE_CURRENT_SOURCE_DIR}/test/data_channel_gloo_cache.cpp")
|
||||
ENDIF()
|
||||
|
||||
IF(NOT USE_NCCL)
|
||||
LIST(REMOVE_ITEM base_cpp "${CMAKE_CURRENT_SOURCE_DIR}/base/data_channels/DataChannelNccl.cpp")
|
||||
ENDIF()
|
||||
|
||||
SET(all_cpp ${base_cpp} ${process_group_cpp})
|
||||
SET(all_h THD.h ${base_h} ${process_group_h})
|
||||
|
||||
EXCLUDE_DIR(all_cpp ".*/generic/.*\\.cpp$")
|
||||
|
||||
INCLUDE_DIRECTORIES("${CMAKE_CURRENT_SOURCE_DIR}/..")
|
||||
|
||||
ADD_LIBRARY(THD STATIC ${all_cpp})
|
||||
|
||||
TARGET_COMPILE_DEFINITIONS(THD PRIVATE "_THD_CORE=1")
|
||||
|
||||
target_link_libraries(THD torch)
|
||||
|
||||
target_include_directories(THD PRIVATE
|
||||
${CMAKE_BINARY_DIR}/aten/src # provides "ATen/TypeExtendedInterface.h" to ATen.h
|
||||
${CMAKE_BINARY_DIR}/caffe2/aten/src # provides <TH/THGeneral.h> to THC.h
|
||||
)
|
||||
|
||||
set_property(TARGET THD PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
IF(MPI_FOUND)
|
||||
INCLUDE_DIRECTORIES(${MPI_INCLUDE_PATH})
|
||||
|
||||
IF(MPI_COMPILE_FLAGS)
|
||||
MESSAGE(STATUS "MPI_COMPILE_FLAGS: ${MPI_COMPILE_FLAGS}")
|
||||
SET_TARGET_PROPERTIES(THD PROPERTIES COMPILE_FLAGS "${MPI_COMPILE_FLAGS}")
|
||||
ENDIF()
|
||||
|
||||
IF(MPI_LINK_FLAGS)
|
||||
MESSAGE(STATUS "MPI_LINK_FLAGS: ${MPI_LINK_FLAGS}")
|
||||
SET_TARGET_PROPERTIES(THD PROPERTIES LINK_FLAGS "${MPI_LINK_FLAGS}")
|
||||
ENDIF()
|
||||
ENDIF()
|
||||
|
||||
# TODO we shouldn't need the USE_CUDA condition here. See https://github.com/pytorch/pytorch/issues/13101
|
||||
IF(USE_GLOO)
|
||||
ADD_DEPENDENCIES(THD gloo)
|
||||
IF(USE_CUDA)
|
||||
ADD_DEPENDENCIES(THD gloo_cuda)
|
||||
ELSEIF(USE_ROCM)
|
||||
ADD_DEPENDENCIES(THD gloo_hip)
|
||||
ENDIF()
|
||||
ENDIF()
|
||||
|
||||
IF(USE_NCCL)
|
||||
ADD_DEFINITIONS(-DUSE_DISTRIBUTED_NCCL=1)
|
||||
TARGET_LINK_LIBRARIES(THD PUBLIC __caffe2_nccl)
|
||||
ENDIF()
|
||||
|
||||
# Test executables
|
||||
IF(THD_WITH_TESTS)
|
||||
ENABLE_TESTING()
|
||||
FIND_PACKAGE(Threads)
|
||||
FOREACH(test_source_file ${test_cpp})
|
||||
# Prepare test names
|
||||
GET_FILENAME_COMPONENT(test_source_file ${test_source_file} NAME)
|
||||
STRING(REPLACE ".cpp" "" test_name ${test_source_file})
|
||||
SET(test_executable_name "test_${test_name}")
|
||||
|
||||
ADD_EXECUTABLE(${test_executable_name} "test/${test_source_file}")
|
||||
TARGET_LINK_LIBRARIES(${test_executable_name} THD ${CAFFE2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
SET_PROPERTY(TARGET ${test_executable_name} PROPERTY CXX_STANDARD 11)
|
||||
ADD_TEST(${test_name} ${test_executable_name})
|
||||
ENDFOREACH()
|
||||
ENDIF()
|
||||
|
||||
INSTALL(TARGETS THD DESTINATION ${THD_INSTALL_LIB_DIR})
|
||||
|
||||
FOREACH(HEADER ${all_h})
|
||||
STRING(REGEX MATCH "(.*)[/\\]" DIR ${HEADER})
|
||||
INSTALL(FILES ${HEADER} DESTINATION ${THD_INSTALL_INCLUDE_DIR}/THD/${DIR})
|
||||
ENDFOREACH()
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#ifdef __cplusplus
|
||||
#define THD_API extern "C"
|
||||
#else
|
||||
#define THD_API
|
||||
#endif
|
||||
|
||||
#ifndef _THD_CORE
|
||||
#include <THD/base/DataChannelRequest.h>
|
||||
#include <THD/base/TensorDescriptor.h>
|
||||
#else
|
||||
#include <THD/base/DataChannelRequest.hpp>
|
||||
#include <THD/base/TensorDescriptor.hpp>
|
||||
#endif
|
||||
#include <THD/base/ChannelType.h>
|
||||
#include <THD/base/Cuda.h>
|
||||
|
||||
#include <THD/process_group/Collectives.h>
|
||||
#include <THD/process_group/General.h>
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
enum THDChannelType {
|
||||
THDChannelTCP = 0,
|
||||
THDChannelMPI,
|
||||
THDChannelGloo,
|
||||
THDChannelNccl
|
||||
};
|
||||
|
|
@ -1,281 +0,0 @@
|
|||
#include <THD/base/ChannelUtils.hpp>
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <sys/poll.h>
|
||||
#include <unistd.h>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
namespace thd {
|
||||
namespace {
|
||||
|
||||
constexpr int LISTEN_QUEUE_SIZE = 1024;
|
||||
|
||||
void setSocketNoDelay(int socket) {
|
||||
int flag = 1;
|
||||
socklen_t optlen = sizeof(flag);
|
||||
SYSCHECK(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen));
|
||||
}
|
||||
|
||||
port_type getSocketPort(int fd) {
|
||||
port_type listen_port;
|
||||
struct sockaddr_storage addr_storage;
|
||||
socklen_t addr_len = sizeof(addr_storage);
|
||||
SYSCHECK(getsockname(
|
||||
fd, reinterpret_cast<struct sockaddr*>(&addr_storage), &addr_len));
|
||||
if (addr_storage.ss_family == AF_INET) {
|
||||
struct sockaddr_in* addr =
|
||||
reinterpret_cast<struct sockaddr_in*>(&addr_storage);
|
||||
listen_port = ntohs(addr->sin_port);
|
||||
} else if (addr_storage.ss_family == AF_INET6) { // AF_INET6
|
||||
struct sockaddr_in6* addr =
|
||||
reinterpret_cast<struct sockaddr_in6*>(&addr_storage);
|
||||
listen_port = ntohs(addr->sin6_port);
|
||||
} else {
|
||||
throw std::runtime_error("unsupported protocol");
|
||||
}
|
||||
return listen_port;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::pair<std::string, std::string> splitAddress(const std::string& addr) {
|
||||
std::string host, port;
|
||||
auto num_colons = std::count(addr.begin(), addr.end(), ':');
|
||||
if (num_colons > 1) {
|
||||
// IPv6
|
||||
auto end_pos = addr.find(']');
|
||||
if (addr[0] != '[' || end_pos == std::string::npos) {
|
||||
throw std::invalid_argument(
|
||||
"IPv6 address in an incorrect format (maybe you forgot to add [ ])");
|
||||
}
|
||||
host = addr.substr(1, end_pos - 1);
|
||||
port = addr.substr(end_pos + 2);
|
||||
} else if (num_colons == 1) {
|
||||
// IPv4 or HOSTNAME:PORT
|
||||
auto sep_pos = addr.find(':');
|
||||
host = addr.substr(0, sep_pos);
|
||||
port = addr.substr(sep_pos + 1);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"expected an address in format IP:PORT or HOSTNAME:PORT");
|
||||
}
|
||||
if (addr == "" || port == "") {
|
||||
throw std::invalid_argument("expected an address in format IP:PORT");
|
||||
}
|
||||
return std::make_pair(host, port);
|
||||
}
|
||||
|
||||
std::string sockaddrToString(struct sockaddr* addr) {
|
||||
char address[INET6_ADDRSTRLEN + 1];
|
||||
if (addr->sa_family == AF_INET) {
|
||||
struct sockaddr_in* s = reinterpret_cast<struct sockaddr_in*>(addr);
|
||||
SYSCHECK(::inet_ntop(AF_INET, &(s->sin_addr), address, INET_ADDRSTRLEN))
|
||||
address[INET_ADDRSTRLEN] = '\0';
|
||||
} else if (addr->sa_family == AF_INET6) {
|
||||
struct sockaddr_in6* s = reinterpret_cast<struct sockaddr_in6*>(addr);
|
||||
SYSCHECK(::inet_ntop(AF_INET6, &(s->sin6_addr), address, INET6_ADDRSTRLEN))
|
||||
address[INET6_ADDRSTRLEN] = '\0';
|
||||
} else {
|
||||
throw std::runtime_error("unsupported protocol");
|
||||
}
|
||||
return address;
|
||||
}
|
||||
|
||||
std::pair<int, port_type> listen(port_type port) {
|
||||
struct addrinfo hints, *res = NULL;
|
||||
|
||||
std::memset(&hints, 0x00, sizeof(hints));
|
||||
hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
|
||||
hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6
|
||||
hints.ai_socktype = SOCK_STREAM; // TCP
|
||||
|
||||
// `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked
|
||||
// by editing `/etc/gai.conf`. so there is no need to manual sorting
|
||||
// or protocol preference.
|
||||
int err = ::getaddrinfo(nullptr, std::to_string(port).data(), &hints, &res);
|
||||
if (err != 0 || !res) {
|
||||
throw std::invalid_argument(
|
||||
"cannot find host to listen on: " + std::string(gai_strerror(err)));
|
||||
}
|
||||
|
||||
std::shared_ptr<struct addrinfo> addresses(
|
||||
res, [](struct addrinfo* p) { ::freeaddrinfo(p); });
|
||||
|
||||
struct addrinfo* next_addr = addresses.get();
|
||||
int socket;
|
||||
while (true) {
|
||||
try {
|
||||
SYSCHECK(
|
||||
socket = ::socket(
|
||||
next_addr->ai_family,
|
||||
next_addr->ai_socktype,
|
||||
next_addr->ai_protocol))
|
||||
|
||||
int optval = 1;
|
||||
SYSCHECK(
|
||||
::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(int)))
|
||||
SYSCHECK(::bind(socket, next_addr->ai_addr, next_addr->ai_addrlen))
|
||||
SYSCHECK(::listen(socket, LISTEN_QUEUE_SIZE))
|
||||
break;
|
||||
} catch (const std::system_error& e) {
|
||||
::close(socket);
|
||||
next_addr = next_addr->ai_next;
|
||||
|
||||
// we have tried all addresses but could not start listening on any of
|
||||
// them
|
||||
if (!next_addr) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// get listen port and address
|
||||
return {socket, getSocketPort(socket)};
|
||||
}
|
||||
|
||||
int connect(
|
||||
const std::string& address,
|
||||
port_type port,
|
||||
bool wait,
|
||||
int timeout) {
|
||||
struct addrinfo hints, *res = NULL;
|
||||
|
||||
std::memset(&hints, 0x00, sizeof(hints));
|
||||
hints.ai_flags = AI_NUMERICSERV; // specifies that port (service) is numeric
|
||||
hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6
|
||||
hints.ai_socktype = SOCK_STREAM; // TCP
|
||||
|
||||
// `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked
|
||||
// by editing `/etc/gai.conf`. so there is no need to manual sorting
|
||||
// or protcol preference.
|
||||
int err =
|
||||
::getaddrinfo(address.data(), std::to_string(port).data(), &hints, &res);
|
||||
if (err != 0 || !res) {
|
||||
throw std::invalid_argument(
|
||||
"host not found: " + std::string(gai_strerror(err)));
|
||||
}
|
||||
|
||||
std::shared_ptr<struct addrinfo> addresses(
|
||||
res, [](struct addrinfo* p) { ::freeaddrinfo(p); });
|
||||
|
||||
struct addrinfo* next_addr = addresses.get();
|
||||
int socket;
|
||||
// we'll loop over the addresses only if at least of them gave us
|
||||
// ECONNREFUSED. Maybe the host was up, but the server wasn't running.
|
||||
bool any_refused = false;
|
||||
while (true) {
|
||||
try {
|
||||
SYSCHECK(
|
||||
socket = ::socket(
|
||||
next_addr->ai_family,
|
||||
next_addr->ai_socktype,
|
||||
next_addr->ai_protocol))
|
||||
ResourceGuard socket_guard([socket]() { ::close(socket); });
|
||||
|
||||
// We need to connect in non-blocking mode, so we can use a timeout
|
||||
SYSCHECK(::fcntl(socket, F_SETFL, O_NONBLOCK));
|
||||
|
||||
int ret = ::connect(socket, next_addr->ai_addr, next_addr->ai_addrlen);
|
||||
if (ret != 0 && errno != EINPROGRESS)
|
||||
throw std::system_error(errno, std::system_category());
|
||||
|
||||
struct pollfd pfd;
|
||||
pfd.fd = socket;
|
||||
pfd.events = POLLOUT;
|
||||
|
||||
int num_ready = ::poll(&pfd, 1, timeout);
|
||||
if (num_ready < 0) {
|
||||
throw std::system_error(errno, std::system_category());
|
||||
} else if (num_ready == 0) {
|
||||
errno = 0;
|
||||
throw std::runtime_error("connect() timed out");
|
||||
}
|
||||
|
||||
socklen_t err_len = sizeof(errno);
|
||||
errno = 0;
|
||||
::getsockopt(socket, SOL_SOCKET, SO_ERROR, &errno, &err_len);
|
||||
/* `errno` is set when:
|
||||
* 1. `getsockopt` has failed
|
||||
* 2. there is awaiting error in the socket (the error is saved to the
|
||||
* `errno` variable)
|
||||
*/
|
||||
if (errno != 0) {
|
||||
throw std::system_error(errno, std::system_category());
|
||||
}
|
||||
|
||||
// Disable non-blocking mode
|
||||
int flags;
|
||||
SYSCHECK(flags = ::fcntl(socket, F_GETFL));
|
||||
SYSCHECK(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK)));
|
||||
socket_guard.release();
|
||||
break;
|
||||
} catch (std::exception& e) {
|
||||
if (errno == ECONNREFUSED)
|
||||
any_refused = true;
|
||||
|
||||
// We need to move to the next address because this was not available
|
||||
// to connect or to create a socket.
|
||||
next_addr = next_addr->ai_next;
|
||||
|
||||
// We have tried all addresses but could not connect to any of them.
|
||||
if (!next_addr) {
|
||||
if (!wait || !any_refused)
|
||||
throw;
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1));
|
||||
any_refused = false;
|
||||
next_addr = addresses.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setSocketNoDelay(socket);
|
||||
|
||||
return socket;
|
||||
}
|
||||
|
||||
std::tuple<int, std::string> accept(int listen_socket, int timeout) {
|
||||
// poll on listen socket, it allows to make timeout
|
||||
std::unique_ptr<struct pollfd[]> events(new struct pollfd[1]);
|
||||
events[0] = {.fd = listen_socket, .events = POLLIN};
|
||||
|
||||
while (true) {
|
||||
int res = ::poll(events.get(), 1, timeout);
|
||||
if (res == 0) {
|
||||
throw std::runtime_error(
|
||||
"waiting for processes to connect has timed out");
|
||||
} else if (res == -1) {
|
||||
if (errno == EINTR) {
|
||||
continue;
|
||||
}
|
||||
throw std::system_error(errno, std::system_category());
|
||||
} else {
|
||||
if (!(events[0].revents & POLLIN))
|
||||
throw std::system_error(ECONNABORTED, std::system_category());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
int socket;
|
||||
SYSCHECK(socket = ::accept(listen_socket, NULL, NULL))
|
||||
|
||||
// Get address of the connecting process
|
||||
struct sockaddr_storage addr;
|
||||
socklen_t addr_len = sizeof(addr);
|
||||
SYSCHECK(::getpeername(
|
||||
socket, reinterpret_cast<struct sockaddr*>(&addr), &addr_len))
|
||||
|
||||
setSocketNoDelay(socket);
|
||||
|
||||
return std::make_tuple(
|
||||
socket, sockaddrToString(reinterpret_cast<struct sockaddr*>(&addr)));
|
||||
}
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,224 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <system_error>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
inline void hash_combine(size_t& seed) {}
|
||||
|
||||
template <typename T, typename... Rest>
|
||||
inline void hash_combine(size_t& seed, const T& v, Rest... rest) {
|
||||
std::hash<T> hasher;
|
||||
seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
hash_combine(seed, rest...);
|
||||
}
|
||||
|
||||
#define MAKE_HASHABLE(type, ...) \
|
||||
namespace std { \
|
||||
template <> \
|
||||
struct hash<type> { \
|
||||
size_t operator()(const type& t) const { \
|
||||
size_t ret = 0; \
|
||||
hash_combine(ret, __VA_ARGS__); \
|
||||
return ret; \
|
||||
} \
|
||||
}; \
|
||||
}
|
||||
|
||||
namespace thd {
|
||||
|
||||
enum class CollectiveType : std::uint8_t {
|
||||
ALL_GATHER = 0,
|
||||
GATHER,
|
||||
SCATTER,
|
||||
ALL_REDUCE,
|
||||
REDUCE,
|
||||
BROADCAST,
|
||||
SEND,
|
||||
BARRIER,
|
||||
LAST
|
||||
};
|
||||
|
||||
enum class DeviceType : std::uint8_t { CPU, CUDA, LAST };
|
||||
|
||||
inline DeviceType getDeviceType(at::Tensor& tensor) {
|
||||
return tensor.type().is_cuda() ? DeviceType::CUDA : DeviceType::CPU;
|
||||
}
|
||||
|
||||
} // namespace thd
|
||||
|
||||
MAKE_HASHABLE(::thd::CollectiveType, static_cast<std::uint8_t>(t));
|
||||
MAKE_HASHABLE(::thd::DeviceType, static_cast<std::uint8_t>(t));
|
||||
|
||||
namespace thd {
|
||||
|
||||
using rank_type = uint32_t;
|
||||
using port_type = uint16_t;
|
||||
using size_type = uint64_t;
|
||||
|
||||
#define SYSCHECK(expr) \
|
||||
{ \
|
||||
errno = 0; \
|
||||
auto ___output = (expr); \
|
||||
(void)___output; \
|
||||
if (errno != 0) \
|
||||
throw std::system_error(errno, std::system_category()); \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void send_bytes(
|
||||
int socket,
|
||||
const T* buffer,
|
||||
size_t length,
|
||||
bool more_data = false) {
|
||||
size_t bytes_to_send = sizeof(T) * length;
|
||||
if (bytes_to_send == 0)
|
||||
return;
|
||||
|
||||
auto bytes = reinterpret_cast<const std::uint8_t*>(buffer);
|
||||
std::uint8_t* current_bytes = const_cast<std::uint8_t*>(bytes);
|
||||
|
||||
int flags = 0;
|
||||
#ifdef MSG_MORE
|
||||
if (more_data) { // there is more data to send
|
||||
flags |= MSG_MORE;
|
||||
}
|
||||
#endif
|
||||
|
||||
while (bytes_to_send > 0) {
|
||||
ssize_t bytes_sent;
|
||||
SYSCHECK(bytes_sent = ::send(socket, current_bytes, bytes_to_send, flags))
|
||||
if (bytes_sent == 0)
|
||||
throw std::system_error(ECONNRESET, std::system_category());
|
||||
|
||||
bytes_to_send -= bytes_sent;
|
||||
current_bytes += bytes_sent;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void recv_bytes(int socket, T* buffer, size_t length) {
|
||||
size_t bytes_to_receive = sizeof(T) * length;
|
||||
if (bytes_to_receive == 0)
|
||||
return;
|
||||
|
||||
auto bytes = reinterpret_cast<std::uint8_t*>(buffer);
|
||||
std::uint8_t* current_bytes = bytes;
|
||||
|
||||
while (bytes_to_receive > 0) {
|
||||
ssize_t bytes_received;
|
||||
SYSCHECK(
|
||||
bytes_received = ::recv(socket, current_bytes, bytes_to_receive, 0))
|
||||
if (bytes_received == 0)
|
||||
throw std::system_error(ECONNRESET, std::system_category());
|
||||
|
||||
bytes_to_receive -= bytes_received;
|
||||
current_bytes += bytes_received;
|
||||
}
|
||||
}
|
||||
|
||||
inline port_type convertToPort(int64_t port) {
|
||||
if ((port < 0) || (port >= std::numeric_limits<port_type>::max()))
|
||||
throw std::domain_error("invalid port (value out of range)");
|
||||
|
||||
return static_cast<port_type>(port);
|
||||
}
|
||||
|
||||
inline rank_type convertToRank(int64_t rank, int64_t min = 0) {
|
||||
if ((rank < min) || (rank >= std::numeric_limits<rank_type>::max()))
|
||||
throw std::domain_error("invalid rank (value out of range)");
|
||||
|
||||
return static_cast<rank_type>(rank);
|
||||
}
|
||||
|
||||
std::pair<int, port_type> listen(port_type port = 0);
|
||||
int connect(
|
||||
const std::string& address,
|
||||
port_type port,
|
||||
bool wait = true,
|
||||
int timeout = -1);
|
||||
std::tuple<int, std::string> accept(int listen_socket, int timeout = -1);
|
||||
|
||||
std::string sockaddrToString(struct sockaddr* addr);
|
||||
std::pair<std::string, std::string> splitAddress(const std::string& addr);
|
||||
|
||||
/* send a string's length and data */
|
||||
inline void send_string(
|
||||
int socket,
|
||||
const std::string& str,
|
||||
bool more_data = false) {
|
||||
size_type size = str.size();
|
||||
send_bytes<size_type>(socket, &size, 1, true);
|
||||
send_bytes<char>(socket, str.data(), size, more_data);
|
||||
}
|
||||
|
||||
/* receive a string as sent in send_string */
|
||||
inline std::string recv_string(int socket) {
|
||||
size_type value_size;
|
||||
recv_bytes<size_type>(socket, &value_size, 1);
|
||||
std::vector<char> value(value_size);
|
||||
recv_bytes<char>(socket, value.data(), value.size());
|
||||
return std::string(value.data(), value.size());
|
||||
}
|
||||
|
||||
/* send a vector's length and data */
|
||||
template <typename T>
|
||||
void send_vector(
|
||||
int socket,
|
||||
const std::vector<T>& vec,
|
||||
bool more_data = false) {
|
||||
size_type size = vec.size();
|
||||
send_bytes<size_type>(socket, &size, 1, true);
|
||||
send_bytes<T>(socket, vec.data(), size, more_data);
|
||||
}
|
||||
|
||||
/* receive a vector as sent in send_vector */
|
||||
template <typename T>
|
||||
std::vector<T> recv_vector(int socket) {
|
||||
size_type value_size;
|
||||
recv_bytes<size_type>(socket, &value_size, 1);
|
||||
std::vector<char> value(value_size);
|
||||
recv_bytes<char>(socket, value.data(), value.size());
|
||||
return value;
|
||||
}
|
||||
|
||||
/* this is only for convenience when sending rvalues */
|
||||
template <typename T>
|
||||
void send_value(int socket, const T& value, bool more_data = false) {
|
||||
send_bytes<T>(socket, &value, 1, more_data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T recv_value(int socket) {
|
||||
T value;
|
||||
recv_bytes<T>(socket, &value, 1);
|
||||
return value;
|
||||
}
|
||||
|
||||
class ResourceGuard {
|
||||
std::function<void()> _destructor;
|
||||
bool _released;
|
||||
|
||||
public:
|
||||
ResourceGuard(std::function<void()> destructor)
|
||||
: _destructor(std::move(destructor)), _released(false) {}
|
||||
|
||||
~ResourceGuard() {
|
||||
if (!_released)
|
||||
_destructor();
|
||||
}
|
||||
|
||||
void release() {
|
||||
_released = true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
#include <THD/base/Cuda.hpp>
|
||||
#include <unordered_map>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
THCState** _THDCudaState;
|
||||
|
||||
void THDSetCudaStatePtr(THCState** state) {
|
||||
_THDCudaState = state;
|
||||
}
|
||||
|
||||
static int nextStreamId = 1; // 0 for the default stream
|
||||
static std::unordered_map<cudaStream_t, int> streamIdMap;
|
||||
|
||||
void THDRegisterCudaStream(cudaStream_t stream) {
|
||||
streamIdMap.emplace(stream, nextStreamId++);
|
||||
}
|
||||
|
||||
int THDGetStreamId(cudaStream_t stream) {
|
||||
if (!stream)
|
||||
return 0;
|
||||
auto it = streamIdMap.find(stream);
|
||||
if (it == streamIdMap.end()) {
|
||||
throw std::runtime_error(
|
||||
"using a stream that's hasn't been registered in THD");
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
#endif
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <THD/THD.h>
|
||||
|
||||
#include <THC/THC.h>
|
||||
|
||||
THD_API void THDSetCudaStatePtr(THCState** state);
|
||||
THD_API void THDRegisterCudaStream(cudaStream_t stream);
|
||||
#endif
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <THC/THC.h>
|
||||
#include <THD/base/Cuda.h>
|
||||
|
||||
extern THCState** _THDCudaState;
|
||||
|
||||
inline THCState* THDGetCudaState() {
|
||||
return *_THDCudaState;
|
||||
}
|
||||
|
||||
int THDGetStreamId(cudaStream_t stream);
|
||||
#endif
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
#include <THD/base/DataChannel.hpp>
|
||||
#ifdef WITH_GLOO
|
||||
#include <THD/base/data_channels/DataChannelGloo.hpp>
|
||||
#endif // WITH_GLOO
|
||||
#ifdef WITH_MPI
|
||||
#include <THD/base/data_channels/DataChannelMPI.hpp>
|
||||
#endif // WITH_MPI
|
||||
#if defined(USE_CUDA) && defined(USE_DISTRIBUTED_NCCL)
|
||||
#include <THD/base/data_channels/DataChannelNccl.hpp>
|
||||
#endif // USE_DISTRIBUTED_NCCL
|
||||
#include <THD/base/data_channels/DataChannelTCP.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <tuple>
|
||||
|
||||
namespace thd {
|
||||
|
||||
#define GET_CONFIG getInitConfig(init_method, world_size, group_name, rank)
|
||||
DataChannel* DataChannel::newChannel(
|
||||
THDChannelType type,
|
||||
std::string init_method,
|
||||
int world_size,
|
||||
std::string group_name,
|
||||
int rank) {
|
||||
switch (type) {
|
||||
case THDChannelTCP:
|
||||
return new DataChannelTCP(GET_CONFIG);
|
||||
|
||||
case THDChannelMPI:
|
||||
#ifdef WITH_MPI
|
||||
return new DataChannelMPI();
|
||||
#endif // WITH_MPI
|
||||
throw std::runtime_error(
|
||||
"the MPI backend is not available; "
|
||||
"try to recompile the THD package with MPI support");
|
||||
|
||||
case THDChannelGloo:
|
||||
#ifdef WITH_GLOO
|
||||
return new DataChannelGloo(GET_CONFIG);
|
||||
#endif // WITH_GLOO
|
||||
throw std::runtime_error(
|
||||
"the Gloo backend is not available; "
|
||||
"try to recompile the THD package with Gloo support");
|
||||
|
||||
case THDChannelNccl:
|
||||
#if defined(USE_CUDA) && defined(USE_DISTRIBUTED_NCCL)
|
||||
return new DataChannelNccl(GET_CONFIG);
|
||||
#endif
|
||||
throw std::runtime_error(
|
||||
"the distributed NCCL backend is not available; "
|
||||
"try to recompile the THD package with CUDA and NCCL 2+ support");
|
||||
|
||||
default:
|
||||
throw std::runtime_error("unsupported data channel type");
|
||||
}
|
||||
}
|
||||
#undef GET_CONFIG
|
||||
|
||||
DataChannel::Group::Group() {}
|
||||
|
||||
DataChannel::Group::Group(std::vector<rank_type> ranks, rank_type max_rank) {
|
||||
if (ranks.size() == 0)
|
||||
throw std::logic_error("cannot create empty group");
|
||||
|
||||
sort(ranks.begin(), ranks.end());
|
||||
if (ranks.back() > max_rank) {
|
||||
throw std::out_of_range(
|
||||
"array of ranks contains invalid rank, "
|
||||
"all ranks should be in range: [0, " +
|
||||
std::to_string(max_rank) + "]");
|
||||
}
|
||||
|
||||
_new2old.reserve(ranks.size());
|
||||
for (size_t i = 0; i < ranks.size(); ++i) {
|
||||
_new2old.push_back(ranks[i]);
|
||||
_old2new.insert({ranks[i], i});
|
||||
}
|
||||
}
|
||||
|
||||
DataChannel::Group::~Group() {}
|
||||
|
||||
auto DataChannel::Group::size() const -> rank_type {
|
||||
return static_cast<rank_type>(_new2old.size());
|
||||
}
|
||||
|
||||
auto DataChannel::Group::mustGetGroupRank(rank_type global_rank) const
|
||||
-> rank_type {
|
||||
rank_type group_rank;
|
||||
bool exists;
|
||||
std::tie(group_rank, exists) = getGroupRank(global_rank);
|
||||
|
||||
if (!exists) {
|
||||
throw std::logic_error(
|
||||
"rank(" + std::to_string(global_rank) + ") is not member of group");
|
||||
}
|
||||
|
||||
return group_rank;
|
||||
}
|
||||
|
||||
auto DataChannel::Group::getGroupRank(rank_type global_rank) const
|
||||
-> std::pair<rank_type, bool> {
|
||||
auto global_rank_it = _old2new.find(global_rank); // O(1) operation
|
||||
if (global_rank_it != _old2new.end())
|
||||
return std::make_pair(global_rank_it->second, true);
|
||||
|
||||
return std::make_pair(0, false);
|
||||
}
|
||||
|
||||
auto DataChannel::Group::mustGetGlobalRank(rank_type group_rank) const
|
||||
-> rank_type {
|
||||
rank_type global_rank;
|
||||
bool exists;
|
||||
std::tie(global_rank, exists) = getGlobalRank(group_rank);
|
||||
|
||||
if (!exists) {
|
||||
throw std::logic_error(
|
||||
"group rank is invalid, rank should be in "
|
||||
"range: [0, " +
|
||||
std::to_string(_new2old.size() - 1) + "]");
|
||||
}
|
||||
|
||||
return global_rank;
|
||||
}
|
||||
|
||||
auto DataChannel::Group::getGlobalRank(rank_type group_rank) const
|
||||
-> std::pair<rank_type, bool> {
|
||||
if (group_rank >= _new2old.size())
|
||||
return std::make_pair(0, false);
|
||||
|
||||
return std::make_pair(_new2old[group_rank], true);
|
||||
}
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
enum THDReduceOp {
|
||||
THDReduceMIN = 0,
|
||||
THDReduceMAX,
|
||||
THDReduceSUM,
|
||||
THDReducePRODUCT,
|
||||
};
|
||||
|
||||
typedef int THDGroup;
|
||||
const THDGroup THDGroupWORLD = 0;
|
||||
|
|
@ -1,168 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/base/ChannelType.h>
|
||||
#include <THD/base/ChannelUtils.hpp>
|
||||
#include <THD/base/DataChannel.h>
|
||||
#include <THD/base/Scalar.hpp>
|
||||
#include <THD/base/init_methods/InitMethod.hpp>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
MAKE_HASHABLE(THDReduceOp, static_cast<int>(t));
|
||||
MAKE_HASHABLE(thd::RPCType, static_cast<char>(t));
|
||||
MAKE_HASHABLE(at::ScalarType, static_cast<int>(t));
|
||||
|
||||
namespace thd {
|
||||
|
||||
struct DataChannel {
|
||||
struct Request {
|
||||
Request(){};
|
||||
virtual ~Request(){};
|
||||
|
||||
// Checks if request has completed. Non-blocking operation.
|
||||
virtual bool isCompleted() = 0;
|
||||
// Waits until request completes. Blocking operation.
|
||||
virtual void wait() = 0;
|
||||
};
|
||||
|
||||
struct Group {
|
||||
Group();
|
||||
/*
|
||||
* Constructs `Group` from provided `ranks` and checks if all ranks are
|
||||
* in range: [0, `max_rank`].
|
||||
*
|
||||
* `ranks` vector should have mapping from new ranks to old ranks (global
|
||||
* ranks) eg. ranks = {[0] = 6, [1] = 2} which means that 0 and 1 are new
|
||||
* ranks in group and 6, 2 are global ranks corresponding to 0 and 1
|
||||
* respectively.
|
||||
*/
|
||||
Group(std::vector<rank_type> ranks, rank_type max_rank);
|
||||
virtual ~Group();
|
||||
|
||||
rank_type size() const;
|
||||
|
||||
/*
|
||||
* In contrast to `getGroupRank` this function throws `std::logic_error`
|
||||
* when rank is member of this group.
|
||||
*/
|
||||
rank_type mustGetGroupRank(rank_type global_rank) const;
|
||||
std::pair<rank_type, bool> getGroupRank(rank_type global_rank) const;
|
||||
|
||||
/*
|
||||
* In contrast to `getGlobalRank` this function throws `std::logic_error`
|
||||
* when provided `group_rank` is not in range of group.
|
||||
*/
|
||||
rank_type mustGetGlobalRank(rank_type group_rank) const;
|
||||
std::pair<rank_type, bool> getGlobalRank(rank_type group_rank) const;
|
||||
|
||||
private:
|
||||
// maps new group ranks to old ranks (global ranks)
|
||||
std::vector<rank_type> _new2old;
|
||||
|
||||
// maps old ranks (global ranks) to new group ranks
|
||||
std::unordered_map<rank_type, rank_type> _old2new;
|
||||
};
|
||||
|
||||
DataChannel(){};
|
||||
virtual ~DataChannel(){};
|
||||
|
||||
virtual bool init() = 0;
|
||||
|
||||
/**
|
||||
* This is required for NCCL backend, since the destroy cannot be done before
|
||||
* CUDA is unloaded since DataChannel is a static object.
|
||||
*/
|
||||
virtual void destroy() = 0;
|
||||
|
||||
virtual rank_type getRank() = 0;
|
||||
virtual rank_type getNumProcesses() = 0;
|
||||
|
||||
/**
|
||||
* All gather inputs from multiple GPUs, each Tensor in input vector should be
|
||||
* on a separate GPU.
|
||||
*
|
||||
* Also note that the output vector is a 1D vector (flattened from 2D),
|
||||
* with the size of input.size() * world_size.
|
||||
*
|
||||
* For instance, rank i 's input[k] tensor would be in
|
||||
* output[i * input.size() + k].
|
||||
*/
|
||||
virtual void allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
std::vector<at::Tensor>& input,
|
||||
THDGroup groupId = THDGroupWORLD) = 0;
|
||||
virtual void allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
THDGroup group_id = THDGroupWORLD) = 0;
|
||||
virtual void gather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id = THDGroupWORLD) = 0;
|
||||
virtual void scatter(
|
||||
std::vector<at::Tensor>& input,
|
||||
at::Tensor& output,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id = THDGroupWORLD) = 0;
|
||||
// All reduce multiple GPUs on a number of nodes
|
||||
virtual void allReduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup group_id = THDGroupWORLD) = 0;
|
||||
virtual void allReduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup group_id = THDGroupWORLD) = 0;
|
||||
/**
|
||||
* Reduce multiple GPUs on a number of nodes
|
||||
* data[0]'s GPU in dstRank will receive the result
|
||||
*/
|
||||
virtual void reduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dstRank,
|
||||
THDGroup groupId = THDGroupWORLD) = 0;
|
||||
virtual void reduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id = THDGroupWORLD) = 0;
|
||||
/**
|
||||
* Broadcast multiple GPUs on a number of nodes
|
||||
* data[0]'s GPU in srcRank will be the source to broadcast
|
||||
*/
|
||||
virtual void broadcast(
|
||||
std::vector<at::Tensor>& data,
|
||||
rank_type srcRank,
|
||||
THDGroup groupId = THDGroupWORLD) = 0;
|
||||
virtual void broadcast(
|
||||
at::Tensor& data,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id = THDGroupWORLD) = 0;
|
||||
virtual void send(Scalar& value, rank_type src_rank) = 0;
|
||||
virtual void send(at::Tensor& data, rank_type dst_rank) = 0;
|
||||
virtual void receive(Scalar& value, rank_type src_rank) = 0;
|
||||
virtual rank_type receive(at::Tensor& data) = 0; // receive from any source
|
||||
virtual void receive(at::Tensor& data, rank_type src_rank) = 0;
|
||||
virtual Request* isend(at::Tensor& data, rank_type dst_rank) = 0;
|
||||
virtual Request* ireceive(at::Tensor& data, rank_type src_rank) = 0;
|
||||
|
||||
virtual void barrier(THDGroup group_id = THDGroupWORLD) = 0;
|
||||
|
||||
virtual THDGroup newGroup(const std::vector<rank_type>& ranks) = 0;
|
||||
virtual void clearGroupCache(THDGroup group_id = THDGroupWORLD) = 0;
|
||||
|
||||
static DataChannel* newChannel(
|
||||
THDChannelType type,
|
||||
std::string init_method,
|
||||
int world_size,
|
||||
std::string group_name,
|
||||
int rank);
|
||||
};
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
#include <THD/base/DataChannelRequest.hpp>
|
||||
|
||||
THD_API void THDRequest_free(void* request) {
|
||||
delete (THDRequest*)request;
|
||||
}
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/THD.h>
|
||||
|
||||
#ifndef _THD_CORE
|
||||
struct _THDRequest;
|
||||
typedef struct _THDRequest THDRequest;
|
||||
#endif
|
||||
|
||||
THD_API void THDRequest_free(void* req);
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/base/DataChannel.hpp>
|
||||
using THDRequest = thd::DataChannel::Request;
|
||||
|
||||
#include <THD/base/DataChannelRequest.h>
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#define HANDLE_EXCEPTIONS try {
|
||||
#define END_HANDLE_EXCEPTIONS \
|
||||
} \
|
||||
catch (std::exception & e) { \
|
||||
THError(e.what()); \
|
||||
}
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
#include <THD/base/RPCType.hpp>
|
||||
|
||||
namespace thd {
|
||||
|
||||
// Static constexpr variables have to be defined out-of-source in C++11.
|
||||
// https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char
|
||||
constexpr RPCType type_traits<char>::type;
|
||||
constexpr RPCType type_traits<int8_t>::type;
|
||||
constexpr RPCType type_traits<uint8_t>::type;
|
||||
constexpr RPCType type_traits<float>::type;
|
||||
constexpr RPCType type_traits<double>::type;
|
||||
constexpr RPCType type_traits<int16_t>::type;
|
||||
constexpr RPCType type_traits<int32_t>::type;
|
||||
constexpr RPCType type_traits<uint32_t>::type;
|
||||
constexpr RPCType type_traits<uint16_t>::type;
|
||||
constexpr RPCType type_traits<int64_t>::type;
|
||||
constexpr RPCType type_traits<uint64_t>::type;
|
||||
constexpr RPCType type_traits<
|
||||
std::conditional<std::is_same<int64_t, long>::value, long long, long>::
|
||||
type>::type;
|
||||
constexpr RPCType type_traits<std::conditional<
|
||||
std::is_same<uint64_t, unsigned long>::value,
|
||||
unsigned long long,
|
||||
unsigned long>::type>::type;
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,191 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace thd {
|
||||
|
||||
/*
|
||||
* The following notation comes from:
|
||||
* docs.python.org/3.5/library/struct.html#module-struct
|
||||
* except from 'T', which stands for Tensor
|
||||
*/
|
||||
|
||||
enum class RPCType : char {
|
||||
CHAR = 'c',
|
||||
UCHAR = 'B',
|
||||
FLOAT = 'f',
|
||||
DOUBLE = 'd',
|
||||
HALF = 'a',
|
||||
SHORT = 'h',
|
||||
USHORT = 'H',
|
||||
INT = 'i',
|
||||
UINT = 'I',
|
||||
LONG = 'l',
|
||||
ULONG = 'L',
|
||||
LONG_LONG = 'q',
|
||||
ULONG_LONG = 'Q',
|
||||
LONG_STORAGE = 'X',
|
||||
TENSOR = 'T',
|
||||
STORAGE = 'S',
|
||||
GENERATOR = 'G',
|
||||
};
|
||||
|
||||
inline bool isFloat(RPCType t) {
|
||||
return (t == RPCType::FLOAT || t == RPCType::DOUBLE || t == RPCType::HALF);
|
||||
}
|
||||
|
||||
inline bool isInteger(RPCType t) {
|
||||
return (
|
||||
t == RPCType::CHAR || t == RPCType::UCHAR || t == RPCType::SHORT ||
|
||||
t == RPCType::USHORT || t == RPCType::INT || t == RPCType::UINT ||
|
||||
t == RPCType::LONG || t == RPCType::ULONG || t == RPCType::LONG_LONG ||
|
||||
t == RPCType::ULONG_LONG);
|
||||
}
|
||||
|
||||
inline const char* toString(RPCType t) {
|
||||
switch (t) {
|
||||
case RPCType::CHAR:
|
||||
return "Char";
|
||||
case RPCType::UCHAR:
|
||||
return "Byte";
|
||||
case RPCType::FLOAT:
|
||||
return "Float";
|
||||
case RPCType::DOUBLE:
|
||||
return "Double";
|
||||
case RPCType::HALF:
|
||||
return "Half";
|
||||
case RPCType::SHORT:
|
||||
return "Short";
|
||||
case RPCType::USHORT:
|
||||
return "UShort";
|
||||
case RPCType::INT:
|
||||
return "Int";
|
||||
case RPCType::UINT:
|
||||
return "UInt";
|
||||
case RPCType::LONG:
|
||||
return "Long";
|
||||
case RPCType::ULONG:
|
||||
return "ULong";
|
||||
case RPCType::LONG_LONG:
|
||||
return "LongLong";
|
||||
case RPCType::ULONG_LONG:
|
||||
return "ULongLong";
|
||||
case RPCType::LONG_STORAGE:
|
||||
return "LongStorage";
|
||||
case RPCType::TENSOR:
|
||||
return "Tensor";
|
||||
case RPCType::STORAGE:
|
||||
return "Storage";
|
||||
default:
|
||||
return "<unknown>";
|
||||
}
|
||||
}
|
||||
|
||||
inline bool isObject(RPCType t) {
|
||||
return (
|
||||
t == RPCType::TENSOR || t == RPCType::STORAGE || t == RPCType::GENERATOR);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct type_traits {};
|
||||
|
||||
// NOTE: The `type` static constexpr variables of these specializations are
|
||||
// additionally defined in RPCType.cpp to avoid undefined
|
||||
// reference errors in C++11.
|
||||
template <>
|
||||
struct type_traits<char> {
|
||||
static constexpr RPCType type = RPCType::CHAR;
|
||||
static constexpr bool is_floating_point = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_traits<int8_t> {
|
||||
static constexpr RPCType type = RPCType::CHAR;
|
||||
static constexpr bool is_floating_point = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_traits<uint8_t> {
|
||||
static constexpr RPCType type = RPCType::UCHAR;
|
||||
static constexpr bool is_floating_point = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_traits<float> {
|
||||
static constexpr RPCType type = RPCType::FLOAT;
|
||||
static constexpr bool is_floating_point = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_traits<double> {
|
||||
static constexpr RPCType type = RPCType::DOUBLE;
|
||||
static constexpr bool is_floating_point = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_traits<int16_t> {
|
||||
static constexpr RPCType type = RPCType::SHORT;
|
||||
static constexpr bool is_floating_point = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_traits<uint16_t> {
|
||||
static constexpr RPCType type = RPCType::USHORT;
|
||||
static constexpr bool is_floating_point = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_traits<int32_t> {
|
||||
static constexpr RPCType type = RPCType::INT;
|
||||
static constexpr bool is_floating_point = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_traits<uint32_t> {
|
||||
static constexpr RPCType type = RPCType::UINT;
|
||||
static constexpr bool is_floating_point = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_traits<int64_t> {
|
||||
static constexpr RPCType type =
|
||||
std::is_same<int64_t, long>::value ? RPCType::LONG : RPCType::LONG_LONG;
|
||||
static constexpr bool is_floating_point = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_traits<uint64_t> {
|
||||
static constexpr RPCType type = std::is_same<uint64_t, unsigned long>::value
|
||||
? RPCType::ULONG
|
||||
: RPCType::ULONG_LONG;
|
||||
static constexpr bool is_floating_point = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_traits<
|
||||
std::conditional<std::is_same<int64_t, long>::value, long long, long>::
|
||||
type> {
|
||||
static constexpr RPCType type =
|
||||
std::is_same<int64_t, long>::value ? RPCType::LONG_LONG : RPCType::LONG;
|
||||
static constexpr bool is_floating_point = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_traits<std::conditional<
|
||||
std::is_same<uint64_t, unsigned long>::value,
|
||||
unsigned long long,
|
||||
unsigned long>::type> {
|
||||
static constexpr RPCType type = std::is_same<uint64_t, unsigned long>::value
|
||||
? RPCType::ULONG_LONG
|
||||
: RPCType::ULONG;
|
||||
static constexpr bool is_floating_point = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct type_traits<const T> : type_traits<T> {};
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include <THD/base/RPCType.hpp>
|
||||
|
||||
namespace thd {
|
||||
|
||||
struct Scalar {
|
||||
Scalar() {}
|
||||
Scalar(const Scalar& other) = delete;
|
||||
Scalar(Scalar&& other) = delete;
|
||||
virtual ~Scalar() {}
|
||||
|
||||
virtual size_t elementSize() const = 0;
|
||||
virtual void* data() = 0;
|
||||
virtual const void* data() const = 0;
|
||||
virtual RPCType type() const = 0;
|
||||
virtual Scalar* clone() const = 0;
|
||||
};
|
||||
|
||||
template <typename real>
|
||||
struct ScalarWrapper : Scalar {
|
||||
ScalarWrapper() {}
|
||||
ScalarWrapper(real value) : _value(value) {}
|
||||
virtual ~ScalarWrapper() {}
|
||||
|
||||
virtual size_t elementSize() const override {
|
||||
return sizeof(real);
|
||||
}
|
||||
|
||||
virtual void* data() override {
|
||||
return &_value;
|
||||
}
|
||||
|
||||
virtual const void* data() const override {
|
||||
return &_value;
|
||||
}
|
||||
|
||||
virtual RPCType type() const override {
|
||||
return type_traits<real>::type;
|
||||
}
|
||||
|
||||
virtual ScalarWrapper* clone() const override {
|
||||
return new ScalarWrapper(value());
|
||||
}
|
||||
|
||||
real value() const {
|
||||
return _value;
|
||||
}
|
||||
|
||||
private:
|
||||
real _value;
|
||||
};
|
||||
|
||||
using FloatScalar = ScalarWrapper<double>;
|
||||
using IntScalar = ScalarWrapper<int64_t>;
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
#ifndef THD_GENERIC_FILE
|
||||
#error "You must define THD_GENERIC_FILE before including THDGenerateAllTypes.h"
|
||||
#endif
|
||||
|
||||
#define real uint8_t
|
||||
#define accreal int64_t
|
||||
#define Real Byte
|
||||
#define THDInf UCHAR_MAX
|
||||
#define THD_REAL_IS_BYTE
|
||||
#line 1 THD_GENERIC_FILE
|
||||
#include THD_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef THDInf
|
||||
#undef THD_REAL_IS_BYTE
|
||||
|
||||
#define real int8_t
|
||||
#define accreal int64_t
|
||||
#define Real Char
|
||||
#define THDInf SCHAR_MAX
|
||||
#define THD_REAL_IS_CHAR
|
||||
#line 1 THD_GENERIC_FILE
|
||||
#include THD_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef THDInf
|
||||
#undef THD_REAL_IS_CHAR
|
||||
|
||||
#define real int16_t
|
||||
#define accreal int64_t
|
||||
#define Real Short
|
||||
#define THDInf SHRT_MAX
|
||||
#define THD_REAL_IS_SHORT
|
||||
#line 1 THD_GENERIC_FILE
|
||||
#include THD_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef THDInf
|
||||
#undef THD_REAL_IS_SHORT
|
||||
|
||||
#define real int32_t
|
||||
#define accreal int64_t
|
||||
#define Real Int
|
||||
#define THDInf INT_MAX
|
||||
#define THD_REAL_IS_INT
|
||||
#line 1 THD_GENERIC_FILE
|
||||
#include THD_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef THDInf
|
||||
#undef THD_REAL_IS_INT
|
||||
|
||||
#define real int64_t
|
||||
#define accreal int64_t
|
||||
#define Real Long
|
||||
#define THDInf LONG_MAX
|
||||
#define THD_REAL_IS_LONG
|
||||
#line 1 THD_GENERIC_FILE
|
||||
#include THD_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef THDInf
|
||||
#undef THD_REAL_IS_LONG
|
||||
|
||||
#define real float
|
||||
#define accreal double
|
||||
#define Real Float
|
||||
#define THDInf FLT_MAX
|
||||
#define THD_REAL_IS_FLOAT
|
||||
#line 1 THD_GENERIC_FILE
|
||||
#include THD_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef THDInf
|
||||
#undef THD_REAL_IS_FLOAT
|
||||
|
||||
#define real double
|
||||
#define accreal double
|
||||
#define Real Double
|
||||
#define THDInf DBL_MAX
|
||||
#define THD_REAL_IS_DOUBLE
|
||||
#line 1 THD_GENERIC_FILE
|
||||
#include THD_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef THDInf
|
||||
#undef THD_REAL_IS_DOUBLE
|
||||
|
||||
#undef THD_GENERIC_FILE
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <TH/TH.h>
|
||||
#include <THD/THD.h>
|
||||
#ifdef USE_CUDA
|
||||
#include <THC/THC.h>
|
||||
#endif
|
||||
|
||||
#ifndef _THD_CORE
|
||||
#include <ATen/ATen.h>
|
||||
using THDTensorDescriptor = at::Tensor;
|
||||
#endif
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
using THDTensorDescriptor = at::Tensor;
|
||||
|
||||
#include <THD/base/TensorDescriptor.h>
|
||||
|
|
@ -1,404 +0,0 @@
|
|||
#include <THD/base/data_channels/DataChannelGloo.hpp>
|
||||
#include <THD/base/data_channels/DataChannelUtils.hpp>
|
||||
#include <THD/base/data_channels/GlooCache.hpp>
|
||||
#include <THD/base/data_channels/Store.hpp>
|
||||
|
||||
#if defined(USE_GLOO_IBVERBS) && USE_GLOO_IBVERBS
|
||||
#include <gloo/transport/ibverbs/device.h>
|
||||
#endif
|
||||
|
||||
#include <gloo/transport/tcp/device.h>
|
||||
|
||||
#include <unistd.h>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#define RETURN_IF_NOT_IN_GROUP \
|
||||
{ \
|
||||
bool exists; \
|
||||
std::tie(std::ignore, exists) = _groups.at(group_id).getGroupRank(_rank); \
|
||||
if (!exists) \
|
||||
return; \
|
||||
}
|
||||
|
||||
// TODO: gloo uses stdint types for integral values and there's some weird
|
||||
// template magic going on that mangles names so that they don't always match
|
||||
// the types below. Only float and double are left enabled for now, because
|
||||
// they're most useful and unambiguous.
|
||||
#define GENERATE_ALL_TYPES(type, func, args...) \
|
||||
switch (type) { \
|
||||
case ::at::ScalarType::Float: \
|
||||
func<float>(args); \
|
||||
break; \
|
||||
case ::at::ScalarType::Double: \
|
||||
func<double>(args); \
|
||||
break; \
|
||||
case ::at::ScalarType::Half: \
|
||||
func<gloo::float16>(args); \
|
||||
break; \
|
||||
case ::at::ScalarType::Char: \
|
||||
func<int8_t>(args); \
|
||||
break; \
|
||||
case ::at::ScalarType::Byte: \
|
||||
func<uint8_t>(args); \
|
||||
break; \
|
||||
case ::at::ScalarType::Int: \
|
||||
func<int32_t>(args); \
|
||||
break; \
|
||||
case ::at::ScalarType::Long: \
|
||||
func<int64_t>(args); \
|
||||
break; \
|
||||
default: \
|
||||
throw std::runtime_error( \
|
||||
"Invalid " + std::string(#func) + " function type"); \
|
||||
}
|
||||
|
||||
namespace thd {
|
||||
|
||||
DataChannelGloo::RequestGloo::RequestGloo(QueueWorker::Request&& request)
|
||||
: _request(std::move(request)) {}
|
||||
|
||||
DataChannelGloo::RequestGloo::~RequestGloo() {}
|
||||
|
||||
bool DataChannelGloo::RequestGloo::isCompleted() {
|
||||
return _request.isCompleted();
|
||||
}
|
||||
|
||||
void DataChannelGloo::RequestGloo::wait() {
|
||||
_request.wait();
|
||||
}
|
||||
|
||||
DataChannelGloo::Group::Group(
|
||||
const std::string& addr,
|
||||
port_type port,
|
||||
std::vector<rank_type> ranks,
|
||||
rank_type max_rank,
|
||||
int store_socket)
|
||||
: DataChannel::Group(std::move(ranks), max_rank),
|
||||
_store(new Store(addr, port, store_socket)) {}
|
||||
|
||||
DataChannelGloo::DataChannelGloo(InitMethod::Config config)
|
||||
: _rank(config.rank), _listen_socket(-1), _cache(nullptr) {
|
||||
_num_processes = config.world_size;
|
||||
|
||||
#if defined(USE_GLOO_IBVERBS) && USE_GLOO_IBVERBS
|
||||
|
||||
// This helper function automatically detects the IB device in the system
|
||||
auto ibDeviceNames = ::gloo::transport::ibverbs::getDeviceNames();
|
||||
|
||||
// If there are IB devices, we will use IB
|
||||
if (!ibDeviceNames.empty()) {
|
||||
// Currently, gloo only supports a single IB device and will use the first
|
||||
auto ibDeviceToUse = ibDeviceNames[0];
|
||||
|
||||
::gloo::transport::ibverbs::attr attr = {
|
||||
.name = ibDeviceToUse,
|
||||
.port = 1,
|
||||
.index = 0,
|
||||
};
|
||||
|
||||
_deviceList.push_back(::gloo::transport::ibverbs::CreateDevice(attr));
|
||||
|
||||
// Otherwise, fallback to use TCP instead
|
||||
} else
|
||||
#endif
|
||||
|
||||
{
|
||||
// Default options listen on this host's name.
|
||||
// NOTE: when hostname has bad configuration in `/etc/hosts` processes
|
||||
// will not connect to each other.
|
||||
::gloo::transport::tcp::attr attr(config.public_address.c_str());
|
||||
_deviceList.push_back(::gloo::transport::tcp::CreateDevice(attr));
|
||||
}
|
||||
|
||||
if (_rank == 0) {
|
||||
_addr = "localhost";
|
||||
_port = config.master.listen_port;
|
||||
_listen_socket = config.master.listen_socket;
|
||||
} else {
|
||||
_addr = config.worker.master_addr;
|
||||
_port = config.worker.master_port;
|
||||
}
|
||||
}
|
||||
|
||||
DataChannelGloo::~DataChannelGloo() {
|
||||
if (_listen_socket != -1) {
|
||||
::close(_listen_socket);
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelGloo::destroy() {}
|
||||
|
||||
bool DataChannelGloo::init() {
|
||||
_cache = std::unique_ptr<GlooCache>(new GlooCache(_rank, _deviceList));
|
||||
|
||||
std::vector<rank_type> ranks;
|
||||
ranks.reserve(_num_processes);
|
||||
for (rank_type rank = 0; rank < _num_processes; ++rank)
|
||||
ranks.push_back(rank);
|
||||
|
||||
_groups.insert({THDGroupWORLD,
|
||||
Group(
|
||||
_addr,
|
||||
_port,
|
||||
ranks,
|
||||
_num_processes - 1,
|
||||
_rank == 0 ? _listen_socket : Store::CLIENT_ONLY)});
|
||||
return true;
|
||||
}
|
||||
|
||||
rank_type DataChannelGloo::getRank() {
|
||||
return _rank;
|
||||
}
|
||||
|
||||
rank_type DataChannelGloo::getNumProcesses() {
|
||||
return _num_processes;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DataChannelGloo::allGatherT(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
THDGroup group_id) {
|
||||
auto input_device = getDeviceType(input);
|
||||
for (auto& out : output) {
|
||||
if (input_device != getDeviceType(out)) {
|
||||
throw std::runtime_error(
|
||||
"allGather got input and output on different devices");
|
||||
}
|
||||
}
|
||||
uint64_t tensor_bytes = input.element_size() * input.numel();
|
||||
uint64_t all_tensor_bytes = tensor_bytes * output.size();
|
||||
auto ret = _cache->getAlgorithm<CollectiveType::ALL_GATHER, T>(
|
||||
group_id,
|
||||
_groups.at(group_id),
|
||||
input_device,
|
||||
tensor_bytes,
|
||||
all_tensor_bytes,
|
||||
input.numel());
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*GlooCache::mutex(ret));
|
||||
std::memcpy(
|
||||
GlooCache::input_buffer(ret).get(), input.data_ptr(), tensor_bytes);
|
||||
GlooCache::algorithm(ret)->run();
|
||||
for (size_t i = 0; i < output.size(); i++) {
|
||||
std::memcpy(
|
||||
output.at(i).data_ptr(),
|
||||
GlooCache::output_buffer(ret).get() + (i * tensor_bytes),
|
||||
tensor_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelGloo::allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
THDGroup group_id) {
|
||||
RETURN_IF_NOT_IN_GROUP
|
||||
|
||||
if (output.size() != _groups.at(group_id).size())
|
||||
throw std::logic_error(
|
||||
"allGather: number of output tensors and group size does not match");
|
||||
|
||||
for (auto out_tensor : output)
|
||||
assertSameSizeAndType(out_tensor, input, "allGather");
|
||||
|
||||
GENERATE_ALL_TYPES(
|
||||
input.scalar_type(), allGatherT, output, input, group_id)
|
||||
}
|
||||
|
||||
// XXX: `gather` is not supported by Gloo yet.
|
||||
void DataChannelGloo::gather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id) {
|
||||
throw std::runtime_error("DataChannelGloo doesn't support gather");
|
||||
}
|
||||
|
||||
// XXX: `scatter` is not supported by Gloo yet.
|
||||
void DataChannelGloo::scatter(
|
||||
std::vector<at::Tensor>& input,
|
||||
at::Tensor& output,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id) {
|
||||
throw std::runtime_error("DataChannelGloo does not support scatter");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DataChannelGloo::allReduceT(
|
||||
at::Tensor& t,
|
||||
THDReduceOp operation,
|
||||
THDGroup group_id) {
|
||||
uint64_t tensor_bytes = t.element_size() * t.numel();
|
||||
auto ret = _cache->getAlgorithm<CollectiveType::ALL_REDUCE, T>(
|
||||
group_id,
|
||||
_groups.at(group_id),
|
||||
getDeviceType(t),
|
||||
tensor_bytes,
|
||||
t.numel(),
|
||||
operation);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*GlooCache::mutex(ret));
|
||||
GlooCache::memcpy_input(ret, t);
|
||||
GlooCache::algorithm(ret)->run();
|
||||
GlooCache::memcpy_output(ret, t);
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelGloo::allReduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup group_id) {
|
||||
RETURN_IF_NOT_IN_GROUP
|
||||
GENERATE_ALL_TYPES(
|
||||
data.scalar_type(), allReduceT, data, operation, group_id)
|
||||
}
|
||||
|
||||
// XXX: `reduce` is not supported by Gloo yet.
|
||||
void DataChannelGloo::reduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id) {
|
||||
throw std::runtime_error("DataChannelGloo does not support reduce");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DataChannelGloo::broadcastT(
|
||||
at::Tensor& data,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id) {
|
||||
uint64_t tensor_bytes = data.element_size() * data.numel();
|
||||
auto ret = _cache->getAlgorithm<CollectiveType::BROADCAST, T>(
|
||||
group_id,
|
||||
_groups.at(group_id),
|
||||
getDeviceType(data),
|
||||
tensor_bytes,
|
||||
data.numel(),
|
||||
_groups.at(group_id).mustGetGroupRank(src_rank));
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*GlooCache::mutex(ret));
|
||||
if (_rank == src_rank) {
|
||||
GlooCache::memcpy_input(ret, data);
|
||||
}
|
||||
|
||||
GlooCache::algorithm(ret)->run();
|
||||
|
||||
if (_rank != src_rank) {
|
||||
GlooCache::memcpy_output(ret, data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelGloo::broadcast(
|
||||
at::Tensor& data,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id) {
|
||||
RETURN_IF_NOT_IN_GROUP
|
||||
GENERATE_ALL_TYPES(
|
||||
data.scalar_type(), broadcastT, data, src_rank, group_id)
|
||||
}
|
||||
|
||||
void DataChannelGloo::send(Scalar& data, rank_type dst_rank) {
|
||||
throw std::runtime_error("DataChannelGloo does not support send");
|
||||
}
|
||||
|
||||
void DataChannelGloo::send(at::Tensor& data, rank_type dst_rank) {
|
||||
throw std::runtime_error("DataChannelGloo does not support send");
|
||||
}
|
||||
|
||||
void DataChannelGloo::receive(Scalar& data, rank_type src_rank) {
|
||||
throw std::runtime_error("DataChannelGloo does not support receive");
|
||||
}
|
||||
|
||||
rank_type DataChannelGloo::receive(at::Tensor& data) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelGloo does not support receive from any source");
|
||||
}
|
||||
|
||||
void DataChannelGloo::receive(at::Tensor& data, rank_type src_rank) {
|
||||
throw std::runtime_error("DataChannelGloo does not support receive");
|
||||
}
|
||||
|
||||
auto DataChannelGloo::isend(at::Tensor& data, rank_type dst_rank)
|
||||
-> RequestGloo* {
|
||||
throw std::runtime_error("DataChannelGloo does not support isend");
|
||||
}
|
||||
|
||||
auto DataChannelGloo::ireceive(at::Tensor& data, rank_type src_rank)
|
||||
-> RequestGloo* {
|
||||
throw std::runtime_error("DataChannelGloo does not support ireceive");
|
||||
}
|
||||
|
||||
void DataChannelGloo::allReduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelGloo does not support mult-GPU cross "
|
||||
"node allreduce");
|
||||
}
|
||||
|
||||
void DataChannelGloo::allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
std::vector<at::Tensor>& input,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelGloo does not support mult-GPU cross "
|
||||
"node allgather");
|
||||
}
|
||||
|
||||
void DataChannelGloo::reduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dstRank,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelGloo does not support mult-GPU cross "
|
||||
"node reduce");
|
||||
}
|
||||
|
||||
void DataChannelGloo::broadcast(
|
||||
std::vector<at::Tensor>& data,
|
||||
rank_type srcRank,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelGloo does not support mult-GPU cross "
|
||||
"node broadcast");
|
||||
}
|
||||
|
||||
void DataChannelGloo::clearGroupCache(THDGroup group_id) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelGloo does not support clear "
|
||||
"group cache");
|
||||
}
|
||||
|
||||
void DataChannelGloo::barrier(THDGroup group_id) {
|
||||
RETURN_IF_NOT_IN_GROUP
|
||||
auto ret = _cache->getAlgorithm<CollectiveType::BARRIER, void>(
|
||||
group_id, _groups.at(group_id));
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*GlooCache::mutex(ret));
|
||||
GlooCache::algorithm(ret)->run();
|
||||
}
|
||||
}
|
||||
|
||||
THDGroup DataChannelGloo::newGroup(const std::vector<rank_type>& ranks) {
|
||||
auto new_group = DataChannelGloo::Group(
|
||||
_addr, _port, ranks, _num_processes - 1, Store::CLIENT_ONLY);
|
||||
THDGroup new_group_id = static_cast<THDGroup>(_groups.size());
|
||||
|
||||
_groups.insert({new_group_id, new_group});
|
||||
return new_group_id;
|
||||
}
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,150 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/base/ChannelUtils.hpp>
|
||||
#include <THD/base/DataChannel.hpp>
|
||||
#include <THD/base/data_channels/DataChannelUtils.hpp>
|
||||
|
||||
#include <gloo/rendezvous/store.h>
|
||||
#include <gloo/transport/device.h>
|
||||
|
||||
#include <map>
|
||||
|
||||
namespace thd {
|
||||
|
||||
struct GlooCache;
|
||||
|
||||
struct DataChannelGloo : DataChannel {
|
||||
using store_type = ::gloo::rendezvous::Store;
|
||||
|
||||
struct RequestGloo : DataChannel::Request {
|
||||
RequestGloo(QueueWorker::Request&& request);
|
||||
virtual ~RequestGloo();
|
||||
|
||||
virtual bool isCompleted() override;
|
||||
virtual void wait() override;
|
||||
|
||||
private:
|
||||
QueueWorker::Request _request;
|
||||
};
|
||||
|
||||
struct Group : DataChannel::Group {
|
||||
Group(
|
||||
const std::string& addr,
|
||||
port_type port,
|
||||
std::vector<rank_type> ranks,
|
||||
rank_type max_rank,
|
||||
int store_socket);
|
||||
|
||||
std::shared_ptr<store_type> _store;
|
||||
};
|
||||
|
||||
DataChannelGloo(InitMethod::Config config);
|
||||
DataChannelGloo(InitMethod::Config config, int timeout);
|
||||
virtual ~DataChannelGloo();
|
||||
|
||||
bool init() override;
|
||||
void destroy() override;
|
||||
|
||||
rank_type getRank() override;
|
||||
rank_type getNumProcesses() override;
|
||||
|
||||
void allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
std::vector<at::Tensor>& input,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void gather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void scatter(
|
||||
std::vector<at::Tensor>& input,
|
||||
at::Tensor& output,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void allReduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void allReduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void reduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dstRank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void reduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void broadcast(
|
||||
std::vector<at::Tensor>& data,
|
||||
rank_type srcRank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void broadcast(
|
||||
at::Tensor& data,
|
||||
rank_type src_id,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void send(Scalar& data, rank_type dst_id) override;
|
||||
void send(at::Tensor& data, rank_type dst_id) override;
|
||||
void receive(Scalar& data, rank_type src_id) override;
|
||||
rank_type receive(at::Tensor& data) override;
|
||||
void receive(at::Tensor& data, rank_type src_id) override;
|
||||
RequestGloo* isend(at::Tensor& data, rank_type dst_rank) override;
|
||||
RequestGloo* ireceive(at::Tensor& data, rank_type src_rank) override;
|
||||
|
||||
void barrier(THDGroup group_id = THDGroupWORLD) override;
|
||||
|
||||
THDGroup newGroup(const std::vector<rank_type>& ranks) override;
|
||||
void clearGroupCache(THDGroup group_id = THDGroupWORLD) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void allGatherT(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
THDGroup group_id);
|
||||
|
||||
template <typename T>
|
||||
void allReduceT(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup group_id = THDGroupWORLD);
|
||||
|
||||
template <typename T>
|
||||
void broadcastT(
|
||||
at::Tensor& data,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id = THDGroupWORLD);
|
||||
|
||||
rank_type _rank; // Current process' rank
|
||||
std::string _addr;
|
||||
port_type _port;
|
||||
rank_type _num_processes; // Number of processes in network
|
||||
/**
|
||||
* The list of network devices (such as Infiniband) that will be used by Gloo.
|
||||
* Currently Gloo only supports a single network device. Therefore:
|
||||
*
|
||||
* _deviceList.size() will always be equal or less than 1.
|
||||
*
|
||||
* We make it a vector for the purpose of future extension to support multiple
|
||||
* network devices.
|
||||
*/
|
||||
std::vector<std::shared_ptr<::gloo::transport::Device>> _deviceList;
|
||||
std::unordered_map<THDGroup, Group> _groups;
|
||||
int _listen_socket;
|
||||
|
||||
std::unique_ptr<GlooCache> _cache;
|
||||
|
||||
// Workers
|
||||
QueueWorker _send_worker, _receive_worker;
|
||||
};
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,542 +0,0 @@
|
|||
#include <THD/base/data_channels/DataChannelMPI.hpp>
|
||||
#include <THD/base/data_channels/DataChannelUtils.hpp>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
namespace thd {
|
||||
|
||||
namespace {
|
||||
|
||||
std::unordered_map<THDReduceOp, MPI_Op> mpi_op = {
|
||||
{THDReduceOp::THDReduceMIN, MPI_MIN},
|
||||
{THDReduceOp::THDReduceMAX, MPI_MAX},
|
||||
{THDReduceOp::THDReduceSUM, MPI_SUM},
|
||||
{THDReduceOp::THDReducePRODUCT, MPI_PROD},
|
||||
};
|
||||
|
||||
std::unordered_map<at::ScalarType, MPI_Datatype> mpi_datatype = {
|
||||
{at::kByte, MPI_UNSIGNED_CHAR},
|
||||
{at::kChar, MPI_CHAR},
|
||||
{at::kDouble, MPI_DOUBLE},
|
||||
{at::kFloat, MPI_FLOAT},
|
||||
{at::kInt, MPI_INT},
|
||||
{at::kLong, MPI_LONG},
|
||||
{at::kShort, MPI_SHORT},
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
DataChannelMPI::RequestMPI::RequestMPI() {}
|
||||
|
||||
DataChannelMPI::RequestMPI::~RequestMPI() {
|
||||
for (auto& request : _requests) {
|
||||
if (request != MPI_REQUEST_NULL)
|
||||
MPI_Request_free(&request);
|
||||
}
|
||||
}
|
||||
|
||||
bool DataChannelMPI::RequestMPI::isCompleted() {
|
||||
int flag;
|
||||
MPI_Testall(_requests.size(), _requests.data(), &flag, MPI_STATUSES_IGNORE);
|
||||
return static_cast<bool>(flag);
|
||||
}
|
||||
|
||||
void DataChannelMPI::RequestMPI::wait() {
|
||||
MPI_Waitall(_requests.size(), _requests.data(), MPI_STATUSES_IGNORE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DataChannelMPI::RequestMPI::save_buffer(std::shared_ptr<T> ptr) {
|
||||
_buffers.push_back(std::static_pointer_cast<void>(ptr));
|
||||
}
|
||||
|
||||
void DataChannelMPI::RequestMPI::save_tensor_buffer(at::Tensor& t) {
|
||||
_tensor_buffers.push_back(t);
|
||||
}
|
||||
|
||||
MPI_Request& DataChannelMPI::RequestMPI::new_request() {
|
||||
_requests.push_back(MPI_Request());
|
||||
return _requests.back();
|
||||
}
|
||||
|
||||
DataChannelMPI::DataChannelMPI() : _rank(-1), _num_processes(0) {}
|
||||
|
||||
DataChannelMPI::~DataChannelMPI() {
|
||||
for (auto& group : _groups) {
|
||||
auto comm = group.second.first;
|
||||
if (comm != MPI_COMM_WORLD && comm != MPI_COMM_NULL)
|
||||
MPI_Comm_free(&comm);
|
||||
}
|
||||
|
||||
MPI_Finalize();
|
||||
}
|
||||
|
||||
void DataChannelMPI::destroy() {}
|
||||
|
||||
bool DataChannelMPI::init() {
|
||||
#ifdef OMPI_MAJOR_VERSION
|
||||
// OMPI_* is specific to Openmpi implementation.
|
||||
// Openmpi v1.10 segfaults in MPI_Bcast with CUDA buffer.
|
||||
if (int(OMPI_MAJOR_VERSION) < 2) {
|
||||
throw std::runtime_error(
|
||||
"Please use Openmpi major version 2 and above for distributed.");
|
||||
}
|
||||
#endif /* OMPI_MAJOR_VERSION */
|
||||
|
||||
int provided;
|
||||
MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &provided);
|
||||
if (provided != MPI_THREAD_MULTIPLE) {
|
||||
std::cerr
|
||||
<< "WARNING: Used MPI implementation doesn't support multithreading, "
|
||||
<< "so distributed functions might not work properly."
|
||||
<< "If you are using mpich, try setting environment MPICH_MAX_THREAD_SAFETY=multiple and rerun."
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
int rank, num_processes;
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &num_processes);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
|
||||
_rank = convertToRank(rank);
|
||||
_num_processes = convertToRank(num_processes);
|
||||
|
||||
std::vector<rank_type> ranks;
|
||||
ranks.reserve(_num_processes);
|
||||
for (rank_type rank = 0; rank < _num_processes; ++rank)
|
||||
ranks.push_back(rank);
|
||||
|
||||
_groups.insert(
|
||||
{THDGroupWORLD,
|
||||
std::make_pair(
|
||||
MPI_COMM_WORLD, DataChannel::Group(ranks, _num_processes - 1))});
|
||||
return true;
|
||||
}
|
||||
|
||||
rank_type DataChannelMPI::getRank() {
|
||||
return _rank;
|
||||
}
|
||||
|
||||
rank_type DataChannelMPI::getNumProcesses() {
|
||||
return _num_processes;
|
||||
}
|
||||
|
||||
at::Tensor DataChannelMPI::_newLikeFlat(
|
||||
std::vector<at::Tensor>& tensors) const {
|
||||
// TODO: check if all outputs are contiguous in memory and skip this step is
|
||||
// yes
|
||||
if (tensors.size() == 0)
|
||||
throw std::runtime_error("received an empty list");
|
||||
auto& t = tensors[0];
|
||||
at::DeviceGuard gpu_guard(t.device());
|
||||
std::vector<int64_t> sizes{static_cast<int64_t>(
|
||||
tensors.size())}; // sizes = [output.size()] + input.sizes()
|
||||
sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
|
||||
return at::empty(sizes, t.options());
|
||||
}
|
||||
|
||||
void DataChannelMPI::allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
THDGroup group_id) {
|
||||
const auto& group_pair = _groups.at(group_id);
|
||||
const auto& comm = group_pair.first;
|
||||
if (comm == MPI_COMM_NULL)
|
||||
return;
|
||||
|
||||
if (output.size() != group_pair.second.size())
|
||||
throw std::logic_error(
|
||||
"allGather: number of output tensors and group size does not match");
|
||||
|
||||
for (auto out_tensor : output)
|
||||
assertSameSizeAndType(out_tensor, input, "allGather");
|
||||
|
||||
auto recv_buffer = _newLikeFlat(output);
|
||||
auto contig_input = input.contiguous();
|
||||
|
||||
MPI_Allgather(
|
||||
contig_input.data_ptr(),
|
||||
contig_input.numel(),
|
||||
mpi_datatype.at(contig_input.scalar_type()),
|
||||
recv_buffer.data_ptr(),
|
||||
contig_input.numel(),
|
||||
mpi_datatype.at(recv_buffer.scalar_type()),
|
||||
comm);
|
||||
|
||||
for (size_t i = 0; i < output.size(); ++i)
|
||||
output[i].copy_(recv_buffer[i]);
|
||||
}
|
||||
|
||||
void DataChannelMPI::gather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id) {
|
||||
const auto& group_pair = _groups.at(group_id);
|
||||
const auto& comm = group_pair.first;
|
||||
if (comm == MPI_COMM_NULL)
|
||||
return;
|
||||
|
||||
at::Tensor recv_buffer;
|
||||
void* recvbuf = nullptr;
|
||||
if (_rank != dst_rank) {
|
||||
if (output.size() > 0)
|
||||
throw std::logic_error(
|
||||
"gather: number of input tensors should be 0 for non root");
|
||||
} else {
|
||||
if (output.size() != group_pair.second.size())
|
||||
throw std::logic_error(
|
||||
"gather: number of output tensors and group size does not match");
|
||||
|
||||
for (auto out_tensor : output)
|
||||
assertSameSizeAndType(out_tensor, input, "gather");
|
||||
|
||||
recv_buffer = _newLikeFlat(output);
|
||||
recvbuf = recv_buffer.data_ptr();
|
||||
}
|
||||
|
||||
rank_type group_dst_rank = group_pair.second.mustGetGroupRank(dst_rank);
|
||||
auto contig_input = input.contiguous();
|
||||
|
||||
MPI_Gather(
|
||||
contig_input.data_ptr(),
|
||||
input.numel(),
|
||||
mpi_datatype.at(input.scalar_type()),
|
||||
recvbuf,
|
||||
input.numel(),
|
||||
mpi_datatype.at(input.scalar_type()),
|
||||
group_dst_rank,
|
||||
comm);
|
||||
|
||||
// NOTE: this is a no-op in all processes except dst_rank
|
||||
for (size_t i = 0; i < output.size(); ++i)
|
||||
output[i].copy_(recv_buffer[i]);
|
||||
}
|
||||
|
||||
void DataChannelMPI::scatter(
|
||||
std::vector<at::Tensor>& input,
|
||||
at::Tensor& output,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id) {
|
||||
const auto& group_pair = _groups.at(group_id);
|
||||
const auto& comm = group_pair.first;
|
||||
if (comm == MPI_COMM_NULL)
|
||||
return;
|
||||
|
||||
if (!output.is_contiguous())
|
||||
throw std::runtime_error("scatter output has to be a contiguous tensor");
|
||||
|
||||
at::Tensor send_buffer;
|
||||
void* sendbuf = nullptr;
|
||||
if (_rank != src_rank) {
|
||||
if (input.size() > 0)
|
||||
throw std::logic_error(
|
||||
"scatter: number of input tensors should be 0 for non root");
|
||||
} else {
|
||||
if (input.size() != group_pair.second.size())
|
||||
throw std::logic_error(
|
||||
"scatter: number of input tensors and group size does not match");
|
||||
|
||||
for (auto in_tensor : input)
|
||||
assertSameSizeAndType(in_tensor, output, "scatter");
|
||||
|
||||
send_buffer = _newLikeFlat(input);
|
||||
for (size_t i = 0; i < input.size(); ++i)
|
||||
send_buffer[i].copy_(input[i]);
|
||||
sendbuf = send_buffer.data_ptr();
|
||||
}
|
||||
|
||||
rank_type group_src_rank = group_pair.second.mustGetGroupRank(src_rank);
|
||||
|
||||
MPI_Scatter(
|
||||
sendbuf,
|
||||
output.numel(),
|
||||
mpi_datatype.at(output.scalar_type()),
|
||||
output.data_ptr(),
|
||||
output.numel(),
|
||||
mpi_datatype.at(output.scalar_type()),
|
||||
group_src_rank,
|
||||
comm);
|
||||
}
|
||||
|
||||
void DataChannelMPI::allReduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup group_id) {
|
||||
const auto& comm = _groups.at(group_id).first;
|
||||
if (comm == MPI_COMM_NULL)
|
||||
return;
|
||||
|
||||
if (!data.is_contiguous())
|
||||
throw std::runtime_error("all_reduce input has to be contiguous");
|
||||
|
||||
MPI_Allreduce(
|
||||
MPI_IN_PLACE,
|
||||
data.data_ptr(),
|
||||
data.numel(),
|
||||
mpi_datatype.at(data.scalar_type()),
|
||||
mpi_op.at(operation),
|
||||
comm);
|
||||
}
|
||||
|
||||
void DataChannelMPI::reduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id) {
|
||||
const auto& group_pair = _groups.at(group_id);
|
||||
const auto& comm = group_pair.first;
|
||||
if (comm == MPI_COMM_NULL)
|
||||
return;
|
||||
|
||||
if (!data.is_contiguous())
|
||||
throw std::runtime_error("reduce input has to be contiguous");
|
||||
|
||||
auto group_dst_rank = group_pair.second.mustGetGroupRank(dst_rank);
|
||||
void* sendbuf = (_rank == dst_rank) ? MPI_IN_PLACE : data.data_ptr();
|
||||
void* recvbuf = (_rank == dst_rank) ? data.data_ptr() : nullptr;
|
||||
MPI_Reduce(
|
||||
sendbuf,
|
||||
recvbuf,
|
||||
data.numel(),
|
||||
mpi_datatype.at(data.scalar_type()),
|
||||
mpi_op.at(operation),
|
||||
group_dst_rank,
|
||||
comm);
|
||||
}
|
||||
|
||||
void DataChannelMPI::broadcast(
|
||||
at::Tensor& data,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id) {
|
||||
const auto& group_pair = _groups.at(group_id);
|
||||
const auto& comm = group_pair.first;
|
||||
if (comm == MPI_COMM_NULL)
|
||||
return;
|
||||
|
||||
if (!data.is_contiguous())
|
||||
throw std::runtime_error("broadcast input has to be contiguous");
|
||||
|
||||
rank_type group_src_rank = group_pair.second.mustGetGroupRank(src_rank);
|
||||
MPI_Bcast(
|
||||
data.data_ptr(),
|
||||
data.numel(),
|
||||
mpi_datatype.at(data.scalar_type()),
|
||||
group_src_rank,
|
||||
comm);
|
||||
}
|
||||
|
||||
void DataChannelMPI::send(Scalar& data, rank_type dst_rank) {
|
||||
MPI_Send(
|
||||
data.data(),
|
||||
data.elementSize(),
|
||||
MPI_UINT8_T,
|
||||
dst_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD);
|
||||
}
|
||||
|
||||
void DataChannelMPI::send(at::Tensor& data, rank_type dst_rank) {
|
||||
if (!data.is_contiguous())
|
||||
throw std::logic_error("tensor to send is not contiguous");
|
||||
|
||||
MPI_Send(
|
||||
data.data_ptr(),
|
||||
data.numel(),
|
||||
mpi_datatype.at(data.scalar_type()),
|
||||
dst_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD);
|
||||
}
|
||||
|
||||
void DataChannelMPI::receive(Scalar& data, rank_type src_rank) {
|
||||
MPI_Recv(
|
||||
data.data(),
|
||||
data.elementSize(),
|
||||
MPI_UINT8_T,
|
||||
src_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD,
|
||||
MPI_STATUS_IGNORE);
|
||||
}
|
||||
|
||||
rank_type DataChannelMPI::receive(at::Tensor& data) {
|
||||
if (!data.is_contiguous())
|
||||
throw std::logic_error("tensor to receive is not contiguous");
|
||||
|
||||
MPI_Status status;
|
||||
MPI_Recv(
|
||||
data.data_ptr(),
|
||||
data.numel(),
|
||||
mpi_datatype.at(data.scalar_type()),
|
||||
MPI_ANY_SOURCE,
|
||||
0,
|
||||
MPI_COMM_WORLD,
|
||||
&status);
|
||||
return status.MPI_SOURCE;
|
||||
}
|
||||
|
||||
void DataChannelMPI::receive(at::Tensor& data, rank_type src_rank) {
|
||||
if (!data.is_contiguous())
|
||||
throw std::logic_error("tensor to receive is not contiguous");
|
||||
|
||||
MPI_Recv(
|
||||
data.data_ptr(),
|
||||
data.numel(),
|
||||
mpi_datatype.at(data.scalar_type()),
|
||||
src_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD,
|
||||
MPI_STATUS_IGNORE);
|
||||
}
|
||||
|
||||
void DataChannelMPI::barrier(THDGroup group_id) {
|
||||
const auto& comm = _groups.at(group_id).first;
|
||||
if (comm == MPI_COMM_NULL)
|
||||
return;
|
||||
|
||||
MPI_Barrier(comm);
|
||||
}
|
||||
|
||||
DataChannelMPI::RequestMPI* DataChannelMPI::isend(
|
||||
at::Tensor& data,
|
||||
rank_type dst_rank) {
|
||||
if (!data.is_contiguous())
|
||||
throw std::logic_error("tensor to send is not contiguous");
|
||||
|
||||
std::unique_ptr<RequestMPI> request{new RequestMPI()};
|
||||
request->save_tensor_buffer(data);
|
||||
auto& mpi_request = request->new_request();
|
||||
MPI_Isend(
|
||||
data.data_ptr(),
|
||||
data.numel(),
|
||||
mpi_datatype.at(data.scalar_type()),
|
||||
dst_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD,
|
||||
&mpi_request);
|
||||
|
||||
return request.release();
|
||||
}
|
||||
|
||||
DataChannelMPI::RequestMPI* DataChannelMPI::ireceive(
|
||||
at::Tensor& data,
|
||||
rank_type src_rank) {
|
||||
if (!data.is_contiguous())
|
||||
throw std::logic_error("tensor to receive is not contiguous");
|
||||
|
||||
std::unique_ptr<RequestMPI> request{new RequestMPI()};
|
||||
request->save_tensor_buffer(data);
|
||||
auto& mpi_request = request->new_request();
|
||||
MPI_Irecv(
|
||||
data.data_ptr(),
|
||||
data.numel(),
|
||||
mpi_datatype.at(data.scalar_type()),
|
||||
src_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD,
|
||||
&mpi_request);
|
||||
|
||||
return request.release();
|
||||
}
|
||||
|
||||
THDGroup DataChannelMPI::newGroup(const std::vector<rank_type>& ranks) {
|
||||
MPI_Group world_group;
|
||||
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
|
||||
|
||||
MPI_Group ranks_group;
|
||||
std::vector<int> int_ranks(ranks.begin(), ranks.end());
|
||||
MPI_Group_incl(world_group, int_ranks.size(), int_ranks.data(), &ranks_group);
|
||||
|
||||
MPI_Comm new_comm;
|
||||
MPI_Comm_create(MPI_COMM_WORLD, ranks_group, &new_comm);
|
||||
|
||||
MPI_Group_free(&world_group);
|
||||
MPI_Group_free(&ranks_group);
|
||||
|
||||
DataChannel::Group new_group;
|
||||
if (new_comm != MPI_COMM_NULL) {
|
||||
int size, mapping_ranks[2];
|
||||
MPI_Comm_size(new_comm, &size);
|
||||
MPI_Comm_rank(new_comm, mapping_ranks); // get rank in new communicator
|
||||
mapping_ranks[1] = _rank; // get rank in world communicator
|
||||
|
||||
std::unique_ptr<int[]> all_mapping_ranks(new int[2 * size]);
|
||||
MPI_Allgather(
|
||||
&mapping_ranks,
|
||||
2,
|
||||
MPI_INT,
|
||||
all_mapping_ranks.get(),
|
||||
2,
|
||||
MPI_INT,
|
||||
new_comm);
|
||||
|
||||
// this vector maps new ranks to ranks in COMM_WORLD (global ranks)
|
||||
std::vector<rank_type> new_ranks(size);
|
||||
for (size_t i = 0; i < 2 * size; i += 2)
|
||||
new_ranks[all_mapping_ranks[i]] = all_mapping_ranks[i + 1];
|
||||
|
||||
new_group = DataChannel::Group(new_ranks, _num_processes - 1);
|
||||
}
|
||||
|
||||
THDGroup new_group_id = static_cast<THDGroup>(_groups.size());
|
||||
_groups.insert({new_group_id, std::make_pair(new_comm, new_group)});
|
||||
return new_group_id;
|
||||
}
|
||||
|
||||
void DataChannelMPI::allReduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelMPI does not support mult-GPU cross "
|
||||
"node allreduce");
|
||||
}
|
||||
|
||||
void DataChannelMPI::allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
std::vector<at::Tensor>& input,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelMPI does not support mult-GPU cross "
|
||||
"node allgather");
|
||||
}
|
||||
|
||||
void DataChannelMPI::reduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dstRank,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelMPI does not support mult-GPU cross "
|
||||
"node reduce");
|
||||
}
|
||||
|
||||
void DataChannelMPI::broadcast(
|
||||
std::vector<at::Tensor>& data,
|
||||
rank_type srcRank,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelMPI does not support mult-GPU cross "
|
||||
"node broadcast");
|
||||
}
|
||||
|
||||
void DataChannelMPI::clearGroupCache(THDGroup group_id) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelMPI does not support clear "
|
||||
"group cache");
|
||||
}
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,111 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/base/DataChannel.hpp>
|
||||
|
||||
#include <mpi.h>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace thd {
|
||||
|
||||
struct DataChannelMPI : DataChannel {
|
||||
struct RequestMPI : DataChannel::Request {
|
||||
friend class DataChannelMPI; // allows `DataChannelMPI` to access private
|
||||
// members
|
||||
|
||||
RequestMPI();
|
||||
virtual ~RequestMPI();
|
||||
|
||||
virtual bool isCompleted() override;
|
||||
virtual void wait() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void save_buffer(std::shared_ptr<T> ptr);
|
||||
void save_tensor_buffer(at::Tensor& t);
|
||||
MPI_Request& new_request();
|
||||
|
||||
std::vector<std::shared_ptr<void>> _buffers;
|
||||
std::vector<at::Tensor> _tensor_buffers;
|
||||
std::vector<MPI_Request> _requests;
|
||||
};
|
||||
|
||||
DataChannelMPI();
|
||||
virtual ~DataChannelMPI();
|
||||
|
||||
bool init() override;
|
||||
void destroy() override;
|
||||
|
||||
rank_type getRank() override;
|
||||
rank_type getNumProcesses() override;
|
||||
|
||||
void allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
std::vector<at::Tensor>& input,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void gather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void scatter(
|
||||
std::vector<at::Tensor>& input,
|
||||
at::Tensor& output,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void allReduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void allReduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void reduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dstRank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void reduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void broadcast(
|
||||
std::vector<at::Tensor>& data,
|
||||
rank_type srcRank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void broadcast(
|
||||
at::Tensor& data,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void send(Scalar& data, rank_type dst_rank) override;
|
||||
void send(at::Tensor& data, rank_type dst_rank) override;
|
||||
void receive(Scalar& data, rank_type src_rank) override;
|
||||
rank_type receive(at::Tensor& data) override;
|
||||
void receive(at::Tensor& data, rank_type src_rank) override;
|
||||
RequestMPI* isend(at::Tensor& data, rank_type dst_rank) override;
|
||||
RequestMPI* ireceive(at::Tensor& data, rank_type src_rank) override;
|
||||
|
||||
void barrier(THDGroup group_id = THDGroupWORLD) override;
|
||||
THDGroup newGroup(const std::vector<rank_type>& ranks) override;
|
||||
void clearGroupCache(THDGroup group_id = THDGroupWORLD) override;
|
||||
|
||||
private:
|
||||
at::Tensor _newLikeFlat(std::vector<at::Tensor>& tensors) const;
|
||||
|
||||
rank_type _rank; // Current process' rank
|
||||
rank_type _num_processes; // Number of processes in network
|
||||
|
||||
// Existing groups of processes with assigned MPI communicator
|
||||
// and corresponding group ids
|
||||
std::unordered_map<THDGroup, std::pair<MPI_Comm, DataChannel::Group>> _groups;
|
||||
};
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,715 +0,0 @@
|
|||
#include <THD/base/data_channels/DataChannelNccl.hpp>
|
||||
#include <THD/base/Cuda.hpp>
|
||||
#include <THD/base/data_channels/DataChannelUtils.hpp>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <THC/THC.h>
|
||||
#include <cuda.h>
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace thd {
|
||||
|
||||
namespace {
|
||||
|
||||
std::unordered_map<THDReduceOp, ncclRedOp_t> ncclOp = {
|
||||
{THDReduceOp::THDReduceMIN, ncclMin},
|
||||
{THDReduceOp::THDReduceMAX, ncclMax},
|
||||
{THDReduceOp::THDReduceSUM, ncclSum},
|
||||
{THDReduceOp::THDReducePRODUCT, ncclProd},
|
||||
};
|
||||
|
||||
std::unordered_map<at::ScalarType, ncclDataType_t> ncclDatatype = {
|
||||
{at::kChar, ncclInt8},
|
||||
{at::kByte, ncclUint8},
|
||||
{at::kFloat, ncclFloat},
|
||||
{at::kDouble, ncclDouble},
|
||||
{at::kInt, ncclInt32},
|
||||
{at::kLong, ncclInt64},
|
||||
{at::kHalf, ncclHalf},
|
||||
};
|
||||
|
||||
// Helper function that gets the data type and issues error if not supported
|
||||
static ncclDataType_t _getNcclDataType(at::ScalarType type) {
|
||||
try {
|
||||
return ncclDatatype.at(type);
|
||||
} catch (std::out_of_range& e) {
|
||||
throw std::runtime_error("Unsupported data type for NCCL backend");
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function that gets the device list to determine the CUDA devices
|
||||
std::vector<int> getDevicesList(const std::string& deviceSeq) {
|
||||
std::stringstream ss(deviceSeq);
|
||||
std::string device;
|
||||
std::vector<int> devices;
|
||||
while (std::getline(ss, device, ',')) {
|
||||
devices.push_back(stoi(device));
|
||||
}
|
||||
return devices;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// DataChannelNccl
|
||||
DataChannelNccl::DataChannelNccl(InitMethod::Config config, int timeout)
|
||||
: _rank(config.rank),
|
||||
_numProcesses(config.world_size),
|
||||
_timeout(timeout),
|
||||
_masterListeningSocket(-1),
|
||||
_slaveSocket(-1) {
|
||||
// Establish the socket connections from rank 0 to all others
|
||||
if (_rank == 0) {
|
||||
_masterListeningSocket = config.master.listen_socket;
|
||||
_masterSendingSockets = std::vector<int>(_numProcesses - 1, -1);
|
||||
|
||||
try {
|
||||
for (rank_type i = 0; i < _numProcesses - 1; ++i) {
|
||||
std::tie(_masterSendingSockets[i], std::ignore) =
|
||||
accept(_masterListeningSocket, _timeout);
|
||||
}
|
||||
} catch (...) {
|
||||
// Destroy the created sockets
|
||||
_destroySockets();
|
||||
throw std::runtime_error("Rank 0 cannot establish thelistening socket");
|
||||
}
|
||||
|
||||
} else {
|
||||
_masterAddr = config.worker.master_addr;
|
||||
_masterPort = config.worker.master_port;
|
||||
|
||||
try {
|
||||
_slaveSocket = connect(_masterAddr, _masterPort, true, _timeout);
|
||||
} catch (...) {
|
||||
// Destroy the created sockets
|
||||
_destroySockets();
|
||||
std::string errStr = "Rank: " + std::to_string(_rank) +
|
||||
" cannot "
|
||||
"connect to the master: " +
|
||||
_masterAddr + ":" + std::to_string(_masterPort);
|
||||
throw std::runtime_error(errStr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use the socket to broadcast NCCL ID
|
||||
void DataChannelNccl::broadcastUniqueNcclId(ncclUniqueId* ncclId) {
|
||||
// Send the unique NCCL id to every rank
|
||||
if (_rank == 0) {
|
||||
for (auto socket : _masterSendingSockets) {
|
||||
send_bytes<uint8_t>(
|
||||
socket, reinterpret_cast<uint8_t*>(ncclId), NCCL_UNIQUE_ID_BYTES);
|
||||
}
|
||||
} else {
|
||||
recv_bytes<uint8_t>(
|
||||
_slaveSocket, reinterpret_cast<uint8_t*>(ncclId), NCCL_UNIQUE_ID_BYTES);
|
||||
}
|
||||
}
|
||||
|
||||
// Destructor will only close all the sockets
|
||||
DataChannelNccl::~DataChannelNccl() {
|
||||
/**
|
||||
* Note that destructor will be called after cudaruntime being unloaded since
|
||||
* DataChannel is a global variable.
|
||||
*/
|
||||
_destroySockets();
|
||||
}
|
||||
|
||||
void DataChannelNccl::_destroySockets() {
|
||||
// Destroying all the socket
|
||||
if (_masterListeningSocket != -1) {
|
||||
::close(_masterListeningSocket);
|
||||
_masterListeningSocket = -1;
|
||||
}
|
||||
if (_slaveSocket != -1) {
|
||||
::close(_slaveSocket);
|
||||
_slaveSocket = -1;
|
||||
}
|
||||
for (size_t i = 0; i < _masterSendingSockets.size(); ++i) {
|
||||
if (_masterSendingSockets[i] != -1) {
|
||||
::close(_masterSendingSockets[i]);
|
||||
_masterSendingSockets[i] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy the data channel
|
||||
void DataChannelNccl::destroy() {
|
||||
std::unique_lock<std::mutex> channelLock(_mutex);
|
||||
|
||||
// Destroying all the socket
|
||||
_destroySockets();
|
||||
|
||||
// Guard GPU device
|
||||
at::cuda::OptionalCUDAGuard gpuGuard;
|
||||
|
||||
/**
|
||||
* Destroy the CUDA and NCCL resources
|
||||
* TODO: creating C++ wrappers for CUDA and NCCL resources to do the
|
||||
* cleanup automatically
|
||||
*/
|
||||
for (auto& itemPair : _groupNcclResources) {
|
||||
auto groupId = itemPair.first;
|
||||
_destroyNcclResources(groupId);
|
||||
}
|
||||
|
||||
_groupNcclResources.clear();
|
||||
_groupDevices.clear();
|
||||
|
||||
_groups.clear();
|
||||
}
|
||||
|
||||
// Helper function that destroys the CUDA event and NCCL communicator
|
||||
void DataChannelNccl::_destroyNcclResources(THDGroup groupId) {
|
||||
if (_groupNcclResources.find(groupId) != _groupNcclResources.end()) {
|
||||
for (int i = 0; i < _groupDevices[groupId].size(); i++) {
|
||||
// Devices used for this group ID
|
||||
auto devices = getDevicesList(_groupDevices[groupId][i]);
|
||||
// Guard GPU device
|
||||
at::cuda::OptionalCUDAGuard gpuGuard;
|
||||
// Destroy the CUDA events
|
||||
size_t idx = 0;
|
||||
for (auto& event : *(_groupNcclResources[groupId][i].ncclCudaEvents())) {
|
||||
gpuGuard.set_index(devices[idx++]);
|
||||
THCudaCheck(cudaEventSynchronize(event));
|
||||
THCudaCheck(cudaEventDestroy(event));
|
||||
}
|
||||
// Destroy the communicators
|
||||
for (auto& comm : *(_groupNcclResources[groupId][i].ncclComms())) {
|
||||
NCCL_CHECK(ncclCommDestroy(comm));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy the cached NCCL resource associated with a given group
|
||||
void DataChannelNccl::clearGroupCache(THDGroup groupId) {
|
||||
std::unique_lock<std::mutex> channelLock(_mutex);
|
||||
|
||||
_destroyNcclResources(groupId);
|
||||
|
||||
_groupNcclResources.erase(groupId);
|
||||
_groupDevices.erase(groupId);
|
||||
}
|
||||
|
||||
// Initialization function
|
||||
bool DataChannelNccl::init() {
|
||||
std::vector<rank_type> ranks;
|
||||
ranks.reserve(_numProcesses);
|
||||
|
||||
for (rank_type rank = 0; rank < _numProcesses; ++rank) {
|
||||
ranks.push_back(rank);
|
||||
}
|
||||
|
||||
// Insert the current group
|
||||
_groups.insert({THDGroupWORLD, DataChannel::Group(ranks, _numProcesses - 1)});
|
||||
|
||||
// Get the GPU count
|
||||
THCudaCheck(cudaGetDeviceCount(&_numGPUs));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
rank_type DataChannelNccl::getRank() {
|
||||
return _rank;
|
||||
}
|
||||
|
||||
rank_type DataChannelNccl::getNumProcesses() {
|
||||
return _numProcesses;
|
||||
}
|
||||
|
||||
NcclResourcePair DataChannelNccl::_getNcclResourcePair(
|
||||
std::vector<at::Tensor>& input,
|
||||
THDGroup groupId) {
|
||||
if (input.empty()) {
|
||||
throw std::runtime_error(
|
||||
"Not able to create/get the Nccl Comm since "
|
||||
"input tensor is empty");
|
||||
}
|
||||
// Get the deviceList String
|
||||
std::string deviceList;
|
||||
for (auto tensor : input) {
|
||||
if (deviceList.empty()) {
|
||||
deviceList = std::to_string(tensor.get_device());
|
||||
} else {
|
||||
deviceList += "," + std::to_string(tensor.get_device());
|
||||
}
|
||||
}
|
||||
|
||||
int index = -1;
|
||||
|
||||
if (_groupDevices.find(groupId) != _groupDevices.end()) {
|
||||
auto pos = std::find(
|
||||
_groupDevices[groupId].begin(),
|
||||
_groupDevices[groupId].end(),
|
||||
deviceList);
|
||||
if (pos != _groupDevices[groupId].end())
|
||||
index = pos - _groupDevices[groupId].begin();
|
||||
}
|
||||
|
||||
if (index >= 0) {
|
||||
return std::make_pair(
|
||||
_groupNcclResources[groupId][index].ncclComms(),
|
||||
_groupNcclResources[groupId][index].ncclCudaEvents());
|
||||
}
|
||||
|
||||
// Add in the device list of the group
|
||||
_groupDevices[groupId].push_back(deviceList);
|
||||
|
||||
// NCCL communicator
|
||||
auto comms =
|
||||
std::unique_ptr<std::vector<ncclComm_t>>(new std::vector<ncclComm_t>());
|
||||
|
||||
comms->resize(input.size());
|
||||
|
||||
// Corresponding CUDA events
|
||||
auto events =
|
||||
std::unique_ptr<std::vector<cudaEvent_t>>(new std::vector<cudaEvent_t>());
|
||||
|
||||
events->resize(input.size());
|
||||
|
||||
// Create the unique NCCL ID and broadcast it
|
||||
ncclUniqueId ncclId;
|
||||
NCCL_CHECK(ncclGetUniqueId(&ncclId));
|
||||
|
||||
// Broadcast so that each process can have a unique NCCL ID
|
||||
broadcastUniqueNcclId(&ncclId);
|
||||
|
||||
// Guard GPU device
|
||||
at::cuda::OptionalCUDAGuard gpuGuard;
|
||||
|
||||
// Now creating the CUDA events
|
||||
for (size_t i = 0; i < input.size(); ++i) {
|
||||
gpuGuard.set_index(input[i].get_device());
|
||||
THCudaCheck(cudaEventCreate(&((*events)[i])));
|
||||
}
|
||||
// Create the communicator on each device of the input
|
||||
NCCL_CHECK(ncclGroupStart());
|
||||
for (size_t i = 0; i < input.size(); ++i) {
|
||||
int nRanks = int(_numProcesses) * input.size();
|
||||
gpuGuard.set_index(input[i].get_device());
|
||||
NCCL_CHECK(ncclCommInitRank(
|
||||
&((*comms)[i]), nRanks, ncclId, _rank * input.size() + i));
|
||||
}
|
||||
NCCL_CHECK(ncclGroupEnd());
|
||||
|
||||
// Move into the hash table
|
||||
if (_groupNcclResources.find(groupId) == _groupNcclResources.end())
|
||||
_groupNcclResources.emplace(
|
||||
std::make_pair(groupId, std::vector<NcclResources>()));
|
||||
|
||||
_groupNcclResources[groupId].push_back(
|
||||
NcclResources(std::move(comms), std::move(events)));
|
||||
|
||||
return std::make_pair(
|
||||
_groupNcclResources[groupId].back().ncclComms(),
|
||||
_groupNcclResources[groupId].back().ncclCudaEvents());
|
||||
}
|
||||
|
||||
// Helper function that checks the input and output tensors for validity
|
||||
bool DataChannelNccl::_tensorCheckHelper(
|
||||
const std::vector<at::Tensor>& input,
|
||||
const std::vector<at::Tensor>& output,
|
||||
size_t outputOverInput) {
|
||||
if (input.size() != output.size()) {
|
||||
throw std::runtime_error(
|
||||
"Input tensor sequence should have the same "
|
||||
"number of tensors as the output tensor sequence");
|
||||
}
|
||||
|
||||
if (input.size() == 0) {
|
||||
// Return false saying this is a no-op
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input.size() > _numGPUs) {
|
||||
throw std::runtime_error(
|
||||
"The number of input tensors is larger than "
|
||||
"the number of available GPUs");
|
||||
}
|
||||
|
||||
// To make sure each tensor is on separate devices
|
||||
std::unordered_set<int> usedDevices;
|
||||
usedDevices.reserve(input.size());
|
||||
|
||||
uint64_t inputNumElement = input[0].numel();
|
||||
auto elementType = input[0].scalar_type();
|
||||
|
||||
for (size_t i = 0; i < input.size(); ++i) {
|
||||
// Check to make sure it's a GPU dense tensor
|
||||
if (!(input[i].is_cuda() && !input[i].is_sparse() &&
|
||||
output[i].is_cuda() && !output[i].is_sparse())) {
|
||||
throw std::runtime_error(
|
||||
"Only CUDA dense tensor is supported for NCCL "
|
||||
"collective operations");
|
||||
}
|
||||
// Check the tensor type is identical
|
||||
if (input[i].scalar_type() != elementType ||
|
||||
output[i].scalar_type() != elementType) {
|
||||
throw std::runtime_error(
|
||||
"Expecting all GPU tensors to have identical "
|
||||
"type");
|
||||
}
|
||||
// Check the input tensor size is identical
|
||||
if (input[i].numel() != inputNumElement) {
|
||||
throw std::runtime_error(
|
||||
"Expecting all input tensors to have identical "
|
||||
"number of elements");
|
||||
}
|
||||
// Check the output tensor size equals to input tensor size
|
||||
if (output[i].numel() != inputNumElement * outputOverInput) {
|
||||
throw std::runtime_error(
|
||||
"The number of elements of output tensor does "
|
||||
"not match the number of elements of the input "
|
||||
"tensor");
|
||||
}
|
||||
// Contiguous verification
|
||||
if (!input[i].is_contiguous() || !output[i].is_contiguous()) {
|
||||
throw std::runtime_error("Expecting all GPU tensors to be contiguous");
|
||||
}
|
||||
|
||||
bool inserted;
|
||||
std::tie(std::ignore, inserted) = usedDevices.insert(input[i].get_device());
|
||||
// Device verification, if the insertion didn't take place
|
||||
if (!inserted) {
|
||||
throw std::runtime_error("Expecting inputs on different GPU devices");
|
||||
}
|
||||
|
||||
// Now check the output device
|
||||
if (input[i].get_device() != output[i].get_device()) {
|
||||
throw std::runtime_error(
|
||||
"Expecting input and output tensors to be on "
|
||||
"the same device");
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void DataChannelNccl::allReduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup groupId) {
|
||||
std::unique_lock<std::mutex> channelLock(_mutex);
|
||||
// Check the tensor vector for consistency
|
||||
if (!_tensorCheckHelper(data, data)) {
|
||||
return;
|
||||
}
|
||||
_checkGroupIdValid(groupId);
|
||||
|
||||
auto ncclResourcePair = _getNcclResourcePair(data, groupId);
|
||||
auto comms = ncclResourcePair.first;
|
||||
auto events = ncclResourcePair.second;
|
||||
|
||||
// Guard GPU device
|
||||
at::cuda::OptionalCUDAGuard gpuGuard;
|
||||
|
||||
std::unique_lock<std::mutex> cudaFreeMutexLock(
|
||||
*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
|
||||
|
||||
NCCL_CHECK(ncclGroupStart());
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
auto device = data[i].get_device();
|
||||
gpuGuard.set_index(device);
|
||||
auto stream = THCState_getCurrentStreamOnDevice(THDGetCudaState(), device);
|
||||
|
||||
NCCL_CHECK(ncclAllReduce(
|
||||
data[i].data_ptr(),
|
||||
data[i].data_ptr(),
|
||||
data[i].numel(),
|
||||
_getNcclDataType(data[i].scalar_type()),
|
||||
ncclOp[operation],
|
||||
(*comms)[i],
|
||||
stream));
|
||||
}
|
||||
NCCL_CHECK(ncclGroupEnd());
|
||||
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
auto device = data[i].get_device();
|
||||
gpuGuard.set_index(device);
|
||||
auto stream = THCState_getCurrentStreamOnDevice(THDGetCudaState(), device);
|
||||
THCudaCheck(cudaEventRecord((*events)[i], stream));
|
||||
}
|
||||
|
||||
cudaFreeMutexLock.unlock();
|
||||
}
|
||||
|
||||
void DataChannelNccl::allReduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup groupId) {
|
||||
std::vector<at::Tensor> dataVec = {data};
|
||||
allReduce(dataVec, operation, groupId);
|
||||
}
|
||||
|
||||
void DataChannelNccl::allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
std::vector<at::Tensor>& input,
|
||||
THDGroup groupId) {
|
||||
std::unique_lock<std::mutex> channelLock(_mutex);
|
||||
|
||||
if (!_tensorCheckHelper(input, output, _numProcesses * input.size())) {
|
||||
return;
|
||||
}
|
||||
_checkGroupIdValid(groupId);
|
||||
|
||||
auto ncclResourcePair = _getNcclResourcePair(input, groupId);
|
||||
auto comms = ncclResourcePair.first;
|
||||
auto events = ncclResourcePair.second;
|
||||
|
||||
// Guard GPU device
|
||||
at::cuda::OptionalCUDAGuard gpuGuard;
|
||||
|
||||
std::unique_lock<std::mutex> cudaFreeMutexLock(
|
||||
*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
|
||||
|
||||
NCCL_CHECK(ncclGroupStart());
|
||||
for (size_t i = 0; i < input.size(); ++i) {
|
||||
auto device = input[i].get_device();
|
||||
gpuGuard.set_index(device);
|
||||
auto stream = THCState_getCurrentStreamOnDevice(THDGetCudaState(), device);
|
||||
|
||||
NCCL_CHECK(ncclAllGather(
|
||||
input[i].data_ptr(),
|
||||
output[i].data_ptr(),
|
||||
input[i].numel(),
|
||||
_getNcclDataType(input[i].scalar_type()),
|
||||
(*comms)[i],
|
||||
stream));
|
||||
}
|
||||
NCCL_CHECK(ncclGroupEnd());
|
||||
|
||||
for (size_t i = 0; i < input.size(); ++i) {
|
||||
auto device = input[i].get_device();
|
||||
gpuGuard.set_index(device);
|
||||
auto stream = THCState_getCurrentStreamOnDevice(THDGetCudaState(), device);
|
||||
THCudaCheck(cudaEventRecord((*events)[i], stream));
|
||||
}
|
||||
|
||||
cudaFreeMutexLock.unlock();
|
||||
}
|
||||
|
||||
void DataChannelNccl::allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
THDGroup groupId) {
|
||||
std::vector<at::Tensor> inputDataVec = {input};
|
||||
allGather(output, inputDataVec, groupId);
|
||||
}
|
||||
|
||||
void DataChannelNccl::reduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dstRank,
|
||||
THDGroup groupId) {
|
||||
std::unique_lock<std::mutex> channelLock(_mutex);
|
||||
|
||||
// Check the tensor vector for consistency
|
||||
if (!_tensorCheckHelper(data, data)) {
|
||||
return;
|
||||
}
|
||||
_checkGroupIdValid(groupId);
|
||||
|
||||
auto ncclResourcePair = _getNcclResourcePair(data, groupId);
|
||||
auto comms = ncclResourcePair.first;
|
||||
auto events = ncclResourcePair.second;
|
||||
|
||||
// Guard GPU device
|
||||
at::cuda::OptionalCUDAGuard gpuGuard;
|
||||
|
||||
std::unique_lock<std::mutex> cudaFreeMutexLock(
|
||||
*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
|
||||
|
||||
NCCL_CHECK(ncclGroupStart());
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
auto device = data[i].get_device();
|
||||
gpuGuard.set_index(device);
|
||||
auto stream = THCState_getCurrentStreamOnDevice(THDGetCudaState(), device);
|
||||
|
||||
NCCL_CHECK(ncclReduce(
|
||||
data[i].data_ptr(),
|
||||
data[i].data_ptr(),
|
||||
data[i].numel(),
|
||||
_getNcclDataType(data[i].scalar_type()),
|
||||
ncclOp[operation],
|
||||
dstRank * data.size(),
|
||||
(*comms)[i],
|
||||
stream));
|
||||
}
|
||||
NCCL_CHECK(ncclGroupEnd());
|
||||
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
auto device = data[i].get_device();
|
||||
gpuGuard.set_index(device);
|
||||
auto stream = THCState_getCurrentStreamOnDevice(THDGetCudaState(), device);
|
||||
THCudaCheck(cudaEventRecord((*events)[i], stream));
|
||||
}
|
||||
|
||||
cudaFreeMutexLock.unlock();
|
||||
}
|
||||
|
||||
void DataChannelNccl::reduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dstRank,
|
||||
THDGroup groupId) {
|
||||
std::vector<at::Tensor> dataVec = {data};
|
||||
reduce(dataVec, operation, dstRank, groupId);
|
||||
}
|
||||
|
||||
void DataChannelNccl::broadcast(
|
||||
std::vector<at::Tensor>& data,
|
||||
rank_type srcRank,
|
||||
THDGroup groupId) {
|
||||
std::unique_lock<std::mutex> channelLock(_mutex);
|
||||
|
||||
// Check the tensor vector for consistency
|
||||
if (!_tensorCheckHelper(data, data)) {
|
||||
return;
|
||||
}
|
||||
_checkGroupIdValid(groupId);
|
||||
|
||||
auto ncclResourcePair = _getNcclResourcePair(data, groupId);
|
||||
auto comms = ncclResourcePair.first;
|
||||
auto events = ncclResourcePair.second;
|
||||
|
||||
// Guard GPU device
|
||||
at::cuda::OptionalCUDAGuard gpuGuard;
|
||||
|
||||
std::unique_lock<std::mutex> cudaFreeMutexLock(
|
||||
*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
|
||||
|
||||
NCCL_CHECK(ncclGroupStart());
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
auto device = data[i].get_device();
|
||||
gpuGuard.set_index(device);
|
||||
auto stream = THCState_getCurrentStreamOnDevice(THDGetCudaState(), device);
|
||||
|
||||
NCCL_CHECK(ncclBcast(
|
||||
data[i].data_ptr(),
|
||||
data[i].numel(),
|
||||
_getNcclDataType(data[i].scalar_type()),
|
||||
srcRank * data.size(),
|
||||
(*comms)[i],
|
||||
stream));
|
||||
}
|
||||
NCCL_CHECK(ncclGroupEnd());
|
||||
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
auto device = data[i].get_device();
|
||||
gpuGuard.set_index(device);
|
||||
auto stream = THCState_getCurrentStreamOnDevice(THDGetCudaState(), device);
|
||||
THCudaCheck(cudaEventRecord((*events)[i], stream));
|
||||
}
|
||||
|
||||
cudaFreeMutexLock.unlock();
|
||||
}
|
||||
|
||||
void DataChannelNccl::broadcast(
|
||||
at::Tensor& data,
|
||||
rank_type srcRank,
|
||||
THDGroup groupId) {
|
||||
std::vector<at::Tensor> dataVec = {data};
|
||||
broadcast(dataVec, srcRank, groupId);
|
||||
}
|
||||
|
||||
void DataChannelNccl::barrier(THDGroup groupId) {
|
||||
throw std::runtime_error("DataChannelNccl does not support barrier");
|
||||
}
|
||||
|
||||
THDGroup DataChannelNccl::newGroup(const std::vector<rank_type>& ranks) {
|
||||
/**
|
||||
* Check if the input rank is a full group since
|
||||
* NCCL data channel currently doesn't support sub-group creation
|
||||
*/
|
||||
std::vector<rank_type> ranksToCompare = std::vector<rank_type>(ranks);
|
||||
std::sort(ranksToCompare.begin(), ranksToCompare.end());
|
||||
for (size_t i = 0; i < ranksToCompare.size(); ++i) {
|
||||
if (ranksToCompare[i] != static_cast<rank_type>(i)) {
|
||||
throw std::runtime_error(
|
||||
"NCCL backend currently only supports fullgroup "
|
||||
"creation. In other words, every rank in the "
|
||||
"process group needs to be a member of the new "
|
||||
"group to be created and sub-group creation is "
|
||||
"currently not supported.");
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> channelLock(_mutex);
|
||||
|
||||
auto newGroup = DataChannel::Group(ranks, _numProcesses - 1);
|
||||
THDGroup newGroupId = static_cast<THDGroup>(_groups.size());
|
||||
|
||||
// Insert the current group
|
||||
_groups.insert({newGroupId, newGroup});
|
||||
|
||||
return newGroupId;
|
||||
}
|
||||
|
||||
// Helper function that checks if the given groupId is valid
|
||||
void DataChannelNccl::_checkGroupIdValid(THDGroup groupId) {
|
||||
if (_groups.find(groupId) == _groups.end()) {
|
||||
std::string errMsg =
|
||||
"Group ID: " + std::to_string(groupId) + " is not valid";
|
||||
throw std::runtime_error(errMsg);
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelNccl::gather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
rank_type dstRank,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error("DataChannelNccl does not support gather");
|
||||
}
|
||||
|
||||
void DataChannelNccl::scatter(
|
||||
std::vector<at::Tensor>& input,
|
||||
at::Tensor& output,
|
||||
rank_type srcRank,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error("DataChannelNccl does not support scatter");
|
||||
}
|
||||
|
||||
void DataChannelNccl::send(Scalar& data, rank_type dstRank) {
|
||||
throw std::runtime_error("DataChannelNccl does not support send");
|
||||
}
|
||||
|
||||
void DataChannelNccl::send(at::Tensor& data, rank_type dstRank) {
|
||||
throw std::runtime_error("DataChannelNccl does not support send");
|
||||
}
|
||||
|
||||
void DataChannelNccl::receive(Scalar& data, rank_type srcRank) {
|
||||
throw std::runtime_error("DataChannelNccl does not support receive");
|
||||
}
|
||||
|
||||
rank_type DataChannelNccl::receive(at::Tensor& data) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelNccl does not support receive "
|
||||
"from any source");
|
||||
}
|
||||
|
||||
void DataChannelNccl::receive(at::Tensor& data, rank_type srcRank) {
|
||||
throw std::runtime_error("DataChannelNccl does not support receive");
|
||||
}
|
||||
|
||||
DataChannelNccl::RequestNccl* DataChannelNccl::isend(
|
||||
at::Tensor& data,
|
||||
rank_type dstRank) {
|
||||
throw std::runtime_error("DataChannelNccl does not support isend");
|
||||
}
|
||||
|
||||
DataChannelNccl::RequestNccl* DataChannelNccl::ireceive(
|
||||
at::Tensor& data,
|
||||
rank_type srcRank) {
|
||||
throw std::runtime_error("DataChannelNccl does not support ireceive");
|
||||
}
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,258 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/base/DataChannel.hpp>
|
||||
#include <THD/base/data_channels/DataChannelUtils.hpp>
|
||||
|
||||
#include <nccl.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#define NCCL_CHECK(cmd) \
|
||||
do { \
|
||||
ncclResult_t error = cmd; \
|
||||
if (error != ncclSuccess) { \
|
||||
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
|
||||
std::to_string(__LINE__) + ", " + \
|
||||
std::string(ncclGetErrorString(error)); \
|
||||
throw std::runtime_error(err); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace thd {
|
||||
|
||||
// Type aliasing
|
||||
using NcclResourcePair =
|
||||
std::pair<std::vector<ncclComm_t>*, std::vector<cudaEvent_t>*>;
|
||||
|
||||
struct DataChannelNccl : DataChannel {
|
||||
// Nothing to implement
|
||||
struct RequestNccl : DataChannel::Request {};
|
||||
|
||||
// Wrapper on the pair of NCCL resources
|
||||
class NcclResources {
|
||||
public:
|
||||
NcclResources() = default;
|
||||
NcclResources(
|
||||
std::unique_ptr<std::vector<ncclComm_t>>&& ncclComm,
|
||||
std::unique_ptr<std::vector<cudaEvent_t>>&& event)
|
||||
:
|
||||
|
||||
_commEventPair(std::pair<
|
||||
std::unique_ptr<std::vector<ncclComm_t>>,
|
||||
std::unique_ptr<std::vector<cudaEvent_t>>>(
|
||||
std::move(ncclComm),
|
||||
std::move(event))) {}
|
||||
// Delete copy and assignment ctors
|
||||
NcclResources(const NcclResources&) = delete;
|
||||
NcclResources& operator=(const NcclResources&) = delete;
|
||||
|
||||
// Move ctors by default
|
||||
NcclResources(NcclResources&&) = default;
|
||||
NcclResources& operator=(NcclResources&&) = default;
|
||||
|
||||
// Nccl Communicator Getter
|
||||
std::vector<ncclComm_t>* ncclComms() {
|
||||
return _commEventPair.first.get();
|
||||
}
|
||||
|
||||
// Nccl CUDA event Getter
|
||||
std::vector<cudaEvent_t>* ncclCudaEvents() {
|
||||
return _commEventPair.second.get();
|
||||
}
|
||||
|
||||
private:
|
||||
std::pair<
|
||||
std::unique_ptr<std::vector<ncclComm_t>>,
|
||||
std::unique_ptr<std::vector<cudaEvent_t>>>
|
||||
_commEventPair;
|
||||
};
|
||||
|
||||
// Constructor
|
||||
DataChannelNccl(InitMethod::Config config, int timeout = -1);
|
||||
virtual ~DataChannelNccl();
|
||||
|
||||
bool init() override;
|
||||
void destroy() override;
|
||||
|
||||
rank_type getRank() override;
|
||||
rank_type getNumProcesses() override;
|
||||
|
||||
void allReduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup = THDGroupWORLD) override;
|
||||
|
||||
void allReduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup groupId = THDGroupWORLD) override;
|
||||
|
||||
void allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
std::vector<at::Tensor>& input,
|
||||
THDGroup groupId = THDGroupWORLD) override;
|
||||
|
||||
void allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
THDGroup groupId = THDGroupWORLD) override;
|
||||
|
||||
void reduce(
|
||||
std::vector<at::Tensor>& input,
|
||||
THDReduceOp operation,
|
||||
rank_type dstRank,
|
||||
THDGroup groupId = THDGroupWORLD) override;
|
||||
|
||||
void reduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dstRank,
|
||||
THDGroup groupId = THDGroupWORLD) override;
|
||||
|
||||
void broadcast(
|
||||
std::vector<at::Tensor>& data,
|
||||
rank_type srcRank,
|
||||
THDGroup groupId = THDGroupWORLD) override;
|
||||
|
||||
void broadcast(
|
||||
at::Tensor& data,
|
||||
rank_type srcRank,
|
||||
THDGroup groupId = THDGroupWORLD) override;
|
||||
|
||||
void barrier(THDGroup groupId = THDGroupWORLD) override;
|
||||
|
||||
THDGroup newGroup(const std::vector<rank_type>& ranks) override;
|
||||
|
||||
void clearGroupCache(THDGroup groupId = THDGroupWORLD) override;
|
||||
|
||||
// Not supported functions
|
||||
void gather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
rank_type dstRank,
|
||||
THDGroup groupId = THDGroupWORLD) override;
|
||||
|
||||
void scatter(
|
||||
std::vector<at::Tensor>& input,
|
||||
at::Tensor& output,
|
||||
rank_type srcRank,
|
||||
THDGroup groupId = THDGroupWORLD) override;
|
||||
|
||||
void send(Scalar& data, rank_type dstRank) override;
|
||||
|
||||
void send(at::Tensor& data, rank_type dstRank) override;
|
||||
|
||||
void receive(Scalar& data, rank_type srcRank) override;
|
||||
|
||||
rank_type receive(at::Tensor& data) override;
|
||||
|
||||
void receive(at::Tensor& data, rank_type srcRank) override;
|
||||
|
||||
RequestNccl* isend(at::Tensor& data, rank_type dstRank) override;
|
||||
|
||||
RequestNccl* ireceive(at::Tensor& data, rank_type srcRank) override;
|
||||
|
||||
private:
|
||||
// Current process' rank
|
||||
rank_type _rank;
|
||||
// Number of processes in network
|
||||
rank_type _numProcesses;
|
||||
|
||||
// Accept waiting timeout in milliseconds, optional
|
||||
int _timeout;
|
||||
// Master's address
|
||||
std::string _masterAddr;
|
||||
// Master's port
|
||||
port_type _masterPort;
|
||||
// Socket on which the master is listening
|
||||
int _masterListeningSocket;
|
||||
/**
|
||||
* Sockets on which the master is sending to each slave
|
||||
* Note that the sockets in the vector can be in arbitrary order and
|
||||
* are not sorted by ranks
|
||||
*/
|
||||
std::vector<int> _masterSendingSockets;
|
||||
/**
|
||||
* Slave socket, which is used for all other slave ranks other than the master
|
||||
* rank (rank 0) to receive rank 0's broadcasted Unique NCCL ID
|
||||
* that is used for building the NCCL communicator
|
||||
*/
|
||||
int _slaveSocket;
|
||||
|
||||
// Number of GPUs on each node
|
||||
int _numGPUs;
|
||||
// Mutex for Nccl Data Channel
|
||||
std::mutex _mutex;
|
||||
|
||||
/**
|
||||
* The GPU devices each group is currently using.
|
||||
* The GPU devices are stored in a device sequence and the cache NCCL
|
||||
* communicator is associated with this GPU device sequence
|
||||
*
|
||||
* e.g. If the group only uses device 0, then the value of
|
||||
* the used device string stored (value of the hashmap) would be "0".
|
||||
*
|
||||
* If the group uses device 0 - 7 and the each tensor of the
|
||||
* input tensor list is on device, 0, 1, 2, 3, 4, 5, 6, 7 separately,
|
||||
* then the value of the used device string stored would be
|
||||
* "0,1,2,3,4,5,6,7"
|
||||
*
|
||||
* If the group uses device 0 - 7 and the each tensor of the
|
||||
* input tensor list is on device, 0, 4, 5, 6, 7, 1, 2, 3 separately,
|
||||
* then the value of the used device string stored would be
|
||||
* "0,4,5,6,7,1,2,3"
|
||||
*
|
||||
* Note that the order of the device for the tensor list matters.
|
||||
*
|
||||
* Also note that each group only caches a single NCCL communicator
|
||||
* associated with the current "used device string".
|
||||
*
|
||||
* If a new device string appears, the previous
|
||||
* cached communicator will be destroyed and a new one with the new
|
||||
* device string will be built
|
||||
*/
|
||||
std::unordered_map<THDGroup, std::vector<std::string>> _groupDevices;
|
||||
|
||||
/**
|
||||
* NCCL resources for for each THDGroup including:
|
||||
* NCCL communicator for the current group
|
||||
* Cuda Events for all GPUs for NCCL operations of the current group
|
||||
*/
|
||||
std::unordered_map<THDGroup, std::vector<NcclResources>> _groupNcclResources;
|
||||
|
||||
// Existing groups
|
||||
std::unordered_map<THDGroup, DataChannel::Group> _groups;
|
||||
|
||||
// Helper function that gets the NCCL communicator
|
||||
NcclResourcePair _getNcclResourcePair(
|
||||
std::vector<at::Tensor>& input,
|
||||
THDGroup groupId);
|
||||
|
||||
/**
|
||||
* Helper function that broadcasts the NCCL unique ID to everyone in the rank
|
||||
* NCCLID pointed by ncclId of Rank 0 will be sent to other ranks' NCCID
|
||||
* pointed by ncclId
|
||||
*/
|
||||
void broadcastUniqueNcclId(ncclUniqueId* ncclId);
|
||||
|
||||
// Helper that checks the input and output tensors
|
||||
bool _tensorCheckHelper(
|
||||
const std::vector<at::Tensor>& input,
|
||||
const std::vector<at::Tensor>& output,
|
||||
size_t outputOverInput = 1);
|
||||
|
||||
// Helper that destroys a group's NCCL resources
|
||||
void _destroyNcclResources(THDGroup groupId);
|
||||
|
||||
// Group validity checker
|
||||
void _checkGroupIdValid(THDGroup groupId);
|
||||
|
||||
// Helper fucntion that destroys all the open sockets
|
||||
void _destroySockets();
|
||||
};
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,838 +0,0 @@
|
|||
#include <THD/base/data_channels/DataChannelTCP.hpp>
|
||||
|
||||
#include <sys/poll.h>
|
||||
#include <unistd.h>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <system_error>
|
||||
|
||||
namespace thd {
|
||||
namespace {
|
||||
|
||||
inline uint32_t log2ceil(uint32_t value) {
|
||||
uint32_t dim = 0;
|
||||
#if defined(__GNUC__)
|
||||
if (value <= 1)
|
||||
return 0;
|
||||
dim = 32 - __builtin_clz(value - 1);
|
||||
#else
|
||||
for (uint32_t size = 1; size < value; ++dim, size <<= 1) /* empty */
|
||||
;
|
||||
#endif // defined(__GNUC__)
|
||||
return dim;
|
||||
}
|
||||
|
||||
// Finds nearest power-of-two less than or equal to `value`.
|
||||
template <typename T>
|
||||
inline uint64_t pow2(T value) {
|
||||
uint64_t pof2 = 1;
|
||||
while (pof2 <= value) {
|
||||
pof2 <<= 1;
|
||||
}
|
||||
pof2 >>= 1;
|
||||
return pof2;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
DataChannelTCP::RequestTCP::RequestTCP(QueueWorker::Request&& request)
|
||||
: _request(std::move(request)) {}
|
||||
|
||||
DataChannelTCP::RequestTCP::~RequestTCP() {}
|
||||
|
||||
bool DataChannelTCP::RequestTCP::isCompleted() {
|
||||
return _request.isCompleted();
|
||||
}
|
||||
|
||||
void DataChannelTCP::RequestTCP::wait() {
|
||||
_request.wait();
|
||||
}
|
||||
|
||||
DataChannelTCP::DataChannelTCP(InitMethod::Config config)
|
||||
: DataChannelTCP(config, -1) {}
|
||||
|
||||
DataChannelTCP::DataChannelTCP(InitMethod::Config config, int timeout)
|
||||
: _socket(-1),
|
||||
_port(0),
|
||||
_timeout(timeout),
|
||||
_processes(config.world_size),
|
||||
_poll_events(nullptr) {
|
||||
_rank = config.rank;
|
||||
|
||||
if (_rank == 0) { // MASTER
|
||||
_socket = config.master.listen_socket;
|
||||
_port = config.master.listen_port;
|
||||
|
||||
_processes[0] = {
|
||||
.rank = 0,
|
||||
.address = "",
|
||||
.port = 0,
|
||||
.socket = -1,
|
||||
};
|
||||
} else { // WORKER
|
||||
// add master
|
||||
_processes[0] = {
|
||||
.rank = 0,
|
||||
.address = config.worker.master_addr,
|
||||
.port = config.worker.master_port,
|
||||
.socket = -1,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
DataChannelTCP::~DataChannelTCP() {
|
||||
if (_socket != -1)
|
||||
::close(_socket);
|
||||
|
||||
for (const auto& process : _processes) {
|
||||
if ((process.rank != _rank) && (process.socket != -1))
|
||||
::close(process.socket);
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelTCP::destroy() {}
|
||||
|
||||
bool DataChannelTCP::initWorker() {
|
||||
auto& master = _processes[0];
|
||||
master.socket = connect(master.address, master.port);
|
||||
|
||||
std::tie(_socket, _port) = listen();
|
||||
|
||||
send_value<rank_type>(master.socket, _rank, true);
|
||||
send_value<port_type>(master.socket, _port); // send listening port to master
|
||||
|
||||
// get all metadata of other processes in network
|
||||
for (size_t i = 1; i < _processes.size(); ++i) {
|
||||
rank_type p_rank = recv_value<rank_type>(master.socket);
|
||||
port_type p_port = recv_value<port_type>(master.socket);
|
||||
std::string p_address = recv_string(master.socket);
|
||||
|
||||
_processes[p_rank] = {
|
||||
.rank = p_rank,
|
||||
.address = p_address,
|
||||
.port = p_port,
|
||||
.socket = -1,
|
||||
};
|
||||
}
|
||||
|
||||
/*
|
||||
* Firstly we are connecting to workers with rank lower than our rank,
|
||||
* then we accepting connections from other wokers with higher rank.
|
||||
*
|
||||
* This prevents from deadlocks where everyone is accepting or everyone is
|
||||
* trying to connect.
|
||||
*/
|
||||
|
||||
for (rank_type r = 1; r < _rank; ++r) {
|
||||
auto& process = _processes[r];
|
||||
process.socket = connect(process.address, process.port);
|
||||
|
||||
// send rank to tell to the accepting process who we are
|
||||
send_value<rank_type>(process.socket, _rank);
|
||||
}
|
||||
|
||||
for (rank_type i = _rank + 1; i < _processes.size(); ++i) {
|
||||
int socket;
|
||||
std::tie(socket, std::ignore) = accept(_socket, _timeout);
|
||||
|
||||
// get rank of process we have just accepted
|
||||
rank_type p_rank = recv_value<rank_type>(socket);
|
||||
_processes[p_rank].socket = socket;
|
||||
}
|
||||
|
||||
// close socket for listening, we will not use it anymore
|
||||
::close(_socket);
|
||||
_socket = -1;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DataChannelTCP::initMaster() {
|
||||
// wait for all workers to connect
|
||||
for (size_t i = 1; i < _processes.size(); ++i) {
|
||||
std::string p_address;
|
||||
int p_socket;
|
||||
std::tie(p_socket, p_address) = accept(_socket, _timeout);
|
||||
|
||||
rank_type p_rank = recv_value<rank_type>(p_socket);
|
||||
port_type p_port = recv_value<port_type>(p_socket);
|
||||
|
||||
if (p_rank >= _processes.size()) {
|
||||
throw std::out_of_range(
|
||||
"worker's rank(" + std::to_string(p_rank) +
|
||||
") is out"
|
||||
"of range: [0, " +
|
||||
std::to_string(_processes.size() - 1) + "]");
|
||||
}
|
||||
|
||||
if (_processes[p_rank].rank == p_rank) {
|
||||
throw std::logic_error(
|
||||
"two processes (" + _processes[p_rank].address + ", " + p_address +
|
||||
") "
|
||||
"reported a rank of " +
|
||||
std::to_string(p_rank));
|
||||
}
|
||||
|
||||
_processes[p_rank] = {
|
||||
.rank = p_rank,
|
||||
.address = p_address,
|
||||
.port = p_port,
|
||||
.socket = p_socket,
|
||||
};
|
||||
}
|
||||
|
||||
// send informations about processes to all workers
|
||||
for (const auto& worker : _processes) {
|
||||
if (worker.rank == 0)
|
||||
continue;
|
||||
|
||||
for (auto& process : _processes) {
|
||||
if (process.rank == 0)
|
||||
continue;
|
||||
|
||||
send_value<rank_type>(worker.socket, process.rank, true);
|
||||
send_value<port_type>(worker.socket, process.port, true);
|
||||
send_string(worker.socket, process.address);
|
||||
}
|
||||
}
|
||||
|
||||
// close socket for listening, we will not use it anymore
|
||||
::close(_socket);
|
||||
_socket = -1;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DataChannelTCP::init() {
|
||||
bool ok = (_rank == 0 ? initMaster() : initWorker());
|
||||
if (ok) {
|
||||
std::vector<rank_type> ranks;
|
||||
ranks.reserve(_processes.size());
|
||||
for (rank_type rank = 0; rank < _processes.size(); ++rank)
|
||||
ranks.push_back(rank);
|
||||
|
||||
_groups.insert(
|
||||
{THDGroupWORLD, DataChannel::Group(ranks, _processes.size() - 1)});
|
||||
}
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
||||
rank_type DataChannelTCP::getRank() {
|
||||
return _rank;
|
||||
}
|
||||
|
||||
rank_type DataChannelTCP::getNumProcesses() {
|
||||
return _processes.size();
|
||||
}
|
||||
|
||||
void DataChannelTCP::allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
THDGroup group_id) {
|
||||
/*
|
||||
* Allgather algorithm is simple ring algorithm. This algorithm perfroms
|
||||
* well on large data (> 512 KB) and generalize well on large group of nodes.
|
||||
* More about efficiency can be found here:
|
||||
* > http://www.mcs.anl.gov/~thakur/papers/ijhpca-coll.pdf (section 4.1)
|
||||
*
|
||||
* TODO: implement Bruck / recursive doubling algorithms to make allGather
|
||||
* efficient also for small data (< 512 KB).
|
||||
*/
|
||||
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
|
||||
const auto& group = _groups.at(group_id);
|
||||
rank_type group_rank;
|
||||
bool exists;
|
||||
std::tie(group_rank, exists) = group.getGroupRank(_rank);
|
||||
if (!exists)
|
||||
return;
|
||||
|
||||
if (output.size() != group.size())
|
||||
throw std::logic_error(
|
||||
"allGather: number of output tensors and group size does not match");
|
||||
|
||||
for (auto out_tensor : output)
|
||||
assertSameSizeAndType(out_tensor, input, "allGather");
|
||||
|
||||
rank_type left = (group.size() + group_rank - 1) % group.size();
|
||||
rank_type right = (group_rank + 1) % group.size();
|
||||
|
||||
memcpy(
|
||||
output[group_rank].data_ptr(),
|
||||
input.data_ptr(),
|
||||
input.element_size() * input.numel());
|
||||
|
||||
auto j = group_rank, jnext = left;
|
||||
for (rank_type i = 0; i < group.size(); ++i) {
|
||||
req_ptr send_request{isend((output[j]), group.mustGetGlobalRank(right))};
|
||||
receive((output[jnext]), group.mustGetGlobalRank(left));
|
||||
send_request->wait();
|
||||
|
||||
j = jnext;
|
||||
jnext = (group.size() + jnext - 1) % group.size();
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelTCP::gather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id) {
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
|
||||
const auto& group = _groups.at(group_id);
|
||||
bool exists;
|
||||
|
||||
std::tie(std::ignore, exists) = group.getGroupRank(_rank);
|
||||
if (!exists)
|
||||
return;
|
||||
|
||||
// assert if dst_rank exists in group
|
||||
group.mustGetGroupRank(dst_rank);
|
||||
if (_rank != dst_rank) {
|
||||
send(input, dst_rank);
|
||||
} else {
|
||||
if (output.size() != group.size())
|
||||
throw std::logic_error(
|
||||
"gather: number of output tensors and group size does not match");
|
||||
|
||||
for (auto out_tensor : output)
|
||||
assertSameSizeAndType(out_tensor, input, "gather");
|
||||
|
||||
for (rank_type i = 0; i < group.size(); ++i) {
|
||||
auto global_rank = group.mustGetGlobalRank(i);
|
||||
if (_rank != global_rank) {
|
||||
receive((output.at(i)), global_rank);
|
||||
} else {
|
||||
memcpy(
|
||||
output.at(i).data_ptr(),
|
||||
input.data_ptr(),
|
||||
input.numel() * input.element_size());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelTCP::scatter(
|
||||
std::vector<at::Tensor>& input,
|
||||
at::Tensor& output,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id) {
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
|
||||
const auto& group = _groups.at(group_id);
|
||||
bool exists;
|
||||
|
||||
std::tie(std::ignore, exists) = group.getGroupRank(_rank);
|
||||
if (!exists)
|
||||
return;
|
||||
|
||||
// assert if src_rank exists in group
|
||||
group.mustGetGroupRank(src_rank);
|
||||
if (_rank != src_rank) {
|
||||
receive(output, src_rank);
|
||||
} else {
|
||||
if (input.size() != group.size())
|
||||
throw std::logic_error(
|
||||
"scatter: number of input tensors and group size does not match");
|
||||
|
||||
for (auto in_tensor : input)
|
||||
assertSameSizeAndType(in_tensor, output, "scatter");
|
||||
|
||||
for (rank_type i = 0; i < group.size(); ++i) {
|
||||
auto global_rank = group.mustGetGlobalRank(i);
|
||||
if (_rank != global_rank) {
|
||||
send((input.at(i)), global_rank);
|
||||
} else {
|
||||
memcpy(
|
||||
output.data_ptr(),
|
||||
input.at(i).data_ptr(),
|
||||
output.numel() * output.element_size());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelTCP::allReduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup group_id) {
|
||||
/*
|
||||
* Allreduce implementation is recursive doubling algorithm. It is good
|
||||
* algorithm for small sizes of message but other (theoratically better)
|
||||
* implementations could not be addapted because of non-commutative
|
||||
* operations on tensors (operation cannot be commutative because this could
|
||||
* introduce different numerical errors on different workers).
|
||||
*
|
||||
* More about efficiency can be found here:
|
||||
* > http://www.mcs.anl.gov/~thakur/papers/ijhpca-coll.pdf (section 4.5)
|
||||
*
|
||||
* Implementation is based on:
|
||||
* > https://github.com/pmodels/mpich/blob/master/src/mpi/coll/allreduce.c
|
||||
*/
|
||||
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
|
||||
const auto& group = _groups.at(group_id);
|
||||
rank_type group_rank;
|
||||
bool exists;
|
||||
|
||||
std::tie(group_rank, exists) = group.getGroupRank(_rank);
|
||||
if (!exists)
|
||||
return;
|
||||
|
||||
uint64_t tensor_bytes = data.element_size() * data.numel();
|
||||
auto tmp_tensor = data.clone();
|
||||
|
||||
auto pof2 = pow2(group.size());
|
||||
int rem = group.size() - pof2;
|
||||
int newrank = 0;
|
||||
|
||||
if (group_rank < 2 * rem) {
|
||||
if (group_rank % 2 == 0) {
|
||||
send(data, group.mustGetGlobalRank(group_rank + 1));
|
||||
newrank = -1;
|
||||
} else {
|
||||
receive(tmp_tensor, group.mustGetGlobalRank(group_rank - 1));
|
||||
_reduce(data, tmp_tensor, operation);
|
||||
newrank = group_rank / 2;
|
||||
}
|
||||
} else {
|
||||
newrank = group_rank - rem;
|
||||
}
|
||||
|
||||
if (newrank != -1) {
|
||||
int mask = 0x1;
|
||||
while (mask < pof2) {
|
||||
int newdst = newrank ^ mask;
|
||||
int dst = (newdst < rem) ? (newdst * 2 + 1) : (newdst + rem);
|
||||
|
||||
auto dst_global_rank = group.mustGetGlobalRank(dst);
|
||||
req_ptr send_request{isend(data, dst_global_rank)};
|
||||
receive(tmp_tensor, dst_global_rank);
|
||||
send_request->wait();
|
||||
|
||||
if (dst < group_rank) {
|
||||
_reduce(data, tmp_tensor, operation);
|
||||
} else {
|
||||
_reduce(tmp_tensor, data, operation);
|
||||
std::memcpy(data.data_ptr(), tmp_tensor.data_ptr(), tensor_bytes);
|
||||
}
|
||||
|
||||
mask <<= 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (group_rank < 2 * rem) {
|
||||
if (group_rank % 2) {
|
||||
send(data, group.mustGetGlobalRank(group_rank - 1));
|
||||
} else {
|
||||
receive(data, group.mustGetGlobalRank(group_rank + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelTCP::reduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id) {
|
||||
/*
|
||||
* Idea of this algorithm is similar to broadcast but with reversed
|
||||
* order and direction of communication.
|
||||
*/
|
||||
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
|
||||
const auto& group = _groups.at(group_id);
|
||||
rank_type group_rank;
|
||||
bool exists;
|
||||
|
||||
std::tie(group_rank, exists) = group.getGroupRank(_rank);
|
||||
if (!exists)
|
||||
return;
|
||||
|
||||
auto group_dst_rank = group.mustGetGroupRank(dst_rank);
|
||||
int dim = log2ceil(group.size());
|
||||
rank_type virtual_rank =
|
||||
(group_rank + group.size() - group_dst_rank) % group.size();
|
||||
int64_t mask = 0;
|
||||
auto result_tensor = data.clone();
|
||||
|
||||
for (int k = 0; k <= dim - 1; mask ^= (1 << k), ++k) {
|
||||
if ((virtual_rank & mask) == 0) {
|
||||
rank_type partner =
|
||||
virtual_rank ^ (1 << k); // partner has opposite bit `k`
|
||||
if (partner >= group.size())
|
||||
continue;
|
||||
|
||||
partner =
|
||||
group.mustGetGlobalRank((partner + group_dst_rank) % group.size());
|
||||
if ((virtual_rank & (1 << k)) != 0) {
|
||||
send(result_tensor, partner);
|
||||
} else {
|
||||
receive(data, partner);
|
||||
_reduce(result_tensor, data, operation);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (_rank == dst_rank)
|
||||
std::memcpy(
|
||||
data.data_ptr(),
|
||||
result_tensor.data_ptr(),
|
||||
data.element_size() * data.numel());
|
||||
}
|
||||
|
||||
void DataChannelTCP::broadcast(
|
||||
at::Tensor& data,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id) {
|
||||
/*
|
||||
* General idea of this algorithm is to send data in `d` dimensional
|
||||
* hypercube where vertices are nodes (processes) and edges are
|
||||
* network connections which can be used to transfer data.
|
||||
*
|
||||
* Since hypercube algorithm works for case when broadcasting rank is 0
|
||||
* we have to create `virtual_rank` which converts regular ranks to
|
||||
* virtual ones where `virtual_rank` for `src_rank` is 0.
|
||||
*/
|
||||
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
|
||||
const auto& group = _groups.at(group_id);
|
||||
rank_type group_rank;
|
||||
bool exists;
|
||||
|
||||
std::tie(group_rank, exists) = group.getGroupRank(_rank);
|
||||
if (!exists)
|
||||
return;
|
||||
|
||||
auto group_src_rank = group.mustGetGroupRank(src_rank);
|
||||
int dim = log2ceil(group.size());
|
||||
rank_type virtual_rank =
|
||||
(group_rank + group.size() - group_src_rank) % group.size();
|
||||
int64_t mask = (1 << dim) - 1;
|
||||
|
||||
for (int k = dim - 1; k >= 0; --k) {
|
||||
mask ^= (1 << k); // clear bit `k`
|
||||
if ((virtual_rank & mask) == 0) {
|
||||
rank_type partner =
|
||||
virtual_rank ^ (1 << k); // partner has opposite bit `k`
|
||||
if (partner >= group.size())
|
||||
continue;
|
||||
|
||||
partner =
|
||||
group.mustGetGlobalRank((partner + group_src_rank) % group.size());
|
||||
if ((virtual_rank & (1 << k)) == 0) {
|
||||
send(data, partner);
|
||||
} else {
|
||||
receive(data, partner);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelTCP::send(Scalar& data, rank_type dst_rank) {
|
||||
auto request = _send_worker.push(
|
||||
[this, &data, dst_rank] { this->_send(data, dst_rank); });
|
||||
request.wait();
|
||||
}
|
||||
|
||||
void DataChannelTCP::send(at::Tensor& data, rank_type dst_rank) {
|
||||
auto request = _send_worker.push(
|
||||
[this, &data, dst_rank] { this->_send(data, dst_rank); });
|
||||
request.wait();
|
||||
}
|
||||
|
||||
void DataChannelTCP::receive(Scalar& data, rank_type src_rank) {
|
||||
auto request = _receive_worker.push(
|
||||
[this, &data, src_rank] { this->_receive(data, src_rank); });
|
||||
request.wait();
|
||||
}
|
||||
|
||||
rank_type DataChannelTCP::receive(at::Tensor& data) {
|
||||
rank_type sender;
|
||||
auto request = _receive_worker.push([this, &data, &sender] {
|
||||
if (!this->_poll_events) {
|
||||
// cache poll events array, it will be reused in another `receive` calls
|
||||
this->_poll_events.reset(new struct pollfd[this->_processes.size()]);
|
||||
for (size_t rank = 0; rank < this->_processes.size(); ++rank) {
|
||||
this->_poll_events[rank] = {.fd = this->_processes[rank].socket,
|
||||
.events = POLLIN};
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup
|
||||
for (size_t rank = 0; rank < this->_processes.size(); ++rank) {
|
||||
this->_poll_events[rank].revents = 0;
|
||||
}
|
||||
|
||||
SYSCHECK(::poll(
|
||||
this->_poll_events.get(),
|
||||
this->_processes.size(),
|
||||
-1)) // infinite timeout
|
||||
for (size_t rank = 0; rank < this->_processes.size(); ++rank) {
|
||||
if (this->_poll_events[rank].revents == 0)
|
||||
continue;
|
||||
|
||||
if (this->_poll_events[rank].revents ^ POLLIN)
|
||||
throw std::system_error(ECONNABORTED, std::system_category());
|
||||
|
||||
this->_receive(data, rank);
|
||||
sender = rank;
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
request.wait();
|
||||
return sender;
|
||||
}
|
||||
|
||||
void DataChannelTCP::receive(at::Tensor& data, rank_type src_rank) {
|
||||
auto request = _receive_worker.push(
|
||||
[this, &data, src_rank] { this->_receive(data, src_rank); });
|
||||
request.wait();
|
||||
}
|
||||
|
||||
DataChannelTCP::RequestTCP* DataChannelTCP::isend(
|
||||
at::Tensor& data,
|
||||
rank_type dst_rank) {
|
||||
auto request = _send_worker.push(
|
||||
[this, data, dst_rank] { this->_send(data, dst_rank); });
|
||||
return new DataChannelTCP::RequestTCP(std::move(request));
|
||||
}
|
||||
|
||||
DataChannelTCP::RequestTCP* DataChannelTCP::ireceive(
|
||||
at::Tensor& data,
|
||||
rank_type src_rank) {
|
||||
auto request = _receive_worker.push(
|
||||
[this, data, src_rank] { this->_receive(data, src_rank); });
|
||||
return new DataChannelTCP::RequestTCP(std::move(request));
|
||||
}
|
||||
|
||||
void DataChannelTCP::barrier(THDGroup group_id) {
|
||||
/*
|
||||
* Barrier is implementation of Bruck algorithm. All processes send to
|
||||
* other processes with rank (i + 2^k) and recv from process with rank (i -
|
||||
* 2^k) with wrap-around. Since we cannot do recv and send at the same time we
|
||||
* do recv asynchronously (thread), send byte and then wait for recv to
|
||||
* complete.
|
||||
*/
|
||||
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
|
||||
const auto& group = _groups.at(group_id);
|
||||
rank_type group_rank;
|
||||
bool exists;
|
||||
|
||||
std::tie(group_rank, exists) = group.getGroupRank(_rank);
|
||||
if (!exists)
|
||||
return;
|
||||
|
||||
std::uint8_t byte = 1;
|
||||
for (rank_type distance = 1; distance < group.size(); distance <<= 1) {
|
||||
rank_type recv_partner =
|
||||
(group_rank + group.size() - distance) % group.size();
|
||||
const auto& recv_process =
|
||||
_processes.at(group.mustGetGlobalRank(recv_partner));
|
||||
auto recv_request = _receive_worker.push([&recv_process, &byte] {
|
||||
recv_bytes<std::uint8_t>(recv_process.socket, &byte, 1);
|
||||
});
|
||||
|
||||
rank_type send_partner = (group_rank + distance) % group.size();
|
||||
const auto& send_process =
|
||||
_processes.at(group.mustGetGlobalRank(send_partner));
|
||||
auto send_request = _send_worker.push([&send_process, &byte] {
|
||||
send_bytes<std::uint8_t>(send_process.socket, &byte, 1);
|
||||
});
|
||||
|
||||
send_request.wait();
|
||||
recv_request.wait();
|
||||
}
|
||||
}
|
||||
|
||||
THDGroup DataChannelTCP::newGroup(const std::vector<rank_type>& ranks) {
|
||||
auto new_group = DataChannel::Group(ranks, _processes.size() - 1);
|
||||
THDGroup new_group_id = static_cast<THDGroup>(_groups.size());
|
||||
|
||||
_groups.insert({new_group_id, new_group});
|
||||
return new_group_id;
|
||||
}
|
||||
|
||||
void DataChannelTCP::_send(const Scalar& data, rank_type dst_rank) {
|
||||
/*
|
||||
* We have to check if dst_rank is positive to properly use `.at` function in
|
||||
* vector. Not checking that can result in int overflow and strange errors.
|
||||
*/
|
||||
|
||||
const auto& process_dst = _processes.at(dst_rank);
|
||||
if (process_dst.rank == _rank)
|
||||
throw std::logic_error("cannot send scalar to process with same rank");
|
||||
|
||||
// send size of scalar in bytes
|
||||
uint64_t scalar_bytes = data.elementSize();
|
||||
send_bytes<uint64_t>(process_dst.socket, &scalar_bytes, 1, true);
|
||||
|
||||
// send data (bytes)
|
||||
send_bytes<std::uint8_t>(
|
||||
process_dst.socket,
|
||||
reinterpret_cast<const std::uint8_t*>(data.data()),
|
||||
scalar_bytes);
|
||||
}
|
||||
|
||||
void DataChannelTCP::_send(const at::Tensor& data, rank_type dst_rank) {
|
||||
/*
|
||||
* We have to check if dst_rank is positive to properly use `.at` function in
|
||||
* vector. Not checking that can result in int overflow and strange errors.
|
||||
*/
|
||||
|
||||
const auto& process_dst = _processes.at(dst_rank);
|
||||
if (process_dst.rank == _rank)
|
||||
throw std::logic_error("cannot send tensor to process with same rank");
|
||||
|
||||
if (!data.is_contiguous())
|
||||
throw std::logic_error("tensor to send is not contiguous");
|
||||
|
||||
// send size of tensor data in bytes
|
||||
uint64_t tensor_bytes = data.element_size() * data.numel();
|
||||
send_bytes<uint64_t>(process_dst.socket, &tensor_bytes, 1, true);
|
||||
|
||||
// send data (bytes)
|
||||
send_bytes<std::uint8_t>(
|
||||
process_dst.socket,
|
||||
reinterpret_cast<const std::uint8_t*>(data.data_ptr()),
|
||||
tensor_bytes);
|
||||
}
|
||||
|
||||
void DataChannelTCP::_receive(Scalar& data, rank_type src_rank) {
|
||||
/*
|
||||
* We have to check if src_rank is positive to properly use `.at` function in
|
||||
* vector. Not checking that can result in int overflow and strange errors.
|
||||
*/
|
||||
|
||||
const auto& process_src = _processes.at(src_rank);
|
||||
if (process_src.rank == _rank)
|
||||
throw std::logic_error("cannot receive scalar from process with same rank");
|
||||
|
||||
// get size of scalar in bytes
|
||||
uint64_t scalar_bytes;
|
||||
recv_bytes<uint64_t>(process_src.socket, &scalar_bytes, 1);
|
||||
|
||||
uint64_t actual_scalar_bytes = data.elementSize();
|
||||
if (actual_scalar_bytes == scalar_bytes) {
|
||||
recv_bytes<std::uint8_t>(
|
||||
process_src.socket,
|
||||
reinterpret_cast<std::uint8_t*>(data.data()),
|
||||
scalar_bytes);
|
||||
} else {
|
||||
// remove invalid data from recv buffer
|
||||
std::unique_ptr<std::uint8_t[]> bytes(new std::uint8_t[scalar_bytes]);
|
||||
recv_bytes<std::uint8_t>(process_src.socket, bytes.get(), scalar_bytes);
|
||||
throw std::logic_error("scalar sizes do not match");
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelTCP::_receive(const at::Tensor& data, rank_type src_rank) {
|
||||
/*
|
||||
* We have to check if src_rank is positive to properly use `.at` function in
|
||||
* vector. Not checking that can result in int overflow and strange errors.
|
||||
*/
|
||||
|
||||
const auto& process_src = _processes.at(src_rank);
|
||||
if (process_src.rank == _rank)
|
||||
throw std::logic_error("cannot receive tensor from process with same rank");
|
||||
|
||||
if (!data.is_contiguous())
|
||||
throw std::logic_error("tensor to receive is not contiguous");
|
||||
|
||||
// get size of tensor data in bytes
|
||||
uint64_t tensor_bytes;
|
||||
recv_bytes<uint64_t>(process_src.socket, &tensor_bytes, 1);
|
||||
|
||||
uint64_t actual_tensor_bytes =
|
||||
data.element_size() * data.numel();
|
||||
if (actual_tensor_bytes == tensor_bytes) {
|
||||
recv_bytes<std::uint8_t>(
|
||||
process_src.socket,
|
||||
reinterpret_cast<std::uint8_t*>(data.data_ptr()),
|
||||
tensor_bytes);
|
||||
} else {
|
||||
// remove invalid data from recv buffer
|
||||
std::unique_ptr<std::uint8_t[]> bytes(new std::uint8_t[tensor_bytes]);
|
||||
recv_bytes<std::uint8_t>(process_src.socket, bytes.get(), tensor_bytes);
|
||||
throw std::logic_error("tensor sizes do not match");
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelTCP::_reduce(
|
||||
at::Tensor& result,
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation) const {
|
||||
assertSameSizeAndType(result, data, "reduce");
|
||||
|
||||
if (operation == THDReduceOp::THDReduceMIN) {
|
||||
at::min_out(result, result, data);
|
||||
} else if (operation == THDReduceOp::THDReduceMAX) {
|
||||
at::max_out(result, result, data);
|
||||
} else if (operation == THDReduceOp::THDReduceSUM) {
|
||||
result.add_(data);
|
||||
} else if (operation == THDReduceOp::THDReducePRODUCT) {
|
||||
result.mul_(data);
|
||||
} else {
|
||||
throw std::logic_error("unsupported reduce operation");
|
||||
}
|
||||
}
|
||||
|
||||
void DataChannelTCP::allReduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelTCP does not support mult-GPU cross "
|
||||
"node allreduce");
|
||||
}
|
||||
|
||||
void DataChannelTCP::allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
std::vector<at::Tensor>& input,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelTCP does not support mult-GPU cross "
|
||||
"node allgather");
|
||||
}
|
||||
|
||||
void DataChannelTCP::reduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dstRank,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelTCP does not support mult-GPU cross "
|
||||
"node reduce");
|
||||
}
|
||||
|
||||
void DataChannelTCP::broadcast(
|
||||
std::vector<at::Tensor>& data,
|
||||
rank_type srcRank,
|
||||
THDGroup groupId) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelTCP does not support mult-GPU cross "
|
||||
"node broadcast");
|
||||
}
|
||||
|
||||
void DataChannelTCP::clearGroupCache(THDGroup group_id) {
|
||||
throw std::runtime_error(
|
||||
"DataChannelTCP does not support clear "
|
||||
"group cache");
|
||||
}
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/base/DataChannel.hpp>
|
||||
#include <THD/base/data_channels/DataChannelUtils.hpp>
|
||||
|
||||
#include <sys/poll.h>
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace thd {
|
||||
|
||||
struct DataChannelTCP : DataChannel {
|
||||
struct RequestTCP : DataChannel::Request {
|
||||
RequestTCP(QueueWorker::Request&& request);
|
||||
virtual ~RequestTCP();
|
||||
|
||||
virtual bool isCompleted() override;
|
||||
virtual void wait() override;
|
||||
|
||||
private:
|
||||
QueueWorker::Request _request;
|
||||
};
|
||||
|
||||
DataChannelTCP(InitMethod::Config config);
|
||||
DataChannelTCP(InitMethod::Config config, int timeout);
|
||||
virtual ~DataChannelTCP();
|
||||
|
||||
bool init() override;
|
||||
void destroy() override;
|
||||
|
||||
rank_type getRank() override;
|
||||
rank_type getNumProcesses() override;
|
||||
|
||||
void allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
std::vector<at::Tensor>& input,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void allGather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void gather(
|
||||
std::vector<at::Tensor>& output,
|
||||
at::Tensor& input,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void scatter(
|
||||
std::vector<at::Tensor>& input,
|
||||
at::Tensor& output,
|
||||
rank_type src_rank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void allReduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void allReduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void reduce(
|
||||
std::vector<at::Tensor>& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dstRank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void reduce(
|
||||
at::Tensor& data,
|
||||
THDReduceOp operation,
|
||||
rank_type dst_rank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void broadcast(
|
||||
std::vector<at::Tensor>& data,
|
||||
rank_type srcRank,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void broadcast(
|
||||
at::Tensor& data,
|
||||
rank_type src_id,
|
||||
THDGroup group_id = THDGroupWORLD) override;
|
||||
void send(Scalar& data, rank_type dst_id) override;
|
||||
void send(at::Tensor& data, rank_type dst_id) override;
|
||||
void receive(Scalar& data, rank_type src_id) override;
|
||||
rank_type receive(at::Tensor& data) override;
|
||||
void receive(at::Tensor& data, rank_type src_id) override;
|
||||
RequestTCP* isend(at::Tensor& data, rank_type dst_rank) override;
|
||||
RequestTCP* ireceive(at::Tensor& data, rank_type src_rank) override;
|
||||
|
||||
void barrier(THDGroup group_id = THDGroupWORLD) override;
|
||||
|
||||
THDGroup newGroup(const std::vector<rank_type>& ranks) override;
|
||||
void clearGroupCache(THDGroup group_id = THDGroupWORLD) override;
|
||||
|
||||
private:
|
||||
using req_ptr = std::unique_ptr<RequestTCP>;
|
||||
// Defines process to which master or worker is connected
|
||||
struct Process {
|
||||
rank_type rank;
|
||||
std::string address;
|
||||
port_type port;
|
||||
int socket;
|
||||
};
|
||||
|
||||
bool initMaster();
|
||||
bool initWorker();
|
||||
|
||||
void _send(const Scalar& data, rank_type dst_id);
|
||||
void _send(const at::Tensor& data, rank_type dst_id);
|
||||
void _receive(Scalar& data, rank_type src_id);
|
||||
void _receive(const at::Tensor& data, rank_type src_id);
|
||||
void _reduce(at::Tensor& result, at::Tensor& data, THDReduceOp operation)
|
||||
const;
|
||||
|
||||
rank_type _rank; // Rank of current process, range: [0.._processes.size()-1]
|
||||
int _socket; // Socket on which process is listening
|
||||
port_type _port; // Port on which process is listening
|
||||
int _timeout; // Accept waiting timeout in milliseconds (it is optional,
|
||||
// default = infinity)
|
||||
|
||||
std::vector<Process> _processes; // Other processes in network
|
||||
std::unique_ptr<struct pollfd[]> _poll_events; // Events array for `poll`
|
||||
|
||||
// General mutex for methods - to protect access to the TCP data channel.
|
||||
std::mutex _mutex;
|
||||
|
||||
// Existing groups of processes and corresponding group ids
|
||||
std::unordered_map<THDGroup, DataChannel::Group> _groups;
|
||||
|
||||
// Workers
|
||||
QueueWorker _send_worker, _receive_worker;
|
||||
};
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,156 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/base/DataChannel.hpp>
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
namespace thd {
|
||||
|
||||
inline void assertSameSizeAndType(
|
||||
const at::Tensor& tensor1,
|
||||
const at::Tensor& tensor2,
|
||||
std::string prefix = std::string()) {
|
||||
bool equal = tensor1.element_size() ==
|
||||
tensor2.element_size() &&
|
||||
tensor1.numel() == tensor2.numel() && tensor1.type() == tensor2.type();
|
||||
|
||||
if (!prefix.empty())
|
||||
prefix = prefix + ": ";
|
||||
|
||||
if (!equal)
|
||||
throw std::logic_error(
|
||||
prefix + "tensors are not equal in size or data type");
|
||||
}
|
||||
|
||||
struct QueueWorker {
|
||||
private:
|
||||
struct Task {
|
||||
Task(std::function<void()>&& handler)
|
||||
: _handler(handler), _completed(false) {}
|
||||
Task(const Task&) = delete;
|
||||
Task& operator=(const Task&) = delete;
|
||||
|
||||
void run() {
|
||||
std::unique_lock<std::mutex> ulock(_mutex);
|
||||
|
||||
try {
|
||||
_handler();
|
||||
} catch (...) {
|
||||
// Do not propagate exception here. We should save it and throw it
|
||||
// in `complete` or `wait` function to user.
|
||||
_exception = std::current_exception();
|
||||
}
|
||||
|
||||
_completed = true;
|
||||
ulock.unlock();
|
||||
_cond.notify_all();
|
||||
}
|
||||
|
||||
bool isCompleted() {
|
||||
std::unique_lock<std::mutex> ulock(_mutex);
|
||||
_validate();
|
||||
return _completed;
|
||||
}
|
||||
|
||||
void wait() {
|
||||
std::unique_lock<std::mutex> ulock(_mutex);
|
||||
if (!_completed)
|
||||
_cond.wait(ulock);
|
||||
|
||||
_validate();
|
||||
}
|
||||
|
||||
private:
|
||||
void _validate() {
|
||||
if (_exception)
|
||||
std::rethrow_exception(_exception);
|
||||
}
|
||||
|
||||
std::function<void()> _handler;
|
||||
std::atomic<bool> _completed;
|
||||
std::mutex _mutex;
|
||||
std::condition_variable _cond;
|
||||
std::exception_ptr _exception;
|
||||
};
|
||||
|
||||
public:
|
||||
struct Request {
|
||||
Request(std::shared_ptr<QueueWorker::Task> item) : _item(item) {}
|
||||
|
||||
void wait() {
|
||||
_item->wait();
|
||||
}
|
||||
bool isCompleted() {
|
||||
return _item->isCompleted();
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<QueueWorker::Task> _item;
|
||||
};
|
||||
|
||||
QueueWorker() : _exiting(false) {
|
||||
_main_thread = std::thread(&QueueWorker::_runner, this);
|
||||
}
|
||||
|
||||
~QueueWorker() {
|
||||
_exiting = true;
|
||||
_cond.notify_one();
|
||||
_main_thread.join();
|
||||
}
|
||||
|
||||
QueueWorker(const QueueWorker&) = delete;
|
||||
QueueWorker& operator=(const QueueWorker&) = delete;
|
||||
|
||||
Request push(std::function<void()>&& f) {
|
||||
auto item = _push(std::make_shared<Task>(std::move(f)));
|
||||
return Request(item);
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Task> _pop() {
|
||||
std::unique_lock<std::mutex> ulock(_mutex);
|
||||
if (_queue.empty())
|
||||
_cond.wait(ulock);
|
||||
|
||||
if (_exiting) // check if we were woken up by destructor
|
||||
return nullptr;
|
||||
|
||||
auto val = _queue.front();
|
||||
_queue.pop();
|
||||
return val;
|
||||
}
|
||||
|
||||
std::shared_ptr<Task> _push(std::shared_ptr<Task> item) {
|
||||
std::unique_lock<std::mutex> ulock(_mutex);
|
||||
_queue.push(item);
|
||||
ulock.unlock();
|
||||
_cond.notify_one();
|
||||
return item;
|
||||
}
|
||||
|
||||
void _runner() {
|
||||
while (true) {
|
||||
auto item = _pop();
|
||||
if (!item) // empty item -> we need to end (descructor called)
|
||||
return;
|
||||
|
||||
item->run();
|
||||
}
|
||||
}
|
||||
|
||||
std::atomic<bool> _exiting;
|
||||
std::queue<std::shared_ptr<Task>> _queue;
|
||||
std::mutex _mutex;
|
||||
std::condition_variable _cond;
|
||||
|
||||
std::thread _main_thread;
|
||||
};
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,516 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/base/ChannelUtils.hpp>
|
||||
#include <THD/base/Cuda.hpp>
|
||||
#include <THD/base/DataChannel.hpp>
|
||||
|
||||
#include <gloo/algorithm.h>
|
||||
#include <gloo/allgather_ring.h>
|
||||
#include <gloo/allreduce_ring.h>
|
||||
#include <gloo/barrier_all_to_all.h>
|
||||
#include <gloo/broadcast_one_to_all.h>
|
||||
#ifdef USE_CUDA
|
||||
#include <gloo/cuda_allreduce_halving_doubling.h>
|
||||
#include <gloo/cuda_allreduce_halving_doubling_pipelined.h>
|
||||
#include <gloo/cuda_allreduce_ring.h>
|
||||
#include <gloo/cuda_broadcast_one_to_all.h>
|
||||
#endif
|
||||
#include <gloo/rendezvous/context.h>
|
||||
#include <gloo/rendezvous/prefix_store.h>
|
||||
#include <gloo/rendezvous/store.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <THC/THC.h>
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
namespace thd {
|
||||
namespace gloo_cache {
|
||||
|
||||
using key_type = std::tuple<
|
||||
CollectiveType, // operation
|
||||
THDGroup, // group
|
||||
DeviceType, // tensors device type
|
||||
int, // CUDA stream id used in the algorithm
|
||||
size_t, // input buffer bytes
|
||||
size_t, // output buffer bytes
|
||||
THDReduceOp, // reduce op
|
||||
rank_type // src/dest rank
|
||||
>;
|
||||
|
||||
const DeviceType UNUSED_DEVICE = DeviceType::LAST;
|
||||
const THDReduceOp UNUSED_OP = THDReduceMIN;
|
||||
const int UNUSED_STREAM = -1;
|
||||
const rank_type UNUSED_RANK = -1;
|
||||
const size_t UNUSED_BYTES = 0;
|
||||
|
||||
// Forward declaration
|
||||
template <CollectiveType D, typename T>
|
||||
struct algorithm_spec;
|
||||
|
||||
} // namespace gloo_cache
|
||||
} // namespace thd
|
||||
|
||||
MAKE_HASHABLE(
|
||||
thd::gloo_cache::key_type,
|
||||
std::get<0>(t),
|
||||
std::get<1>(t),
|
||||
std::get<2>(t),
|
||||
std::get<3>(t),
|
||||
std::get<4>(t),
|
||||
std::get<5>(t),
|
||||
std::get<6>(t),
|
||||
std::get<7>(t));
|
||||
|
||||
namespace thd {
|
||||
|
||||
struct GlooCache {
|
||||
using buffer_type = char;
|
||||
using algorithm_type = ::gloo::Algorithm;
|
||||
using context_type = ::gloo::rendezvous::Context;
|
||||
using prefix_store_type = ::gloo::rendezvous::PrefixStore;
|
||||
using store_type = ::gloo::rendezvous::Store;
|
||||
|
||||
using key_type = gloo_cache::key_type;
|
||||
using value_type = std::tuple<
|
||||
std::shared_ptr<algorithm_type>, // algorithm
|
||||
std::shared_ptr<buffer_type>, // input buffer (nullptr if not used)
|
||||
std::shared_ptr<buffer_type>, // output buffer (nullptr if not used)
|
||||
std::shared_ptr<std::mutex> // mutex to protect same algorithm from
|
||||
// running concurrently
|
||||
>;
|
||||
|
||||
GlooCache(
|
||||
rank_type rank,
|
||||
std::vector<std::shared_ptr<::gloo::transport::Device>> deviceList)
|
||||
: _rank(rank), _deviceList(deviceList) {}
|
||||
|
||||
GlooCache(GlooCache const&) = delete;
|
||||
void operator=(GlooCache const&) = delete;
|
||||
|
||||
// Accessors for value_type tuple
|
||||
static inline std::shared_ptr<algorithm_type> algorithm(const value_type& t) {
|
||||
return std::get<0>(t);
|
||||
}
|
||||
|
||||
static inline std::shared_ptr<buffer_type> input_buffer(const value_type& t) {
|
||||
return std::get<1>(t);
|
||||
}
|
||||
|
||||
static inline std::shared_ptr<buffer_type> output_buffer(
|
||||
const value_type& t) {
|
||||
return std::get<2>(t);
|
||||
}
|
||||
|
||||
static inline std::shared_ptr<std::mutex> mutex(const value_type& t) {
|
||||
return std::get<3>(t);
|
||||
}
|
||||
|
||||
// NOTE: this function needs to be thread safe
|
||||
std::shared_ptr<context_type> createContext(
|
||||
const DataChannelGloo::Group& group,
|
||||
const std::string& prefix) {
|
||||
/**
|
||||
* We currently only supports a single Infiniband interface. In other words,
|
||||
* if there are multiple Infiniband devices in the system, Gloo will detect
|
||||
* all of them and use the first device.
|
||||
*
|
||||
* TODO: This can be extended later to utilize multiple Infiniband devices
|
||||
*
|
||||
* For ethernet, _deviceList[0] will always have the default ethernet
|
||||
* device that is detected from the user's provided IP address and there
|
||||
* won't be multiple one device in _deviceList
|
||||
*
|
||||
* For Infiniband, _deviceList[0], which is the first found IB interfance,
|
||||
* will be used by all Gloo operations.
|
||||
*/
|
||||
size_t curDevice = 0;
|
||||
auto context = std::make_shared<context_type>(
|
||||
group.mustGetGroupRank(_rank), group.size());
|
||||
prefix_store_type prefix_store(prefix, *group._store);
|
||||
context->connectFullMesh(prefix_store, _deviceList[curDevice]);
|
||||
return context;
|
||||
}
|
||||
|
||||
// NOTE: this function needs to be thread safe
|
||||
std::shared_ptr<buffer_type> createBuffer(size_t bytes, DeviceType device)
|
||||
const {
|
||||
if (device == DeviceType::CPU) {
|
||||
return std::shared_ptr<buffer_type>(
|
||||
new char[bytes], std::default_delete<char[]>());
|
||||
#ifdef USE_CUDA
|
||||
} else if (device == DeviceType::CUDA) {
|
||||
buffer_type* buf =
|
||||
static_cast<buffer_type*>(THCudaMalloc(THDGetCudaState(), bytes));
|
||||
return std::shared_ptr<buffer_type>(
|
||||
buf, [](char* ptr) { THCudaFree(THDGetCudaState(), ptr); });
|
||||
#endif
|
||||
} else {
|
||||
throw std::runtime_error("unsupported device in GlooCache::createBuffer");
|
||||
}
|
||||
}
|
||||
|
||||
template <CollectiveType D, typename T, typename... Args>
|
||||
value_type getAlgorithm(
|
||||
THDGroup group_id,
|
||||
const DataChannelGloo::Group& group,
|
||||
Args... args) {
|
||||
auto key = gloo_cache::algorithm_spec<D, T>::key(group_id, args...);
|
||||
|
||||
std::unique_lock<std::mutex> lock(_mutex);
|
||||
auto it = _algorithms.find(key);
|
||||
if (it == _algorithms.end()) {
|
||||
lock.unlock();
|
||||
|
||||
auto algorithm = gloo_cache::algorithm_spec<D, T>::create(
|
||||
*this, group, print_key(key), std::forward<Args>(args)...);
|
||||
|
||||
lock.lock();
|
||||
|
||||
bool inserted;
|
||||
std::tie(it, inserted) =
|
||||
_algorithms.emplace(std::move(key), std::move(algorithm));
|
||||
if (!inserted)
|
||||
throw std::runtime_error(
|
||||
"detected a race when creating Gloo algorithm");
|
||||
}
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static void memcpy_input(value_type& info, at::Tensor& t) {
|
||||
uint64_t tensor_bytes = t.element_size() * t.numel();
|
||||
auto t_dev = getDeviceType(t);
|
||||
auto input_buffer = GlooCache::input_buffer(info).get();
|
||||
|
||||
if (t_dev == DeviceType::CPU) {
|
||||
std::memcpy(input_buffer, t.data_ptr(), tensor_bytes);
|
||||
#ifdef USE_CUDA
|
||||
} else if (t_dev == DeviceType::CUDA) {
|
||||
auto stream = THCState_getCurrentStream(THDGetCudaState());
|
||||
THCudaCheck(cudaMemcpyAsync(
|
||||
input_buffer,
|
||||
t.data_ptr(),
|
||||
tensor_bytes,
|
||||
cudaMemcpyDeviceToDevice,
|
||||
stream));
|
||||
#endif
|
||||
} else {
|
||||
throw std::runtime_error("unsupported device in memcpy_input");
|
||||
}
|
||||
}
|
||||
|
||||
static void memcpy_output(value_type& info, at::Tensor& t) {
|
||||
uint64_t tensor_bytes = t.element_size() * t.numel();
|
||||
auto t_dev = getDeviceType(t);
|
||||
auto output_buffer = GlooCache::output_buffer(info).get();
|
||||
|
||||
if (t_dev == DeviceType::CPU) {
|
||||
std::memcpy(t.data_ptr(), output_buffer, tensor_bytes);
|
||||
#ifdef USE_CUDA
|
||||
} else if (t_dev == DeviceType::CUDA) {
|
||||
auto stream = THCState_getCurrentStream(THDGetCudaState());
|
||||
THCudaCheck(cudaMemcpyAsync(
|
||||
t.data_ptr(),
|
||||
output_buffer,
|
||||
tensor_bytes,
|
||||
cudaMemcpyDeviceToDevice,
|
||||
stream));
|
||||
#endif
|
||||
} else {
|
||||
throw std::runtime_error("unsupported device in memcpy_input");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::string print_key(const key_type& k) {
|
||||
return std::to_string(static_cast<uint8_t>(std::get<0>(k))) + "-" +
|
||||
std::to_string(std::get<1>(k)) + "-" +
|
||||
std::to_string(static_cast<uint8_t>(std::get<2>(k))) + "-" +
|
||||
std::to_string(std::get<3>(k)) + "-" + std::to_string(std::get<4>(k)) +
|
||||
"-" + std::to_string(std::get<5>(k)) + "-" +
|
||||
std::to_string(std::get<6>(k)) + "-" + std::to_string(std::get<7>(k));
|
||||
}
|
||||
|
||||
rank_type _rank;
|
||||
std::vector<std::shared_ptr<::gloo::transport::Device>> _deviceList;
|
||||
std::shared_ptr<store_type> _store;
|
||||
|
||||
std::mutex _mutex;
|
||||
|
||||
std::unordered_map<key_type, value_type> _algorithms;
|
||||
};
|
||||
|
||||
namespace gloo_cache {
|
||||
|
||||
template <typename T>
|
||||
const ::gloo::ReductionFunction<T>* THDToGlooReduceOp(THDReduceOp op) {
|
||||
switch (op) {
|
||||
case THDReduceMIN:
|
||||
return ::gloo::ReductionFunction<T>::min;
|
||||
case THDReduceMAX:
|
||||
return ::gloo::ReductionFunction<T>::max;
|
||||
case THDReduceSUM:
|
||||
return ::gloo::ReductionFunction<T>::sum;
|
||||
case THDReducePRODUCT:
|
||||
return ::gloo::ReductionFunction<T>::product;
|
||||
default:
|
||||
throw std::invalid_argument("unknown reduce operation");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct algorithm_spec<CollectiveType::ALL_GATHER, T> {
|
||||
static GlooCache::key_type key(
|
||||
THDGroup group_id,
|
||||
DeviceType device,
|
||||
size_t input_bytes,
|
||||
size_t output_bytes,
|
||||
size_t unused_count) {
|
||||
return std::make_tuple(
|
||||
CollectiveType::ALL_GATHER,
|
||||
group_id,
|
||||
device,
|
||||
UNUSED_STREAM,
|
||||
input_bytes,
|
||||
output_bytes,
|
||||
UNUSED_OP,
|
||||
UNUSED_RANK);
|
||||
}
|
||||
|
||||
static GlooCache::value_type create(
|
||||
GlooCache& cache,
|
||||
const DataChannelGloo::Group& group,
|
||||
const std::string& store_prefix,
|
||||
DeviceType device,
|
||||
size_t input_bytes,
|
||||
size_t output_bytes,
|
||||
size_t count) {
|
||||
auto context = cache.createContext(group, store_prefix);
|
||||
auto input_buffer = cache.createBuffer(input_bytes, device);
|
||||
auto output_buffer = cache.createBuffer(output_bytes, device);
|
||||
|
||||
std::shared_ptr<GlooCache::algorithm_type> algo;
|
||||
if (device == DeviceType::CPU) {
|
||||
algo = std::make_shared<::gloo::AllgatherRing<T>>(
|
||||
context,
|
||||
std::initializer_list<const T*>{
|
||||
reinterpret_cast<const T*>(input_buffer.get())},
|
||||
reinterpret_cast<T*>(output_buffer.get()),
|
||||
count);
|
||||
} else {
|
||||
throw std::runtime_error("unsupported device in Gloo allGather");
|
||||
}
|
||||
|
||||
return std::make_tuple(
|
||||
algo, input_buffer, output_buffer, std::make_shared<std::mutex>());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct algorithm_spec<CollectiveType::ALL_REDUCE, T> {
|
||||
static GlooCache::key_type key(
|
||||
THDGroup group_id,
|
||||
DeviceType device,
|
||||
size_t input_bytes,
|
||||
size_t unused_count,
|
||||
THDReduceOp op) {
|
||||
int stream = UNUSED_STREAM;
|
||||
#ifdef USE_CUDA
|
||||
if (device == DeviceType::CUDA) {
|
||||
auto cuda_stream = THCState_getCurrentStream(THDGetCudaState());
|
||||
stream = THDGetStreamId(cuda_stream);
|
||||
}
|
||||
#endif
|
||||
return std::make_tuple(
|
||||
CollectiveType::ALL_REDUCE,
|
||||
group_id,
|
||||
device,
|
||||
stream,
|
||||
input_bytes,
|
||||
input_bytes,
|
||||
op,
|
||||
UNUSED_RANK);
|
||||
}
|
||||
|
||||
static GlooCache::value_type create(
|
||||
GlooCache& cache,
|
||||
const DataChannelGloo::Group& group,
|
||||
const std::string& store_prefix,
|
||||
DeviceType device,
|
||||
size_t input_bytes,
|
||||
size_t count,
|
||||
THDReduceOp op) {
|
||||
auto context = cache.createContext(group, store_prefix);
|
||||
auto input_buffer = cache.createBuffer(input_bytes, device);
|
||||
|
||||
std::shared_ptr<GlooCache::algorithm_type> algo;
|
||||
if (device == DeviceType::CPU) {
|
||||
algo = std::make_shared<::gloo::AllreduceRing<T>>(
|
||||
context,
|
||||
std::initializer_list<T*>{reinterpret_cast<T*>(input_buffer.get())},
|
||||
count,
|
||||
THDToGlooReduceOp<T>(op));
|
||||
#ifdef USE_CUDA
|
||||
} else if (device == DeviceType::CUDA) {
|
||||
if (op != THDReduceSUM) {
|
||||
throw std::runtime_error(
|
||||
"Gloo backend only supports sum op for CUDA all reduce");
|
||||
}
|
||||
auto stream = THCState_getCurrentStream(THDGetCudaState());
|
||||
|
||||
#if defined(USE_GLOO_IBVERBS) && USE_GLOO_IBVERBS
|
||||
// Only enable GPU direct if the device supports it
|
||||
if (context->getDevice()->hasGPUDirect()) {
|
||||
algo = std::make_shared<::gloo::CudaAllreduceHalvingDoublingPipelined<
|
||||
T,
|
||||
::gloo::CudaDeviceWorkspace<T>>>(
|
||||
context,
|
||||
std::initializer_list<T*>{reinterpret_cast<T*>(input_buffer.get())},
|
||||
count,
|
||||
std::vector<cudaStream_t>{stream});
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
algo = std::make_shared<::gloo::CudaAllreduceHalvingDoublingPipelined<
|
||||
T,
|
||||
::gloo::CudaHostWorkspace<T>>>(
|
||||
context,
|
||||
std::initializer_list<T*>{reinterpret_cast<T*>(input_buffer.get())},
|
||||
count,
|
||||
std::vector<cudaStream_t>{stream});
|
||||
}
|
||||
#endif
|
||||
|
||||
} else {
|
||||
throw std::runtime_error("unsupported tensor device in Gloo allReduce");
|
||||
}
|
||||
|
||||
return std::make_tuple(
|
||||
algo,
|
||||
input_buffer,
|
||||
input_buffer, // we get the result in same buffer
|
||||
std::make_shared<std::mutex>());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct algorithm_spec<CollectiveType::BROADCAST, T> {
|
||||
static GlooCache::key_type key(
|
||||
THDGroup group_id,
|
||||
DeviceType device,
|
||||
size_t input_bytes,
|
||||
size_t unused_count,
|
||||
rank_type src_rank) {
|
||||
int stream = UNUSED_STREAM;
|
||||
#ifdef USE_CUDA
|
||||
if (device == DeviceType::CUDA) {
|
||||
auto cuda_stream = THCState_getCurrentStream(THDGetCudaState());
|
||||
stream = THDGetStreamId(cuda_stream);
|
||||
}
|
||||
#endif
|
||||
return std::make_tuple(
|
||||
CollectiveType::BROADCAST,
|
||||
group_id,
|
||||
device,
|
||||
stream,
|
||||
input_bytes,
|
||||
input_bytes,
|
||||
UNUSED_OP,
|
||||
src_rank);
|
||||
}
|
||||
|
||||
static GlooCache::value_type create(
|
||||
GlooCache& cache,
|
||||
const DataChannelGloo::Group& group,
|
||||
const std::string& store_prefix,
|
||||
DeviceType device,
|
||||
size_t input_bytes,
|
||||
size_t count,
|
||||
rank_type src_rank) {
|
||||
auto context = cache.createContext(group, store_prefix);
|
||||
auto input_buffer = cache.createBuffer(input_bytes, device);
|
||||
|
||||
std::shared_ptr<GlooCache::algorithm_type> algo;
|
||||
if (device == DeviceType::CPU) {
|
||||
algo = std::make_shared<::gloo::BroadcastOneToAll<T>>(
|
||||
context,
|
||||
std::initializer_list<T*>{reinterpret_cast<T*>(input_buffer.get())},
|
||||
count,
|
||||
src_rank);
|
||||
#ifdef USE_CUDA
|
||||
} else if (device == DeviceType::CUDA) {
|
||||
auto stream = THCState_getCurrentStream(THDGetCudaState());
|
||||
|
||||
#if defined(USE_GLOO_IBVERBS) && USE_GLOO_IBVERBS
|
||||
// Only enable GPU direct if the device supports it
|
||||
if (context->getDevice()->hasGPUDirect()) {
|
||||
algo = std::make_shared<
|
||||
::gloo::CudaBroadcastOneToAll<T, ::gloo::CudaDeviceWorkspace<T>>>(
|
||||
context,
|
||||
std::initializer_list<T*>{reinterpret_cast<T*>(input_buffer.get())},
|
||||
count,
|
||||
src_rank,
|
||||
0,
|
||||
std::vector<cudaStream_t>{stream});
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
algo = std::make_shared<
|
||||
::gloo::CudaBroadcastOneToAll<T, ::gloo::CudaHostWorkspace<T>>>(
|
||||
context,
|
||||
std::initializer_list<T*>{reinterpret_cast<T*>(input_buffer.get())},
|
||||
count,
|
||||
src_rank,
|
||||
0,
|
||||
std::vector<cudaStream_t>{stream});
|
||||
}
|
||||
#endif
|
||||
|
||||
} else {
|
||||
throw std::runtime_error("unsupported tensor device in Gloo broadcast");
|
||||
}
|
||||
|
||||
return std::make_tuple(
|
||||
algo,
|
||||
input_buffer,
|
||||
input_buffer, // we get the result in same buffer
|
||||
std::make_shared<std::mutex>());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> // unused
|
||||
struct algorithm_spec<CollectiveType::BARRIER, T> {
|
||||
static GlooCache::key_type key(THDGroup group_id) {
|
||||
return std::make_tuple(
|
||||
CollectiveType::BARRIER,
|
||||
group_id,
|
||||
UNUSED_DEVICE,
|
||||
UNUSED_STREAM,
|
||||
UNUSED_BYTES,
|
||||
UNUSED_BYTES,
|
||||
UNUSED_OP,
|
||||
UNUSED_RANK);
|
||||
}
|
||||
|
||||
static GlooCache::value_type create(
|
||||
GlooCache& cache,
|
||||
const DataChannelGloo::Group& group,
|
||||
const std::string& store_prefix) {
|
||||
auto context = cache.createContext(group, store_prefix);
|
||||
return std::make_tuple(
|
||||
std::make_shared<::gloo::BarrierAllToAll>(context),
|
||||
nullptr,
|
||||
nullptr,
|
||||
std::make_shared<std::mutex>());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gloo_cache
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,186 +0,0 @@
|
|||
#include <THD/base/data_channels/Store.hpp>
|
||||
#include <THD/base/ChannelUtils.hpp>
|
||||
|
||||
#include <poll.h>
|
||||
#include <unistd.h>
|
||||
#include <system_error>
|
||||
|
||||
namespace thd {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class QueryType : std::uint8_t { SET, GET, WAIT, STOP_WAITING };
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Store::StoreDeamon::StoreDeamon(int listen_socket)
|
||||
: _listen_socket(listen_socket), _keys_awaited(), _sockets() {
|
||||
_deamon = std::thread(&Store::StoreDeamon::deamon, this);
|
||||
}
|
||||
|
||||
Store::StoreDeamon::~StoreDeamon() {
|
||||
::close(_listen_socket);
|
||||
for (auto socket : _sockets) {
|
||||
if (socket != -1)
|
||||
::close(socket);
|
||||
}
|
||||
}
|
||||
|
||||
void Store::StoreDeamon::join() {
|
||||
_deamon.join();
|
||||
}
|
||||
|
||||
void Store::StoreDeamon::deamon() {
|
||||
std::vector<struct pollfd> fds;
|
||||
fds.push_back({.fd = _listen_socket, .events = POLLIN});
|
||||
|
||||
// receive the queries
|
||||
bool finished = false;
|
||||
while (!finished) {
|
||||
for (size_t i = 0; i < _sockets.size(); i++) {
|
||||
fds[i].revents = 0;
|
||||
}
|
||||
|
||||
SYSCHECK(::poll(fds.data(), fds.size(), -1));
|
||||
if (fds[0].revents != 0) {
|
||||
if (fds[0].revents ^ POLLIN)
|
||||
throw std::system_error(ECONNABORTED, std::system_category());
|
||||
|
||||
int sock_fd = std::get<0>(accept(_listen_socket));
|
||||
_sockets.push_back(sock_fd);
|
||||
_keys_awaited.push_back(0);
|
||||
fds.push_back({.fd = sock_fd, .events = POLLIN});
|
||||
}
|
||||
for (size_t rank = 0; rank < _sockets.size(); rank++) {
|
||||
if (fds[rank + 1].revents == 0)
|
||||
continue;
|
||||
|
||||
if (fds[rank + 1].revents ^ POLLIN)
|
||||
throw std::system_error(ECONNABORTED, std::system_category());
|
||||
|
||||
try {
|
||||
query(rank);
|
||||
} catch (...) {
|
||||
// There was an error when processing query. Probably an exception
|
||||
// occurred in recv/send what would indicate that socket on the other
|
||||
// side has been closed. If the closing was due to normal exit, then the
|
||||
// store should exit too. Otherwise, if it was different exception,
|
||||
// other processes will get an exception once they try to use the store.
|
||||
finished = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* query communicates with the worker. The format
|
||||
* of the query is as follows:
|
||||
* type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
|
||||
* or, in the case of wait
|
||||
* type of query | number of args | size of arg1 | arg1 | ...
|
||||
*/
|
||||
void Store::StoreDeamon::query(rank_type rank) {
|
||||
int socket = _sockets[rank];
|
||||
QueryType qt;
|
||||
recv_bytes<QueryType>(socket, &qt, 1);
|
||||
if (qt == QueryType::SET) {
|
||||
std::string key = recv_string(socket);
|
||||
_store[key] = recv_vector<char>(socket);
|
||||
// On "set", wake up all of the processes that wait
|
||||
// for keys already in the store
|
||||
auto to_wake = _waiting.find(key);
|
||||
if (to_wake != _waiting.end()) {
|
||||
for (int proc : to_wake->second) {
|
||||
if (--_keys_awaited[proc] == 0)
|
||||
send_value<QueryType>(_sockets[proc], QueryType::STOP_WAITING);
|
||||
}
|
||||
_waiting.erase(to_wake);
|
||||
}
|
||||
} else if (qt == QueryType::GET) {
|
||||
std::string key = recv_string(socket);
|
||||
std::vector<char> data = _store.at(key);
|
||||
send_vector(socket, data);
|
||||
} else if (qt == QueryType::WAIT) {
|
||||
size_type nargs;
|
||||
recv_bytes<size_type>(socket, &nargs, 1);
|
||||
std::vector<std::string> keys(nargs);
|
||||
for (size_t i = 0; i < nargs; i++) {
|
||||
keys[i] = recv_string(socket);
|
||||
}
|
||||
if (checkAndUpdate(keys)) {
|
||||
send_value<QueryType>(socket, QueryType::STOP_WAITING);
|
||||
} else {
|
||||
for (auto& key : keys) {
|
||||
_waiting[key].push_back(rank);
|
||||
}
|
||||
_keys_awaited[rank] = keys.size();
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("expected a query type");
|
||||
}
|
||||
}
|
||||
|
||||
bool Store::StoreDeamon::checkAndUpdate(std::vector<std::string>& keys) const {
|
||||
bool ret = true;
|
||||
for (auto it = keys.begin(); it != keys.end();) {
|
||||
if (_store.count(*it) == 0) {
|
||||
ret = false;
|
||||
it++;
|
||||
} else {
|
||||
it = keys.erase(it);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Store::Store(const std::string& addr, port_type port, int listen_socket)
|
||||
: _socket(-1),
|
||||
_store_addr(addr),
|
||||
_store_port(port),
|
||||
_store_thread(nullptr) {
|
||||
if (listen_socket != Store::CLIENT_ONLY) {
|
||||
_store_thread =
|
||||
std::unique_ptr<StoreDeamon>(new StoreDeamon(listen_socket));
|
||||
}
|
||||
|
||||
_socket = connect(_store_addr, _store_port);
|
||||
}
|
||||
|
||||
Store::~Store() {
|
||||
::close(_socket);
|
||||
|
||||
// Store deamon should end because of closed connection.
|
||||
if (_store_thread) {
|
||||
_store_thread->join();
|
||||
}
|
||||
}
|
||||
|
||||
void Store::set(const std::string& key, const std::vector<char>& data) {
|
||||
send_value<QueryType>(_socket, QueryType::SET);
|
||||
send_string(_socket, key, true);
|
||||
send_vector<char>(_socket, data);
|
||||
}
|
||||
|
||||
std::vector<char> Store::get(const std::string& key) {
|
||||
wait({key});
|
||||
send_value<QueryType>(_socket, QueryType::GET);
|
||||
send_string(_socket, key);
|
||||
return recv_vector<char>(_socket);
|
||||
}
|
||||
|
||||
void Store::wait(const std::vector<std::string>& keys) {
|
||||
send_value<QueryType>(_socket, QueryType::WAIT);
|
||||
size_type nkeys = keys.size();
|
||||
send_bytes<size_type>(_socket, &nkeys, 1, (nkeys > 0));
|
||||
for (size_t i = 0; i < nkeys; i++) {
|
||||
send_string(_socket, keys[i], (i != (nkeys - 1)));
|
||||
}
|
||||
// after sending the query, wait for a 'stop_waiting' response
|
||||
QueryType qr;
|
||||
recv_bytes<QueryType>(_socket, &qr, 1);
|
||||
if (qr != QueryType::STOP_WAITING)
|
||||
throw std::runtime_error("stop_waiting response expected");
|
||||
}
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/base/ChannelUtils.hpp>
|
||||
#include <gloo/rendezvous/store.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace thd {
|
||||
|
||||
struct Store : public ::gloo::rendezvous::Store {
|
||||
private:
|
||||
struct StoreDeamon {
|
||||
StoreDeamon() = delete;
|
||||
StoreDeamon(int listen_socket);
|
||||
~StoreDeamon();
|
||||
|
||||
void join();
|
||||
|
||||
private:
|
||||
using store_type = std::unordered_map<std::string, std::vector<char>>;
|
||||
|
||||
void deamon();
|
||||
void query(rank_type rank);
|
||||
bool checkAndUpdate(std::vector<std::string>& keys) const;
|
||||
|
||||
int _listen_socket;
|
||||
|
||||
std::thread _deamon;
|
||||
store_type _store;
|
||||
std::unordered_map<std::string, std::vector<rank_type>> _waiting;
|
||||
std::vector<size_t> _keys_awaited;
|
||||
std::vector<int> _sockets;
|
||||
};
|
||||
|
||||
public:
|
||||
// A special value for listen_socket which doesn't launch the deamon
|
||||
static constexpr int CLIENT_ONLY = -1;
|
||||
|
||||
Store(
|
||||
const std::string& addr,
|
||||
port_type port,
|
||||
int listen_socket = CLIENT_ONLY);
|
||||
~Store();
|
||||
|
||||
void set(const std::string& key, const std::vector<char>& data) override;
|
||||
std::vector<char> get(const std::string& key) override;
|
||||
void wait(const std::vector<std::string>& keys) override;
|
||||
|
||||
private:
|
||||
int _listen_socket;
|
||||
int _socket;
|
||||
std::string _store_addr;
|
||||
port_type _store_port;
|
||||
std::unique_ptr<StoreDeamon>
|
||||
_store_thread; // it is initialised only in a selected process
|
||||
};
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
#include <THD/base/init_methods/InitMethod.hpp>
|
||||
|
||||
#ifdef THD_INIT_EXTENSION_H
|
||||
#define INCF(F) INCF_(F)
|
||||
#define INCF_(F) #F
|
||||
#include INCF(THD_INIT_EXTENSION_H)
|
||||
#endif
|
||||
|
||||
namespace thd {
|
||||
namespace init {
|
||||
|
||||
InitMethod::Config initTCP(
|
||||
std::string argument,
|
||||
int world_size_r,
|
||||
std::string group_name,
|
||||
int rank);
|
||||
InitMethod::Config initFile(
|
||||
std::string argument,
|
||||
int world_size_r,
|
||||
std::string group_name,
|
||||
int rank);
|
||||
InitMethod::Config initEnv(
|
||||
std::string argument,
|
||||
int world_size_r,
|
||||
std::string group_name,
|
||||
int rank);
|
||||
|
||||
InitMethodFuncMap initMethods(
|
||||
{{"env://", ::thd::init::initEnv},
|
||||
{"file://", ::thd::init::initFile},
|
||||
{"tcp://", ::thd::init::initTCP}
|
||||
|
||||
#ifdef THD_INIT_EXTENSION_H
|
||||
,
|
||||
/**
|
||||
* Additional method pairs can be defined in THD_INIT_EXTENSION_H header
|
||||
* to extend the init methods
|
||||
*/
|
||||
THD_INIT_EXTENSION_METHODS
|
||||
#endif
|
||||
|
||||
});
|
||||
|
||||
} // namespace init
|
||||
|
||||
InitMethod::Config getInitConfig(
|
||||
std::string argument,
|
||||
int world_size,
|
||||
std::string group_name,
|
||||
int rank) {
|
||||
InitMethod::Config config;
|
||||
|
||||
for (auto& methodPair : init::initMethods) {
|
||||
auto initMethodPrefix = methodPair.first;
|
||||
auto initMethodFunc = methodPair.second;
|
||||
if (argument.find(initMethodPrefix) == 0) {
|
||||
config = initMethodFunc(argument, world_size, group_name, rank);
|
||||
}
|
||||
}
|
||||
config.validate();
|
||||
return config;
|
||||
}
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/base/ChannelUtils.hpp>
|
||||
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace thd {
|
||||
|
||||
struct InitMethod {
|
||||
struct Config {
|
||||
struct MasterConfig {
|
||||
int listen_socket;
|
||||
port_type listen_port;
|
||||
};
|
||||
|
||||
struct WorkerConfig {
|
||||
std::string master_addr;
|
||||
port_type master_port;
|
||||
};
|
||||
|
||||
Config() {
|
||||
rank = -1;
|
||||
world_size = 0;
|
||||
public_address = "";
|
||||
master.listen_socket = -1;
|
||||
master.listen_port = 0;
|
||||
worker.master_addr = "";
|
||||
worker.master_port = 0;
|
||||
}
|
||||
|
||||
rank_type rank;
|
||||
rank_type world_size;
|
||||
std::string public_address;
|
||||
MasterConfig master;
|
||||
WorkerConfig worker;
|
||||
|
||||
void validate() {
|
||||
if (world_size == 0)
|
||||
throw std::logic_error("world_size was not set in config");
|
||||
|
||||
if (rank >= world_size || rank == -1)
|
||||
throw std::logic_error("rank was not set in config");
|
||||
|
||||
if (public_address == "")
|
||||
throw std::logic_error("public_address was not set in config");
|
||||
|
||||
if (rank == 0) {
|
||||
if (master.listen_socket < 0)
|
||||
throw std::logic_error("master:listen_socket was not set in config");
|
||||
|
||||
if (master.listen_port <= 0)
|
||||
throw std::logic_error("master:listen_port was not set in config");
|
||||
} else {
|
||||
if (worker.master_addr == "")
|
||||
throw std::logic_error("worker:master_addr was not set in config");
|
||||
|
||||
if (worker.master_port <= 0)
|
||||
throw std::logic_error("worker:master_port was not set in config");
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
namespace init {
|
||||
|
||||
using InitMethodFuncMap = std::unordered_map<
|
||||
std::string,
|
||||
std::function<
|
||||
::thd::InitMethod::Config(std::string, int, std::string, int)>>;
|
||||
|
||||
} // namespace init
|
||||
|
||||
InitMethod::Config getInitConfig(
|
||||
std::string argument,
|
||||
int world_size = -1,
|
||||
std::string group_name = "",
|
||||
int rank = -1);
|
||||
|
||||
} // namespace thd
|
||||
|
|
@ -1,86 +0,0 @@
|
|||
#include <THD/base/init_methods/InitMethod.hpp>
|
||||
#include <THD/base/init_methods/InitMethodUtils.hpp>
|
||||
|
||||
namespace thd {
|
||||
namespace init {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char RANK_ENV[] = "RANK";
|
||||
constexpr char WORLD_SIZE_ENV[] = "WORLD_SIZE";
|
||||
constexpr char MASTER_PORT_ENV[] = "MASTER_PORT";
|
||||
constexpr char MASTER_ADDR_ENV[] = "MASTER_ADDR";
|
||||
|
||||
const char* mustGetEnv(const char* env) {
|
||||
const char* value = std::getenv(env);
|
||||
if (value == nullptr) {
|
||||
throw std::logic_error(
|
||||
std::string("") + "failed to read the " + env +
|
||||
" environmental variable; maybe you forgot to set it?");
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
std::tuple<std::string, port_type> loadWorkerEnv() {
|
||||
std::string str_port = mustGetEnv(MASTER_PORT_ENV);
|
||||
auto port = convertToPort(std::stoul(str_port));
|
||||
return std::make_tuple(mustGetEnv(MASTER_ADDR_ENV), port);
|
||||
}
|
||||
|
||||
rank_type maybeLoadEnv(
|
||||
const char* env_name,
|
||||
int value,
|
||||
std::string parameter_name) {
|
||||
const char* env_value_str = std::getenv(env_name);
|
||||
int env_value = value;
|
||||
if (env_value_str != nullptr)
|
||||
env_value = std::stol(env_value_str);
|
||||
if (value != -1 && env_value != value)
|
||||
throw std::runtime_error(
|
||||
parameter_name +
|
||||
" specified both as an "
|
||||
"environmental variable and to the initializer");
|
||||
if (env_value == -1)
|
||||
throw std::runtime_error(
|
||||
parameter_name +
|
||||
" is not set but it is required for "
|
||||
"env:// init method");
|
||||
|
||||
return convertToRank(env_value);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
InitMethod::Config initEnv(
|
||||
std::string argument, /* unused */
|
||||
int world_size_r,
|
||||
std::string group_name,
|
||||
int rank) {
|
||||
InitMethod::Config config;
|
||||
|
||||
config.rank = maybeLoadEnv(RANK_ENV, rank, "rank");
|
||||
config.world_size = maybeLoadEnv(WORLD_SIZE_ENV, world_size_r, "world_size");
|
||||
|
||||
if (group_name != "") {
|
||||
throw std::runtime_error(
|
||||
"group_name is not supported in env:// init method");
|
||||
}
|
||||
|
||||
if (config.rank == 0) {
|
||||
config.master.listen_port =
|
||||
convertToPort(std::stoul(mustGetEnv(MASTER_PORT_ENV)));
|
||||
std::tie(config.master.listen_socket, std::ignore) =
|
||||
listen(config.master.listen_port);
|
||||
config.public_address =
|
||||
discoverWorkers(config.master.listen_socket, config.world_size);
|
||||
} else {
|
||||
std::tie(config.worker.master_addr, config.worker.master_port) =
|
||||
loadWorkerEnv();
|
||||
std::tie(std::ignore, config.public_address) =
|
||||
discoverMaster({config.worker.master_addr}, config.worker.master_port);
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
} // namespace init
|
||||
} // namespace thd
|
||||
|
|
@ -1,249 +0,0 @@
|
|||
#include <THD/base/init_methods/InitMethod.hpp>
|
||||
#include <THD/base/init_methods/InitMethodUtils.hpp>
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <system_error>
|
||||
#include <thread>
|
||||
|
||||
namespace thd {
|
||||
namespace init {
|
||||
|
||||
namespace {
|
||||
|
||||
void lockLoop(int fd, struct flock& oflock) {
|
||||
while (true) {
|
||||
int err = ::fcntl(fd, F_SETLKW, &oflock);
|
||||
if (err == 0)
|
||||
break;
|
||||
else if (errno == EINTR)
|
||||
continue;
|
||||
else
|
||||
throw std::system_error(errno, std::system_category());
|
||||
}
|
||||
}
|
||||
|
||||
void lockFile(int fd) {
|
||||
struct flock oflock;
|
||||
oflock.l_type = F_WRLCK; // write lock
|
||||
oflock.l_whence = SEEK_SET;
|
||||
oflock.l_start = 0;
|
||||
oflock.l_len = 0; // lock whole file
|
||||
lockLoop(fd, oflock);
|
||||
}
|
||||
|
||||
void unlockFile(int fd) {
|
||||
struct flock oflock;
|
||||
oflock.l_type = F_UNLCK; // unlock
|
||||
oflock.l_whence = SEEK_SET;
|
||||
oflock.l_start = 0;
|
||||
oflock.l_len = 0; // unlock whole file
|
||||
lockLoop(fd, oflock);
|
||||
}
|
||||
|
||||
// file_descriptor, number_of_lines_in_file
|
||||
std::pair<int, size_t> waitForGroup(
|
||||
std::string file_path,
|
||||
std::string group_name,
|
||||
std::fstream& file) {
|
||||
int fd;
|
||||
std::string content;
|
||||
struct stat fd_stat, path_stat;
|
||||
// Loop until the file is either empty, or filled with ours group_name
|
||||
while (true) {
|
||||
// Loop until we have an open, locked and valid file
|
||||
while (true) {
|
||||
fd = ::open(file_path.c_str(), O_RDWR | O_CREAT, 0644);
|
||||
if (fd == -1) {
|
||||
throw std::system_error(
|
||||
fd,
|
||||
std::generic_category(),
|
||||
"cannot access '" + file_path + "' file");
|
||||
}
|
||||
lockFile(fd);
|
||||
|
||||
// This helps prevent a race when while we were waiting for the lock,
|
||||
// the file has been removed from the fs
|
||||
SYSCHECK(::fstat(fd, &fd_stat));
|
||||
int err = stat(file_path.c_str(), &path_stat);
|
||||
if (err == 0 && fd_stat.st_dev == path_stat.st_dev &&
|
||||
fd_stat.st_ino == path_stat.st_ino) {
|
||||
break;
|
||||
}
|
||||
::close(fd);
|
||||
}
|
||||
|
||||
file.close();
|
||||
file.open(file_path);
|
||||
content = {std::istreambuf_iterator<char>(file),
|
||||
std::istreambuf_iterator<char>()};
|
||||
|
||||
if (content.length() == 0 || content.find(group_name) == 0)
|
||||
break;
|
||||
|
||||
unlockFile(fd);
|
||||
::close(fd);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
||||
}
|
||||
|
||||
return {fd, std::count(content.begin(), content.end(), '\n')};
|
||||
}
|
||||
|
||||
size_t waitForData(int fd, std::fstream& file, rank_type world_size) {
|
||||
size_t lines = 0;
|
||||
// Wait until all processes will write their info
|
||||
while (lines < world_size) {
|
||||
unlockFile(fd);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
||||
lockFile(fd);
|
||||
|
||||
file.seekp(0, std::ios_base::beg);
|
||||
file.sync();
|
||||
std::string content = {std::istreambuf_iterator<char>(file),
|
||||
std::istreambuf_iterator<char>()};
|
||||
lines = std::count(content.begin(), content.end(), '\n');
|
||||
}
|
||||
|
||||
file.seekp(0, std::ios_base::beg);
|
||||
return lines;
|
||||
}
|
||||
|
||||
// master_port, master_addrs, ranks
|
||||
std::tuple<port_type, std::vector<std::string>, std::vector<int>> parseFile(
|
||||
std::fstream& file,
|
||||
rank_type world_size,
|
||||
std::string group_name) {
|
||||
port_type master_port;
|
||||
std::vector<std::string> master_addrs;
|
||||
std::vector<int> ranks(world_size);
|
||||
// Parse the file
|
||||
for (size_t i = 0; i < world_size; ++i) {
|
||||
std::string proc_group_name;
|
||||
size_t proc_addrs_count;
|
||||
int proc_rank;
|
||||
port_type proc_port;
|
||||
|
||||
file >> proc_group_name >> proc_rank >> proc_port >> proc_addrs_count;
|
||||
if (proc_group_name != group_name) {
|
||||
throw std::logic_error("proc_group_name != group_name");
|
||||
}
|
||||
|
||||
std::vector<std::string> proc_addrs(proc_addrs_count);
|
||||
for (auto& str : proc_addrs) {
|
||||
file >> str;
|
||||
}
|
||||
|
||||
ranks[i] = proc_rank;
|
||||
/*
|
||||
* Master data is found only when:
|
||||
* 1. proc_rank has been manually assigned as 0 (first condition)
|
||||
* 2. process has no assigned rank, and it hasn't been initialized yet.
|
||||
*/
|
||||
if (proc_rank == 0 || (proc_rank == -1 && master_addrs.size() == 0)) {
|
||||
master_port = proc_port;
|
||||
master_addrs = std::move(proc_addrs);
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure there are no duplicates
|
||||
for (size_t i = 0; i < ranks.size(); ++i) {
|
||||
for (size_t j = i + 1; j < ranks.size(); ++j) {
|
||||
if (ranks[i] >= 0 && (ranks[i] == ranks[j]))
|
||||
throw std::logic_error("more than one node have assigned same rank");
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(master_port, master_addrs, ranks);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
InitMethod::Config initFile(
|
||||
std::string argument,
|
||||
int world_size_r,
|
||||
std::string group_name,
|
||||
int assigned_rank) {
|
||||
group_name.append("#"); // To make sure it's not empty
|
||||
std::string file_path = argument.substr(7); // chop "file://"
|
||||
rank_type world_size;
|
||||
try {
|
||||
world_size = convertToRank(world_size_r);
|
||||
} catch (std::exception& e) {
|
||||
if (world_size_r == -1) {
|
||||
throw std::invalid_argument(
|
||||
"world_size is not set - it is required for "
|
||||
"`file://` init methods with this backend");
|
||||
}
|
||||
throw std::invalid_argument("invalid world_size");
|
||||
}
|
||||
|
||||
InitMethod::Config config;
|
||||
int fd;
|
||||
size_t order;
|
||||
std::fstream file;
|
||||
|
||||
std::tie(fd, order) = waitForGroup(file_path, group_name, file);
|
||||
// NOTE: the function returns a locked fd
|
||||
|
||||
int listen_socket;
|
||||
port_type port;
|
||||
std::tie(listen_socket, port) = listen();
|
||||
|
||||
// Append our information
|
||||
auto if_addrs = getInterfaceAddresses();
|
||||
file << group_name << ' ' << assigned_rank << ' ' << port << ' '
|
||||
<< if_addrs.size();
|
||||
for (auto addr_str : if_addrs) {
|
||||
file << ' ' << addr_str;
|
||||
}
|
||||
file << std::endl;
|
||||
|
||||
size_t lines = waitForData(fd, file, world_size);
|
||||
|
||||
port_type master_port = -1;
|
||||
std::vector<std::string> master_addrs;
|
||||
std::vector<int> ranks;
|
||||
std::tie(master_port, master_addrs, ranks) =
|
||||
parseFile(file, world_size, group_name);
|
||||
|
||||
config.rank = getRank(ranks, assigned_rank, order);
|
||||
|
||||
// Last process removes the file.
|
||||
file.seekp(0, std::ios_base::end);
|
||||
file << std::endl;
|
||||
lines++;
|
||||
if (lines == 2 * world_size) {
|
||||
::remove(file_path.c_str());
|
||||
}
|
||||
|
||||
file.close();
|
||||
unlockFile(fd);
|
||||
|
||||
config.world_size = world_size;
|
||||
if (config.rank == 0) {
|
||||
config.public_address = discoverWorkers(listen_socket, world_size);
|
||||
config.master = {
|
||||
.listen_socket = listen_socket,
|
||||
.listen_port = master_port,
|
||||
};
|
||||
} else {
|
||||
::close(listen_socket);
|
||||
|
||||
std::string master_address;
|
||||
std::tie(master_address, config.public_address) =
|
||||
discoverMaster(master_addrs, master_port);
|
||||
config.worker = {
|
||||
.master_addr = master_address,
|
||||
.master_port = master_port,
|
||||
};
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
} // namespace init
|
||||
} // namespace thd
|
||||
|
|
@ -1,392 +0,0 @@
|
|||
#include <THD/base/init_methods/InitMethod.hpp>
|
||||
#include <THD/base/init_methods/InitMethodUtils.hpp>
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
#include <iterator>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
constexpr size_t num_rand_bytes = 32;
|
||||
constexpr size_t max_msg_length = 4000;
|
||||
|
||||
namespace thd {
|
||||
namespace init {
|
||||
namespace {
|
||||
|
||||
std::string getRandomString() {
|
||||
static constexpr char charset[] =
|
||||
"0123456789"
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"abcdefghijklmnopqrstuvwxyz";
|
||||
int fd;
|
||||
uint8_t rand_bytes[num_rand_bytes];
|
||||
ssize_t bytes_read;
|
||||
SYSCHECK(fd = open("/dev/urandom", O_RDONLY));
|
||||
SYSCHECK(bytes_read = read(fd, &rand_bytes, sizeof(rand_bytes)));
|
||||
if (bytes_read != sizeof(rand_bytes))
|
||||
throw std::runtime_error("failed to read from /dev/urandom");
|
||||
SYSCHECK(::close(fd));
|
||||
|
||||
std::string str;
|
||||
str.reserve(num_rand_bytes);
|
||||
for (uint8_t* byte = rand_bytes; byte != rand_bytes + num_rand_bytes;
|
||||
++byte) {
|
||||
str.push_back(charset[(*byte) % (sizeof(charset) - 1)]);
|
||||
}
|
||||
return str;
|
||||
}
|
||||
|
||||
struct MulticastMessage {
|
||||
std::string uid;
|
||||
std::string group_name;
|
||||
std::vector<std::string> addresses;
|
||||
port_type port;
|
||||
int rank;
|
||||
|
||||
MulticastMessage(std::string group_name, port_type port, int rank)
|
||||
: uid(getRandomString()),
|
||||
group_name(group_name),
|
||||
addresses(getInterfaceAddresses()),
|
||||
port(port),
|
||||
rank(rank) {}
|
||||
|
||||
MulticastMessage(std::string msg) {
|
||||
std::istringstream ss{msg};
|
||||
ss >> uid >> group_name >> port >> rank;
|
||||
addresses = {std::istream_iterator<std::string>(ss),
|
||||
std::istream_iterator<std::string>()};
|
||||
}
|
||||
|
||||
std::string pack() {
|
||||
std::ostringstream ss;
|
||||
ss << uid << ' ' << group_name << ' ' << port << ' ' << rank;
|
||||
for (const auto& address : addresses) {
|
||||
ss << ' ' << address;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
bool isMulticastAddress(struct sockaddr* address) {
|
||||
if (address->sa_family == AF_INET) {
|
||||
struct sockaddr_in* address_ipv4 =
|
||||
reinterpret_cast<struct sockaddr_in*>(address);
|
||||
uint32_t host_addr = ntohl(address_ipv4->sin_addr.s_addr);
|
||||
return (host_addr & 0xF0000000) == 0xE0000000;
|
||||
} else if (address->sa_family == AF_INET6) {
|
||||
struct sockaddr_in6* address_ipv6 =
|
||||
reinterpret_cast<struct sockaddr_in6*>(address);
|
||||
auto& addr_bytes = address_ipv6->sin6_addr.s6_addr;
|
||||
// NOTE: address is in network byte order
|
||||
return addr_bytes[0] == 0xff;
|
||||
} else {
|
||||
throw std::invalid_argument("unsupported address family");
|
||||
}
|
||||
}
|
||||
|
||||
int bindMulticastSocket(
|
||||
struct sockaddr* address,
|
||||
struct sockaddr_storage* sock_addr,
|
||||
int timeout_sec = 1,
|
||||
int ttl = 1) {
|
||||
struct timeval timeout = {.tv_sec = timeout_sec, .tv_usec = 0};
|
||||
|
||||
int socket, optval;
|
||||
SYSCHECK(socket = ::socket(address->sa_family, SOCK_DGRAM, 0));
|
||||
optval = 1;
|
||||
SYSCHECK(
|
||||
::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(int)));
|
||||
|
||||
if (address->sa_family == AF_INET) {
|
||||
struct sockaddr_in* sock_addr_ipv4 =
|
||||
reinterpret_cast<struct sockaddr_in*>(sock_addr);
|
||||
struct sockaddr_in* address_ipv4 =
|
||||
reinterpret_cast<struct sockaddr_in*>(address);
|
||||
std::memset(sock_addr_ipv4, 0, sizeof(*sock_addr_ipv4));
|
||||
sock_addr_ipv4->sin_family = address->sa_family;
|
||||
sock_addr_ipv4->sin_addr.s_addr = INADDR_ANY;
|
||||
sock_addr_ipv4->sin_port = address_ipv4->sin_port;
|
||||
|
||||
SYSCHECK(::bind(
|
||||
socket,
|
||||
reinterpret_cast<struct sockaddr*>(sock_addr_ipv4),
|
||||
sizeof(*sock_addr_ipv4)));
|
||||
SYSCHECK(::setsockopt(
|
||||
socket, SOL_SOCKET, SO_RCVTIMEO, (char*)&timeout, sizeof(timeout)));
|
||||
|
||||
struct ip_mreq mreq;
|
||||
mreq.imr_multiaddr = address_ipv4->sin_addr;
|
||||
mreq.imr_interface.s_addr = htonl(INADDR_ANY);
|
||||
SYSCHECK(
|
||||
::setsockopt(socket, IPPROTO_IP, IP_MULTICAST_TTL, &ttl, sizeof(ttl)));
|
||||
SYSCHECK(::setsockopt(
|
||||
socket, IPPROTO_IP, IP_ADD_MEMBERSHIP, &mreq, sizeof(mreq)));
|
||||
|
||||
sock_addr_ipv4->sin_addr = address_ipv4->sin_addr;
|
||||
} else if (address->sa_family == AF_INET6) {
|
||||
struct sockaddr_in6* sock_addr_ipv6 =
|
||||
reinterpret_cast<struct sockaddr_in6*>(sock_addr);
|
||||
struct sockaddr_in6* address_ipv6 =
|
||||
reinterpret_cast<struct sockaddr_in6*>(address);
|
||||
std::memset(sock_addr_ipv6, 0, sizeof(*sock_addr_ipv6));
|
||||
sock_addr_ipv6->sin6_family = address->sa_family;
|
||||
sock_addr_ipv6->sin6_addr = in6addr_any;
|
||||
sock_addr_ipv6->sin6_port = address_ipv6->sin6_port;
|
||||
|
||||
SYSCHECK(::bind(
|
||||
socket,
|
||||
reinterpret_cast<struct sockaddr*>(sock_addr_ipv6),
|
||||
sizeof(*sock_addr_ipv6)));
|
||||
SYSCHECK(::setsockopt(
|
||||
socket, SOL_SOCKET, SO_RCVTIMEO, (char*)&timeout, sizeof(timeout)));
|
||||
|
||||
struct ipv6_mreq mreq;
|
||||
mreq.ipv6mr_multiaddr = address_ipv6->sin6_addr;
|
||||
mreq.ipv6mr_interface = 0;
|
||||
SYSCHECK(::setsockopt(
|
||||
socket, IPPROTO_IPV6, IPV6_MULTICAST_HOPS, &ttl, sizeof(ttl)));
|
||||
SYSCHECK(::setsockopt(
|
||||
socket, IPPROTO_IPV6, IPV6_JOIN_GROUP, &mreq, sizeof(mreq)));
|
||||
|
||||
sock_addr_ipv6->sin6_addr = address_ipv6->sin6_addr;
|
||||
}
|
||||
|
||||
return socket;
|
||||
}
|
||||
|
||||
// messages
|
||||
std::vector<MulticastMessage> getMessages(
|
||||
struct sockaddr* addr,
|
||||
rank_type world_size,
|
||||
std::string group_name,
|
||||
std::string packed_msg) {
|
||||
struct sockaddr_storage sock_addr;
|
||||
int socket = bindMulticastSocket(addr, &sock_addr);
|
||||
// NOTE: Multicast membership is dropped on close
|
||||
ResourceGuard socket_guard([socket]() { ::close(socket); });
|
||||
|
||||
std::set<std::string> msgs = {packed_msg};
|
||||
|
||||
char recv_message[max_msg_length];
|
||||
if (packed_msg.length() + 1 > max_msg_length) {
|
||||
throw std::logic_error("message too long for multicast init");
|
||||
}
|
||||
|
||||
auto broadcast = [socket, &sock_addr, &packed_msg]() {
|
||||
SYSCHECK(::sendto(
|
||||
socket,
|
||||
packed_msg.c_str(),
|
||||
packed_msg.size() + 1,
|
||||
0,
|
||||
reinterpret_cast<struct sockaddr*>(&sock_addr),
|
||||
sock_addr.ss_family == AF_INET ? sizeof(struct sockaddr_in)
|
||||
: sizeof(struct sockaddr_in6)));
|
||||
};
|
||||
|
||||
broadcast();
|
||||
|
||||
// Wait for messages from all processes
|
||||
while (msgs.size() < world_size) {
|
||||
try {
|
||||
SYSCHECK(::recv(socket, recv_message, sizeof(recv_message), 0));
|
||||
std::string recv_message_str(recv_message);
|
||||
|
||||
if (recv_message_str == packed_msg)
|
||||
continue; // ignore multicast loopback
|
||||
|
||||
// We should ignore messages coming from different group
|
||||
auto recv_msg = MulticastMessage(recv_message_str);
|
||||
if (recv_msg.group_name != group_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
msgs.insert(
|
||||
recv_message_str); // set will automatically deduplicate messages
|
||||
} catch (const std::system_error& e) {
|
||||
// Check if this was really a timeout from `recvfrom` or a different
|
||||
// error.
|
||||
if (errno != EAGAIN && errno != EWOULDBLOCK)
|
||||
throw;
|
||||
}
|
||||
|
||||
broadcast();
|
||||
}
|
||||
|
||||
// Just to decrease the probability of packet loss deadlocking the system
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||
broadcast();
|
||||
|
||||
std::vector<MulticastMessage> unpacked_msgs;
|
||||
for (auto& msg : msgs) {
|
||||
unpacked_msgs.emplace_back(msg);
|
||||
}
|
||||
|
||||
return unpacked_msgs;
|
||||
}
|
||||
|
||||
InitMethod::Config initTCPMaster(
|
||||
std::string address,
|
||||
std::string str_port,
|
||||
rank_type world_size,
|
||||
int assigned_rank) {
|
||||
InitMethod::Config config;
|
||||
if (assigned_rank == -1) {
|
||||
throw std::invalid_argument(
|
||||
"tcp:// method with non-multicast addresses "
|
||||
"requires manual rank assignment");
|
||||
}
|
||||
|
||||
config.rank = convertToRank(assigned_rank);
|
||||
config.world_size = world_size;
|
||||
auto port = convertToPort(std::stoul(str_port));
|
||||
if (config.rank == 0) {
|
||||
config.master.listen_port = port;
|
||||
std::tie(config.master.listen_socket, std::ignore) = listen(port);
|
||||
config.public_address =
|
||||
discoverWorkers(config.master.listen_socket, world_size);
|
||||
} else {
|
||||
config.worker.master_addr = address;
|
||||
config.worker.master_port = port;
|
||||
std::tie(std::ignore, config.public_address) =
|
||||
discoverMaster({address}, port);
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
InitMethod::Config initTCPMulticast(
|
||||
std::string group_name,
|
||||
rank_type world_size,
|
||||
int assigned_rank,
|
||||
struct sockaddr* addr) {
|
||||
InitMethod::Config config;
|
||||
|
||||
int listen_socket;
|
||||
port_type listen_port;
|
||||
std::tie(listen_socket, listen_port) = listen();
|
||||
ResourceGuard listen_socket_guard(
|
||||
[listen_socket]() { ::close(listen_socket); });
|
||||
|
||||
MulticastMessage msg{group_name, listen_port, assigned_rank};
|
||||
std::string packed_msg = msg.pack();
|
||||
|
||||
std::vector<MulticastMessage> msgs =
|
||||
getMessages(addr, world_size, group_name, packed_msg);
|
||||
|
||||
std::vector<MulticastMessage*> sorted_msgs(msgs.size());
|
||||
|
||||
// Pre-fill sorted_msgs with processes that had their ranks assigned manually
|
||||
for (auto& msg : msgs) {
|
||||
if (msg.rank >= 0) {
|
||||
if (sorted_msgs[msg.rank] != nullptr)
|
||||
throw std::logic_error("more than one node have assigned same rank");
|
||||
sorted_msgs[msg.rank] = &msg;
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: msgs are already sorted lexicographically, so we can greedily
|
||||
// insert them into free slots
|
||||
size_t free_pos = 0;
|
||||
for (auto& msg : msgs) {
|
||||
if (msg.rank >= 0)
|
||||
continue; // These were sorted in the previous loop
|
||||
while (sorted_msgs[free_pos] != nullptr)
|
||||
free_pos++;
|
||||
sorted_msgs[free_pos] = &msg;
|
||||
}
|
||||
|
||||
auto& master_msg = *sorted_msgs[0];
|
||||
for (size_t rank = 0; rank < sorted_msgs.size(); ++rank) {
|
||||
if (packed_msg == sorted_msgs[rank]->pack()) {
|
||||
config.rank = rank;
|
||||
config.world_size = world_size;
|
||||
if (config.rank == 0) {
|
||||
listen_socket_guard.release();
|
||||
config.master = {
|
||||
.listen_socket = listen_socket,
|
||||
.listen_port = master_msg.port,
|
||||
};
|
||||
|
||||
config.public_address = discoverWorkers(listen_socket, world_size);
|
||||
} else {
|
||||
std::string master_address;
|
||||
std::tie(master_address, config.public_address) =
|
||||
discoverMaster(master_msg.addresses, master_msg.port);
|
||||
config.worker = {
|
||||
.master_addr = master_address,
|
||||
.master_port = master_msg.port,
|
||||
};
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
InitMethod::Config initTCP(
|
||||
std::string argument,
|
||||
int world_size_r,
|
||||
std::string group_name,
|
||||
int rank) {
|
||||
group_name.append("#"); // To make sure it's not empty
|
||||
argument.erase(0, 6); // chop "tcp://"
|
||||
rank_type world_size;
|
||||
try {
|
||||
world_size = convertToRank(world_size_r);
|
||||
} catch (std::exception& e) {
|
||||
if (world_size_r == -1) {
|
||||
throw std::invalid_argument(
|
||||
"world_size is not set - it is required for "
|
||||
"`tcp://` init methods with this backend");
|
||||
}
|
||||
throw std::invalid_argument("invalid world_size");
|
||||
}
|
||||
|
||||
// Parse arguments
|
||||
std::string address, str_port;
|
||||
std::tie(address, str_port) = splitAddress(argument);
|
||||
|
||||
// Resolve addr and select init method
|
||||
struct addrinfo hints = {0};
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
struct addrinfo* res;
|
||||
if (::getaddrinfo(address.c_str(), str_port.c_str(), &hints, &res)) {
|
||||
throw std::invalid_argument("invalid init address");
|
||||
}
|
||||
ResourceGuard res_guard([res]() { ::freeaddrinfo(res); });
|
||||
|
||||
for (struct addrinfo* head = res; head != NULL; head = head->ai_next) {
|
||||
if (head->ai_family != AF_INET && head->ai_family != AF_INET6)
|
||||
continue;
|
||||
try {
|
||||
if (isMulticastAddress(head->ai_addr)) {
|
||||
return initTCPMulticast(group_name, world_size, rank, head->ai_addr);
|
||||
} else {
|
||||
return initTCPMaster(address, str_port, world_size, rank);
|
||||
}
|
||||
} catch (std::exception& e) {
|
||||
if (!head->ai_next)
|
||||
throw;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("failed to initialize THD using given address");
|
||||
}
|
||||
|
||||
} // namespace init
|
||||
} // namespace thd
|
||||
|
|
@ -1,120 +0,0 @@
|
|||
#include <THD/base/init_methods/InitMethodUtils.hpp>
|
||||
|
||||
#include <ifaddrs.h>
|
||||
#include <net/if.h>
|
||||
#include <sys/ioctl.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <tuple>
|
||||
|
||||
namespace thd {
|
||||
|
||||
namespace {
|
||||
|
||||
void sendPeerName(int socket) {
|
||||
struct sockaddr_storage master_addr;
|
||||
socklen_t master_addr_len = sizeof(master_addr);
|
||||
SYSCHECK(getpeername(
|
||||
socket,
|
||||
reinterpret_cast<struct sockaddr*>(&master_addr),
|
||||
&master_addr_len));
|
||||
|
||||
std::string addr_str =
|
||||
sockaddrToString(reinterpret_cast<struct sockaddr*>(&master_addr));
|
||||
send_string(socket, addr_str);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<std::string> getInterfaceAddresses() {
|
||||
struct ifaddrs* ifa;
|
||||
SYSCHECK(getifaddrs(&ifa));
|
||||
ResourceGuard ifaddrs_guard([ifa]() { ::freeifaddrs(ifa); });
|
||||
|
||||
std::vector<std::string> addresses;
|
||||
|
||||
while (ifa != nullptr) {
|
||||
struct sockaddr* addr = ifa->ifa_addr;
|
||||
if (addr) {
|
||||
bool is_loopback = ifa->ifa_flags & IFF_LOOPBACK;
|
||||
bool is_ip = addr->sa_family == AF_INET || addr->sa_family == AF_INET6;
|
||||
if (is_ip && !is_loopback) {
|
||||
addresses.push_back(sockaddrToString(addr));
|
||||
}
|
||||
}
|
||||
ifa = ifa->ifa_next;
|
||||
}
|
||||
|
||||
return addresses;
|
||||
}
|
||||
|
||||
std::string discoverWorkers(int listen_socket, rank_type world_size) {
|
||||
// accept connections from workers so they can know our address
|
||||
std::vector<int> sockets(world_size - 1);
|
||||
for (rank_type i = 0; i < world_size - 1; ++i) {
|
||||
std::tie(sockets[i], std::ignore) = accept(listen_socket);
|
||||
}
|
||||
|
||||
std::string public_addr;
|
||||
for (auto socket : sockets) {
|
||||
sendPeerName(socket);
|
||||
public_addr = recv_string(socket);
|
||||
::close(socket);
|
||||
}
|
||||
return public_addr;
|
||||
}
|
||||
|
||||
std::pair<std::string, std::string> discoverMaster(
|
||||
std::vector<std::string> addresses,
|
||||
port_type port) {
|
||||
// try to connect to address via any of the addresses
|
||||
std::string master_address = "";
|
||||
int socket;
|
||||
for (const auto& address : addresses) {
|
||||
try {
|
||||
socket = connect(address, port, true, 2000);
|
||||
master_address = address;
|
||||
break;
|
||||
} catch (...) {
|
||||
} // when connection fails just try different address
|
||||
}
|
||||
|
||||
if (master_address == "") {
|
||||
throw std::runtime_error(
|
||||
"could not establish connection with other processes");
|
||||
}
|
||||
ResourceGuard socket_guard([socket]() { ::close(socket); });
|
||||
sendPeerName(socket);
|
||||
std::string my_address = recv_string(socket);
|
||||
|
||||
return std::make_pair(master_address, my_address);
|
||||
}
|
||||
|
||||
rank_type getRank(
|
||||
const std::vector<int>& ranks,
|
||||
int assigned_rank,
|
||||
size_t order) {
|
||||
if (assigned_rank >= 0) {
|
||||
return assigned_rank;
|
||||
} else {
|
||||
std::vector<bool> taken_ranks(ranks.size());
|
||||
for (auto rank : ranks) {
|
||||
if (rank >= 0)
|
||||
taken_ranks[rank] = true;
|
||||
}
|
||||
|
||||
auto unassigned = std::count(ranks.begin(), ranks.begin() + order, -1) + 1;
|
||||
rank_type rank = 0;
|
||||
while (true) {
|
||||
if (!taken_ranks[rank])
|
||||
unassigned--;
|
||||
if (unassigned == 0)
|
||||
break;
|
||||
rank++;
|
||||
}
|
||||
|
||||
return rank;
|
||||
}
|
||||
}
|
||||
} // namespace thd
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/base/ChannelUtils.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace thd {
|
||||
|
||||
std::vector<std::string> getInterfaceAddresses();
|
||||
|
||||
std::string discoverWorkers(int listen_socket, rank_type world_size);
|
||||
|
||||
// pair of master_address, my_address
|
||||
std::pair<std::string, std::string> discoverMaster(
|
||||
std::vector<std::string> addresses,
|
||||
port_type port);
|
||||
|
||||
// Helper that gets the rank based on the input order
|
||||
rank_type getRank(
|
||||
const std::vector<int>& ranks,
|
||||
int assigned_rank,
|
||||
size_t order);
|
||||
} // namespace thd
|
||||
|
|
@ -1,192 +0,0 @@
|
|||
import argparse
|
||||
import os
|
||||
from timeit import default_timer as timer
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def print_header(title):
|
||||
print(title)
|
||||
print("{:>8}\t{:>5}\t{:<{num_tensors_width}}\t{:>11}\t{:>11}".
|
||||
format("MB/s", "MB", "#", "s", "ms/op",
|
||||
num_tensors_width=MAX_NUM_TENSORS))
|
||||
|
||||
|
||||
def print_stats(bytes, num_tensors, time):
|
||||
print("{:>8.3f}\t{:>5.1f}\t{:<{num_tensors_width}}\t{:>11.3f}\t{:>11.3f}".
|
||||
format(bytes * num_tensors / (2**20 * time),
|
||||
bytes / 2**20,
|
||||
num_tensors,
|
||||
time,
|
||||
1000 * time / num_tensors,
|
||||
num_tensors_width=MAX_NUM_TENSORS))
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Benchmark torch.distributed.')
|
||||
parser.add_argument('--max-bytes', dest='max_bytes', action='store', default=28,
|
||||
type=int,
|
||||
help='set the inclusive upper limit for tensor size; ' +
|
||||
'default: 22 (2**22 = 4 MB)')
|
||||
parser.add_argument('--max-num-tensors', dest='max_num_tensors', action='store',
|
||||
default=3, type=int,
|
||||
help='set the inclusive upper limit for the number of ' +
|
||||
'tensors to be sent during one test run; ' +
|
||||
'default: 3 (10**3 = 1000)')
|
||||
parser.add_argument('--min-bytes', dest='min_bytes', action='store', default=19,
|
||||
type=int,
|
||||
help='set the inclusive lower limit for tensor size; ' +
|
||||
'default: 19 (2**19 = 512 KB)')
|
||||
parser.add_argument('--min-num-tensors', dest='min_num_tensors', action='store',
|
||||
default=2, type=int,
|
||||
help='set the inclusive lower limit for the number of ' +
|
||||
'tensors to be sent during one test run; ' +
|
||||
'default: 2 (10**2 = 100)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
MIN_NUM_TENSORS = args.min_num_tensors
|
||||
MIN_BYTES = args.min_bytes
|
||||
MAX_NUM_TENSORS = args.max_num_tensors + 1
|
||||
MAX_BYTES = args.max_bytes + 1
|
||||
|
||||
dist.init_process_group(backend=os.environ['BACKEND'])
|
||||
|
||||
rank = dist.get_rank()
|
||||
dist.barrier()
|
||||
|
||||
if rank == 0:
|
||||
print_header("broadcast")
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
start = timer()
|
||||
for i in range(0, num_tensors):
|
||||
dist.broadcast(tensor, 0)
|
||||
end = timer()
|
||||
print_stats(bytes, num_tensors, end - start)
|
||||
print()
|
||||
else:
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes)
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
for i in range(0, num_tensors):
|
||||
dist.broadcast(tensor, 0)
|
||||
dist.barrier()
|
||||
|
||||
if rank == 0:
|
||||
print_header("send from 0 to 1")
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
start = timer()
|
||||
for i in range(0, num_tensors):
|
||||
dist.send(tensor, 1)
|
||||
end = timer()
|
||||
print_stats(bytes, num_tensors, end - start)
|
||||
print()
|
||||
elif rank == 1:
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
for i in range(0, num_tensors):
|
||||
dist.recv(tensor, 0)
|
||||
dist.barrier()
|
||||
|
||||
if rank == 0:
|
||||
print_header("reduce")
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
start = timer()
|
||||
for i in range(0, num_tensors):
|
||||
dist.reduce(tensor, 0)
|
||||
end = timer()
|
||||
print_stats(bytes, num_tensors, end - start)
|
||||
print()
|
||||
else:
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
for i in range(0, num_tensors):
|
||||
dist.reduce(tensor, 0)
|
||||
dist.barrier()
|
||||
|
||||
if rank == 0:
|
||||
print_header("all reduce")
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
start = timer()
|
||||
for i in range(0, num_tensors):
|
||||
dist.all_reduce(tensor)
|
||||
end = timer()
|
||||
print_stats(bytes, num_tensors, end - start)
|
||||
print()
|
||||
else:
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
for i in range(0, num_tensors):
|
||||
dist.all_reduce(tensor)
|
||||
dist.barrier()
|
||||
|
||||
if rank == 0:
|
||||
print_header("scatter")
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
tensors = [tensor for n in range(0, dist.get_world_size())]
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
start = timer()
|
||||
for i in range(0, num_tensors):
|
||||
dist.scatter(tensor, scatter_list=tensors)
|
||||
end = timer()
|
||||
print_stats(bytes, num_tensors, end - start)
|
||||
print()
|
||||
else:
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
for i in range(0, num_tensors):
|
||||
dist.scatter(tensor, src=0)
|
||||
dist.barrier()
|
||||
|
||||
if rank == 0:
|
||||
print_header("gather")
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
tensors = [tensor for n in range(0, dist.get_world_size())]
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
start = timer()
|
||||
for i in range(0, num_tensors):
|
||||
dist.gather(tensor, gather_list=tensors)
|
||||
end = timer()
|
||||
print_stats(bytes, num_tensors, end - start)
|
||||
print()
|
||||
else:
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
for i in range(0, num_tensors):
|
||||
dist.gather(tensor, dst=0)
|
||||
dist.barrier()
|
||||
|
||||
if rank == 0:
|
||||
print_header("all gather")
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
tensors = [tensor for n in range(0, dist.get_world_size())]
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
start = timer()
|
||||
for i in range(0, num_tensors):
|
||||
dist.all_gather(tensors, tensor)
|
||||
end = timer()
|
||||
print_stats(bytes, num_tensors, end - start)
|
||||
print()
|
||||
else:
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
tensors = [tensor for n in range(0, dist.get_world_size())]
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
for i in range(0, num_tensors):
|
||||
dist.all_gather(tensors, tensor)
|
||||
dist.barrier()
|
||||
|
|
@ -1,163 +0,0 @@
|
|||
#! /bin/sh
|
||||
|
||||
set -eu
|
||||
|
||||
BACKEND=mpi
|
||||
engine="$PWD/benchmark.py"
|
||||
environment=/dev/null
|
||||
master_hostname=localhost
|
||||
output_file=/dev/stdout
|
||||
hosts=localhost
|
||||
MASTER_PORT=29500
|
||||
MASTER_ADDR="$master_hostname:$MASTER_PORT"
|
||||
WORLD_SIZE=2
|
||||
|
||||
errxit() {
|
||||
printf "%s\n" "$*" 1>&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
usage() {
|
||||
cat <<-EOF
|
||||
Usage: ./run_benchmark [ OPTIONS ]
|
||||
|
||||
Optional arguments:
|
||||
-., --env FILE
|
||||
Set the path to a file to source. When using the MPI backend, the file
|
||||
will be sourced before running 'mpirun'. In case of the TCP backend
|
||||
the file will be sourced on every host after establishing a successful
|
||||
SSH connection. Default: '/dev/null'.
|
||||
|
||||
-b, --backend BACKEND
|
||||
Set the backend to benchmark. Default: 'mpi'.
|
||||
|
||||
-e, --engine ENGINE
|
||||
Set the path to the benchmarking script to run. Use absolute paths if
|
||||
you'll be using the tcp backend. Default: '\$PWD/benchmark.py'.
|
||||
|
||||
--help
|
||||
Show this help and exit successfully.
|
||||
|
||||
-h, --hosts HOSTS
|
||||
Set the list of hosts to run the benchmark on. Format: 'host1,host2'.
|
||||
Default: 'localhost'.
|
||||
|
||||
--max-bytes MAX_BYTES
|
||||
Set the inclusive upper limit for tensor size.
|
||||
Default: 22 (2**22 = 4 MB).
|
||||
|
||||
--max-num-tensors MAX_NUM_TENSORS
|
||||
Set the inclusive upper limit for the number of tensors to be sent
|
||||
during one test run. Default: 3 (10**3 = 1000).
|
||||
|
||||
--min-bytes MIN_BYTES
|
||||
Set the inclusive lower limit for tensor bytes.
|
||||
Default: 19 (2**19 = 512 KB).
|
||||
|
||||
--min-num-tensors MIN_NUM_TENSORS
|
||||
Set the inclusive upper limit for the number of tensors to be sent
|
||||
during one test run. Default: 2 (10**2 = 100).
|
||||
|
||||
-n, --name NAME
|
||||
Set the ip address/host name of the master node. Default: 'localhost'.
|
||||
|
||||
-o, --output FILE
|
||||
Set the path to the output file where the master host will append
|
||||
benchmark results. Default: '/dev/stdout'.
|
||||
|
||||
-p, --port PORT
|
||||
Set the port number master is listening on. Default: '29500'.
|
||||
|
||||
-s, --world-size WORLD_SIZE
|
||||
Set the number of processes to be spawned. Default: '2'.
|
||||
EOF
|
||||
}
|
||||
|
||||
while [ $# -gt 0 ]; do
|
||||
case "$1" in
|
||||
'-.'|--env)
|
||||
environment="$2"
|
||||
shift 2
|
||||
;;
|
||||
--backend|-b)
|
||||
BACKEND="$2"
|
||||
shift 2
|
||||
;;
|
||||
--engine|-e)
|
||||
engine="$2"
|
||||
shift 2
|
||||
;;
|
||||
--help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
--hosts|-h)
|
||||
hosts="$2"
|
||||
shift 2
|
||||
;;
|
||||
--port|-p)
|
||||
MASTER_PORT="$2"
|
||||
shift 2
|
||||
;;
|
||||
--min-num-tensors)
|
||||
min_num_tensors="--min-num-tensors $2"
|
||||
shift 2
|
||||
;;
|
||||
--min-bytes)
|
||||
min_bytes="--min-bytes $2"
|
||||
shift 2
|
||||
;;
|
||||
--max-num-tensors)
|
||||
max_num_tensors="--max-num-tensors $2"
|
||||
shift 2
|
||||
;;
|
||||
--max-bytes)
|
||||
max_bytes="--max-bytes $2"
|
||||
shift 2
|
||||
;;
|
||||
--name|-n)
|
||||
master_hostname="$2"
|
||||
shift 2
|
||||
;;
|
||||
--output|-o)
|
||||
output_file="$2"
|
||||
shift 2
|
||||
;;
|
||||
--world-size|-s)
|
||||
WORLD_SIZE="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
errxit "Unknown option '$1'"
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
MASTER_ADDR="$master_hostname:$MASTER_PORT"
|
||||
if [ x"$BACKEND" = xtcp ]; then
|
||||
RANK=0
|
||||
host_list="$(printf "%s\n" "$hosts" | tr ',' ' ')"
|
||||
if [ "$(printf '%s\n' "$host_list" | wc -w)" -ne "$WORLD_SIZE" ]; then
|
||||
errxit "Number of hosts ($host_list) doesn't match" \
|
||||
"the world size ($WORLD_SIZE)"
|
||||
fi
|
||||
for host in $host_list; do
|
||||
ssh "$host" ". $environment &&" \
|
||||
"BACKEND=$BACKEND MASTER_ADDR=$MASTER_ADDR" \
|
||||
"MASTER_PORT=$MASTER_PORT WORLD_SIZE=$WORLD_SIZE RANK=$RANK" \
|
||||
"python $engine" \
|
||||
">> ${output_file:-}" \
|
||||
"${min_num_tensors:-} ${min_bytes:-} ${max_num_tensors:-} ${max_bytes:-}" &
|
||||
RANK=$((RANK+1))
|
||||
done
|
||||
wait
|
||||
elif [ x"$BACKEND" = xmpi ]; then
|
||||
. "$environment"
|
||||
export BACKEND
|
||||
mpirun -hosts "$hosts" -n "$WORLD_SIZE" >> ${output_file:-} \
|
||||
python "$engine" \
|
||||
${min_num_tensors:-} ${min_bytes:-} ${max_num_tensors:-} ${max_bytes:-}
|
||||
else
|
||||
errxit "Invalid backend: '$BACKEND'"
|
||||
fi
|
||||
|
||||
|
|
@ -1,154 +0,0 @@
|
|||
#include <THD/process_group/Collectives.hpp>
|
||||
#include <THD/base/ChannelUtils.hpp>
|
||||
#include <THD/process_group/General.hpp>
|
||||
|
||||
#include <vector>
|
||||
|
||||
using namespace thd;
|
||||
|
||||
int THDGetRank() {
|
||||
return static_cast<int>(dataChannel->getRank());
|
||||
}
|
||||
|
||||
int THDGetNumProcesses() {
|
||||
return static_cast<int>(dataChannel->getNumProcesses());
|
||||
}
|
||||
|
||||
void THDAllReduceMultiGPU(
|
||||
THDTensorDescriptor* data,
|
||||
size_t len,
|
||||
THDReduceOp operation,
|
||||
THDGroup group) {
|
||||
std::vector<at::Tensor> dataVec(data, data + len);
|
||||
dataChannel->allReduce(dataVec, operation, group);
|
||||
}
|
||||
|
||||
void THDAllReduce(
|
||||
THDTensorDescriptor& desc,
|
||||
THDReduceOp operation,
|
||||
THDGroup group) {
|
||||
dataChannel->allReduce(desc, operation, group);
|
||||
}
|
||||
|
||||
void THDReduceMultiGPU(
|
||||
THDTensorDescriptor* desc,
|
||||
size_t len,
|
||||
THDReduceOp operation,
|
||||
int dst_rank,
|
||||
THDGroup group) {
|
||||
std::vector<at::Tensor> dataVec(desc, desc + len);
|
||||
dataChannel->reduce(dataVec, operation, convertToRank(dst_rank), group);
|
||||
}
|
||||
|
||||
void THDReduce(
|
||||
THDTensorDescriptor& desc,
|
||||
THDReduceOp operation,
|
||||
int dst_rank,
|
||||
THDGroup group) {
|
||||
dataChannel->reduce(desc, operation, convertToRank(dst_rank), group);
|
||||
}
|
||||
|
||||
void THDBroadcastMultiGPU(
|
||||
THDTensorDescriptor* desc,
|
||||
size_t len,
|
||||
int src_rank,
|
||||
THDGroup group) {
|
||||
std::vector<at::Tensor> dataVec(desc, desc + len);
|
||||
dataChannel->broadcast(dataVec, convertToRank(src_rank), group);
|
||||
}
|
||||
|
||||
void THDBroadcast(THDTensorDescriptor& desc, int src_rank, THDGroup group) {
|
||||
dataChannel->broadcast(desc, convertToRank(src_rank), group);
|
||||
}
|
||||
|
||||
THDRequest* THDIsend(THDTensorDescriptor& desc, int dst_rank) {
|
||||
return dataChannel->isend(desc, convertToRank(dst_rank));
|
||||
}
|
||||
|
||||
THDRequest* THDIrecv(THDTensorDescriptor& desc, int src_rank) {
|
||||
return dataChannel->ireceive(desc, convertToRank(src_rank));
|
||||
}
|
||||
|
||||
void THDSend(THDTensorDescriptor& desc, int dst_rank) {
|
||||
dataChannel->send(desc, convertToRank(dst_rank));
|
||||
}
|
||||
|
||||
int THDRecvAnySource(THDTensorDescriptor& desc) {
|
||||
return dataChannel->receive(desc);
|
||||
}
|
||||
|
||||
void THDRecv(THDTensorDescriptor& desc, int src_rank) {
|
||||
dataChannel->receive(desc, convertToRank(src_rank));
|
||||
}
|
||||
|
||||
void THDAllGatherMultiGPU(
|
||||
THDTensorDescriptor* output,
|
||||
size_t outputLen,
|
||||
THDTensorDescriptor* input,
|
||||
size_t inputLen,
|
||||
THDGroup group) {
|
||||
std::vector<at::Tensor> outputVec(output, output + outputLen);
|
||||
std::vector<at::Tensor> inputVec(input, input + inputLen);
|
||||
dataChannel->allGather(outputVec, inputVec, group);
|
||||
}
|
||||
|
||||
void THDAllGather(
|
||||
THDTensorDescriptor* output,
|
||||
size_t len,
|
||||
THDTensorDescriptor& input,
|
||||
THDGroup group) {
|
||||
std::vector<at::Tensor> v_output(output, output + len);
|
||||
dataChannel->allGather(v_output, input, group);
|
||||
}
|
||||
|
||||
void THDGatherSend(THDTensorDescriptor& input, int dst_rank, THDGroup group) {
|
||||
std::vector<at::Tensor> v_output;
|
||||
dataChannel->gather(v_output, input, convertToRank(dst_rank), group);
|
||||
}
|
||||
|
||||
void THDGatherRecv(
|
||||
THDTensorDescriptor* output,
|
||||
size_t len,
|
||||
THDTensorDescriptor& input,
|
||||
THDGroup group) {
|
||||
std::vector<at::Tensor> v_output(output, output + len);
|
||||
dataChannel->gather(v_output, input, dataChannel->getRank(), group);
|
||||
}
|
||||
|
||||
void THDScatterSend(
|
||||
THDTensorDescriptor* input,
|
||||
size_t len,
|
||||
THDTensorDescriptor& output,
|
||||
THDGroup group) {
|
||||
std::vector<at::Tensor> v_input(input, input + len);
|
||||
dataChannel->scatter(v_input, output, dataChannel->getRank(), group);
|
||||
}
|
||||
|
||||
void THDScatterRecv(THDTensorDescriptor& output, int src_rank, THDGroup group) {
|
||||
if (src_rank < 0)
|
||||
throw std::domain_error("src_rank should not be negative");
|
||||
|
||||
std::vector<at::Tensor> v_input;
|
||||
dataChannel->scatter(v_input, output, convertToRank(src_rank), group);
|
||||
}
|
||||
|
||||
void THDBarrier(THDGroup group) {
|
||||
dataChannel->barrier(group);
|
||||
}
|
||||
|
||||
THDGroup THDNewGroup(const int* ranks, size_t len) {
|
||||
std::vector<rank_type> v_ranks(len);
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
v_ranks[i] = convertToRank(ranks[i]);
|
||||
}
|
||||
|
||||
return dataChannel->newGroup(v_ranks);
|
||||
}
|
||||
|
||||
bool THDRequest_isCompleted(THDRequest* request) {
|
||||
return request->isCompleted();
|
||||
}
|
||||
|
||||
void THDRequest_wait(THDRequest* request) {
|
||||
request->wait();
|
||||
}
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/THD.h>
|
||||
#include <THD/base/DataChannel.h>
|
||||
|
||||
THD_API int THDGetRank();
|
||||
THD_API int THDGetNumProcesses();
|
||||
THD_API void THDAllReduceMultiGPU(
|
||||
THDTensorDescriptor* data,
|
||||
size_t len,
|
||||
THDReduceOp operation,
|
||||
THDGroup group);
|
||||
THD_API void THDAllReduce(
|
||||
THDTensorDescriptor& desc,
|
||||
THDReduceOp operation,
|
||||
THDGroup group);
|
||||
THD_API void THDReduceMultiGPU(
|
||||
THDTensorDescriptor* desc,
|
||||
size_t len,
|
||||
THDReduceOp operation,
|
||||
int dst_rank,
|
||||
THDGroup group);
|
||||
THD_API void THDReduce(
|
||||
THDTensorDescriptor& desc,
|
||||
THDReduceOp operation,
|
||||
int dst_rank,
|
||||
THDGroup group);
|
||||
THD_API void THDBroadcastMultiGPU(
|
||||
THDTensorDescriptor* desc,
|
||||
size_t len,
|
||||
int src_rank,
|
||||
THDGroup group);
|
||||
THD_API void THDBroadcast(
|
||||
THDTensorDescriptor& desc,
|
||||
int src_rank,
|
||||
THDGroup group);
|
||||
THD_API THDRequest* THDIsend(THDTensorDescriptor& desc, int dst_rank);
|
||||
THD_API THDRequest* THDIrecv(THDTensorDescriptor& desc, int src_rank);
|
||||
THD_API void THDSend(THDTensorDescriptor& desc, int dst_rank);
|
||||
THD_API int THDRecvAnySource(THDTensorDescriptor& desc);
|
||||
THD_API void THDRecv(THDTensorDescriptor& desc, int src_rank);
|
||||
THD_API void THDAllGatherMultiGPU(
|
||||
THDTensorDescriptor* output,
|
||||
size_t outputLen,
|
||||
THDTensorDescriptor* input,
|
||||
size_t inputLen,
|
||||
THDGroup group);
|
||||
THD_API void THDAllGather(
|
||||
THDTensorDescriptor* output,
|
||||
size_t len,
|
||||
THDTensorDescriptor& input,
|
||||
THDGroup group);
|
||||
THD_API void THDGatherSend(
|
||||
THDTensorDescriptor& input,
|
||||
int dst_rank,
|
||||
THDGroup group);
|
||||
THD_API void THDGatherRecv(
|
||||
THDTensorDescriptor* output,
|
||||
size_t len,
|
||||
THDTensorDescriptor& input,
|
||||
THDGroup group);
|
||||
THD_API void THDScatterSend(
|
||||
THDTensorDescriptor* input,
|
||||
size_t len,
|
||||
THDTensorDescriptor& output,
|
||||
THDGroup group);
|
||||
THD_API void THDScatterRecv(
|
||||
THDTensorDescriptor& output,
|
||||
int src_rank,
|
||||
THDGroup group);
|
||||
THD_API void THDBarrier(THDGroup group);
|
||||
THD_API THDGroup THDNewGroup(const int* ranks, size_t len);
|
||||
THD_API bool THDRequest_isCompleted(THDRequest* request);
|
||||
THD_API void THDRequest_wait(THDRequest* request);
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <THD/process_group/Collectives.h>
|
||||
#include <THD/base/TensorDescriptor.hpp>
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
#include <THD/process_group/General.hpp>
|
||||
#include <THD/base/Exceptions.hpp>
|
||||
|
||||
namespace thd {
|
||||
std::unique_ptr<DataChannel> dataChannel;
|
||||
} // namespace thd
|
||||
|
||||
using namespace thd;
|
||||
|
||||
void THDProcessGroupInit(
|
||||
THDChannelType channel_type,
|
||||
std::string init_method = "env://",
|
||||
int world_size = -1,
|
||||
std::string group_name = "",
|
||||
int rank = -1) {
|
||||
HANDLE_EXCEPTIONS
|
||||
dataChannel = std::unique_ptr<DataChannel>(thd::DataChannel::newChannel(
|
||||
channel_type, init_method, world_size, group_name, rank));
|
||||
dataChannel->init();
|
||||
END_HANDLE_EXCEPTIONS
|
||||
}
|
||||
|
||||
void THDProcessGroupDestroy() {
|
||||
HANDLE_EXCEPTIONS
|
||||
if (dataChannel) {
|
||||
dataChannel->destroy();
|
||||
dataChannel.reset(nullptr);
|
||||
}
|
||||
END_HANDLE_EXCEPTIONS
|
||||
}
|
||||
|
||||
void THDClearGroupCache(THDGroup group) {
|
||||
HANDLE_EXCEPTIONS
|
||||
if (dataChannel) {
|
||||
dataChannel->clearGroupCache(group);
|
||||
}
|
||||
END_HANDLE_EXCEPTIONS
|
||||
}
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <THD/THD.h>
|
||||
#include <THD/base/DataChannel.h>
|
||||
|
||||
THD_API void THDProcessGroupInit(
|
||||
THDChannelType channel_type,
|
||||
std::string init_method,
|
||||
int world_size,
|
||||
std::string group_name,
|
||||
int rank);
|
||||
THD_API void THDProcessGroupDestroy();
|
||||
THD_API void THDClearGroupCache(THDGroup group);
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <THD/process_group/General.h>
|
||||
#include <THD/base/DataChannel.hpp>
|
||||
|
||||
namespace thd {
|
||||
extern std::unique_ptr<DataChannel> dataChannel;
|
||||
} // namespace thd
|
||||
|
|
@ -1,95 +0,0 @@
|
|||
#include <THPP/tensors/THTensor.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <condition_variable>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
constexpr char RANK_ENV[] = "RANK";
|
||||
constexpr char WORLD_SIZE_ENV[] = "WORLD_SIZE";
|
||||
constexpr char MASTER_PORT_ENV[] = "MASTER_PORT";
|
||||
constexpr char MASTER_ADDR_ENV[] = "MASTER_ADDR";
|
||||
|
||||
struct Barrier {
|
||||
Barrier() : _count(0) {}
|
||||
Barrier(size_t count) : _count(count) {}
|
||||
|
||||
void wait() {
|
||||
std::unique_lock<std::mutex> lock{_mutex};
|
||||
if (--_count == 0) {
|
||||
_cv.notify_all();
|
||||
} else {
|
||||
_cv.wait(lock);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex _mutex;
|
||||
std::condition_variable _cv;
|
||||
size_t _count;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<!std::numeric_limits<T>::is_integer, bool>::type
|
||||
check_equal(T x, T y, int ulp = 5) {
|
||||
auto eps = std::numeric_limits<T>::epsilon();
|
||||
auto min = std::numeric_limits<T>::min();
|
||||
return (std::abs(x - y) < eps * std::abs(x + y) * ulp) ||
|
||||
(std::abs(x - y) < min);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<std::numeric_limits<T>::is_integer, bool>::type
|
||||
check_equal(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<thpp::THTensor<T>> buildTensor(
|
||||
std::vector<int64_t> shape,
|
||||
T value) {
|
||||
auto tensor = std::make_shared<thpp::THTensor<T>>();
|
||||
tensor->resize(shape);
|
||||
tensor->fill(value);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool contains(std::vector<T> v, T value) {
|
||||
return std::find(v.begin(), v.end(), value) != v.end();
|
||||
}
|
||||
|
||||
inline int64_t nowInMilliseconds() {
|
||||
return std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch())
|
||||
.count();
|
||||
}
|
||||
|
||||
inline int64_t factorial(int n) {
|
||||
int64_t a = 1;
|
||||
for (int64_t i = 1; i <= n; ++i) {
|
||||
a *= i;
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
#define ASSERT_TENSOR_VALUE(T, tensor, value) \
|
||||
{ \
|
||||
for (size_t idx = 0; idx < (tensor).numel(); idx++) \
|
||||
assert(check_equal( \
|
||||
reinterpret_cast<T*>((tensor).data())[idx], static_cast<T>(value))); \
|
||||
}
|
||||
|
||||
#define ASSERT_THROWS(exception, expr) \
|
||||
{ \
|
||||
try { \
|
||||
(expr); \
|
||||
assert(false); \
|
||||
} catch (const exception& e) { \
|
||||
} \
|
||||
}
|
||||
|
|
@ -1,818 +0,0 @@
|
|||
#ifdef WITH_GLOO
|
||||
#include <THD/base/data_channels/DataChannelGloo.hpp>
|
||||
#endif // WITH_GLOO
|
||||
#ifdef WITH_MPI
|
||||
#include <THD/base/data_channels/DataChannelMPI.hpp>
|
||||
#endif // WITH_MPI
|
||||
#include <THD/base/data_channels/DataChannelTCP.hpp>
|
||||
#include <THD/test/TestUtils.hpp>
|
||||
|
||||
#include <THPP/tensors/THTensor.hpp>
|
||||
|
||||
#include <unistd.h>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <thread>
|
||||
|
||||
constexpr std::array<int, 4> WORKERS_NUM = {2, 4, 7, 13};
|
||||
constexpr int MASTER_PORT = 45678;
|
||||
constexpr int BARRIER_WAIT_TIME = 200; // milliseconds
|
||||
|
||||
std::vector<std::thread> g_all_workers;
|
||||
std::mutex g_mutex;
|
||||
std::string g_data_channel_type;
|
||||
std::unique_ptr<Barrier> g_barrier;
|
||||
|
||||
void test_send_recv_tensor(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support send/recv
|
||||
}
|
||||
|
||||
if (data_channel->getRank() == 0) {
|
||||
auto float_tensor = buildTensor<float>({1, 2, 3}, 4.2);
|
||||
data_channel->send(*float_tensor, 1);
|
||||
} else if (data_channel->getRank() == 1) {
|
||||
auto float_tensor = buildTensor<float>({1, 2, 3}, -1.0);
|
||||
data_channel->receive(*float_tensor, 0);
|
||||
ASSERT_TENSOR_VALUE(float, *float_tensor, 4.2);
|
||||
}
|
||||
}
|
||||
|
||||
void test_send_recv_tensor_any_source(
|
||||
std::shared_ptr<thd::DataChannel> data_channel,
|
||||
int workers) {
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support send/recv from any source
|
||||
}
|
||||
|
||||
if (data_channel->getRank() == 0) {
|
||||
std::set<int> ranks;
|
||||
for (int i = 0; i < workers; i++) {
|
||||
auto int_tensor = buildTensor<int>({1, 2, 3}, -1);
|
||||
data_channel->receive(*int_tensor);
|
||||
ranks.insert(static_cast<int*>(int_tensor->data())[0]);
|
||||
}
|
||||
|
||||
assert(ranks.size() == workers);
|
||||
} else {
|
||||
auto int_tensor = buildTensor<int>({1, 2, 3}, data_channel->getRank());
|
||||
data_channel->send(*int_tensor, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void test_send_recv_scalar(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support send/recv
|
||||
}
|
||||
|
||||
if (data_channel->getRank() == 0) {
|
||||
thd::ScalarWrapper<int> scalar((int)1232);
|
||||
data_channel->send(scalar, 1);
|
||||
} else if (data_channel->getRank() == 1) {
|
||||
thd::ScalarWrapper<int> scalar((int)-1);
|
||||
data_channel->receive(scalar, 0);
|
||||
assert(scalar.value() == 1232);
|
||||
}
|
||||
}
|
||||
|
||||
void test_broadcast(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
for (size_t dest = 0; dest < data_channel->getNumProcesses(); ++dest) {
|
||||
if (data_channel->getRank() == dest) {
|
||||
auto float_tensor = buildTensor<float>({1, 2, 3, 4, 5}, 10.123);
|
||||
data_channel->broadcast(*float_tensor, dest);
|
||||
} else {
|
||||
auto float_tensor = buildTensor<float>({1, 2, 3, 4, 5}, -1.0);
|
||||
data_channel->broadcast(*float_tensor, dest);
|
||||
ASSERT_TENSOR_VALUE(float, *float_tensor, 10.123)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void _test_reduce_helper(
|
||||
std::shared_ptr<thd::DataChannel> data_channel,
|
||||
THDReduceOp op_type,
|
||||
int64_t init_value,
|
||||
int64_t expected_value) {
|
||||
if (data_channel->getRank() == 0) {
|
||||
auto int_tensor = buildTensor<int>({1, 2, 3, 4, 5}, init_value);
|
||||
data_channel->reduce(*int_tensor, op_type, 0);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, expected_value)
|
||||
} else {
|
||||
auto int_tensor =
|
||||
buildTensor<int>({1, 2, 3, 4, 5}, data_channel->getRank());
|
||||
data_channel->reduce(*int_tensor, op_type, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void test_reduce(std::shared_ptr<thd::DataChannel> data_channel, int workers) {
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support reduce
|
||||
}
|
||||
|
||||
_test_reduce_helper(
|
||||
data_channel,
|
||||
THDReduceOp::THDReduceSUM,
|
||||
2,
|
||||
2 + (workers * (workers + 1) / 2));
|
||||
_test_reduce_helper(
|
||||
data_channel, THDReduceOp::THDReducePRODUCT, 2, 2 * factorial(workers));
|
||||
_test_reduce_helper(data_channel, THDReduceOp::THDReduceMIN, 10010, 1);
|
||||
_test_reduce_helper(
|
||||
data_channel,
|
||||
THDReduceOp::THDReduceMAX,
|
||||
-1,
|
||||
data_channel->getNumProcesses() - 1);
|
||||
}
|
||||
|
||||
void _test_allReduce_helper(
|
||||
std::shared_ptr<thd::DataChannel> data_channel,
|
||||
THDReduceOp op_type,
|
||||
int64_t init_value,
|
||||
int64_t expected_value) {
|
||||
if (data_channel->getRank() == 0) {
|
||||
auto int_tensor = buildTensor<int>({1, 2, 3, 4, 5, 6, 7, 100}, init_value);
|
||||
data_channel->allReduce(*int_tensor, op_type, 0);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, expected_value)
|
||||
} else {
|
||||
auto int_tensor =
|
||||
buildTensor<int>({1, 2, 3, 4, 5, 6, 7, 100}, data_channel->getRank());
|
||||
data_channel->allReduce(*int_tensor, op_type, 0);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, expected_value)
|
||||
}
|
||||
}
|
||||
|
||||
void test_allReduce(
|
||||
std::shared_ptr<thd::DataChannel> data_channel,
|
||||
int workers) {
|
||||
_test_allReduce_helper(
|
||||
data_channel,
|
||||
THDReduceOp::THDReduceSUM,
|
||||
2,
|
||||
2 + (workers * (workers + 1) / 2));
|
||||
_test_allReduce_helper(
|
||||
data_channel, THDReduceOp::THDReducePRODUCT, 2, 2 * factorial(workers));
|
||||
_test_allReduce_helper(data_channel, THDReduceOp::THDReduceMIN, 10010, 1);
|
||||
_test_allReduce_helper(
|
||||
data_channel,
|
||||
THDReduceOp::THDReduceMAX,
|
||||
-1,
|
||||
data_channel->getNumProcesses() - 1);
|
||||
}
|
||||
|
||||
void test_scatter(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support scatter
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<thpp::IntTensor>> tensors;
|
||||
std::vector<thpp::Tensor*> raw_tensors;
|
||||
if (data_channel->getRank() == 0) {
|
||||
for (size_t i = 0; i < data_channel->getNumProcesses(); ++i) {
|
||||
tensors.push_back(buildTensor<int>({1, 2, 3, 4, 5}, i));
|
||||
raw_tensors.push_back(tensors.back().get());
|
||||
}
|
||||
}
|
||||
|
||||
auto int_tensor = buildTensor<int>({1, 2, 3, 4, 5}, -1);
|
||||
data_channel->scatter(raw_tensors, *int_tensor, 0);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, data_channel->getRank())
|
||||
}
|
||||
|
||||
void test_gather(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support gather
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<thpp::IntTensor>> tensors;
|
||||
std::vector<thpp::Tensor*> raw_tensors;
|
||||
auto int_tensor = buildTensor<int>({1, 2, 3, 4, 5}, data_channel->getRank());
|
||||
if (data_channel->getRank() == 0) {
|
||||
for (size_t i = 0; i < data_channel->getNumProcesses(); ++i) {
|
||||
tensors.push_back(buildTensor<int>({1, 2, 3, 4, 5}, -1));
|
||||
raw_tensors.push_back(tensors.back().get());
|
||||
}
|
||||
|
||||
data_channel->gather(raw_tensors, *int_tensor, 0);
|
||||
for (size_t i = 0; i < tensors.size(); ++i)
|
||||
ASSERT_TENSOR_VALUE(int, *(tensors[i]), i)
|
||||
} else {
|
||||
data_channel->gather(raw_tensors, *int_tensor, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void test_allGather(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
std::vector<std::shared_ptr<thpp::IntTensor>> tensors;
|
||||
std::vector<thpp::Tensor*> raw_tensors;
|
||||
auto int_tensor = buildTensor<int>({1, 2, 3, 4, 5}, data_channel->getRank());
|
||||
for (size_t i = 0; i < data_channel->getNumProcesses(); ++i) {
|
||||
tensors.push_back(buildTensor<int>({1, 2, 3, 4, 5}, -1));
|
||||
raw_tensors.push_back(tensors.back().get());
|
||||
}
|
||||
|
||||
data_channel->allGather(raw_tensors, *int_tensor, 0);
|
||||
for (size_t i = 0; i < tensors.size(); ++i)
|
||||
ASSERT_TENSOR_VALUE(int, *(tensors[i]), i)
|
||||
}
|
||||
|
||||
void test_barrier(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
for (int i = 0; i < data_channel->getNumProcesses(); ++i) {
|
||||
if (data_channel->getRank() == i) {
|
||||
int64_t time_after_barrier = nowInMilliseconds() + BARRIER_WAIT_TIME;
|
||||
auto time_tensor = buildTensor<int64_t>({1}, time_after_barrier);
|
||||
data_channel->broadcast(*time_tensor, i);
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::milliseconds(BARRIER_WAIT_TIME + 10));
|
||||
data_channel->barrier();
|
||||
} else {
|
||||
auto time_tensor = buildTensor<int64_t>({1}, -1);
|
||||
data_channel->broadcast(
|
||||
*time_tensor, i); // get expected time after barrier
|
||||
data_channel->barrier();
|
||||
assert(
|
||||
nowInMilliseconds() >=
|
||||
reinterpret_cast<int64_t*>(time_tensor->data())[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void test_isend(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support isend
|
||||
}
|
||||
|
||||
if (data_channel->getRank() == 0) {
|
||||
std::vector<std::shared_ptr<thd::DataChannel::Request>> requests;
|
||||
for (size_t i = 1; i < data_channel->getNumProcesses(); ++i) {
|
||||
auto tensor = buildTensor<int>({1, 2, 3, 4, 5}, i);
|
||||
requests.push_back(std::shared_ptr<thd::DataChannel::Request>(
|
||||
data_channel->isend(*tensor, i)));
|
||||
}
|
||||
|
||||
for (auto request : requests) {
|
||||
request->wait();
|
||||
assert(request->isCompleted());
|
||||
}
|
||||
} else {
|
||||
auto int_tensor = buildTensor<int>({1, 2, 3, 4, 5}, -1);
|
||||
data_channel->receive(*int_tensor, 0);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, data_channel->getRank())
|
||||
}
|
||||
}
|
||||
|
||||
void test_irecv(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support irecv
|
||||
}
|
||||
|
||||
if (data_channel->getRank() == 0) {
|
||||
std::vector<std::shared_ptr<thd::DataChannel::Request>> requests;
|
||||
std::vector<std::shared_ptr<thpp::IntTensor>> tensors;
|
||||
for (size_t i = 1; i < data_channel->getNumProcesses(); ++i) {
|
||||
tensors.push_back(buildTensor<int>({1, 2, 3, 4, 5}, -1));
|
||||
requests.push_back(std::shared_ptr<thd::DataChannel::Request>(
|
||||
data_channel->ireceive(*tensors.back(), i)));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < requests.size(); ++i) {
|
||||
requests.at(i)->wait();
|
||||
assert(requests.at(i)->isCompleted());
|
||||
ASSERT_TENSOR_VALUE(int, *tensors.at(i), i + 1)
|
||||
}
|
||||
} else {
|
||||
auto int_tensor =
|
||||
buildTensor<int>({1, 2, 3, 4, 5}, data_channel->getRank());
|
||||
data_channel->send(*int_tensor, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void test_interlaces(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support isend, irecv, send, recv
|
||||
}
|
||||
|
||||
if (data_channel->getRank() == 0) {
|
||||
std::vector<std::shared_ptr<thd::DataChannel::Request>> requests;
|
||||
for (size_t i = 1; i < data_channel->getNumProcesses(); ++i) {
|
||||
auto tensor = buildTensor<int>({1, 2, 3, 4, 5}, 10);
|
||||
requests.push_back(std::shared_ptr<thd::DataChannel::Request>(
|
||||
data_channel->isend(*tensor, i)));
|
||||
}
|
||||
|
||||
data_channel->barrier();
|
||||
|
||||
for (size_t i = 1; i < data_channel->getNumProcesses(); ++i) {
|
||||
auto int_tensor = buildTensor<int>({1, 2, 3, 4, 5}, 20);
|
||||
data_channel->send(*int_tensor, i);
|
||||
}
|
||||
} else {
|
||||
auto int_tensor1 = buildTensor<int>({1, 2, 3, 4, 5}, -1);
|
||||
auto request = std::shared_ptr<thd::DataChannel::Request>(
|
||||
data_channel->ireceive(*int_tensor1, 0));
|
||||
|
||||
data_channel->barrier();
|
||||
|
||||
auto int_tensor2 = buildTensor<int>({1, 2, 3, 4, 5}, -1);
|
||||
data_channel->receive(*int_tensor2, 0);
|
||||
request->wait();
|
||||
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor1, 10)
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor2, 20)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* In group tests we call same functions in processes which do not belong to
|
||||
* those groups to check if it will not affect any computations.
|
||||
*
|
||||
* Processes which do not belong to group do not have to call those methods!
|
||||
*/
|
||||
|
||||
////////////
|
||||
// GROUPS //
|
||||
////////////
|
||||
|
||||
void test_broadcast_group(
|
||||
std::shared_ptr<thd::DataChannel> data_channel,
|
||||
THDGroup group,
|
||||
std::vector<thd::rank_type> group_ranks) {
|
||||
if (contains(group_ranks, data_channel->getRank())) {
|
||||
auto int_tensor = buildTensor({1, 2, 3, 4, 5}, -1);
|
||||
if (data_channel->getRank() == group_ranks[0])
|
||||
int_tensor->fill(2000);
|
||||
|
||||
data_channel->broadcast(*int_tensor, group_ranks[0], group);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, 2000)
|
||||
} else {
|
||||
auto int_tensor = buildTensor({1, 2, 3, 4, 5}, 1000);
|
||||
data_channel->broadcast(*int_tensor, group_ranks[0], group);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, 1000)
|
||||
}
|
||||
}
|
||||
|
||||
void test_reduce_group(
|
||||
std::shared_ptr<thd::DataChannel> data_channel,
|
||||
THDGroup group,
|
||||
std::vector<thd::rank_type> group_ranks) {
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support reduce
|
||||
}
|
||||
|
||||
if (contains(group_ranks, data_channel->getRank())) {
|
||||
auto int_tensor = buildTensor({1, 2, 3, 4, 5}, 10);
|
||||
data_channel->reduce(
|
||||
*int_tensor, THDReduceOp::THDReduceSUM, group_ranks[0], group);
|
||||
if (data_channel->getRank() == group_ranks[0]) {
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, 10 * group_ranks.size())
|
||||
} else {
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, 10)
|
||||
}
|
||||
} else {
|
||||
auto int_tensor = buildTensor({1, 2, 3, 4, 5}, 1000);
|
||||
data_channel->reduce(
|
||||
*int_tensor, THDReduceOp::THDReduceSUM, group_ranks[0], group);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, 1000)
|
||||
}
|
||||
}
|
||||
|
||||
void test_allReduce_group(
|
||||
std::shared_ptr<thd::DataChannel> data_channel,
|
||||
THDGroup group,
|
||||
std::vector<thd::rank_type> group_ranks) {
|
||||
if (contains(group_ranks, data_channel->getRank())) {
|
||||
auto int_tensor = buildTensor({1, 2, 3, 4, 5, 6, 7, 100}, 10);
|
||||
data_channel->allReduce(*int_tensor, THDReduceOp::THDReduceSUM, group);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, 10 * group_ranks.size())
|
||||
} else {
|
||||
auto int_tensor = buildTensor({1, 2, 3, 4, 5, 6, 7, 100}, 1000);
|
||||
data_channel->allReduce(*int_tensor, THDReduceOp::THDReduceSUM, group);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, 1000)
|
||||
}
|
||||
}
|
||||
|
||||
void test_scatter_group(
|
||||
std::shared_ptr<thd::DataChannel> data_channel,
|
||||
THDGroup group,
|
||||
std::vector<thd::rank_type> group_ranks) {
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support scatter
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<thpp::IntTensor>> tensors;
|
||||
std::vector<thpp::Tensor*> raw_tensors;
|
||||
if (contains(group_ranks, data_channel->getRank())) {
|
||||
if (data_channel->getRank() == group_ranks[0]) {
|
||||
for (size_t i = 0; i < group_ranks.size(); ++i) {
|
||||
tensors.push_back(buildTensor<int>({1, 2, 3, 4, 5}, group_ranks[i]));
|
||||
raw_tensors.push_back(tensors.back().get());
|
||||
}
|
||||
}
|
||||
|
||||
auto int_tensor = buildTensor<int>({1, 2, 3, 4, 5}, -1);
|
||||
data_channel->scatter(raw_tensors, *int_tensor, group_ranks[0], group);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, data_channel->getRank())
|
||||
} else {
|
||||
auto int_tensor = buildTensor({1, 2, 3, 4, 5}, 1000);
|
||||
data_channel->scatter(raw_tensors, *int_tensor, group_ranks[0], group);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, 1000)
|
||||
}
|
||||
}
|
||||
|
||||
void test_gather_group(
|
||||
std::shared_ptr<thd::DataChannel> data_channel,
|
||||
THDGroup group,
|
||||
std::vector<thd::rank_type> group_ranks) {
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support gather
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<thpp::IntTensor>> tensors;
|
||||
std::vector<thpp::Tensor*> raw_tensors;
|
||||
if (contains(group_ranks, data_channel->getRank())) {
|
||||
auto int_tensor =
|
||||
buildTensor<int>({1, 2, 3, 4, 5}, data_channel->getRank());
|
||||
if (data_channel->getRank() == group_ranks[0]) {
|
||||
for (size_t i = 0; i < group_ranks.size(); ++i) {
|
||||
tensors.push_back(buildTensor<int>({1, 2, 3, 4, 5}, -1));
|
||||
raw_tensors.push_back(tensors.back().get());
|
||||
}
|
||||
|
||||
data_channel->gather(raw_tensors, *int_tensor, group_ranks[0], group);
|
||||
for (size_t i = 0; i < tensors.size(); ++i)
|
||||
ASSERT_TENSOR_VALUE(int, *(tensors[i]), group_ranks[i])
|
||||
} else {
|
||||
data_channel->gather(raw_tensors, *int_tensor, group_ranks[0], group);
|
||||
}
|
||||
} else {
|
||||
auto int_tensor = buildTensor({1, 2, 3, 4, 5}, 1000);
|
||||
data_channel->gather(raw_tensors, *int_tensor, group_ranks[0], group);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, 1000)
|
||||
}
|
||||
}
|
||||
|
||||
void test_allGather_group(
|
||||
std::shared_ptr<thd::DataChannel> data_channel,
|
||||
THDGroup group,
|
||||
std::vector<thd::rank_type> group_ranks) {
|
||||
std::vector<std::shared_ptr<thpp::IntTensor>> tensors;
|
||||
std::vector<thpp::Tensor*> raw_tensors;
|
||||
if (contains(group_ranks, data_channel->getRank())) {
|
||||
auto int_tensor =
|
||||
buildTensor<int>({1, 2, 3, 4, 5}, data_channel->getRank());
|
||||
for (size_t i = 0; i < group_ranks.size(); ++i) {
|
||||
tensors.push_back(buildTensor<int>({1, 2, 3, 4, 5}, -1));
|
||||
raw_tensors.push_back(tensors.back().get());
|
||||
}
|
||||
|
||||
data_channel->allGather(raw_tensors, *int_tensor, group);
|
||||
for (size_t i = 0; i < tensors.size(); ++i)
|
||||
ASSERT_TENSOR_VALUE(int, *(tensors[i]), group_ranks[i])
|
||||
} else {
|
||||
auto int_tensor = buildTensor({1, 2, 3, 4, 5}, 1000);
|
||||
data_channel->allGather(raw_tensors, *int_tensor, group);
|
||||
ASSERT_TENSOR_VALUE(int, *int_tensor, 1000)
|
||||
}
|
||||
}
|
||||
|
||||
void test_barrier_group(
|
||||
std::shared_ptr<thd::DataChannel> data_channel,
|
||||
THDGroup group,
|
||||
std::vector<thd::rank_type> group_ranks) {
|
||||
if (contains(group_ranks, data_channel->getRank())) {
|
||||
for (int i = 0; i < group_ranks.size(); ++i) {
|
||||
if (data_channel->getRank() == group_ranks[i]) {
|
||||
int64_t time_after_barrier = nowInMilliseconds() + BARRIER_WAIT_TIME;
|
||||
auto time_tensor = buildTensor<int64_t>({1}, time_after_barrier);
|
||||
data_channel->broadcast(*time_tensor, group_ranks[i], group);
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::milliseconds(BARRIER_WAIT_TIME + 10));
|
||||
data_channel->barrier(group);
|
||||
} else {
|
||||
auto time_tensor = buildTensor<int64_t>({1}, -1);
|
||||
data_channel->broadcast(
|
||||
*time_tensor,
|
||||
group_ranks[i],
|
||||
group); // get expected time after barrier
|
||||
data_channel->barrier(group);
|
||||
assert(
|
||||
nowInMilliseconds() >=
|
||||
reinterpret_cast<int64_t*>(time_tensor->data())[0]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::milliseconds(BARRIER_WAIT_TIME + 100));
|
||||
data_channel->barrier(group);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////
|
||||
// EXCEPTIONS //
|
||||
////////////////
|
||||
|
||||
void test_send_recv_invalid_rank(
|
||||
std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support send/recv
|
||||
}
|
||||
|
||||
if (g_data_channel_type == "mpi") {
|
||||
return; // XXX: MPI does not throw exceptions
|
||||
}
|
||||
|
||||
auto rank = data_channel->getRank();
|
||||
auto int_tensor = buildTensor({1, 2, 3, 4, 5}, -1);
|
||||
|
||||
{// cannot send or receive to self
|
||||
ASSERT_THROWS(std::logic_error, data_channel->send(*int_tensor, rank))
|
||||
ASSERT_THROWS(
|
||||
std::logic_error, data_channel->receive(*int_tensor, rank))}
|
||||
|
||||
{ // cannot send or receive to/from process with rank -1
|
||||
ASSERT_THROWS(std::out_of_range, data_channel->send(*int_tensor, -1))
|
||||
ASSERT_THROWS(std::out_of_range, data_channel->receive(*int_tensor, -1))
|
||||
}
|
||||
}
|
||||
|
||||
// Cannot create empty group or group will be null
|
||||
void test_empty_group(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
// in MPI there will be created NULL_COMM
|
||||
if (g_data_channel_type == "tcp" || g_data_channel_type == "gloo") {
|
||||
ASSERT_THROWS(std::logic_error, data_channel->newGroup({}))
|
||||
}
|
||||
}
|
||||
|
||||
// Process with rank 0 is not part of group, we cannot perform operation to it
|
||||
void test_process_not_in_group(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
auto int_tensor = buildTensor({1, 2, 3, 4, 5}, -1);
|
||||
|
||||
THDGroup group = data_channel->newGroup({1});
|
||||
std::vector<std::shared_ptr<thpp::IntTensor>> tensors = {
|
||||
buildTensor<int>({1, 2, 3, 4, 5}, -1)};
|
||||
std::vector<thpp::Tensor*> raw_tensors = {tensors.back().get()};
|
||||
|
||||
if (data_channel->getRank() == 1) {
|
||||
ASSERT_THROWS(
|
||||
std::logic_error, data_channel->broadcast(*int_tensor, 0, group))
|
||||
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support scatter/gather/reduce
|
||||
}
|
||||
|
||||
ASSERT_THROWS(
|
||||
std::logic_error,
|
||||
data_channel->reduce(*int_tensor, THDReduceOp::THDReduceSUM, 0, group))
|
||||
|
||||
ASSERT_THROWS(
|
||||
std::logic_error,
|
||||
data_channel->scatter(raw_tensors, *int_tensor, 0, group))
|
||||
|
||||
ASSERT_THROWS(
|
||||
std::logic_error,
|
||||
data_channel->gather(raw_tensors, *int_tensor, 0, group))
|
||||
}
|
||||
}
|
||||
|
||||
// input_tensors does not match size of group
|
||||
void test_tensors_do_not_match_group_size(
|
||||
std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
auto int_tensor = buildTensor({1, 2, 3, 4, 5}, -1);
|
||||
THDGroup group = data_channel->newGroup({1, 2});
|
||||
std::vector<std::shared_ptr<thpp::IntTensor>> tensors = {
|
||||
buildTensor<int>({1, 2, 3, 4, 5}, -1)};
|
||||
std::vector<thpp::Tensor*> raw_tensors = {tensors.back().get()};
|
||||
|
||||
if (data_channel->getRank() == 1 || data_channel->getRank() == 2) {
|
||||
ASSERT_THROWS(
|
||||
std::logic_error,
|
||||
data_channel->allGather(raw_tensors, *int_tensor, group))
|
||||
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support scatter/gather
|
||||
}
|
||||
|
||||
if (data_channel->getRank() == 1) {
|
||||
ASSERT_THROWS(
|
||||
std::logic_error,
|
||||
data_channel->scatter(raw_tensors, *int_tensor, 1, group))
|
||||
|
||||
ASSERT_THROWS(
|
||||
std::logic_error,
|
||||
data_channel->gather(raw_tensors, *int_tensor, 1, group))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// input_tensors are not the same
|
||||
void test_tensors_are_not_the_same(
|
||||
std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
auto int_tensor = buildTensor({1, 2, 3, 4, 5}, -1);
|
||||
THDGroup group = data_channel->newGroup({1, 2});
|
||||
std::vector<std::shared_ptr<thpp::IntTensor>> tensors = {
|
||||
buildTensor<int>({1, 2, 3, 4, 5}, -1),
|
||||
buildTensor<int>({1, 2, 3, 4}, -1)};
|
||||
std::vector<thpp::Tensor*> raw_tensors = {tensors[0].get(), tensors[1].get()};
|
||||
|
||||
if (data_channel->getRank() == 1 || data_channel->getRank() == 2) {
|
||||
ASSERT_THROWS(
|
||||
std::logic_error,
|
||||
data_channel->allGather(raw_tensors, *int_tensor, group))
|
||||
|
||||
if (g_data_channel_type == "gloo") {
|
||||
return; // XXX: Gloo does not support scatter/gather
|
||||
}
|
||||
|
||||
if (data_channel->getRank() == 1) {
|
||||
ASSERT_THROWS(
|
||||
std::logic_error,
|
||||
data_channel->scatter(raw_tensors, *int_tensor, 1, group))
|
||||
|
||||
ASSERT_THROWS(
|
||||
std::logic_error,
|
||||
data_channel->gather(raw_tensors, *int_tensor, 1, group))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void run_all_tests(
|
||||
std::shared_ptr<thd::DataChannel> data_channel,
|
||||
int workers) {
|
||||
test_send_recv_tensor(data_channel);
|
||||
test_send_recv_tensor_any_source(data_channel, workers);
|
||||
test_send_recv_scalar(data_channel);
|
||||
test_broadcast(data_channel);
|
||||
test_reduce(data_channel, workers);
|
||||
test_allReduce(data_channel, workers);
|
||||
test_scatter(data_channel);
|
||||
test_gather(data_channel);
|
||||
test_allGather(data_channel);
|
||||
test_barrier(data_channel);
|
||||
test_isend(data_channel);
|
||||
test_irecv(data_channel);
|
||||
test_interlaces(data_channel);
|
||||
|
||||
std::vector<thd::rank_type> group_ranks = {1, 2};
|
||||
THDGroup group = data_channel->newGroup(group_ranks);
|
||||
test_broadcast_group(data_channel, group, group_ranks);
|
||||
test_reduce_group(data_channel, group, group_ranks);
|
||||
test_allReduce_group(data_channel, group, group_ranks);
|
||||
test_scatter_group(data_channel, group, group_ranks);
|
||||
test_gather_group(data_channel, group, group_ranks);
|
||||
test_allGather_group(data_channel, group, group_ranks);
|
||||
test_barrier_group(data_channel, group, group_ranks);
|
||||
|
||||
test_send_recv_invalid_rank(data_channel);
|
||||
test_empty_group(data_channel);
|
||||
test_process_not_in_group(data_channel);
|
||||
test_tensors_do_not_match_group_size(data_channel);
|
||||
test_tensors_are_not_the_same(data_channel);
|
||||
}
|
||||
|
||||
void init_tcp_master(int workers) {
|
||||
g_mutex.lock();
|
||||
setenv(WORLD_SIZE_ENV, std::to_string((workers + 1)).data(), 1);
|
||||
setenv(RANK_ENV, "0", 1);
|
||||
setenv(MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
|
||||
auto masterChannel = std::make_shared<thd::DataChannelTCP>(
|
||||
thd::getInitConfig("env://")); // reads all env variable
|
||||
g_mutex.unlock();
|
||||
|
||||
assert(masterChannel->init());
|
||||
run_all_tests(masterChannel, workers);
|
||||
|
||||
// wait for all workers to finish
|
||||
for (auto& worker : g_all_workers) {
|
||||
worker.join();
|
||||
}
|
||||
}
|
||||
|
||||
void init_tcp_worker(unsigned int id, int workers) {
|
||||
g_mutex.lock();
|
||||
setenv(RANK_ENV, std::to_string(id).data(), 1);
|
||||
setenv(
|
||||
MASTER_ADDR_ENV,
|
||||
std::string("127.0.0.1:" + std::to_string(MASTER_PORT)).data(),
|
||||
1);
|
||||
auto worker_channel = std::make_shared<thd::DataChannelTCP>(
|
||||
thd::getInitConfig("env://")); // reads all env variable
|
||||
g_mutex.unlock();
|
||||
|
||||
assert(worker_channel->init());
|
||||
run_all_tests(worker_channel, workers);
|
||||
}
|
||||
|
||||
#ifdef WITH_GLOO
|
||||
void init_gloo_master(int workers) {
|
||||
g_mutex.lock();
|
||||
setenv(WORLD_SIZE_ENV, std::to_string((workers + 1)).data(), 1);
|
||||
setenv(RANK_ENV, "0", 1);
|
||||
setenv(MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
|
||||
auto masterChannel = std::make_shared<thd::DataChannelGloo>(
|
||||
thd::getInitConfig("env://")); // reads all env variable
|
||||
g_mutex.unlock();
|
||||
|
||||
assert(masterChannel->init());
|
||||
run_all_tests(masterChannel, workers);
|
||||
|
||||
g_barrier->wait();
|
||||
}
|
||||
|
||||
void init_gloo_worker(unsigned int id, int workers) {
|
||||
g_mutex.lock();
|
||||
setenv(RANK_ENV, std::to_string(id).data(), 1);
|
||||
setenv(
|
||||
MASTER_ADDR_ENV,
|
||||
std::string("127.0.0.1:" + std::to_string(MASTER_PORT)).data(),
|
||||
1);
|
||||
auto worker_channel = std::make_shared<thd::DataChannelGloo>(
|
||||
thd::getInitConfig("env://")); // reads all env variable
|
||||
g_mutex.unlock();
|
||||
|
||||
assert(worker_channel->init());
|
||||
run_all_tests(worker_channel, workers);
|
||||
|
||||
g_barrier->wait();
|
||||
}
|
||||
#endif // WITH_GLOO
|
||||
|
||||
#ifdef WITH_MPI
|
||||
void init_mpi_process() {
|
||||
auto data_channel = std::make_shared<thd::DataChannelMPI>();
|
||||
assert(data_channel->init());
|
||||
run_all_tests(data_channel, WORKERS_NUM[0]);
|
||||
|
||||
std::cout << "MPI OK (id: " << data_channel->getRank() << ")" << std::endl;
|
||||
}
|
||||
#endif // WITH_MPI
|
||||
|
||||
int main(int argc, char const* argv[]) {
|
||||
#ifdef WITH_MPI
|
||||
if (argc == 1) {
|
||||
#endif // WITH_MPI
|
||||
g_data_channel_type = "tcp";
|
||||
for (auto workers : WORKERS_NUM) {
|
||||
std::cout << "TCP (workers: " << workers << "):" << std::endl;
|
||||
// start tcp master
|
||||
std::thread tcp_master_thread(init_tcp_master, workers);
|
||||
|
||||
// start tcp worker
|
||||
for (int id = 1; id <= workers; ++id) {
|
||||
g_all_workers.push_back(std::thread(init_tcp_worker, id, workers));
|
||||
}
|
||||
|
||||
tcp_master_thread.join();
|
||||
g_all_workers.clear();
|
||||
|
||||
std::cout << "TCP - OK" << std::endl;
|
||||
}
|
||||
|
||||
#ifdef WITH_GLOO
|
||||
g_data_channel_type = "gloo";
|
||||
for (auto workers : WORKERS_NUM) {
|
||||
g_barrier.reset(new Barrier(workers + 1));
|
||||
std::cout << "Gloo (workers: " << workers << "):" << std::endl;
|
||||
// start gloo master
|
||||
std::thread gloo_master_thread(init_gloo_master, workers);
|
||||
|
||||
// start gloo worker
|
||||
for (int id = 1; id <= workers; ++id) {
|
||||
g_all_workers.push_back(std::thread(init_gloo_worker, id, workers));
|
||||
}
|
||||
|
||||
// wait for all workers to finish
|
||||
for (auto& worker : g_all_workers) {
|
||||
worker.join();
|
||||
}
|
||||
|
||||
gloo_master_thread.join();
|
||||
g_all_workers.clear();
|
||||
|
||||
std::cout << "Gloo - OK" << std::endl;
|
||||
}
|
||||
#endif // WITH_GLOO
|
||||
|
||||
#ifdef WITH_MPI
|
||||
std::cout << "--------------------------" << std::endl;
|
||||
|
||||
// start MPI processes
|
||||
std::cout << "MPI:" << std::endl;
|
||||
execlp(
|
||||
"mpirun",
|
||||
"mpirun",
|
||||
"-n",
|
||||
std::to_string(WORKERS_NUM[0] + 1).data(),
|
||||
argv[0],
|
||||
"1",
|
||||
NULL);
|
||||
} else {
|
||||
g_data_channel_type = "mpi";
|
||||
init_mpi_process();
|
||||
}
|
||||
#endif // WITH_MPI
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
#include <THD/base/data_channels/DataChannelGloo.hpp>
|
||||
#include <THD/test/TestUtils.hpp>
|
||||
|
||||
#include <THPP/tensors/THTensor.hpp>
|
||||
|
||||
#include <unistd.h>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
constexpr std::array<int, 1> WORKERS_NUM = {10};
|
||||
constexpr int MASTER_PORT = 45678;
|
||||
|
||||
std::vector<std::thread> g_all_workers;
|
||||
std::mutex g_mutex;
|
||||
|
||||
void test(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
for (size_t dest = 0; dest < data_channel->getNumProcesses(); ++dest) {
|
||||
if (data_channel->getRank() == dest) {
|
||||
auto float_tensor = buildTensor<float>({1, 2, 3, 4, 5}, 10.123);
|
||||
data_channel->broadcast(*float_tensor, dest);
|
||||
} else {
|
||||
auto float_tensor = buildTensor<float>({1, 2, 3, 4, 5}, -1.0);
|
||||
data_channel->broadcast(*float_tensor, dest);
|
||||
ASSERT_TENSOR_VALUE(float, *float_tensor, 10.123)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void run_all_tests(
|
||||
std::shared_ptr<thd::DataChannel> data_channel,
|
||||
int workers) {
|
||||
// NOTE: without properly working GlooCache this test would create
|
||||
// about (1000 * WORKERS ^ 3) connections what is over 'normal' system
|
||||
// configuration
|
||||
for (size_t i = 0; i < 1000; ++i) {
|
||||
test(data_channel);
|
||||
}
|
||||
}
|
||||
|
||||
void init_gloo_master(int workers) {
|
||||
g_mutex.lock();
|
||||
setenv(WORLD_SIZE_ENV, std::to_string((workers + 1)).data(), 1);
|
||||
setenv(RANK_ENV, "0", 1);
|
||||
setenv(MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
|
||||
auto masterChannel = std::make_shared<thd::DataChannelGloo>(
|
||||
thd::getInitConfig("env://")); // reads all env variable
|
||||
g_mutex.unlock();
|
||||
|
||||
assert(masterChannel->init());
|
||||
run_all_tests(masterChannel, workers);
|
||||
}
|
||||
|
||||
void init_gloo_worker(unsigned int id, int workers) {
|
||||
g_mutex.lock();
|
||||
setenv(RANK_ENV, std::to_string(id).data(), 1);
|
||||
setenv(
|
||||
MASTER_ADDR_ENV,
|
||||
std::string("127.0.0.1:" + std::to_string(MASTER_PORT)).data(),
|
||||
1);
|
||||
auto worker_channel = std::make_shared<thd::DataChannelGloo>(
|
||||
thd::getInitConfig("env://")); // reads all env variable
|
||||
g_mutex.unlock();
|
||||
|
||||
assert(worker_channel->init());
|
||||
run_all_tests(worker_channel, workers);
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
for (auto workers : WORKERS_NUM) {
|
||||
std::cout << "Gloo (workers: " << workers << "):" << std::endl;
|
||||
// start gloo master
|
||||
std::thread gloo_master_thread(init_gloo_master, workers);
|
||||
|
||||
// start gloo worker
|
||||
for (int id = 1; id <= workers; ++id) {
|
||||
g_all_workers.push_back(std::thread(init_gloo_worker, id, workers));
|
||||
}
|
||||
|
||||
// wait for all workers to finish
|
||||
for (auto& worker : g_all_workers) {
|
||||
worker.join();
|
||||
}
|
||||
|
||||
gloo_master_thread.join();
|
||||
g_all_workers.clear();
|
||||
|
||||
std::cout << "Gloo - OK" << std::endl;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
#include <THD/base/data_channels/DataChannelMPI.hpp>
|
||||
|
||||
#include <unistd.h>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
constexpr int WORKERS_NUM = 2;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc == 1) {
|
||||
execlp(
|
||||
"mpirun",
|
||||
"mpirun",
|
||||
"-n",
|
||||
std::to_string(WORKERS_NUM + 1).data(),
|
||||
argv[0],
|
||||
"1",
|
||||
NULL);
|
||||
}
|
||||
|
||||
auto dataChannel = std::make_shared<thd::DataChannelMPI>();
|
||||
assert(dataChannel->init());
|
||||
assert(dataChannel->getNumProcesses() == (WORKERS_NUM + 1));
|
||||
std::cout << "OK (id: " << dataChannel->getRank() << ")" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
#include <THD/base/data_channels/DataChannelTCP.hpp>
|
||||
#include <THD/test/TestUtils.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
|
||||
constexpr int WORKERS_NUM = 2;
|
||||
constexpr int MASTER_PORT = 45680;
|
||||
|
||||
void master() {
|
||||
setenv(WORLD_SIZE_ENV, std::to_string((WORKERS_NUM + 1)).data(), 1);
|
||||
setenv(RANK_ENV, "0", 1);
|
||||
setenv(MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
|
||||
auto masterChannel = std::make_shared<thd::DataChannelTCP>(
|
||||
thd::getInitConfig("env://"), 2000); // timeout after 2s
|
||||
|
||||
ASSERT_THROWS(std::exception, masterChannel->init())
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::thread master_thread(master);
|
||||
master_thread.join();
|
||||
std::cout << "OK" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -1,71 +0,0 @@
|
|||
#include <THD/base/data_channels/DataChannelTCP.hpp>
|
||||
#include <THD/test/TestUtils.hpp>
|
||||
|
||||
#include <THPP/tensors/THTensor.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
constexpr int WORKERS_NUM = 2;
|
||||
constexpr int MASTER_PORT = 45679;
|
||||
|
||||
std::vector<std::thread> g_all_workers;
|
||||
std::mutex g_mutex;
|
||||
|
||||
void master() {
|
||||
g_mutex.lock();
|
||||
setenv(WORLD_SIZE_ENV, std::to_string((WORKERS_NUM + 1)).data(), 1);
|
||||
setenv(RANK_ENV, "0", 1);
|
||||
setenv(MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
|
||||
auto masterChannel = std::make_shared<thd::DataChannelTCP>(
|
||||
thd::getInitConfig("env://")); // reads all env variable
|
||||
g_mutex.unlock();
|
||||
|
||||
// wait a long time before init
|
||||
std::this_thread::sleep_for(std::chrono::seconds(4));
|
||||
|
||||
assert(masterChannel->init());
|
||||
|
||||
auto float_tensor = buildTensor<float>({1, 2, 3}, 4);
|
||||
masterChannel->broadcast(*float_tensor, 0); // send good tensor
|
||||
|
||||
// wait for all workers to finish
|
||||
for (auto& worker : g_all_workers) {
|
||||
worker.join();
|
||||
}
|
||||
}
|
||||
|
||||
void worker(int id) {
|
||||
g_mutex.lock();
|
||||
setenv(RANK_ENV, std::to_string(id).data(), 1);
|
||||
setenv(
|
||||
MASTER_ADDR_ENV,
|
||||
std::string("127.0.0.1:" + std::to_string(MASTER_PORT)).data(),
|
||||
1);
|
||||
auto workerChannel = std::make_shared<thd::DataChannelTCP>(
|
||||
thd::getInitConfig("env://")); // reads all env variable
|
||||
g_mutex.unlock();
|
||||
|
||||
assert(workerChannel->init());
|
||||
|
||||
auto float_tensor = buildTensor<float>({1, 2, 3}, -1);
|
||||
workerChannel->broadcast(*float_tensor, 0);
|
||||
ASSERT_TENSOR_VALUE(float, *float_tensor, 4)
|
||||
}
|
||||
|
||||
int main() {
|
||||
// start master
|
||||
std::thread master_thread(master);
|
||||
|
||||
// start worker
|
||||
for (int id = 1; id <= WORKERS_NUM; ++id) {
|
||||
g_all_workers.push_back(std::thread(worker, id));
|
||||
}
|
||||
|
||||
master_thread.join();
|
||||
std::cout << "OK" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
#include <THD/base/data_channels/DataChannelTCP.hpp>
|
||||
#include <THD/test/TestUtils.hpp>
|
||||
|
||||
#include <THPP/tensors/THTensor.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
constexpr int WORKERS_NUM = 2;
|
||||
constexpr int MASTER_PORT = 45678;
|
||||
|
||||
std::vector<std::thread> g_all_workers;
|
||||
std::mutex g_mutex;
|
||||
|
||||
void master() {
|
||||
g_mutex.lock();
|
||||
setenv(WORLD_SIZE_ENV, std::to_string((WORKERS_NUM + 1)).data(), 1);
|
||||
setenv(RANK_ENV, "0", 1);
|
||||
setenv(MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
|
||||
auto masterChannel = std::make_shared<thd::DataChannelTCP>(
|
||||
thd::getInitConfig("env://")); // reads all env variable
|
||||
g_mutex.unlock();
|
||||
|
||||
assert(masterChannel->init());
|
||||
assert(masterChannel->getRank() == 0);
|
||||
assert(masterChannel->getNumProcesses() == WORKERS_NUM + 1);
|
||||
|
||||
// wait for all workers to finish
|
||||
for (auto& worker : g_all_workers) {
|
||||
worker.join();
|
||||
}
|
||||
}
|
||||
|
||||
void worker(int id) {
|
||||
g_mutex.lock();
|
||||
setenv(RANK_ENV, std::to_string(id).data(), 1);
|
||||
setenv(
|
||||
MASTER_ADDR_ENV,
|
||||
std::string("127.0.0.1:" + std::to_string(MASTER_PORT)).data(),
|
||||
1);
|
||||
auto workerChannel = std::make_shared<thd::DataChannelTCP>(
|
||||
thd::getInitConfig("env://")); // reads all env variable
|
||||
g_mutex.unlock();
|
||||
|
||||
assert(workerChannel->init());
|
||||
assert(workerChannel->getRank() == id);
|
||||
assert(workerChannel->getNumProcesses() == WORKERS_NUM + 1);
|
||||
}
|
||||
|
||||
int main() {
|
||||
// start master
|
||||
std::thread master_thread(master);
|
||||
|
||||
// start worker
|
||||
for (int id = 1; id <= WORKERS_NUM; ++id) {
|
||||
g_all_workers.push_back(std::thread(worker, id));
|
||||
}
|
||||
|
||||
master_thread.join();
|
||||
std::cout << "OK" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <typeinfo>
|
||||
|
||||
#include <THPP/tensors/THTensor.hpp>
|
||||
|
||||
using namespace std;
|
||||
|
||||
int main() {
|
||||
thpp::FloatTensor* tensor = new thpp::THTensor<float>();
|
||||
thpp::FloatTensor* tensor2 = new thpp::THTensor<float>();
|
||||
assert(tensor->nDim() == 0);
|
||||
|
||||
tensor->resize({1, 2, 3});
|
||||
assert(tensor->nDim() == 3);
|
||||
int i = 0;
|
||||
for (auto s : tensor->sizes())
|
||||
assert(s == ++i);
|
||||
|
||||
vector<int64_t> sizes = {2, 2};
|
||||
tensor2->resize(sizes);
|
||||
tensor2->fill(4);
|
||||
tensor->add(*tensor2, 1);
|
||||
assert(tensor->nDim() == 2);
|
||||
|
||||
for (auto s : tensor->sizes())
|
||||
assert(s == 2);
|
||||
for (int i = 0; i < 2; i++)
|
||||
assert(reinterpret_cast<float*>(tensor->data())[i] == 5);
|
||||
|
||||
bool thrown = false;
|
||||
try {
|
||||
thpp::IntTensor& a = dynamic_cast<thpp::IntTensor&>(*tensor);
|
||||
} catch (std::bad_cast& e) {
|
||||
thrown = true;
|
||||
}
|
||||
assert(thrown);
|
||||
|
||||
delete tensor;
|
||||
delete tensor2;
|
||||
cout << "OK" << endl;
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
# THD refactor
|
||||
|
||||
This is a work in progress. It is separate from the main THD directory
|
||||
to avoid disrupting THD users or have to deal with backwards compat
|
||||
early on. Once this gets to a usable state, we'll add Python bindings
|
||||
and a compat layer.
|
||||
|
||||
See https://github.com/pytorch/pytorch/issues/7434 for the main issue.
|
||||
|
||||
This tree is intentionally not part of the main build and will be
|
||||
buildable/testable in isolation, as long as ATen is available in
|
||||
`<repository root>/torch/lib/tmp_install`.
|
||||
|
||||
To build and install ATen here, navigate to the root of this
|
||||
repository and run:
|
||||
|
||||
``` shell
|
||||
tools/build_pytorch_libs.sh --with-cuda ATen
|
||||
```
|
||||
|
|
@ -4,7 +4,6 @@ from .data_parallel import DataParallel, data_parallel
|
|||
from .scatter_gather import scatter, gather
|
||||
from .distributed import DistributedDataParallel
|
||||
from .distributed_cpu import DistributedDataParallelCPU
|
||||
import torch.nn.parallel.deprecated # noqa: F401
|
||||
|
||||
__all__ = ['replicate', 'scatter', 'parallel_apply', 'gather', 'data_parallel',
|
||||
'DataParallel', 'DistributedDataParallel', 'DistributedDataParallelCPU']
|
||||
|
|
|
|||
|
|
@ -1,4 +0,0 @@
|
|||
from .distributed import DistributedDataParallel
|
||||
from .distributed_cpu import DistributedDataParallelCPU
|
||||
|
||||
__all__ = ['DistributedDataParallel', 'DistributedDataParallelCPU']
|
||||
|
|
@ -1,484 +0,0 @@
|
|||
import sys
|
||||
import threading
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors, \
|
||||
_take_tensors
|
||||
import torch.utils.hooks
|
||||
|
||||
from torch.cuda.comm import broadcast_coalesced
|
||||
from torch.cuda import nccl
|
||||
import torch.distributed.deprecated as dist
|
||||
|
||||
from ...modules import Module
|
||||
from ..replicate import replicate
|
||||
from ..scatter_gather import scatter_kwargs, gather
|
||||
from ..parallel_apply import parallel_apply
|
||||
|
||||
if sys.version_info[0] == 3:
|
||||
import queue
|
||||
else:
|
||||
import Queue as queue
|
||||
|
||||
|
||||
class DistributedDataParallel(Module):
|
||||
r"""Implements distributed data parallelism at the module level.
|
||||
|
||||
This container parallelizes the application of the given module by
|
||||
splitting the input across the specified devices by chunking in the batch
|
||||
dimension. The module is replicated on each machine and each device, and
|
||||
each such replica handles a portion of the input. During the backwards
|
||||
pass, gradients from each node are averaged.
|
||||
|
||||
The batch size should be larger than the number of GPUs used locally. It
|
||||
should also be an integer multiple of the number of GPUs so that each chunk
|
||||
is the same size (so that each GPU processes the same number of samples).
|
||||
|
||||
See also: :ref:`distributed-basics` and :ref:`cuda-nn-dataparallel-instead`.
|
||||
The same constraints on input as in :class:`torch.nn.DataParallel` apply.
|
||||
|
||||
Creation of this class requires the distributed package to be already
|
||||
initialized in the process group mode
|
||||
(see :func:`torch.distributed.deprecated.init_process_group`).
|
||||
|
||||
.. warning::
|
||||
This module works only with the ``nccl`` and ``gloo`` backends.
|
||||
|
||||
.. warning::
|
||||
Constructor, forward method, and differentiation of the output (or a
|
||||
function of the output of this module) is a distributed synchronization
|
||||
point. Take that into account in case different processes might be
|
||||
executing different code.
|
||||
|
||||
.. warning::
|
||||
This module assumes all parameters are registered in the model by the
|
||||
time it is created. No parameters should be added nor removed later.
|
||||
Same applies to buffers.
|
||||
|
||||
.. warning::
|
||||
This module assumes all buffers and gradients are dense.
|
||||
|
||||
.. warning::
|
||||
This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
|
||||
only work if gradients are to be accumulated in ``.grad`` attributes of
|
||||
parameters).
|
||||
|
||||
.. warning::
|
||||
If you plan on using this module with a ``nccl`` backend or a ``gloo``
|
||||
backend (that uses Infiniband), together with a DataLoader that uses
|
||||
multiple workers, please change the multiprocessing start method to
|
||||
``forkserver`` (Python 3 only) or ``spawn``. Unfortunately
|
||||
Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will
|
||||
likely experience deadlocks if you don't change this setting.
|
||||
|
||||
.. note::
|
||||
Parameters are never broadcast between processes. The module performs
|
||||
an all-reduce step on gradients and assumes that they will be modified
|
||||
by the optimizer in all processes in the same way. Buffers
|
||||
(e.g. BatchNorm stats) are broadcast from the module in process of rank
|
||||
0, to all other replicas in the system in every iteration.
|
||||
|
||||
.. warning::
|
||||
Forward and backward hooks defined on :attr:`module` and its submodules
|
||||
won't be invoked anymore, unless the hooks are initialized in the
|
||||
:meth:`forward` method.
|
||||
|
||||
Args:
|
||||
module: module to be parallelized
|
||||
device_ids: CUDA devices (default: all devices)
|
||||
output_device: device location of output (default: device_ids[0])
|
||||
broadcast_buffers: flag that enables syncing (broadcasting) buffers of
|
||||
the module at beginning of the forward function.
|
||||
(default: True)
|
||||
|
||||
Attributes:
|
||||
module (Module): the module to be parallelized
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.distributed.deprecated.init_process_group(world_size=4, init_method='...')
|
||||
>>> net = torch.nn.DistributedDataParallel(model)
|
||||
"""
|
||||
|
||||
def __init__(self, module, device_ids=None, output_device=None, dim=0,
|
||||
broadcast_buffers=True):
|
||||
super(DistributedDataParallel, self).__init__()
|
||||
if dist._backend not in (dist.dist_backend.NCCL, dist.dist_backend.GLOO):
|
||||
raise ValueError('Invalid backend, only NCCL and GLOO backends are supported by DistributedDataParallel')
|
||||
|
||||
if device_ids is None:
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
if output_device is None:
|
||||
output_device = device_ids[0]
|
||||
self.dim = dim
|
||||
self.module = module
|
||||
self.device_ids = device_ids
|
||||
self.output_device = output_device
|
||||
self.broadcast_buffers = broadcast_buffers
|
||||
|
||||
# Flag used by the NCCL backend to make sure we only reduce gradients
|
||||
# one time in the execution engine
|
||||
self.need_reduction = False
|
||||
|
||||
MB = 1024 * 1024
|
||||
# used for intra-node param sync and inter-node sync as well
|
||||
self.broadcast_bucket_size = 10 * MB
|
||||
self.nccl_reduce_bucket_size = 256 * MB
|
||||
|
||||
# Sync params and buffers
|
||||
module_states = list(self.module.state_dict().values())
|
||||
if len(module_states) > 0:
|
||||
self._dist_broadcast_coalesced(module_states,
|
||||
self.broadcast_bucket_size)
|
||||
|
||||
if len(device_ids) > 1:
|
||||
# TODO: we don't need to replicate params in here. they're always going to
|
||||
# be broadcasted using larger blocks in broadcast_coalesced, so it might be
|
||||
# better to not pollute the caches with these small blocks
|
||||
self._module_copies = replicate(self.module, self.device_ids, detach=True)
|
||||
self._module_copies[0] = self.module
|
||||
|
||||
for module_copy in self._module_copies[1:]:
|
||||
for param, copy_param in zip(self.module.parameters(), module_copy.parameters()):
|
||||
copy_param.requires_grad = param.requires_grad
|
||||
|
||||
else:
|
||||
self._module_copies = [self.module]
|
||||
|
||||
# For NCCL backend, since every single NCCL call is asynchoronous, we
|
||||
# therefore directly enqueue all the NCCL reduction calls to the
|
||||
# default CUDA stream without spawning up other reduction threads.
|
||||
# This achieves the best performance.
|
||||
if dist._backend == dist.dist_backend.NCCL:
|
||||
self._register_nccl_grad_hook()
|
||||
return
|
||||
|
||||
bucket_bytes_cap = 1 * MB
|
||||
|
||||
# This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
|
||||
param_buckets = []
|
||||
# Split the parameters into buckets and by types as well
|
||||
for dev_idx, module in enumerate(self._module_copies):
|
||||
param_buckets.append(list(_take_tensors(module.parameters(), bucket_bytes_cap)))
|
||||
|
||||
self.bucket_sizes = []
|
||||
self.bucket_map = {}
|
||||
|
||||
# We transpose param_buckets, so the loop is over buckets.
|
||||
# param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems
|
||||
for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)):
|
||||
self.bucket_sizes.append(0)
|
||||
# Now, we transpose again, so we iterate over bucket_elems, but getting tuples
|
||||
# of params from each device.
|
||||
for idx, param_tuple in enumerate(zip(*param_buckets_tuple)):
|
||||
if idx == 0:
|
||||
# Bucket parameter type tracking
|
||||
bucket_param_type = param_tuple[0].type()
|
||||
# Only gloo and nccl support half-precision
|
||||
if bucket_param_type == torch.cuda.HalfTensor and \
|
||||
dist._backend != dist.dist_backend.GLOO:
|
||||
raise RuntimeError("DistributedDataParallel currently only "
|
||||
"supports half precision parameters "
|
||||
"with Nccl and Gloo backend")
|
||||
if not param_tuple[0].requires_grad:
|
||||
continue
|
||||
for p in param_tuple:
|
||||
self.bucket_map[p] = bucket_idx
|
||||
self.bucket_sizes[bucket_idx] += 1
|
||||
|
||||
self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))]
|
||||
self.bucket_events = [[None] * len(self.device_ids) for _ in range(len(self.bucket_sizes))]
|
||||
self.reduced = [False] * len(self.bucket_sizes)
|
||||
|
||||
self._register_grad_hooks()
|
||||
self.dispatch_lock = threading.Lock()
|
||||
self._start_reduction_threads()
|
||||
|
||||
def __getstate__(self):
|
||||
attrs = copy.copy(self.__dict__)
|
||||
if dist._backend != dist.dist_backend.NCCL:
|
||||
del attrs['_grad_accs'], attrs['_reduction_queues'], \
|
||||
attrs['_reduction_streams'], attrs['_reduction_threads'], \
|
||||
attrs['_nccl_streams'], attrs['_default_streams'], \
|
||||
attrs['dispatch_lock']
|
||||
return attrs
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(DistributedDataParallel, self).__setstate__(state)
|
||||
if dist._backend == dist.dist_backend.NCCL:
|
||||
self._register_nccl_grad_hook()
|
||||
else:
|
||||
self._register_grad_hooks()
|
||||
self.dispatch_lock = threading.Lock()
|
||||
self._start_reduction_threads()
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
self.need_reduction = True
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||
self._sync_params()
|
||||
if len(self.device_ids) == 1:
|
||||
return self.module(*inputs[0], **kwargs[0])
|
||||
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
|
||||
return self.gather(outputs, self.output_device)
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
||||
|
||||
def parallel_apply(self, replicas, inputs, kwargs):
|
||||
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
||||
|
||||
def gather(self, outputs, output_device):
|
||||
return gather(outputs, output_device, dim=self.dim)
|
||||
|
||||
def train(self, mode=True):
|
||||
super(DistributedDataParallel, self).train(mode)
|
||||
for module in self._module_copies[1:]:
|
||||
module.train(mode)
|
||||
|
||||
def _dist_broadcast_coalesced(self, tensors, buffer_size):
|
||||
"""
|
||||
Broadcast a sequence of tensors to the default group from rank 0.
|
||||
Small tensors are first coalesced into a buffer to reduce the number of
|
||||
broadcasts.
|
||||
|
||||
tensors (sequence): tensors to broadcast. Each tensor needs to be on the
|
||||
same GPU.
|
||||
buffer_size (int): maximum size of the buffer for coalescing
|
||||
"""
|
||||
for tensors in _take_tensors(tensors, buffer_size):
|
||||
flat_tensors = _flatten_dense_tensors(tensors)
|
||||
dist.broadcast(flat_tensors, 0)
|
||||
for tensor, synced in zip(tensors,
|
||||
_unflatten_dense_tensors(flat_tensors, tensors)):
|
||||
tensor.copy_(synced)
|
||||
|
||||
def _sync_params(self):
|
||||
if len(self.device_ids) > 1:
|
||||
# intra-node parameter sync
|
||||
params = [p.data for p in self.module.parameters()]
|
||||
result = broadcast_coalesced(params, self.device_ids, self.broadcast_bucket_size)
|
||||
for tensors, module in zip(result[1:], self._module_copies[1:]):
|
||||
for tensor, param in zip(tensors, module.parameters()):
|
||||
param.data.set_(tensor)
|
||||
|
||||
# module buffer sync
|
||||
if self.broadcast_buffers:
|
||||
buffers = [b.data for b in self.module.buffers()]
|
||||
if len(buffers) > 0:
|
||||
# cross-node buffer sync
|
||||
self._dist_broadcast_coalesced(buffers, self.broadcast_bucket_size)
|
||||
|
||||
if len(self.device_ids) > 1:
|
||||
# intra-node buffer sync
|
||||
result = broadcast_coalesced(buffers, self.device_ids, self.broadcast_bucket_size)
|
||||
for tensors, module in zip(result[1:], self._module_copies[1:]):
|
||||
for tensor, buf in zip(tensors, module.buffers()):
|
||||
buf.data.set_(tensor)
|
||||
|
||||
def _register_grad_hooks(self):
|
||||
self._grad_accs = [] # need to keep them in scope
|
||||
for device_idx, module in enumerate(self._module_copies):
|
||||
for p in module.parameters():
|
||||
if p.requires_grad:
|
||||
p_tmp = p.expand_as(p)
|
||||
grad_acc = p_tmp.grad_fn.next_functions[0][0]
|
||||
grad_acc.register_hook(self._make_param_hook(p, device_idx))
|
||||
self._grad_accs.append(grad_acc)
|
||||
|
||||
def _register_nccl_grad_hook(self):
|
||||
"""
|
||||
This function registers the callback all-reduction function for the
|
||||
NCCL backend. All gradients will be all reduced in one single step.
|
||||
The NCCL reduction will directly be enqueued into the
|
||||
default CUDA stream. Therefore, no synchronization is needed.
|
||||
"""
|
||||
# Creating a new group
|
||||
self.nccl_reduction_group_id = dist.new_group()
|
||||
|
||||
def reduction_fn_nccl():
|
||||
# This function only needs to be called once
|
||||
if not self.need_reduction:
|
||||
return
|
||||
|
||||
self.need_reduction = False
|
||||
all_grads = [[] for _ in range(len(self._module_copies))]
|
||||
all_grads_buckets_iters = []
|
||||
|
||||
# Bucketing all the gradients
|
||||
for dev_idx, module in enumerate(self._module_copies):
|
||||
for param in module.parameters():
|
||||
if not param.requires_grad or param.grad is None:
|
||||
continue
|
||||
if param.grad.requires_grad:
|
||||
raise RuntimeError("DistributedDataParallel only works "
|
||||
"with gradients that don't require "
|
||||
"grad")
|
||||
# Adding the gradients for reduction
|
||||
all_grads[dev_idx].append(param.grad.data)
|
||||
|
||||
# Now bucketing the parameters
|
||||
dev_grads_buckets = _take_tensors(all_grads[dev_idx],
|
||||
self.nccl_reduce_bucket_size)
|
||||
|
||||
all_grads_buckets_iters.append(dev_grads_buckets)
|
||||
|
||||
# Now reduce each bucket one after another
|
||||
for grads_batch in zip(*all_grads_buckets_iters):
|
||||
grads_batch_coalesced = []
|
||||
# Coalesce each bucket
|
||||
for dev_idx, dev_grads_batch in enumerate(grads_batch):
|
||||
dev_id = self.device_ids[dev_idx]
|
||||
with torch.cuda.device(dev_id):
|
||||
dev_grads_batch_coalesced = _flatten_dense_tensors(dev_grads_batch)
|
||||
grads_batch_coalesced.append(dev_grads_batch_coalesced)
|
||||
|
||||
# We will only use device 0's results, but this single op should be
|
||||
# faster than doing the following two operation sequentially:
|
||||
# (1) intra-node reduce to lead GPU, followed by
|
||||
# (2) inter-node allreduce for all the first lead GPUs in all nodes
|
||||
dist.all_reduce_multigpu(grads_batch_coalesced,
|
||||
group=self.nccl_reduction_group_id)
|
||||
|
||||
# Now only work on the first device of self.device_ids, uncoalesce
|
||||
# the gradients for each bucket
|
||||
grads_batch_coalesced[0] /= dist.get_world_size()
|
||||
grads_batch_reduced = _unflatten_dense_tensors(grads_batch_coalesced[0], grads_batch[0])
|
||||
for grad, reduced in zip(grads_batch[0], grads_batch_reduced):
|
||||
grad.copy_(reduced)
|
||||
|
||||
# clear the gradients and save memory for replicas
|
||||
for module in self._module_copies[1:]:
|
||||
for param in module.parameters():
|
||||
if param.requires_grad:
|
||||
param.grad = None
|
||||
param.data.set_()
|
||||
|
||||
# Now register the reduction hook on the parameters
|
||||
for p in self.module.parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
|
||||
@torch.utils.hooks.unserializable_hook
|
||||
def allreduce_hook(*unused):
|
||||
Variable._execution_engine.queue_callback(reduction_fn_nccl)
|
||||
|
||||
p.register_hook(allreduce_hook)
|
||||
|
||||
def _make_param_hook(self, param, device_idx):
|
||||
|
||||
bucket_idx = self.bucket_map[param]
|
||||
|
||||
def distributed_data_parallel_hook(*unused):
|
||||
if param.grad.requires_grad:
|
||||
raise RuntimeError("DistributedDataParallel only works with "
|
||||
"gradients that don't require grad")
|
||||
bucket = self.buckets[bucket_idx][device_idx]
|
||||
bucket.append(param.grad.data)
|
||||
|
||||
# We can flush these and save memory for replicas
|
||||
if device_idx > 0:
|
||||
param.grad = None
|
||||
param.data.set_()
|
||||
|
||||
# Current device's bucket is full
|
||||
if len(bucket) == self.bucket_sizes[bucket_idx]:
|
||||
with torch.cuda.device(self.device_ids[device_idx]):
|
||||
event = torch.cuda.Event()
|
||||
event.record()
|
||||
with self.dispatch_lock:
|
||||
self.bucket_events[bucket_idx][device_idx] = event
|
||||
self._queue_reduction(bucket_idx)
|
||||
|
||||
return distributed_data_parallel_hook
|
||||
|
||||
def _queue_reduction(self, bucket_idx):
|
||||
dev_buckets = self.buckets[bucket_idx]
|
||||
dev_events = self.bucket_events[bucket_idx]
|
||||
|
||||
# Check if it's ready
|
||||
if any(evt is None for evt in dev_events):
|
||||
return
|
||||
|
||||
# Queue the reduction and make sure backward waits for it
|
||||
event = threading.Event()
|
||||
self._reduction_queues[bucket_idx].put((dev_buckets, dev_events, event))
|
||||
Variable._execution_engine.queue_callback(lambda: event.wait())
|
||||
|
||||
# Reset bucket state
|
||||
self.buckets[bucket_idx] = [[] for _ in range(len(self.device_ids))]
|
||||
self.bucket_events[bucket_idx] = [None] * len(self.device_ids)
|
||||
self.reduced[bucket_idx] = True
|
||||
if all(self.reduced):
|
||||
self.reduced = [False] * len(self.bucket_sizes)
|
||||
|
||||
def sync_reduction_streams():
|
||||
# We only have to sync with the first one, but it's safer to do it this way
|
||||
# in case we change the way in which we parallelize work
|
||||
r_streams = zip(*self._reduction_streams)
|
||||
for dev_id, default_stream, dev_r_streams in zip(self.device_ids, self._default_streams, r_streams):
|
||||
with torch.cuda.device(dev_id):
|
||||
for reduction_stream in dev_r_streams:
|
||||
default_stream.wait_stream(reduction_stream)
|
||||
Variable._execution_engine.queue_callback(sync_reduction_streams)
|
||||
|
||||
def _start_reduction_threads(self):
|
||||
num_buckets = len(self.bucket_sizes)
|
||||
self._reduction_queues = [queue.Queue() for _ in range(num_buckets)]
|
||||
self._reduction_threads = []
|
||||
self._reduction_streams = [[] for _ in range(num_buckets)]
|
||||
self._nccl_streams = []
|
||||
self._default_streams = []
|
||||
for dev_id in self.device_ids:
|
||||
with torch.cuda.device(dev_id):
|
||||
# TODO: don't assume we're on a default stream
|
||||
self._default_streams.append(torch.cuda.current_stream())
|
||||
self._nccl_streams.append(torch.cuda.Stream())
|
||||
for reduction_queue, reduction_streams in zip(self._reduction_queues, self._reduction_streams):
|
||||
for dev_id in self.device_ids:
|
||||
with torch.cuda.device(dev_id):
|
||||
reduction_streams.append(torch.cuda.Stream())
|
||||
# We only use the first device for distributed reductions
|
||||
dist._register_stream(reduction_streams[0])
|
||||
|
||||
group_id = dist.new_group()
|
||||
|
||||
self._reduction_threads.append(threading.Thread(
|
||||
target=self._reduction_thread_fn,
|
||||
args=(reduction_queue, group_id, self.device_ids, reduction_streams, self._nccl_streams)))
|
||||
self._reduction_threads[-1].daemon = True
|
||||
self._reduction_threads[-1].start()
|
||||
|
||||
@staticmethod
|
||||
def _reduction_thread_fn(queue, group_id, device_ids, reduction_streams, nccl_streams):
|
||||
|
||||
def _process_batch():
|
||||
dev_grad_batch, dev_events, job_event = queue.get()
|
||||
dev_coalesced = []
|
||||
# Coalesce the tensors on all devices and start a local reduction
|
||||
for dev_id, grad_batch, event, stream in zip(device_ids, dev_grad_batch, dev_events, reduction_streams):
|
||||
with torch.cuda.device(dev_id), torch.cuda.stream(stream):
|
||||
stream.wait_event(event)
|
||||
coalesced = _flatten_dense_tensors(grad_batch)
|
||||
dev_coalesced.append(coalesced)
|
||||
# Wait for all copies to complete before starting the NCCL kernel
|
||||
for stream in reduction_streams:
|
||||
stream.synchronize()
|
||||
nccl.reduce(dev_coalesced, root=0, streams=nccl_streams)
|
||||
|
||||
# From now on we're only going to work on the first device (from device_ids)
|
||||
grad_batch = dev_grad_batch[0]
|
||||
coalesced = dev_coalesced[0]
|
||||
reduce_stream = reduction_streams[0]
|
||||
with torch.cuda.stream(reduce_stream):
|
||||
reduce_stream.wait_stream(nccl_streams[0])
|
||||
coalesced /= dist.get_world_size()
|
||||
dist.all_reduce(coalesced, group=group_id)
|
||||
for grad, reduced in zip(grad_batch, _unflatten_dense_tensors(coalesced, grad_batch)):
|
||||
grad.copy_(reduced)
|
||||
job_event.set()
|
||||
|
||||
with torch.cuda.device(device_ids[0]):
|
||||
while True:
|
||||
_process_batch() # just to have a clear scope
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
import torch
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
import torch.distributed.deprecated as dist
|
||||
from torch.nn.modules import Module
|
||||
from collections import defaultdict
|
||||
from torch.autograd import Variable
|
||||
import torch.utils.hooks
|
||||
|
||||
|
||||
class DistributedDataParallelCPU(Module):
|
||||
r"""Implements distributed data parallelism for CPU at the module level.
|
||||
|
||||
This module support the ``mpi``, ``gloo``, ``tcp`` backends.
|
||||
|
||||
This container parallelizes the application of the given module by
|
||||
splitting the input across the specified devices by chunking in the batch
|
||||
dimension. The module is replicated on each machine, and each such replica
|
||||
handles a portion of the input. During the backwards pass, gradients from
|
||||
each node are averaged.
|
||||
|
||||
This module could be used in conjunction with the DistributedSampler,
|
||||
(see :class `torch.utils.data.distributed.DistributedSampler`)
|
||||
which will load a subset of the original dataset for each node with the same
|
||||
batch size. So strong scaling should be configured like this:
|
||||
n = 1, batch size = 128
|
||||
n = 2, batch size = 64
|
||||
n = 4, batch size = 32
|
||||
n = 8, batch size = 16
|
||||
|
||||
Creation of this class requires the distributed package to be already
|
||||
initialized in the process group mode
|
||||
(see :func:`torch.distributed.deprecated.init_process_group`).
|
||||
|
||||
.. warning::
|
||||
Constructor, forward method, and differentiation of the output (or a
|
||||
function of the output of this module) is a distributed synchronization
|
||||
point. Take that into account in case different node might be
|
||||
executing different code.
|
||||
|
||||
.. warning::
|
||||
This module assumes all parameters are registered in the model by the
|
||||
time it is created. No parameters should be added nor removed later.
|
||||
|
||||
.. warning::
|
||||
This module assumes all gradients are dense.
|
||||
|
||||
.. warning::
|
||||
This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
|
||||
only work if gradients are to be accumulated in ``.grad`` attributes of
|
||||
parameters).
|
||||
|
||||
.. note::
|
||||
Parameters are broadcast between nodes in the __init__() function. The
|
||||
module performs an all-reduce step on gradients and assumes that they
|
||||
will be modified by the optimizer in all nodes in the same way.
|
||||
|
||||
.. warning::
|
||||
Forward and backward hooks defined on :attr:`module` and its submodules
|
||||
won't be invoked anymore, unless the hooks are initialized in the
|
||||
:meth:`forward` method.
|
||||
|
||||
Args:
|
||||
module: module to be parallelized
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.distributed.deprecated.init_process_group(world_size=4, init_method='...')
|
||||
>>> net = torch.nn.DistributedDataParallelCPU(model)
|
||||
"""
|
||||
|
||||
def __init__(self, module):
|
||||
super(DistributedDataParallelCPU, self).__init__()
|
||||
self.module = module
|
||||
self.sync_parameters()
|
||||
|
||||
def allreduce_params():
|
||||
if self.needs_reduction:
|
||||
self.needs_reduction = False
|
||||
buckets = defaultdict(list)
|
||||
for param in self.module.parameters():
|
||||
if param.requires_grad and param.grad is not None:
|
||||
tp = type(param.data)
|
||||
buckets[tp].append(param)
|
||||
|
||||
for bucket in buckets.values():
|
||||
grads = [param.grad.data for param in bucket]
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced)
|
||||
coalesced /= dist.get_world_size()
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
for param in list(self.module.parameters()):
|
||||
@torch.utils.hooks.unserializable_hook
|
||||
def allreduce_hook(*unused):
|
||||
Variable._execution_engine.queue_callback(allreduce_params)
|
||||
|
||||
if param.requires_grad:
|
||||
param.register_hook(allreduce_hook)
|
||||
|
||||
def sync_parameters(self):
|
||||
for param in self.module.parameters():
|
||||
dist.broadcast(param.data, 0)
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
self.needs_reduction = True
|
||||
return self.module(*inputs, **kwargs)
|
||||
Loading…
Reference in New Issue
Block a user