pytorch/torch/csrc/distributed/rpc/utils.cpp
Pritam Damania fe4170bda8 Add send and recv backward functions for builtin operators RPC. (#25527)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25527

Master GH issue: https://github.com/pytorch/pytorch/issues/23110.

This change builds upon https://github.com/pytorch/pytorch/pull/24876 and
provides all the autograd hooks needed for a forward pass with distributed rpc
for builtin operators. This change does not address distributed rpc for python
UDFs and that will be addressed in follow up PRs.

Summary of changes:
1. Attach send autograd functions when a request is sent from the client and
response is sent from the server.
2. Attach receive autograd functions when a request is received on the server
and a response is received on the client.
3. Generate a globally unique autograd_message_id for each send/recv autograd
function pair to uniquely identify them.
ghstack-source-id: 91240466

Test Plan: unit tests.

Differential Revision: D17148077

fbshipit-source-id: 192d8a3f552ed7cc939f55dcca332965c9bd3233
2019-10-03 01:18:46 -07:00

69 lines
2.1 KiB
C++

#include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/distributed/rpc/python_udf_call.h>
#include <torch/csrc/distributed/rpc/python_udf_resp.h>
#include <torch/csrc/distributed/rpc/rpc_with_autograd.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/script_resp.h>
#include <torch/csrc/distributed/rpc/script_rref_proto.h>
namespace torch {
namespace distributed {
namespace rpc {
std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) {
switch (request.type()) {
case MessageType::SCRIPT_CALL: {
return ScriptCall::fromMessage(request);
}
case MessageType::PYTHON_CALL: {
return PythonUDFCall::fromMessage(request);
}
case MessageType::REMOTE_CALL: {
return ScriptRemoteCall::fromMessage(request);
}
case MessageType::RREF_FETCH_CALL: {
return ScriptRRefFetchCall::fromMessage(request);
}
case MessageType::RREF_USER_CREATE: {
return ScriptRRefCreate::fromMessage(request);
}
case MessageType::RREF_USER_DELETE: {
return ScriptRRefDelete::fromMessage(request);
}
case MessageType::MESSAGE_WITH_AUTOGRAD_REQ: {
return RpcWithAutograd::fromMessage(request);
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Request type ", request.type(), " not supported.");
}
}
}
std::unique_ptr<RpcCommandBase> deserializeResponse(const Message& response) {
switch (response.type()) {
case MessageType::SCRIPT_RET: {
return ScriptResp::fromMessage(response);
}
case MessageType::PYTHON_RET: {
return PythonUDFResp::fromMessage(response);
}
case MessageType::EXCEPTION: {
std::string err(response.payload().begin(), response.payload().end());
throw std::runtime_error(err);
}
case MessageType::MESSAGE_WITH_AUTOGRAD_RESP: {
return RpcWithAutograd::fromMessage(response);
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Response type ", response.type(), " not supported.");
}
}
}
} // namespace rpc
} // namespace distributed
} // namespace torch