mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44655 Since `toHere()` does not execute operations over RPC and simply transfers the value to the local node, we don't need to enable the profiler remotely for this message. This causes unnecessary overhead and is not needed. Since `toHere` is a blocking call, we already profile the call on the local node using `RECORD_USER_SCOPE`, so this does not change the expected profiler results (validated by ensuring all remote profiling tests pass). ghstack-source-id: 112605610 Test Plan: CI Reviewed By: mrshenli Differential Revision: D23641466 fbshipit-source-id: 109d9eb10bd7fe76122b2026aaf1c7893ad10588
60 lines
2.5 KiB
C++
60 lines
2.5 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/distributed/autograd/context/context.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace autograd {
|
|
|
|
// This method is used to attach the 'send' autograd function to the autograd
|
|
// graph when we use RPC. This method creates a new 'send' autograd function
|
|
// and attaches the provided tensors as next_edges to the 'send' function. In
|
|
// addition to this, it also registers the send function in the provided
|
|
// autograd context. Finally, the RPC message is updated with appropriate
|
|
// autograd information for the recipient.
|
|
TORCH_API void addSendRpcBackward(
|
|
const ContextPtr& autogradContext,
|
|
const AutogradMetadata& autogradMetadata,
|
|
std::vector<torch::Tensor>& tensors);
|
|
|
|
// This method is used to attach the 'recv' autograd function to the autograd
|
|
// graph when we use RPC. This method creates a new 'recv' autograd function
|
|
// and attaches the provided tensors as inputs to the 'recv' function. It
|
|
// creates a new autograd context if needed and registers the 'recv' function
|
|
// with this context.
|
|
//
|
|
// Returns a pointer to the autograd context created.
|
|
TORCH_API ContextPtr addRecvRpcBackward(
|
|
const AutogradMetadata& autogradMetadata,
|
|
std::vector<torch::Tensor>& tensors,
|
|
rpc::worker_id_t fromWorkerId);
|
|
|
|
// This method is a wrapper utility used internally to wrap autograd info
|
|
// and attach autograd function for each type of rpc call if it has valid
|
|
// context and tensors require grads or forceGradRecording is true, in this
|
|
// case, return RpcWithAutograd message; otherwise return original rpc message.
|
|
// NB: forceGradRecording is useful when the request does not contain any tensor
|
|
// but the corresponding response does.
|
|
TORCH_API rpc::Message getMessageWithAutograd(
|
|
const rpc::worker_id_t dstId,
|
|
rpc::Message&& wrappedRpcMsg,
|
|
rpc::MessageType msgType,
|
|
bool forceGradRecording = false);
|
|
|
|
// Send message after autograd checking
|
|
TORCH_API std::shared_ptr<torch::distributed::rpc::FutureMessage>
|
|
sendMessageWithAutograd(
|
|
rpc::RpcAgent& agent,
|
|
const rpc::WorkerInfo& dst,
|
|
rpc::Message&& wrappedRpcMsg,
|
|
bool forceGradRecording = false,
|
|
const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout,
|
|
bool forceDisableProfiling = false);
|
|
|
|
} // namespace autograd
|
|
} // namespace distributed
|
|
} // namespace torch
|