pytorch/torch/csrc/distributed/rpc/utils.cpp
Shen Li 400293fcc6 Support remote for builtin operators in distributed autograd (#28630)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28630

This includes:
1. Respect autograd context in rpc.remote for builtin ops
2. Force setting autograd context in RRef.to_here() even if the
message for to_here() does not contain any tensor.

Test Plan: Imported from OSS

Differential Revision: D18138562

Pulled By: mrshenli

fbshipit-source-id: a39ec83e556d19130f22eb317927241a017000ba
2019-10-29 19:39:00 -07:00

103 lines
3.5 KiB
C++

#include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
#include <torch/csrc/distributed/rpc/python_remote_call.h>
#include <torch/csrc/distributed/rpc/python_udf_call.h>
#include <torch/csrc/distributed/rpc/python_udf_resp.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/script_resp.h>
namespace torch {
namespace distributed {
namespace rpc {
std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) {
switch (request.type()) {
case MessageType::SCRIPT_CALL: {
return ScriptCall::fromMessage(request);
}
case MessageType::PYTHON_CALL: {
return PythonUDFCall::fromMessage(request);
}
case MessageType::SCRIPT_REMOTE_CALL: {
return ScriptRemoteCall::fromMessage(request);
}
case MessageType::PYTHON_REMOTE_CALL: {
return PythonRemoteCall::fromMessage(request);
}
case MessageType::SCRIPT_RREF_FETCH_CALL: {
return ScriptRRefFetchCall::fromMessage(request);
}
case MessageType::PYTHON_RREF_FETCH_CALL: {
return PythonRRefFetchCall::fromMessage(request);
}
case MessageType::RREF_USER_DELETE: {
return RRefUserDelete::fromMessage(request);
}
case MessageType::RREF_CHILD_ACCEPT: {
return RRefChildAccept::fromMessage(request);
}
case MessageType::RREF_FORK_REQUEST: {
return RRefForkRequest::fromMessage(request);
}
case MessageType::FORWARD_AUTOGRAD_REQ: {
return autograd::RpcWithAutograd::fromMessage(request);
}
case MessageType::BACKWARD_AUTOGRAD_REQ: {
return autograd::PropagateGradientsReq::fromMessage(request);
}
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
return autograd::CleanupAutogradContextReq::fromMessage(request);
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Request type ", request.type(), " not supported.");
}
}
}
std::unique_ptr<RpcCommandBase> deserializeResponse(const Message& response) {
switch (response.type()) {
case MessageType::SCRIPT_RET: {
return ScriptResp::fromMessage(response);
}
case MessageType::PYTHON_RET: {
return PythonUDFResp::fromMessage(response);
}
case MessageType::REMOTE_RET: {
return RemoteRet::fromMessage(response);
}
case MessageType::SCRIPT_RREF_FETCH_RET: {
return ScriptRRefFetchRet::fromMessage(response);
}
case MessageType::RREF_ACK: {
return RRefAck::fromMessage(response);
}
case MessageType::EXCEPTION: {
std::string err(response.payload().begin(), response.payload().end());
throw std::runtime_error(err);
}
case MessageType::FORWARD_AUTOGRAD_RESP: {
return autograd::RpcWithAutograd::fromMessage(response);
}
case MessageType::BACKWARD_AUTOGRAD_RESP: {
return autograd::RpcWithAutograd::fromMessage(response);
}
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: {
return autograd::CleanupAutogradContextResp::fromMessage(response);
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Response type ", response.type(), " not supported.");
}
}
}
} // namespace rpc
} // namespace distributed
} // namespace torch