pytorch/test/distributed/launcher/api_test.py
Tristan Rice 952a00eda7 torchelastic: change monitor_interval default to 0.1 (#124692)
This reduces the default monitor_interval for torchelastic to 0.1s as testing shows negligble load for common use cases. Even at the extremes, 100k processes is only 45.4% cpu util of a single core.

Torchelastic monitor_interval only monitors the processes on a single worker so under typical loads even for huge jobs we expect ~8 subprocesses per machine with one per GPU.

As an external datapoint, Python's wait polls every 50usec-50ms (https://github.com/python/cpython/blob/main/Lib/subprocess.py#L2035).

## Motivation

This setting is used to control how frequently we poll for failed processes in elastic.

* For some jobs of note we run elastic 3 times per try so with the default timeout of 5 seconds we should save ~15 seconds per retry.
* @kiukchung's use case: Apparently this is annoying in notebooks etc since it adds delay to shutdown when testing things

## Results

This is measured in cores (100% is a single core under full load).

| monitor_interval (s) | nproc-per-node | CPU util (highest observed) |
| -------------------- | -------------- | --------------------------- |
| 1.0                  | 10             | 0.2%                        |
| 0.1                  | 1              | 0.4%                        |
| 0.1                  | 10             | 0.4%                        |
| 0.01                 | 10             | 0.9%                        |
| 0.001                | 10             | 4.0%                        |
| 0.1                  | 100            | 0.5%                        |
| 0.1                  | 1000           | 2.2%                        |
| 0.1                  | 10000          | 15.7%                       |
| 0.1                  | 100000         | 45.4%                       |

## Methodology

```sh
# run command
$ LOGLEVEL=INFO torchrun --nnodes 1 --nproc-per-node 10 --monitor-interval 0.1 ~/wait.py

# wait a few seconds for all processes to start and reach steady state and then run, wait ~30s or 3 prints and take the highest
$ top -b -d 10 -c | rg 'torchrun.*wait
```

wait.py

```py
import time

time.sleep(10*60)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124692
Approved by: https://github.com/kiukchung, https://github.com/kurman
2024-04-24 01:44:41 +00:00

414 lines
13 KiB
Python

#!/usr/bin/env python3
# Owner(s): ["oncall: r2p"]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import multiprocessing as mp
import os
import shutil
import signal
import sys
import tempfile
import time
import unittest
import uuid
from contextlib import closing
from typing import Any, Dict, Optional
from unittest import mock
from unittest.mock import MagicMock, Mock, patch
import torch
import torch.distributed as dist
from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
from torch.distributed.elastic.multiprocessing.api import SignalException
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
from torch.distributed.elastic.utils import get_socket_with_port
from torch.distributed.launcher.api import (
_get_entrypoint_name,
elastic_launch,
launch_agent,
LaunchConfig,
)
from torch.testing._internal.common_utils import (
skip_but_pass_in_sandcastle_if,
TEST_WITH_DEV_DBG_ASAN,
)
def path(script):
return os.path.join(os.path.dirname(__file__), script)
def simple_rank_scale():
rank = int(os.environ["RANK"])
return 10 + rank
def function_with_bug():
raise RuntimeError("test error")
def get_test_launch_config(
rdzv_endpoint: str,
min_nodes: int,
max_nodes: int,
nproc_per_node: int,
run_id: str = "",
rdzv_backend: str = "etcd",
config: Optional[Dict[str, Any]] = None,
) -> LaunchConfig:
rdzv_configs = {}
if config:
rdzv_configs.update(config)
return LaunchConfig(
min_nodes=min_nodes,
max_nodes=max_nodes,
nproc_per_node=nproc_per_node,
run_id=run_id,
rdzv_endpoint=rdzv_endpoint,
monitor_interval=0.1,
rdzv_backend=rdzv_backend,
start_method="spawn",
max_restarts=0,
rdzv_configs=rdzv_configs,
)
def elastic_launch_wrapper(
test_dir: str,
rdzv_endpoint: str,
min_nodes: int,
max_nodes: int,
nproc_per_node: int,
run_id: str,
):
"""A wrapper function for class `elastic_launch.` in order to make multiprocess returns correct exit code."""
elastic_launch(
get_test_launch_config(
rdzv_endpoint, min_nodes, max_nodes, nproc_per_node, run_id
),
sys.executable,
)("-u", path("bin/test_script.py"), f"--touch-file-dir={test_dir}")
def _dist_sum(wait=0):
rank = int(os.environ["RANK"])
dist.init_process_group(backend="gloo")
t = torch.tensor(rank)
time.sleep(wait)
dist.all_reduce(t, op=dist.reduce_op.SUM)
return t.item()
ELASTIC_AGENT_RUN = "torch.distributed.launcher.api.LocalElasticAgent.run"
EVENTS_RECORD = "torch.distributed.launcher.api.events.record"
GET_RDZV_HANDLER = (
"torch.distributed.elastic.rendezvous.registry.get_rendezvous_handler"
)
class MockException(Exception):
pass
def short_hash():
return str(uuid.uuid4()).split("-")[0]
class ElasticLaunchTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
# start a standalone, single process etcd server to use for all tests.
cls._etcd_server = EtcdServer()
cls._etcd_server.start()
cls._etcd_endpoint = cls._etcd_server.get_endpoint()
@classmethod
def tearDownClass(cls):
# stop the standalone etcd server.
cls._etcd_server.stop()
def setUp(self):
self.test_dir = tempfile.mkdtemp()
# remove any lingering environment variables.
for env in os.environ.keys():
if env.startswith("PET_"):
del os.environ[env]
# set a sentinel env var on the parent proc.
# this should be present on the child and gets
# asserted in ``bin/test_script.py``.
os.environ["TEST_SENTINEL_PARENT"] = "FOOBAR"
os.environ["OMP_NUM_THREADS"] = str(1)
def tearDown(self):
shutil.rmtree(self.test_dir)
def check_works_ran(self, world_size: int):
self.assertSetEqual(
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
)
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_launch_script_python(self):
nnodes = 1
nproc_per_node = 4
elastic_launch(
get_test_launch_config(self._etcd_endpoint, nnodes, nnodes, nproc_per_node),
sys.executable,
)("-u", path("bin/test_script.py"), f"--touch-file-dir={self.test_dir}")
# make sure all the workers ran.
# each worker touches a file with its global rank as the name.
world_size = nnodes * nproc_per_node
self.check_works_ran(world_size)
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_launch_script_python_local_rank_transfer(self):
nnodes = 1
nproc_per_node = 4
elastic_launch(
get_test_launch_config(self._etcd_endpoint, nnodes, nnodes, nproc_per_node),
sys.executable,
)("-u", path("bin/test_script.py"), f"--touch-file-dir={self.test_dir}")
# make sure all the workers ran.
# each worker touches a file with its global rank as the name.
world_size = nnodes * nproc_per_node
self.check_works_ran(world_size)
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_launch_script_bash(self):
nnodes = 1
nproc_per_node = 4
elastic_launch(
get_test_launch_config(self._etcd_endpoint, nnodes, nnodes, nproc_per_node),
path("bin/test_script.sh"),
)(f"{self.test_dir}")
world_size = nnodes * nproc_per_node
self.check_works_ran(world_size)
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_launch_function(self):
nnodes = 1
nproc_per_node = 4
res = elastic_launch(
get_test_launch_config(self._etcd_endpoint, nnodes, nnodes, nproc_per_node),
simple_rank_scale,
)()
expected_res = [10, 11, 12, 13]
actual_res = sorted(value for value in res.values())
self.assertEqual(expected_res, actual_res)
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_launch_dist_sum_with_static_rdzv(self):
nnodes = 1
nproc_per_node = 4
sock = get_socket_with_port()
with closing(sock):
master_port = sock.getsockname()[1]
rdzv_endpoint = f"127.0.0.1:{master_port}"
rank = 0
rdzv_config = {
"rank": rank,
}
res = elastic_launch(
get_test_launch_config(
rdzv_endpoint,
nnodes,
nnodes,
nproc_per_node,
rdzv_backend="static",
config=rdzv_config,
),
_dist_sum,
)()
expected_res = [sum(range(nproc_per_node))] * nproc_per_node
actual_res = sorted(value for value in res.values())
self.assertEqual(expected_res, actual_res)
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_launch_elastic(self):
nproc_per_node = 4
elastic_launch(
get_test_launch_config(self._etcd_endpoint, 1, 2, nproc_per_node),
sys.executable,
)("-u", path("bin/test_script.py"), f"--touch-file-dir={self.test_dir}")
world_size = nproc_per_node
self.check_works_ran(world_size)
@mock.patch("torch.distributed.elastic.events.record")
def test_launch_elastic_worker_raise_exception(self, record_mock):
"""
Asserts that when the worker program fails and lancher raieses exception
to indicate that worker process failed.
"""
nproc_per_node = 4
with self.assertRaises(ChildFailedError):
elastic_launch(
get_test_launch_config(self._etcd_endpoint, 1, 2, nproc_per_node),
sys.executable,
)("-u", path("bin/test_script.py"), "--fail")
record_mock.assert_called_once()
@mock.patch("torch.distributed.elastic.events.record")
@mock.patch(
"torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent.run"
)
def test_launch_elastic_agent_raise_exception(self, record_mock, mock_agent_run):
"""
Asserts that when the agent raises an exception
the launcher re-raises the original exception.
"""
mock_agent_run.side_effect = MockException
with self.assertRaises(MockException):
elastic_launch(
get_test_launch_config(self._etcd_endpoint, 1, 2, 4),
sys.executable,
)("-u", path("bin/test_script.py"), f"--touch-file-dir={self.test_dir}")
record_mock.assert_called_once()
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_launch_elastic_multiple_agents(self):
min_nodes = 1
max_nodes = 2
nproc_per_node = 4
nnodes = 2
run_id = str(uuid.uuid4().int)
procs = []
ctx = mp.get_context("spawn")
for _ in range(nnodes - 1):
p = ctx.Process(
target=elastic_launch_wrapper,
args=(
self.test_dir,
self._etcd_endpoint,
min_nodes,
max_nodes,
nproc_per_node,
run_id,
),
)
procs.append(p)
p.start()
elastic_launch_wrapper(
self.test_dir,
self._etcd_endpoint,
min_nodes,
max_nodes,
nproc_per_node,
run_id,
)
for i in range(nnodes - 1):
p = procs[i]
p.join()
self.assertEqual(0, p.exitcode)
# make sure all the workers ran
# each worker touches a file with its global rank as the name
world_size = nnodes * nproc_per_node
self.assertSetEqual(
{str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
)
@patch("torch.distributed.launcher.api.LocalElasticAgent")
def test_launch_shutdown(self, agent_mock_cls):
agent_mock = Mock()
agent_mock.run.return_value = RunResult(WorkerState.SUCCEEDED)
agent_mock_cls.return_value = agent_mock
rdzv_handler_mock = Mock()
with patch(
"torch.distributed.elastic.rendezvous.registry.get_rendezvous_handler"
) as param_mock:
param_mock.return_value = rdzv_handler_mock
elastic_launch(
get_test_launch_config(self._etcd_endpoint, 1, 1, 4),
sys.executable,
)("-u", path("bin/test_script.py"), f"--touch-file-dir={self.test_dir}")
rdzv_handler_mock.shutdown.assert_called_once()
def test_get_entrypoint_name(self):
self.assertEqual(
"simple_rank_scale", _get_entrypoint_name(simple_rank_scale, [])
)
self.assertEqual("", _get_entrypoint_name(sys.executable, []))
self.assertEqual("", _get_entrypoint_name(sys.executable, ["-u"]))
self.assertEqual(
"test_script.py",
_get_entrypoint_name(sys.executable, ["-u", "test_script.py"]),
)
self.assertEqual("", _get_entrypoint_name(None, []))
@patch(ELASTIC_AGENT_RUN)
@patch(GET_RDZV_HANDLER)
def test_rdzv_handler_shutdown_on_agent_signal(self, mock_get_rdzv, mock_agent_run):
config = get_test_launch_config(
self._etcd_endpoint, min_nodes=1, max_nodes=1, nproc_per_node=1
)
for sigval in [signal.SIGTERM, signal.SIGINT]:
with patch(EVENTS_RECORD) as record_event_mock:
rdzv_handler_mock = MagicMock()
rdzv_handler_mock.get_run_id.return_value = short_hash()
mock_get_rdzv.return_value = rdzv_handler_mock
mock_agent_run.side_effect = SignalException("test", sigval)
with self.assertRaises(SignalException):
launch_agent(config, simple_rank_scale, [])
rdzv_handler_mock.shutdown.assert_not_called()
record_event_mock.assert_called_once()
@patch(ELASTIC_AGENT_RUN)
@patch(GET_RDZV_HANDLER)
def test_rdzv_handler_shutdown_on_agent_error(self, mock_get_rdzv, mock_agent_run):
config = get_test_launch_config(
self._etcd_endpoint, min_nodes=1, max_nodes=1, nproc_per_node=1
)
with patch(EVENTS_RECORD) as record_event_mock:
rdzv_handler_mock = MagicMock()
rdzv_handler_mock.get_run_id.return_value = short_hash()
mock_get_rdzv.return_value = rdzv_handler_mock
mock_agent_run.side_effect = RuntimeError("any other exception")
with self.assertRaises(RuntimeError):
launch_agent(config, simple_rank_scale, [])
rdzv_handler_mock.shutdown.assert_called_once()
record_event_mock.assert_called_once()