mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32633 There were 2 sources of current RPC agent. - One is in Python world, `torch.distributedrpc.api._agent`. - The other is in C++ world, `RpcAgent::defaultRpcAgent_` Setting Python `_agent` to `None`, does not necessarily reset the C++ `defaultRpcAgent_` to `nullptr`. i.e. ``` torch.distributedrpc.api._agent = None ``` does not translate to ``` RpcAgent::defaultRpcAgent_ = nullptr ``` This PR is to remove this ambiguity, and use the C++ pointer as source of truth. The solution is to leverage a pybind11 behavior that it implicitly casts C++ `shared_ptr<RpcAgent>(nullptr)` to Python `None`. ghstack-source-id: 97293315 Test Plan: ``` buck test mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork -- test_duplicate_name buck build mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork buck-out/gen/caffe2/test/distributed/rpc/rpc_fork\#binary.par -r test_process_group_debug_info ``` ``` buck test mode/dev-nosan //caffe2/torch/fb/distributed/pytorch/tests:test_remote_module buck test mode/dev-nosan //caffe2/torch/fb/distributed/modules/tests:test_sharded_embedding buck test mode/dev-nosan //caffe2/torch/fb/distributed/modules/tests:test_sharded_pairwise_attention_pooling buck test mode/dev-nosan //caffe2/torch/fb/distributed/pytorch/tests:test_rpc ``` Differential Revision: D5733066 fbshipit-source-id: b3e6032ee975f19ca556497edbbf40b517b25be8
209 lines
6.4 KiB
C++
209 lines
6.4 KiB
C++
#include <torch/csrc/distributed/rpc/rref_impl.h>
|
|
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
|
#include <torch/csrc/distributed/autograd/utils.h>
|
|
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#include <torch/csrc/distributed/rpc/rref_proto.h>
|
|
#include <torch/csrc/distributed/rpc/utils.h>
|
|
#include <torch/csrc/jit/pybind_utils.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
namespace {
|
|
|
|
constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
|
|
constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
|
|
constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
|
|
constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
|
|
constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
|
|
constexpr int PARENT_IDX = 5; // index of parent in the tuple
|
|
constexpr int TYPE_IDX = 6; // index of parent in the tuple
|
|
|
|
// NB: if more fields are added, make sure this field is also bumped
|
|
constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
|
|
} // namespace
|
|
|
|
std::atomic<local_id_t> RRefContext::nextLocalId_{0};
|
|
|
|
////////////////////////// RRefForkData /////////////////////////////////
|
|
|
|
RRefForkData::RRefForkData(
|
|
worker_id_t ownerId,
|
|
const RRefId& rrefId,
|
|
const ForkId& forkId,
|
|
worker_id_t parent,
|
|
std::string type_str)
|
|
: ownerId_(ownerId),
|
|
rrefId_(rrefId),
|
|
forkId_(forkId),
|
|
parent_(parent),
|
|
type_str_(std::move(type_str)) {}
|
|
|
|
py::tuple RRefForkData::toPyTuple() const {
|
|
return py::make_tuple(
|
|
ownerId_,
|
|
rrefId_.createdOn_,
|
|
rrefId_.localId_,
|
|
forkId_.createdOn_,
|
|
forkId_.localId_,
|
|
parent_,
|
|
type_str_);
|
|
}
|
|
|
|
RRefForkData RRefForkData::fromPyTuple(const py::tuple& t) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
t.size() == RFD_TUPLE_SIZE,
|
|
"Pickled RRefForkData must contain 6 numbers.");
|
|
worker_id_t ownerId = t[OWNER_IDX].cast<worker_id_t>();
|
|
// const reference will extend the lifetime of the temporary variable
|
|
const RRefId& rrefId = RRefId(
|
|
t[RREFID_ON_IDX].cast<worker_id_t>(),
|
|
t[RREFID_ID_IDX].cast<local_id_t>());
|
|
const RRefId& forkId = RRefId(
|
|
t[FORKID_ON_IDX].cast<worker_id_t>(),
|
|
t[FORKID_ID_IDX].cast<local_id_t>());
|
|
|
|
worker_id_t parent = t[PARENT_IDX].cast<worker_id_t>();
|
|
const std::string& typeStr = t[TYPE_IDX].cast<std::string>();
|
|
|
|
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
|
|
}
|
|
|
|
////////////////////////////// RRef /////////////////////////////////////
|
|
|
|
RRef::RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
|
|
: RRefInterface(),
|
|
ownerId_(ownerId),
|
|
rrefId_(rrefId),
|
|
type_(std::move(type)) {}
|
|
|
|
RRefForkData RRef::fork() const {
|
|
auto& ctx = RRefContext::getInstance();
|
|
return RRefForkData(
|
|
ownerId_,
|
|
rrefId_,
|
|
ctx.genGloballyUniqueId(),
|
|
ctx.getWorkerId(),
|
|
type_->str());
|
|
}
|
|
|
|
////////////////////////// UserRRef /////////////////////////////////////
|
|
|
|
UserRRef::UserRRef(
|
|
worker_id_t ownerId,
|
|
const RRefId& rrefId,
|
|
const ForkId& forkId,
|
|
TypePtr type)
|
|
: RRef(ownerId, rrefId, std::move(type)), forkId_(forkId) {
|
|
// Do nothing,
|
|
// (1) If this UserRRef is a fork of an existing RRef, RRefContext will send
|
|
// a RREF_FORK_REQUEST message to the owner.
|
|
// (2) If this the creator UserRRef, ScriptRemoteCall or PythonRemoteCall will
|
|
// properly notify the owner.
|
|
}
|
|
|
|
UserRRef::~UserRRef() {
|
|
try {
|
|
RRefContext::getInstance().delUser(ownerId_, rrefId_, forkId_);
|
|
} catch (const std::exception& ex) {
|
|
LOG(ERROR) << "Error occurred when deleting UserRRef instance, "
|
|
<< "RRefId = " << rrefId_ << ", ForkId = " << forkId_ << " : "
|
|
<< ex.what();
|
|
} catch (...) {
|
|
LOG(ERROR) << "Error occurred when deleting UserRRef instance, "
|
|
<< "RRefId = " << rrefId_ << ", ForkId = " << forkId_ << " : "
|
|
<< "unknown error";
|
|
}
|
|
}
|
|
|
|
const ForkId& UserRRef::forkId() const {
|
|
return forkId_;
|
|
}
|
|
|
|
IValue UserRRef::toHere() {
|
|
auto agent = RpcAgent::getCurrentRpcAgent();
|
|
|
|
// ScriptRRefFetchCall message always carries autograd context id even if
|
|
// the message itself does not contain any tensor, because the response would
|
|
// potentially contain tensors.
|
|
Message msgToSend;
|
|
|
|
if (isPyObj()) {
|
|
msgToSend = PythonRRefFetchCall(ownerId_, rrefId()).toMessage();
|
|
} else {
|
|
msgToSend = ScriptRRefFetchCall(ownerId_, rrefId()).toMessage();
|
|
}
|
|
|
|
auto futureResponse = autograd::sendMessageWithAutograd(
|
|
*agent,
|
|
agent->getWorkerInfo(ownerId_),
|
|
std::move(msgToSend),
|
|
true /* forceGradRecording */);
|
|
|
|
const Message& message = futureResponse->wait();
|
|
MessageType msgType = message.type();
|
|
auto response = deserializeResponse(message, msgType);
|
|
TORCH_INTERNAL_ASSERT(
|
|
msgType == MessageType::SCRIPT_RREF_FETCH_RET ||
|
|
msgType == MessageType::PYTHON_RREF_FETCH_RET,
|
|
"Message type should either be SCRIPT_RREF_FETCH_RET "
|
|
"or PYTHON_RREF_FETCH_RET");
|
|
RpcCommandBase& rpc = *response;
|
|
if (isPyObj()) {
|
|
auto& rfr = static_cast<PythonRRefFetchRet&>(rpc);
|
|
return jit::toIValue(
|
|
PythonRpcHandler::getInstance().deserialize(
|
|
SerializedPyObj::fromIValues(rfr.values())),
|
|
PyObjectType::get());
|
|
} else {
|
|
auto& rfr = static_cast<ScriptRRefFetchRet&>(rpc);
|
|
return rfr.values().front();
|
|
}
|
|
}
|
|
|
|
////////////////////////// OwnerRRef /////////////////////////////////////
|
|
|
|
const IValue& OwnerRRef::getValue() const {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
valueCV_.wait(lock, [this] { return value_.has_value(); });
|
|
return value_.value();
|
|
}
|
|
|
|
bool OwnerRRef::hasValue() const {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
return value_.has_value();
|
|
}
|
|
|
|
std::shared_ptr<FutureMessage> OwnerRRef::getFuture() {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
if (future_.get()) {
|
|
return future_;
|
|
}
|
|
future_ = std::make_shared<FutureMessage>();
|
|
std::shared_ptr<FutureMessage> ret = future_;
|
|
if (value_.has_value()) {
|
|
lock.unlock();
|
|
ret->markCompleted(Message());
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
void OwnerRRef::setValue(IValue&& value) {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
value_ = std::move(value);
|
|
std::shared_ptr<FutureMessage> future;
|
|
future.swap(future_);
|
|
lock.unlock();
|
|
valueCV_.notify_all();
|
|
if (future.get() && !future->completed()) {
|
|
future->markCompleted(Message());
|
|
}
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|