Support NUMA Binding for Callable Entrypoints, Take 2 (#161183)

# Context
In #160163, we added support for NUMA binding for `Callable` entrypoints to `elastic_launch`. This requires special consideration, because they go through a different path to spawn subprocesses compared to `str` entrypoints, a path which does not provide a straightforward way to utilize `numactl` CLI. See #160006 for a full description of the challenges.

Although #160163 worked in initial local experiments, we ran into some linker errors in other environments when we tried to call `numactl`. This appeared to be due to interactions with how the `LD_PRELOAD` environment variable was being set.

# This PR
On further thought, the most straightforward, foolproof solution here is to use [the trick that @d4l3k suggested.](https://github.com/pytorch/pytorch/issues/160006#issuecomment-3162018836)

Specifically, for each local rank `i`:
1. The parent process sets its own CPU affinity to what local rank `i`'s should be.
2. Then, the parent spawns the subprocess for local rank `i`.
3. Finally, the parent resets its own CPU affinity to what it was originally.

There were other solutions that would work just for `Callable` entrypoints, but I believe this is the simplest one that can work for *both* `str` and `Callable`, and it's pretty simple.

This required a bit of refactoring:
1. Turn all the `_get_.*_numactl_options` into functions which return a set of logical CPUs to bind to, rather than options like `--cpunodebind=0`.
2. Instead of wrapping commands with `numactl`, use `os.sched_setaffinity` to bind to the CPUs from (1.).
3. Put this all inside a context manager which encapsulates applying and restoring the bindings in the parent process.
4. Use the context manager for both `str` and `Callable` paths

# Test Plan
## Automated
`$ pytest test/test_numa_binding.py`

## Manual
See [doc.](https://docs.google.com/document/d/1vxD-OKYBTT27jbBwtW9iz9g0tNM0u-i0tiTJg_ieQA8/edit?tab=t.0) Meta only, but TLDR tried out every combination of `str`, `Callable`, binding disabled, and binding enabled on the same model and saw 2x SM utilization for binding enabled.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161183
Approved by: https://github.com/d4l3k
This commit is contained in:
Paul de Supinski 2025-08-23 07:23:20 +00:00 committed by PyTorch MergeBot
parent 431846a632
commit 33346b5814
6 changed files with 360 additions and 496 deletions

View File

@ -1221,6 +1221,9 @@ coverage_ignore_functions = [
"reduce_typed_storage_child",
"storage_from_cache",
# torch.multiprocessing.spawn
# Added docstring for this but I think we need to go through
# and add the entire torch.multiprocessing.spawn module to a .rst...
"should_use_parallel_start",
"start_processes",
# torch.nn.functional
"adaptive_max_pool1d_with_indices", # documented as adaptive_max_pool1d

View File

@ -3,12 +3,10 @@
from __future__ import annotations
import json
import multiprocessing.spawn as spawn
import os
import subprocess
import sys
import tempfile
from dataclasses import dataclass
from multiprocessing.context import SpawnProcess
from typing import Any, Optional
from unittest import skipUnless
from unittest.mock import mock_open, patch
@ -16,6 +14,9 @@ from unittest.mock import mock_open, patch
import torch
from torch._utils_internal import signpost_event
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes
from torch.distributed.elastic.multiprocessing.subprocess_handler import (
SubprocessHandler,
)
from torch.numa.binding import (
_get_ranges_str_from_ints,
_get_set_of_int_from_ranges_str,
@ -40,7 +41,6 @@ class MockDeviceProperties:
_real_open = open
_real_mkstemp = tempfile.mkstemp
@skipUnless(sys.platform == "linux", "Only linux currently supported")
@ -56,7 +56,6 @@ class NumaBindingTest(TestCase):
self._mock_num_logical_cpus = 0
self._mock_num_numa_nodes = 0
self._mock_num_sockets = 0
self._temp_file_paths = []
self._context_managers_to_apply_to_all_tests = [
patch("torch.cuda.device_count", self._mock_device_count),
@ -67,9 +66,6 @@ class NumaBindingTest(TestCase):
patch("builtins.open", new=self._mock_open),
patch("os.listdir", new=self._mock_listdir),
patch("os.sched_getaffinity", new=self._mock_sched_getaffinity),
patch("shutil.which", return_value="/usr/bin/numactl"),
patch("torch.numa.binding.run"),
patch("torch.numa.binding.mkstemp", self._mock_mkstemp),
patch("torch.numa.binding.signpost_event", self._mock_signpost_event),
]
@ -77,14 +73,6 @@ class NumaBindingTest(TestCase):
context_manager.__enter__()
def tearDown(self) -> None:
# Clean up temporary files
for temp_file_path in self._temp_file_paths:
try:
os.unlink(temp_file_path)
except FileNotFoundError:
# File may have already been deleted or doesn't exist
pass
for context_manager in self._context_managers_to_apply_to_all_tests:
context_manager.__exit__(None, None, None)
super().tearDown()
@ -94,12 +82,6 @@ class NumaBindingTest(TestCase):
json.dumps(kwargs["parameters"])
return signpost_event(*args, **kwargs)
def _mock_mkstemp(self, *args, **kwargs):
# Just keep track of temp files so we can delete them
fd, path = _real_mkstemp(*args, **kwargs)
self._temp_file_paths.append(path)
return fd, path
def _add_mock_hardware(
self,
*,
@ -249,18 +231,41 @@ class NumaBindingTest(TestCase):
def _mock_sched_getaffinity(self, pid: int) -> set[int]:
return set(range(self._mock_num_logical_cpus))
def _start_processes_for_str_entrypoint_and_get_Popen_args(
def _start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
self, *, numa_options: Optional[NumaOptions], target_local_rank: int
) -> tuple[str, ...]:
"""
Calls start_processes like elastic_launch ultimately would
and returns the commandline args tuple input to Popen.
) -> Optional[set[int]]:
active_local_rank = None
target_sched_setaffinity_logical_cpu_indices = None
Does not actually create the processes.
"""
with patch(
"torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler.Popen"
) as mock_popen:
real_subprocess_handler_init = SubprocessHandler.__init__
def mock_SubprocessHandler__init__(*args, **kwargs) -> None:
nonlocal active_local_rank
active_local_rank = kwargs["local_rank_id"]
return real_subprocess_handler_init(*args, **kwargs)
def mock_sched_setaffinity(*args, **kwargs) -> None:
nonlocal target_sched_setaffinity_logical_cpu_indices
if (
active_local_rank == target_local_rank
# We only care about the first call, not the second
# one where it gets reset
and target_sched_setaffinity_logical_cpu_indices is None
):
target_sched_setaffinity_logical_cpu_indices = args[1]
with (
patch(
"os.sched_setaffinity", mock_sched_setaffinity
) as mock_sched_setaffinity,
patch(
"torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler.Popen"
),
patch(
"torch.distributed.elastic.multiprocessing.subprocess_handler.SubprocessHandler.__init__",
mock_SubprocessHandler__init__,
),
):
start_processes(
name="test_process",
entrypoint="echo",
@ -273,40 +278,40 @@ class NumaBindingTest(TestCase):
logs_specs=DefaultLogsSpecs(),
numa_options=numa_options,
)
# This will raise an exception if there is no call from the desired local_rank
call_args = next(
call_args
for call_args in mock_popen.call_args_list
if call_args.kwargs.get("env", {}).get("LOCAL_RANK")
== str(target_local_rank)
)
return call_args.kwargs["args"]
def _start_processes_for_callable_entrypoint_and_get_executable_contents(
return target_sched_setaffinity_logical_cpu_indices
def _start_processes_for_callable_entrypoint_and_get_sched_setaffinity_cpus(
self, *, numa_options: Optional[NumaOptions], target_local_rank: int
) -> str:
) -> Optional[set[int]]:
active_local_rank = None
executable_path = None
target_sched_setaffinity_logical_cpu_indices = None
def _mock_process_start(self: Any) -> None:
real_process__init__ = SpawnProcess.__init__
def _mock_process__init__(*args, **kwargs) -> None:
nonlocal active_local_rank
active_local_rank = self._args[1]
spawn.get_command_line()
self._target(*self._args)
active_local_rank = kwargs["args"][1]
return real_process__init__(*args, **kwargs)
original_get_command_line = spawn.get_command_line
def _mock_get_command_line(*args, **kwargs) -> list[str]:
nonlocal executable_path
result = original_get_command_line(*args, **kwargs)
if active_local_rank == target_local_rank:
executable_path = result[0]
return result
def mock_sched_setaffinity(*args, **kwargs) -> None:
nonlocal target_sched_setaffinity_logical_cpu_indices
if (
active_local_rank == target_local_rank
# We only care about the first call, not the second
# one where it gets reset
and target_sched_setaffinity_logical_cpu_indices is None
):
target_sched_setaffinity_logical_cpu_indices = args[1]
with (
patch("multiprocessing.context.SpawnProcess.start", _mock_process_start),
patch("multiprocessing.spawn.get_command_line", _mock_get_command_line),
patch(
"os.sched_setaffinity", mock_sched_setaffinity
) as mock_sched_setaffinity,
patch("multiprocessing.context.SpawnProcess.start"),
patch(
"multiprocessing.context.SpawnProcess.__init__", _mock_process__init__
),
patch("multiprocessing.process.BaseProcess.sentinel", 1),
# Prevent hanging
patch(
@ -325,9 +330,7 @@ class NumaBindingTest(TestCase):
numa_options=numa_options,
)
assert executable_path is not None
with open(executable_path) as executable_file:
return executable_file.read()
return target_sched_setaffinity_logical_cpu_indices
def test_node_numa_binding(self) -> None:
self._add_mock_hardware(
@ -338,20 +341,19 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=2,
)
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=11,
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=11,
)
)
self.assertEqual(
command_args,
bound_logical_cpu_indices,
# There are 8 numa nodes and 2 GPUs per numa node, so GPU 11 would be
# on numa node 11 // 2 = 5.
(
"numactl",
"--cpunodebind=5",
"echo",
"Hello, world!",
),
# Each numa node has 4 * 2 * 2 = 16 logical CPUs
# Numa node 5 has CPUs 80-95
set(range(80, 96)),
)
def test_no_numa_binding_if_numa_options_not_provided(self) -> None:
@ -363,15 +365,14 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=2,
)
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=None, target_local_rank=11
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=None, target_local_rank=11
)
)
self.assertEqual(
command_args,
(
"echo",
"Hello, world!",
),
bound_logical_cpu_indices,
None,
)
def test_default_numa_binding(self) -> None:
@ -407,7 +408,7 @@ class NumaBindingTest(TestCase):
def test_fallback(self) -> None:
self._add_mock_hardware(
num_sockets=1,
num_sockets=2,
num_numa_nodes_per_socket=1,
num_gpus_per_numa_node=1,
num_l3_caches_per_numa_node=1,
@ -417,28 +418,27 @@ class NumaBindingTest(TestCase):
with (
patch("torch.numa.binding.signpost_event") as signpost_patch,
patch(
"torch.numa.binding.run",
side_effect=subprocess.CalledProcessError(1, "numactl"),
"torch.numa.binding._get_numa_node_index_for_gpu_index",
side_effect=Exception("Mock exception!"),
),
):
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(
affinity_mode=AffinityMode.NODE,
should_fall_back_if_binding_fails=True,
),
target_local_rank=0,
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(
affinity_mode=AffinityMode.NODE,
should_fall_back_if_binding_fails=True,
),
target_local_rank=0,
)
)
self.assertIn(
"subprocess.CalledProcessError",
"Mock exception!",
signpost_patch.call_args.kwargs["parameters"]["traceback"],
)
self.assertEqual(
command_args,
# No numa bindings due to exception
(
"echo",
"Hello, world!",
),
bound_logical_cpu_indices,
# We should just reset to the original CPU affinity, which is all the CPUs
set(range(4)),
)
def test_explicit_numa_options_overrides_default(self) -> None:
@ -460,7 +460,7 @@ class NumaBindingTest(TestCase):
NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
)
def test_fork_start_method_does_not_call_get_default_numa_options(self) -> None:
def test_parallel_start_does_not_call_get_default_numa_options(self) -> None:
# Inner import to avoid crashing if not torch.distributed.is_available()
from torch.distributed.launcher.api import LaunchConfig
@ -475,16 +475,14 @@ class NumaBindingTest(TestCase):
with patch(
"torch.distributed.launcher.api.get_default_numa_options"
) as mock_get_default_numa_options:
os.environ["TORCH_MP_PARALLEL_START"] = "1"
launch_config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=2,
start_method="fork",
# Don't provide numa_options
start_method="forkserver",
)
# Verify get_default_numa_options was not called
mock_get_default_numa_options.assert_not_called()
# Verify numa_options is None when start_method is fork
self.assertIsNone(launch_config.numa_options)
def test_nproc_must_equal_cuda_device_count_to_use_default_numa_options(
@ -509,9 +507,7 @@ class NumaBindingTest(TestCase):
max_nodes=1,
nproc_per_node=2,
)
# Verify get_default_numa_options was not called
mock_get_default_numa_options.assert_not_called()
# Verify numa_options is None when start_method is fork
self.assertIsNone(launch_config.numa_options)
def test_socket_numa_binding_with_multiple_numa_per_socket(self) -> None:
@ -523,18 +519,18 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=2,
)
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET),
target_local_rank=15,
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET),
target_local_rank=15,
)
)
self.assertEqual(
command_args,
(
"numactl",
"--cpunodebind=6-7",
"echo",
"Hello, world!",
),
bound_logical_cpu_indices,
# GPU 15 is on numa node 15 // 2 = 7, which is on socket 3 (numa nodes 6 and 7)
# Each numa node has 4 * 2 * 2 = 16 logical CPUs
# Numa nodes 6 and 7 have CPUs 96-111 and 112-127
set(range(96, 128)),
)
def test_socket_numa_binding_with_single_numa_per_socket(self) -> None:
@ -546,18 +542,18 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=2,
)
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET),
target_local_rank=7,
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET),
target_local_rank=7,
)
)
self.assertEqual(
command_args,
(
"numactl",
"--cpunodebind=3",
"echo",
"Hello, world!",
),
bound_logical_cpu_indices,
# GPU 7 is on numa node 7 // 2 = 3, which is socket 3 by itself
# Each numa node has 4 * 2 * 2 = 16 logical CPUs
# Numa node 3 has CPUs 48-63
set(range(48, 64)),
)
def test_exclusive_numa_binding(self) -> None:
@ -569,34 +565,30 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=3,
)
command_args_0 = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
target_local_rank=0,
bound_logical_cpu_indices_0 = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
target_local_rank=0,
)
)
self.assertEqual(
command_args_0,
(
"numactl",
# Gets an extra physical core due to odd number of physical cores on numa node
"--physcpubind=0-3",
"echo",
"Hello, world!",
),
bound_logical_cpu_indices_0,
# Gets an extra physical core due to odd number of physical cores on numa node
# 3 physical cores total, 2 GPUs: GPU 0 gets 2 physical cores (CPUs 0-3)
set(range(0, 4)),
)
command_args_1 = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
target_local_rank=1,
bound_logical_cpu_indices_1 = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
target_local_rank=1,
)
)
self.assertEqual(
command_args_1,
(
"numactl",
# Does not get an extra physical core, since the 1st GPU already took the extra.
"--physcpubind=4-5",
"echo",
"Hello, world!",
),
bound_logical_cpu_indices_1,
# Does not get an extra physical core, since the 1st GPU already took the extra.
# GPU 1 gets 1 physical core (CPUs 4-5)
set(range(4, 6)),
)
def test_exclusive_raises_if_too_few_physical_cores(self) -> None:
@ -612,7 +604,7 @@ class NumaBindingTest(TestCase):
RuntimeError,
"There are only 1 physical cores on numa_node_index=0, but there are 2 GPUs associated with this NUMA node.",
):
self._start_processes_for_str_entrypoint_and_get_Popen_args(
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
target_local_rank=1,
)
@ -626,19 +618,18 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=3,
)
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=3,
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=3,
)
)
self.assertEqual(
command_args,
(
"numactl",
# The second L3 on the second numa node
"--physcpubind=24-29",
"echo",
"Hello, world!",
),
bound_logical_cpu_indices,
# GPU 3 is on numa node 3 // 2 = 1, relative GPU index is 3 % 2 = 1
# The second L3 on the second numa node (numa node 1)
# Second numa node starts at CPU 18, second L3 cache is CPUs 24-29
set(range(24, 30)),
)
def test_core_complex_numa_binding_with_fewer_l3_than_gpu(self) -> None:
@ -650,20 +641,18 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=3,
)
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=3,
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=3,
)
)
self.assertEqual(
command_args,
(
"numactl",
# There are only 2 L3 caches, so the 4th GPU shares the same
# cores as the 3rd GPU.
"--physcpubind=6-11",
"echo",
"Hello, world!",
),
bound_logical_cpu_indices,
# GPU 3 is on numa node 3 // 2 = 1, relative GPU index is 3 % 2 = 1
# With 1 L3 cache per numa node, GPU 3 uses L3 cache index 1 % 1 = 0 (the only cache)
# Second numa node starts at CPU 6, single L3 cache spans CPUs 6-11
set(range(6, 12)),
)
def test_core_complex_prefers_caches_with_more_cpus(self) -> None:
@ -677,20 +666,17 @@ class NumaBindingTest(TestCase):
# Only some subset of the CPUs are available this time.
with patch("os.sched_getaffinity", return_value={0, 4, 6, 7, 9}):
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=0,
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=0,
)
)
self.assertEqual(
command_args,
(
"numactl",
# Binds to the second L3 because it has the most available CPUs
"--physcpubind=6-7,9",
"echo",
"Hello, world!",
),
bound_logical_cpu_indices,
# Binds to the second L3 because it has the most available CPUs
{6, 7, 9},
)
def test_core_complex_tiebreak_prefers_lower_cache_key(self) -> None:
@ -706,36 +692,19 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=1,
)
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=0,
)
self.assertEqual(
command_args,
(
"numactl",
"--physcpubind=0-1",
"echo",
"Hello, world!",
),
)
def test_raises_error_if_numactl_unavailable(self) -> None:
self._add_mock_hardware(
num_sockets=1,
num_numa_nodes_per_socket=1,
num_gpus_per_numa_node=1,
num_l3_caches_per_numa_node=1,
num_physical_core_per_l3_cache=1,
)
with (
patch("shutil.which", return_value=None),
self.assertRaisesRegex(RuntimeError, r".*numactl.*"),
):
self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=0,
)
)
self.assertEqual(
bound_logical_cpu_indices,
# 1 numa node, 2 L3 caches, 1 physical core per L3 cache = 2 logical CPUs per cache
# L3 cache 0: CPUs 0-1, L3 cache 1: CPUs 2-3
# Both have same number of CPUs, so prefer lower cache key (0)
set(range(0, 2)),
)
def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None:
self._add_mock_hardware(
@ -755,18 +724,18 @@ class NumaBindingTest(TestCase):
contents="-1",
)
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=0,
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=0,
)
)
self.assertEqual(
command_args,
(
"numactl",
"--cpunodebind=0",
"echo",
"Hello, world!",
),
bound_logical_cpu_indices,
# GPU 0 has numa node stored as -1, which is treated as numa node 0
# Each numa node has 1 * 1 * 2 = 2 logical CPUs
# Numa node 0 has CPUs 0-1
set(range(0, 2)),
)
def test_callable_entrypoint_basic(self) -> None:
@ -778,27 +747,41 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=2,
)
executable_contents = (
self._start_processes_for_callable_entrypoint_and_get_executable_contents(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=11,
)
bound_logical_cpu_indices = self._start_processes_for_callable_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=11,
)
self.assertEqual(
executable_contents,
bound_logical_cpu_indices,
# There are 8 numa nodes and 2 GPUs per numa node, so GPU 11 would be
# on numa node 11 // 2 = 5.
f"""#!/bin/bash
# If this file is more than a few minutes old and still exists on your machine,
# that is NOT expected. It should have deleted itself. If you are seeing an accumulation of such
# files, that could suggest a bug in pytorch. See https://github.com/pytorch/pytorch/pull/160163.
rm -- "$0"
numactl --cpunodebind=5 {sys.executable} "$@"
""",
# Each numa node has 4 * 2 * 2 = 16 logical CPUs
# Numa node 5 has CPUs 80-95
set(range(80, 96)),
)
def test_raises_if_binding_to_empty_set(self) -> None:
self._add_mock_hardware(
num_sockets=1,
num_numa_nodes_per_socket=1,
num_gpus_per_numa_node=1,
num_l3_caches_per_numa_node=1,
num_physical_core_per_l3_cache=1,
)
with (
patch(
"torch.numa.binding._get_logical_cpus_to_bind_to", return_value=set()
),
self.assertRaisesRegex(
RuntimeError, "Must bind to a non-empty set of CPU indices"
),
):
self._start_processes_for_callable_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=0,
)
def test_get_set_of_int_from_ranges_str(self) -> None:
self.assertEqual(
_get_set_of_int_from_ranges_str("0-2,4,6-7"), {0, 1, 2, 4, 6, 7}

View File

@ -11,7 +11,10 @@ import sys
from subprocess import Popen
from typing import Any, Optional
from torch.numa.binding import maybe_wrap_command_with_numa_bindings, NumaOptions
from torch.numa.binding import (
maybe_temporarily_apply_numa_binding_to_current_process,
NumaOptions,
)
__all__ = ["SubprocessHandler"]
@ -50,22 +53,20 @@ class SubprocessHandler:
env_vars.update(env)
args_str = (entrypoint, *[str(e) for e in args])
args_str = (
maybe_wrap_command_with_numa_bindings(
command_args=args_str,
gpu_index=local_rank_id,
numa_options=numa_options,
)
or args_str
)
self.local_rank_id = local_rank_id
self.proc: Popen = self._popen(args_str, env_vars)
# See HACK [NUMA inheritance] in spawn.py for context.
with maybe_temporarily_apply_numa_binding_to_current_process(
gpu_index=local_rank_id, numa_options=numa_options
):
self.proc: Popen = self._popen(args_str, env_vars)
def _popen(self, args: tuple, env: dict[str, str]) -> Popen:
kwargs: dict[str, Any] = {}
if not IS_WINDOWS:
kwargs["start_new_session"] = True
return Popen(
# pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes],
# _PathLike[str], bytes, str]], bytes, str]` for 1st param but got

View File

@ -26,6 +26,7 @@ from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
from torch.distributed.elastic.utils.logging import get_logger
from torch.multiprocessing.spawn import should_use_parallel_start
from torch.numa.binding import NumaOptions
@ -109,9 +110,11 @@ class LaunchConfig:
if (
self.numa_options is None
# NOTE: This filter isn't relevant for str entrypoints,
# but it's the default anyway.
and self.start_method == "spawn"
# The way we apply NUMA bindings currently depends
# on the processes being started sequentially.
# Technically, this filter does not matter for str entrypoints,
# but we ignore that nuance for now.
and not should_use_parallel_start(self.start_method)
and torch.cuda.is_available()
# We assume local_rank n uses cuda device n.
and torch.cuda.device_count() == self.nproc_per_node

View File

@ -2,7 +2,6 @@
import logging
import multiprocessing
import multiprocessing.connection
import multiprocessing.spawn as mp_spawn
import os
import pickle
import signal
@ -14,7 +13,7 @@ from concurrent.futures import as_completed, ThreadPoolExecutor
from typing import Optional
from torch.numa.binding import (
maybe_get_temporary_python_executable_with_numa_bindings,
maybe_temporarily_apply_numa_binding_to_current_process,
NumaOptions,
)
@ -30,6 +29,7 @@ __all__ = [
"ProcessException",
"ProcessExitedException",
"ProcessRaisedException",
"should_use_parallel_start",
"spawn",
"SpawnContext",
"start_processes",
@ -227,6 +227,17 @@ class SpawnContext(ProcessContext):
super().__init__(processes, error_files)
def should_use_parallel_start(start_method: str) -> bool:
"""
Returns:
Whether we will start subprocesses in parallel.
"""
return (
start_method == "forkserver"
and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
)
# Note: [start_processes]
# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
@ -248,53 +259,21 @@ def start_processes(
# this func will start processes in parallel if start_method is 'forkserver'.
# Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
# todo: investigate why spawn does not work with threadpool and raises SIGINT
if (
start_method == "forkserver"
and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
):
if should_use_parallel_start(start_method):
log.info("Starting processes in parallel.")
start_parallel = True
else:
# Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
start_parallel = False
if numa_options is not None and start_method != "spawn":
raise ValueError("NUMA binding is only compatible with spawn")
if numa_options is not None and start_parallel:
raise ValueError("NUMA binding is not compatible with parallel start")
mp = multiprocessing.get_context(start_method)
error_files = [None] * nprocs
processes = [None] * nprocs
original_executable = mp_spawn.get_executable()
def start_process(i):
# HACK: We want to force Process.start() to kick off the subprocess
# using a custom numactl command per rank. However, the API exposed
# by multiprocessing only allows us to override the executable for
# the entire context, and only with a single str rather than a tuple.
# Furthermore, there is no API for passing additional options, e.g.
# to make LOCAL_RANK available to the executable.
#
# In order to get around these limitations, we pre-compute
# the appropriate command containing NUMA bindings and store it in a
# temporary executable which passes Python args on to the original
# executable. Then, we call set_executable before and after each
# Process.start() call.
#
# This assumes that, under the hood, Process.start() for rank n
# will not call get_executable after start_process for rank n+1
# calls set_executable again. We guarantee this by
# raising an exception if `start_parallel`, above. (Not clear
# if there would be a race condition otherwise, but we want to be safe.)
temporary_executable_path = (
maybe_get_temporary_python_executable_with_numa_bindings(
python_executable_path=original_executable,
gpu_index=i,
numa_options=numa_options,
)
)
# Each process is assigned a file to write tracebacks to. We
# use the file being non-empty to indicate an exception
# occurred (vs an expected shutdown). Note: this previously
@ -307,18 +286,29 @@ def start_processes(
tf.close()
os.unlink(tf.name)
try:
if temporary_executable_path is not None:
mp.set_executable(temporary_executable_path)
process = mp.Process(
target=_wrap,
args=(fn, i, args, tf.name),
daemon=daemon,
)
process = mp.Process(
target=_wrap,
args=(fn, i, args, tf.name),
daemon=daemon,
)
# HACK [NUMA inheritance]: Subprocesses inherit the parent process's CPU
# affinity. So, we temporarily apply the bindings to the current process,
# and then immediately undo them.
# This is necessary because the alternatives would be to
# either
# 1. Use numactl CLI. However, Python's multiprocessing library
# does not provide an API which would allow us to prepend
# the command it runs with numactl options.
# 2. Wrap the provided function such that it first applies
# NUMA bindings, and then executes as expected. However, this
# can result in worse memory locality, because torch and CUDA
# initialization would occur before applying the bindings, thus
# allowing some memory to be allocated on the wrong NUMA nodes.
with maybe_temporarily_apply_numa_binding_to_current_process(
gpu_index=i, numa_options=numa_options
):
process.start()
finally:
if temporary_executable_path is not None:
mp.set_executable(original_executable)
return i, process, tf.name
if not start_parallel:

View File

@ -1,15 +1,11 @@
import os
import shutil
import stat
import subprocess
import traceback
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from enum import Enum
from logging import getLogger
from subprocess import run
from tempfile import mkstemp
from typing import Callable, Optional, TypeVar
import torch
@ -18,13 +14,10 @@ from torch._utils_internal import signpost_event
__all__ = [
"AffinityMode",
"maybe_get_temporary_python_executable_with_numa_bindings",
"maybe_wrap_command_with_numa_bindings",
"maybe_temporarily_apply_numa_binding_to_current_process",
"NumaOptions",
]
_NUMACTL_COMMAND = "numactl"
logger = getLogger(__name__)
@ -54,248 +47,136 @@ class NumaOptions:
should_fall_back_if_binding_fails: bool = False
def maybe_get_temporary_python_executable_with_numa_bindings(
*, python_executable_path: str, gpu_index: int, numa_options: Optional[NumaOptions]
) -> Optional[str]:
@contextmanager
def maybe_temporarily_apply_numa_binding_to_current_process(
*, gpu_index: int, numa_options: Optional[NumaOptions]
) -> Iterator[None]:
"""
Args:
python_executable_path: E.g., "/usr/local/bin/python"
Returns:
Path to a temporary file. This file can be executed just like the original python
executable, except it will first apply NUMA bindings.
1. Applies NUMA binding to the current process, suitable for the process
which will be interacting with GPU gpu_index.
2. Resets to the original CPU affinity before exiting the context manager.
"""
if numa_options is None:
logger.info("Received numa_options=None, not creating numa executable.")
return None
yield
return
if isinstance(python_executable_path, bytes):
python_executable_path = python_executable_path.decode()
full_numactl_command = maybe_wrap_command_with_numa_bindings(
# "$@", i.e. pass through any args the python executable would have
# received.
command_args=(python_executable_path, '"$@"'),
gpu_index=gpu_index,
numa_options=numa_options,
original_logical_cpu_indices = _get_allowed_cpu_indices_for_current_process()
_apply_numa_binding_to_current_process(
gpu_index=gpu_index, numa_options=numa_options
)
yield
_bind_current_process_to_logical_cpus(
logical_cpu_indices=original_logical_cpu_indices
)
if full_numactl_command is None:
return None
executable_path = _get_temporary_executable_for_command(
command_args=full_numactl_command
)
logger.info("Returning python executable with NUMA bindings %s", executable_path)
return executable_path
def maybe_wrap_command_with_numa_bindings(
*,
command_args: tuple[str, ...],
gpu_index: int,
numa_options: Optional[NumaOptions],
) -> Optional[tuple[str, ...]]:
"""
Args:
command_args: Full shell command, like ("/usr/local/bin/python", "train.py")
gpu_index: The index of the GPU which command_args should bind to
Returns:
command_args, but wrapped so that it runs with NUMA bindings corresponding to
gpu_index and numa_options.
E.g., ("numactl", "--cpunodebind=0", "/usr/local/bin/python", "train.py")
"""
if not numa_options:
logger.info("Received numa_options=None, not applying bindings.")
return None
def _apply_numa_binding_to_current_process(
*, gpu_index: int, numa_options: NumaOptions
) -> None:
kwargs = {
"command_args": command_args,
"gpu_index": gpu_index,
"numa_options": asdict(numa_options),
}
logger.info("Attempting to wrap command with NUMA bindings, given input %r", kwargs)
logger.info("Attempting to apply NUMA binding, given input %r", kwargs)
try:
_raise_if_numactl_not_available()
numactl_options = _get_numactl_cli_options(
command_args=command_args, gpu_index=gpu_index, numa_options=numa_options
)
logger.info("Computed numactl_options=%r", numactl_options)
_raise_if_numactl_fails_dry_run(numactl_options=numactl_options)
logger.info("Validated numactl_options=%r", numactl_options)
full_numactl_command = _get_assembled_command_from_pieces(
command_args=command_args, numactl_options=numactl_options
logical_cpu_indices = _get_logical_cpus_to_bind_to(
gpu_index=gpu_index, numa_options=numa_options
)
logger.info(
"Successfully wrapped command with numa_bindings. Returning %r",
full_numactl_command,
"Computed logical_cpu_indices=%s for NUMA binding",
_get_ranges_str_from_ints(logical_cpu_indices),
)
_raise_if_logical_cpu_indices_invalid(logical_cpu_indices=logical_cpu_indices)
logger.info(
"Validated logical_cpu_indices=%s for NUMA binding",
_get_ranges_str_from_ints(logical_cpu_indices),
)
_bind_current_process_to_logical_cpus(logical_cpu_indices=logical_cpu_indices)
logger.info(
"Successfully bound to logical_cpu_indices=%r for NUMA binding",
_get_ranges_str_from_ints(logical_cpu_indices),
)
signpost_event(
category="numa_binding",
name="wrap_command_success",
parameters={**kwargs, "result": full_numactl_command},
name="apply_success",
parameters={
**kwargs,
"logical_cpu_indices": _get_ranges_str_from_ints(logical_cpu_indices),
},
)
return full_numactl_command
except Exception:
signpost_event(
category="numa_binding",
name="wrap_command_exception",
name="apply_exception",
parameters={
**kwargs,
"traceback": traceback.format_exc(),
},
)
logger.exception(
"Failed to wrap command with NUMA bindings for input = %r", kwargs
)
logger.exception("Failed to apply NUMA binding for input=%r", kwargs)
if numa_options.should_fall_back_if_binding_fails:
logger.warning("Falling back to original command without NUMA bindings.")
logger.warning(
"Continuing executing without applying NUMA binding, despite exception %s",
traceback.format_exc(),
)
return None
raise
def _get_temporary_executable_for_command(
def _raise_if_logical_cpu_indices_invalid(*, logical_cpu_indices: set[int]) -> None:
if not logical_cpu_indices:
raise RuntimeError("Must bind to a non-empty set of CPU indices")
def _bind_current_process_to_logical_cpus(*, logical_cpu_indices: set[int]) -> None:
# 0 represents the current process
os.sched_setaffinity(0, logical_cpu_indices)
def _get_logical_cpus_to_bind_to(
*,
command_args: tuple[str, ...],
) -> str:
"""
Returns:
Path to a temporary file which executes the specified command. The executable
deletes itself the first time it runs, so do not try to run it multiple times.
"""
fd, path = mkstemp(
prefix="pytorch-numa-bind",
suffix=".sh",
)
# We do rm first to guarantee the file deletes itself. The rest of the file
# will still run as intended.
contents = f"""#!/bin/bash
# If this file is more than a few minutes old and still exists on your machine,
# that is NOT expected. It should have deleted itself. If you are seeing an accumulation of such
# files, that could suggest a bug in pytorch. See https://github.com/pytorch/pytorch/pull/160163.
rm -- "$0"
{" ".join(command_args)}
"""
with os.fdopen(fd, "w") as file:
file.write(contents)
# Ensure the file is fully synced, in order to avoid race condition
# from trying to execute it too early.
file.flush()
os.fsync(fd)
# Make the script executable
os.chmod(path, stat.S_IRWXU)
logger.info(
"Created temporary executable at path %s, with contents\n%s", path, contents
)
return path
def _get_numactl_cli_options(
*,
command_args: tuple[str, ...],
gpu_index: int,
numa_options: NumaOptions,
) -> tuple[str, ...]:
) -> set[int]:
"""
Args:
command_args: The args for a command, such as might be input to Popen.
Example: ("python", "trainer.py")
gpu_index: The index of the GPU that will be used by the subprocess which executes command_args.
gpu_index: The index of the GPU that will be used by the subprocess.
Example: 0
numa_options: See NumaOptions for details.
Returns:
Depending on numa_options, something like
("--cpunodebind=0")
Set of logical CPU indices to bind to.
"""
if numa_options.affinity_mode == AffinityMode.NODE:
numactl_command_options = _get_node_numactl_options(gpu_index=gpu_index)
logical_cpus = _node_get_logical_cpus_to_bind_to(gpu_index=gpu_index)
elif numa_options.affinity_mode == AffinityMode.SOCKET:
numactl_command_options = _get_socket_numactl_options(gpu_index=gpu_index)
logical_cpus = _socket_get_logical_cpus_to_bind_to(gpu_index=gpu_index)
elif numa_options.affinity_mode == AffinityMode.EXCLUSIVE:
numactl_command_options = _get_exclusive_numactl_options(gpu_index=gpu_index)
logical_cpus = _exclusive_get_logical_cpus_to_bind_to(gpu_index=gpu_index)
elif numa_options.affinity_mode == AffinityMode.CORE_COMPLEX:
numactl_command_options = _get_core_complex_numactl_options(gpu_index=gpu_index)
logical_cpus = _core_complex_get_logical_cpus_to_bind_to(gpu_index=gpu_index)
else:
raise ValueError(f"Affinity mode {numa_options.affinity_mode} not supported.")
return numactl_command_options
return logical_cpus
def _raise_if_numactl_fails_dry_run(*, numactl_options: tuple[str, ...]) -> None:
noop_args = _get_assembled_command_from_pieces(
# Execute arbitrary noop
command_args=("true",),
numactl_options=numactl_options,
)
temporary_executable_path = _get_temporary_executable_for_command(
command_args=noop_args
)
try:
run(
(temporary_executable_path,),
stdout=subprocess.DEVNULL,
# These allow us to capture the stderr as text
stderr=subprocess.PIPE,
text=True,
# Raise exception if nonzero exit status.
check=True,
)
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"""Our binding logic inferred to prepend your command with options {noop_args[:-1]}.
Before doing that, we did a noop dry run with args {noop_args}, but that command failed.
This should NOT happen, and likely suggests a bug in pytorch's numa binding logic.
The {_NUMACTL_COMMAND} command itself had this stderr:
{e.stderr}
"""
) from e
def _get_assembled_command_from_pieces(
*, command_args: tuple[str, ...], numactl_options: tuple[str, ...]
) -> tuple[str, ...]:
# Syntax for invoking a command but with numactl activated is numactl <args> command <args>
return (_NUMACTL_COMMAND, *numactl_options, *command_args)
def _raise_if_numactl_not_available() -> None:
if not shutil.which(_NUMACTL_COMMAND):
raise RuntimeError(
f"{_NUMACTL_COMMAND} shell command is required for NUMA bindings."
)
def _get_node_numactl_options(*, gpu_index: int) -> tuple[str, ...]:
def _node_get_logical_cpus_to_bind_to(*, gpu_index: int) -> set[int]:
"""
Core logic of 'node' numa strategy.
Returns options to be used with numactl. E.g.,
("--cpunodebind=0").
"""
numa_node_index = _get_numa_node_index_for_gpu_index(gpu_index=gpu_index)
return (f"--cpunodebind={numa_node_index}",)
return _get_allowed_logical_cpu_indices_for_numa_node(
numa_node_index=numa_node_index
)
def _get_socket_numactl_options(*, gpu_index: int) -> tuple[str, ...]:
def _socket_get_logical_cpus_to_bind_to(*, gpu_index: int) -> set[int]:
"""
Core logic of 'socket' numa strategy.
"""
@ -306,12 +187,19 @@ def _get_socket_numactl_options(*, gpu_index: int) -> tuple[str, ...]:
numa_node_indices = _get_numa_node_indices_for_socket_index(
socket_index=socket_index
)
numa_node_indices_str = _get_ranges_str_from_ints(numa_node_indices)
return (f"--cpunodebind={numa_node_indices_str}",)
logical_cpus = set()
for numa_node_index in numa_node_indices:
logical_cpus.update(
_get_allowed_logical_cpu_indices_for_numa_node(
numa_node_index=numa_node_index
)
)
return logical_cpus
def _get_exclusive_numactl_options(*, gpu_index: int) -> tuple[str, ...]:
def _exclusive_get_logical_cpus_to_bind_to(*, gpu_index: int) -> set[int]:
"""
Core logic of 'exclusive' numa strategy.
"""
@ -370,20 +258,18 @@ def _get_exclusive_numactl_options(*, gpu_index: int) -> tuple[str, ...]:
)
# Slice and flatten the logical CPUs from the selected physical cores
logical_cpu_indices_for_original_gpu = (
logical_cpu_indices_for_original_gpu = {
logical_cpu_index
for logical_cpu_indices in list(
physical_core_to_allowed_logical_cpu_indices.values()
)[start:end]
for logical_cpu_index in logical_cpu_indices
)
}
return (
f"--physcpubind={_get_ranges_str_from_ints(logical_cpu_indices_for_original_gpu)}",
)
return logical_cpu_indices_for_original_gpu
def _get_core_complex_numactl_options(*, gpu_index: int) -> tuple[str, ...]:
def _core_complex_get_logical_cpus_to_bind_to(*, gpu_index: int) -> set[int]:
"""
Core logic of 'core-complex' numa strategy.
@ -427,9 +313,7 @@ def _get_core_complex_numactl_options(*, gpu_index: int) -> tuple[str, ...]:
max_level_cache_to_allowed_logical_cpu_indices.values()
)[cache_index_for_original_gpu]
return (
f"--physcpubind={_get_ranges_str_from_ints(logical_cpu_indices_for_original_gpu)}",
)
return logical_cpu_indices_for_original_gpu
K = TypeVar("K")