Change Store exception handling

This commit is contained in:
Janusz Marcinkiewicz 2017-04-20 04:19:33 -07:00 committed by Adam Paszke
parent 310d08c37b
commit f07f13c6e9
9 changed files with 28 additions and 27 deletions

View File

@ -507,8 +507,6 @@ if BACKEND == 'tcp' or BACKEND == 'gloo':
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
# We're retreiving a corresponding test and executing it.
getattr(self, self.id().split(".")[2])()
if rank != 0:
time.sleep(0.2) # temporary fix for Gloo
sys.exit(0)
def _join_and_reduce(self):

View File

@ -45,7 +45,7 @@ void send_bytes(int socket, const T* buffer, std::size_t length, bool more_data
ssize_t bytes_sent;
SYSCHECK(bytes_sent = ::send(socket, current_bytes, bytes_to_send, flags))
if (bytes_sent == 0)
throw std::system_error(EBADMSG, std::system_category());
throw std::system_error(ECONNRESET, std::system_category());
bytes_to_send -= bytes_sent;
current_bytes += bytes_sent;
@ -67,7 +67,7 @@ void recv_bytes(int socket, T* buffer, std::size_t length)
ssize_t bytes_received;
SYSCHECK(bytes_received = ::recv(socket, current_bytes, bytes_to_receive, 0))
if (bytes_received == 0)
throw std::system_error(EBADMSG, std::system_category());
throw std::system_error(ECONNRESET, std::system_category());
bytes_to_receive -= bytes_received;
current_bytes += bytes_received;

View File

@ -65,6 +65,8 @@ void DataChannelGloo::RequestGloo::wait() {
DataChannelGloo::DataChannelGloo()
: _rank(load_rank_env())
, _store(nullptr)
, _cache(nullptr)
{
if (_rank == 0) {
_num_processes = load_world_size_env();
@ -73,7 +75,6 @@ DataChannelGloo::DataChannelGloo()
::gloo::transport::tcp::attr attr("localhost"); // default options listen on host
_device = ::gloo::transport::tcp::CreateDevice(attr);
// Master runs the store_thread
if (_rank == 0) {
std::tie(_port, std::ignore) = load_master_env();
_addr = "localhost";

View File

@ -32,8 +32,11 @@ struct hash<std::tuple<::thd::DataOperation, THDGroup>> {
};
template<>
struct hash<std::tuple<::thd::DataOperation, THDGroup, std::size_t, std::size_t, THDReduceOp, thd::rank_type>> {
std::size_t operator()(const std::tuple<::thd::DataOperation, THDGroup, std::size_t, std::size_t, THDReduceOp, thd::rank_type>& k) const {
struct hash<std::tuple<::thd::DataOperation, THDGroup, std::size_t,
std::size_t, THDReduceOp, thd::rank_type>> {
std::size_t operator()(const std::tuple<::thd::DataOperation, THDGroup,
std::size_t, std::size_t, THDReduceOp,
thd::rank_type>& k) const {
return (
hash<::thd::DataOperation>()(std::get<0>(k)) ^
hash<THDGroup>()(std::get<1>(k)) ^
@ -260,7 +263,7 @@ struct algorithm_spec<DataOperation::ALL_REDUCE, T> {
input_bytes, input_bytes, op, 0);
}
static GlooCache::value_type create(GlooCache& cache,
static GlooCache::value_type create(GlooCache& cache,
const DataChannel::Group& group, GlooCache::store_type& store,
std::size_t input_bytes, std::size_t count, THDReduceOp op
) {

View File

@ -35,17 +35,14 @@ Store::StoreDeamon::~StoreDeamon()
}
}
void Store::StoreDeamon::join() {
_deamon.join();
}
void Store::StoreDeamon::deamon() {
int socket;
std::tie(socket, std::ignore) = listen(_port);
// accept WORLD_SIZE connections
for (auto& p_socket : _sockets) {
std::tie(p_socket, std::ignore) = accept(socket);
}
@ -74,9 +71,19 @@ void Store::StoreDeamon::deamon() {
if (fds[rank].revents ^ POLLIN)
throw std::system_error(ECONNABORTED, std::system_category());
finished = query(rank);
if (finished)
try {
finished = query(rank);
if (finished)
break;
} catch (...) {
// There was an error when processing query. Probably exception occured in
// recv/send what would indicate that socket on the other side has been closed.
// If the closing was due to normal exit this indicates that store should
// exit too. Otherwise, when closing was due to unexpected behaviour,
// other processes will get exception when trying to use store.
finished = true;
break;
}
}
}
}
@ -128,16 +135,13 @@ bool Store::StoreDeamon::query(rank_type rank) {
}
return false;
} else if (qt == QueryType::FINISH) {
for (auto socket : _sockets) {
send_value<std::uint8_t>(socket, 1);
}
return true;
} else {
throw std::runtime_error("expected a query type");
}
}
bool Store::StoreDeamon::checkAndUpdate(std::vector<std::string>& keys) {
bool Store::StoreDeamon::checkAndUpdate(std::vector<std::string>& keys) const {
bool ret = true;
for (auto it = keys.begin(); it != keys.end();) {
if (_store.find(*it) == _store.end()) {
@ -160,7 +164,7 @@ Store::Store(rank_type rank, const std::string& addr,
, _socket(-1)
, _store_thread(nullptr)
{
// Master runs the store_thread
// only one process starts store
if (_rank == 0) {
_store_thread = std::unique_ptr<StoreDeamon>(
new StoreDeamon(port, world_size)
@ -177,8 +181,6 @@ Store::~Store() {
_store_thread->join();
}
std::uint8_t barrier_byte;
recv_bytes<std::uint8_t>(_socket, &barrier_byte, 1);
::close(_socket);
}

View File

@ -25,7 +25,7 @@ private:
void deamon();
bool query(rank_type rank);
bool checkAndUpdate(std::vector<std::string>& keys);
bool checkAndUpdate(std::vector<std::string>& keys) const;
port_type _port;

View File

@ -93,7 +93,7 @@ bool MasterCommandChannel::init() {
::close(_sockets[0]);
int fd[2];
SYSCHECK((void)::pipe(fd));
SYSCHECK(::pipe(fd));
_sockets[0] = fd[0];
_error_pipe = fd[1];
_error_thread = std::thread(&MasterCommandChannel::errorHandler, this);

View File

@ -28,7 +28,6 @@ constexpr int BARRIER_WAIT_TIME = 200; // milliseconds
std::vector<std::thread> g_all_workers;
std::mutex g_mutex;
std::string g_data_channel_type;
std::unique_ptr<Barrier> g_barrier;
void test_send_recv_tensor(std::shared_ptr<thd::DataChannel> data_channel) {
@ -685,7 +684,6 @@ void init_gloo_master(int workers) {
assert(masterChannel->init());
run_all_tests(masterChannel, workers);
g_barrier->wait();
}
void init_gloo_worker(unsigned int id, int workers) {
@ -697,7 +695,6 @@ void init_gloo_worker(unsigned int id, int workers) {
assert(worker_channel->init());
run_all_tests(worker_channel, workers);
g_barrier->wait();
}
#endif // WITH_GLOO
@ -736,7 +733,6 @@ int main(int argc, char const *argv[]) {
#ifdef WITH_GLOO
g_data_channel_type = "gloo";
for (auto workers : WORKERS_NUM) {
g_barrier.reset(new Barrier(workers + 1));
std::cout << "Gloo (workers: " << workers << "):" << std::endl;
// start gloo master
std::thread gloo_master_thread(init_gloo_master, workers);

View File

@ -35,7 +35,8 @@ void test(std::shared_ptr<thd::DataChannel> data_channel) {
}
void run_all_tests(std::shared_ptr<thd::DataChannel> data_channel, int workers) {
// NOTE: without store this test would create about (1000 * WORKERS ^ 3) connections
// NOTE: without properly working GlooCache this test would create
// about (1000 * WORKERS ^ 3) connections what is over 'normal' system configuration
for (std::size_t i = 0; i < 1000; ++i) {
test(data_channel);
}