mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] parallel compile: add import of thread_safe_fork for internal (#137155)
Summary: We had a report of crashes in parallel compile subprocesses linked to reading justknobs. See https://fburl.com/workplace/14a4mcbh internally. This is a known issue with justknobs. It looks like we don't have a lot of control over evaluating knobs. Some are read in inductor (`"pytorch/remote_cache:autotune_memcache_version`), but many are read by the triton compiler. According to this advice https://fburl.com/workplace/imx9lsx3, we can import thread_safe_fork which installs some functionality to destroy some singletons before forking and re-enable them after. This apporach works for the failing workload. Test Plan: See D63719673 where the reporting user was kind enough to provide us with a local repro. Without the relevant import, we can reproduce the crash. With the import, the training runs successfully to completion. Differential Revision: D63736829 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137155 Approved by: https://github.com/xmfan, https://github.com/eellison
This commit is contained in:
parent
f96020c246
commit
8bb8c3997b
|
|
@ -15,6 +15,10 @@ from concurrent.futures import Future, ProcessPoolExecutor
|
|||
from concurrent.futures.process import BrokenProcessPool
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
# _thread_safe_fork is needed because the subprocesses in the pool can read
|
||||
# justknobs, e.g., in the Triton compiler. For internal, the import installs
|
||||
# functionality to destroy singletons before forking and re-enable them after.
|
||||
import torch._thread_safe_fork # noqa: F401
|
||||
from torch._inductor import config
|
||||
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
|
||||
|
||||
|
|
|
|||
|
|
@ -563,6 +563,20 @@ _fuse_ddp_communication_passes: List[Union[Callable[..., None], str]] = [
|
|||
_micro_pipeline_tp: bool = False
|
||||
|
||||
|
||||
def parallel_compile_enabled_internally() -> bool:
|
||||
"""
|
||||
TODO: Remove when parallel compiled is fully enabled internally. For rollout, use a
|
||||
knob to enable / disable. The justknob should not be performed at import, however.
|
||||
So for fbcode, we assign compile_threads to 'None' below and initialize lazily in
|
||||
async_compile.py.
|
||||
"""
|
||||
ENABLE_PARALLEL_COMPILE_VERSION = 1
|
||||
|
||||
jk_name = "pytorch/inductor:enable_parallel_compile_version"
|
||||
version = torch._utils_internal.justknobs_getval_int(jk_name)
|
||||
return ENABLE_PARALLEL_COMPILE_VERSION >= version
|
||||
|
||||
|
||||
def decide_compile_threads() -> int:
|
||||
"""
|
||||
Here are the precedence to decide compile_threads
|
||||
|
|
@ -575,13 +589,7 @@ def decide_compile_threads() -> int:
|
|||
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
|
||||
elif sys.platform == "win32":
|
||||
return 1
|
||||
# TODO: For internal rollout, we use a killswitch to disable. The justknob should
|
||||
# not be performed at import, however. So for fbcode, we assign compile_threads to
|
||||
# None below and call this method lazily in async_compile.py. Remove this after
|
||||
# rollout completes.
|
||||
elif is_fbcode() and not torch._utils_internal.justknobs_check(
|
||||
"pytorch/inductor:enable_parallel_compile"
|
||||
):
|
||||
elif is_fbcode() and not parallel_compile_enabled_internally():
|
||||
return 1
|
||||
else:
|
||||
cpu_count = (
|
||||
|
|
|
|||
0
torch/_thread_safe_fork.py
Normal file
0
torch/_thread_safe_fork.py
Normal file
Loading…
Reference in New Issue
Block a user