pytorch/torch/csrc/distributed/rpc/utils.cpp
Shen Li 043530a9b9 Support remote for Python UDF in distributed autograd
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28656

Test Plan: Imported from OSS

Differential Revision: D18138561

Pulled By: mrshenli

fbshipit-source-id: 798e7c00465b5a299f7b4642683bc407895bc7da
2019-10-29 19:39:04 -07:00

106 lines
3.6 KiB
C++

#include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
#include <torch/csrc/distributed/rpc/python_remote_call.h>
#include <torch/csrc/distributed/rpc/python_udf_call.h>
#include <torch/csrc/distributed/rpc/python_udf_resp.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/script_resp.h>
namespace torch {
namespace distributed {
namespace rpc {
std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) {
switch (request.type()) {
case MessageType::SCRIPT_CALL: {
return ScriptCall::fromMessage(request);
}
case MessageType::PYTHON_CALL: {
return PythonUDFCall::fromMessage(request);
}
case MessageType::SCRIPT_REMOTE_CALL: {
return ScriptRemoteCall::fromMessage(request);
}
case MessageType::PYTHON_REMOTE_CALL: {
return PythonRemoteCall::fromMessage(request);
}
case MessageType::SCRIPT_RREF_FETCH_CALL: {
return ScriptRRefFetchCall::fromMessage(request);
}
case MessageType::PYTHON_RREF_FETCH_CALL: {
return PythonRRefFetchCall::fromMessage(request);
}
case MessageType::RREF_USER_DELETE: {
return RRefUserDelete::fromMessage(request);
}
case MessageType::RREF_CHILD_ACCEPT: {
return RRefChildAccept::fromMessage(request);
}
case MessageType::RREF_FORK_REQUEST: {
return RRefForkRequest::fromMessage(request);
}
case MessageType::FORWARD_AUTOGRAD_REQ: {
return autograd::RpcWithAutograd::fromMessage(request);
}
case MessageType::BACKWARD_AUTOGRAD_REQ: {
return autograd::PropagateGradientsReq::fromMessage(request);
}
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
return autograd::CleanupAutogradContextReq::fromMessage(request);
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Request type ", request.type(), " not supported.");
}
}
}
std::unique_ptr<RpcCommandBase> deserializeResponse(const Message& response) {
switch (response.type()) {
case MessageType::SCRIPT_RET: {
return ScriptResp::fromMessage(response);
}
case MessageType::PYTHON_RET: {
return PythonUDFResp::fromMessage(response);
}
case MessageType::REMOTE_RET: {
return RemoteRet::fromMessage(response);
}
case MessageType::SCRIPT_RREF_FETCH_RET: {
return ScriptRRefFetchRet::fromMessage(response);
}
case MessageType::PYTHON_RREF_FETCH_RET: {
return PythonRRefFetchRet::fromMessage(response);
}
case MessageType::RREF_ACK: {
return RRefAck::fromMessage(response);
}
case MessageType::EXCEPTION: {
std::string err(response.payload().begin(), response.payload().end());
throw std::runtime_error(err);
}
case MessageType::FORWARD_AUTOGRAD_RESP: {
return autograd::RpcWithAutograd::fromMessage(response);
}
case MessageType::BACKWARD_AUTOGRAD_RESP: {
return autograd::RpcWithAutograd::fromMessage(response);
}
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: {
return autograd::CleanupAutogradContextResp::fromMessage(response);
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Response type ", response.type(), " not supported.");
}
}
}
} // namespace rpc
} // namespace distributed
} // namespace torch