mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32197 This is to reland https://github.com/pytorch/pytorch/pull/30063, the main change is to match a general exception and grep "pickle" error word in "test_script_functions_not_supported" unit test, as Python 3.5 and Python 3.6 throw different types of errors with different error message for the rpc call in the unit test. [test all]This diff makes following changes: 1. Providing a new set of python rpc privated APIs, they can accept an annotated TorchScript call and this call can be serialized, deserialized and executed in C++ without GIL. These privated APIs will be binded to JIT in the future, and they are different from public APIs as future JIT binded private APIs will be able to accept qualified_name, not callables. These private APIs are subject to be deprecated once JIT supports torch script function to be a JIT type. Also, these APIs require torch script function to be defined and annotated by users in python land, it can not be script class/module constructor or class/module methods. 2. This diff also allows public rpc APIs to accept an annotated TorchScript call and execute code path that above private APIs ran on. Therefore if users invoke an annotated TorchScript call over RPC, this call can be serialized, deserialized and executed in C++ without GIL as well. 3. The above private APIs call a newly defined C++ function to make rpc torch script call to be serialized, deserialized and executed in C++ land. This C++ function returns an ivalue::Future. so that in follow up diff this C++ function can be called when these privated APIs are binded to JIT. 4. script_call.cpp/.h and request_callback_impl.cpp files are refactored accordingly so that torch script call and builtin call can share same message type and codes. 5. refactored deserializeResponse() and added a new utility to deserizalize response to IValue ghstack-source-id: 96879167 ghstack-source-id: 96879167 Test Plan: unit test Differential Revision: D19402374 fbshipit-source-id: 04efcc7c167d08a6503f29efe55e76f2be4b2c5e
211 lines
6.8 KiB
C++
211 lines
6.8 KiB
C++
#include <torch/csrc/distributed/rpc/rref_impl.h>
|
|
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
|
#include <torch/csrc/distributed/autograd/utils.h>
|
|
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#include <torch/csrc/distributed/rpc/rref_proto.h>
|
|
#include <torch/csrc/distributed/rpc/utils.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
namespace {
|
|
|
|
constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
|
|
constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
|
|
constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
|
|
constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
|
|
constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
|
|
constexpr int PARENT_IDX = 5; // index of parent in the tuple
|
|
|
|
// NB: if more fields are added, make sure this field is also bumped
|
|
constexpr int RFD_TUPLE_SIZE = 6; // number of RRefForkData fields in py::tuple
|
|
} // namespace
|
|
|
|
std::atomic<local_id_t> RRefContext::nextLocalId_{0};
|
|
|
|
////////////////////////// RRefForkData /////////////////////////////////
|
|
|
|
RRefForkData::RRefForkData(
|
|
worker_id_t ownerId,
|
|
const RRefId& rrefId,
|
|
const ForkId& forkId,
|
|
worker_id_t parent)
|
|
: ownerId_(ownerId), rrefId_(rrefId), forkId_(forkId), parent_(parent) {}
|
|
|
|
py::tuple RRefForkData::toPyTuple() const {
|
|
return py::make_tuple(
|
|
ownerId_,
|
|
rrefId_.createdOn_,
|
|
rrefId_.localId_,
|
|
forkId_.createdOn_,
|
|
forkId_.localId_,
|
|
parent_);
|
|
}
|
|
|
|
RRefForkData RRefForkData::fromPyTuple(const py::tuple& t) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
t.size() == RFD_TUPLE_SIZE,
|
|
"Pickled RRefForkData must contain 6 numbers.");
|
|
worker_id_t ownerId = t[OWNER_IDX].cast<worker_id_t>();
|
|
// const reference will extend the lifetime of the temporary variable
|
|
const RRefId& rrefId = RRefId(
|
|
t[RREFID_ON_IDX].cast<worker_id_t>(),
|
|
t[RREFID_ID_IDX].cast<local_id_t>());
|
|
const RRefId& forkId = RRefId(
|
|
t[FORKID_ON_IDX].cast<worker_id_t>(),
|
|
t[FORKID_ID_IDX].cast<local_id_t>());
|
|
worker_id_t parent = t[PARENT_IDX].cast<worker_id_t>();
|
|
return RRefForkData(ownerId, rrefId, forkId, parent);
|
|
}
|
|
|
|
////////////////////////////// RRef /////////////////////////////////////
|
|
|
|
RRef::RRef(worker_id_t ownerId, const RRefId& rrefId)
|
|
: RRefInterface(), ownerId_(ownerId), rrefId_(rrefId) {}
|
|
|
|
RRefForkData RRef::fork() const {
|
|
auto& ctx = RRefContext::getInstance();
|
|
return RRefForkData(
|
|
ownerId_, rrefId_, ctx.genGloballyUniqueId(), ctx.getWorkerId());
|
|
}
|
|
|
|
////////////////////////// UserRRef /////////////////////////////////////
|
|
|
|
template <typename T>
|
|
UserRRef<T>::UserRRef(
|
|
worker_id_t ownerId,
|
|
const RRefId& rrefId,
|
|
const ForkId& forkId)
|
|
: RRef(ownerId, rrefId), forkId_(forkId) {
|
|
// Do nothing,
|
|
// (1) If this UserRRef is a fork of an existing RRef, RRefContext will send
|
|
// a RREF_FORK_REQUEST message to the owner.
|
|
// (2) If this the creator UserRRef, ScriptRemoteCall or PythonRemoteCall will
|
|
// properly notify the owner.
|
|
}
|
|
|
|
template <typename T>
|
|
UserRRef<T>::~UserRRef() {
|
|
try {
|
|
RRefContext::getInstance().delUser(ownerId_, rrefId_, forkId_);
|
|
} catch (const std::exception& ex) {
|
|
LOG(ERROR) << "Error occurred when deleting UserRRef instance, "
|
|
<< "RRefId = " << rrefId_ << ", ForkId = " << forkId_ << " : "
|
|
<< ex.what();
|
|
} catch (...) {
|
|
LOG(ERROR) << "Error occurred when deleting UserRRef instance, "
|
|
<< "RRefId = " << rrefId_ << ", ForkId = " << forkId_ << " : "
|
|
<< "unknown error";
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
const ForkId& UserRRef<T>::forkId() const {
|
|
return forkId_;
|
|
}
|
|
|
|
template <>
|
|
IValue UserRRef<IValue>::toHere() {
|
|
auto agent = RpcAgent::getDefaultRpcAgent();
|
|
|
|
// ScriptRRefFetchCall message always carries autograd context id even if
|
|
// the message itself does not contain any tensor, because the response would
|
|
// potentially contain tensors.
|
|
auto futureResponse = autograd::sendMessageWithAutograd(
|
|
*agent,
|
|
agent->getWorkerInfo(ownerId_),
|
|
ScriptRRefFetchCall(ownerId_, rrefId()).toMessage(),
|
|
true /* forceGradRecording */);
|
|
|
|
const Message& message = futureResponse->wait();
|
|
MessageType msgType = message.type();
|
|
auto response = deserializeResponse(message, msgType);
|
|
TORCH_INTERNAL_ASSERT(
|
|
msgType == MessageType::SCRIPT_RREF_FETCH_RET,
|
|
"Message type should be SCRIPT_RREF_FETCH_RET.");
|
|
RpcCommandBase& rpc = *response;
|
|
auto& rfr = static_cast<ScriptRRefFetchRet&>(rpc);
|
|
return rfr.values().front();
|
|
}
|
|
|
|
template <>
|
|
py::object UserRRef<py::object>::toHere() {
|
|
auto agent = RpcAgent::getDefaultRpcAgent();
|
|
|
|
// PythonRRefFetchCall message always carries autograd context id even if
|
|
// the message itself does not contain any tensor, because the response would
|
|
// potentially contain tensors.
|
|
auto futureResponse = autograd::sendMessageWithAutograd(
|
|
*agent,
|
|
agent->getWorkerInfo(ownerId_),
|
|
PythonRRefFetchCall(ownerId_, rrefId()).toMessage(),
|
|
true /* forceGradRecording */);
|
|
|
|
const Message& message = futureResponse->wait();
|
|
MessageType msgType = message.type();
|
|
auto response = deserializeResponse(message, msgType);
|
|
TORCH_INTERNAL_ASSERT(
|
|
msgType == MessageType::PYTHON_RREF_FETCH_RET,
|
|
"Message type should be PYTHON_RREF_FETCH_RET.");
|
|
RpcCommandBase& rpc = *response;
|
|
auto& rfr = static_cast<PythonRRefFetchRet&>(rpc);
|
|
return PythonRpcHandler::getInstance().deserialize(
|
|
SerializedPyObj::fromIValues(rfr.values()));
|
|
}
|
|
|
|
template class UserRRef<IValue>;
|
|
template class UserRRef<py::object>;
|
|
|
|
////////////////////////// OwnerRRef /////////////////////////////////////
|
|
|
|
template <typename T>
|
|
const T& OwnerRRef<T>::getValue() const {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
valueCV_.wait(lock, [this] { return value_.has_value(); });
|
|
return value_.value();
|
|
}
|
|
|
|
template <typename T>
|
|
bool OwnerRRef<T>::hasValue() const {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
return value_.has_value();
|
|
}
|
|
|
|
template <typename T>
|
|
std::shared_ptr<FutureMessage> OwnerRRef<T>::getFuture() {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
if (future_.get()) {
|
|
return future_;
|
|
}
|
|
future_ = std::make_shared<FutureMessage>();
|
|
std::shared_ptr<FutureMessage> ret = future_;
|
|
if (value_.has_value()) {
|
|
lock.unlock();
|
|
ret->markCompleted(Message());
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
template <typename T>
|
|
void OwnerRRef<T>::setValue(T&& value) {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
value_ = std::move(value);
|
|
std::shared_ptr<FutureMessage> future;
|
|
future.swap(future_);
|
|
lock.unlock();
|
|
valueCV_.notify_all();
|
|
if (future.get() && !future->completed()) {
|
|
future->markCompleted(Message());
|
|
}
|
|
}
|
|
|
|
template class OwnerRRef<IValue>;
|
|
template class OwnerRRef<py::object>;
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|