mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.
I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.
I used the following script to do the canonicalization:
```
import subprocess
import re
import os.path
files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
for fn in files:
if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
continue
if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
continue
with open(fn, 'r') as f:
c = f.read()
def fmt(p):
return "#include <{}>".format(p)
def repl(m):
p = m.group(1)
if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
return fmt(p)
if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
return fmt(p)
for root in ["aten/src", "torch/lib", ""]:
for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
new_p = os.path.relpath(os.path.join(bad_root, p), root)
if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
return fmt(new_p)
print("ERROR: ", fn, p)
return m.group(0)
new_c = re.sub(r'#include "([^"]+)"', repl, c)
if new_c != c:
print(fn)
with open(fn, 'w') as f:
f.write(new_c)
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849
Reviewed By: dzhulgakov
Differential Revision: D13363445
Pulled By: ezyang
fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
381 lines
11 KiB
C++
381 lines
11 KiB
C++
#include <torch/csrc/cuda/python_nccl.h>
|
|
|
|
#include <torch/csrc/cuda/nccl.h>
|
|
#include <torch/csrc/DynamicTypes.h>
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/THP.h>
|
|
#include <torch/csrc/Types.h>
|
|
#include <torch/csrc/cuda/THCP.h>
|
|
#include <torch/csrc/cuda/nccl.h>
|
|
#include <torch/csrc/utils/functional.h>
|
|
|
|
#include <ATen/cuda/CUDAGuard.h>
|
|
|
|
#include <nccl.h>
|
|
|
|
#include <sstream>
|
|
#include <unordered_map>
|
|
|
|
using namespace at;
|
|
using namespace torch;
|
|
using namespace torch::cuda::nccl;
|
|
using namespace torch::cuda::nccl::detail;
|
|
|
|
static const char* COMM_CAPSULE_NAME = "torch.cuda.nccl.Communicator";
|
|
|
|
PyObject* THCPModule_nccl_version(PyObject* self, PyObject* args) {
|
|
return PyInt_FromLong(version());
|
|
}
|
|
|
|
PyObject* THCPModule_nccl_unique_id(PyObject* self, PyObject* args) {
|
|
HANDLE_TH_ERRORS
|
|
ncclUniqueId id;
|
|
NCCL_CHECK(ncclGetUniqueId(&id));
|
|
return PyBytes_FromStringAndSize((char*)&id, NCCL_UNIQUE_ID_BYTES);
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static ncclComm_t unpack_nccl_comm(PyObject* capsule) {
|
|
ncclComm_t comm =
|
|
(ncclComm_t)PyCapsule_GetPointer(capsule, COMM_CAPSULE_NAME);
|
|
if (!comm)
|
|
throw python_error();
|
|
return comm;
|
|
}
|
|
|
|
static void destroy_nccl_comm(PyObject* capsule) {
|
|
/*
|
|
* TODO(T30279827) Temporarily disable calling ncclCommDestroy
|
|
* Calling ncclCommDestroy while program exiting is undefined
|
|
* according to Nvidia, and lead to segfault in NCCL 2
|
|
* (whether it is called before or after the CUDA runtime destructor).
|
|
* Temporarily disable it in destructor to avoid segfault.
|
|
* Following up with Nvidia for long term solution.
|
|
*/
|
|
return;
|
|
|
|
HANDLE_TH_ERRORS
|
|
ncclComm_t comm = unpack_nccl_comm(capsule);
|
|
with_no_gil([&] { ncclCommDestroy(comm); });
|
|
END_HANDLE_TH_ERRORS_RET()
|
|
}
|
|
|
|
static std::vector<c10::optional<at::cuda::CUDAStream>> unpack_streams(PyObject* obj, size_t size) {
|
|
if (obj == Py_None) {
|
|
return std::vector<c10::optional<at::cuda::CUDAStream>>(size, c10::nullopt);
|
|
}
|
|
auto streams = THPUtils_PySequence_to_CUDAStreamList(obj);
|
|
if (streams.size() != size) {
|
|
throw std::runtime_error(
|
|
"number of streams is not equal to number of inputs");
|
|
}
|
|
return streams;
|
|
}
|
|
|
|
static std::vector<at::Tensor> extract_tensors(PyObject* obj);
|
|
|
|
static std::vector<ncclComm_t> unpack_comms(PyObject* obj, size_t size) {
|
|
if (obj == Py_None) {
|
|
return std::vector<ncclComm_t>();
|
|
}
|
|
std::vector<ncclComm_t> comms;
|
|
if (PyCapsule_CheckExact(obj)) {
|
|
comms = {unpack_nccl_comm(obj)};
|
|
} else {
|
|
auto seq = THPObjectPtr(PySequence_Fast(obj, "comm is not a sequence"));
|
|
if (!seq)
|
|
throw python_error();
|
|
auto size = PySequence_Fast_GET_SIZE(seq.get());
|
|
comms = std::vector<ncclComm_t>(size);
|
|
for (int64_t i = 0; i < size; i++) {
|
|
comms[i] = unpack_nccl_comm(PySequence_Fast_GET_ITEM(seq.get(), i));
|
|
}
|
|
}
|
|
if (comms.size() != size) {
|
|
throw std::runtime_error(
|
|
"number of communicators is not equal to number of inputs");
|
|
}
|
|
return comms;
|
|
}
|
|
|
|
PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
|
|
HANDLE_TH_ERRORS
|
|
int nranks;
|
|
const char* id;
|
|
Py_ssize_t id_len;
|
|
int rank;
|
|
|
|
if (!PyArg_ParseTuple(
|
|
args, "is#i:nccl_init_rank", &nranks, &id, &id_len, &rank)) {
|
|
return nullptr;
|
|
}
|
|
THPUtils_assert(
|
|
id_len == NCCL_UNIQUE_ID_BYTES,
|
|
"invalid unqiue_id (expected %d bytes, got %zd)",
|
|
NCCL_UNIQUE_ID_BYTES,
|
|
id_len);
|
|
|
|
ncclUniqueId commId;
|
|
memcpy(&commId, id, NCCL_UNIQUE_ID_BYTES);
|
|
ncclComm_t comm;
|
|
with_no_gil(
|
|
[&] { NCCL_CHECK(ncclCommInitRank(&comm, nranks, commId, rank)); });
|
|
return PyCapsule_New(comm, COMM_CAPSULE_NAME, &destroy_nccl_comm);
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) {
|
|
HANDLE_TH_ERRORS
|
|
PyObject *_inputs, *_outputs, *_streams, *_comms;
|
|
int root, op;
|
|
|
|
if (!PyArg_ParseTuple(
|
|
args,
|
|
"OOiiOO",
|
|
&_inputs,
|
|
&_outputs,
|
|
&root,
|
|
&op,
|
|
&_streams,
|
|
&_comms)) {
|
|
THPUtils_invalidArguments(
|
|
args,
|
|
nullptr,
|
|
"nccl_reduce",
|
|
1,
|
|
"(sequence[Tensor] inputs, sequence[Tensor] outputs, int root,"
|
|
" int op, sequence[torch.cuda.Stream or None]");
|
|
return nullptr;
|
|
}
|
|
|
|
std::vector<at::Tensor> inputs = extract_tensors(_inputs);
|
|
std::vector<at::Tensor> outputs = extract_tensors(_outputs);
|
|
std::vector<c10::optional<at::cuda::CUDAStream>> streams = unpack_streams(_streams, inputs.size());
|
|
auto user_comms = unpack_comms(_comms, inputs.size());
|
|
|
|
with_no_gil([&] {
|
|
torch::cuda::nccl::reduce(inputs, outputs, root, op, streams, user_comms);
|
|
});
|
|
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
|
|
HANDLE_TH_ERRORS
|
|
PyObject *_inputs, *_outputs, *_streams, *_comms;
|
|
int op;
|
|
|
|
if (!PyArg_ParseTuple(
|
|
args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
|
|
THPUtils_invalidArguments(
|
|
args,
|
|
nullptr,
|
|
"nccl_all_reduce",
|
|
1,
|
|
"(sequence[Tensor] inputs, sequence[Tensor] outputs, int op,"
|
|
" sequence[torch.cuda.Stream] streams,"
|
|
" sequence[torch.cuda.nccl.Communicator] comms)");
|
|
return nullptr;
|
|
}
|
|
|
|
std::vector<at::Tensor> inputs = extract_tensors(_inputs);
|
|
std::vector<at::Tensor> outputs = extract_tensors(_outputs);
|
|
auto streams = unpack_streams(_streams, inputs.size());
|
|
auto user_comms = unpack_comms(_comms, inputs.size());
|
|
|
|
with_no_gil([&] {
|
|
_check_inputs(inputs, outputs, 1, 1);
|
|
size_t len = inputs.size();
|
|
|
|
ncclDataType_t data_type = _get_data_type(inputs[0].type());
|
|
|
|
int64_t count = inputs[0].numel();
|
|
std::lock_guard<std::mutex> lock(*(THCCachingAllocator_getCudaFreeMutex()));
|
|
auto comms = user_comms.empty() ? _get_communicators(inputs)
|
|
: ArrayRef<ncclComm_t>(user_comms);
|
|
at::cuda::OptionalCUDAGuard device_guard;
|
|
AutoNcclGroup nccl_group_guard;
|
|
for (size_t i = 0; i < len; i++) {
|
|
int device = inputs[i].get_device();
|
|
device_guard.set_index(device);
|
|
auto stream = !streams[i]
|
|
? at::cuda::getCurrentCUDAStream(device).stream()
|
|
: streams[i]->stream();
|
|
NCCL_CHECK(ncclAllReduce(
|
|
inputs[i].data_ptr(),
|
|
outputs[i].data_ptr(),
|
|
count,
|
|
data_type,
|
|
(ncclRedOp_t)op,
|
|
comms[i],
|
|
stream));
|
|
}
|
|
});
|
|
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) {
|
|
HANDLE_TH_ERRORS
|
|
PyObject *_inputs, *_streams, *_comms;
|
|
int root;
|
|
|
|
if (!PyArg_ParseTuple(args, "OiOO", &_inputs, &root, &_streams, &_comms)) {
|
|
THPUtils_invalidArguments(
|
|
args,
|
|
nullptr,
|
|
"nccl_broadcast",
|
|
1,
|
|
"(sequence[Tensor] inputs, int root)");
|
|
return nullptr;
|
|
}
|
|
|
|
std::vector<at::Tensor> inputs = extract_tensors(_inputs);
|
|
THPUtils_assert(root >= 0 && (size_t)root < inputs.size(), "invalid root");
|
|
auto streams = unpack_streams(_streams, inputs.size());
|
|
auto user_comms = unpack_comms(_comms, inputs.size());
|
|
|
|
with_no_gil(
|
|
[&] { torch::cuda::nccl::broadcast(inputs, streams, user_comms); });
|
|
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
|
|
HANDLE_TH_ERRORS
|
|
PyObject *_inputs, *_outputs, *_streams, *_comms;
|
|
|
|
if (!PyArg_ParseTuple(
|
|
args, "OOOO", &_inputs, &_outputs, &_streams, &_comms)) {
|
|
THPUtils_invalidArguments(
|
|
args,
|
|
nullptr,
|
|
"nccl_all_gather",
|
|
1,
|
|
"(sequence[Tensor] inputs, sequence[Tensor] outputs");
|
|
return nullptr;
|
|
}
|
|
|
|
std::vector<at::Tensor> inputs = extract_tensors(_inputs);
|
|
std::vector<at::Tensor> outputs = extract_tensors(_outputs);
|
|
auto streams = unpack_streams(_streams, inputs.size());
|
|
auto user_comms = unpack_comms(_comms, inputs.size());
|
|
|
|
with_no_gil([&] {
|
|
size_t len = inputs.size();
|
|
_check_inputs(inputs, outputs, len, 1);
|
|
|
|
ncclDataType_t data_type = _get_data_type(inputs[0].type());
|
|
|
|
int64_t count = inputs[0].numel();
|
|
std::lock_guard<std::mutex> lock(*(THCCachingAllocator_getCudaFreeMutex()));
|
|
auto comms = user_comms.empty() ? _get_communicators(inputs)
|
|
: ArrayRef<ncclComm_t>(user_comms);
|
|
at::cuda::OptionalCUDAGuard device_guard;
|
|
AutoNcclGroup nccl_group_guard;
|
|
for (size_t i = 0; i < len; i++) {
|
|
int device = inputs[i].get_device();
|
|
device_guard.set_index(device);
|
|
auto stream = !streams[i]
|
|
? at::cuda::getCurrentCUDAStream(device).stream()
|
|
: streams[i]->stream();
|
|
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
|
|
NCCL_CHECK(ncclAllGather(
|
|
inputs[i].data_ptr(),
|
|
outputs[i].data_ptr(),
|
|
count,
|
|
data_type,
|
|
comms[i],
|
|
stream));
|
|
#else
|
|
NCCL_CHECK(ncclAllGather(
|
|
inputs[i].data_ptr(),
|
|
count,
|
|
data_type,
|
|
outputs[i].data_ptr(),
|
|
comms[i],
|
|
stream));
|
|
#endif
|
|
}
|
|
});
|
|
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {
|
|
HANDLE_TH_ERRORS
|
|
PyObject *_inputs, *_outputs, *_streams, *_comms;
|
|
int op;
|
|
|
|
if (!PyArg_ParseTuple(
|
|
args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
|
|
THPUtils_invalidArguments(
|
|
args,
|
|
nullptr,
|
|
"nccl_reduce_scatter",
|
|
1,
|
|
"(sequence[Tensor] inputs, sequence[Tensor] outputs, int op");
|
|
return nullptr;
|
|
}
|
|
|
|
std::vector<at::Tensor> inputs = extract_tensors(_inputs);
|
|
std::vector<at::Tensor> outputs = extract_tensors(_outputs);
|
|
auto streams = unpack_streams(_streams, inputs.size());
|
|
auto user_comms = unpack_comms(_comms, inputs.size());
|
|
|
|
with_no_gil([&] {
|
|
size_t len = inputs.size();
|
|
_check_inputs(inputs, outputs, 1, len);
|
|
|
|
ncclDataType_t data_type = _get_data_type(inputs[0].type());
|
|
|
|
int64_t count = inputs[0].numel() / len;
|
|
std::lock_guard<std::mutex> lock(*(THCCachingAllocator_getCudaFreeMutex()));
|
|
auto comms = user_comms.empty() ? _get_communicators(inputs)
|
|
: ArrayRef<ncclComm_t>(user_comms);
|
|
at::cuda::OptionalCUDAGuard device_guard;
|
|
AutoNcclGroup nccl_group_guard;
|
|
for (size_t i = 0; i < len; i++) {
|
|
int device = inputs[i].get_device();
|
|
device_guard.set_index(device);
|
|
auto stream = !streams[i]
|
|
? at::cuda::getCurrentCUDAStream(device).stream()
|
|
: streams[i]->stream();
|
|
NCCL_CHECK(ncclReduceScatter(
|
|
inputs[i].data_ptr(),
|
|
outputs[i].data_ptr(),
|
|
count,
|
|
data_type,
|
|
(ncclRedOp_t)op,
|
|
comms[i],
|
|
stream));
|
|
}
|
|
});
|
|
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static std::vector<at::Tensor> extract_tensors(PyObject* obj) {
|
|
auto seq = THPObjectPtr(PySequence_Fast(obj, "expected a sequence"));
|
|
if (!seq)
|
|
throw python_error();
|
|
|
|
std::vector<at::Tensor> list;
|
|
Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
|
|
for (Py_ssize_t i = 0; i < length; i++) {
|
|
PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
|
|
if (!THPVariable_Check(item)) {
|
|
throw TypeError(
|
|
"expected Tensor at %d (got %s)", (int)i, Py_TYPE(item)->tp_name);
|
|
}
|
|
auto var = (THPVariable*)item;
|
|
list.emplace_back(var->cdata.data());
|
|
}
|
|
return list;
|
|
}
|