#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace distributed { namespace rpc { namespace { template using shared_ptr_class_ = py::class_>; PyObject* rpc_init(PyObject* /* unused */) { auto rpc_module = THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc")); if (!rpc_module) { throw python_error(); } auto module = py::handle(rpc_module).cast(); auto rpcAgentOptions = shared_ptr_class_(module, "RpcAgentOptions") .def_readwrite("rpc_timeout", &RpcAgentOptions::rpcTimeout); auto workerInfo = shared_ptr_class_( module, "WorkerInfo", R"(Encapsulates information of a worker in the system.)") .def_readonly("name", &WorkerInfo::name_, R"(Name of the worker.)") .def_readonly( "id", &WorkerInfo::id_, R"(Globally unique id of the worker.)") .def("__eq__", &WorkerInfo::operator==, py::is_operator()) // pybind11 suggests the syntax .def(hash(py::self)), with the // unqualified "hash" function call. However the // argument-dependent lookup for the function "hash" doesn't get // triggered in this context because it conflicts with the struct // torch::hash, so we need to use the qualified name // py::detail::hash, which unfortunately is in a detail namespace. .def(py::detail::hash(py::self)); auto rpcAgent = shared_ptr_class_(module, "RpcAgent") .def( "join", &RpcAgent::join, py::call_guard()) .def( "sync", &RpcAgent::sync, py::call_guard()); auto pyRRef = shared_ptr_class_(module, "RRef", R"( A class encapsulating a reference to a value of some type on a remote worker. This handle will keep the referenced remote value alive on the worker. )") .def(py::init()) .def( // not releasing GIL here to avoid context switch on getters "is_owner", &PyRRef::isOwner, R"( Returns whether or not the current node is the owner of this ``RRef``. )") .def( // not releasing GIL here to avoid context switch on getters "owner", &PyRRef::owner, R"( Returns worker information of the node that owns this ``RRef``. )") .def( "to_here", &PyRRef::toHere, py::call_guard(), R"( Blocking call that copies the value of the RRef from the owner to the local node and returns it. If the current node is the owner, returns a reference to the local value. )") .def( "local_value", &PyRRef::localValue, py::call_guard(), R"( If the current node is the owner, returns a reference to the local value. Otherwise, throws an exception. )") .def(py::pickle( [](const PyRRef& self) { // __getstate__ return self.pickle(); }, [](py::tuple t) { // NOLINT // __setstate__ return PyRRef::unpickle(t); })); // future.wait() should not be called after join_rpc(), e.g., pythonRpcHandler // is cleaned up in join_rpc(), after join_rpc(), python objects returned // from rpc python call can not be resolved. auto futureMessage = shared_ptr_class_(module, "FutureMessage") .def( "wait", [&](FutureMessage& fut) { return toPyObj(fut.wait()); }, py::call_guard()); shared_ptr_class_( module, "ProcessGroupRpcAgentOptions", rpcAgentOptions) .def(py::init<>()) .def_readwrite( "num_send_recv_threads", &ProcessGroupRpcAgentOptions::numSendRecvThreads); shared_ptr_class_(module, "ProcessGroupAgent", rpcAgent) .def( py::init< std::string, std::shared_ptr<::c10d::ProcessGroup>, int, std::chrono::milliseconds>(), py::arg("name"), py::arg("process_group"), py::arg("num_send_recv_threads"), py::arg("rpc_timeout")) .def( "get_worker_info", (const WorkerInfo& (ProcessGroupAgent::*)(void)const) & RpcAgent::getWorkerInfo, py::call_guard()) .def( "get_worker_info", (const WorkerInfo& (ProcessGroupAgent::*)(const std::string&)const) & ProcessGroupAgent::getWorkerInfo, py::call_guard()) .def( "join", &ProcessGroupAgent::join, py::call_guard()) .def( "sync", &ProcessGroupAgent::sync, py::call_guard()); module.def("_start_rpc_agent", [](const std::shared_ptr& agent) { RpcAgent::setDefaultRpcAgent(agent); agent->start(); }); module.def("_destroy_rref_context", []() { RRefContext::getInstance().destroyInstance(); }); module.def("_cleanup_python_rpc_handler", []() { PythonRpcHandler::getInstance().cleanup(); }); module.def( "_invoke_rpc_builtin", [](RpcAgent& agent, const WorkerInfo& dst, const std::string& opName, const py::args& args, const py::kwargs& kwargs) { return pyRpcBuiltin(agent, dst, opName, args, kwargs); }); module.def( "_invoke_rpc_python_udf", [](RpcAgent& agent, const WorkerInfo& dst, std::string& pickledPythonUDF, std::vector& tensors) { return pyRpcPythonUdf(agent, dst, pickledPythonUDF, tensors); }); module.def( "_invoke_remote_builtin", [](RpcAgent& agent, const WorkerInfo& dst, const std::string& opName, const py::args& args, const py::kwargs& kwargs) { return pyRemoteBuiltin(agent, dst, opName, args, kwargs); }); module.def( "_invoke_remote_python_udf", [](RpcAgent& agent, const WorkerInfo& dst, std::string& pickledPythonUDF, std::vector& tensors) { return pyRemotePythonUdf(agent, dst, pickledPythonUDF, tensors); }); module.def( "get_rpc_timeout", []() { return RpcAgent::getDefaultRpcAgent()->getRpcTimeout(); }, R"( Retrieve the timeout for all RPCs that was set during RPC initialization. Returns: `datetime.timedelta` instance indicating the RPC timeout. )"); module.def( "_set_rpc_timeout", [](const std::chrono::milliseconds& rpcTimeout) { RpcAgent::getDefaultRpcAgent()->setRpcTimeout(rpcTimeout); }, R"( Set the timeout for all RPCs. If an RPC is not completed within this time, an exception indicating it has timed out will be raised. )"); Py_RETURN_TRUE; } } // namespace static PyMethodDef methods[] = { // NOLINT {"_rpc_init", (PyCFunction)rpc_init, METH_NOARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; PyMethodDef* python_functions() { return methods; } } // namespace rpc } // namespace distributed } // namespace torch