pytorch/torch/distributed/launcher/api.py
Paul de Supinski 33346b5814 Support NUMA Binding for Callable Entrypoints, Take 2 (#161183)
# Context
In #160163, we added support for NUMA binding for `Callable` entrypoints to `elastic_launch`. This requires special consideration, because they go through a different path to spawn subprocesses compared to `str` entrypoints, a path which does not provide a straightforward way to utilize `numactl` CLI. See #160006 for a full description of the challenges.

Although #160163 worked in initial local experiments, we ran into some linker errors in other environments when we tried to call `numactl`. This appeared to be due to interactions with how the `LD_PRELOAD` environment variable was being set.

# This PR
On further thought, the most straightforward, foolproof solution here is to use [the trick that @d4l3k suggested.](https://github.com/pytorch/pytorch/issues/160006#issuecomment-3162018836)

Specifically, for each local rank `i`:
1. The parent process sets its own CPU affinity to what local rank `i`'s should be.
2. Then, the parent spawns the subprocess for local rank `i`.
3. Finally, the parent resets its own CPU affinity to what it was originally.

There were other solutions that would work just for `Callable` entrypoints, but I believe this is the simplest one that can work for *both* `str` and `Callable`, and it's pretty simple.

This required a bit of refactoring:
1. Turn all the `_get_.*_numactl_options` into functions which return a set of logical CPUs to bind to, rather than options like `--cpunodebind=0`.
2. Instead of wrapping commands with `numactl`, use `os.sched_setaffinity` to bind to the CPUs from (1.).
3. Put this all inside a context manager which encapsulates applying and restoring the bindings in the parent process.
4. Use the context manager for both `str` and `Callable` paths

# Test Plan
## Automated
`$ pytest test/test_numa_binding.py`

## Manual
See [doc.](https://docs.google.com/document/d/1vxD-OKYBTT27jbBwtW9iz9g0tNM0u-i0tiTJg_ieQA8/edit?tab=t.0) Meta only, but TLDR tried out every combination of `str`, `Callable`, binding disabled, and binding enabled on the same model and saw 2x SM utilization for binding enabled.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161183
Approved by: https://github.com/d4l3k
2025-08-23 07:23:22 +00:00

320 lines
12 KiB
Python

#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import sys
import uuid
from dataclasses import dataclass, field
from typing import Any, Callable, Optional, Union
import torch
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
from torch._utils_internal import get_default_numa_options
from torch.distributed.elastic import events, metrics
from torch.distributed.elastic.agent.server.api import WorkerSpec
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from torch.distributed.elastic.multiprocessing import (
DefaultLogsSpecs,
LogsSpecs,
SignalException,
)
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
from torch.distributed.elastic.utils.logging import get_logger
from torch.multiprocessing.spawn import should_use_parallel_start
from torch.numa.binding import NumaOptions
__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"]
logger = get_logger(__name__)
@dataclass
class LaunchConfig:
"""
Creates a rendezvous config.
Args:
min_nodes: Minimum amount of nodes that the user function will
be launched on. Elastic agent ensures that the user
function start only when the min_nodes amount enters
the rendezvous.
max_nodes: Maximum amount of nodes that the user function
will be launched on.
nproc_per_node: On each node the elastic agent will launch
this amount of workers that will execute user
defined function.
rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd).
rdzv_endpoint: The endpoint of the rdzv sync. storage.
rdzv_configs: Key, value pair that specifies rendezvous specific configuration.
rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going
to be removed in future versions, see the note below. The default timeout is 900 seconds.
run_id: The unique run id of the job (if not passed a unique one will be
deduced from run environment - flow workflow id in flow - or auto generated).
role: User defined role of the worker (defaults to "trainer").
max_restarts: The maximum amount of restarts that elastic agent will conduct
on workers before failure.
monitor_interval: The interval in seconds that is used by the elastic_agent
as a period of monitoring workers.
start_method: The method is used by the elastic agent to start the
workers (spawn, fork, forkserver).
metrics_cfg: configuration to initialize metrics.
local_addr: address of the local node if any. If not set, a lookup on the local
machine's FQDN will be performed.
local_ranks_filter: ranks for which to show logs in console. If not set, show from all.
event_log_handler: name of the event logging handler as registered in
`elastic/events/handlers.py <https://docs.pytorch.org/docs/stable/elastic/events.html>`_.
.. note::
`rdzv_timeout` is a legacy argument that will be removed in future.
Set the timeout via `rdzv_configs['timeout']`
"""
min_nodes: int
max_nodes: int
nproc_per_node: int
logs_specs: Optional[LogsSpecs] = None
run_id: str = ""
role: str = "default_role"
rdzv_endpoint: str = ""
rdzv_backend: str = "etcd"
rdzv_configs: dict[str, Any] = field(default_factory=dict)
rdzv_timeout: int = -1
max_restarts: int = 3
monitor_interval: float = 0.1
start_method: str = "spawn"
log_line_prefix_template: Optional[str] = None
metrics_cfg: dict[str, str] = field(default_factory=dict)
local_addr: Optional[str] = None
event_log_handler: str = "null"
numa_options: Optional[NumaOptions] = None
def __post_init__(self):
default_timeout = 900
if self.rdzv_timeout != -1:
self.rdzv_configs["timeout"] = self.rdzv_timeout
elif "timeout" not in self.rdzv_configs:
self.rdzv_configs["timeout"] = default_timeout
# Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage
if self.logs_specs is None:
self.logs_specs = DefaultLogsSpecs()
if (
self.numa_options is None
# The way we apply NUMA bindings currently depends
# on the processes being started sequentially.
# Technically, this filter does not matter for str entrypoints,
# but we ignore that nuance for now.
and not should_use_parallel_start(self.start_method)
and torch.cuda.is_available()
# We assume local_rank n uses cuda device n.
and torch.cuda.device_count() == self.nproc_per_node
):
self.numa_options = get_default_numa_options()
logger.info("Using default numa options = %r", self.numa_options)
class elastic_launch:
"""
Launches an torchelastic agent on the container that invoked the entrypoint.
1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
``entrypoint`` can be a function or a command.
2. The return value is a map of each worker's output mapped
by their respective global rank.
Usage
::
def worker_fn(foo):
# ...
def main():
# entrypoint is a function.
outputs = elastic_launch(LaunchConfig, worker_fn)(foo)
# return rank 0's output
return outputs[0]
# entrypoint is a command and ``script.py`` is the python module.
outputs = elastic_launch(LaunchConfig, "script.py")(args)
outputs = elastic_launch(LaunchConfig, "python")("script.py")
"""
def __init__(
self,
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
):
self._config = config
self._entrypoint = entrypoint
def __call__(self, *args):
return launch_agent(self._config, self._entrypoint, list(args))
def _get_entrypoint_name(
entrypoint: Union[Callable, str, None], args: list[Any]
) -> str:
"""Retrieve entrypoint name with the rule:
1. If entrypoint is a function, use ``entrypoint.__qualname__``.
2. If entrypoint is a string, check its value:
2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args``
which does not start with hifen letter (for example, "-u" will be skipped).
2.2 otherwise, use ``entrypoint`` value.
3. Otherwise, return empty string.
"""
if isinstance(entrypoint, Callable): # type: ignore[arg-type]
return entrypoint.__name__ # type: ignore[union-attr]
elif isinstance(entrypoint, str):
if entrypoint == sys.executable:
return next((arg for arg in args if arg[0] != "-"), "")
else:
return entrypoint
else:
return ""
def _get_addr_and_port(
rdzv_parameters: RendezvousParameters,
) -> tuple[Optional[str], Optional[int]]:
if rdzv_parameters.backend != "static":
return (None, None)
endpoint = rdzv_parameters.endpoint
endpoint = endpoint.strip()
if not endpoint:
raise ValueError(
"Endpoint is missing in endpoint. Try to add --master-addr and --master-port"
)
master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1)
if master_port == -1:
raise ValueError(
f"port is missing in endpoint: {endpoint}. Try to specify --master-port"
)
return (master_addr, master_port)
def launch_agent(
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
args: list[Any],
) -> dict[int, Any]:
if not config.run_id:
run_id = str(uuid.uuid4().int)
logger.warning("config has no run_id, generated a random run_id: %s", run_id)
config.run_id = run_id
entrypoint_name = _get_entrypoint_name(entrypoint, args)
logger.info(
"Starting elastic_operator with launch configs:\n"
" entrypoint : %(entrypoint)s\n"
" min_nodes : %(min_nodes)s\n"
" max_nodes : %(max_nodes)s\n"
" nproc_per_node : %(nproc_per_node)s\n"
" run_id : %(run_id)s\n"
" rdzv_backend : %(rdzv_backend)s\n"
" rdzv_endpoint : %(rdzv_endpoint)s\n"
" rdzv_configs : %(rdzv_configs)s\n"
" max_restarts : %(max_restarts)s\n"
" monitor_interval : %(monitor_interval)s\n"
" log_dir : %(log_dir)s\n"
" metrics_cfg : %(metrics_cfg)s\n"
" event_log_handler : %(event_log_handler)s\n"
" numa_options : %(numa_options)s\n",
{
"entrypoint": entrypoint_name,
"min_nodes": config.min_nodes,
"max_nodes": config.max_nodes,
"nproc_per_node": config.nproc_per_node,
"run_id": config.run_id,
"rdzv_backend": config.rdzv_backend,
"rdzv_endpoint": config.rdzv_endpoint,
"rdzv_configs": config.rdzv_configs,
"max_restarts": config.max_restarts,
"monitor_interval": config.monitor_interval,
"log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr]
"metrics_cfg": config.metrics_cfg,
"event_log_handler": config.event_log_handler,
"numa_options": config.numa_options,
},
)
rdzv_parameters = RendezvousParameters(
backend=config.rdzv_backend,
endpoint=config.rdzv_endpoint,
run_id=config.run_id,
min_nodes=config.min_nodes,
max_nodes=config.max_nodes,
local_addr=config.local_addr,
**config.rdzv_configs,
)
master_addr, master_port = _get_addr_and_port(rdzv_parameters)
spec = WorkerSpec(
role=config.role,
local_world_size=config.nproc_per_node,
entrypoint=entrypoint,
args=tuple(args),
rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
max_restarts=config.max_restarts,
monitor_interval=config.monitor_interval,
master_addr=master_addr,
master_port=master_port,
local_addr=config.local_addr,
event_log_handler=config.event_log_handler,
numa_options=config.numa_options,
)
agent = LocalElasticAgent(
spec=spec,
logs_specs=config.logs_specs, # type: ignore[arg-type]
start_method=config.start_method,
log_line_prefix_template=config.log_line_prefix_template,
)
shutdown_rdzv = True
try:
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
result = agent.run()
# records that agent.run() has succeeded NOT that workers have succeeded
events.record(agent.get_event_succeeded(), config.event_log_handler)
if result.is_failed():
# ChildFailedError is treated specially by @record
# if the error files for the failed children exist
# @record will copy the first error (root cause)
# to the error file of the launcher process.
raise ChildFailedError(
name=entrypoint_name,
failures=result.failures,
)
return result.return_values
except ChildFailedError:
raise
except SignalException:
# when the agent dies with a signal do NOT shutdown the rdzv_handler
# since this closes the rendezvous on this rdzv_id permanently and
# prevents any additional scaling events
shutdown_rdzv = False
events.record(agent.get_event_failed(), config.event_log_handler)
raise
except Exception:
events.record(agent.get_event_failed(), config.event_log_handler)
raise
finally:
if shutdown_rdzv:
spec.rdzv_handler.shutdown()