#include "torch/csrc/utils/pybind.h" #include "torch/csrc/cuda/comm.h" #include "torch/csrc/cuda/Stream.h" #include "torch/csrc/cuda/THCP.h" #include "torch/csrc/utils/auto_gil.h" #include "torch/csrc/utils/functional.h" #include #include #include #include namespace torch { namespace cuda { namespace python { void initCommMethods(PyObject *module) { auto m = py::cast(module); m.def( "_broadcast_coalesced", [](std::vector& tensors, std::vector devices, size_t buffer_size) { return broadcast_coalesced(tensors, devices, buffer_size); }, py::arg("tensors"), py::arg("devices"), py::arg("buffer_size"), py::call_guard()) .def( "_broadcast", [](at::Tensor& tensor, std::vector devices) { return broadcast(tensor, devices); }, py::call_guard()) .def( "_scatter", [](at::Tensor& tensor, std::vector& devices, c10::optional> chunk_sizes, int64_t dim, c10::optional py_streams) { c10::optional>> streams; if (py_streams) { py::handle handle = *py_streams; streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr()); } // Note: We're holding the GIL up to here. AutoNoGIL no_gil; return scatter(tensor, devices, chunk_sizes, dim, streams); }, py::arg("tensor"), py::arg("devices"), py::arg("chunk_sizes"), py::arg("dim"), py::arg("streams")) .def( "_gather", [](std::vector& tensors, int64_t dim, c10::optional destination_index) { return gather(tensors, dim, destination_index); }, py::arg("tensors"), py::arg("dim"), py::arg("destination_index"), py::call_guard()); } }}}