pytorch/torch/lib/THD/base/ChannelUtils.cpp
Edward Yang 517c7c9861 Canonicalize all includes in PyTorch. (#14849)
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.

I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.

I used the following script to do the canonicalization:

```
  import subprocess
  import re
  import os.path

  files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
  for fn in files:
      if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
          continue
      if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
          continue
      with open(fn, 'r') as f:
          c = f.read()
      def fmt(p):
          return "#include <{}>".format(p)
      def repl(m):
          p = m.group(1)
          if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
              return fmt(p)
          if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
              return fmt(p)
          for root in ["aten/src", "torch/lib", ""]:
              for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
                  new_p = os.path.relpath(os.path.join(bad_root, p), root)
                  if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
                      return fmt(new_p)
          print("ERROR: ", fn, p)
          return m.group(0)
      new_c = re.sub(r'#include "([^"]+)"', repl, c)
      if new_c != c:
          print(fn)
          with open(fn, 'w') as f:
              f.write(new_c)
```

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849

Reviewed By: dzhulgakov

Differential Revision: D13363445

Pulled By: ezyang

fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
2018-12-08 19:38:30 -08:00

282 lines
8.6 KiB
C++

#include <THD/base/ChannelUtils.hpp>
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/poll.h>
#include <unistd.h>
#include <algorithm>
#include <cstring>
#include <memory>
#include <string>
#include <thread>
namespace thd {
namespace {
constexpr int LISTEN_QUEUE_SIZE = 1024;
void setSocketNoDelay(int socket) {
int flag = 1;
socklen_t optlen = sizeof(flag);
SYSCHECK(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen));
}
port_type getSocketPort(int fd) {
port_type listen_port;
struct sockaddr_storage addr_storage;
socklen_t addr_len = sizeof(addr_storage);
SYSCHECK(getsockname(
fd, reinterpret_cast<struct sockaddr*>(&addr_storage), &addr_len));
if (addr_storage.ss_family == AF_INET) {
struct sockaddr_in* addr =
reinterpret_cast<struct sockaddr_in*>(&addr_storage);
listen_port = ntohs(addr->sin_port);
} else if (addr_storage.ss_family == AF_INET6) { // AF_INET6
struct sockaddr_in6* addr =
reinterpret_cast<struct sockaddr_in6*>(&addr_storage);
listen_port = ntohs(addr->sin6_port);
} else {
throw std::runtime_error("unsupported protocol");
}
return listen_port;
}
} // anonymous namespace
std::pair<std::string, std::string> splitAddress(const std::string& addr) {
std::string host, port;
auto num_colons = std::count(addr.begin(), addr.end(), ':');
if (num_colons > 1) {
// IPv6
auto end_pos = addr.find(']');
if (addr[0] != '[' || end_pos == std::string::npos) {
throw std::invalid_argument(
"IPv6 address in an incorrect format (maybe you forgot to add [ ])");
}
host = addr.substr(1, end_pos - 1);
port = addr.substr(end_pos + 2);
} else if (num_colons == 1) {
// IPv4 or HOSTNAME:PORT
auto sep_pos = addr.find(':');
host = addr.substr(0, sep_pos);
port = addr.substr(sep_pos + 1);
} else {
throw std::invalid_argument(
"expected an address in format IP:PORT or HOSTNAME:PORT");
}
if (addr == "" || port == "") {
throw std::invalid_argument("expected an address in format IP:PORT");
}
return std::make_pair(host, port);
}
std::string sockaddrToString(struct sockaddr* addr) {
char address[INET6_ADDRSTRLEN + 1];
if (addr->sa_family == AF_INET) {
struct sockaddr_in* s = reinterpret_cast<struct sockaddr_in*>(addr);
SYSCHECK(::inet_ntop(AF_INET, &(s->sin_addr), address, INET_ADDRSTRLEN))
address[INET_ADDRSTRLEN] = '\0';
} else if (addr->sa_family == AF_INET6) {
struct sockaddr_in6* s = reinterpret_cast<struct sockaddr_in6*>(addr);
SYSCHECK(::inet_ntop(AF_INET6, &(s->sin6_addr), address, INET6_ADDRSTRLEN))
address[INET6_ADDRSTRLEN] = '\0';
} else {
throw std::runtime_error("unsupported protocol");
}
return address;
}
std::pair<int, port_type> listen(port_type port) {
struct addrinfo hints, *res = NULL;
std::memset(&hints, 0x00, sizeof(hints));
hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6
hints.ai_socktype = SOCK_STREAM; // TCP
// `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked
// by editing `/etc/gai.conf`. so there is no need to manual sorting
// or protocol preference.
int err = ::getaddrinfo(nullptr, std::to_string(port).data(), &hints, &res);
if (err != 0 || !res) {
throw std::invalid_argument(
"cannot find host to listen on: " + std::string(gai_strerror(err)));
}
std::shared_ptr<struct addrinfo> addresses(
res, [](struct addrinfo* p) { ::freeaddrinfo(p); });
struct addrinfo* next_addr = addresses.get();
int socket;
while (true) {
try {
SYSCHECK(
socket = ::socket(
next_addr->ai_family,
next_addr->ai_socktype,
next_addr->ai_protocol))
int optval = 1;
SYSCHECK(
::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(int)))
SYSCHECK(::bind(socket, next_addr->ai_addr, next_addr->ai_addrlen))
SYSCHECK(::listen(socket, LISTEN_QUEUE_SIZE))
break;
} catch (const std::system_error& e) {
::close(socket);
next_addr = next_addr->ai_next;
// we have tried all addresses but could not start listening on any of
// them
if (!next_addr) {
throw;
}
}
}
// get listen port and address
return {socket, getSocketPort(socket)};
}
int connect(
const std::string& address,
port_type port,
bool wait,
int timeout) {
struct addrinfo hints, *res = NULL;
std::memset(&hints, 0x00, sizeof(hints));
hints.ai_flags = AI_NUMERICSERV; // specifies that port (service) is numeric
hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6
hints.ai_socktype = SOCK_STREAM; // TCP
// `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked
// by editing `/etc/gai.conf`. so there is no need to manual sorting
// or protcol preference.
int err =
::getaddrinfo(address.data(), std::to_string(port).data(), &hints, &res);
if (err != 0 || !res) {
throw std::invalid_argument(
"host not found: " + std::string(gai_strerror(err)));
}
std::shared_ptr<struct addrinfo> addresses(
res, [](struct addrinfo* p) { ::freeaddrinfo(p); });
struct addrinfo* next_addr = addresses.get();
int socket;
// we'll loop over the addresses only if at least of them gave us
// ECONNREFUSED. Maybe the host was up, but the server wasn't running.
bool any_refused = false;
while (true) {
try {
SYSCHECK(
socket = ::socket(
next_addr->ai_family,
next_addr->ai_socktype,
next_addr->ai_protocol))
ResourceGuard socket_guard([socket]() { ::close(socket); });
// We need to connect in non-blocking mode, so we can use a timeout
SYSCHECK(::fcntl(socket, F_SETFL, O_NONBLOCK));
int ret = ::connect(socket, next_addr->ai_addr, next_addr->ai_addrlen);
if (ret != 0 && errno != EINPROGRESS)
throw std::system_error(errno, std::system_category());
struct pollfd pfd;
pfd.fd = socket;
pfd.events = POLLOUT;
int num_ready = ::poll(&pfd, 1, timeout);
if (num_ready < 0) {
throw std::system_error(errno, std::system_category());
} else if (num_ready == 0) {
errno = 0;
throw std::runtime_error("connect() timed out");
}
socklen_t err_len = sizeof(errno);
errno = 0;
::getsockopt(socket, SOL_SOCKET, SO_ERROR, &errno, &err_len);
/* `errno` is set when:
* 1. `getsockopt` has failed
* 2. there is awaiting error in the socket (the error is saved to the
* `errno` variable)
*/
if (errno != 0) {
throw std::system_error(errno, std::system_category());
}
// Disable non-blocking mode
int flags;
SYSCHECK(flags = ::fcntl(socket, F_GETFL));
SYSCHECK(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK)));
socket_guard.release();
break;
} catch (std::exception& e) {
if (errno == ECONNREFUSED)
any_refused = true;
// We need to move to the next address because this was not available
// to connect or to create a socket.
next_addr = next_addr->ai_next;
// We have tried all addresses but could not connect to any of them.
if (!next_addr) {
if (!wait || !any_refused)
throw;
std::this_thread::sleep_for(std::chrono::seconds(1));
any_refused = false;
next_addr = addresses.get();
}
}
}
setSocketNoDelay(socket);
return socket;
}
std::tuple<int, std::string> accept(int listen_socket, int timeout) {
// poll on listen socket, it allows to make timeout
std::unique_ptr<struct pollfd[]> events(new struct pollfd[1]);
events[0] = {.fd = listen_socket, .events = POLLIN};
while (true) {
int res = ::poll(events.get(), 1, timeout);
if (res == 0) {
throw std::runtime_error(
"waiting for processes to connect has timed out");
} else if (res == -1) {
if (errno == EINTR) {
continue;
}
throw std::system_error(errno, std::system_category());
} else {
if (!(events[0].revents & POLLIN))
throw std::system_error(ECONNABORTED, std::system_category());
break;
}
}
int socket;
SYSCHECK(socket = ::accept(listen_socket, NULL, NULL))
// Get address of the connecting process
struct sockaddr_storage addr;
socklen_t addr_len = sizeof(addr);
SYSCHECK(::getpeername(
socket, reinterpret_cast<struct sockaddr*>(&addr), &addr_len))
setSocketNoDelay(socket);
return std::make_tuple(
socket, sockaddrToString(reinterpret_cast<struct sockaddr*>(&addr)));
}
} // namespace thd