mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Use c10::ThreadPool to send and receive messages (#23968)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23968 Existing ProcessGroupAgent uses a single thread to send all messages, and a single thread to listen and process all received messages. This causes both performance issues and also prevents nested RPCs. For example, when running nested RPC A->B->A->B, the second recv on B cannot start until the first recv on B finishes. If the second recv is triggered by a nested RPC in the first recv, it will deadlock. Ideally, we should expose sth like responder or FutureResult to the Python land to support nested asynchronous UDFs. This diff adds a shared ThreadPool for send and recv. Send use it do send out messages, and recv use it to process received messages. There is still a dedicated thread to listen for incoming messages and add it to task queue. There are two goals: 1) speed up ProcessGroupAgent 2) use ThreadPool as a temporary solution for (a small number of) nested RPCs ghstack-source-id: 88476246 Differential Revision: D16695091 fbshipit-source-id: fd18a5c65e7fcd1331b73d1287673e6e10d2dd86
This commit is contained in:
parent
dd97743de7
commit
99dea08e60
|
|
@ -19,6 +19,17 @@ def my_function(a, b, c):
|
|||
def no_result():
|
||||
print("do nothing")
|
||||
|
||||
def nested_rpc(dst):
|
||||
return dist.rpc(dst, torch.add, args=(torch.ones(2, 2), 1))
|
||||
|
||||
def light_rpc():
|
||||
return 0
|
||||
|
||||
def heavy_rpc(tensor):
|
||||
for i in range(1, 100):
|
||||
tensor *= i
|
||||
tensor /= i + 1
|
||||
return 0
|
||||
|
||||
# it is used to test python user defined class and methods over rpc
|
||||
class my_class:
|
||||
|
|
@ -222,6 +233,39 @@ class RpcTest(MultiProcessTestCase):
|
|||
expected = "run_python_udf_internal caught exception: " + str(e)
|
||||
self.assertEqual(ret, expected)
|
||||
|
||||
@_wrap_with_rpc
|
||||
def test_nested_rpc(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = dist.rpc('worker{}'.format(dst_rank), nested_rpc,
|
||||
args=('worker{}'.format(self.rank),))
|
||||
self.assertEqual(ret, torch.ones(2, 2) + 1)
|
||||
|
||||
def _stress_test_rpc(self, f, repeat=1000, args=()):
|
||||
import time
|
||||
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
futs = []
|
||||
tik = time.time()
|
||||
for _ in range(repeat):
|
||||
fut = dist.rpc('worker{}'.format(dst_rank), f, args=args, async_call=True)
|
||||
futs.append(fut)
|
||||
|
||||
for fut in futs:
|
||||
self.assertEqual(fut.wait(), 0)
|
||||
tok = time.time()
|
||||
print("Rank {} finished testing {} {} times in {} seconds.".format(
|
||||
self.rank, f.__name__, repeat, tok - tik
|
||||
))
|
||||
|
||||
@_wrap_with_rpc
|
||||
def test_stress_light_rpc(self):
|
||||
self._stress_test_rpc(light_rpc)
|
||||
|
||||
@_wrap_with_rpc
|
||||
def test_stress_heavy_rpc(self):
|
||||
self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
#include <torch/csrc/distributed/rpc/ProcessGroupAgent.h>
|
||||
#include <c10d/ProcessGroup.hpp>
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
namespace torch {
|
||||
|
|
@ -20,28 +22,25 @@ void serialize(const Message& message, std::ostream& os) {
|
|||
std::vector<torch::Tensor> tensors = message.tensors();
|
||||
// append payload as a tensor
|
||||
tensors.push_back(torch::from_blob(payload, payload_size, {torch::kChar}));
|
||||
// append id and type as a tensor
|
||||
tensors.push_back(torch::tensor(
|
||||
{message.id(), (int64_t) message.type()}, {torch::kInt64}
|
||||
// append id as a tensor
|
||||
tensors.push_back(torch::tensor({message.id()}, {torch::kInt64}
|
||||
));
|
||||
|
||||
torch::save(tensors, os);
|
||||
}
|
||||
|
||||
Message deserialize(std::istream& is) {
|
||||
Message deserialize(MessageType type, std::istream& is) {
|
||||
std::vector<torch::Tensor> tensors;
|
||||
|
||||
torch::load(tensors, is);
|
||||
|
||||
TORCH_CHECK(tensors.size() >= 2, "Failed to deserialize a message.");
|
||||
auto miscTensor = std::move(tensors.back());
|
||||
auto idTensor = std::move(tensors.back());
|
||||
tensors.pop_back();
|
||||
auto payloadTensor = std::move(tensors.back());
|
||||
tensors.pop_back();
|
||||
|
||||
int64_t* miscItems = miscTensor.storage().data<int64_t>();
|
||||
int64_t id = miscItems[0];
|
||||
MessageType type = MessageType(miscItems[1]);
|
||||
int64_t id = idTensor.storage().data<int64_t>()[0];
|
||||
|
||||
std::vector<char> payload(payloadTensor.numel());
|
||||
|
||||
|
|
@ -59,12 +58,15 @@ Message deserialize(std::istream& is) {
|
|||
ProcessGroupAgent::ProcessGroupAgent(
|
||||
std::string workerName,
|
||||
std::unordered_map<std::string, int> nameMap,
|
||||
std::shared_ptr<c10d::ProcessGroup> pg)
|
||||
std::shared_ptr<c10d::ProcessGroup> pg,
|
||||
int numSendRecvThreads)
|
||||
: RpcAgent(std::move(workerName), processRequestBlocking),
|
||||
nameMap_(std::move(nameMap)),
|
||||
stop_(false),
|
||||
pg_(std::move(pg)),
|
||||
nextId_(0) {
|
||||
nextId_(0),
|
||||
sendMutexes_(pg_->getSize()),
|
||||
threadPool_(numSendRecvThreads) {
|
||||
TORCH_CHECK(nameMap_.size() > 1, "ProcessGroupAgent requires world_size to "
|
||||
"be at least 2, but got ", nameMap_.size());
|
||||
auto workerRankIter = nameMap_.find(workerName_);
|
||||
|
|
@ -79,7 +81,6 @@ ProcessGroupAgent::ProcessGroupAgent(
|
|||
names_[entry.second] = entry.first;
|
||||
}
|
||||
PythonRpcHandler::init();
|
||||
sendThread_ = std::thread(&ProcessGroupAgent::sendLoop, this);
|
||||
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
|
||||
}
|
||||
|
||||
|
|
@ -92,14 +93,8 @@ void ProcessGroupAgent::join() {
|
|||
// effort to fix this problem).
|
||||
sync();
|
||||
int dst = (pg_->getRank() + 1) % pg_->getSize();
|
||||
enqueue(SendWork(dst, Message({}, {}, MessageType::SHUTDOWN)));
|
||||
std::unique_lock<std::mutex> lock(sendQueueMutex_);
|
||||
workConsumeCV_.wait(lock, [&] { return sendQueue_.empty(); });
|
||||
stop_ = true;
|
||||
lock.unlock();
|
||||
|
||||
workProduceCV_.notify_all();
|
||||
sendThread_.join();
|
||||
enqueueSend(SendWork(dst, Message({}, {}, MessageType::SHUTDOWN)));
|
||||
threadPool_.waitWorkComplete();
|
||||
listenerThread_.join();
|
||||
}
|
||||
|
||||
|
|
@ -108,11 +103,9 @@ void ProcessGroupAgent::sync() {
|
|||
// the lock below, because other processes might not enter sync() until it
|
||||
// gets some response from this RpcAgent.
|
||||
pg_->barrier()->wait();
|
||||
// Acquire the lock on the send queue to prevent additional messages to be put
|
||||
// onto the send queue.
|
||||
std::unique_lock<std::mutex> lock(sendQueueMutex_);
|
||||
// Wait until the send queue is depleted.
|
||||
workConsumeCV_.wait(lock, [&] { return sendQueue_.empty(); });
|
||||
// Wait until the all send works are done.
|
||||
// NB: There might be additional send works inserted while waiting.
|
||||
threadPool_.waitWorkComplete();
|
||||
// Use another barrier in case different RpcAgent handles different amounts of
|
||||
// workloads.
|
||||
pg_->barrier()->wait();
|
||||
|
|
@ -140,87 +133,111 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
|
|||
future->markCompleted();
|
||||
}
|
||||
|
||||
enqueue(SendWork(dstRank, std::move(message)));
|
||||
enqueueSend(SendWork(dstRank, std::move(message)));
|
||||
return future;
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::enqueue(SendWork work) {
|
||||
std::unique_lock<std::mutex> lock(sendQueueMutex_);
|
||||
sendQueue_.emplace_back(std::move(work));
|
||||
lock.unlock();
|
||||
void ProcessGroupAgent::enqueueSend(SendWork work) {
|
||||
// NB: this can be changed to use a native move capture when moved to C++14
|
||||
threadPool_.run(std::bind(
|
||||
[&](const SendWork& work) {
|
||||
std::stringstream ss;
|
||||
serialize(work.message_, ss);
|
||||
std::string serializedPayload = ss.str();
|
||||
|
||||
workProduceCV_.notify_one();
|
||||
std::vector<torch::Tensor> preamble = {
|
||||
torch::tensor(
|
||||
{
|
||||
(int64_t)pg_->getRank(),
|
||||
(int64_t)serializedPayload.length(),
|
||||
(int64_t)work.message_.type()
|
||||
}, {torch::kLong})
|
||||
};
|
||||
|
||||
// ProcessGroup is not thread-safe when sending with the same tag, hence
|
||||
// the lock
|
||||
std::vector<std::shared_ptr<c10d::ProcessGroup::Work>> pendingSends;
|
||||
if (work.message_.isShutdown()) {
|
||||
pendingSends.reserve(1);
|
||||
std::lock_guard<std::mutex> guard(sendMutexes_[work.to_]);
|
||||
pendingSends.emplace_back(
|
||||
pg_->send(preamble, work.to_, work.to_ /* channelTag */));
|
||||
} else {
|
||||
std::vector<torch::Tensor> payload = {
|
||||
torch::from_blob(
|
||||
(void *)serializedPayload.c_str(),
|
||||
serializedPayload.length(),
|
||||
{torch::kChar}
|
||||
)
|
||||
};
|
||||
pendingSends.reserve(2);
|
||||
std::lock_guard<std::mutex> guard(sendMutexes_[work.to_]);
|
||||
pendingSends.emplace_back(
|
||||
pg_->send(preamble, work.to_, work.to_ /* channelTag */));
|
||||
pendingSends.emplace_back(
|
||||
pg_->send(payload, work.to_, work.to_ /* channelTag */));
|
||||
}
|
||||
for (auto& pendingSend: pendingSends) {
|
||||
pendingSend->wait();
|
||||
}
|
||||
|
||||
},
|
||||
std::move(work)
|
||||
));
|
||||
}
|
||||
|
||||
// making sure tensors are not deleted before send finishes
|
||||
void ProcessGroupAgent::sendLoop() {
|
||||
std::unique_lock<std::mutex> lock(sendQueueMutex_);
|
||||
void ProcessGroupAgent::enqueueRecv(RecvWork work) {
|
||||
threadPool_.run(std::bind(
|
||||
[&](RecvWork& work) {
|
||||
|
||||
while (!stop_) {
|
||||
if (sendQueue_.empty()) {
|
||||
workProduceCV_.wait(lock);
|
||||
continue;
|
||||
}
|
||||
torch::Tensor& payload = work.payload_;
|
||||
std::stringstream ss(std::string(
|
||||
(char*)payload.storage().data<signed char>(), payload.numel()));
|
||||
|
||||
auto work = std::move(sendQueue_.front());
|
||||
sendQueue_.pop_front();
|
||||
lock.unlock();
|
||||
Message message = deserialize(work.type_, ss);
|
||||
|
||||
workConsumeCV_.notify_one();
|
||||
|
||||
|
||||
std::stringstream ss;
|
||||
serialize(work.message_, ss);
|
||||
std::string str = ss.str();
|
||||
|
||||
std::vector<torch::Tensor> preamble = {
|
||||
torch::tensor(
|
||||
if (message.isRequest()) {
|
||||
cb_(names_[work.from_], std::move(message), *this);
|
||||
} else if (message.isResponse()) {
|
||||
auto id = message.id();
|
||||
{
|
||||
(int64_t)pg_->getRank(),
|
||||
(int64_t)str.length(),
|
||||
}, {torch::kLong})
|
||||
};
|
||||
pg_->send(preamble, work.dstRank_, work.dstRank_ /* channelTag */)->wait();
|
||||
std::vector<torch::Tensor> payload =
|
||||
{torch::from_blob((void *)str.c_str(), str.length(), {torch::kChar})};
|
||||
pg_->send(payload, work.dstRank_, work.dstRank_ /* channelTag */)->wait();
|
||||
|
||||
lock.lock();
|
||||
}
|
||||
std::lock_guard<std::mutex> lock{futureMutex_};
|
||||
futures_[id]->markCompleted(std::move(message));
|
||||
futures_.erase(id);
|
||||
}
|
||||
} else {
|
||||
// TODO: pass the error back to the caller instead of crashing here.
|
||||
AT_ERROR("unrecognized message type ", message.type());
|
||||
}
|
||||
},
|
||||
std::move(work)
|
||||
));
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::listenLoop() {
|
||||
while (true) {
|
||||
// rank, tensor size
|
||||
std::vector<torch::Tensor> preamble = {torch::empty({2}, {torch::kInt64})};
|
||||
// rank, tensor size, message type
|
||||
std::vector<torch::Tensor> preamble = {torch::empty({3}, {torch::kInt64})};
|
||||
pg_->recvAnysource(preamble, pg_->getRank())->wait();
|
||||
int64_t* preamble_items = preamble.front().storage().data<int64_t>();
|
||||
|
||||
auto srcRank = preamble_items[0];
|
||||
auto size = preamble_items[1];
|
||||
MessageType type = MessageType(preamble_items[2]);
|
||||
|
||||
if (type == MessageType::SHUTDOWN) {
|
||||
// FIXME: This LOG also prints warnings no InitGoogleLogging() was invoked
|
||||
// before logging, but it is not appropriate to call InitGoogleLogging()
|
||||
// here either.
|
||||
LOG(INFO) << "Shutting down ProcessGroupAgent "
|
||||
<< workerName_ << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> tensors = {torch::empty({size}, {torch::kChar})};
|
||||
pg_->recv(tensors, srcRank, pg_->getRank())->wait();
|
||||
|
||||
std::stringstream ss(std::string(
|
||||
(char*)tensors[0].storage().data<signed char>(), tensors[0].numel()));
|
||||
|
||||
Message message = deserialize(ss);
|
||||
|
||||
if (message.isRequest()) {
|
||||
cb_(names_[srcRank], std::move(message), *this);
|
||||
} else if (message.isResponse()) {
|
||||
auto id = message.id();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{futureMutex_};
|
||||
futures_[id]->markCompleted(std::move(message));
|
||||
futures_.erase(id);
|
||||
}
|
||||
} else if (message.isShutdown()) {
|
||||
break;
|
||||
} else {
|
||||
AT_ERROR("unrecognized message type ", message.type());
|
||||
}
|
||||
enqueueRecv(RecvWork(srcRank, type, std::move(tensors[0])));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,27 +1,38 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/core/thread_pool.h>
|
||||
#include <c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/distributed/rpc/FutureMessage.h>
|
||||
#include <torch/csrc/distributed/rpc/RpcAgent.h>
|
||||
#include <torch/csrc/distributed/rpc/functions.h>
|
||||
#include <torch/csrc/distributed/rpc/PythonRpcHandler.h>
|
||||
|
||||
#include <deque>
|
||||
#include <thread>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
|
||||
// SendWork and RecvWork will be put into a task queue, and later picked up by
|
||||
// worker threads from the same ThreadPool.
|
||||
struct SendWork {
|
||||
SendWork(const int dstRank,
|
||||
Message&& message)
|
||||
: dstRank_(dstRank),
|
||||
message_(message) {}
|
||||
SendWork(const int to, Message&& message) :
|
||||
to_(to), message_(message) {}
|
||||
|
||||
const int dstRank_;
|
||||
const int to_;
|
||||
Message message_;
|
||||
};
|
||||
|
||||
// SendWork wraps a Message and RecvWork wraps a Tensor. The difference here is
|
||||
// to allow us to run serialization/deserialization in the worker threads.
|
||||
struct RecvWork {
|
||||
RecvWork(const int from, MessageType type, torch::Tensor&& payload)
|
||||
: from_(from), type_(type), payload_(payload) {}
|
||||
|
||||
const int from_;
|
||||
const MessageType type_;
|
||||
torch::Tensor payload_;
|
||||
};
|
||||
|
||||
class ProcessGroupAgent : public RpcAgent {
|
||||
|
|
@ -29,7 +40,8 @@ class ProcessGroupAgent : public RpcAgent {
|
|||
|
||||
ProcessGroupAgent(std::string workerName,
|
||||
std::unordered_map<std::string, int> nameMap,
|
||||
std::shared_ptr<c10d::ProcessGroup> pg);
|
||||
std::shared_ptr<c10d::ProcessGroup> pg,
|
||||
int numSendRecvThreads = 4);
|
||||
|
||||
// This method wraps the destination information and the message into a
|
||||
// SendWork object, and put the SendWork into a queue. Another thread will
|
||||
|
|
@ -42,10 +54,10 @@ class ProcessGroupAgent : public RpcAgent {
|
|||
void sync() override;
|
||||
|
||||
private:
|
||||
// put SendWork into a queue and notify the sendLoop thread
|
||||
void enqueue(SendWork work);
|
||||
// sending out the message
|
||||
void sendLoop();
|
||||
// put SendWork into a queue and notify the worker thread
|
||||
void enqueueSend(SendWork work);
|
||||
// put RecvWork into a queue and notify the worker thread
|
||||
void enqueueRecv(RecvWork work);
|
||||
// receiving messages
|
||||
void listenLoop();
|
||||
|
||||
|
|
@ -61,12 +73,20 @@ class ProcessGroupAgent : public RpcAgent {
|
|||
// names_[rank] stores the name of the corresponding worker, use this vector
|
||||
// to get worker name from rank and pass it to the RequestCallback.
|
||||
std::vector<std::string> names_;
|
||||
std::deque<SendWork> sendQueue_;
|
||||
std::mutex sendQueueMutex_;
|
||||
std::condition_variable workProduceCV_;
|
||||
std::condition_variable workConsumeCV_;
|
||||
std::thread sendThread_;
|
||||
// one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe
|
||||
// when using the same tag.
|
||||
std::vector<std::mutex> sendMutexes_;
|
||||
std::thread listenerThread_;
|
||||
// A threadPool that processing both SendWork and RecvWork. There are two
|
||||
// motivations for adding a ThreadPool:
|
||||
// (1) RPC serialization/deserialization and processing can be expensive,
|
||||
// hence using multiple threads to speed it up.
|
||||
// (2) The current RPC API does not support asynchronous UDFs, e.g., UDFs can
|
||||
// not yield in the middle of execution to wait for IO, and resume the IO
|
||||
// is done. This would result in deadlocks when we have nested RPC calls.
|
||||
// NB: Ideally, this should be addressed by supporting asynchronous UDF.
|
||||
// This is just a temporary solution for (2).
|
||||
ThreadPool threadPool_;
|
||||
std::unordered_map<int64_t, std::shared_ptr<FutureMessage>> futures_;
|
||||
std::mutex futureMutex_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -48,7 +48,12 @@ PyObject* rpc_init(PyObject* /* unused */) {
|
|||
module, "ProcessGroupAgent", rpcAgent)
|
||||
.def(py::init<std::string,
|
||||
std::unordered_map<std::string, int>,
|
||||
std::shared_ptr<::c10d::ProcessGroup>>())
|
||||
std::shared_ptr<::c10d::ProcessGroup>,
|
||||
int>(),
|
||||
py::arg("name"),
|
||||
py::arg("name_map"),
|
||||
py::arg("process_group"),
|
||||
py::arg("num_send_recv_threads") = 4)
|
||||
.def("join",
|
||||
&ProcessGroupAgent::join,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user