Use c10::ThreadPool to send and receive messages (#23968)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23968

Existing ProcessGroupAgent uses a single thread to send all messages, and
a single thread to listen and process all received messages. This causes
both performance issues and also prevents nested RPCs. For example, when
running nested RPC A->B->A->B, the second recv on B cannot start until
the first recv on B finishes. If the second recv is triggered by a nested
RPC in the first recv, it will deadlock. Ideally, we should expose sth like
responder or FutureResult to the Python land to support nested asynchronous
UDFs.

This diff adds a shared ThreadPool for send and recv. Send use it do send
out messages, and recv use it to process received messages. There is still
a dedicated thread to listen for incoming messages and add it to task queue.
There are two goals: 1) speed up ProcessGroupAgent 2) use ThreadPool as a
temporary solution for (a small number of) nested RPCs

ghstack-source-id: 88476246

Differential Revision: D16695091

fbshipit-source-id: fd18a5c65e7fcd1331b73d1287673e6e10d2dd86
This commit is contained in:
Shen Li 2019-08-16 17:47:33 -07:00 committed by Facebook Github Bot
parent dd97743de7
commit 99dea08e60
4 changed files with 185 additions and 99 deletions

View File

@ -19,6 +19,17 @@ def my_function(a, b, c):
def no_result():
print("do nothing")
def nested_rpc(dst):
return dist.rpc(dst, torch.add, args=(torch.ones(2, 2), 1))
def light_rpc():
return 0
def heavy_rpc(tensor):
for i in range(1, 100):
tensor *= i
tensor /= i + 1
return 0
# it is used to test python user defined class and methods over rpc
class my_class:
@ -222,6 +233,39 @@ class RpcTest(MultiProcessTestCase):
expected = "run_python_udf_internal caught exception: " + str(e)
self.assertEqual(ret, expected)
@_wrap_with_rpc
def test_nested_rpc(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc('worker{}'.format(dst_rank), nested_rpc,
args=('worker{}'.format(self.rank),))
self.assertEqual(ret, torch.ones(2, 2) + 1)
def _stress_test_rpc(self, f, repeat=1000, args=()):
import time
n = self.rank + 1
dst_rank = n % self.world_size
futs = []
tik = time.time()
for _ in range(repeat):
fut = dist.rpc('worker{}'.format(dst_rank), f, args=args, async_call=True)
futs.append(fut)
for fut in futs:
self.assertEqual(fut.wait(), 0)
tok = time.time()
print("Rank {} finished testing {} {} times in {} seconds.".format(
self.rank, f.__name__, repeat, tok - tik
))
@_wrap_with_rpc
def test_stress_light_rpc(self):
self._stress_test_rpc(light_rpc)
@_wrap_with_rpc
def test_stress_heavy_rpc(self):
self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
if __name__ == '__main__':
run_tests()

View File

@ -1,4 +1,6 @@
#include <torch/csrc/distributed/rpc/ProcessGroupAgent.h>
#include <c10d/ProcessGroup.hpp>
#include <Python.h>
namespace torch {
@ -20,28 +22,25 @@ void serialize(const Message& message, std::ostream& os) {
std::vector<torch::Tensor> tensors = message.tensors();
// append payload as a tensor
tensors.push_back(torch::from_blob(payload, payload_size, {torch::kChar}));
// append id and type as a tensor
tensors.push_back(torch::tensor(
{message.id(), (int64_t) message.type()}, {torch::kInt64}
// append id as a tensor
tensors.push_back(torch::tensor({message.id()}, {torch::kInt64}
));
torch::save(tensors, os);
}
Message deserialize(std::istream& is) {
Message deserialize(MessageType type, std::istream& is) {
std::vector<torch::Tensor> tensors;
torch::load(tensors, is);
TORCH_CHECK(tensors.size() >= 2, "Failed to deserialize a message.");
auto miscTensor = std::move(tensors.back());
auto idTensor = std::move(tensors.back());
tensors.pop_back();
auto payloadTensor = std::move(tensors.back());
tensors.pop_back();
int64_t* miscItems = miscTensor.storage().data<int64_t>();
int64_t id = miscItems[0];
MessageType type = MessageType(miscItems[1]);
int64_t id = idTensor.storage().data<int64_t>()[0];
std::vector<char> payload(payloadTensor.numel());
@ -59,12 +58,15 @@ Message deserialize(std::istream& is) {
ProcessGroupAgent::ProcessGroupAgent(
std::string workerName,
std::unordered_map<std::string, int> nameMap,
std::shared_ptr<c10d::ProcessGroup> pg)
std::shared_ptr<c10d::ProcessGroup> pg,
int numSendRecvThreads)
: RpcAgent(std::move(workerName), processRequestBlocking),
nameMap_(std::move(nameMap)),
stop_(false),
pg_(std::move(pg)),
nextId_(0) {
nextId_(0),
sendMutexes_(pg_->getSize()),
threadPool_(numSendRecvThreads) {
TORCH_CHECK(nameMap_.size() > 1, "ProcessGroupAgent requires world_size to "
"be at least 2, but got ", nameMap_.size());
auto workerRankIter = nameMap_.find(workerName_);
@ -79,7 +81,6 @@ ProcessGroupAgent::ProcessGroupAgent(
names_[entry.second] = entry.first;
}
PythonRpcHandler::init();
sendThread_ = std::thread(&ProcessGroupAgent::sendLoop, this);
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
}
@ -92,14 +93,8 @@ void ProcessGroupAgent::join() {
// effort to fix this problem).
sync();
int dst = (pg_->getRank() + 1) % pg_->getSize();
enqueue(SendWork(dst, Message({}, {}, MessageType::SHUTDOWN)));
std::unique_lock<std::mutex> lock(sendQueueMutex_);
workConsumeCV_.wait(lock, [&] { return sendQueue_.empty(); });
stop_ = true;
lock.unlock();
workProduceCV_.notify_all();
sendThread_.join();
enqueueSend(SendWork(dst, Message({}, {}, MessageType::SHUTDOWN)));
threadPool_.waitWorkComplete();
listenerThread_.join();
}
@ -108,11 +103,9 @@ void ProcessGroupAgent::sync() {
// the lock below, because other processes might not enter sync() until it
// gets some response from this RpcAgent.
pg_->barrier()->wait();
// Acquire the lock on the send queue to prevent additional messages to be put
// onto the send queue.
std::unique_lock<std::mutex> lock(sendQueueMutex_);
// Wait until the send queue is depleted.
workConsumeCV_.wait(lock, [&] { return sendQueue_.empty(); });
// Wait until the all send works are done.
// NB: There might be additional send works inserted while waiting.
threadPool_.waitWorkComplete();
// Use another barrier in case different RpcAgent handles different amounts of
// workloads.
pg_->barrier()->wait();
@ -140,87 +133,111 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
future->markCompleted();
}
enqueue(SendWork(dstRank, std::move(message)));
enqueueSend(SendWork(dstRank, std::move(message)));
return future;
}
void ProcessGroupAgent::enqueue(SendWork work) {
std::unique_lock<std::mutex> lock(sendQueueMutex_);
sendQueue_.emplace_back(std::move(work));
lock.unlock();
void ProcessGroupAgent::enqueueSend(SendWork work) {
// NB: this can be changed to use a native move capture when moved to C++14
threadPool_.run(std::bind(
[&](const SendWork& work) {
std::stringstream ss;
serialize(work.message_, ss);
std::string serializedPayload = ss.str();
workProduceCV_.notify_one();
std::vector<torch::Tensor> preamble = {
torch::tensor(
{
(int64_t)pg_->getRank(),
(int64_t)serializedPayload.length(),
(int64_t)work.message_.type()
}, {torch::kLong})
};
// ProcessGroup is not thread-safe when sending with the same tag, hence
// the lock
std::vector<std::shared_ptr<c10d::ProcessGroup::Work>> pendingSends;
if (work.message_.isShutdown()) {
pendingSends.reserve(1);
std::lock_guard<std::mutex> guard(sendMutexes_[work.to_]);
pendingSends.emplace_back(
pg_->send(preamble, work.to_, work.to_ /* channelTag */));
} else {
std::vector<torch::Tensor> payload = {
torch::from_blob(
(void *)serializedPayload.c_str(),
serializedPayload.length(),
{torch::kChar}
)
};
pendingSends.reserve(2);
std::lock_guard<std::mutex> guard(sendMutexes_[work.to_]);
pendingSends.emplace_back(
pg_->send(preamble, work.to_, work.to_ /* channelTag */));
pendingSends.emplace_back(
pg_->send(payload, work.to_, work.to_ /* channelTag */));
}
for (auto& pendingSend: pendingSends) {
pendingSend->wait();
}
},
std::move(work)
));
}
// making sure tensors are not deleted before send finishes
void ProcessGroupAgent::sendLoop() {
std::unique_lock<std::mutex> lock(sendQueueMutex_);
void ProcessGroupAgent::enqueueRecv(RecvWork work) {
threadPool_.run(std::bind(
[&](RecvWork& work) {
while (!stop_) {
if (sendQueue_.empty()) {
workProduceCV_.wait(lock);
continue;
}
torch::Tensor& payload = work.payload_;
std::stringstream ss(std::string(
(char*)payload.storage().data<signed char>(), payload.numel()));
auto work = std::move(sendQueue_.front());
sendQueue_.pop_front();
lock.unlock();
Message message = deserialize(work.type_, ss);
workConsumeCV_.notify_one();
std::stringstream ss;
serialize(work.message_, ss);
std::string str = ss.str();
std::vector<torch::Tensor> preamble = {
torch::tensor(
if (message.isRequest()) {
cb_(names_[work.from_], std::move(message), *this);
} else if (message.isResponse()) {
auto id = message.id();
{
(int64_t)pg_->getRank(),
(int64_t)str.length(),
}, {torch::kLong})
};
pg_->send(preamble, work.dstRank_, work.dstRank_ /* channelTag */)->wait();
std::vector<torch::Tensor> payload =
{torch::from_blob((void *)str.c_str(), str.length(), {torch::kChar})};
pg_->send(payload, work.dstRank_, work.dstRank_ /* channelTag */)->wait();
lock.lock();
}
std::lock_guard<std::mutex> lock{futureMutex_};
futures_[id]->markCompleted(std::move(message));
futures_.erase(id);
}
} else {
// TODO: pass the error back to the caller instead of crashing here.
AT_ERROR("unrecognized message type ", message.type());
}
},
std::move(work)
));
}
void ProcessGroupAgent::listenLoop() {
while (true) {
// rank, tensor size
std::vector<torch::Tensor> preamble = {torch::empty({2}, {torch::kInt64})};
// rank, tensor size, message type
std::vector<torch::Tensor> preamble = {torch::empty({3}, {torch::kInt64})};
pg_->recvAnysource(preamble, pg_->getRank())->wait();
int64_t* preamble_items = preamble.front().storage().data<int64_t>();
auto srcRank = preamble_items[0];
auto size = preamble_items[1];
MessageType type = MessageType(preamble_items[2]);
if (type == MessageType::SHUTDOWN) {
// FIXME: This LOG also prints warnings no InitGoogleLogging() was invoked
// before logging, but it is not appropriate to call InitGoogleLogging()
// here either.
LOG(INFO) << "Shutting down ProcessGroupAgent "
<< workerName_ << std::endl;
return;
}
std::vector<torch::Tensor> tensors = {torch::empty({size}, {torch::kChar})};
pg_->recv(tensors, srcRank, pg_->getRank())->wait();
std::stringstream ss(std::string(
(char*)tensors[0].storage().data<signed char>(), tensors[0].numel()));
Message message = deserialize(ss);
if (message.isRequest()) {
cb_(names_[srcRank], std::move(message), *this);
} else if (message.isResponse()) {
auto id = message.id();
{
std::lock_guard<std::mutex> lock{futureMutex_};
futures_[id]->markCompleted(std::move(message));
futures_.erase(id);
}
} else if (message.isShutdown()) {
break;
} else {
AT_ERROR("unrecognized message type ", message.type());
}
enqueueRecv(RecvWork(srcRank, type, std::move(tensors[0])));
}
}

View File

@ -1,27 +1,38 @@
#pragma once
#include <c10/core/thread_pool.h>
#include <c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/rpc/FutureMessage.h>
#include <torch/csrc/distributed/rpc/RpcAgent.h>
#include <torch/csrc/distributed/rpc/functions.h>
#include <torch/csrc/distributed/rpc/PythonRpcHandler.h>
#include <deque>
#include <thread>
namespace torch {
namespace distributed {
namespace rpc {
// SendWork and RecvWork will be put into a task queue, and later picked up by
// worker threads from the same ThreadPool.
struct SendWork {
SendWork(const int dstRank,
Message&& message)
: dstRank_(dstRank),
message_(message) {}
SendWork(const int to, Message&& message) :
to_(to), message_(message) {}
const int dstRank_;
const int to_;
Message message_;
};
// SendWork wraps a Message and RecvWork wraps a Tensor. The difference here is
// to allow us to run serialization/deserialization in the worker threads.
struct RecvWork {
RecvWork(const int from, MessageType type, torch::Tensor&& payload)
: from_(from), type_(type), payload_(payload) {}
const int from_;
const MessageType type_;
torch::Tensor payload_;
};
class ProcessGroupAgent : public RpcAgent {
@ -29,7 +40,8 @@ class ProcessGroupAgent : public RpcAgent {
ProcessGroupAgent(std::string workerName,
std::unordered_map<std::string, int> nameMap,
std::shared_ptr<c10d::ProcessGroup> pg);
std::shared_ptr<c10d::ProcessGroup> pg,
int numSendRecvThreads = 4);
// This method wraps the destination information and the message into a
// SendWork object, and put the SendWork into a queue. Another thread will
@ -42,10 +54,10 @@ class ProcessGroupAgent : public RpcAgent {
void sync() override;
private:
// put SendWork into a queue and notify the sendLoop thread
void enqueue(SendWork work);
// sending out the message
void sendLoop();
// put SendWork into a queue and notify the worker thread
void enqueueSend(SendWork work);
// put RecvWork into a queue and notify the worker thread
void enqueueRecv(RecvWork work);
// receiving messages
void listenLoop();
@ -61,12 +73,20 @@ class ProcessGroupAgent : public RpcAgent {
// names_[rank] stores the name of the corresponding worker, use this vector
// to get worker name from rank and pass it to the RequestCallback.
std::vector<std::string> names_;
std::deque<SendWork> sendQueue_;
std::mutex sendQueueMutex_;
std::condition_variable workProduceCV_;
std::condition_variable workConsumeCV_;
std::thread sendThread_;
// one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe
// when using the same tag.
std::vector<std::mutex> sendMutexes_;
std::thread listenerThread_;
// A threadPool that processing both SendWork and RecvWork. There are two
// motivations for adding a ThreadPool:
// (1) RPC serialization/deserialization and processing can be expensive,
// hence using multiple threads to speed it up.
// (2) The current RPC API does not support asynchronous UDFs, e.g., UDFs can
// not yield in the middle of execution to wait for IO, and resume the IO
// is done. This would result in deadlocks when we have nested RPC calls.
// NB: Ideally, this should be addressed by supporting asynchronous UDF.
// This is just a temporary solution for (2).
ThreadPool threadPool_;
std::unordered_map<int64_t, std::shared_ptr<FutureMessage>> futures_;
std::mutex futureMutex_;
};

View File

@ -48,7 +48,12 @@ PyObject* rpc_init(PyObject* /* unused */) {
module, "ProcessGroupAgent", rpcAgent)
.def(py::init<std::string,
std::unordered_map<std::string, int>,
std::shared_ptr<::c10d::ProcessGroup>>())
std::shared_ptr<::c10d::ProcessGroup>,
int>(),
py::arg("name"),
py::arg("name_map"),
py::arg("process_group"),
py::arg("num_send_recv_threads") = 4)
.def("join",
&ProcessGroupAgent::join,
py::call_guard<py::gil_scoped_release>())