mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR was created to resolve issue brought up in https://fb.workplace.com/groups/319878845696681/permalink/741428653541696/ Changes: - Adds timeout argument to RpcAgent.join() - Add optional timeout argument to ThriftRpcAgent barrier() - During shutdown (ThriftRpcAgent join) calls the barrier, the agent will use the timeout passed to shutdown and pass that timeout into the join(). - Update API.py to also include fix bug (missing timeout for signal) - Change default shutdown timeout to 0 (no timeout). Existing functionality in _all_gather will remain the same and wait indefinitely for signal if no timeout is set for the function. New functionality has user specify timeout for both the signal and rpc calls. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76194 Test Plan: Modified barrier test buck test torch/fb/distributed/thriftRpcBackend/test:ThriftRpcAgentTest -- BarrierTest Reviewed By: mrshenli Differential Revision: D35825382 fbshipit-source-id: e91e9ab5d9fca08787cb6b6b8125a4b03d1c7cde (cherry picked from commit fcf899a387001574bf4e39a213ea741611d76097)
863 lines
36 KiB
C++
863 lines
36 KiB
C++
#include <torch/csrc/python_headers.h>
|
|
|
|
#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
|
|
#include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
|
|
#include <torch/csrc/distributed/rpc/py_rref.h>
|
|
#include <torch/csrc/distributed/rpc/python_functions.h>
|
|
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
|
#include <torch/csrc/distributed/rpc/request_callback_impl.h>
|
|
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
|
|
#include <torch/csrc/distributed/rpc/torchscript_functions.h>
|
|
#include <torch/csrc/distributed/rpc/types.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/types.h>
|
|
|
|
#include <pybind11/chrono.h>
|
|
#include <pybind11/operators.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
namespace {
|
|
|
|
constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000);
|
|
|
|
template <typename T>
|
|
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
|
|
|
PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
|
|
auto rpc_module =
|
|
THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc"));
|
|
if (!rpc_module) {
|
|
throw python_error();
|
|
}
|
|
|
|
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
|
|
if (!torch_C_module) {
|
|
throw python_error();
|
|
}
|
|
|
|
auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
|
|
auto m =
|
|
torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings");
|
|
|
|
auto module = py::handle(m).cast<py::module>();
|
|
|
|
auto rpcBackendOptions =
|
|
shared_ptr_class_<RpcBackendOptions>(
|
|
module,
|
|
"RpcBackendOptions",
|
|
R"(An abstract structure encapsulating the options passed into the RPC
|
|
backend. An instance of this class can be passed in to
|
|
:meth:`~torch.distributed.rpc.init_rpc` in order to initialize RPC
|
|
with specific configurations, such as the RPC timeout and
|
|
``init_method`` to be used. )")
|
|
.def(py::init<>())
|
|
.def(
|
|
py::init<float, std::string>(),
|
|
py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
|
|
py::arg("init_method") = kDefaultInitMethod)
|
|
.def_readwrite(
|
|
"rpc_timeout",
|
|
&RpcBackendOptions::rpcTimeoutSeconds,
|
|
R"(A float indicating the timeout to use for all
|
|
RPCs. If an RPC does not complete in this timeframe, it will
|
|
complete with an exception indicating that it has timed out.)")
|
|
.def_readwrite(
|
|
"init_method",
|
|
&RpcBackendOptions::initMethod,
|
|
R"(URL specifying how to initialize the process group.
|
|
Default is ``env://``)");
|
|
|
|
// The following C++ constants need to be cast so they can be used from
|
|
// python.
|
|
module.attr("_DEFAULT_RPC_TIMEOUT_SEC") = py::cast(kDefaultRpcTimeoutSeconds);
|
|
module.attr("_UNSET_RPC_TIMEOUT") = py::cast(kUnsetRpcTimeout);
|
|
module.attr("_DEFAULT_INIT_METHOD") = py::cast(kDefaultInitMethod);
|
|
|
|
auto workerInfo =
|
|
shared_ptr_class_<WorkerInfo>(
|
|
module,
|
|
"WorkerInfo",
|
|
R"(A structure that encapsulates information of a worker in the system.
|
|
Contains the name and ID of the worker. This class is not meant to
|
|
be constructed directly, rather, an instance can be retrieved
|
|
through :meth:`~torch.distributed.rpc.get_worker_info` and the
|
|
result can be passed in to functions such as
|
|
:meth:`~torch.distributed.rpc.rpc_sync`, :meth:`~torch.distributed.rpc.rpc_async`,
|
|
:meth:`~torch.distributed.rpc.remote` to avoid copying a string on
|
|
every invocation.)")
|
|
.def(
|
|
py::init<std::string, worker_id_t>(),
|
|
py::arg("name"),
|
|
py::arg("id"))
|
|
.def_readonly(
|
|
"name", &WorkerInfo::name_, R"(The name of the worker.)")
|
|
.def_readonly(
|
|
"id",
|
|
&WorkerInfo::id_,
|
|
R"(Globally unique id to identify the worker.)")
|
|
.def("__eq__", &WorkerInfo::operator==, py::is_operator())
|
|
// pybind11 suggests the syntax .def(hash(py::self)), with the
|
|
// unqualified "hash" function call. However the
|
|
// argument-dependent lookup for the function "hash" doesn't get
|
|
// triggered in this context because it conflicts with the struct
|
|
// c10::hash, so we need to use the qualified name
|
|
// py::detail::hash, which unfortunately is in a detail namespace.
|
|
.def(py::detail::hash(py::self)) // NOLINT
|
|
.def(
|
|
"__repr__",
|
|
[](const WorkerInfo& workerInfo) {
|
|
std::ostringstream os;
|
|
os << workerInfo;
|
|
return os.str();
|
|
})
|
|
.def(py::pickle(
|
|
/* __getstate__ */
|
|
[](const WorkerInfo& workerInfo) {
|
|
return py::make_tuple(workerInfo.name_, workerInfo.id_);
|
|
},
|
|
/* __setstate__ */
|
|
[](py::tuple t) {
|
|
TORCH_CHECK(t.size() == 2, "Invalid WorkerInfo state.");
|
|
|
|
WorkerInfo info(
|
|
t[0].cast<std::string>(), t[1].cast<worker_id_t>());
|
|
return info;
|
|
}));
|
|
|
|
auto rpcAgent =
|
|
shared_ptr_class_<RpcAgent>(module, "RpcAgent")
|
|
.def(
|
|
"join",
|
|
&RpcAgent::join,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
py::arg("shutdown") = false,
|
|
py::arg("timeout") = 0)
|
|
.def(
|
|
"sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"shutdown",
|
|
&RpcAgent::shutdown,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_info",
|
|
(const WorkerInfo& (RpcAgent::*)(void) const) &
|
|
RpcAgent::getWorkerInfo,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_info",
|
|
(const WorkerInfo& (RpcAgent::*)(const std::string&) const) &
|
|
RpcAgent::getWorkerInfo,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_infos",
|
|
&RpcAgent::getWorkerInfos,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_get_device_map",
|
|
&RpcAgent::getDeviceMap,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_debug_info",
|
|
&RpcAgent::getDebugInfo,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_metrics",
|
|
&RpcAgent::getMetrics,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
auto pyRRef =
|
|
shared_ptr_class_<PyRRef>(module, "PyRRef", R"(
|
|
A class encapsulating a reference to a value of some type on a remote
|
|
worker. This handle will keep the referenced remote value alive on the
|
|
worker. A ``UserRRef`` will be deleted when 1) no references to it in
|
|
both the application code and in the local RRef context, or 2) the
|
|
application has called a graceful shutdown. Invoking methods on a
|
|
deleted RRef leads to undefined behaviors. RRef implementation only
|
|
offers best-effort error detection, and applications should not use
|
|
``UserRRefs`` after ``rpc.shutdown()``.
|
|
|
|
.. warning::
|
|
RRefs can only be serialized and deserialized by the RPC module.
|
|
Serializing and deserializing RRefs without RPC (e.g., Python
|
|
pickle, torch :meth:`~torch.save` / :meth:`~torch.load`,
|
|
JIT :meth:`~torch.jit.save` / :meth:`~torch.jit.load`, etc.) will
|
|
lead to errors.
|
|
|
|
Args:
|
|
value (object): The value to be wrapped by this RRef.
|
|
type_hint (Type, optional): Python type that should be passed to
|
|
``TorchScript`` compiler as type hint for ``value``.
|
|
|
|
Example::
|
|
Following examples skip RPC initialization and shutdown code
|
|
for simplicity. Refer to RPC docs for those details.
|
|
|
|
1. Create an RRef using rpc.remote
|
|
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>> rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
|
|
>>> # get a copy of value from the RRef
|
|
>>> x = rref.to_here()
|
|
|
|
2. Create an RRef from a local object
|
|
|
|
>>> import torch
|
|
>>> from torch.distributed.rpc import RRef
|
|
>>> x = torch.zeros(2, 2)
|
|
>>> rref = RRef(x)
|
|
|
|
3. Share an RRef with other workers
|
|
|
|
>>> # On both worker0 and worker1:
|
|
>>> def f(rref):
|
|
>>> return rref.to_here() + 1
|
|
|
|
>>> # On worker0:
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>> from torch.distributed.rpc import RRef
|
|
>>> rref = RRef(torch.zeros(2, 2))
|
|
>>> # the following RPC shares the rref with worker1, reference
|
|
>>> # count is automatically updated.
|
|
>>> rpc.rpc_sync("worker1", f, args=(rref,))
|
|
)")
|
|
.def(
|
|
py::init<const py::object&, const py::object&>(),
|
|
py::arg("value"),
|
|
py::arg("type_hint") = py::none())
|
|
.def(
|
|
// not releasing GIL here to avoid context switch on getters
|
|
"is_owner",
|
|
&PyRRef::isOwner,
|
|
R"(
|
|
Returns whether or not the current node is the owner of this
|
|
``RRef``.
|
|
)")
|
|
.def(
|
|
"confirmed_by_owner",
|
|
&PyRRef::confirmedByOwner,
|
|
R"(
|
|
Returns whether this ``RRef`` has been confirmed by the owner.
|
|
``OwnerRRef`` always returns true, while ``UserRRef`` only
|
|
returns true when the owner knowns about this ``UserRRef``.
|
|
)")
|
|
.def(
|
|
// not releasing GIL here to avoid context switch on getters
|
|
"owner",
|
|
&PyRRef::owner,
|
|
R"(
|
|
Returns worker information of the node that owns this ``RRef``.
|
|
)")
|
|
.def(
|
|
// not releasing GIL here to avoid context switch on getters
|
|
"owner_name",
|
|
&PyRRef::ownerName,
|
|
R"(
|
|
Returns worker name of the node that owns this ``RRef``.
|
|
)")
|
|
.def(
|
|
"to_here",
|
|
&PyRRef::toHere,
|
|
py::arg("timeout") = py::cast(kUnsetRpcTimeout),
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Blocking call that copies the value of the RRef from the owner
|
|
to the local node and returns it. If the current node is the
|
|
owner, returns a reference to the local value.
|
|
|
|
Args:
|
|
timeout (float, optional): Timeout for ``to_here``. If
|
|
the call does not complete within this timeframe, an
|
|
exception indicating so will be raised. If this
|
|
argument is not provided, the default RPC timeout
|
|
(60s) will be used.
|
|
)")
|
|
.def(
|
|
"local_value",
|
|
&PyRRef::localValue,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
If the current node is the owner, returns a reference to the
|
|
local value. Otherwise, throws an exception.
|
|
)")
|
|
.def(
|
|
"rpc_sync",
|
|
[](const PyRRef& self, float timeoutSeconds) {
|
|
return self.createRRefProxy(
|
|
RRefProxyType::RPC_SYNC, timeoutSeconds);
|
|
},
|
|
py::arg("timeout") = kUnsetRpcTimeout,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Create a helper proxy to easily launch an ``rpc_sync`` using
|
|
the owner of the RRef as the destination to run functions on
|
|
the object referenced by this RRef. More specifically,
|
|
``rref.rpc_sync().func_name(*args, **kwargs)`` is the same as
|
|
the following:
|
|
|
|
>>> def run(rref, func_name, args, kwargs):
|
|
>>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
>>>
|
|
>>> rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs))
|
|
|
|
Args:
|
|
timeout (float, optional): Timeout for ``rref.rpc_sync()``.
|
|
If the call does not complete within this timeframe, an
|
|
exception indicating so will be raised. If this argument
|
|
is not provided, the default RPC timeout will be used.
|
|
|
|
Example::
|
|
>>> from torch.distributed import rpc
|
|
>>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
|
|
>>> rref.rpc_sync().size() # returns torch.Size([2, 2])
|
|
>>> rref.rpc_sync().view(1, 4) # returns tensor([[1., 1., 1., 1.]])
|
|
)")
|
|
.def(
|
|
"rpc_async",
|
|
[](const PyRRef& self, float timeoutSeconds) {
|
|
return self.createRRefProxy(
|
|
RRefProxyType::RPC_ASYNC, timeoutSeconds);
|
|
},
|
|
py::arg("timeout") = kUnsetRpcTimeout,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Create a helper proxy to easily launch an ``rpc_async`` using
|
|
the owner of the RRef as the destination to run functions on
|
|
the object referenced by this RRef. More specifically,
|
|
``rref.rpc_async().func_name(*args, **kwargs)`` is the same as
|
|
the following:
|
|
|
|
>>> def run(rref, func_name, args, kwargs):
|
|
>>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
>>>
|
|
>>> rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
|
|
|
|
Args:
|
|
timeout (float, optional): Timeout for ``rref.rpc_async()``.
|
|
If the call does not complete within this timeframe, an
|
|
exception indicating so will be raised. If this argument
|
|
is not provided, the default RPC timeout will be used.
|
|
|
|
Example::
|
|
>>> from torch.distributed import rpc
|
|
>>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
|
|
>>> rref.rpc_async().size().wait() # returns torch.Size([2, 2])
|
|
>>> rref.rpc_async().view(1, 4).wait() # returns tensor([[1., 1., 1., 1.]])
|
|
)")
|
|
.def(
|
|
"remote",
|
|
[](const PyRRef& self, float timeoutSeconds) {
|
|
return self.createRRefProxy(
|
|
RRefProxyType::REMOTE, timeoutSeconds);
|
|
},
|
|
py::arg("timeout") = kUnsetRpcTimeout,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Create a helper proxy to easily launch a ``remote`` using
|
|
the owner of the RRef as the destination to run functions on
|
|
the object referenced by this RRef. More specifically,
|
|
``rref.remote().func_name(*args, **kwargs)`` is the same as
|
|
the following:
|
|
|
|
>>> def run(rref, func_name, args, kwargs):
|
|
>>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
>>>
|
|
>>> rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs))
|
|
|
|
Args:
|
|
timeout (float, optional): Timeout for ``rref.remote()``. If
|
|
the creation of this :class:`~torch.distributed.rpc.RRef`
|
|
is not successfully completed within the timeout, then the
|
|
next time there is an attempt to use the RRef
|
|
(such as ``to_here``), a timeout will be raised. If not
|
|
provided, the default RPC timeout will be used. Please see
|
|
``rpc.remote()`` for specific timeout semantics for
|
|
:class:`~torch.distributed.rpc.RRef`.
|
|
|
|
Example::
|
|
>>> from torch.distributed import rpc
|
|
>>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
|
|
>>> rref.remote().size().to_here() # returns torch.Size([2, 2])
|
|
>>> rref.remote().view(1, 4).to_here() # returns tensor([[1., 1., 1., 1.]])
|
|
)")
|
|
.def(
|
|
py::pickle(
|
|
/* __getstate__ */
|
|
[](const PyRRef& /* unused */) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Can not pickle rref in python pickler, rref can only be "
|
|
"pickled when using RPC");
|
|
// Note that this return has no meaning since we always
|
|
// throw, it's only here to satisfy Pybind API's
|
|
// requirement.
|
|
return py::make_tuple();
|
|
},
|
|
/* __setstate__ */
|
|
[](py::tuple /* unused */) { // NOLINT
|
|
TORCH_CHECK(
|
|
false,
|
|
"Can not unpickle rref in python pickler, rref can only be "
|
|
"unpickled when using RPC");
|
|
// Note that this return has no meaning since we always
|
|
// throw, it's only here to satisfy PyBind's API
|
|
// requirement.
|
|
return PyRRef(
|
|
py::cast<py::none>(Py_None),
|
|
py::cast<py::none>(Py_None));
|
|
}),
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_serialize",
|
|
&PyRRef::pickle,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def_static(
|
|
"_deserialize",
|
|
&PyRRef::unpickle,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_get_type",
|
|
// Intentionally not releasing GIL, as most accesses just
|
|
// retrieve cached type py::object
|
|
&PyRRef::getRRefType,
|
|
py::arg("timeout") = kUnsetRpcTimeout,
|
|
py::arg("blocking") = true,
|
|
R"(
|
|
If ``blocking=True``, returns the type of the data object
|
|
referenced by this ``RRef``. On the owner, this is same as
|
|
``type(rref.local_value())``. Otherwise, returns a future to
|
|
this result. On a user, this will trigger an RPC to fetch the
|
|
``type`` object from the owner. After this function is run
|
|
once, the ``type`` object is cached by the ``RRef``, and
|
|
subsequent invocations no longer trigger RPC. Note that this is
|
|
true regardless of the ``blocking`` argument of subsequent
|
|
calls.
|
|
|
|
Args:
|
|
rref (torch.distributed.rpc.RRef): The RRef to get type of.
|
|
timeout (float, optional): Timeout, in seconds for
|
|
``_get_type``. If the call does not complete within
|
|
this timeframe, an exception indicating so will be
|
|
raised. If this argument is not provided, the default
|
|
RPC timeout will be used.
|
|
blocking (bool, optional): Whether to synchronously wait on
|
|
the RPC triggered by the first call and return the
|
|
type. If ``False``, will return a future. Default is
|
|
``True``.
|
|
)")
|
|
.def(
|
|
"_get_future",
|
|
[](const PyRRef& self) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(
|
|
self.getFuture());
|
|
},
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Returns the future that corresponds to the creation of this RRef
|
|
on the remote node. This is for internal use cases such as profiling
|
|
only.
|
|
)")
|
|
.def(
|
|
"_get_profiling_future",
|
|
[](const PyRRef& self) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(
|
|
self.getProfilingFuture());
|
|
},
|
|
py::call_guard<py::gil_scoped_acquire>(),
|
|
R"(
|
|
Returns future that completes when the profiling event corresponding
|
|
to the creation of this RRef on the remote node has been recorded.
|
|
)")
|
|
.def(
|
|
"_set_profiling_future",
|
|
[](PyRRef& self,
|
|
const std::shared_ptr<jit::PythonFutureWrapper>&
|
|
wrappedFuture) {
|
|
self.setProfilingFuture(wrappedFuture->fut);
|
|
},
|
|
py::call_guard<py::gil_scoped_acquire>(),
|
|
R"(
|
|
Set future that is completed when the profiling event corresponding
|
|
to the creation of this RRef on the remote node has been recorded.
|
|
)")
|
|
.def(
|
|
"backward",
|
|
[](PyRRef& self,
|
|
int64_t dist_autograd_ctx_id,
|
|
bool retain_graph) {
|
|
self.backward(dist_autograd_ctx_id, retain_graph);
|
|
},
|
|
py::arg("dist_autograd_ctx_id") = -1,
|
|
py::arg("retain_graph") = false,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
R"(
|
|
Runs the backward pass using the RRef as the root of the
|
|
backward pass. If ``dist_autograd_ctx_id`` is provided,
|
|
we perform a distributed backward pass using the provided
|
|
ctx_id starting from the owner of the RRef. In this case,
|
|
:meth:`~torch.distributed.autograd.get_gradients` should be
|
|
used to retrieve the gradients. If ``dist_autograd_ctx_id``
|
|
is ``None``, it is assumed that this is a local autograd graph
|
|
and we only perform a local backward pass. In the local case,
|
|
the node calling this API has to be the owner of the RRef.
|
|
The value of the RRef is expected to be a scalar Tensor.
|
|
|
|
Args:
|
|
dist_autograd_ctx_id (int, optional): The distributed
|
|
autograd context id for which we should retrieve the
|
|
gradients (default: -1).
|
|
retain_graph(bool, optional): If ``False``, the graph used to
|
|
compute the grad will be freed. Note that in nearly all
|
|
cases setting this option to ``True`` is not needed and
|
|
often can be worked around in a much more efficient way.
|
|
Usually, you need to set this to ``True`` to run backward
|
|
multiple times (default: False).
|
|
|
|
Example::
|
|
>>> import torch.distributed.autograd as dist_autograd
|
|
>>> with dist_autograd.context() as context_id:
|
|
>>> rref.backward(context_id)
|
|
)")
|
|
// not releasing GIL to avoid context switch
|
|
.def("__repr__", &PyRRef::str);
|
|
|
|
#ifdef USE_TENSORPIPE
|
|
|
|
// Base class: torch.distributed.rpc.RpcBackendOptions.
|
|
py::class_<TensorPipeRpcBackendOptions>(
|
|
module, "_TensorPipeRpcBackendOptionsBase", rpcBackendOptions)
|
|
.def(
|
|
py::init<
|
|
int,
|
|
optional<std::vector<std::string>>,
|
|
optional<std::vector<std::string>>,
|
|
float,
|
|
std::string,
|
|
std::unordered_map<std::string, DeviceMap>,
|
|
std::vector<c10::Device>>(),
|
|
py::arg("num_worker_threads") = kDefaultNumWorkerThreads,
|
|
py::arg("_transports") = optional<std::vector<std::string>>(),
|
|
py::arg("_channels") = optional<std::vector<std::string>>(),
|
|
py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
|
|
py::arg("init_method") = kDefaultInitMethod,
|
|
py::arg("device_maps") = std::unordered_map<std::string, DeviceMap>(),
|
|
py::arg("devices") = std::vector<c10::Device>())
|
|
.def_readwrite(
|
|
"num_worker_threads",
|
|
&TensorPipeRpcBackendOptions::numWorkerThreads,
|
|
R"(
|
|
The number of threads in the thread-pool used by
|
|
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
|
|
requests.
|
|
)")
|
|
.def_readwrite(
|
|
"device_maps",
|
|
&TensorPipeRpcBackendOptions::deviceMaps,
|
|
R"(The device map locations.)")
|
|
.def_readwrite(
|
|
"devices",
|
|
&TensorPipeRpcBackendOptions::devices,
|
|
R"(All devices used by the local agent.)")
|
|
.def("_set_device_map", &TensorPipeRpcBackendOptions::setDeviceMap);
|
|
|
|
module.attr("_DEFAULT_NUM_WORKER_THREADS") =
|
|
py::cast(kDefaultNumWorkerThreads);
|
|
|
|
shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent", rpcAgent)
|
|
.def(
|
|
py::init(
|
|
[](const c10::intrusive_ptr<::c10d::Store>& store,
|
|
std::string selfName,
|
|
worker_id_t selfId,
|
|
optional<int> worldSize,
|
|
TensorPipeRpcBackendOptions opts,
|
|
std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
|
|
std::vector<c10::Device> devices) {
|
|
return std::shared_ptr<TensorPipeAgent>(
|
|
new TensorPipeAgent(
|
|
store,
|
|
std::move(selfName),
|
|
selfId,
|
|
worldSize,
|
|
std::move(opts),
|
|
std::move(reverseDeviceMaps),
|
|
std::move(devices),
|
|
std::make_unique<RequestCallbackImpl>()),
|
|
impl::destroy_without_gil<TensorPipeAgent>);
|
|
}),
|
|
py::arg("store"),
|
|
py::arg("name"),
|
|
py::arg("rank"),
|
|
py::arg("world_size"),
|
|
py::arg("rpc_backend_options"),
|
|
py::arg("reverse_device_maps"),
|
|
py::arg("devices"))
|
|
.def(
|
|
"join",
|
|
&TensorPipeAgent::join,
|
|
py::call_guard<py::gil_scoped_release>(),
|
|
py::arg("shutdown") = false,
|
|
py::arg("timeout") = 0)
|
|
.def(
|
|
"shutdown",
|
|
&TensorPipeAgent::shutdown,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_info",
|
|
(const WorkerInfo& (TensorPipeAgent::*)(void) const) &
|
|
RpcAgent::getWorkerInfo,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_info",
|
|
(const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) &
|
|
TensorPipeAgent::getWorkerInfo,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_info",
|
|
(const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) &
|
|
TensorPipeAgent::getWorkerInfo,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"get_worker_infos",
|
|
(std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
|
|
TensorPipeAgent::getWorkerInfos,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_get_device_map",
|
|
(DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst) const) &
|
|
TensorPipeAgent::getDeviceMap,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_get_backend_options",
|
|
&TensorPipeAgent::getBackendOptions,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def(
|
|
"_update_group_membership",
|
|
&TensorPipeAgent::updateGroupMembership,
|
|
py::call_guard<py::gil_scoped_release>())
|
|
.def_readonly("is_static_group", &TensorPipeAgent::isStaticGroup_)
|
|
.def_property_readonly("store", &TensorPipeAgent::getStore);
|
|
|
|
#endif // USE_TENSORPIPE
|
|
|
|
module.def("_is_current_rpc_agent_set", &RpcAgent::isCurrentRpcAgentSet);
|
|
|
|
module.def("_get_current_rpc_agent", &RpcAgent::getCurrentRpcAgent);
|
|
|
|
module.def(
|
|
"_set_and_start_rpc_agent",
|
|
[](const std::shared_ptr<RpcAgent>& rpcAgent) {
|
|
RpcAgent::setCurrentRpcAgent(rpcAgent);
|
|
// Initializing typeResolver inside RpcAgent constructor will make
|
|
// RpcAgent have python dependency. To avoid RpcAgent to have python
|
|
// dependency, setTypeResolver() here.
|
|
std::shared_ptr<TypeResolver> typeResolver =
|
|
std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
|
|
auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr(
|
|
qn.qualifiedName());
|
|
return c10::StrongTypePtr(
|
|
PythonRpcHandler::getInstance().jitCompilationUnit(),
|
|
std::move(typePtr));
|
|
});
|
|
rpcAgent->setTypeResolver(typeResolver);
|
|
rpcAgent->start();
|
|
},
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_reset_current_rpc_agent",
|
|
[]() { RpcAgent::setCurrentRpcAgent(nullptr); },
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_delete_all_user_and_unforked_owner_rrefs",
|
|
[](std::chrono::milliseconds timeoutMillis) {
|
|
RRefContext::getInstance().delAllUsersAndUnforkedOwners(timeoutMillis);
|
|
},
|
|
py::arg("timeout") = kDeleteAllUsersTimeout,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def("_destroy_rref_context", [](bool ignoreRRefLeak) {
|
|
// NB: do not release GIL in the function. The destroyInstance() method
|
|
// returns a list of deleted OwnerRRefs that hold py::object instances.
|
|
// Clearing those OwnerRRefs are likely to trigger Python deref, which
|
|
// requires GIL.
|
|
RRefContext::getInstance().destroyInstance(ignoreRRefLeak).clear();
|
|
});
|
|
|
|
module.def("_rref_context_get_debug_info", []() {
|
|
return RRefContext::getInstance().getDebugInfo();
|
|
});
|
|
|
|
module.def(
|
|
"_cleanup_python_rpc_handler",
|
|
[]() { PythonRpcHandler::getInstance().cleanup(); },
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_invoke_rpc_builtin",
|
|
[](const WorkerInfo& dst,
|
|
const std::string& opName,
|
|
const float rpcTimeoutSeconds,
|
|
const py::args& args,
|
|
const py::kwargs& kwargs) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(
|
|
pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds));
|
|
},
|
|
py::call_guard<py::gil_scoped_acquire>());
|
|
|
|
module.def(
|
|
"_invoke_rpc_python_udf",
|
|
[](const WorkerInfo& dst,
|
|
std::string& pickledPythonUDF,
|
|
std::vector<torch::Tensor>& tensors,
|
|
const float rpcTimeoutSeconds,
|
|
const bool isAsyncExecution) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(pyRpcPythonUdf(
|
|
dst,
|
|
pickledPythonUDF,
|
|
tensors,
|
|
rpcTimeoutSeconds,
|
|
isAsyncExecution));
|
|
},
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_invoke_rpc_torchscript",
|
|
[](const std::string& dstWorkerName,
|
|
const std::string& qualifiedNameStr,
|
|
const py::tuple& argsTuple,
|
|
const py::dict& kwargsDict,
|
|
const float rpcTimeoutSeconds,
|
|
const bool isAsyncExecution) {
|
|
return std::make_shared<jit::PythonFutureWrapper>(pyRpcTorchscript(
|
|
dstWorkerName,
|
|
qualifiedNameStr,
|
|
argsTuple,
|
|
kwargsDict,
|
|
rpcTimeoutSeconds,
|
|
isAsyncExecution));
|
|
},
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_invoke_remote_builtin",
|
|
&pyRemoteBuiltin,
|
|
py::call_guard<py::gil_scoped_acquire>());
|
|
|
|
module.def(
|
|
"_invoke_remote_python_udf",
|
|
&pyRemotePythonUdf,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"_invoke_remote_torchscript",
|
|
&pyRemoteTorchscript,
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
module.def(
|
|
"get_rpc_timeout",
|
|
[]() {
|
|
return RpcAgent::getCurrentRpcAgent()->getRpcTimeout().count() /
|
|
kSecToMsConversion;
|
|
},
|
|
R"(
|
|
Retrieve the default timeout for all RPCs that was set during RPC initialization.
|
|
The returned value will be in seconds.
|
|
Returns:
|
|
``float`` indicating the RPC timeout in seconds.
|
|
)");
|
|
|
|
module.def(
|
|
"enable_gil_profiling",
|
|
[](bool flag) {
|
|
RpcAgent::getCurrentRpcAgent()->enableGILProfiling(flag);
|
|
},
|
|
R"(
|
|
Set whether GIL wait times should be enabled or not. This incurs a slight
|
|
overhead cost. Default is disabled for performance reasons.
|
|
|
|
Args:
|
|
flag (bool): True to set GIL profiling, False to disable.
|
|
)");
|
|
|
|
module.def(
|
|
"_set_rpc_timeout",
|
|
[](const float rpcTimeoutSeconds) {
|
|
auto rpcTimeout = std::chrono::milliseconds(
|
|
static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
|
|
RpcAgent::getCurrentRpcAgent()->setRpcTimeout(rpcTimeout);
|
|
},
|
|
R"(
|
|
Set the default timeout for all RPCs. The input unit is expected to be
|
|
in seconds. If an RPC is not completed within this time, an exception
|
|
indicating it has timed out will be raised. To control timeout for
|
|
specific RPCs, a timeout parameter can be passed into
|
|
:meth:`~torch.distributed.rpc.rpc_sync` and
|
|
:meth:`~torch.distributed.rpc.rpc_async`.
|
|
|
|
Args:
|
|
rpcTimeoutSeconds (float): Timeout value in seconds.
|
|
)");
|
|
|
|
module.def(
|
|
"_enable_server_process_global_profiler",
|
|
&profiler::processglobal::enableServer);
|
|
module.def(
|
|
"_disable_server_process_global_profiler",
|
|
&profiler::processglobal::disableServer);
|
|
|
|
module.def("_set_profiler_node_id", &at::RecordFunction::setDefaultNodeId);
|
|
|
|
py::class_<
|
|
RemoteProfilerManager,
|
|
std::unique_ptr<RemoteProfilerManager, py::nodelete>>(
|
|
module, "RemoteProfilerManager")
|
|
.def("set_current_profiling_key", [](const std::string& key) {
|
|
auto& inst = RemoteProfilerManager::getInstance();
|
|
inst.setCurrentKey(key);
|
|
});
|
|
|
|
module.def(
|
|
"_enable_jit_rref_pickle",
|
|
&enableJitRRefPickle,
|
|
R"(
|
|
Allows ``torch.jit.save`` to save a ``torch.jit.ScriptModule`` with
|
|
pickled RRefs out of RPC contexts.
|
|
|
|
|
|
.. warning::
|
|
This is dangerous. If the module contains RRefs, the pickled
|
|
result must be sent over RPC and get unpickled on the receiving side
|
|
to restore the module. Otherwise, there will be RRef leaks, which
|
|
can potentially lead to program hang. When using this API, it is
|
|
applications responsibility to make sure that the above assumption
|
|
always holds.
|
|
)");
|
|
module.def("_disable_jit_rref_pickle", &disableJitRRefPickle);
|
|
|
|
Py_RETURN_TRUE;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
static PyMethodDef methods[] = { // NOLINT
|
|
{"_rpc_init", rpc_init, METH_NOARGS, nullptr},
|
|
{nullptr, nullptr, 0, nullptr}};
|
|
|
|
PyMethodDef* python_functions() {
|
|
return methods;
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|