mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25527 Master GH issue: https://github.com/pytorch/pytorch/issues/23110. This change builds upon https://github.com/pytorch/pytorch/pull/24876 and provides all the autograd hooks needed for a forward pass with distributed rpc for builtin operators. This change does not address distributed rpc for python UDFs and that will be addressed in follow up PRs. Summary of changes: 1. Attach send autograd functions when a request is sent from the client and response is sent from the server. 2. Attach receive autograd functions when a request is received on the server and a response is received on the client. 3. Generate a globally unique autograd_message_id for each send/recv autograd function pair to uniquely identify them. ghstack-source-id: 91240466 Test Plan: unit tests. Differential Revision: D17148077 fbshipit-source-id: 192d8a3f552ed7cc939f55dcca332965c9bd3233
69 lines
2.1 KiB
C++
69 lines
2.1 KiB
C++
#include <torch/csrc/distributed/rpc/utils.h>
|
|
#include <torch/csrc/distributed/rpc/python_udf_call.h>
|
|
#include <torch/csrc/distributed/rpc/python_udf_resp.h>
|
|
#include <torch/csrc/distributed/rpc/rpc_with_autograd.h>
|
|
#include <torch/csrc/distributed/rpc/script_call.h>
|
|
#include <torch/csrc/distributed/rpc/script_remote_call.h>
|
|
#include <torch/csrc/distributed/rpc/script_resp.h>
|
|
#include <torch/csrc/distributed/rpc/script_rref_proto.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) {
|
|
switch (request.type()) {
|
|
case MessageType::SCRIPT_CALL: {
|
|
return ScriptCall::fromMessage(request);
|
|
}
|
|
case MessageType::PYTHON_CALL: {
|
|
return PythonUDFCall::fromMessage(request);
|
|
}
|
|
case MessageType::REMOTE_CALL: {
|
|
return ScriptRemoteCall::fromMessage(request);
|
|
}
|
|
case MessageType::RREF_FETCH_CALL: {
|
|
return ScriptRRefFetchCall::fromMessage(request);
|
|
}
|
|
case MessageType::RREF_USER_CREATE: {
|
|
return ScriptRRefCreate::fromMessage(request);
|
|
}
|
|
case MessageType::RREF_USER_DELETE: {
|
|
return ScriptRRefDelete::fromMessage(request);
|
|
}
|
|
case MessageType::MESSAGE_WITH_AUTOGRAD_REQ: {
|
|
return RpcWithAutograd::fromMessage(request);
|
|
}
|
|
default: {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Request type ", request.type(), " not supported.");
|
|
}
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<RpcCommandBase> deserializeResponse(const Message& response) {
|
|
switch (response.type()) {
|
|
case MessageType::SCRIPT_RET: {
|
|
return ScriptResp::fromMessage(response);
|
|
}
|
|
case MessageType::PYTHON_RET: {
|
|
return PythonUDFResp::fromMessage(response);
|
|
}
|
|
case MessageType::EXCEPTION: {
|
|
std::string err(response.payload().begin(), response.payload().end());
|
|
throw std::runtime_error(err);
|
|
}
|
|
case MessageType::MESSAGE_WITH_AUTOGRAD_RESP: {
|
|
return RpcWithAutograd::fromMessage(response);
|
|
}
|
|
default: {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Response type ", response.type(), " not supported.");
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|