pytorch/torch/lib/libshm/manager.cpp
Edward Yang 517c7c9861 Canonicalize all includes in PyTorch. (#14849)
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.

I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.

I used the following script to do the canonicalization:

```
  import subprocess
  import re
  import os.path

  files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
  for fn in files:
      if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
          continue
      if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
          continue
      with open(fn, 'r') as f:
          c = f.read()
      def fmt(p):
          return "#include <{}>".format(p)
      def repl(m):
          p = m.group(1)
          if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
              return fmt(p)
          if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
              return fmt(p)
          for root in ["aten/src", "torch/lib", ""]:
              for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
                  new_p = os.path.relpath(os.path.join(bad_root, p), root)
                  if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
                      return fmt(new_p)
          print("ERROR: ", fn, p)
          return m.group(0)
      new_c = re.sub(r'#include "([^"]+)"', repl, c)
      if new_c != c:
          print(fn)
          with open(fn, 'w') as f:
              f.write(new_c)
```

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849

Reviewed By: dzhulgakov

Differential Revision: D13363445

Pulled By: ezyang

fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
2018-12-08 19:38:30 -08:00

169 lines
4.3 KiB
C++

#include <sys/mman.h>
#include <poll.h>
#include <errno.h>
#include <unistd.h>
#include <fcntl.h>
#include <vector>
#include <set>
#include <algorithm>
#include <memory>
#include <unordered_map>
#include <torch/csrc/utils/tempfile.h>
#include <c10/util/Optional.h>
#include <libshm/err.h>
#include <libshm/socket.h>
const int SHUTDOWN_TIMEOUT = 2000; // 2s
#ifdef DEBUG_LOG
#define COLOR "\033[31;1m"
#define RESET "\033[0m"
#define __DEBUG(msg, ...) fprintf(stderr, COLOR msg "%c" RESET, __VA_ARGS__);
#define DEBUG(...) __DEBUG(__VA_ARGS__, '\n')
#else
#define DEBUG(...) (void)0
#endif
struct ClientSession {
ClientSession(ManagerSocket s): socket(std::move(s)), pid(0) {}
ManagerSocket socket;
pid_t pid;
};
std::vector<struct pollfd> pollfds;
std::unordered_map<int, ClientSession> client_sessions;
// TODO: check if objects have been freed from time to time
std::set<std::string> used_objects;
void register_fd(int fd) {
struct pollfd pfd = {0};
pfd.fd = fd;
pfd.events = POLLIN;
pollfds.push_back(pfd);
}
void unregister_fd(int fd) {
pollfds.erase(
std::remove_if(pollfds.begin(), pollfds.end(),
[fd](const struct pollfd &pfd) { return pfd.fd == fd; }),
pollfds.end());
client_sessions.erase(fd);
}
void print_init_message(const char *message) {
size_t unused;
unused = write(1, message, strlen(message));
unused = write(1, "\n", 1);
}
bool object_exists(const char *name) {
int fd = shm_open(name, O_RDONLY, 0);
if (fd >= 0) {
close(fd);
return true;
} else {
return false;
}
}
void free_used_object(const std::string &name) {
if (!object_exists(name.c_str())) {
DEBUG("object %s appears to have been freed", name.c_str());
used_objects.erase(name);
} else {
DEBUG("object %s still exists", name.c_str());
}
}
int main(int argc, char *argv[]) {
setsid(); // Daemonize the process
std::unique_ptr<ManagerServerSocket> srv_socket;
const auto tempfile =
torch::utils::try_make_tempfile(/*name_prefix=*/"torch-shm-file-");
try {
if (!tempfile.has_value()) {
throw std::runtime_error(
"could not generate a random filename for manager socket");
}
// TODO: better strategy for generating tmp names
// TODO: retry on collisions - this can easily fail
srv_socket.reset(new ManagerServerSocket(tempfile->name));
register_fd(srv_socket->socket_fd);
print_init_message(tempfile->name.c_str());
DEBUG("opened socket %s", tempfile->name.c_str());
} catch (...) {
print_init_message("ERROR");
throw;
}
int timeout = -1;
std::vector<int> to_add;
std::vector<int> to_remove;
for (;;) {
int nevents;
if (client_sessions.size() == 0)
timeout = SHUTDOWN_TIMEOUT;
SYSCHECK(nevents = poll(pollfds.data(), pollfds.size(), timeout));
timeout = -1;
if (nevents == 0 && client_sessions.size() == 0)
break;
for (auto &pfd: pollfds) {
if (pfd.revents & (POLLERR | POLLHUP)) {
// some process died
DEBUG("detaching process");
auto &session = client_sessions.at(pfd.fd);
DEBUG("%d has died", session.pid);
to_remove.push_back(pfd.fd);
} else if (pfd.revents & POLLIN) {
if (pfd.fd == srv_socket->socket_fd) {
// someone is joining
DEBUG("registered new client");
auto client = srv_socket->accept();
int fd = client.socket_fd;
to_add.push_back(fd);
client_sessions.emplace(fd, std::move(client));
} else {
// someone wants to register a segment
DEBUG("got alloc info");
auto &session = client_sessions.at(pfd.fd);
AllocInfo info = session.socket.receive();
session.pid = info.pid;
DEBUG("got alloc info: %d %d %s", (int)info.free, info.pid, info.filename);
if (info.free) {
free_used_object(info.filename);
} else {
used_objects.insert(info.filename);
DEBUG("registered object %s", info.filename);
session.socket.confirm();
}
}
}
}
for (int fd: to_add)
register_fd(fd);
to_add.clear();
for (int fd: to_remove)
unregister_fd(fd);
to_remove.clear();
}
for (auto &obj_name: used_objects) {
DEBUG("freeing %s", obj_name.c_str());
shm_unlink(obj_name.c_str());
}
DEBUG("manager done");
return 0;
}