mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Change Store exception handling
This commit is contained in:
parent
310d08c37b
commit
f07f13c6e9
|
|
@ -507,8 +507,6 @@ if BACKEND == 'tcp' or BACKEND == 'gloo':
|
|||
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
|
||||
# We're retreiving a corresponding test and executing it.
|
||||
getattr(self, self.id().split(".")[2])()
|
||||
if rank != 0:
|
||||
time.sleep(0.2) # temporary fix for Gloo
|
||||
sys.exit(0)
|
||||
|
||||
def _join_and_reduce(self):
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ void send_bytes(int socket, const T* buffer, std::size_t length, bool more_data
|
|||
ssize_t bytes_sent;
|
||||
SYSCHECK(bytes_sent = ::send(socket, current_bytes, bytes_to_send, flags))
|
||||
if (bytes_sent == 0)
|
||||
throw std::system_error(EBADMSG, std::system_category());
|
||||
throw std::system_error(ECONNRESET, std::system_category());
|
||||
|
||||
bytes_to_send -= bytes_sent;
|
||||
current_bytes += bytes_sent;
|
||||
|
|
@ -67,7 +67,7 @@ void recv_bytes(int socket, T* buffer, std::size_t length)
|
|||
ssize_t bytes_received;
|
||||
SYSCHECK(bytes_received = ::recv(socket, current_bytes, bytes_to_receive, 0))
|
||||
if (bytes_received == 0)
|
||||
throw std::system_error(EBADMSG, std::system_category());
|
||||
throw std::system_error(ECONNRESET, std::system_category());
|
||||
|
||||
bytes_to_receive -= bytes_received;
|
||||
current_bytes += bytes_received;
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ void DataChannelGloo::RequestGloo::wait() {
|
|||
|
||||
DataChannelGloo::DataChannelGloo()
|
||||
: _rank(load_rank_env())
|
||||
, _store(nullptr)
|
||||
, _cache(nullptr)
|
||||
{
|
||||
if (_rank == 0) {
|
||||
_num_processes = load_world_size_env();
|
||||
|
|
@ -73,7 +75,6 @@ DataChannelGloo::DataChannelGloo()
|
|||
::gloo::transport::tcp::attr attr("localhost"); // default options listen on host
|
||||
_device = ::gloo::transport::tcp::CreateDevice(attr);
|
||||
|
||||
// Master runs the store_thread
|
||||
if (_rank == 0) {
|
||||
std::tie(_port, std::ignore) = load_master_env();
|
||||
_addr = "localhost";
|
||||
|
|
|
|||
|
|
@ -32,8 +32,11 @@ struct hash<std::tuple<::thd::DataOperation, THDGroup>> {
|
|||
};
|
||||
|
||||
template<>
|
||||
struct hash<std::tuple<::thd::DataOperation, THDGroup, std::size_t, std::size_t, THDReduceOp, thd::rank_type>> {
|
||||
std::size_t operator()(const std::tuple<::thd::DataOperation, THDGroup, std::size_t, std::size_t, THDReduceOp, thd::rank_type>& k) const {
|
||||
struct hash<std::tuple<::thd::DataOperation, THDGroup, std::size_t,
|
||||
std::size_t, THDReduceOp, thd::rank_type>> {
|
||||
std::size_t operator()(const std::tuple<::thd::DataOperation, THDGroup,
|
||||
std::size_t, std::size_t, THDReduceOp,
|
||||
thd::rank_type>& k) const {
|
||||
return (
|
||||
hash<::thd::DataOperation>()(std::get<0>(k)) ^
|
||||
hash<THDGroup>()(std::get<1>(k)) ^
|
||||
|
|
@ -260,7 +263,7 @@ struct algorithm_spec<DataOperation::ALL_REDUCE, T> {
|
|||
input_bytes, input_bytes, op, 0);
|
||||
}
|
||||
|
||||
static GlooCache::value_type create(GlooCache& cache,
|
||||
static GlooCache::value_type create(GlooCache& cache,
|
||||
const DataChannel::Group& group, GlooCache::store_type& store,
|
||||
std::size_t input_bytes, std::size_t count, THDReduceOp op
|
||||
) {
|
||||
|
|
|
|||
|
|
@ -35,17 +35,14 @@ Store::StoreDeamon::~StoreDeamon()
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
void Store::StoreDeamon::join() {
|
||||
_deamon.join();
|
||||
}
|
||||
|
||||
|
||||
void Store::StoreDeamon::deamon() {
|
||||
int socket;
|
||||
|
||||
std::tie(socket, std::ignore) = listen(_port);
|
||||
// accept WORLD_SIZE connections
|
||||
for (auto& p_socket : _sockets) {
|
||||
std::tie(p_socket, std::ignore) = accept(socket);
|
||||
}
|
||||
|
|
@ -74,9 +71,19 @@ void Store::StoreDeamon::deamon() {
|
|||
if (fds[rank].revents ^ POLLIN)
|
||||
throw std::system_error(ECONNABORTED, std::system_category());
|
||||
|
||||
finished = query(rank);
|
||||
if (finished)
|
||||
try {
|
||||
finished = query(rank);
|
||||
if (finished)
|
||||
break;
|
||||
} catch (...) {
|
||||
// There was an error when processing query. Probably exception occured in
|
||||
// recv/send what would indicate that socket on the other side has been closed.
|
||||
// If the closing was due to normal exit this indicates that store should
|
||||
// exit too. Otherwise, when closing was due to unexpected behaviour,
|
||||
// other processes will get exception when trying to use store.
|
||||
finished = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -128,16 +135,13 @@ bool Store::StoreDeamon::query(rank_type rank) {
|
|||
}
|
||||
return false;
|
||||
} else if (qt == QueryType::FINISH) {
|
||||
for (auto socket : _sockets) {
|
||||
send_value<std::uint8_t>(socket, 1);
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
throw std::runtime_error("expected a query type");
|
||||
}
|
||||
}
|
||||
|
||||
bool Store::StoreDeamon::checkAndUpdate(std::vector<std::string>& keys) {
|
||||
bool Store::StoreDeamon::checkAndUpdate(std::vector<std::string>& keys) const {
|
||||
bool ret = true;
|
||||
for (auto it = keys.begin(); it != keys.end();) {
|
||||
if (_store.find(*it) == _store.end()) {
|
||||
|
|
@ -160,7 +164,7 @@ Store::Store(rank_type rank, const std::string& addr,
|
|||
, _socket(-1)
|
||||
, _store_thread(nullptr)
|
||||
{
|
||||
// Master runs the store_thread
|
||||
// only one process starts store
|
||||
if (_rank == 0) {
|
||||
_store_thread = std::unique_ptr<StoreDeamon>(
|
||||
new StoreDeamon(port, world_size)
|
||||
|
|
@ -177,8 +181,6 @@ Store::~Store() {
|
|||
_store_thread->join();
|
||||
}
|
||||
|
||||
std::uint8_t barrier_byte;
|
||||
recv_bytes<std::uint8_t>(_socket, &barrier_byte, 1);
|
||||
::close(_socket);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ private:
|
|||
|
||||
void deamon();
|
||||
bool query(rank_type rank);
|
||||
bool checkAndUpdate(std::vector<std::string>& keys);
|
||||
bool checkAndUpdate(std::vector<std::string>& keys) const;
|
||||
|
||||
port_type _port;
|
||||
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ bool MasterCommandChannel::init() {
|
|||
::close(_sockets[0]);
|
||||
|
||||
int fd[2];
|
||||
SYSCHECK((void)::pipe(fd));
|
||||
SYSCHECK(::pipe(fd));
|
||||
_sockets[0] = fd[0];
|
||||
_error_pipe = fd[1];
|
||||
_error_thread = std::thread(&MasterCommandChannel::errorHandler, this);
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ constexpr int BARRIER_WAIT_TIME = 200; // milliseconds
|
|||
std::vector<std::thread> g_all_workers;
|
||||
std::mutex g_mutex;
|
||||
std::string g_data_channel_type;
|
||||
std::unique_ptr<Barrier> g_barrier;
|
||||
|
||||
|
||||
void test_send_recv_tensor(std::shared_ptr<thd::DataChannel> data_channel) {
|
||||
|
|
@ -685,7 +684,6 @@ void init_gloo_master(int workers) {
|
|||
|
||||
assert(masterChannel->init());
|
||||
run_all_tests(masterChannel, workers);
|
||||
g_barrier->wait();
|
||||
}
|
||||
|
||||
void init_gloo_worker(unsigned int id, int workers) {
|
||||
|
|
@ -697,7 +695,6 @@ void init_gloo_worker(unsigned int id, int workers) {
|
|||
|
||||
assert(worker_channel->init());
|
||||
run_all_tests(worker_channel, workers);
|
||||
g_barrier->wait();
|
||||
}
|
||||
#endif // WITH_GLOO
|
||||
|
||||
|
|
@ -736,7 +733,6 @@ int main(int argc, char const *argv[]) {
|
|||
#ifdef WITH_GLOO
|
||||
g_data_channel_type = "gloo";
|
||||
for (auto workers : WORKERS_NUM) {
|
||||
g_barrier.reset(new Barrier(workers + 1));
|
||||
std::cout << "Gloo (workers: " << workers << "):" << std::endl;
|
||||
// start gloo master
|
||||
std::thread gloo_master_thread(init_gloo_master, workers);
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ void test(std::shared_ptr<thd::DataChannel> data_channel) {
|
|||
}
|
||||
|
||||
void run_all_tests(std::shared_ptr<thd::DataChannel> data_channel, int workers) {
|
||||
// NOTE: without store this test would create about (1000 * WORKERS ^ 3) connections
|
||||
// NOTE: without properly working GlooCache this test would create
|
||||
// about (1000 * WORKERS ^ 3) connections what is over 'normal' system configuration
|
||||
for (std::size_t i = 0; i < 1000; ++i) {
|
||||
test(data_channel);
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user