#include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace distributed { namespace rpc { using torch::distributed::autograd::DistAutogradContainer; using torch::distributed::autograd::DistAutogradContext; DistAutogradContainer* getDistAutogradContainer(); class TestE2EBase : public ::testing::Test { protected: void SetUp() override { // Setup distributed autograd. autogradContainer = getDistAutogradContainer(); // Setup server store. c10d::TCPStoreOptions opts{ /* port */ 0, /* isServer */ true, numWorkers, /* waitWorkers */ true, /* timeout */ std::chrono::seconds(10)}; store = c10::make_intrusive(serverAddress, opts); buildRpcAgent(); rpcAgentPostProcessing(); } void rpcAgentPostProcessing() { RpcAgent::setCurrentRpcAgent(rpcAgent); std::shared_ptr typeResolver = std::make_shared([&](const c10::QualifiedName& qn) { // For Dict that is used for device map. auto pos = qn.name().find("Dict"); if (pos != std::string::npos) { return c10::StrongTypePtr( nullptr, c10::DictType::create( c10::StringType::get(), c10::StringType::get())); } return c10::StrongTypePtr( nullptr, c10::TensorType::create(at::Tensor())); }); rpcAgent->setTypeResolver(typeResolver); rpcAgent->start(); } void TearDown() override { rpcAgent->join(); rpcAgent->shutdown(); RpcAgent::setCurrentRpcAgent(nullptr); } c10::intrusive_ptr createRemoteRRef( at::Tensor t1, at::Tensor t2, std::shared_ptr op) { auto& ctx = RRefContext::getInstance(); auto ownerRRef = ctx.createOwnerRRef(c10::TensorType::create(t1)); // prevent this owner RRef being deleted due to other forks ctx.addSelfAsFork(ownerRRef); ScriptRemoteCall scriptRemoteCall( op, {t1, t2, 1}, ownerRRef->rrefId(), ownerRRef->rrefId()); auto jitFuture = autograd::sendMessageWithAutograd( *rpcAgent, rpcAgent->getWorkerInfo("worker"), std::move(scriptRemoteCall).toMessage(), false); ownerRRef->registerOwnerCreationFuture(jitFuture); // Builtin operators does not return py::object, and hence does not require // GIL for destructing the potentially deleted OwerRRef. jitFuture->addCallback( [ownerRRefId = ownerRRef->rrefId()](JitFuture& jitFuture) { callback::finishCreatingOwnerRRef(jitFuture, ownerRRefId); }); return ownerRRef; } at::Tensor remoteAdd( at::Tensor t1, at::Tensor t2, std::shared_ptr op) { ScriptCall scriptCall(op, {t1, t2, /* alpha */ 1}); // Send the RPC and return result. auto response = autograd::sendMessageWithAutograd( *rpcAgent, rpcAgent->getWorkerInfo("worker"), std::move(scriptCall).toMessage()); response->waitAndThrow(); MessageType messageType = MessageType::FORWARD_AUTOGRAD_RESP; auto wrappedResponse = deserializeResponse( std::move(*response->value().toCustomClass()), messageType); return static_cast(*wrappedResponse).value().toTensor(); } virtual void buildRpcAgent() = 0; class AutogradContextGuard { public: explicit AutogradContextGuard() : context(DistAutogradContainer::getInstance().newContext()) {} ~AutogradContextGuard() { DistAutogradContainer::getInstance().releaseContext(context->contextId()); } private: std::shared_ptr context; }; void runTrainingLoop() { auto options = at::TensorOptions().requires_grad(true); auto t1 = torch::ones({3, 3}, options); auto t2 = torch::ones({3, 3}, options); c10::OperatorName full_name("aten::add", "Tensor"); auto matchedOp = torch::jit::findOperatorFor(full_name); ASSERT_TRUE(matchedOp); for (size_t i = 0; i < numIters; i++) { // Create the autograd context guard. AutogradContextGuard guard; // Multiple RPCs within one autograd context for the forward pass. auto result = remoteAdd(t1, t2, matchedOp); for (size_t j = 0; j < 5; j++) { result = remoteAdd(t1, result, matchedOp); } auto rref = createRemoteRRef(t1, result, matchedOp); result = rref->getValue().toTensor(); // Run backward pass now. autograd::DistEngine::getInstance().execute( DistAutogradContainer::currentContextId(), {torch::sum(result)}, /* retainGraph */ false); } } DistAutogradContainer* autogradContainer; std::shared_ptr rpcAgent; static const size_t numIters; static const size_t numWorkers; c10::intrusive_ptr store; static const char* serverAddress; }; } // namespace rpc } // namespace distributed } // namespace torch