#include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace distributed { namespace rpc { std::unique_ptr deserializeRequest(const Message& request) { switch (request.type()) { case MessageType::SCRIPT_CALL: { return ScriptCall::fromMessage(request); } case MessageType::PYTHON_CALL: { return PythonCall::fromMessage(request); } case MessageType::SCRIPT_REMOTE_CALL: { return ScriptRemoteCall::fromMessage(request); } case MessageType::PYTHON_REMOTE_CALL: { return PythonRemoteCall::fromMessage(request); } case MessageType::SCRIPT_RREF_FETCH_CALL: { return ScriptRRefFetchCall::fromMessage(request); } case MessageType::PYTHON_RREF_FETCH_CALL: { return PythonRRefFetchCall::fromMessage(request); } case MessageType::RREF_USER_DELETE: { return RRefUserDelete::fromMessage(request); } case MessageType::RREF_CHILD_ACCEPT: { return RRefChildAccept::fromMessage(request); } case MessageType::RREF_FORK_REQUEST: { return RRefForkRequest::fromMessage(request); } case MessageType::FORWARD_AUTOGRAD_REQ: { return autograd::RpcWithAutograd::fromMessage(request); } case MessageType::BACKWARD_AUTOGRAD_REQ: { return autograd::PropagateGradientsReq::fromMessage(request); } case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: { return autograd::CleanupAutogradContextReq::fromMessage(request); } default: { TORCH_INTERNAL_ASSERT( false, "Request type ", request.type(), " not supported."); } } } std::unique_ptr deserializeResponse(const Message& response) { switch (response.type()) { case MessageType::SCRIPT_RET: { return ScriptResp::fromMessage(response); } case MessageType::PYTHON_RET: { return PythonResp::fromMessage(response); } case MessageType::REMOTE_RET: { return RemoteRet::fromMessage(response); } case MessageType::SCRIPT_RREF_FETCH_RET: { return ScriptRRefFetchRet::fromMessage(response); } case MessageType::PYTHON_RREF_FETCH_RET: { return PythonRRefFetchRet::fromMessage(response); } case MessageType::RREF_ACK: { return RRefAck::fromMessage(response); } case MessageType::EXCEPTION: { std::string err(response.payload().begin(), response.payload().end()); throw std::runtime_error(err); } case MessageType::FORWARD_AUTOGRAD_RESP: { return autograd::RpcWithAutograd::fromMessage(response); } case MessageType::BACKWARD_AUTOGRAD_RESP: { return autograd::RpcWithAutograd::fromMessage(response); } case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: { return autograd::CleanupAutogradContextResp::fromMessage(response); } default: { TORCH_INTERNAL_ASSERT( false, "Response type ", response.type(), " not supported."); } } } } // namespace rpc } // namespace distributed } // namespace torch