#include #include #include #include #include #include #include namespace torch { namespace distributed { namespace rpc { namespace { constexpr int OWNER_IDX = 0; // index of ownerId in the tuple constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple constexpr int PARENT_IDX = 5; // index of parent in the tuple // NB: if more fields are added, make sure this field is also bumped constexpr int RFD_TUPLE_SIZE = 6; // number of RRefForkData fields in py::tuple template T& unwrapAutogradMessage( const Message& message, std::unique_ptr& response) { if (message.type() == MessageType::FORWARD_AUTOGRAD_RESP) { auto& rpcWithAutograd = static_cast(*response); // Attach 'recv' autograd function. addRecvRpcBackward( rpcWithAutograd.autogradMetadata(), rpcWithAutograd.tensors(), rpcWithAutograd.fromWorkerId()); auto& wrappedRpc = rpcWithAutograd.wrappedRpc(); return static_cast(wrappedRpc); } else { return static_cast(*response); } } } // namespace std::atomic RRefContext::nextLocalId_{0}; ////////////////////////// RRefForkData ///////////////////////////////// RRefForkData::RRefForkData( worker_id_t ownerId, const RRefId& rrefId, const ForkId& forkId, worker_id_t parent) : ownerId_(ownerId), rrefId_(rrefId), forkId_(forkId), parent_(parent) {} py::tuple RRefForkData::toPyTuple() const { return py::make_tuple( ownerId_, rrefId_.createdOn_, rrefId_.localId_, forkId_.createdOn_, forkId_.localId_, parent_); } RRefForkData RRefForkData::fromPyTuple(const py::tuple& t) { TORCH_INTERNAL_ASSERT( t.size() == RFD_TUPLE_SIZE, "Pickled RRefForkData must contain 6 numbers."); worker_id_t ownerId = t[OWNER_IDX].cast(); // const reference will extend the lifetime of the temporary variable const RRefId& rrefId = RRefId( t[RREFID_ON_IDX].cast(), t[RREFID_ID_IDX].cast()); const RRefId& forkId = RRefId( t[FORKID_ON_IDX].cast(), t[FORKID_ID_IDX].cast()); worker_id_t parent = t[PARENT_IDX].cast(); return RRefForkData(ownerId, rrefId, forkId, parent); } ////////////////////////////// RRef ///////////////////////////////////// RRef::RRef(worker_id_t ownerId, const RRefId& rrefId) : RRefInterface(), ownerId_(ownerId), rrefId_(rrefId) {} RRefForkData RRef::fork() const { auto& ctx = RRefContext::getInstance(); return RRefForkData( ownerId_, rrefId_, ctx.genGloballyUniqueId(), ctx.getWorkerId()); } ////////////////////////// UserRRef ///////////////////////////////////// template UserRRef::UserRRef( worker_id_t ownerId, const RRefId& rrefId, const ForkId& forkId) : RRef(ownerId, rrefId), forkId_(forkId) { // Do nothing, // (1) If this UserRRef is a fork of an existing RRef, RRefContext will send // a RREF_FORK_REQUEST message to the owner. // (2) If this the creator UserRRef, ScriptRemoteCall or PythonRemoteCall will // properly notify the owner. } template UserRRef::~UserRRef() { try { RRefContext::getInstance().delUser(ownerId_, rrefId_, forkId_); } catch (const std::exception& ex) { LOG(ERROR) << "Error occurred when deleting UserRRef instance, " << "RRefId = " << rrefId_ << ", ForkId = " << forkId_ << " : " << ex.what(); } catch (...) { LOG(ERROR) << "Error occurred when deleting UserRRef instance, " << "RRefId = " << rrefId_ << ", ForkId = " << forkId_ << " : " << "unknown error"; } } template const ForkId& UserRRef::forkId() const { return forkId_; } template <> IValue UserRRef::toHere() { auto agent = RpcAgent::getDefaultRpcAgent(); // ScriptRRefFetchCall message always carries autograd context id even if // the message itself does not contain any tensor, because the response would // potentially contain tensors. auto futureResponse = autograd::sendMessageWithAutograd( *agent, agent->getWorkerInfo(ownerId_), ScriptRRefFetchCall(ownerId_, rrefId()).toMessage(), true /* forceGradRecording */); const Message& message = futureResponse->wait(); auto response = deserializeResponse(message); auto& rfr = unwrapAutogradMessage(message, response); return rfr.values().front(); } template <> py::object UserRRef::toHere() { auto agent = RpcAgent::getDefaultRpcAgent(); // PythonRRefFetchCall message always carries autograd context id even if // the message itself does not contain any tensor, because the response would // potentially contain tensors. auto futureResponse = autograd::sendMessageWithAutograd( *agent, agent->getWorkerInfo(ownerId_), PythonRRefFetchCall(ownerId_, rrefId()).toMessage(), true /* forceGradRecording */); const Message& message = futureResponse->wait(); auto response = deserializeResponse(message); auto& rfr = unwrapAutogradMessage(message, response); return PythonRpcHandler::getInstance().deserialize( SerializedPyObj::fromIValues(rfr.values())); } template class UserRRef; template class UserRRef; ////////////////////////// OwnerRRef ///////////////////////////////////// template const T& OwnerRRef::getValue() const { std::unique_lock lock(mutex_); valueCV_.wait(lock, [this] { return value_.has_value(); }); return value_.value(); } template bool OwnerRRef::hasValue() const { std::lock_guard lock(mutex_); return value_.has_value(); } template std::shared_ptr OwnerRRef::getFuture() { std::unique_lock lock(mutex_); if (future_.get()) { return future_; } future_ = std::make_shared(); std::shared_ptr ret = future_; if (value_.has_value()) { lock.unlock(); ret->markCompleted(Message()); } return ret; } template void OwnerRRef::setValue(T&& value) { std::unique_lock lock(mutex_); value_ = std::move(value); std::shared_ptr future; future.swap(future_); lock.unlock(); valueCV_.notify_all(); if (future.get() && !future->completed()) { future->markCompleted(Message()); } } template class OwnerRRef; template class OwnerRRef; } // namespace rpc } // namespace distributed } // namespace torch