mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38590 This PR implements timeout semantics for RRef for parity with rpc_sync and rpc_async. How it works: - Timeout parameter is added to rpc.remote. If the rpc.remote call times out, note that the error won't be raised to the user in that call, as it is not blocking (similar to rpc_async). Instead, the timeout error will be raised the next time the RRef is used (either by pickling or to_here call). - Error handling semantics are added to RRef to deal with the timeout errors. Previously, if there was an error creating the OwnerRRef, the callback on the local user would throw an error in a callback, resulting in an `std::terminate`. Instead of this, the error is now caught and surfaced to the user the next time the RRef is used. As part of this, we have added an `RPCErrorType` enum and defined RRef error handlers to handle the `RPCErrorrTypes` (currently just timeout and unknown) - A timeout parameter is added to `to_here()` which gives the user control over the max amount of time it can block for. - `ctx.prepareChildForFork()` which is called when the RRef is pickled (i.e. used as an arg over RPC) checks if the `rpc.remote()` call had timed out, and if so, raises that error to the user. - Tests are added, primarily via delay injection. ghstack-source-id: 105232837 Test Plan: CI Differential Revision: D21588165 fbshipit-source-id: c9f9e8aa3521012ea1de3e0f152a41afdf8b23f3
265 lines
8.7 KiB
C++
265 lines
8.7 KiB
C++
#include <torch/csrc/distributed/rpc/rref_impl.h>
|
|
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
|
#include <torch/csrc/distributed/autograd/utils.h>
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#include <torch/csrc/distributed/rpc/rref_proto.h>
|
|
#include <torch/csrc/distributed/rpc/utils.h>
|
|
|
|
namespace {
|
|
// If the type is subtype of named type, return its qualifiedname, otherwise
|
|
// return its type str.
|
|
std::string getTypeStr(const c10::TypePtr& type) {
|
|
switch (type->kind()) {
|
|
case c10::TypeKind::FunctionType:
|
|
return type->cast<c10::FunctionType>()->name()->qualifiedName();
|
|
case c10::TypeKind::TupleType:
|
|
return type->cast<c10::TupleType>()->name()->qualifiedName();
|
|
case c10::TypeKind::ClassType:
|
|
return type->cast<c10::ClassType>()->name()->qualifiedName();
|
|
case c10::TypeKind::InterfaceType:
|
|
return type->cast<c10::InterfaceType>()->name()->qualifiedName();
|
|
default:
|
|
return type->str();
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
std::atomic<local_id_t> RRefContext::nextLocalId_{0};
|
|
|
|
////////////////////////// RRefForkData /////////////////////////////////
|
|
|
|
RRefForkData::RRefForkData(
|
|
worker_id_t ownerId,
|
|
const RRefId& rrefId,
|
|
const ForkId& forkId,
|
|
worker_id_t parent,
|
|
std::string typeStr)
|
|
: ownerId_(ownerId),
|
|
rrefId_(rrefId),
|
|
forkId_(forkId),
|
|
parent_(parent),
|
|
typeStr_(std::move(typeStr)) {}
|
|
|
|
////////////////////////////// RRef /////////////////////////////////////
|
|
|
|
RRef::RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
|
|
: RRefInterface(),
|
|
ownerId_(ownerId),
|
|
rrefId_(rrefId),
|
|
type_(std::move(type)) {}
|
|
|
|
RRefForkData RRef::fork() const {
|
|
auto& ctx = RRefContext::getInstance();
|
|
return RRefForkData(
|
|
ownerId_,
|
|
rrefId_,
|
|
ctx.genGloballyUniqueId(),
|
|
ctx.getWorkerId(),
|
|
getTypeStr(type_));
|
|
}
|
|
|
|
void RRef::handleError(
|
|
RPCErrorType errorType,
|
|
const FutureMessage& futMessage) {
|
|
static std::unordered_map<
|
|
RPCErrorType,
|
|
std::function<void(const FutureMessage& fm)>,
|
|
std::hash<int>>
|
|
errorHandlers = {
|
|
{RPCErrorType::TIMEOUT,
|
|
[this](const FutureMessage& /* unused */) { setTimedOut(); }},
|
|
{RPCErrorType::INTENTIONAL_FAILURE,
|
|
[this](const FutureMessage& /* unused */) { setTimedOut(); }},
|
|
{RPCErrorType::UNKNOWN_ERROR, [](const FutureMessage& fm) {
|
|
// Default error handler, equivalent to
|
|
// RRefContext::handleException().
|
|
VLOG(1) << "Got exception: " << fm.error()->what();
|
|
throw std::runtime_error(fm.error()->what());
|
|
}}};
|
|
errorHandlers.find(errorType)->second(futMessage);
|
|
}
|
|
|
|
////////////////////////// UserRRef /////////////////////////////////////
|
|
|
|
UserRRef::UserRRef(
|
|
worker_id_t ownerId,
|
|
const RRefId& rrefId,
|
|
const ForkId& forkId,
|
|
TypePtr type)
|
|
: RRef(ownerId, rrefId, std::move(type)),
|
|
forkId_(forkId),
|
|
confirmedByOwner_(false) {
|
|
// Do nothing,
|
|
// (1) If this UserRRef is a fork of an existing RRef, RRefContext will send
|
|
// a RREF_FORK_REQUEST message to the owner.
|
|
// (2) If this the creator UserRRef, ScriptRemoteCall or PythonRemoteCall will
|
|
// properly notify the owner.
|
|
}
|
|
|
|
void UserRRef::tryDel() {
|
|
std::lock_guard<std::mutex> lockGuard(deletedOnOwnerMutex_);
|
|
if (!deletedOnOwner_) {
|
|
try {
|
|
RRefContext::getInstance().delUser(ownerId_, rrefId_, forkId_);
|
|
deletedOnOwner_ = true;
|
|
} catch (const std::exception& ex) {
|
|
LOG(ERROR) << "Error occurred when deleting UserRRef instance, "
|
|
<< "RRefId = " << rrefId_ << ", ForkId = " << forkId_ << " : "
|
|
<< ex.what();
|
|
} catch (...) {
|
|
LOG(ERROR) << "Error occurred when deleting UserRRef instance, "
|
|
<< "RRefId = " << rrefId_ << ", ForkId = " << forkId_ << " : "
|
|
<< "unknown error";
|
|
}
|
|
}
|
|
}
|
|
|
|
void UserRRef::release_resources() {
|
|
tryDel();
|
|
}
|
|
|
|
const ForkId& UserRRef::forkId() const {
|
|
return forkId_;
|
|
}
|
|
|
|
IValue UserRRef::toHere(const float timeoutSeconds) const {
|
|
if (this->getTimedOut()) {
|
|
throw std::runtime_error(
|
|
"RRef creation via rpc.remote() timed out, and it "
|
|
"is possible that the RRef on the owner node does not exist.");
|
|
}
|
|
// see Note [Best-Effort Check on Deleted UserRRefs]
|
|
TORCH_CHECK(
|
|
!deletedOnOwner_,
|
|
"User RRef with RRefId=",
|
|
rrefId(),
|
|
" and ForkId=",
|
|
forkId(),
|
|
" has been deleted. Cannot call to_here() on it after deletion.");
|
|
TORCH_CHECK(
|
|
!type_->is_module(),
|
|
"User RRef with RRefId=",
|
|
rrefId(),
|
|
" and ForkId=",
|
|
forkId(),
|
|
" is an RRef to a ScriptModule. "
|
|
"It can't be sent through RPC "
|
|
"from owner, ",
|
|
ownerName(),
|
|
", to user, ",
|
|
RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_,
|
|
".");
|
|
|
|
auto agent = RpcAgent::getCurrentRpcAgent();
|
|
|
|
// ScriptRRefFetchCall message always carries autograd context id even if
|
|
// the message itself does not contain any tensor, because the response would
|
|
// potentially contain tensors.
|
|
Message msgToSend;
|
|
|
|
if (isPyObj()) {
|
|
msgToSend = PythonRRefFetchCall(ownerId_, rrefId()).toMessage();
|
|
} else {
|
|
msgToSend = ScriptRRefFetchCall(ownerId_, rrefId()).toMessage();
|
|
}
|
|
|
|
auto futureResponse = autograd::sendMessageWithAutograd(
|
|
*agent,
|
|
agent->getWorkerInfo(ownerId_),
|
|
std::move(msgToSend),
|
|
true /* forceGradRecording */,
|
|
timeoutSeconds);
|
|
|
|
// TODO: we should ideally be able to interrupt this blocking wait if we check
|
|
// isTimedOut() and it is true.
|
|
const Message& message = futureResponse->wait();
|
|
MessageType msgType = message.type();
|
|
auto response = deserializeResponse(message, msgType);
|
|
TORCH_INTERNAL_ASSERT(
|
|
msgType == MessageType::SCRIPT_RREF_FETCH_RET ||
|
|
msgType == MessageType::PYTHON_RREF_FETCH_RET,
|
|
"Message type should either be SCRIPT_RREF_FETCH_RET "
|
|
"or PYTHON_RREF_FETCH_RET");
|
|
RpcCommandBase& rpc = *response;
|
|
auto& rrefFetchRet = static_cast<RRefFetchRet&>(rpc);
|
|
if (isPyObj()) {
|
|
// wrap python serialized vector of ivalues into tuple, this
|
|
// made the C++ toHere interface to return single IValue
|
|
return ivalue::Tuple::create(rrefFetchRet.values());
|
|
} else {
|
|
return rrefFetchRet.values().front();
|
|
}
|
|
}
|
|
|
|
RRefForkData UserRRef::fork() const {
|
|
// Note [Best-Effort Check on Deleted UserRRefs]
|
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
// This check does not guarantee correctness, as there could be another thread
|
|
// trying to delete this UserRRef concurrently. Passing this check does not
|
|
// mean this RRef will be alive throughout this function. This is just our
|
|
// best-effort attempt to raise proper error messages. The behavior of using
|
|
// deleted UserRRefs is undefined.
|
|
//
|
|
// The reason for not implementing strict checks are:
|
|
// 1. This would need to acquire lock on deletedOnOwnerMutex_, which would
|
|
// introduce unnecessary overhead for most normal use cases.
|
|
// 2. This would introduce a lot of complexities to get the behavior correct.
|
|
// Assume we acquired the lock here, and there is another thread X block
|
|
// waiting in tryDel() on the lock. Exiting this fork function would
|
|
// unblock thread X. However, while X proceeds with deleting this UserRRef,
|
|
// the call site of fork() might have added the UserRRef to
|
|
// pendingChildren_ map, but up to this point, nothing prevents X from
|
|
// deleting this RRef even if it shouldn't do so due to the state change
|
|
// in pendingChildren_. We might be able to get it right for now by locking
|
|
// and checking pendingChildren_ in X, but the gain does not seem to
|
|
// worth the complexity.
|
|
TORCH_CHECK(
|
|
!deletedOnOwner_,
|
|
"User RRef with RRefId=",
|
|
rrefId(),
|
|
" and ForkId=",
|
|
forkId(),
|
|
" has been deleted. Cannot call fork an UserRRef after deletion.");
|
|
return RRef::fork();
|
|
}
|
|
|
|
////////////////////////// OwnerRRef /////////////////////////////////////
|
|
|
|
const IValue& OwnerRRef::getValue() const {
|
|
if (this->getTimedOut()) {
|
|
throw std::runtime_error(
|
|
"RRef creation via rpc.remote() to self timed out, and it "
|
|
"is possible that the RRef on the owner node does not exist.");
|
|
}
|
|
future_->wait();
|
|
if (future_->hasError()) {
|
|
(void)future_->value(); // Throws the error.
|
|
}
|
|
return future_->constValue();
|
|
}
|
|
|
|
bool OwnerRRef::hasValue() const {
|
|
return future_->completed();
|
|
}
|
|
|
|
std::shared_ptr<JitFuture> OwnerRRef::getFuture() {
|
|
return future_;
|
|
}
|
|
|
|
void OwnerRRef::setValue(IValue&& value) {
|
|
future_->markCompleted(value);
|
|
}
|
|
|
|
void OwnerRRef::setError(const std::string& error) {
|
|
future_->setErrorIfNeeded(error);
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|