pytorch/torch/csrc/cuda/python_comm.cpp
Edward Yang 1111a6b810 Use pybind11::gil_scoped_* functions instead of AutoGIL/AutoNoGIL (#30274)
Summary:
Reland of https://github.com/pytorch/pytorch/pull/29095
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30274

Differential Revision: D18762293

Pulled By: ezyang

fbshipit-source-id: d3d50c2dd12bcb678ab25fa708eb6587cc4b66f9
2019-12-02 12:19:58 -08:00

69 lines
2.2 KiB
C++

#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/cuda/comm.h>
#include <torch/csrc/cuda/Stream.h>
#include <torch/csrc/cuda/THCP.h>
#include <pybind11/pybind11.h>
#include <ATen/core/functional.h>
#include <ATen/ATen.h>
#include <THC/THC.h>
#include <cstddef>
#include <vector>
namespace torch { namespace cuda { namespace python {
void initCommMethods(PyObject *module) {
auto m = py::cast<py::module>(module);
m.def(
"_broadcast_coalesced",
[](std::vector<at::Tensor>& tensors,
std::vector<int64_t> devices,
size_t buffer_size) {
return broadcast_coalesced(tensors, devices, buffer_size);
},
py::arg("tensors"),
py::arg("devices"),
py::arg("buffer_size"),
py::call_guard<py::gil_scoped_release>())
.def(
"_broadcast",
[](at::Tensor& tensor, std::vector<int64_t> devices) {
return broadcast(tensor, devices);
},
py::call_guard<py::gil_scoped_release>())
.def(
"_scatter",
[](at::Tensor& tensor,
std::vector<int64_t>& devices,
c10::optional<std::vector<int64_t>> chunk_sizes,
int64_t dim,
c10::optional<py::object> py_streams) {
c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>> streams;
if (py_streams) {
py::handle handle = *py_streams;
streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr());
}
// Note: We're holding the GIL up to here.
pybind11::gil_scoped_release no_gil;
return scatter(tensor, devices, chunk_sizes, dim, streams);
},
py::arg("tensor"),
py::arg("devices"),
py::arg("chunk_sizes"),
py::arg("dim"),
py::arg("streams"))
.def(
"_gather",
[](std::vector<at::Tensor>& tensors,
int64_t dim,
c10::optional<int32_t> destination_index) {
return gather(tensors, dim, destination_index);
},
py::arg("tensors"),
py::arg("dim"),
py::arg("destination_index"),
py::call_guard<py::gil_scoped_release>());
}
}}}