mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] parallel compile: add import of thread_safe_fork for internal (#137155)
Summary: We had a report of crashes in parallel compile subprocesses linked to reading justknobs. See https://fburl.com/workplace/14a4mcbh internally. This is a known issue with justknobs. It looks like we don't have a lot of control over evaluating knobs. Some are read in inductor (`"pytorch/remote_cache:autotune_memcache_version`), but many are read by the triton compiler. According to this advice https://fburl.com/workplace/imx9lsx3, we can import thread_safe_fork which installs some functionality to destroy some singletons before forking and re-enable them after. This apporach works for the failing workload. Test Plan: See D63719673 where the reporting user was kind enough to provide us with a local repro. Without the relevant import, we can reproduce the crash. With the import, the training runs successfully to completion. Differential Revision: D63736829 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137155 Approved by: https://github.com/xmfan, https://github.com/eellison
This commit is contained in:
parent
f96020c246
commit
8bb8c3997b
|
|
@ -15,6 +15,10 @@ from concurrent.futures import Future, ProcessPoolExecutor
|
||||||
from concurrent.futures.process import BrokenProcessPool
|
from concurrent.futures.process import BrokenProcessPool
|
||||||
from typing import Any, Callable, Dict
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
|
# _thread_safe_fork is needed because the subprocesses in the pool can read
|
||||||
|
# justknobs, e.g., in the Triton compiler. For internal, the import installs
|
||||||
|
# functionality to destroy singletons before forking and re-enable them after.
|
||||||
|
import torch._thread_safe_fork # noqa: F401
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
|
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -563,6 +563,20 @@ _fuse_ddp_communication_passes: List[Union[Callable[..., None], str]] = [
|
||||||
_micro_pipeline_tp: bool = False
|
_micro_pipeline_tp: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_compile_enabled_internally() -> bool:
|
||||||
|
"""
|
||||||
|
TODO: Remove when parallel compiled is fully enabled internally. For rollout, use a
|
||||||
|
knob to enable / disable. The justknob should not be performed at import, however.
|
||||||
|
So for fbcode, we assign compile_threads to 'None' below and initialize lazily in
|
||||||
|
async_compile.py.
|
||||||
|
"""
|
||||||
|
ENABLE_PARALLEL_COMPILE_VERSION = 1
|
||||||
|
|
||||||
|
jk_name = "pytorch/inductor:enable_parallel_compile_version"
|
||||||
|
version = torch._utils_internal.justknobs_getval_int(jk_name)
|
||||||
|
return ENABLE_PARALLEL_COMPILE_VERSION >= version
|
||||||
|
|
||||||
|
|
||||||
def decide_compile_threads() -> int:
|
def decide_compile_threads() -> int:
|
||||||
"""
|
"""
|
||||||
Here are the precedence to decide compile_threads
|
Here are the precedence to decide compile_threads
|
||||||
|
|
@ -575,13 +589,7 @@ def decide_compile_threads() -> int:
|
||||||
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
|
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
|
||||||
elif sys.platform == "win32":
|
elif sys.platform == "win32":
|
||||||
return 1
|
return 1
|
||||||
# TODO: For internal rollout, we use a killswitch to disable. The justknob should
|
elif is_fbcode() and not parallel_compile_enabled_internally():
|
||||||
# not be performed at import, however. So for fbcode, we assign compile_threads to
|
|
||||||
# None below and call this method lazily in async_compile.py. Remove this after
|
|
||||||
# rollout completes.
|
|
||||||
elif is_fbcode() and not torch._utils_internal.justknobs_check(
|
|
||||||
"pytorch/inductor:enable_parallel_compile"
|
|
||||||
):
|
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
cpu_count = (
|
cpu_count = (
|
||||||
|
|
|
||||||
0
torch/_thread_safe_fork.py
Normal file
0
torch/_thread_safe_fork.py
Normal file
Loading…
Reference in New Issue
Block a user