mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Headers under torch/csrc/distributed may be referened with relative path, e.g., "<c10d/...>". However, relative path cannot be gracefully handled by Meta internal build when the NCCL PG is hipified to support AMD/RCCL because the "hipified" header files are generated in other directories. Moreover, using absolute path for header inclusion is the state-of-the-art in most components in Pytorch. Thus, this patch refactors all header paths in torch/csrc/distributed to be absolute. See D39835774 for more details about Meta internal complication. **How to test**: commit 9e5d199 removes -I./torch/csrc/distributed in compile options. Thus use it to verify we don't miss any relative path use of torch/csrc/distributed headers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/85780 Approved by: https://github.com/kumpera, https://github.com/huydhn
173 lines
5.4 KiB
C++
173 lines
5.4 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/csrc/distributed/autograd/context/container.h>
|
|
#include <torch/csrc/distributed/autograd/context/context.h>
|
|
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
|
|
#include <torch/csrc/distributed/autograd/utils.h>
|
|
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#include <torch/csrc/distributed/rpc/script_call.h>
|
|
#include <torch/csrc/distributed/rpc/script_remote_call.h>
|
|
#include <torch/csrc/distributed/rpc/script_resp.h>
|
|
#include <torch/csrc/distributed/rpc/utils.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
using torch::distributed::autograd::DistAutogradContainer;
|
|
using torch::distributed::autograd::DistAutogradContext;
|
|
|
|
DistAutogradContainer* getDistAutogradContainer();
|
|
|
|
class TestE2EBase : public ::testing::Test {
|
|
protected:
|
|
void SetUp() override {
|
|
// Setup distributed autograd.
|
|
autogradContainer = getDistAutogradContainer();
|
|
|
|
// Setup server store.
|
|
c10d::TCPStoreOptions opts{
|
|
/* port */ 0,
|
|
/* isServer */ true,
|
|
numWorkers,
|
|
/* waitWorkers */ true,
|
|
/* timeout */ std::chrono::seconds(10)};
|
|
|
|
store = c10::make_intrusive<c10d::TCPStore>(serverAddress, opts);
|
|
|
|
buildRpcAgent();
|
|
|
|
rpcAgentPostProcessing();
|
|
}
|
|
|
|
void rpcAgentPostProcessing() {
|
|
RpcAgent::setCurrentRpcAgent(rpcAgent);
|
|
std::shared_ptr<TypeResolver> typeResolver =
|
|
std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
|
|
// For Dict that is used for device map.
|
|
auto pos = qn.name().find("Dict");
|
|
if (pos != std::string::npos) {
|
|
return c10::StrongTypePtr(
|
|
nullptr,
|
|
c10::DictType::create(
|
|
c10::StringType::get(), c10::StringType::get()));
|
|
}
|
|
return c10::StrongTypePtr(
|
|
nullptr, c10::TensorType::create(at::Tensor()));
|
|
});
|
|
rpcAgent->setTypeResolver(typeResolver);
|
|
rpcAgent->start();
|
|
}
|
|
|
|
void TearDown() override {
|
|
rpcAgent->join();
|
|
rpcAgent->shutdown();
|
|
RpcAgent::setCurrentRpcAgent(nullptr);
|
|
}
|
|
|
|
c10::intrusive_ptr<OwnerRRef> createRemoteRRef(
|
|
at::Tensor t1,
|
|
at::Tensor t2,
|
|
std::shared_ptr<torch::jit::Operator> op) {
|
|
auto& ctx = RRefContext::getInstance();
|
|
auto ownerRRef = ctx.createOwnerRRef(c10::TensorType::create(t1));
|
|
// prevent this owner RRef being deleted due to other forks
|
|
ctx.addSelfAsFork(ownerRRef);
|
|
|
|
ScriptRemoteCall scriptRemoteCall(
|
|
op, {t1, t2, 1}, ownerRRef->rrefId(), ownerRRef->rrefId());
|
|
auto jitFuture = autograd::sendMessageWithAutograd(
|
|
*rpcAgent,
|
|
rpcAgent->getWorkerInfo("worker"),
|
|
std::move(scriptRemoteCall).toMessage(),
|
|
false);
|
|
|
|
ownerRRef->registerOwnerCreationFuture(jitFuture);
|
|
|
|
// Builtin operators does not return py::object, and hence does not require
|
|
// GIL for destructing the potentially deleted OwerRRef.
|
|
jitFuture->addCallback(
|
|
[ownerRRefId = ownerRRef->rrefId()](JitFuture& jitFuture) {
|
|
callback::finishCreatingOwnerRRef(jitFuture, ownerRRefId);
|
|
});
|
|
return ownerRRef;
|
|
}
|
|
|
|
at::Tensor remoteAdd(
|
|
at::Tensor t1,
|
|
at::Tensor t2,
|
|
std::shared_ptr<torch::jit::Operator> op) {
|
|
ScriptCall scriptCall(op, {t1, t2, /* alpha */ 1});
|
|
|
|
// Send the RPC and return result.
|
|
auto response = autograd::sendMessageWithAutograd(
|
|
*rpcAgent,
|
|
rpcAgent->getWorkerInfo("worker"),
|
|
std::move(scriptCall).toMessage());
|
|
response->waitAndThrow();
|
|
|
|
MessageType messageType = MessageType::FORWARD_AUTOGRAD_RESP;
|
|
auto wrappedResponse = deserializeResponse(
|
|
std::move(*response->value().toCustomClass<Message>()), messageType);
|
|
return static_cast<ScriptResp&>(*wrappedResponse).value().toTensor();
|
|
}
|
|
|
|
virtual void buildRpcAgent() = 0;
|
|
|
|
class AutogradContextGuard {
|
|
public:
|
|
explicit AutogradContextGuard()
|
|
: context(DistAutogradContainer::getInstance().newContext()) {}
|
|
|
|
~AutogradContextGuard() {
|
|
DistAutogradContainer::getInstance().releaseContext(context->contextId());
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<DistAutogradContext> context;
|
|
};
|
|
|
|
void runTrainingLoop() {
|
|
auto options = at::TensorOptions().requires_grad(true);
|
|
auto t1 = torch::ones({3, 3}, options);
|
|
auto t2 = torch::ones({3, 3}, options);
|
|
|
|
c10::OperatorName full_name("aten::add", "Tensor");
|
|
auto matchedOp = torch::jit::findOperatorFor(full_name);
|
|
ASSERT_TRUE(matchedOp);
|
|
|
|
for (size_t i = 0; i < numIters; i++) {
|
|
// Create the autograd context guard.
|
|
AutogradContextGuard guard;
|
|
|
|
// Multiple RPCs within one autograd context for the forward pass.
|
|
auto result = remoteAdd(t1, t2, matchedOp);
|
|
for (size_t j = 0; j < 5; j++) {
|
|
result = remoteAdd(t1, result, matchedOp);
|
|
}
|
|
|
|
auto rref = createRemoteRRef(t1, result, matchedOp);
|
|
result = rref->getValue().toTensor();
|
|
|
|
// Run backward pass now.
|
|
autograd::DistEngine::getInstance().execute(
|
|
DistAutogradContainer::currentContextId(),
|
|
{torch::sum(result)},
|
|
/* retainGraph */ false);
|
|
}
|
|
}
|
|
|
|
DistAutogradContainer* autogradContainer;
|
|
std::shared_ptr<RpcAgent> rpcAgent;
|
|
static const size_t numIters;
|
|
static const size_t numWorkers;
|
|
c10::intrusive_ptr<c10d::Store> store;
|
|
static const char* serverAddress;
|
|
};
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|