From a0ab243c3a5dfe12b392e4074d69360fd013f842 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 10 Apr 2025 21:02:14 +0000 Subject: [PATCH] Revert "Generalize poison fork logic for each device backend (#144664)" This reverts commit 83bd0b63b55f224fada6d5f6dd7eb5b4cb3072fb. Reverted https://github.com/pytorch/pytorch/pull/144664 on behalf of https://github.com/atalman due to failing internal tests ([comment](https://github.com/pytorch/pytorch/pull/144664#issuecomment-2795157082)) --- test/test_cpp_extensions_mtia_backend.py | 12 ++++--- torch/csrc/cuda/Module.cpp | 36 ++++++++++++++++----- torch/csrc/mps/Module.cpp | 30 +++++++++++++++--- torch/csrc/mtia/Module.cpp | 31 +++++++++++++++--- torch/csrc/utils/device_lazy_init.cpp | 40 ------------------------ torch/csrc/utils/device_lazy_init.h | 17 ---------- torch/csrc/xpu/Module.cpp | 34 ++++++++++++++++---- 7 files changed, 115 insertions(+), 85 deletions(-) diff --git a/test/test_cpp_extensions_mtia_backend.py b/test/test_cpp_extensions_mtia_backend.py index 8082113f263..6203b799328 100644 --- a/test/test_cpp_extensions_mtia_backend.py +++ b/test/test_cpp_extensions_mtia_backend.py @@ -11,18 +11,20 @@ from torch.testing._internal.common_utils import ( IS_ARM64, IS_LINUX, skipIfTorchDynamo, + TEST_CUDA, TEST_PRIVATEUSE1, + TEST_XPU, ) +from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME -# This TestCase should be mutually exclusive with other backends. -HAS_CUDA = torch.backends.cuda.is_built() -HAS_XPU = torch.xpu._is_compiled() -HAS_MPS = torch.backends.mps.is_built() +# define TEST_ROCM before changing TEST_CUDA +TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None +TEST_CUDA = TEST_CUDA and CUDA_HOME is not None @unittest.skipIf( - IS_ARM64 or not IS_LINUX or HAS_CUDA or HAS_XPU or HAS_MPS or TEST_PRIVATEUSE1, + IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1 or TEST_ROCM or TEST_XPU, "Only on linux platform and mutual exclusive to other backends", ) @torch.testing._internal.common_utils.markDynamoStrictTest diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index f5365a674d2..1ff4079a56e 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -51,9 +51,32 @@ #include #include #include +#ifndef WIN32 +#include +#endif using namespace torch; +static bool in_bad_fork = false; // True for children forked after cuda init + +#ifndef WIN32 +// Called in the forked child if cuda has already been initialized +static void forked_child() { + in_bad_fork = true; + torch::utils::set_requires_device_init(at::kCUDA, true); +} +#endif + +// Should be called before the first cuda call. +// Note: This is distinct from initExtension because a stub cuda implementation +// has some working functions (e.g. device_count) but cannot fully initialize. +static void poison_fork() { +#ifndef WIN32 + static auto result [[maybe_unused]] = + pthread_atfork(nullptr, nullptr, forked_child); +#endif +} + //////////////////////////////////////////////////////////////////////////////// // CUDA management methods //////////////////////////////////////////////////////////////////////////////// @@ -137,17 +160,14 @@ PyObject* THCPModule_canDeviceAccessPeer_wrap(PyObject* self, PyObject* args) { PyObject* THCPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - // Note: This is distinct from initExtension because a stub cuda - // implementation has some working functions (e.g. device_count) but cannot - // fully initialize. - torch::utils::register_fork_handler_for_device_init(at::kCUDA); + poison_fork(); return THPUtils_packUInt64(at::cuda::device_count()); END_HANDLE_TH_ERRORS } PyObject* THCPModule_getArchFlags(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - torch::utils::register_fork_handler_for_device_init(at::kCUDA); + poison_fork(); #ifdef CUDA_ARCH_FLAGS static const char* flags = C10_STRINGIZE(CUDA_ARCH_FLAGS); return THPUtils_packString(flags); @@ -159,7 +179,7 @@ PyObject* THCPModule_getArchFlags(PyObject* self, PyObject* noargs) { static PyObject* THCPModule_isInBadFork(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kCUDA)); + return PyBool_FromLong(in_bad_fork); END_HANDLE_TH_ERRORS } @@ -1493,8 +1513,8 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) { "please rebuild pytorch without asan if you need to use this module"); #endif HANDLE_TH_ERRORS - TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kCUDA)); - torch::utils::register_fork_handler_for_device_init(at::kCUDA); + TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level + poison_fork(); at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda")); diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index 8c33c596b32..3cd75cedada 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -6,12 +6,16 @@ #include #include #include -#include #include #include #include #include +// pthread.h is included for tracking bad forks +#ifndef WIN32 +#include +#endif + #ifdef USE_MPS #include #include @@ -19,9 +23,27 @@ namespace torch::mps { +namespace { +// True for children forked after mps init +static bool in_bad_fork = false; + +// Called in the forked child if mps has already been initialized +static void forked_mps_child() { + in_bad_fork = true; +} + +// Should be called before the first mps call. +static void track_bad_mps_fork() { +#ifndef WIN32 + static auto result [[maybe_unused]] = + pthread_atfork(nullptr, nullptr, forked_mps_child); +#endif +} +} // namespace + static PyObject* MPSModule_isInBadFork(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kMPS)); + return PyBool_FromLong(in_bad_fork); END_HANDLE_TH_ERRORS } @@ -29,7 +51,7 @@ static PyObject* MPSModule_getDefaultMPSGenerator( PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS - torch::utils::register_fork_handler_for_device_init(at::kMPS); + track_bad_mps_fork(); return THPGenerator_initDefaultGenerator( at::detail::getMPSHooks().getDefaultGenerator()); END_HANDLE_TH_ERRORS @@ -37,8 +59,8 @@ static PyObject* MPSModule_getDefaultMPSGenerator( static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS + track_bad_mps_fork(); if (at::detail::getMPSHooks().hasMPS()) { - torch::utils::register_fork_handler_for_device_init(at::kMPS); Py_RETURN_TRUE; } else { Py_RETURN_FALSE; diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index ec6229967e0..405b9d78002 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -7,15 +7,38 @@ #include #include #include +#ifndef WIN32 +#include +#endif namespace torch::mtia { +static bool in_bad_fork = false; // True for children forked after mtia init + +#ifndef WIN32 +// Called in the forked child if mtia has already been initialized +static void forked_child() { + in_bad_fork = true; + torch::utils::set_requires_device_init(at::kMTIA, true); +} +#endif + +// Should be called before the first mtia call. +// Note: This is distinct from initExtension because a stub mtia implementation +// has some working functions (e.g. device_count) but cannot fully initialize. +static void poison_fork() { +#ifndef WIN32 + static auto result [[maybe_unused]] = + pthread_atfork(nullptr, nullptr, forked_child); +#endif +} + void initModule(PyObject* module) { auto m = py::handle(module).cast(); m.def("_mtia_init", []() { - TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kMTIA)); - torch::utils::register_fork_handler_for_device_init(at::kMTIA); + TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level + poison_fork(); at::globalContext().lazyInitDevice(c10::DeviceType::MTIA); }); @@ -24,9 +47,7 @@ void initModule(PyObject* module) { return at::detail::isMTIAHooksBuilt(); }); - m.def("_mtia_isInBadFork", []() { - return torch::utils::is_device_in_bad_fork(at::kMTIA); - }); + m.def("_mtia_isInBadFork", []() { return in_bad_fork; }); m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) { torch::utils::device_lazy_init(at::kMTIA); diff --git a/torch/csrc/utils/device_lazy_init.cpp b/torch/csrc/utils/device_lazy_init.cpp index c5a6512b363..74adb6b5e6b 100644 --- a/torch/csrc/utils/device_lazy_init.cpp +++ b/torch/csrc/utils/device_lazy_init.cpp @@ -1,23 +1,13 @@ #include -#include #include #include #include #include - -#ifndef WIN32 -#include -#endif - namespace torch::utils { namespace { std::array is_initialized{}; -std::array is_in_bad_fork{}; -std::array - at_fork_once_flags{}; -std::optional at_fork_device_type{}; } // anonymous namespace @@ -68,34 +58,4 @@ void set_requires_device_init(at::DeviceType device_type, bool value) { is_initialized[static_cast(device_type)] = !value; } -bool is_device_in_bad_fork(at::DeviceType device_type) { - return is_in_bad_fork[static_cast(device_type)]; -} - -void set_device_in_bad_fork(at::DeviceType device_type, bool value) { - is_in_bad_fork[static_cast(device_type)] = value; -} - -// Should be called before the first device runtime call. -void register_fork_handler_for_device_init(at::DeviceType device_type) { -#ifndef WIN32 - auto& flag = at_fork_once_flags[static_cast(device_type)]; - c10::call_once(flag, [device_type]() { - TORCH_CHECK( - !at_fork_device_type, - "Only one device type can be registered. But now, we have two device types: ", - at_fork_device_type.value(), - " and ", - device_type); - at_fork_device_type = device_type; - pthread_atfork(nullptr, nullptr, []() { - set_device_in_bad_fork(at_fork_device_type.value(), true); - if (is_device_lazy_init_supported(at_fork_device_type.value())) { - set_requires_device_init(at_fork_device_type.value(), true); - } - }); - }); -#endif -} - } // namespace torch::utils diff --git a/torch/csrc/utils/device_lazy_init.h b/torch/csrc/utils/device_lazy_init.h index e65f16ace16..e1f480a60f7 100644 --- a/torch/csrc/utils/device_lazy_init.h +++ b/torch/csrc/utils/device_lazy_init.h @@ -67,21 +67,4 @@ inline void maybe_initialize_device( bool is_device_initialized(at::DeviceType device_type); -TORCH_PYTHON_API bool is_device_in_bad_fork(at::DeviceType device_type); - -TORCH_PYTHON_API void set_device_in_bad_fork( - at::DeviceType device_type, - bool value); - -TORCH_PYTHON_API void register_fork_handler_for_device_init( - at::DeviceType device_type); - -inline void maybe_register_fork_handler_for_device_init( - std::optional& device_type) { - if (!device_type.has_value()) { - return; - } - register_fork_handler_for_device_init(device_type.value()); -} - } // namespace torch::utils diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index 8144dddd829..43ad06365ef 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -11,8 +11,32 @@ #include #include +#ifndef WIN32 +#include +#endif + using namespace torch; +static bool in_bad_fork = false; // True for children forked after xpu init + +#ifndef WIN32 +// Called in the forked child if xpu has already been initialized +static void forked_child() { + in_bad_fork = true; + torch::utils::set_requires_device_init(at::kXPU, true); +} +#endif + +// Should be called before the first xpu call. It is mainly called in lazy_init. +// Note: This is distinct from initExtension because a stub xpu implementation +// has some working functions (e.g. device_count) but cannot fully initialize. +static void poison_fork() { +#ifndef WIN32 + static auto result [[maybe_unused]] = + pthread_atfork(nullptr, nullptr, forked_child); +#endif +} + // XPU management methods static PyObject* THXPModule_getArchFlags(PyObject* self, PyObject* noargs) { @@ -28,7 +52,7 @@ static PyObject* THXPModule_getArchFlags(PyObject* self, PyObject* noargs) { static PyObject* THXPModule_isInBadFork_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kXPU)); + return PyBool_FromLong(in_bad_fork); END_HANDLE_TH_ERRORS } @@ -91,9 +115,7 @@ static PyObject* THXPModule_getDeviceCount_wrap( PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - // Note: This is distinct from initExtension because a stub xpu implementation - // has some working functions (e.g. device_count) but cannot fully initialize. - torch::utils::register_fork_handler_for_device_init(at::kXPU); + poison_fork(); return THPUtils_packUInt64(at::xpu::device_count()); END_HANDLE_TH_ERRORS } @@ -398,8 +420,8 @@ static void initXpuMethodBindings(PyObject* module) { // classes static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kXPU)); - torch::utils::register_fork_handler_for_device_init(at::kXPU); + TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level + poison_fork(); at::globalContext().lazyInitDevice(c10::DeviceType::XPU); auto m = THPObjectPtr(PyImport_ImportModule("torch.xpu"));