From d86c14156d875b782b82dda96842a1f77910f010 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Sun, 6 Apr 2025 09:08:18 +0000 Subject: [PATCH] Generalize poison fork logic for each device backend (#144664) # Motivation Generalize the posion_fork code to make it reusable across different devices. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144664 Approved by: https://github.com/EikanWang, https://github.com/albanD --- torch/csrc/cuda/Module.cpp | 36 ++++++------------------ torch/csrc/mps/Module.cpp | 30 +++----------------- torch/csrc/mtia/Module.cpp | 31 ++++----------------- torch/csrc/utils/device_lazy_init.cpp | 40 +++++++++++++++++++++++++++ torch/csrc/utils/device_lazy_init.h | 17 ++++++++++++ torch/csrc/xpu/Module.cpp | 34 ++++------------------- 6 files changed, 80 insertions(+), 108 deletions(-) diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 1ff4079a56e..f5365a674d2 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -51,32 +51,9 @@ #include #include #include -#ifndef WIN32 -#include -#endif using namespace torch; -static bool in_bad_fork = false; // True for children forked after cuda init - -#ifndef WIN32 -// Called in the forked child if cuda has already been initialized -static void forked_child() { - in_bad_fork = true; - torch::utils::set_requires_device_init(at::kCUDA, true); -} -#endif - -// Should be called before the first cuda call. -// Note: This is distinct from initExtension because a stub cuda implementation -// has some working functions (e.g. device_count) but cannot fully initialize. -static void poison_fork() { -#ifndef WIN32 - static auto result [[maybe_unused]] = - pthread_atfork(nullptr, nullptr, forked_child); -#endif -} - //////////////////////////////////////////////////////////////////////////////// // CUDA management methods //////////////////////////////////////////////////////////////////////////////// @@ -160,14 +137,17 @@ PyObject* THCPModule_canDeviceAccessPeer_wrap(PyObject* self, PyObject* args) { PyObject* THCPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - poison_fork(); + // Note: This is distinct from initExtension because a stub cuda + // implementation has some working functions (e.g. device_count) but cannot + // fully initialize. + torch::utils::register_fork_handler_for_device_init(at::kCUDA); return THPUtils_packUInt64(at::cuda::device_count()); END_HANDLE_TH_ERRORS } PyObject* THCPModule_getArchFlags(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - poison_fork(); + torch::utils::register_fork_handler_for_device_init(at::kCUDA); #ifdef CUDA_ARCH_FLAGS static const char* flags = C10_STRINGIZE(CUDA_ARCH_FLAGS); return THPUtils_packString(flags); @@ -179,7 +159,7 @@ PyObject* THCPModule_getArchFlags(PyObject* self, PyObject* noargs) { static PyObject* THCPModule_isInBadFork(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - return PyBool_FromLong(in_bad_fork); + return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kCUDA)); END_HANDLE_TH_ERRORS } @@ -1513,8 +1493,8 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) { "please rebuild pytorch without asan if you need to use this module"); #endif HANDLE_TH_ERRORS - TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level - poison_fork(); + TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kCUDA)); + torch::utils::register_fork_handler_for_device_init(at::kCUDA); at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda")); diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index 3694cd19417..0ec9b8418c6 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -6,16 +6,12 @@ #include #include #include +#include #include #include #include #include -// pthread.h is included for tracking bad forks -#ifndef WIN32 -#include -#endif - #ifdef USE_MPS #include #include @@ -23,27 +19,9 @@ namespace torch::mps { -namespace { -// True for children forked after mps init -static bool in_bad_fork = false; - -// Called in the forked child if mps has already been initialized -static void forked_mps_child() { - in_bad_fork = true; -} - -// Should be called before the first mps call. -static void track_bad_mps_fork() { -#ifndef WIN32 - static auto result [[maybe_unused]] = - pthread_atfork(nullptr, nullptr, forked_mps_child); -#endif -} -} // namespace - static PyObject* MPSModule_isInBadFork(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - return PyBool_FromLong(in_bad_fork); + return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kMPS)); END_HANDLE_TH_ERRORS } @@ -51,7 +29,7 @@ static PyObject* MPSModule_getDefaultMPSGenerator( PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS - track_bad_mps_fork(); + torch::utils::register_fork_handler_for_device_init(at::kMPS); return THPGenerator_initDefaultGenerator( at::detail::getMPSHooks().getDefaultGenerator()); END_HANDLE_TH_ERRORS @@ -59,8 +37,8 @@ static PyObject* MPSModule_getDefaultMPSGenerator( static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS - track_bad_mps_fork(); if (at::detail::getMPSHooks().hasMPS()) { + torch::utils::register_fork_handler_for_device_init(at::kMPS); Py_RETURN_TRUE; } else { Py_RETURN_FALSE; diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index 405b9d78002..ec6229967e0 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -7,38 +7,15 @@ #include #include #include -#ifndef WIN32 -#include -#endif namespace torch::mtia { -static bool in_bad_fork = false; // True for children forked after mtia init - -#ifndef WIN32 -// Called in the forked child if mtia has already been initialized -static void forked_child() { - in_bad_fork = true; - torch::utils::set_requires_device_init(at::kMTIA, true); -} -#endif - -// Should be called before the first mtia call. -// Note: This is distinct from initExtension because a stub mtia implementation -// has some working functions (e.g. device_count) but cannot fully initialize. -static void poison_fork() { -#ifndef WIN32 - static auto result [[maybe_unused]] = - pthread_atfork(nullptr, nullptr, forked_child); -#endif -} - void initModule(PyObject* module) { auto m = py::handle(module).cast(); m.def("_mtia_init", []() { - TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level - poison_fork(); + TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kMTIA)); + torch::utils::register_fork_handler_for_device_init(at::kMTIA); at::globalContext().lazyInitDevice(c10::DeviceType::MTIA); }); @@ -47,7 +24,9 @@ void initModule(PyObject* module) { return at::detail::isMTIAHooksBuilt(); }); - m.def("_mtia_isInBadFork", []() { return in_bad_fork; }); + m.def("_mtia_isInBadFork", []() { + return torch::utils::is_device_in_bad_fork(at::kMTIA); + }); m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) { torch::utils::device_lazy_init(at::kMTIA); diff --git a/torch/csrc/utils/device_lazy_init.cpp b/torch/csrc/utils/device_lazy_init.cpp index 74adb6b5e6b..c5a6512b363 100644 --- a/torch/csrc/utils/device_lazy_init.cpp +++ b/torch/csrc/utils/device_lazy_init.cpp @@ -1,13 +1,23 @@ #include +#include #include #include #include #include + +#ifndef WIN32 +#include +#endif + namespace torch::utils { namespace { std::array is_initialized{}; +std::array is_in_bad_fork{}; +std::array + at_fork_once_flags{}; +std::optional at_fork_device_type{}; } // anonymous namespace @@ -58,4 +68,34 @@ void set_requires_device_init(at::DeviceType device_type, bool value) { is_initialized[static_cast(device_type)] = !value; } +bool is_device_in_bad_fork(at::DeviceType device_type) { + return is_in_bad_fork[static_cast(device_type)]; +} + +void set_device_in_bad_fork(at::DeviceType device_type, bool value) { + is_in_bad_fork[static_cast(device_type)] = value; +} + +// Should be called before the first device runtime call. +void register_fork_handler_for_device_init(at::DeviceType device_type) { +#ifndef WIN32 + auto& flag = at_fork_once_flags[static_cast(device_type)]; + c10::call_once(flag, [device_type]() { + TORCH_CHECK( + !at_fork_device_type, + "Only one device type can be registered. But now, we have two device types: ", + at_fork_device_type.value(), + " and ", + device_type); + at_fork_device_type = device_type; + pthread_atfork(nullptr, nullptr, []() { + set_device_in_bad_fork(at_fork_device_type.value(), true); + if (is_device_lazy_init_supported(at_fork_device_type.value())) { + set_requires_device_init(at_fork_device_type.value(), true); + } + }); + }); +#endif +} + } // namespace torch::utils diff --git a/torch/csrc/utils/device_lazy_init.h b/torch/csrc/utils/device_lazy_init.h index e1f480a60f7..e65f16ace16 100644 --- a/torch/csrc/utils/device_lazy_init.h +++ b/torch/csrc/utils/device_lazy_init.h @@ -67,4 +67,21 @@ inline void maybe_initialize_device( bool is_device_initialized(at::DeviceType device_type); +TORCH_PYTHON_API bool is_device_in_bad_fork(at::DeviceType device_type); + +TORCH_PYTHON_API void set_device_in_bad_fork( + at::DeviceType device_type, + bool value); + +TORCH_PYTHON_API void register_fork_handler_for_device_init( + at::DeviceType device_type); + +inline void maybe_register_fork_handler_for_device_init( + std::optional& device_type) { + if (!device_type.has_value()) { + return; + } + register_fork_handler_for_device_init(device_type.value()); +} + } // namespace torch::utils diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index 43ad06365ef..8144dddd829 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -11,32 +11,8 @@ #include #include -#ifndef WIN32 -#include -#endif - using namespace torch; -static bool in_bad_fork = false; // True for children forked after xpu init - -#ifndef WIN32 -// Called in the forked child if xpu has already been initialized -static void forked_child() { - in_bad_fork = true; - torch::utils::set_requires_device_init(at::kXPU, true); -} -#endif - -// Should be called before the first xpu call. It is mainly called in lazy_init. -// Note: This is distinct from initExtension because a stub xpu implementation -// has some working functions (e.g. device_count) but cannot fully initialize. -static void poison_fork() { -#ifndef WIN32 - static auto result [[maybe_unused]] = - pthread_atfork(nullptr, nullptr, forked_child); -#endif -} - // XPU management methods static PyObject* THXPModule_getArchFlags(PyObject* self, PyObject* noargs) { @@ -52,7 +28,7 @@ static PyObject* THXPModule_getArchFlags(PyObject* self, PyObject* noargs) { static PyObject* THXPModule_isInBadFork_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - return PyBool_FromLong(in_bad_fork); + return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kXPU)); END_HANDLE_TH_ERRORS } @@ -115,7 +91,9 @@ static PyObject* THXPModule_getDeviceCount_wrap( PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - poison_fork(); + // Note: This is distinct from initExtension because a stub xpu implementation + // has some working functions (e.g. device_count) but cannot fully initialize. + torch::utils::register_fork_handler_for_device_init(at::kXPU); return THPUtils_packUInt64(at::xpu::device_count()); END_HANDLE_TH_ERRORS } @@ -420,8 +398,8 @@ static void initXpuMethodBindings(PyObject* module) { // classes static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level - poison_fork(); + TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kXPU)); + torch::utils::register_fork_handler_for_device_init(at::kXPU); at::globalContext().lazyInitDevice(c10::DeviceType::XPU); auto m = THPObjectPtr(PyImport_ImportModule("torch.xpu"));