[pytorch][torchelastic] Duplicate stdout and stderr and apply custom filter in torchrun (#160712)

Summary:
Part of an effort to extract some important error logs (e.g. [#157996](https://github.com/pytorch/pytorch/pull/157996)) that was `tee`'ed to `stdout` and `stderr`.

The general idea is to:

- Duplicate the `tee`s on `stdout` and `stderr` to a separate file, `filtered_stdout.log` and `filtered_stderr.log`, respectively.
- In these files, as its name suggests, only log lines matching a customizable filter.
- Later on in another PR, append the contents of these files to the reply file.

Outline of changes in this PR:

- Enhance `TailLog` to be able to 1) stream to a file, and 2) only write when the line matches the passed filter.
- Add `filtered_stdout` and `filtered_stderr` to `LogsDest` and have `LogsSpecs` `reify` them.
- In `start_processes()` and `PContext`, add params `duplicate_stdout_filters` and `duplicate_stderr_filters` to filter and write the duplicated stream to the files above. When no filters are passed in, no duplicated streams are created.

Test Plan:
```
$ buck test 'fbcode//mode/opt' caffe2/test/distributed/elastic/multiprocessing:api_test
```
```
Buck UI: https://www.internalfb.com/buck2/f5c6b7da-217d-4a0b-872a-c7cd3d05587f
Test UI: https://www.internalfb.com/intern/testinfra/testrun/4222124951617688
Network: Up: 398B  Down: 44MiB  (reSessionID-a489a961-b602-45be-b851-3490ebb7a26a)
Analyzing targets. Remaining     0/200
Executing actions. Remaining     0/12856                                                                                                                                        0.1s exec time total
Command: test.     Finished 1 local
Time elapsed: 17:37.9s
Tests finished: Pass 52. Fail 0. Fatal 0. Skip 0. Build failure 0
```
```
$ buck test 'fbcode//mode/opt' caffe2/test/distributed/elastic/multiprocessing:tail_log_test
```
```
Buck UI: https://www.internalfb.com/buck2/d6d5c1c1-db98-4d9c-b608-7ba6fbb5e3ee
Test UI: https://www.internalfb.com/intern/testinfra/testrun/13510798985149262
Network: Up: 94KiB  Down: 417MiB  (reSessionID-27b46fba-d31c-4c04-8ede-a506454e6922)
Analyzing targets. Remaining     0/3                                                                                                                                            536 actions, 555 artifacts declared
Executing actions. Remaining     0/186                                                                                                                                          1:05.5s exec time total
Command: test.     Finished 7 local, 1 remote, 115 cache (93% hit)                                                                                                              37.0s exec time cached (56%)
Time elapsed: 1:11.5s
Tests finished: Pass 7. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Rollback Plan:

Differential Revision: D80188995

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160712
Approved by: https://github.com/fduwjj
This commit is contained in:
Phil Hu 2025-10-23 14:22:21 +00:00 committed by PyTorch MergeBot
parent 2b93d5b450
commit cbcb4f7768
11 changed files with 401 additions and 79 deletions

View File

@ -127,8 +127,9 @@ def echo1(msg: str, exitcode: int = 0) -> str:
print(f"exit {exitcode} from {rank}", file=sys.stderr)
sys.exit(exitcode)
else:
print(f"{msg} stdout from {rank}")
print(f"{msg} stderr from {rank}", file=sys.stderr)
for m in msg.split(","):
print(f"{m} stdout from {rank}")
print(f"{m} stderr from {rank}", file=sys.stderr)
return f"{msg}_{rank}"
@ -247,6 +248,13 @@ class _StartProcessesTest(TestCase):
for line in expected:
self.assertIn(line, actual)
def assert_not_in_file(self, lines: list[str], filename: str) -> None:
lines = [f"{line.rstrip()}\n" for line in lines]
with open(filename) as fp:
actual = fp.readlines()
for line in lines:
self.assertNotIn(line, actual)
def assert_pids_noexist(self, pids: dict[int, int]):
for local_rank, pid in pids.items():
with self.assertRaises(
@ -360,8 +368,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
self.assertIsNone(pc.wait(timeout=0.1, period=0.01))
self.assertIsNotNone(pc.wait(period=0.1))
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_pcontext_wait_on_a_child_thread(self):
asyncio.run(asyncio.to_thread(self.test_pcontext_wait))
@ -379,8 +387,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
pids = pc.pids()
pc.close()
self.assert_pids_noexist(pids)
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_function_with_tensor(self):
for start_method in self._start_methods:
@ -482,8 +490,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
int(error_file_data["message"]["extraInfo"]["timestamp"]),
int(failure.timestamp),
)
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_wait_for_all_child_procs_to_exit(self):
"""
@ -580,8 +588,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
self.assert_in_file([], results.stdouts[0])
self.assertFalse(results.stderrs[1])
self.assertFalse(results.stdouts[1])
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
failure = results.failures[1]
self.assertEqual(-15, failure.exitcode)
@ -731,8 +739,37 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
self.assertFalse(pc.stdouts[1])
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_binary_duplicate_log_filters(self):
pc = start_processes(
name="trainer",
entrypoint=bin("echo1.py"),
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(),
redirects={0: Std.ERR, 1: Std.NONE},
tee={0: Std.OUT, 1: Std.ERR},
),
log_line_prefixes={0: "[rank0]:", 1: "[rank1]:"},
duplicate_stdout_filters=["helloA"],
duplicate_stderr_filters=["worldA", "B"],
start_method="spawn",
)
result = pc.wait()
self.assertFalse(result.is_failed())
self.assert_in_file(["[rank0]:helloA stdout from 0"], pc.filtered_stdout)
self.assert_not_in_file(
["[rank0]:helloB stdout from 0"], pc.filtered_stdout
)
self.assert_in_file(["[rank1]:worldA stderr from 1"], pc.filtered_stderr)
self.assert_in_file(["[rank1]:worldB stderr from 1"], pc.filtered_stderr)
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
@ -794,8 +831,44 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
self.assertFalse(pc.stdouts[1])
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_function_duplicate_log_filters(self):
for start_method in self._start_methods:
with self.subTest(start_method=start_method):
pc = start_processes(
name="trainer",
entrypoint=echo1,
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(),
redirects={0: Std.ERR, 1: Std.NONE},
tee={0: Std.OUT, 1: Std.ERR},
),
duplicate_stdout_filters=["helloA"],
duplicate_stderr_filters=["worldA", "B"],
start_method="spawn",
)
result = pc.wait()
self.assertFalse(result.is_failed())
self.assert_in_file(
["[trainer0]:helloA stdout from 0"], pc.filtered_stdout
)
self.assert_not_in_file(
["[trainer0]:helloB stdout from 0"], pc.filtered_stdout
)
self.assert_in_file(
["[trainer1]:worldA stderr from 1"], pc.filtered_stderr
)
self.assert_in_file(
["[trainer1]:worldB stderr from 1"], pc.filtered_stderr
)
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_function(self):
for start_method, redirs in product(self._start_methods, redirects_all()):
@ -880,8 +953,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
self.assertFalse(results.stdouts[0])
self.assertFalse(results.stderrs[1])
self.assertFalse(results.stdouts[1])
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
def test_no_zombie_process_function(self):
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]

View File

@ -23,5 +23,6 @@ if __name__ == "__main__":
print(f"exit {exitcode} from {rank}", file=sys.stderr)
sys.exit(exitcode)
else:
print(f"{args.msg} stdout from {rank}")
print(f"{args.msg} stderr from {rank}", file=sys.stderr)
for msg in args.msg.split(","):
print(f"{msg} stdout from {rank}")
print(f"{msg} stderr from {rank}", file=sys.stderr)

View File

@ -84,6 +84,53 @@ class TailLogTest(unittest.TestCase):
)
self.assertTrue(tail.stopped())
def test_tail_write_to_dst_file(self):
"""
writer() writes 0 - max (on number on each line) to a log file.
Run nprocs such writers and tail the log files into a temp file
and validate that all lines are accounted for.
"""
nprocs = 32
max = 1000
interval_sec = 0.0001
log_files = {
local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log")
for local_rank in range(nprocs)
}
dst = os.path.join(self.test_dir, "tailed_stdout.log")
tail = TailLog(
name="writer", log_files=log_files, dst=dst, interval_sec=interval_sec
).start()
# sleep here is intentional to ensure that the log tail
# can gracefully handle and wait for non-existent log files
time.sleep(interval_sec * 10)
futs = []
for local_rank, file in log_files.items():
f = self.threadpool.submit(
write, max=max, sleep=interval_sec * local_rank, file=file
)
futs.append(f)
wait(futs, return_when=ALL_COMPLETED)
self.assertFalse(tail.stopped())
tail.stop()
actual: dict[int, set[int]] = {}
with open(dst) as dst_file:
for line in dst_file:
header, num = line.split(":")
nums = actual.setdefault(header, set())
nums.add(int(num))
self.assertEqual(nprocs, len(actual))
self.assertEqual(
{f"[writer{i}]": set(range(max)) for i in range(nprocs)}, actual
)
self.assertTrue(tail.stopped())
def test_tail_with_custom_prefix(self):
"""
writer() writes 0 - max (on number on each line) to a log file.
@ -131,6 +178,52 @@ class TailLogTest(unittest.TestCase):
self.assertIn(f"[worker{i}][{i}]", headers)
self.assertTrue(tail.stopped())
def test_tail_with_custom_filter(self):
"""
writer() writes 0 - max (on number on each line) to a log file.
Run nprocs such writers and tail the log files into an IOString
and validate that all lines are accounted for.
"""
nprocs = 3
max = 20
interval_sec = 0.0001
log_files = {
local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log")
for local_rank in range(nprocs)
}
dst = io.StringIO()
tail = TailLog(
"writer",
log_files,
dst,
interval_sec=interval_sec,
log_line_filter=lambda line: "2" in line, # only print lines containing '2'
).start()
# sleep here is intentional to ensure that the log tail
# can gracefully handle and wait for non-existent log files
time.sleep(interval_sec * 10)
futs = []
for local_rank, file in log_files.items():
f = self.threadpool.submit(
write, max=max, sleep=interval_sec * local_rank, file=file
)
futs.append(f)
wait(futs, return_when=ALL_COMPLETED)
self.assertFalse(tail.stopped())
tail.stop()
dst.seek(0)
actual: dict[int, set[int]] = {}
for line in dst.readlines():
header, num = line.split(":")
nums = actual.setdefault(header, set())
nums.add(int(num))
self.assertEqual(nprocs, len(actual))
self.assertEqual({f"[writer{i}]": {2, 12} for i in range(nprocs)}, actual)
self.assertTrue(tail.stopped())
def test_tail_no_files(self):
"""
Ensures that the log tail can gracefully handle no log files

View File

@ -55,9 +55,10 @@ class SignalHandlingTest(TestCase):
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Mock the stdout_tail and stderr_tail
mock_stdout_tail = MagicMock()
mock_stderr_tail = MagicMock()
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
# Remove environment variable if it exists to test default behavior
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
@ -84,8 +85,8 @@ class SignalHandlingTest(TestCase):
# Verify _start was called
mock_pcontext._start.assert_called_once()
# Verify _stdout_tail.start() and _stderr_tail.start() were called
mock_pcontext._stdout_tail.start.assert_called_once()
mock_pcontext._stderr_tail.start.assert_called_once()
mock_stdout_tail.start.assert_called_once()
mock_stderr_tail.start.assert_called_once()
@patch("torch.distributed.elastic.multiprocessing.api.threading")
@patch("torch.distributed.elastic.multiprocessing.api.signal")
@ -99,9 +100,10 @@ class SignalHandlingTest(TestCase):
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Mock the stdout_tail and stderr_tail
mock_stdout_tail = MagicMock()
mock_stderr_tail = MagicMock()
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
# Set custom signals in the environment variable
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGUSR1,SIGUSR2"
@ -139,9 +141,10 @@ class SignalHandlingTest(TestCase):
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Mock the stdout_tail and stderr_tail
mock_stdout_tail = MagicMock()
mock_stderr_tail = MagicMock()
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
# Set invalid signals in the environment variable
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,INVALID_SIGNAL"
@ -180,9 +183,10 @@ class SignalHandlingTest(TestCase):
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Mock the stdout_tail and stderr_tail
mock_stdout_tail = MagicMock()
mock_stderr_tail = MagicMock()
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
# Set signals including ones not supported on Windows
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGHUP,SIGUSR1"
@ -234,9 +238,10 @@ class SignalHandlingTest(TestCase):
mock_threading.current_thread.return_value = MagicMock() # Not the main thread
mock_threading.main_thread.return_value = MagicMock()
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Mock the stdout_tail and stderr_tail
mock_stdout_tail = MagicMock()
mock_stderr_tail = MagicMock()
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
# Call the start method
PContext.start(mock_pcontext)
@ -262,9 +267,10 @@ class SignalHandlingTest(TestCase):
mock_threading.main_thread.return_value
)
mock_pcontext = MagicMock(spec=PContext)
# Mock the _stdout_tail and _stderr_tail attributes
mock_pcontext._stdout_tail = MagicMock()
mock_pcontext._stderr_tail = MagicMock()
# Mock the stdout_tail and stderr_tail
mock_stdout_tail = MagicMock()
mock_stderr_tail = MagicMock()
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
# Set environment variable to include SIGUSR1 and SIGUSR2
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGUSR1,SIGUSR2"
@ -323,8 +329,8 @@ class SignalHandlingTest(TestCase):
# Verify _start was called
mock_pcontext._start.assert_called_once()
# Verify _stdout_tail.start() and _stderr_tail.start() were called
mock_pcontext._stdout_tail.start.assert_called_once()
mock_pcontext._stderr_tail.start.assert_called_once()
mock_stdout_tail.start.assert_called_once()
mock_stderr_tail.start.assert_called_once()
if __name__ == "__main__":

View File

@ -75,6 +75,10 @@ class WorkerSpec:
takes precedence over ``redirects`` settings.
event_log_handler: name of the event logging handler as registered in
`elastic/events/handlers.py <https://docs.pytorch.org/docs/stable/elastic/events.html>`_.
duplicate_stdout_filters: If non-empty, duplicates stdout to a file containing only lines
that match _any_ of the filter strings.
duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines
that match _any_ of the filter strings.
"""
role: str
@ -91,6 +95,8 @@ class WorkerSpec:
local_addr: Optional[str] = None
event_log_handler: str = "null"
numa_options: Optional[NumaOptions] = None
duplicate_stdout_filters: Optional[list[str]] = None
duplicate_stderr_filters: Optional[list[str]] = None
def __post_init__(self):
assert self.local_world_size > 0

View File

@ -356,6 +356,8 @@ class LocalElasticAgent(SimpleElasticAgent):
log_line_prefixes=log_line_prefixes,
start_method=self._start_method,
numa_options=spec.numa_options,
duplicate_stdout_filters=spec.duplicate_stdout_filters,
duplicate_stderr_filters=spec.duplicate_stderr_filters,
)
return self._pcontext.pids()

View File

@ -109,6 +109,8 @@ def start_processes(
log_line_prefixes: Optional[dict[int, str]] = None,
start_method: str = "spawn",
numa_options: Optional[NumaOptions] = None,
duplicate_stdout_filters: Optional[list[str]] = None,
duplicate_stderr_filters: Optional[list[str]] = None,
) -> PContext:
"""
Start ``n`` copies of ``entrypoint`` processes with the provided options.
@ -130,12 +132,17 @@ def start_processes(
this is done by default and there is no need to manually annotate
with the ``@record`` annotation.
``redirects`` and ``tee`` are bitmasks specifying which std stream(s) to redirect
to a log file in the ``log_dir``. Valid mask values are defined in ``Std``.
To redirect/tee only certain local ranks, pass ``redirects`` as a map with the key as
the local rank to specify the redirect behavior for.
Inside ``logs_specs``, ``redirects`` and ``tee`` are bitmasks specifying which std
stream(s) to redirect to a log file in the ``log_dir``. Valid mask values are defined
in ``Std``. To redirect/tee only certain local ranks, pass ``redirects`` as a map
with the key as the local rank to specify the redirect behavior for.
Any missing local ranks will default to ``Std.NONE``.
``duplicate_stdout_filters`` and ``duplicate_stderr_filters``, if non-empty,
duplicate stdouts and stderrs respectively specified in ``logs_specs``'s ``tee``
to a file containing only lines that match _any_ of the filter strings. The log
file is aggregated across all ranks selected by ``tee``.
``tee`` acts like the unix "tee" command in that it redirects + prints to console.
To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter.
@ -144,6 +151,8 @@ def start_processes(
#. ``{local_rank}/error.json``: if the process failed, a file with the error info
#. ``{local_rank}/stdout.log``: if ``redirect & STDOUT == STDOUT``
#. ``{local_rank}/stderr.log``: if ``redirect & STDERR == STDERR``
#. ``filtered_stdout.log``: if ``duplicate_stdout_filters`` is non-empty
#. ``filtered_stderr.log``: if ``duplicate_stderr_filters`` is non-empty
.. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory.
@ -198,9 +207,13 @@ def start_processes(
log_dir: directory used to write log files
start_method: multiprocessing start method (spawn, fork, forkserver)
ignored for binaries
redirects: which std streams to redirect to a log file
tee: which std streams to redirect + print to console
logs_specs: defines ``log_dir``, ``redirects``, and ``tee``.
inside ``logs_specs``:
- redirects: which std streams to redirect to a log file
- tee: which std streams to redirect + print to console
local_ranks_filter: which ranks' logs to print to console
duplicate_stdout_filters: filters for the duplicated stdout logs
duplicate_stderr_filters: filters for the duplicated stderr logs
"""
@ -215,6 +228,8 @@ def start_processes(
entrypoint=entrypoint,
args=args,
envs=envs,
duplicate_stdout_filters=duplicate_stdout_filters,
duplicate_stderr_filters=duplicate_stderr_filters,
logs_specs=logs_specs,
log_line_prefixes=log_line_prefixes,
numa_options=numa_options,
@ -225,6 +240,8 @@ def start_processes(
entrypoint=entrypoint,
args=args,
envs=envs,
duplicate_stdout_filters=duplicate_stdout_filters,
duplicate_stderr_filters=duplicate_stderr_filters,
log_line_prefixes=log_line_prefixes,
start_method=start_method,
logs_specs=logs_specs,

View File

@ -193,6 +193,8 @@ class LogsDest:
tee_stdouts: dict[int, str] = field(default_factory=dict)
tee_stderrs: dict[int, str] = field(default_factory=dict)
error_files: dict[int, str] = field(default_factory=dict)
filtered_stdout: str = field(default_factory=str)
filtered_stderr: str = field(default_factory=str)
class LogsSpecs(ABC):
@ -290,6 +292,8 @@ class DefaultLogsSpecs(LogsSpecs):
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stdout.log`
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stderr.log`
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/error.json`
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/filtered_stdout.log`
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/filtered_stderr.log`
"""
nprocs = len(envs)
global_env = {} # use only to query properties that are not dependent on a rank
@ -386,7 +390,15 @@ class DefaultLogsSpecs(LogsSpecs):
)
envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file
return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files)
return LogsDest(
stdouts,
stderrs,
tee_stdouts,
tee_stderrs,
error_files,
os.path.join(attempt_log_dir, "filtered_stdout.log"),
os.path.join(attempt_log_dir, "filtered_stderr.log"),
)
def __repr__(self) -> str:
return (
@ -438,6 +450,16 @@ class PContext(abc.ABC):
.. warning:: stdouts and stderrs should ALWAYS be a superset of
tee_stdouts and tee_stderrs (respectively) this is b/c
tee is implemented as a redirect + tail -f <stdout/stderr.log>
Args:
duplicate_stdout_filters:
If non-empty, duplicates stdouts specified in ``logs_specs``'s ``tee``
to a file containing only lines that match _any_ of the filter strings.
The log file is aggregated across all ranks selected by ``tee``.
duplicate_stderr_filters:
If non-empty, duplicates stderrs specified in ``logs_specs``'s ``tee``
to a file containing only lines that match _any_ of the filter strings.
The log file is aggregated across all ranks selected by ``tee``.
"""
def __init__(
@ -448,6 +470,8 @@ class PContext(abc.ABC):
envs: dict[int, dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[dict[int, str]] = None,
duplicate_stdout_filters: Optional[list[str]] = None,
duplicate_stderr_filters: Optional[list[str]] = None,
):
self.name = name
# validate that all mappings have the same number of keys and
@ -467,13 +491,39 @@ class PContext(abc.ABC):
self.stderrs = logs_dest.stderrs
self.error_files = logs_dest.error_files
self.nprocs = nprocs
self.filtered_stdout = logs_dest.filtered_stdout
self.filtered_stderr = logs_dest.filtered_stderr
self._stdout_tail = TailLog(
name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes
)
self._stderr_tail = TailLog(
name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes
)
self._tail_logs = [
TailLog(name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes),
TailLog(name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes),
]
if duplicate_stdout_filters:
self._tail_logs.append(
TailLog(
name,
logs_dest.tee_stdouts,
self.filtered_stdout,
log_line_prefixes,
log_line_filter=lambda line: any(
needle in line for needle in duplicate_stdout_filters
),
)
)
if duplicate_stderr_filters:
self._tail_logs.append(
TailLog(
name,
logs_dest.tee_stderrs,
self.filtered_stderr,
log_line_prefixes,
log_line_filter=lambda line: any(
needle in line for needle in duplicate_stderr_filters
),
)
)
def start(self) -> None:
"""Start processes using parameters defined in the constructor."""
@ -517,8 +567,8 @@ class PContext(abc.ABC):
"This could lead to orphaned worker processes if the torchrun is terminated."
)
self._start()
self._stdout_tail.start()
self._stderr_tail.start()
for tail_log in self._tail_logs:
tail_log.start()
@abc.abstractmethod
def _start(self) -> None:
@ -605,10 +655,8 @@ class PContext(abc.ABC):
if not death_sig:
death_sig = _get_default_signal()
self._close(death_sig=death_sig, timeout=timeout)
if self._stdout_tail:
self._stdout_tail.stop()
if self._stderr_tail:
self._stderr_tail.stop()
for tail_log in self._tail_logs:
tail_log.stop()
def get_std_cm(std_rd: str, redirect_fn):
@ -661,6 +709,8 @@ class MultiprocessContext(PContext):
logs_specs: LogsSpecs,
log_line_prefixes: Optional[dict[int, str]] = None,
numa_options: Optional[NumaOptions] = None,
duplicate_stdout_filters: Optional[list[str]] = None,
duplicate_stderr_filters: Optional[list[str]] = None,
):
super().__init__(
name,
@ -669,6 +719,8 @@ class MultiprocessContext(PContext):
envs,
logs_specs,
log_line_prefixes,
duplicate_stdout_filters,
duplicate_stderr_filters,
)
self.start_method = start_method
@ -846,6 +898,8 @@ class SubprocessContext(PContext):
logs_specs: LogsSpecs,
log_line_prefixes: Optional[dict[int, str]] = None,
numa_options: Optional[NumaOptions] = None,
duplicate_stdout_filters: Optional[list[str]] = None,
duplicate_stderr_filters: Optional[list[str]] = None,
):
super().__init__(
name,
@ -854,6 +908,8 @@ class SubprocessContext(PContext):
envs,
logs_specs,
log_line_prefixes,
duplicate_stdout_filters,
duplicate_stderr_filters,
)
# state vector; _vdone[local_rank] -> is local_rank finished or not

View File

@ -12,11 +12,12 @@ import os
import time
from concurrent.futures.thread import ThreadPoolExecutor
from threading import Event
from typing import Optional, TextIO, TYPE_CHECKING
from typing import Callable, Optional, TextIO, TYPE_CHECKING, Union
if TYPE_CHECKING:
from concurrent.futures._base import Future
from io import TextIOWrapper
__all__ = ["tail_logfile", "TailLog"]
@ -24,7 +25,12 @@ logger = logging.getLogger(__name__)
def tail_logfile(
header: str, file: str, dst: TextIO, finished: Event, interval_sec: float
header: str,
file: str,
dst: TextIO,
finished: Event,
interval_sec: float,
log_line_filter: Optional[Callable[[str], bool]] = None,
):
while not os.path.exists(file):
if finished.is_set():
@ -36,7 +42,8 @@ def tail_logfile(
line = fp.readline()
if line:
dst.write(f"{header}{line}")
if log_line_filter and log_line_filter(line):
dst.write(f"{header}{line}")
else: # reached EOF
if finished.is_set():
# log line producer is finished
@ -90,9 +97,10 @@ class TailLog:
self,
name: str,
log_files: dict[int, str],
dst: TextIO,
dst: Union[TextIO, str],
log_line_prefixes: Optional[dict[int, str]] = None,
interval_sec: float = 0.1,
log_line_filter: Callable[[str], bool] = (lambda _: True),
):
n = len(log_files)
self._threadpool = None
@ -104,9 +112,22 @@ class TailLog:
)
self._name = name
self._dst = dst
self._dst_file: Optional[TextIOWrapper] = None
self._dst: Optional[Union[TextIO, TextIOWrapper]] = None
if isinstance(dst, str):
try:
self._dst_file = open(dst, mode="w", errors="replace")
self._dst = self._dst_file
except Exception:
logger.exception("error opening dst file %s.", dst)
self._dst = None
self._dst_file = None
else:
self._dst = dst
self._log_files = log_files
self._log_line_prefixes = log_line_prefixes
self._log_line_filter = log_line_filter
self._finished_events: dict[int, Event] = {
local_rank: Event() for local_rank in log_files.keys()
}
@ -115,7 +136,7 @@ class TailLog:
self._stopped = False
def start(self) -> "TailLog":
if not self._threadpool:
if not self._threadpool or not self._dst:
return self
for local_rank, file in self._log_files.items():
@ -130,6 +151,7 @@ class TailLog:
dst=self._dst,
finished=self._finished_events[local_rank],
interval_sec=self._interval_sec,
log_line_filter=self._log_line_filter,
)
)
return self
@ -152,6 +174,9 @@ class TailLog:
if self._threadpool:
self._threadpool.shutdown(wait=True)
if self._dst_file:
self._dst_file.close()
self._stopped = True
def stopped(self) -> bool:

View File

@ -71,6 +71,10 @@ class LaunchConfig:
local_ranks_filter: ranks for which to show logs in console. If not set, show from all.
event_log_handler: name of the event logging handler as registered in
`elastic/events/handlers.py <https://docs.pytorch.org/docs/stable/elastic/events.html>`_.
duplicate_stdout_filters: If non-empty, duplicates stdout to a file containing only lines
that match _any_ of the filter strings.
duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines
that match _any_ of the filter strings.
.. note::
@ -98,6 +102,8 @@ class LaunchConfig:
event_log_handler: str = "null"
numa_options: Optional[NumaOptions] = None
signals_to_handle: str = "SIGTERM,SIGINT,SIGHUP,SIGQUIT"
duplicate_stdout_filters: Optional[list[str]] = None
duplicate_stderr_filters: Optional[list[str]] = None
def __post_init__(self):
default_timeout = 900
@ -214,20 +220,22 @@ def launch_agent(
logger.info(
"Starting elastic_operator with launch configs:\n"
" entrypoint : %(entrypoint)s\n"
" min_nodes : %(min_nodes)s\n"
" max_nodes : %(max_nodes)s\n"
" nproc_per_node : %(nproc_per_node)s\n"
" run_id : %(run_id)s\n"
" rdzv_backend : %(rdzv_backend)s\n"
" rdzv_endpoint : %(rdzv_endpoint)s\n"
" rdzv_configs : %(rdzv_configs)s\n"
" max_restarts : %(max_restarts)s\n"
" monitor_interval : %(monitor_interval)s\n"
" log_dir : %(log_dir)s\n"
" metrics_cfg : %(metrics_cfg)s\n"
" event_log_handler : %(event_log_handler)s\n"
" numa_options : %(numa_options)s\n",
" entrypoint : %(entrypoint)s\n"
" min_nodes : %(min_nodes)s\n"
" max_nodes : %(max_nodes)s\n"
" nproc_per_node : %(nproc_per_node)s\n"
" run_id : %(run_id)s\n"
" rdzv_backend : %(rdzv_backend)s\n"
" rdzv_endpoint : %(rdzv_endpoint)s\n"
" rdzv_configs : %(rdzv_configs)s\n"
" max_restarts : %(max_restarts)s\n"
" monitor_interval : %(monitor_interval)s\n"
" log_dir : %(log_dir)s\n"
" metrics_cfg : %(metrics_cfg)s\n"
" event_log_handler : %(event_log_handler)s\n"
" numa_options : %(numa_options)s\n",
" duplicate_stdout_filters : %(duplicate_stdout_filters)s\n",
" duplicate_stderr_filters : %(duplicate_stderr_filters)s\n",
{
"entrypoint": entrypoint_name,
"min_nodes": config.min_nodes,
@ -244,6 +252,8 @@ def launch_agent(
"event_log_handler": config.event_log_handler,
"numa_options": config.numa_options,
"signals_to_handle": config.signals_to_handle,
"duplicate_stdout_filters": config.duplicate_stdout_filters,
"duplicate_stderr_filters": config.duplicate_stderr_filters,
},
)
@ -275,6 +285,8 @@ def launch_agent(
local_addr=config.local_addr,
event_log_handler=config.event_log_handler,
numa_options=config.numa_options,
duplicate_stdout_filters=config.duplicate_stdout_filters,
duplicate_stderr_filters=config.duplicate_stderr_filters,
)
agent = LocalElasticAgent(

View File

@ -399,6 +399,13 @@ def get_args_parser() -> ArgumentParser:
"""Parse the command line options."""
parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher")
def comma_separated_list(value):
placeholder = "<COMMA_PLACEHOLDER>"
value = value.replace(",,", placeholder)
items = value.split(",")
items = [item.replace(placeholder, ",") for item in items]
return items
#
# Worker/node size related arguments.
#
@ -571,6 +578,28 @@ def get_args_parser() -> ArgumentParser:
"log files saved via --redirect or --tee",
)
parser.add_argument(
"--duplicate-stdout-filters",
"--duplicate_stdout_filters",
action=env,
type=comma_separated_list,
default=[],
help="Duplicates logs streamed to stdout to another specified file with a list of filters (e.g. "
"[--duplicate_stdout_filters 'apple,orange'] will duplicate log lines matching 'apple' "
"OR 'orange'. An empty filters list won't duplicate any lines. Use double comma to escape a comma) ",
)
parser.add_argument(
"--duplicate-stderr-filters",
"--duplicate_stderr_filters",
action=env,
type=comma_separated_list,
default=[],
help="Duplicates logs streamed to stderr to another specified file with a list of filters (e.g. "
"[--duplicate_stdout_filters 'apple,orange'] will duplicate log lines matching 'apple' "
"OR 'orange'. An empty filters list won't duplicate any lines. Use double comma to escape a comma) ",
)
#
# Backwards compatible parameters with caffe2.distributed.launch.
#
@ -871,6 +900,8 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
event_log_handler=args.event_log_handler,
numa_options=numa_options,
signals_to_handle=args.signals_to_handle,
duplicate_stdout_filters=args.duplicate_stdout_filters,
duplicate_stderr_filters=args.duplicate_stderr_filters,
)
with_python = not args.no_python