pytorch/torch/multiprocessing/spawn.py
Paul de Supinski 33346b5814 Support NUMA Binding for Callable Entrypoints, Take 2 (#161183)
# Context
In #160163, we added support for NUMA binding for `Callable` entrypoints to `elastic_launch`. This requires special consideration, because they go through a different path to spawn subprocesses compared to `str` entrypoints, a path which does not provide a straightforward way to utilize `numactl` CLI. See #160006 for a full description of the challenges.

Although #160163 worked in initial local experiments, we ran into some linker errors in other environments when we tried to call `numactl`. This appeared to be due to interactions with how the `LD_PRELOAD` environment variable was being set.

# This PR
On further thought, the most straightforward, foolproof solution here is to use [the trick that @d4l3k suggested.](https://github.com/pytorch/pytorch/issues/160006#issuecomment-3162018836)

Specifically, for each local rank `i`:
1. The parent process sets its own CPU affinity to what local rank `i`'s should be.
2. Then, the parent spawns the subprocess for local rank `i`.
3. Finally, the parent resets its own CPU affinity to what it was originally.

There were other solutions that would work just for `Callable` entrypoints, but I believe this is the simplest one that can work for *both* `str` and `Callable`, and it's pretty simple.

This required a bit of refactoring:
1. Turn all the `_get_.*_numactl_options` into functions which return a set of logical CPUs to bind to, rather than options like `--cpunodebind=0`.
2. Instead of wrapping commands with `numactl`, use `os.sched_setaffinity` to bind to the CPUs from (1.).
3. Put this all inside a context manager which encapsulates applying and restoring the bindings in the parent process.
4. Use the context manager for both `str` and `Callable` paths

# Test Plan
## Automated
`$ pytest test/test_numa_binding.py`

## Manual
See [doc.](https://docs.google.com/document/d/1vxD-OKYBTT27jbBwtW9iz9g0tNM0u-i0tiTJg_ieQA8/edit?tab=t.0) Meta only, but TLDR tried out every combination of `str`, `Callable`, binding disabled, and binding enabled on the same model and saw 2x SM utilization for binding enabled.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161183
Approved by: https://github.com/d4l3k
2025-08-23 07:23:22 +00:00

377 lines
14 KiB
Python

# mypy: allow-untyped-defs
import logging
import multiprocessing
import multiprocessing.connection
import os
import pickle
import signal
import sys
import tempfile
import time
import warnings
from concurrent.futures import as_completed, ThreadPoolExecutor
from typing import Optional
from torch.numa.binding import (
maybe_temporarily_apply_numa_binding_to_current_process,
NumaOptions,
)
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START"
log = logging.getLogger(__name__)
__all__ = [
"ProcessContext",
"ProcessException",
"ProcessExitedException",
"ProcessRaisedException",
"should_use_parallel_start",
"spawn",
"SpawnContext",
"start_processes",
]
class ProcessException(Exception):
__slots__ = ["error_index", "error_pid"]
def __init__(self, msg: str, error_index: int, pid: int):
super().__init__(msg)
self.msg = msg
self.error_index = error_index
self.pid = pid
def __reduce__(self):
return type(self), (self.msg, self.error_index, self.pid)
class ProcessRaisedException(ProcessException):
"""Exception raised when a process failed due to an exception raised by the code."""
def __init__(
self,
msg: str,
error_index: int,
error_pid: int,
):
super().__init__(msg, error_index, error_pid)
class ProcessExitedException(ProcessException):
"""Exception raised when a process failed due to signal or exited with a specific code."""
__slots__ = ["exit_code"]
def __init__(
self,
msg: str,
error_index: int,
error_pid: int,
exit_code: int,
signal_name: Optional[str] = None,
):
super().__init__(msg, error_index, error_pid)
self.exit_code = exit_code
self.signal_name = signal_name
def __reduce__(self):
return (
type(self),
(self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
)
def _wrap(fn, i, args, error_file):
# prctl(2) is a Linux specific system call.
# On other systems the following function call has no effect.
# This is set to ensure that non-daemonic child processes can
# terminate if their parent terminates before they do.
_prctl_pr_set_pdeathsig(signal.SIGINT)
try:
fn(i, *args)
except KeyboardInterrupt:
pass # SIGINT; Killed by parent, do nothing
except Exception:
# Propagate exception to parent process, keeping original traceback
import traceback
with open(error_file, "wb") as fh:
pickle.dump(traceback.format_exc(), fh)
sys.exit(1)
class ProcessContext:
def __init__(self, processes, error_files):
self.error_files = error_files
self.processes = processes
self.sentinels = {
process.sentinel: index for index, process in enumerate(processes)
}
def pids(self):
return [int(process.pid) for process in self.processes]
def _join_procs_with_timeout(self, timeout: float):
"""Attempt to join all processes with a shared timeout."""
end = time.monotonic() + timeout
for process in self.processes:
time_to_wait = max(0, end - time.monotonic())
process.join(time_to_wait)
def join(
self, timeout: Optional[float] = None, grace_period: Optional[float] = None
):
r"""Join one or more processes within spawn context.
Attempt to join one or more processes in this spawn context.
If one of them exited with a non-zero exit status, this function
kills the remaining processes (optionally with a grace period)
and raises an exception with the cause of the first process exiting.
Returns ``True`` if all processes have been joined successfully,
``False`` if there are more processes that need to be joined.
Args:
timeout (float): Wait this long (in seconds) before giving up on waiting.
grace_period (float): When any processes fail, wait this long (in seconds)
for others to shutdown gracefully before terminating them. If they
still don't exit, wait another grace period before killing them.
"""
# Ensure this function can be called even when we're done.
if len(self.sentinels) == 0:
return True
# Wait for any process to fail or all of them to succeed.
ready = multiprocessing.connection.wait(
self.sentinels.keys(),
timeout=timeout,
)
error_index = None
for sentinel in ready:
index = self.sentinels.pop(sentinel)
process = self.processes[index]
process.join()
if process.exitcode != 0:
error_index = index
break
# Return if there was no error.
if error_index is None:
# Return whether or not all processes have been joined.
return len(self.sentinels) == 0
# An error occurred. Clean-up all processes before returning.
# First, allow a grace period for processes to shutdown themselves.
if grace_period is not None:
self._join_procs_with_timeout(grace_period)
# Then, terminate processes that are still alive. Try SIGTERM first.
for process in self.processes:
if process.is_alive():
log.warning("Terminating process %s via signal SIGTERM", process.pid)
process.terminate()
# Try SIGKILL if the process isn't going down after another grace_period.
# The reason is related to python signal handling is limited
# to main thread and if that is in c/c++ land and stuck it won't
# to handle it. We have seen processes getting stuck not handling
# SIGTERM for the above reason.
self._join_procs_with_timeout(30 if grace_period is None else grace_period)
for process in self.processes:
if process.is_alive():
log.warning(
"Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL",
process.pid,
)
process.kill()
process.join()
# The file will only be created if the process crashed.
failed_process = self.processes[error_index]
if not os.access(self.error_files[error_index], os.R_OK):
exitcode = self.processes[error_index].exitcode
if exitcode < 0:
try:
name = signal.Signals(-exitcode).name
except ValueError:
name = f"<Unknown signal {-exitcode}>"
raise ProcessExitedException(
f"process {error_index:d} terminated with signal {name}",
error_index=error_index,
error_pid=failed_process.pid,
exit_code=exitcode,
signal_name=name,
)
else:
raise ProcessExitedException(
f"process {error_index:d} terminated with exit code {exitcode:d}",
error_index=error_index,
error_pid=failed_process.pid,
exit_code=exitcode,
)
with open(self.error_files[error_index], "rb") as fh:
original_trace = pickle.load(fh)
msg = f"\n\n-- Process {error_index:d} terminated with the following error:\n"
msg += original_trace
raise ProcessRaisedException(msg, error_index, failed_process.pid)
class SpawnContext(ProcessContext):
def __init__(self, processes, error_files):
warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.")
super().__init__(processes, error_files)
def should_use_parallel_start(start_method: str) -> bool:
"""
Returns:
Whether we will start subprocesses in parallel.
"""
return (
start_method == "forkserver"
and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
)
# Note: [start_processes]
# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'
# works better than 'spawn'. Every helper function we created for mp.spawn is indeed
# general enough, and backends like XLA can reuse them in Colab notebooks as well.
# Currently we only add this API first, we can consider adding it to documentation as
# needed in the future.
def start_processes(
fn,
args=(),
nprocs=1,
join=True,
daemon=False,
start_method="spawn",
numa_options: Optional[NumaOptions] = None,
):
# To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
# this func will start processes in parallel if start_method is 'forkserver'.
# Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
# todo: investigate why spawn does not work with threadpool and raises SIGINT
if should_use_parallel_start(start_method):
log.info("Starting processes in parallel.")
start_parallel = True
else:
# Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
start_parallel = False
if numa_options is not None and start_parallel:
raise ValueError("NUMA binding is not compatible with parallel start")
mp = multiprocessing.get_context(start_method)
error_files = [None] * nprocs
processes = [None] * nprocs
def start_process(i):
# Each process is assigned a file to write tracebacks to. We
# use the file being non-empty to indicate an exception
# occurred (vs an expected shutdown). Note: this previously
# used a multiprocessing.Queue but that can be prone to
# deadlocks, so we went with a simpler solution for a one-shot
# message between processes.
tf = tempfile.NamedTemporaryFile(
prefix="pytorch-errorfile-", suffix=".pickle", delete=False
)
tf.close()
os.unlink(tf.name)
process = mp.Process(
target=_wrap,
args=(fn, i, args, tf.name),
daemon=daemon,
)
# HACK [NUMA inheritance]: Subprocesses inherit the parent process's CPU
# affinity. So, we temporarily apply the bindings to the current process,
# and then immediately undo them.
# This is necessary because the alternatives would be to
# either
# 1. Use numactl CLI. However, Python's multiprocessing library
# does not provide an API which would allow us to prepend
# the command it runs with numactl options.
# 2. Wrap the provided function such that it first applies
# NUMA bindings, and then executes as expected. However, this
# can result in worse memory locality, because torch and CUDA
# initialization would occur before applying the bindings, thus
# allowing some memory to be allocated on the wrong NUMA nodes.
with maybe_temporarily_apply_numa_binding_to_current_process(
gpu_index=i, numa_options=numa_options
):
process.start()
return i, process, tf.name
if not start_parallel:
for i in range(nprocs):
idx, process, tf_name = start_process(i)
error_files[idx] = tf_name
processes[idx] = process
else:
with ThreadPoolExecutor(max_workers=nprocs) as executor:
futures = [executor.submit(start_process, i) for i in range(nprocs)]
for fut in as_completed(futures):
idx, process, tf_name = fut.result()
# idx and process rank needs to be the same.
error_files[idx] = tf_name
processes[idx] = process
context = ProcessContext(processes, error_files)
if not join:
return context
# Loop on join until it returns True or raises an exception.
while not context.join():
pass
def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"):
r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
If one of the processes exits with a non-zero exit status, the
remaining processes are killed and an exception is raised with the
cause of termination. In the case an exception was caught in the
child process, it is forwarded and its traceback is included in
the exception raised in the parent process.
Args:
fn (function): Function is called as the entrypoint of the
spawned process. This function must be defined at the top
level of a module so it can be pickled and spawned. This
is a requirement imposed by multiprocessing.
The function is called as ``fn(i, *args)``, where ``i`` is
the process index and ``args`` is the passed through tuple
of arguments.
args (tuple): Arguments passed to ``fn``.
nprocs (int): Number of processes to spawn.
join (bool): Perform a blocking join on all processes.
daemon (bool): The spawned processes' daemon flag. If set to True,
daemonic processes will be created.
start_method (str): (deprecated) this method will always use ``spawn``
as the start method. To use a different start method
use ``start_processes()``.
Returns:
None if ``join`` is ``True``,
:class:`~ProcessContext` if ``join`` is ``False``
"""
if start_method != "spawn":
msg = (
f"This method only supports start_method=spawn (got: {start_method}).\n"
"To use a different start_method use:\n\t\t"
" torch.multiprocessing.start_processes(...)"
)
warnings.warn(msg, FutureWarning, stacklevel=2)
return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")