mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.
I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.
I used the following script to do the canonicalization:
```
import subprocess
import re
import os.path
files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
for fn in files:
if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
continue
if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
continue
with open(fn, 'r') as f:
c = f.read()
def fmt(p):
return "#include <{}>".format(p)
def repl(m):
p = m.group(1)
if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
return fmt(p)
if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
return fmt(p)
for root in ["aten/src", "torch/lib", ""]:
for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
new_p = os.path.relpath(os.path.join(bad_root, p), root)
if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
return fmt(new_p)
print("ERROR: ", fn, p)
return m.group(0)
new_c = re.sub(r'#include "([^"]+)"', repl, c)
if new_c != c:
print(fn)
with open(fn, 'w') as f:
f.write(new_c)
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849
Reviewed By: dzhulgakov
Differential Revision: D13363445
Pulled By: ezyang
fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
169 lines
4.3 KiB
C++
169 lines
4.3 KiB
C++
#include <sys/mman.h>
|
|
#include <poll.h>
|
|
#include <errno.h>
|
|
#include <unistd.h>
|
|
#include <fcntl.h>
|
|
#include <vector>
|
|
#include <set>
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <unordered_map>
|
|
|
|
#include <torch/csrc/utils/tempfile.h>
|
|
#include <c10/util/Optional.h>
|
|
|
|
#include <libshm/err.h>
|
|
#include <libshm/socket.h>
|
|
|
|
const int SHUTDOWN_TIMEOUT = 2000; // 2s
|
|
|
|
#ifdef DEBUG_LOG
|
|
#define COLOR "\033[31;1m"
|
|
#define RESET "\033[0m"
|
|
#define __DEBUG(msg, ...) fprintf(stderr, COLOR msg "%c" RESET, __VA_ARGS__);
|
|
#define DEBUG(...) __DEBUG(__VA_ARGS__, '\n')
|
|
#else
|
|
#define DEBUG(...) (void)0
|
|
#endif
|
|
|
|
struct ClientSession {
|
|
ClientSession(ManagerSocket s): socket(std::move(s)), pid(0) {}
|
|
|
|
ManagerSocket socket;
|
|
pid_t pid;
|
|
};
|
|
|
|
|
|
std::vector<struct pollfd> pollfds;
|
|
std::unordered_map<int, ClientSession> client_sessions;
|
|
// TODO: check if objects have been freed from time to time
|
|
std::set<std::string> used_objects;
|
|
|
|
|
|
void register_fd(int fd) {
|
|
struct pollfd pfd = {0};
|
|
pfd.fd = fd;
|
|
pfd.events = POLLIN;
|
|
pollfds.push_back(pfd);
|
|
}
|
|
|
|
|
|
void unregister_fd(int fd) {
|
|
pollfds.erase(
|
|
std::remove_if(pollfds.begin(), pollfds.end(),
|
|
[fd](const struct pollfd &pfd) { return pfd.fd == fd; }),
|
|
pollfds.end());
|
|
client_sessions.erase(fd);
|
|
}
|
|
|
|
|
|
void print_init_message(const char *message) {
|
|
size_t unused;
|
|
unused = write(1, message, strlen(message));
|
|
unused = write(1, "\n", 1);
|
|
}
|
|
|
|
bool object_exists(const char *name) {
|
|
int fd = shm_open(name, O_RDONLY, 0);
|
|
if (fd >= 0) {
|
|
close(fd);
|
|
return true;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
void free_used_object(const std::string &name) {
|
|
if (!object_exists(name.c_str())) {
|
|
DEBUG("object %s appears to have been freed", name.c_str());
|
|
used_objects.erase(name);
|
|
} else {
|
|
DEBUG("object %s still exists", name.c_str());
|
|
}
|
|
}
|
|
|
|
int main(int argc, char *argv[]) {
|
|
setsid(); // Daemonize the process
|
|
|
|
std::unique_ptr<ManagerServerSocket> srv_socket;
|
|
const auto tempfile =
|
|
torch::utils::try_make_tempfile(/*name_prefix=*/"torch-shm-file-");
|
|
try {
|
|
if (!tempfile.has_value()) {
|
|
throw std::runtime_error(
|
|
"could not generate a random filename for manager socket");
|
|
}
|
|
// TODO: better strategy for generating tmp names
|
|
// TODO: retry on collisions - this can easily fail
|
|
srv_socket.reset(new ManagerServerSocket(tempfile->name));
|
|
register_fd(srv_socket->socket_fd);
|
|
print_init_message(tempfile->name.c_str());
|
|
DEBUG("opened socket %s", tempfile->name.c_str());
|
|
} catch (...) {
|
|
print_init_message("ERROR");
|
|
throw;
|
|
}
|
|
|
|
int timeout = -1;
|
|
std::vector<int> to_add;
|
|
std::vector<int> to_remove;
|
|
for (;;) {
|
|
int nevents;
|
|
if (client_sessions.size() == 0)
|
|
timeout = SHUTDOWN_TIMEOUT;
|
|
SYSCHECK(nevents = poll(pollfds.data(), pollfds.size(), timeout));
|
|
timeout = -1;
|
|
if (nevents == 0 && client_sessions.size() == 0)
|
|
break;
|
|
|
|
for (auto &pfd: pollfds) {
|
|
if (pfd.revents & (POLLERR | POLLHUP)) {
|
|
// some process died
|
|
DEBUG("detaching process");
|
|
auto &session = client_sessions.at(pfd.fd);
|
|
DEBUG("%d has died", session.pid);
|
|
to_remove.push_back(pfd.fd);
|
|
} else if (pfd.revents & POLLIN) {
|
|
if (pfd.fd == srv_socket->socket_fd) {
|
|
// someone is joining
|
|
DEBUG("registered new client");
|
|
auto client = srv_socket->accept();
|
|
int fd = client.socket_fd;
|
|
to_add.push_back(fd);
|
|
client_sessions.emplace(fd, std::move(client));
|
|
} else {
|
|
// someone wants to register a segment
|
|
DEBUG("got alloc info");
|
|
auto &session = client_sessions.at(pfd.fd);
|
|
AllocInfo info = session.socket.receive();
|
|
session.pid = info.pid;
|
|
DEBUG("got alloc info: %d %d %s", (int)info.free, info.pid, info.filename);
|
|
if (info.free) {
|
|
free_used_object(info.filename);
|
|
} else {
|
|
used_objects.insert(info.filename);
|
|
DEBUG("registered object %s", info.filename);
|
|
session.socket.confirm();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for (int fd: to_add)
|
|
register_fd(fd);
|
|
to_add.clear();
|
|
|
|
for (int fd: to_remove)
|
|
unregister_fd(fd);
|
|
to_remove.clear();
|
|
}
|
|
|
|
for (auto &obj_name: used_objects) {
|
|
DEBUG("freeing %s", obj_name.c_str());
|
|
shm_unlink(obj_name.c_str());
|
|
}
|
|
|
|
DEBUG("manager done");
|
|
return 0;
|
|
}
|