diff --git a/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py b/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py index 62d50456e6c..adfe863b73f 100644 --- a/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py +++ b/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py @@ -308,13 +308,18 @@ class LocalElasticAgentTest(unittest.TestCase): ) def get_agent( - self, spec: WorkerSpec, start_method: str = "spawn", exit_barrier_timeout=5 + self, + spec: WorkerSpec, + start_method: str = "spawn", + exit_barrier_timeout=5, + log_line_prefix_template: Optional[str] = None, ) -> LocalElasticAgent: return LocalElasticAgent( spec, start_method=start_method, exit_barrier_timeout=exit_barrier_timeout, log_dir=self.log_dir(), + log_line_prefix_template=log_line_prefix_template, ) # pyre-fixme[56]: Pyre was not able to infer the type of the decorator @@ -333,6 +338,7 @@ class LocalElasticAgentTest(unittest.TestCase): master_port_override: Optional[int] = None, is_host=True, monitor_interval=0.01, + log_line_prefix_template: Optional[str] = None, ) -> Optional[RunResult]: """ Runs a single agent. This method can be called either on a separate process @@ -356,6 +362,7 @@ class LocalElasticAgentTest(unittest.TestCase): spec=spec, start_method=start_method, exit_barrier_timeout=exit_barrier_timeout, + log_line_prefix_template=log_line_prefix_template, ) result = agent.run() @@ -371,7 +378,10 @@ class LocalElasticAgentTest(unittest.TestCase): return result def run_job( - self, node_configs: List[Conf], exit_barrier_timeout: int = 5 + self, + node_configs: List[Conf], + exit_barrier_timeout: int = 5, + log_line_prefix_template: Optional[str] = None, ) -> Dict[str, List[RunResult]]: """ Simulates running a distributed job by running multiple agents @@ -398,6 +408,8 @@ class LocalElasticAgentTest(unittest.TestCase): "max_restarts": 0, "exit_barrier_timeout": exit_barrier_timeout, "is_host": node_idx == 0, + "log_line_prefix_template": log_line_prefix_template + } p = mp.Process(target=self.run_agent, kwargs=run_agent_args) procs.append(p) @@ -633,16 +645,16 @@ class LocalElasticAgentTest(unittest.TestCase): def test_simple_dist_sum_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.simple_dist_sum) - def run_distributed_sum_homogeneous(self): + def run_distributed_sum_homogeneous(self, log_line_prefix_template: Optional[str] = None): node_configs = [ - Conf(role="sum", entrypoint=_dist_sum, local_world_size=4), - Conf(role="sum", entrypoint=_dist_sum, local_world_size=4), + Conf(role="sum", entrypoint=_dist_sum, local_world_size=4, tee=Std.ALL), + Conf(role="sum", entrypoint=_dist_sum, local_world_size=4, tee=Std.ALL), ] # When the process method is spawn, the coverage collector hangs # due to getting stuck on the _dist_sum in waiting for TCPStore workers # to join the cluster # TODO(aivanou): t83447589 come up with the proper fix - res = self.run_job(node_configs) + res = self.run_job(node_configs, log_line_prefix_template=log_line_prefix_template) self.assertEqual(2, len(res["sum"])) ranks = set() for run_results in res["sum"]: @@ -659,6 +671,14 @@ class LocalElasticAgentTest(unittest.TestCase): backend="c10d", test_to_run=self.run_distributed_sum_homogeneous ) + def test_run_with_custom_log_lines(self): + log_line_prefix_template = "[${role_name}-${local_rank}:${rank}]:" + self.run_test_with_backend( + backend="c10d", + test_to_run=lambda: self.run_distributed_sum_homogeneous(log_line_prefix_template) + ) + + @unittest.skipIf( TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, "test incompatible with dev/dbg asan or tsan", diff --git a/test/distributed/elastic/multiprocessing/api_test.py b/test/distributed/elastic/multiprocessing/api_test.py index 6240fbe41d6..7ca37cefdf2 100644 --- a/test/distributed/elastic/multiprocessing/api_test.py +++ b/test/distributed/elastic/multiprocessing/api_test.py @@ -609,6 +609,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): args={0: ("hello",), 1: ("hello",)}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, log_dir=self.log_dir(), + log_line_prefixes={0: "[rank0]:", 1: "[rank1]:"}, redirects=redirs, ) @@ -641,6 +642,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): args={0: ("hello",), 1: ("world",)}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, log_dir=self.log_dir(), + log_line_prefixes={0: "[rank0]:", 1: "[rank1]:"}, start_method="spawn", redirects={0: Std.ERR, 1: Std.NONE}, tee={0: Std.OUT, 1: Std.ERR}, diff --git a/test/distributed/elastic/multiprocessing/tail_log_test.py b/test/distributed/elastic/multiprocessing/tail_log_test.py index 3cd1909e2e9..e675f5f7eb8 100644 --- a/test/distributed/elastic/multiprocessing/tail_log_test.py +++ b/test/distributed/elastic/multiprocessing/tail_log_test.py @@ -53,7 +53,7 @@ class TailLogTest(unittest.TestCase): } dst = io.StringIO() - tail = TailLog("writer", log_files, dst, interval_sec).start() + tail = TailLog(name="writer", log_files=log_files, dst=dst, interval_sec=interval_sec).start() # sleep here is intentional to ensure that the log tail # can gracefully handle and wait for non-existent log files time.sleep(interval_sec * 10) @@ -83,6 +83,54 @@ class TailLogTest(unittest.TestCase): ) self.assertTrue(tail.stopped()) + def test_tail_with_custom_prefix(self): + """ + writer() writes 0 - max (on number on each line) to a log file. + Run nprocs such writers and tail the log files into an IOString + and validate that all lines are accounted for. + """ + nprocs = 3 + max = 10 + interval_sec = 0.0001 + + log_files = { + local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log") + for local_rank in range(nprocs) + } + + dst = io.StringIO() + log_line_prefixes = {n: f"[worker{n}][{n}]:" for n in range(nprocs)} + tail = TailLog( + "writer", + log_files, + dst, + interval_sec=interval_sec, + log_line_prefixes=log_line_prefixes, + ).start() + # sleep here is intentional to ensure that the log tail + # can gracefully handle and wait for non-existent log files + time.sleep(interval_sec * 10) + futs = [] + for local_rank, file in log_files.items(): + f = self.threadpool.submit( + write, max=max, sleep=interval_sec * local_rank, file=file + ) + futs.append(f) + wait(futs, return_when=ALL_COMPLETED) + self.assertFalse(tail.stopped()) + tail.stop() + dst.seek(0) + + headers: Set[str] = set() + for line in dst.readlines(): + header, _ = line.split(":") + headers.add(header) + self.assertEqual(nprocs, len(headers)) + for i in range(nprocs): + self.assertIn(f"[worker{i}][{i}]", headers) + self.assertTrue(tail.stopped()) + + def test_tail_no_files(self): """ Ensures that the log tail can gracefully handle no log files diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py index 2119745cb89..31b052b042a 100644 --- a/torch/distributed/elastic/agent/server/local_elastic_agent.py +++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -12,6 +12,7 @@ import os import shutil import signal import socket +from string import Template import tempfile import uuid from typing import Any, Dict, Optional, Tuple @@ -79,6 +80,16 @@ class LocalElasticAgent(SimpleElasticAgent): be propagated to the worker processes to allow them to connect to the same named pipe that ```LocalElasticAgent``` uses. + Logs are written to the specified log directory. Each log line will be by default + prefixed by ``[${role_name}${local_rank}]:`` (e.g. ``[trainer0]: foobar``). + Log prefixes can be customized by passing a `template string + `_ as the + ``log_line_prefix_template`` argument. + The following macros (identifiers) are substituted at runtime: + ``${role_name}, ${local_rank}, ${rank}``. For example, to prefix each log line with + global rank instead of the local rank, set ``log_line_prefix_template = "[${rank}]:``. + + Example launching function :: @@ -129,12 +140,14 @@ class LocalElasticAgent(SimpleElasticAgent): start_method="spawn", exit_barrier_timeout: float = 300, log_dir: Optional[str] = None, + log_line_prefix_template: Optional[str] = None, ): super().__init__(spec, exit_barrier_timeout) self._start_method = start_method self._pcontext: Optional[PContext] = None rdzv_run_id = spec.rdzv_handler.get_run_id() self._log_dir = self._make_log_dir(log_dir, rdzv_run_id) + self._log_line_prefix_template = log_line_prefix_template self._worker_watchdog: Optional[timer.FileTimerServer] = None def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str): @@ -229,6 +242,7 @@ class LocalElasticAgent(SimpleElasticAgent): args: Dict[int, Tuple] = {} envs: Dict[int, Dict[str, str]] = {} + log_line_prefixes: Optional[Dict[int, str]] = {} if self._log_line_prefix_template else None for worker in worker_group.workers: local_rank = worker.local_rank worker_env = { @@ -254,6 +268,14 @@ class LocalElasticAgent(SimpleElasticAgent): if "OMP_NUM_THREADS" in os.environ: worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] + + if self._log_line_prefix_template: + log_line_prefix = Template(self._log_line_prefix_template).safe_substitute( + role_name=spec.role, + rank=worker.global_rank, + local_rank=local_rank,) + log_line_prefixes[local_rank] = log_line_prefix + envs[local_rank] = worker_env worker_args = list(spec.args) worker_args = macros.substitute(worker_args, str(local_rank)) @@ -274,6 +296,7 @@ class LocalElasticAgent(SimpleElasticAgent): args=args, envs=envs, log_dir=attempt_log_dir, + log_line_prefixes=log_line_prefixes, start_method=self._start_method, redirects=spec.redirects, tee=spec.tee, diff --git a/torch/distributed/elastic/multiprocessing/__init__.py b/torch/distributed/elastic/multiprocessing/__init__.py index c5738cff749..7e8dfcbd8de 100644 --- a/torch/distributed/elastic/multiprocessing/__init__.py +++ b/torch/distributed/elastic/multiprocessing/__init__.py @@ -64,21 +64,33 @@ implementations of the parent :class:`api.PContext` class. """ import os -from typing import Callable, Dict, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401 + _validate_full_rank, MultiprocessContext, PContext, ProcessFailure, RunProcsResult, - Std, SignalException, + Std, SubprocessContext, - _validate_full_rank, to_map, ) from torch.distributed.elastic.utils.logging import get_logger +__all__ = [ + "start_processes", + "MultiprocessContext", + "PContext", + "ProcessFailure", + "RunProcsResult", + "SignalException", + "Std", + "SubprocessContext", + "to_map", +] + log = get_logger(__name__) @@ -88,6 +100,7 @@ def start_processes( args: Dict[int, Tuple], envs: Dict[int, Dict[str, str]], log_dir: str, + log_line_prefixes: Optional[Dict[int, str]] = None, start_method: str = "spawn", redirects: Union[Std, Dict[int, Std]] = Std.NONE, tee: Union[Std, Dict[int, Std]] = Std.NONE, @@ -257,6 +270,7 @@ def start_processes( tee_stdouts=tee_stdouts, tee_stderrs=tee_stderrs, error_files=error_files, + log_line_prefixes=log_line_prefixes, ) else: context = MultiprocessContext( @@ -269,6 +283,7 @@ def start_processes( tee_stdouts=tee_stdouts, tee_stderrs=tee_stderrs, error_files=error_files, + log_line_prefixes=log_line_prefixes, start_method=start_method, ) diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 4e29c0d2b8f..edc32e51b97 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -208,6 +208,7 @@ class PContext(abc.ABC): tee_stdouts: Dict[int, str], tee_stderrs: Dict[int, str], error_files: Dict[int, str], + log_line_prefixes: Optional[Dict[int, str]] = None, ): self.name = name # validate that all mappings have the same number of keys and @@ -224,8 +225,8 @@ class PContext(abc.ABC): self.error_files = error_files self.nprocs = nprocs - self._stdout_tail = TailLog(name, tee_stdouts, sys.stdout) - self._stderr_tail = TailLog(name, tee_stderrs, sys.stderr) + self._stdout_tail = TailLog(name, tee_stdouts, sys.stdout, log_line_prefixes) + self._stderr_tail = TailLog(name, tee_stderrs, sys.stderr, log_line_prefixes) def start(self) -> None: """ @@ -389,6 +390,7 @@ class MultiprocessContext(PContext): tee_stderrs: Dict[int, str], error_files: Dict[int, str], start_method: str, + log_line_prefixes: Optional[Dict[int, str]] = None, ): super().__init__( name, @@ -400,6 +402,7 @@ class MultiprocessContext(PContext): tee_stdouts, tee_stderrs, error_files, + log_line_prefixes, ) self.start_method = start_method @@ -611,6 +614,7 @@ class SubprocessContext(PContext): tee_stdouts: Dict[int, str], tee_stderrs: Dict[int, str], error_files: Dict[int, str], + log_line_prefixes: Optional[Dict[int, str]] = None, ): super().__init__( name, @@ -622,6 +626,7 @@ class SubprocessContext(PContext): tee_stdouts, tee_stderrs, error_files, + log_line_prefixes, ) # state vector; _vdone[local_rank] -> is local_rank finished or not diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 8feccd95515..03c094945f0 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -12,7 +12,7 @@ import time from concurrent.futures._base import Future from concurrent.futures.thread import ThreadPoolExecutor from threading import Event -from typing import Dict, List, TextIO +from typing import Dict, List, Optional, TextIO __all__ = ["tail_logfile", "TailLog"] @@ -55,7 +55,8 @@ class TailLog: Each log file's line will be suffixed with a header of the form: ``[{name}{idx}]:``, where the ``name`` is user-provided and ``idx`` is the index of the log file - in the ``log_files`` mapping. + in the ``log_files`` mapping. ``log_line_prefixes`` can be used to override the + header for each log file. Usage: @@ -86,6 +87,7 @@ class TailLog: name: str, log_files: Dict[int, str], dst: TextIO, + log_line_prefixes: Optional[Dict[int, str]] = None, interval_sec: float = 0.1, ): n = len(log_files) @@ -99,6 +101,7 @@ class TailLog: self._name = name self._dst = dst self._log_files = log_files + self._log_line_prefixes = log_line_prefixes self._finished_events: Dict[int, Event] = { local_rank: Event() for local_rank in log_files.keys() } @@ -111,10 +114,13 @@ class TailLog: return self for local_rank, file in self._log_files.items(): + header = f"[{self._name}{local_rank}]:" + if self._log_line_prefixes and local_rank in self._log_line_prefixes: + header = self._log_line_prefixes[local_rank] self._futs.append( self._threadpool.submit( tail_logfile, - header=f"[{self._name}{local_rank}]:", + header=header, file=file, dst=self._dst, finished=self._finished_events[local_rank], diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index 8a90ba0baef..f6cc5702e74 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -82,6 +82,7 @@ class LaunchConfig: monitor_interval: float = 30 start_method: str = "spawn" log_dir: Optional[str] = None + log_line_prefix_template: Optional[str] = None redirects: Union[Std, Dict[int, Std]] = Std.NONE tee: Union[Std, Dict[int, Std]] = Std.NONE metrics_cfg: Dict[str, str] = field(default_factory=dict) @@ -245,7 +246,10 @@ def launch_agent( ) agent = LocalElasticAgent( - spec=spec, start_method=config.start_method, log_dir=config.log_dir + spec=spec, + start_method=config.start_method, + log_dir=config.log_dir, + log_line_prefix_template=config.log_line_prefix_template, ) shutdown_rdzv = True diff --git a/torch/distributed/run.py b/torch/distributed/run.py index 9a2beeac744..8ba94a249ae 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -713,6 +713,8 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str # This env variable will be passed down to the subprocesses os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) + log_line_prefix_template = os.getenv("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE") + rdzv_configs = _parse_rendezvous_config(args.rdzv_conf) if args.rdzv_backend == "static": @@ -735,6 +737,7 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str redirects=Std.from_str(args.redirects), tee=Std.from_str(args.tee), log_dir=args.log_dir, + log_line_prefix_template=log_line_prefix_template, local_addr=args.local_addr, )