mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27951 we want to clean up the distributed autograd context across the other nodes when a single node is done (here done means exited the context manager `with dist_autograd.context() as context_id: ...`). This PR does a few things to implement the above: 1) Add classes to encapsulate messages for requesting this context release and the response 2) Handling of this request in `request_callback_impl.cpp`. When we receive this request, we get the context from a given context_id and release it. 3) RPC call in `DistAutogradContainer::releaseContext` to send this command. This currently does not wait for an ack or implement any sort of retrying. We send the RPC to all the workerIds we have come into contact with (implemented in https://github.com/pytorch/pytorch/pull/26324) 4) Relevant unit tests In follow up PRs, we will add error checking + retries for this call. ghstack-source-id: 92269279 Test Plan: Added/modified unit tests in `test/dist_autograd_test.py` Differential Revision: D17920137 fbshipit-source-id: 7403512ab5fcbc28d21c548b2e45319dd472e26a
103 lines
3.5 KiB
C++
103 lines
3.5 KiB
C++
#include <torch/csrc/distributed/rpc/utils.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
|
#include <torch/csrc/distributed/rpc/python_remote_call.h>
|
|
#include <torch/csrc/distributed/rpc/python_udf_call.h>
|
|
#include <torch/csrc/distributed/rpc/python_udf_resp.h>
|
|
#include <torch/csrc/distributed/rpc/rref_proto.h>
|
|
#include <torch/csrc/distributed/rpc/script_call.h>
|
|
#include <torch/csrc/distributed/rpc/script_remote_call.h>
|
|
#include <torch/csrc/distributed/rpc/script_resp.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) {
|
|
switch (request.type()) {
|
|
case MessageType::SCRIPT_CALL: {
|
|
return ScriptCall::fromMessage(request);
|
|
}
|
|
case MessageType::PYTHON_CALL: {
|
|
return PythonUDFCall::fromMessage(request);
|
|
}
|
|
case MessageType::SCRIPT_REMOTE_CALL: {
|
|
return ScriptRemoteCall::fromMessage(request);
|
|
}
|
|
case MessageType::PYTHON_REMOTE_CALL: {
|
|
return PythonRemoteCall::fromMessage(request);
|
|
}
|
|
case MessageType::SCRIPT_RREF_FETCH_CALL: {
|
|
return ScriptRRefFetchCall::fromMessage(request);
|
|
}
|
|
case MessageType::PYTHON_RREF_FETCH_CALL: {
|
|
return PythonRRefFetchCall::fromMessage(request);
|
|
}
|
|
case MessageType::RREF_USER_DELETE: {
|
|
return RRefUserDelete::fromMessage(request);
|
|
}
|
|
case MessageType::RREF_CHILD_ACCEPT: {
|
|
return RRefChildAccept::fromMessage(request);
|
|
}
|
|
case MessageType::RREF_FORK_REQUEST: {
|
|
return RRefForkRequest::fromMessage(request);
|
|
}
|
|
case MessageType::FORWARD_AUTOGRAD_REQ: {
|
|
return autograd::RpcWithAutograd::fromMessage(request);
|
|
}
|
|
case MessageType::BACKWARD_AUTOGRAD_REQ: {
|
|
return autograd::PropagateGradientsReq::fromMessage(request);
|
|
}
|
|
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
|
|
return autograd::CleanupAutogradContextReq::fromMessage(request);
|
|
}
|
|
default: {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Request type ", request.type(), " not supported.");
|
|
}
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<RpcCommandBase> deserializeResponse(const Message& response) {
|
|
switch (response.type()) {
|
|
case MessageType::SCRIPT_RET: {
|
|
return ScriptResp::fromMessage(response);
|
|
}
|
|
case MessageType::PYTHON_RET: {
|
|
return PythonUDFResp::fromMessage(response);
|
|
}
|
|
case MessageType::REMOTE_RET: {
|
|
return RemoteRet::fromMessage(response);
|
|
}
|
|
case MessageType::RREF_FETCH_RET: {
|
|
return RRefFetchRet::fromMessage(response);
|
|
}
|
|
case MessageType::RREF_ACK: {
|
|
return RRefAck::fromMessage(response);
|
|
}
|
|
case MessageType::EXCEPTION: {
|
|
std::string err(response.payload().begin(), response.payload().end());
|
|
throw std::runtime_error(err);
|
|
}
|
|
case MessageType::FORWARD_AUTOGRAD_RESP: {
|
|
return autograd::RpcWithAutograd::fromMessage(response);
|
|
}
|
|
case MessageType::BACKWARD_AUTOGRAD_RESP: {
|
|
return autograd::RpcWithAutograd::fromMessage(response);
|
|
}
|
|
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: {
|
|
return autograd::CleanupAutogradContextResp::fromMessage(response);
|
|
}
|
|
default: {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Response type ", response.type(), " not supported.");
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|