[inductor] parallel compile: add import of thread_safe_fork for internal (#137155)

Summary: We had a report of crashes in parallel compile subprocesses linked to reading justknobs. See https://fburl.com/workplace/14a4mcbh internally. This is a known issue with justknobs. It looks like we don't have a lot of control over evaluating knobs. Some are read in inductor (`"pytorch/remote_cache:autotune_memcache_version`), but many are read by the triton compiler. According to this advice https://fburl.com/workplace/imx9lsx3, we can import thread_safe_fork which installs some functionality to destroy some singletons before forking and re-enable them after. This apporach works for the failing workload.

Test Plan: See D63719673 where the reporting user was kind enough to provide us with a local repro. Without the relevant import, we can reproduce the crash. With the import, the training runs successfully to completion.

Differential Revision: D63736829

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137155
Approved by: https://github.com/xmfan, https://github.com/eellison
This commit is contained in:
Sam Larsen 2024-10-03 17:37:21 +00:00 committed by PyTorch MergeBot
parent f96020c246
commit 8bb8c3997b
3 changed files with 19 additions and 7 deletions

View File

@ -15,6 +15,10 @@ from concurrent.futures import Future, ProcessPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from typing import Any, Callable, Dict
# _thread_safe_fork is needed because the subprocesses in the pool can read
# justknobs, e.g., in the Triton compiler. For internal, the import installs
# functionality to destroy singletons before forking and re-enable them after.
import torch._thread_safe_fork # noqa: F401
from torch._inductor import config
from torch._inductor.compile_worker.watchdog import _async_compile_initializer

View File

@ -563,6 +563,20 @@ _fuse_ddp_communication_passes: List[Union[Callable[..., None], str]] = [
_micro_pipeline_tp: bool = False
def parallel_compile_enabled_internally() -> bool:
"""
TODO: Remove when parallel compiled is fully enabled internally. For rollout, use a
knob to enable / disable. The justknob should not be performed at import, however.
So for fbcode, we assign compile_threads to 'None' below and initialize lazily in
async_compile.py.
"""
ENABLE_PARALLEL_COMPILE_VERSION = 1
jk_name = "pytorch/inductor:enable_parallel_compile_version"
version = torch._utils_internal.justknobs_getval_int(jk_name)
return ENABLE_PARALLEL_COMPILE_VERSION >= version
def decide_compile_threads() -> int:
"""
Here are the precedence to decide compile_threads
@ -575,13 +589,7 @@ def decide_compile_threads() -> int:
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
elif sys.platform == "win32":
return 1
# TODO: For internal rollout, we use a killswitch to disable. The justknob should
# not be performed at import, however. So for fbcode, we assign compile_threads to
# None below and call this method lazily in async_compile.py. Remove this after
# rollout completes.
elif is_fbcode() and not torch._utils_internal.justknobs_check(
"pytorch/inductor:enable_parallel_compile"
):
elif is_fbcode() and not parallel_compile_enabled_internally():
return 1
else:
cpu_count = (

View File