pytorch/torch/csrc/distributed/autograd/utils.h
Rohan Varma d4a634c209 [RPC profiling] Don't wrap toHere() calls with profiling (#44655)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44655

Since `toHere()` does not execute operations over RPC and simply
transfers the value to the local node, we don't need to enable the profiler
remotely for this message. This causes unnecessary overhead and is not needed.

Since `toHere` is a blocking call, we already profile the call on the local node using `RECORD_USER_SCOPE`, so this does not change the expected profiler results (validated by ensuring all remote profiling tests pass).
ghstack-source-id: 112605610

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D23641466

fbshipit-source-id: 109d9eb10bd7fe76122b2026aaf1c7893ad10588
2020-09-22 21:17:00 -07:00

60 lines
2.5 KiB
C++

#pragma once
#include <torch/csrc/distributed/autograd/context/context.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
namespace torch {
namespace distributed {
namespace autograd {
// This method is used to attach the 'send' autograd function to the autograd
// graph when we use RPC. This method creates a new 'send' autograd function
// and attaches the provided tensors as next_edges to the 'send' function. In
// addition to this, it also registers the send function in the provided
// autograd context. Finally, the RPC message is updated with appropriate
// autograd information for the recipient.
TORCH_API void addSendRpcBackward(
const ContextPtr& autogradContext,
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors);
// This method is used to attach the 'recv' autograd function to the autograd
// graph when we use RPC. This method creates a new 'recv' autograd function
// and attaches the provided tensors as inputs to the 'recv' function. It
// creates a new autograd context if needed and registers the 'recv' function
// with this context.
//
// Returns a pointer to the autograd context created.
TORCH_API ContextPtr addRecvRpcBackward(
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors,
rpc::worker_id_t fromWorkerId);
// This method is a wrapper utility used internally to wrap autograd info
// and attach autograd function for each type of rpc call if it has valid
// context and tensors require grads or forceGradRecording is true, in this
// case, return RpcWithAutograd message; otherwise return original rpc message.
// NB: forceGradRecording is useful when the request does not contain any tensor
// but the corresponding response does.
TORCH_API rpc::Message getMessageWithAutograd(
const rpc::worker_id_t dstId,
rpc::Message&& wrappedRpcMsg,
rpc::MessageType msgType,
bool forceGradRecording = false);
// Send message after autograd checking
TORCH_API std::shared_ptr<torch::distributed::rpc::FutureMessage>
sendMessageWithAutograd(
rpc::RpcAgent& agent,
const rpc::WorkerInfo& dst,
rpc::Message&& wrappedRpcMsg,
bool forceGradRecording = false,
const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout,
bool forceDisableProfiling = false);
} // namespace autograd
} // namespace distributed
} // namespace torch