#pragma once #include #include #include #include namespace torch::distributed::autograd { // This method is used to attach the 'send' autograd function to the autograd // graph when we use RPC. This method creates a new 'send' autograd function // and attaches the provided tensors as next_edges to the 'send' function. In // addition to this, it also registers the send function in the provided // autograd context. Finally, the RPC message is updated with appropriate // autograd information for the recipient. TORCH_API void addSendRpcBackward( const ContextPtr& autogradContext, const AutogradMetadata& autogradMetadata, std::vector& tensors); // This method is used to attach the 'recv' autograd function to the autograd // graph when we use RPC. This method creates a new 'recv' autograd function // and attaches the provided tensors as inputs to the 'recv' function. It // creates a new autograd context if needed and registers the 'recv' function // with this context. // // Returns a pointer to the autograd context created. TORCH_API ContextPtr addRecvRpcBackward( const AutogradMetadata& autogradMetadata, std::vector& tensors, rpc::worker_id_t fromWorkerId, const rpc::DeviceMap& deviceMap); // This method is a wrapper utility used internally to wrap autograd info // and attach autograd function for each type of rpc call if it has valid // context and tensors require grads or forceGradRecording is true, in this // case, return RpcWithAutograd message; otherwise return original rpc message. // NB: forceGradRecording is useful when the request does not contain any tensor // but the corresponding response does. TORCH_API c10::intrusive_ptr getMessageWithAutograd( const rpc::worker_id_t dstId, c10::intrusive_ptr wrappedRpcMsg, rpc::MessageType msgType, bool forceGradRecording = false, const rpc::DeviceMap& deviceMap = {}); // Send message after autograd checking TORCH_API c10::intrusive_ptr sendMessageWithAutograd( rpc::RpcAgent& agent, const rpc::WorkerInfo& dst, c10::intrusive_ptr wrappedRpcMsg, bool forceGradRecording = false, const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout, bool forceDisableProfiling = false); } // namespace torch::distributed::autograd