mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR introduces **-Wmissing-prototypes** of clang-tidy to prevent further coding errors such as the one fixed by PR #96714. <!-- copilot:summary --> ### <samp>🤖 Generated by Copilot at fd2cf2a</samp> This pull request makes several internal functions static to improve performance and avoid name clashes. It also fixes some typos, formatting, and missing includes in various files. It adds a new .clang-tidy check to warn about missing prototypes for non-static functions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96805 Approved by: https://github.com/malfet, https://github.com/albanD
186 lines
6.7 KiB
C++
186 lines
6.7 KiB
C++
#include <ATen/ThreadLocalState.h>
|
|
#include <c10/util/ThreadLocalDebugInfo.h>
|
|
#include <torch/csrc/autograd/functions/utils.h>
|
|
#include <torch/csrc/autograd/profiler.h>
|
|
#include <torch/csrc/distributed/autograd/context/container.h>
|
|
#include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
|
|
#include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
|
|
#include <torch/csrc/distributed/autograd/utils.h>
|
|
#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
|
|
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
|
#include <torch/csrc/distributed/rpc/types.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace autograd {
|
|
|
|
using torch::distributed::autograd::AutogradMetadata;
|
|
using torch::distributed::autograd::RpcWithAutograd;
|
|
using torch::distributed::rpc::JitFuture;
|
|
using torch::distributed::rpc::Message;
|
|
using torch::distributed::rpc::MessageType;
|
|
using torch::distributed::rpc::RpcAgent;
|
|
using torch::distributed::rpc::WorkerInfo;
|
|
|
|
void addSendRpcBackward(
|
|
const ContextPtr& autogradContext,
|
|
const AutogradMetadata& autogradMetadata,
|
|
std::vector<torch::Tensor>& tensors) {
|
|
// Attach autograd information only for tensors requiring grad.
|
|
std::vector<torch::Tensor> tensors_with_grad;
|
|
std::copy_if(
|
|
tensors.begin(),
|
|
tensors.end(),
|
|
std::back_inserter(tensors_with_grad),
|
|
[](const torch::Tensor& t) { return t.requires_grad(); });
|
|
|
|
// Attach the appropriate autograd edges.
|
|
auto grad_fn = std::make_shared<SendRpcBackward>();
|
|
grad_fn->set_next_edges(
|
|
torch::autograd::collect_next_edges(tensors_with_grad));
|
|
|
|
// Add the appropriate input metadata for the grad_fn.
|
|
for (const auto& tensor : tensors_with_grad) {
|
|
grad_fn->add_input_metadata(tensor);
|
|
}
|
|
|
|
// Record the send autograd function in our current context.
|
|
autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
|
|
}
|
|
|
|
ContextPtr addRecvRpcBackward(
|
|
const AutogradMetadata& autogradMetadata,
|
|
std::vector<torch::Tensor>& tensors,
|
|
rpc::worker_id_t fromWorkerId,
|
|
const rpc::DeviceMap& deviceMap) {
|
|
// Initialize autograd context if necessary.
|
|
auto& autogradContainer = DistAutogradContainer::getInstance();
|
|
auto autogradContext =
|
|
autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);
|
|
|
|
if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
|
|
// Attach the tensors as inputs to the autograd function.
|
|
auto grad_fn = std::make_shared<RecvRpcBackward>(
|
|
autogradMetadata, autogradContext, fromWorkerId, deviceMap);
|
|
for (auto& tensor : tensors) {
|
|
if (tensor.requires_grad()) {
|
|
torch::autograd::set_history(tensor, grad_fn);
|
|
}
|
|
}
|
|
|
|
// Now update the autograd context with the necessary information.
|
|
autogradContext->addRecvFunction(
|
|
grad_fn, autogradMetadata.autogradMessageId);
|
|
}
|
|
|
|
return autogradContext;
|
|
}
|
|
|
|
static c10::intrusive_ptr<Message> getMessageWithProfiling(
|
|
c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMessage,
|
|
MessageType msgType,
|
|
torch::autograd::profiler::ProfilerConfig&& profilerConfig) {
|
|
auto& remoteProfilerManager =
|
|
torch::distributed::rpc::RemoteProfilerManager::getInstance();
|
|
|
|
auto key = remoteProfilerManager.getCurrentProfilingKey();
|
|
// generate a globally unique Id
|
|
auto globallyUniqueProfilingId = remoteProfilerManager.getNextProfilerId();
|
|
// Save a mapping of ID -> RPC profiling key and unset the current TLS key.
|
|
remoteProfilerManager.saveRPCKey(globallyUniqueProfilingId, key);
|
|
remoteProfilerManager.unsetCurrentKey();
|
|
auto wrappedProfilingMsg = RpcWithProfilingReq(
|
|
msgType,
|
|
std::move(wrappedRpcMessage),
|
|
std::move(profilerConfig),
|
|
globallyUniqueProfilingId);
|
|
|
|
return std::move(wrappedProfilingMsg).toMessage();
|
|
}
|
|
|
|
c10::intrusive_ptr<Message> getMessageWithAutograd(
|
|
const rpc::worker_id_t dstId,
|
|
c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg,
|
|
MessageType msgType,
|
|
bool forceGradRecording,
|
|
const rpc::DeviceMap& deviceMap) {
|
|
auto& autogradContainer = DistAutogradContainer::getInstance();
|
|
|
|
// If there is no valid context and no tensor requires grads, send original
|
|
// rpc message. otherwise, attach grad info and grad functions and send
|
|
// rpcWithAutograd message.
|
|
auto tensorsRequireGrad =
|
|
torch::autograd::compute_requires_grad(wrappedRpcMsg->tensors());
|
|
if (!autogradContainer.hasValidContext() ||
|
|
(!forceGradRecording && !tensorsRequireGrad)) {
|
|
return wrappedRpcMsg;
|
|
}
|
|
|
|
// Retrieve the appropriate context to modify.
|
|
auto autogradContext = autogradContainer.currentContext();
|
|
|
|
// Wrap the original rpc with autograd information.
|
|
AutogradMetadata autogradMetadata(
|
|
autogradContext->contextId(), autogradContainer.newAutogradMessageId());
|
|
auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
|
|
RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
|
|
msgType,
|
|
autogradMetadata,
|
|
std::move(wrappedRpcMsg),
|
|
deviceMap);
|
|
|
|
if (tensorsRequireGrad) {
|
|
// Record autograd information for 'send'.
|
|
addSendRpcBackward(
|
|
autogradContext, autogradMetadata, rpcWithAutograd->tensors());
|
|
}
|
|
// Record the workerID
|
|
autogradContext->addKnownWorkerId(dstId);
|
|
|
|
return std::move(*rpcWithAutograd).toMessage();
|
|
}
|
|
|
|
c10::intrusive_ptr<JitFuture> sendMessageWithAutograd(
|
|
RpcAgent& agent,
|
|
const WorkerInfo& dst,
|
|
c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg,
|
|
bool forceGradRecording,
|
|
const float rpcTimeoutSeconds,
|
|
bool forceDisableProfiling) {
|
|
auto msg = getMessageWithAutograd(
|
|
dst.id_,
|
|
std::move(wrappedRpcMsg),
|
|
MessageType::FORWARD_AUTOGRAD_REQ,
|
|
forceGradRecording,
|
|
agent.getDeviceMap(dst));
|
|
|
|
// If profiler is enabled, wrap this message with profiling metadata that will
|
|
// tell the remote end to process this request with the profiler enabled.
|
|
if (!forceDisableProfiling) {
|
|
switch (torch::profiler::impl::profilerType()) {
|
|
case torch::profiler::impl::ActiveProfilerType::LEGACY: {
|
|
auto profilerConfig = torch::autograd::profiler::getProfilerConfig();
|
|
auto msgWithProfiling = getMessageWithProfiling(
|
|
std::move(msg),
|
|
rpc::MessageType::RUN_WITH_PROFILING_REQ,
|
|
std::move(profilerConfig));
|
|
return agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
|
|
}
|
|
case torch::profiler::impl::ActiveProfilerType::KINETO:
|
|
TORCH_WARN_ONCE(
|
|
"Profiling a distributed call with the Kineto profiler will profile "
|
|
"the caller, but not the worker.");
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
return agent.send(dst, std::move(msg), rpcTimeoutSeconds);
|
|
;
|
|
}
|
|
|
|
} // namespace autograd
|
|
} // namespace distributed
|
|
} // namespace torch
|