mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
# Context Previously, we would modify the parent process's NUMA bindings in order to force child process to inherit them. However, this would not work correctly if `start_method="forkserver"`, because the subprocesses would actually inherit their bindings from the forkserver middleman process. In this case, the inherited affinity would actually be incorrect for all but the first subprocess (because the forkserver process would get created lazily, and hence inherit and then stick with the bindings intended for the first subprocess). # This PR * `str` entrypoints: Use `numactl` CLI * `Callable` entrypoints: Wrap the `Callable` entrypoint and call `os.sched_setaffinity` inside it. Hopefully this will be the last necessary iteration. # Test Plan ## Automated `$ pytest test/test_numa_binding.py` ## Manual Verified flops/sec and memory locality wins on several different types of jobs * `Callable` with forkserver * `str` entrypoint with spawn * `Callable` entrypoint with spawn More details in [this doc (Meta-only).](https://docs.google.com/document/d/1vxD-OKYBTT27jbBwtW9iz9g0tNM0u-i0tiTJg_ieQA8/edit?tab=t.scjv58yswi64) # Later PR Update all the documentation when we're confident this has stabilized. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166026 Approved by: https://github.com/d4l3k Co-authored-by: PyTorch MergeBot <pytorchmergebot@users.noreply.github.com>
792 lines
31 KiB
Python
792 lines
31 KiB
Python
# Owner(s): ["oncall: distributed"]
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import sys
|
||
from dataclasses import dataclass
|
||
from typing import Any, Optional
|
||
from unittest import skipUnless
|
||
from unittest.mock import mock_open, patch
|
||
|
||
import torch
|
||
from torch._utils_internal import signpost_event
|
||
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes
|
||
from torch.distributed.elastic.multiprocessing.api import _wrap
|
||
from torch.numa.binding import (
|
||
_bind_all_threads_in_current_process_to_logical_cpus,
|
||
_get_ranges_str_from_ints,
|
||
_get_set_of_int_from_ranges_str,
|
||
AffinityMode,
|
||
NumaOptions,
|
||
)
|
||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class MockDeviceProperties:
|
||
name: str
|
||
major: int
|
||
minor: int
|
||
total_memory: str
|
||
multi_processor_count: int
|
||
uuid: str
|
||
pci_bus_id: int
|
||
pci_device_id: int
|
||
pci_domain_id: int
|
||
L2_cache_size: str
|
||
|
||
|
||
_real_open = open
|
||
|
||
|
||
@skipUnless(sys.platform == "linux", "Only linux currently supported")
|
||
@skipUnless(
|
||
torch.distributed.is_available(), "Need access to some distributed submodules"
|
||
)
|
||
class NumaBindingTest(TestCase):
|
||
def setUp(self) -> None:
|
||
super().setUp()
|
||
|
||
self._mock_file_path_to_contents: dict[str, str] = {}
|
||
self._mock_device_properties: list[MockDeviceProperties] = []
|
||
self._mock_num_logical_cpus = 0
|
||
self._mock_num_numa_nodes = 0
|
||
self._mock_num_sockets = 0
|
||
|
||
self._context_managers_to_apply_to_all_tests = [
|
||
patch("torch.cuda.device_count", self._mock_device_count),
|
||
patch("torch.cuda.get_device_properties", self._mock_get_device_properties),
|
||
patch("torch.cuda.is_available", self._mock_is_available),
|
||
# Implicitly used by dynamo
|
||
patch("torch.cuda.get_rng_state"),
|
||
patch("builtins.open", new=self._mock_open),
|
||
patch("os.listdir", new=self._mock_listdir),
|
||
patch("os.sched_getaffinity", new=self._mock_sched_getaffinity),
|
||
patch("torch.numa.binding.signpost_event", self._mock_signpost_event),
|
||
patch("torch.numa.binding.shutil.which", return_value="/usr/bin/numactl"),
|
||
]
|
||
|
||
for context_manager in self._context_managers_to_apply_to_all_tests:
|
||
context_manager.__enter__()
|
||
|
||
def tearDown(self) -> None:
|
||
for context_manager in self._context_managers_to_apply_to_all_tests:
|
||
context_manager.__exit__(None, None, None)
|
||
super().tearDown()
|
||
|
||
def _mock_signpost_event(self, *args, **kwargs) -> None:
|
||
# Please keep these parameters JSON serializable for logging purposes
|
||
json.dumps(kwargs["parameters"])
|
||
return signpost_event(*args, **kwargs)
|
||
|
||
def _add_mock_hardware(
|
||
self,
|
||
*,
|
||
num_sockets: int,
|
||
num_numa_nodes_per_socket: int,
|
||
num_gpus_per_numa_node: int,
|
||
num_l3_caches_per_numa_node: int,
|
||
num_physical_core_per_l3_cache: int,
|
||
) -> None:
|
||
"""
|
||
It's not fun, but we mock everything down to sysfs level
|
||
to make sure we get really thorough coverage.
|
||
"""
|
||
for socket_index in range(num_sockets):
|
||
for numa_node_index in range(
|
||
self._mock_num_numa_nodes,
|
||
self._mock_num_numa_nodes + num_numa_nodes_per_socket,
|
||
):
|
||
self._mock_file_contents(
|
||
file_path=f"/sys/devices/system/node/node{numa_node_index}/cpulist",
|
||
contents=f"{self._mock_num_logical_cpus}-"
|
||
+ f"{self._mock_num_logical_cpus + num_l3_caches_per_numa_node * num_physical_core_per_l3_cache * 2 - 1}",
|
||
)
|
||
for gpu_index in range(
|
||
len(self._mock_device_properties),
|
||
len(self._mock_device_properties) + num_gpus_per_numa_node,
|
||
):
|
||
device_properties = MockDeviceProperties(
|
||
name=f"mock_gpu_{gpu_index}",
|
||
major=8,
|
||
minor=0,
|
||
total_memory="512GB",
|
||
multi_processor_count=256,
|
||
uuid=f"mock_gpu_uuid_{gpu_index}",
|
||
pci_bus_id=gpu_index,
|
||
pci_device_id=gpu_index,
|
||
pci_domain_id=gpu_index,
|
||
L2_cache_size="40MB",
|
||
)
|
||
self._mock_device_properties.append(device_properties)
|
||
pci_numa_node_path = (
|
||
self._get_corresponding_pci_numa_node_file_path(
|
||
device_properties=device_properties
|
||
)
|
||
)
|
||
self._mock_file_contents(
|
||
file_path=pci_numa_node_path,
|
||
contents=str(numa_node_index),
|
||
)
|
||
|
||
for _ in range(num_l3_caches_per_numa_node):
|
||
lowest_logical_cpu_index_on_l3 = self._mock_num_logical_cpus
|
||
highest_logical_cpu_index_on_l3 = (
|
||
self._mock_num_logical_cpus
|
||
+ 2 * num_physical_core_per_l3_cache
|
||
- 1
|
||
)
|
||
for logical_cpu_index in range(
|
||
self._mock_num_logical_cpus,
|
||
self._mock_num_logical_cpus
|
||
# Assume hyperthreaded
|
||
+ 2 * num_physical_core_per_l3_cache,
|
||
):
|
||
thread_siblings_range_str = (
|
||
f"{logical_cpu_index - 1}-{logical_cpu_index}"
|
||
if logical_cpu_index % 2
|
||
else f"{logical_cpu_index}-{logical_cpu_index + 1}"
|
||
)
|
||
self._mock_file_contents(
|
||
file_path=f"/sys/devices/system/cpu/cpu{logical_cpu_index}/topology/thread_siblings_list",
|
||
contents=thread_siblings_range_str,
|
||
)
|
||
# Unrelated file our logic should know to skip
|
||
self._mock_file_contents(
|
||
file_path=f"/sys/devices/system/cpu/cpu{logical_cpu_index}/cache/paulwuzhere",
|
||
contents="Data",
|
||
)
|
||
self._mock_file_contents(
|
||
file_path=f"/sys/devices/system/cpu/cpu{logical_cpu_index}/topology/physical_package_id",
|
||
contents=str(socket_index),
|
||
)
|
||
for cache_level in range(5):
|
||
self._mock_file_contents(
|
||
file_path=f"/sys/devices/system/cpu/cpu{logical_cpu_index}/cache/index{cache_level}/type",
|
||
contents="ShouldSkip" if cache_level == 4 else "Data",
|
||
)
|
||
self._mock_file_contents(
|
||
file_path=f"/sys/devices/system/cpu/cpu{logical_cpu_index}/cache/index{cache_level}/level",
|
||
contents=str(cache_level),
|
||
)
|
||
self._mock_file_contents(
|
||
file_path=f"/sys/devices/system/cpu/cpu{logical_cpu_index}/cache/index{cache_level}/shared_cpu_list",
|
||
contents=(
|
||
f"{lowest_logical_cpu_index_on_l3}-{highest_logical_cpu_index_on_l3}"
|
||
if cache_level == 3
|
||
# Assume L1-2 are per physical core
|
||
else thread_siblings_range_str
|
||
),
|
||
)
|
||
self._mock_num_logical_cpus += 1
|
||
self._mock_num_numa_nodes += 1
|
||
self._mock_num_sockets += 1
|
||
self._mock_file_contents(
|
||
file_path="/sys/devices/system/node/possible",
|
||
contents=f"0-{self._mock_num_numa_nodes - 1}",
|
||
)
|
||
|
||
def _mock_is_available(self) -> bool:
|
||
return len(self._mock_device_properties) > 0
|
||
|
||
def _get_corresponding_pci_numa_node_file_path(
|
||
self, *, device_properties: MockDeviceProperties
|
||
) -> str:
|
||
pci_addr = (
|
||
f"{device_properties.pci_domain_id:04x}:"
|
||
+ f"{device_properties.pci_bus_id:02x}:{device_properties.pci_device_id:02x}.0"
|
||
)
|
||
return f"/sys/bus/pci/devices/{pci_addr}/numa_node"
|
||
|
||
def _mock_file_contents(self, *, file_path: str, contents: str) -> None:
|
||
self._mock_file_path_to_contents[file_path] = contents
|
||
|
||
def _mock_device_count(self) -> int:
|
||
return len(self._mock_device_properties)
|
||
|
||
def _mock_get_device_properties(self, index: int) -> MockDeviceProperties:
|
||
return self._mock_device_properties[index]
|
||
|
||
def _mock_open(self, path: str, *args, **kwargs) -> Any:
|
||
if path in self._mock_file_path_to_contents:
|
||
return mock_open(read_data=self._mock_file_path_to_contents[path])()
|
||
if isinstance(path, str) and path.startswith("/sys/"):
|
||
raise FileNotFoundError(f"File {path} was not mocked.")
|
||
# Looks like CI is calling open and intending to open an actual file in some places.
|
||
# Need this to make the CI pass.
|
||
return _real_open(path, *args, **kwargs)
|
||
|
||
def _mock_listdir(self, target_path: str) -> set[str]:
|
||
if not target_path.endswith("/"):
|
||
target_path += "/"
|
||
return {
|
||
mock_path.split(target_path)[1].split("/")[0]
|
||
for mock_path in self._mock_file_path_to_contents
|
||
if mock_path.startswith(target_path)
|
||
}
|
||
|
||
def _mock_sched_getaffinity(self, pid: int) -> set[int]:
|
||
return set(range(self._mock_num_logical_cpus))
|
||
|
||
def _start_processes_for_str_entrypoint_and_get_command_args(
|
||
self, *, numa_options: Optional[NumaOptions], target_local_rank: int
|
||
) -> tuple[str, ...]:
|
||
with patch(
|
||
"torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler.Popen"
|
||
) as mock_popen:
|
||
start_processes(
|
||
name="test_process",
|
||
entrypoint="echo",
|
||
args=dict.fromkeys(
|
||
range(self._mock_device_count()), ("Hello, world!",)
|
||
),
|
||
envs={
|
||
i: {"LOCAL_RANK": str(i)} for i in range(self._mock_device_count())
|
||
},
|
||
logs_specs=DefaultLogsSpecs(),
|
||
numa_options=numa_options,
|
||
)
|
||
call_args = next(
|
||
call_args
|
||
for call_args in mock_popen.call_args_list
|
||
if call_args.kwargs.get("env", {}).get("LOCAL_RANK")
|
||
== str(target_local_rank)
|
||
)
|
||
return call_args.kwargs["args"]
|
||
|
||
def _start_processes_for_callable_entrypoint_and_get_sched_setaffinity_cpus(
|
||
self, *, numa_options: Optional[NumaOptions], target_local_rank: int
|
||
) -> Optional[set[int]]:
|
||
target_sched_setaffinity_logical_cpu_indices = None
|
||
|
||
def mock_sched_setaffinity(*args, **kwargs) -> None:
|
||
nonlocal target_sched_setaffinity_logical_cpu_indices
|
||
target_sched_setaffinity_logical_cpu_indices = args[1]
|
||
|
||
def dummy_fn():
|
||
pass
|
||
|
||
import torch.multiprocessing as mp
|
||
|
||
ctx = mp.get_context()
|
||
mock_queue = ctx.SimpleQueue()
|
||
mock_event = ctx.Event()
|
||
|
||
with patch("os.sched_setaffinity", mock_sched_setaffinity):
|
||
mock_event.set() # Prevent hanging
|
||
# This is the entrypoint for subprocesses with Callable entrypoints
|
||
_wrap(
|
||
local_rank=target_local_rank,
|
||
fn=dummy_fn,
|
||
args={target_local_rank: ()},
|
||
envs={target_local_rank: {}},
|
||
stdout_redirects={target_local_rank: ""},
|
||
stderr_redirects={target_local_rank: ""},
|
||
ret_vals={target_local_rank: mock_queue},
|
||
queue_finished_reading_event=mock_event,
|
||
numa_options=numa_options,
|
||
)
|
||
|
||
return target_sched_setaffinity_logical_cpu_indices
|
||
|
||
def test_node_numa_binding(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=4,
|
||
num_numa_nodes_per_socket=2,
|
||
num_gpus_per_numa_node=2,
|
||
num_l3_caches_per_numa_node=4,
|
||
num_physical_core_per_l3_cache=2,
|
||
)
|
||
|
||
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
|
||
target_local_rank=11,
|
||
)
|
||
self.assertEqual(
|
||
command_args,
|
||
# There are 8 numa nodes and 2 GPUs per numa node, so GPU 11 would be
|
||
# on numa node 11 // 2 = 5.
|
||
# Each numa node has 4 * 2 * 2 = 16 logical CPUs
|
||
# Numa node 5 has CPUs 80-95
|
||
("numactl", "--physcpubind=80-95", "echo", "Hello, world!"),
|
||
)
|
||
|
||
def test_no_numa_binding_if_numa_options_not_provided(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=4,
|
||
num_numa_nodes_per_socket=2,
|
||
num_gpus_per_numa_node=2,
|
||
num_l3_caches_per_numa_node=4,
|
||
num_physical_core_per_l3_cache=2,
|
||
)
|
||
|
||
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=None, target_local_rank=11
|
||
)
|
||
self.assertEqual(
|
||
command_args,
|
||
("echo", "Hello, world!"),
|
||
)
|
||
|
||
def test_default_numa_binding(self) -> None:
|
||
# Inner import to avoid crashing if not torch.distributed.is_available()
|
||
from torch.distributed.launcher.api import LaunchConfig
|
||
|
||
self._add_mock_hardware(
|
||
num_sockets=1,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=1,
|
||
num_l3_caches_per_numa_node=1,
|
||
num_physical_core_per_l3_cache=1,
|
||
)
|
||
|
||
with patch(
|
||
"torch.distributed.launcher.api.get_default_numa_options",
|
||
return_value=NumaOptions(
|
||
affinity_mode=AffinityMode.NODE, should_fall_back_if_binding_fails=True
|
||
),
|
||
):
|
||
launch_config = LaunchConfig(
|
||
min_nodes=1,
|
||
max_nodes=1,
|
||
nproc_per_node=1,
|
||
# Don't provide numa_options
|
||
)
|
||
self.assertEqual(
|
||
launch_config.numa_options,
|
||
NumaOptions(
|
||
affinity_mode=AffinityMode.NODE, should_fall_back_if_binding_fails=True
|
||
),
|
||
)
|
||
|
||
def test_fallback(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=2,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=1,
|
||
num_l3_caches_per_numa_node=1,
|
||
num_physical_core_per_l3_cache=1,
|
||
)
|
||
|
||
with (
|
||
patch("torch.numa.binding.signpost_event") as signpost_patch,
|
||
patch(
|
||
"torch.numa.binding._get_numa_node_index_for_gpu_index",
|
||
side_effect=Exception("Mock exception!"),
|
||
),
|
||
):
|
||
command_args = (
|
||
self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=NumaOptions(
|
||
affinity_mode=AffinityMode.NODE,
|
||
should_fall_back_if_binding_fails=True,
|
||
),
|
||
target_local_rank=0,
|
||
)
|
||
)
|
||
self.assertIn(
|
||
"Mock exception!",
|
||
signpost_patch.call_args.kwargs["parameters"]["traceback"],
|
||
)
|
||
self.assertEqual(
|
||
command_args,
|
||
("echo", "Hello, world!"),
|
||
)
|
||
|
||
def test_fallback_if_numactl_not_available(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=2,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=1,
|
||
num_l3_caches_per_numa_node=1,
|
||
num_physical_core_per_l3_cache=1,
|
||
)
|
||
|
||
with (
|
||
patch("torch.numa.binding.signpost_event") as signpost_patch,
|
||
patch("torch.numa.binding.shutil.which", return_value=None),
|
||
):
|
||
command_args = (
|
||
self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=NumaOptions(
|
||
affinity_mode=AffinityMode.NODE,
|
||
should_fall_back_if_binding_fails=True,
|
||
),
|
||
target_local_rank=0,
|
||
)
|
||
)
|
||
self.assertIn(
|
||
"numactl CLI is required for NUMA binding",
|
||
signpost_patch.call_args.kwargs["parameters"]["traceback"],
|
||
)
|
||
self.assertEqual(
|
||
command_args,
|
||
("echo", "Hello, world!"),
|
||
)
|
||
|
||
def test_explicit_numa_options_overrides_default(self) -> None:
|
||
# Inner import to avoid crashing if not torch.distributed.is_available()
|
||
from torch.distributed.launcher.api import LaunchConfig
|
||
|
||
with patch(
|
||
"torch.distributed.launcher.api.get_default_numa_options",
|
||
return_value=NumaOptions(affinity_mode=AffinityMode.NODE),
|
||
):
|
||
launch_config = LaunchConfig(
|
||
min_nodes=1,
|
||
max_nodes=1,
|
||
nproc_per_node=1,
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
|
||
)
|
||
self.assertEqual(
|
||
launch_config.numa_options,
|
||
NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
|
||
)
|
||
|
||
def test_nproc_must_equal_cuda_device_count_to_use_default_numa_options(
|
||
self,
|
||
) -> None:
|
||
# Inner import to avoid crashing if not torch.distributed.is_available()
|
||
from torch.distributed.launcher.api import LaunchConfig
|
||
|
||
self._add_mock_hardware(
|
||
num_sockets=1,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=1,
|
||
num_l3_caches_per_numa_node=1,
|
||
num_physical_core_per_l3_cache=1,
|
||
)
|
||
|
||
with patch(
|
||
"torch.distributed.launcher.api.get_default_numa_options"
|
||
) as mock_get_default_numa_options:
|
||
launch_config = LaunchConfig(
|
||
min_nodes=1,
|
||
max_nodes=1,
|
||
nproc_per_node=2,
|
||
)
|
||
mock_get_default_numa_options.assert_not_called()
|
||
self.assertIsNone(launch_config.numa_options)
|
||
|
||
def test_socket_numa_binding_with_multiple_numa_per_socket(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=4,
|
||
num_numa_nodes_per_socket=2,
|
||
num_gpus_per_numa_node=2,
|
||
num_l3_caches_per_numa_node=4,
|
||
num_physical_core_per_l3_cache=2,
|
||
)
|
||
|
||
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET),
|
||
target_local_rank=15,
|
||
)
|
||
self.assertEqual(
|
||
command_args,
|
||
# GPU 15 is on numa node 15 // 2 = 7, which is on socket 3 (numa nodes 6 and 7)
|
||
# Each numa node has 4 * 2 * 2 = 16 logical CPUs
|
||
# Numa nodes 6 and 7 have CPUs 96-111 and 112-127
|
||
("numactl", "--physcpubind=96-127", "echo", "Hello, world!"),
|
||
)
|
||
|
||
def test_socket_numa_binding_with_single_numa_per_socket(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=4,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=2,
|
||
num_l3_caches_per_numa_node=4,
|
||
num_physical_core_per_l3_cache=2,
|
||
)
|
||
|
||
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET),
|
||
target_local_rank=7,
|
||
)
|
||
self.assertEqual(
|
||
command_args,
|
||
# GPU 7 is on numa node 7 // 2 = 3, which is socket 3 by itself
|
||
# Each numa node has 4 * 2 * 2 = 16 logical CPUs
|
||
# Numa node 3 has CPUs 48-63
|
||
("numactl", "--physcpubind=48-63", "echo", "Hello, world!"),
|
||
)
|
||
|
||
def test_exclusive_numa_binding(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=2,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=2,
|
||
num_l3_caches_per_numa_node=1,
|
||
num_physical_core_per_l3_cache=3,
|
||
)
|
||
|
||
command_args_0 = self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
|
||
target_local_rank=0,
|
||
)
|
||
self.assertEqual(
|
||
command_args_0,
|
||
# Gets an extra physical core due to odd number of physical cores on numa node
|
||
# 3 physical cores total, 2 GPUs: GPU 0 gets 2 physical cores (CPUs 0-3)
|
||
("numactl", "--physcpubind=0-3", "echo", "Hello, world!"),
|
||
)
|
||
|
||
command_args_1 = self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
|
||
target_local_rank=1,
|
||
)
|
||
self.assertEqual(
|
||
command_args_1,
|
||
# Does not get an extra physical core, since the 1st GPU already took the extra.
|
||
# GPU 1 gets 1 physical core (CPUs 4-5)
|
||
("numactl", "--physcpubind=4-5", "echo", "Hello, world!"),
|
||
)
|
||
|
||
def test_exclusive_raises_if_too_few_physical_cores(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=2,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=2,
|
||
num_l3_caches_per_numa_node=1,
|
||
num_physical_core_per_l3_cache=1,
|
||
)
|
||
|
||
with self.assertRaisesRegex(
|
||
RuntimeError,
|
||
"There are only 1 physical cores on numa_node_index=0, but there are 2 GPUs associated with this NUMA node.",
|
||
):
|
||
self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
|
||
target_local_rank=1,
|
||
)
|
||
|
||
def test_core_complex_numa_binding_with_extra_l3(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=2,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=2,
|
||
num_l3_caches_per_numa_node=3,
|
||
num_physical_core_per_l3_cache=3,
|
||
)
|
||
|
||
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
|
||
target_local_rank=3,
|
||
)
|
||
self.assertEqual(
|
||
command_args,
|
||
# GPU 3 is on numa node 3 // 2 = 1, relative GPU index is 3 % 2 = 1
|
||
# The second L3 on the second numa node (numa node 1)
|
||
# Second numa node starts at CPU 18, second L3 cache is CPUs 24-29
|
||
("numactl", "--physcpubind=24-29", "echo", "Hello, world!"),
|
||
)
|
||
|
||
def test_core_complex_numa_binding_with_fewer_l3_than_gpu(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=2,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=2,
|
||
num_l3_caches_per_numa_node=1,
|
||
num_physical_core_per_l3_cache=3,
|
||
)
|
||
|
||
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
|
||
target_local_rank=3,
|
||
)
|
||
self.assertEqual(
|
||
command_args,
|
||
# GPU 3 is on numa node 3 // 2 = 1, relative GPU index is 3 % 2 = 1
|
||
# With 1 L3 cache per numa node, GPU 3 uses L3 cache index 1 % 1 = 0 (the only cache)
|
||
# Second numa node starts at CPU 6, single L3 cache spans CPUs 6-11
|
||
("numactl", "--physcpubind=6-11", "echo", "Hello, world!"),
|
||
)
|
||
|
||
def test_core_complex_prefers_caches_with_more_cpus(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=1,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=1,
|
||
num_l3_caches_per_numa_node=3,
|
||
num_physical_core_per_l3_cache=3,
|
||
)
|
||
|
||
# Only some subset of the CPUs are available this time.
|
||
with patch("os.sched_getaffinity", return_value={0, 4, 6, 7, 9}):
|
||
command_args = (
|
||
self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
|
||
target_local_rank=0,
|
||
)
|
||
)
|
||
|
||
# Binds to the second L3 because it has the most available CPUs
|
||
self.assertEqual(
|
||
command_args,
|
||
("numactl", "--physcpubind=6-7,9", "echo", "Hello, world!"),
|
||
)
|
||
|
||
def test_core_complex_tiebreak_prefers_lower_cache_key(self) -> None:
|
||
"""
|
||
When several max‑level caches expose the same number of logical CPUs,
|
||
prioritize binding to caches with lower cpu indices first.
|
||
"""
|
||
self._add_mock_hardware(
|
||
num_sockets=1,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=1,
|
||
num_l3_caches_per_numa_node=2,
|
||
num_physical_core_per_l3_cache=1,
|
||
)
|
||
|
||
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
|
||
target_local_rank=0,
|
||
)
|
||
# 1 numa node, 2 L3 caches, 1 physical core per L3 cache = 2 logical CPUs per cache
|
||
# L3 cache 0: CPUs 0-1, L3 cache 1: CPUs 2-3
|
||
# Both have same number of CPUs, so prefer lower cache key (0)
|
||
self.assertEqual(
|
||
command_args,
|
||
("numactl", "--physcpubind=0-1", "echo", "Hello, world!"),
|
||
)
|
||
|
||
def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=1,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=1,
|
||
num_l3_caches_per_numa_node=1,
|
||
num_physical_core_per_l3_cache=1,
|
||
)
|
||
|
||
device_0_properties = self._mock_get_device_properties(0)
|
||
# Overwrite the existing mock file
|
||
self._mock_file_contents(
|
||
file_path=self._get_corresponding_pci_numa_node_file_path(
|
||
device_properties=device_0_properties
|
||
),
|
||
contents="-1",
|
||
)
|
||
|
||
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
|
||
target_local_rank=0,
|
||
)
|
||
# GPU 0 has numa node stored as -1, which is treated as numa node 0
|
||
# Each numa node has 1 * 1 * 2 = 2 logical CPUs
|
||
# Numa node 0 has CPUs 0-1
|
||
self.assertEqual(
|
||
command_args,
|
||
("numactl", "--physcpubind=0-1", "echo", "Hello, world!"),
|
||
)
|
||
|
||
def test_callable_entrypoint_basic(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=4,
|
||
num_numa_nodes_per_socket=2,
|
||
num_gpus_per_numa_node=2,
|
||
num_l3_caches_per_numa_node=4,
|
||
num_physical_core_per_l3_cache=2,
|
||
)
|
||
|
||
bound_logical_cpu_indices = self._start_processes_for_callable_entrypoint_and_get_sched_setaffinity_cpus(
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
|
||
target_local_rank=11,
|
||
)
|
||
self.assertEqual(
|
||
bound_logical_cpu_indices,
|
||
# There are 8 numa nodes and 2 GPUs per numa node, so GPU 11 would be
|
||
# on numa node 11 // 2 = 5.
|
||
# Each numa node has 4 * 2 * 2 = 16 logical CPUs
|
||
# Numa node 5 has CPUs 80-95
|
||
set(range(80, 96)),
|
||
)
|
||
|
||
def test_raises_if_binding_to_empty_set(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=1,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=1,
|
||
num_l3_caches_per_numa_node=1,
|
||
num_physical_core_per_l3_cache=1,
|
||
)
|
||
|
||
with (
|
||
patch(
|
||
"torch.numa.binding._get_logical_cpus_to_bind_to", return_value=set()
|
||
),
|
||
self.assertRaisesRegex(
|
||
RuntimeError, "Must bind to a non-empty set of CPU indices"
|
||
),
|
||
):
|
||
self._start_processes_for_callable_entrypoint_and_get_sched_setaffinity_cpus(
|
||
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
|
||
target_local_rank=0,
|
||
)
|
||
|
||
def test_get_set_of_int_from_ranges_str(self) -> None:
|
||
self.assertEqual(
|
||
_get_set_of_int_from_ranges_str("0-2,4,6-7"), {0, 1, 2, 4, 6, 7}
|
||
)
|
||
|
||
def test_get_range_str_from_ints(self) -> None:
|
||
self.assertEqual(_get_ranges_str_from_ints([7, 0, 1, 6, 2, 4]), "0-2,4,6-7")
|
||
|
||
def test_bind_all_threads_in_current_process_to_logical_cpus(self) -> None:
|
||
self._add_mock_hardware(
|
||
num_sockets=1,
|
||
num_numa_nodes_per_socket=1,
|
||
num_gpus_per_numa_node=1,
|
||
num_l3_caches_per_numa_node=1,
|
||
num_physical_core_per_l3_cache=1,
|
||
)
|
||
|
||
self._mock_file_contents(file_path="/proc/self/task/8675309", contents="")
|
||
self._mock_file_contents(file_path="/proc/self/task/1111111", contents="")
|
||
self._mock_file_contents(
|
||
# The exception from casting not_an_integer to int should get silenced.
|
||
file_path="/proc/self/task/not_an_integer",
|
||
contents="",
|
||
)
|
||
|
||
# Mock the original affinity for all threads
|
||
original_main_thread_affinity = set(range(8))
|
||
manually_reaffinitized_thread_affinity = {0}
|
||
|
||
def mock_sched_getaffinity_impl(tid: int) -> set[int]:
|
||
if tid == 1111111:
|
||
return manually_reaffinitized_thread_affinity
|
||
return original_main_thread_affinity
|
||
|
||
call_order = []
|
||
|
||
def mock_sched_setaffinity_impl(tid: int, cpus: set[int]) -> None:
|
||
call_order.append((tid, cpus))
|
||
|
||
with (
|
||
patch("os.sched_getaffinity", side_effect=mock_sched_getaffinity_impl),
|
||
patch(
|
||
"os.sched_setaffinity", side_effect=mock_sched_setaffinity_impl
|
||
) as mock_sched_setaffinity,
|
||
):
|
||
_bind_all_threads_in_current_process_to_logical_cpus(
|
||
logical_cpu_indices={2, 0, 9} # arbitrary
|
||
)
|
||
|
||
# Should set affinity for main thread (0) and thread 8675309 (same affinity as main)
|
||
# Should NOT set affinity for thread 1111111 (manually reaffinitized)
|
||
self.assertEqual(mock_sched_setaffinity.call_count, 2)
|
||
mock_sched_setaffinity.assert_any_call(0, {2, 0, 9})
|
||
mock_sched_setaffinity.assert_any_call(8675309, {2, 0, 9})
|
||
|
||
# Verify that thread 1111111 was NOT reaffinitized
|
||
for tid, _ in call_order:
|
||
self.assertNotEqual(
|
||
tid, 1111111, "Thread 1111111 should not have been reaffinitized"
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
run_tests()
|