#include "torch/csrc/utils/pybind.h" #include "torch/csrc/cuda/comm.h" #include namespace torch { namespace cuda { namespace python { void initCommMethods(PyObject *module) { auto m = py::cast(module); m.def("_broadcast_coalesced", [](std::vector& tensors, std::vector devices, size_t buffer_size) { return broadcast_coalesced(tensors, devices, buffer_size); }, py::arg("tensors"), py::arg("devices"), py::arg("buffer_size"), py::call_guard()) .def("_broadcast", [](at::Tensor& tensor, std::vector devices) { return broadcast(tensor, devices); }, py::call_guard()); } }}}