mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytorch][torchelastic] Duplicate stdout and stderr and apply custom filter in torchrun (#160712)
Summary: Part of an effort to extract some important error logs (e.g. [#157996](https://github.com/pytorch/pytorch/pull/157996)) that was `tee`'ed to `stdout` and `stderr`. The general idea is to: - Duplicate the `tee`s on `stdout` and `stderr` to a separate file, `filtered_stdout.log` and `filtered_stderr.log`, respectively. - In these files, as its name suggests, only log lines matching a customizable filter. - Later on in another PR, append the contents of these files to the reply file. Outline of changes in this PR: - Enhance `TailLog` to be able to 1) stream to a file, and 2) only write when the line matches the passed filter. - Add `filtered_stdout` and `filtered_stderr` to `LogsDest` and have `LogsSpecs` `reify` them. - In `start_processes()` and `PContext`, add params `duplicate_stdout_filters` and `duplicate_stderr_filters` to filter and write the duplicated stream to the files above. When no filters are passed in, no duplicated streams are created. Test Plan: ``` $ buck test 'fbcode//mode/opt' caffe2/test/distributed/elastic/multiprocessing:api_test ``` ``` Buck UI: https://www.internalfb.com/buck2/f5c6b7da-217d-4a0b-872a-c7cd3d05587f Test UI: https://www.internalfb.com/intern/testinfra/testrun/4222124951617688 Network: Up: 398B Down: 44MiB (reSessionID-a489a961-b602-45be-b851-3490ebb7a26a) Analyzing targets. Remaining 0/200 Executing actions. Remaining 0/12856 0.1s exec time total Command: test. Finished 1 local Time elapsed: 17:37.9s Tests finished: Pass 52. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` ``` $ buck test 'fbcode//mode/opt' caffe2/test/distributed/elastic/multiprocessing:tail_log_test ``` ``` Buck UI: https://www.internalfb.com/buck2/d6d5c1c1-db98-4d9c-b608-7ba6fbb5e3ee Test UI: https://www.internalfb.com/intern/testinfra/testrun/13510798985149262 Network: Up: 94KiB Down: 417MiB (reSessionID-27b46fba-d31c-4c04-8ede-a506454e6922) Analyzing targets. Remaining 0/3 536 actions, 555 artifacts declared Executing actions. Remaining 0/186 1:05.5s exec time total Command: test. Finished 7 local, 1 remote, 115 cache (93% hit) 37.0s exec time cached (56%) Time elapsed: 1:11.5s Tests finished: Pass 7. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Rollback Plan: Differential Revision: D80188995 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160712 Approved by: https://github.com/fduwjj
This commit is contained in:
parent
2b93d5b450
commit
cbcb4f7768
|
|
@ -127,8 +127,9 @@ def echo1(msg: str, exitcode: int = 0) -> str:
|
|||
print(f"exit {exitcode} from {rank}", file=sys.stderr)
|
||||
sys.exit(exitcode)
|
||||
else:
|
||||
print(f"{msg} stdout from {rank}")
|
||||
print(f"{msg} stderr from {rank}", file=sys.stderr)
|
||||
for m in msg.split(","):
|
||||
print(f"{m} stdout from {rank}")
|
||||
print(f"{m} stderr from {rank}", file=sys.stderr)
|
||||
return f"{msg}_{rank}"
|
||||
|
||||
|
||||
|
|
@ -247,6 +248,13 @@ class _StartProcessesTest(TestCase):
|
|||
for line in expected:
|
||||
self.assertIn(line, actual)
|
||||
|
||||
def assert_not_in_file(self, lines: list[str], filename: str) -> None:
|
||||
lines = [f"{line.rstrip()}\n" for line in lines]
|
||||
with open(filename) as fp:
|
||||
actual = fp.readlines()
|
||||
for line in lines:
|
||||
self.assertNotIn(line, actual)
|
||||
|
||||
def assert_pids_noexist(self, pids: dict[int, int]):
|
||||
for local_rank, pid in pids.items():
|
||||
with self.assertRaises(
|
||||
|
|
@ -360,8 +368,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
|||
|
||||
self.assertIsNone(pc.wait(timeout=0.1, period=0.01))
|
||||
self.assertIsNotNone(pc.wait(period=0.1))
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_pcontext_wait_on_a_child_thread(self):
|
||||
asyncio.run(asyncio.to_thread(self.test_pcontext_wait))
|
||||
|
|
@ -379,8 +387,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
|||
pids = pc.pids()
|
||||
pc.close()
|
||||
self.assert_pids_noexist(pids)
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_function_with_tensor(self):
|
||||
for start_method in self._start_methods:
|
||||
|
|
@ -482,8 +490,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
|||
int(error_file_data["message"]["extraInfo"]["timestamp"]),
|
||||
int(failure.timestamp),
|
||||
)
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_wait_for_all_child_procs_to_exit(self):
|
||||
"""
|
||||
|
|
@ -580,8 +588,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
|||
self.assert_in_file([], results.stdouts[0])
|
||||
self.assertFalse(results.stderrs[1])
|
||||
self.assertFalse(results.stdouts[1])
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
failure = results.failures[1]
|
||||
self.assertEqual(-15, failure.exitcode)
|
||||
|
|
@ -731,8 +739,37 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
|||
self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
|
||||
self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
|
||||
self.assertFalse(pc.stdouts[1])
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_binary_duplicate_log_filters(self):
|
||||
pc = start_processes(
|
||||
name="trainer",
|
||||
entrypoint=bin("echo1.py"),
|
||||
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
|
||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(),
|
||||
redirects={0: Std.ERR, 1: Std.NONE},
|
||||
tee={0: Std.OUT, 1: Std.ERR},
|
||||
),
|
||||
log_line_prefixes={0: "[rank0]:", 1: "[rank1]:"},
|
||||
duplicate_stdout_filters=["helloA"],
|
||||
duplicate_stderr_filters=["worldA", "B"],
|
||||
start_method="spawn",
|
||||
)
|
||||
|
||||
result = pc.wait()
|
||||
|
||||
self.assertFalse(result.is_failed())
|
||||
self.assert_in_file(["[rank0]:helloA stdout from 0"], pc.filtered_stdout)
|
||||
self.assert_not_in_file(
|
||||
["[rank0]:helloB stdout from 0"], pc.filtered_stdout
|
||||
)
|
||||
self.assert_in_file(["[rank1]:worldA stderr from 1"], pc.filtered_stderr)
|
||||
self.assert_in_file(["[rank1]:worldB stderr from 1"], pc.filtered_stderr)
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
|
||||
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
|
||||
|
|
@ -794,8 +831,44 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
|||
self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
|
||||
self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
|
||||
self.assertFalse(pc.stdouts[1])
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_function_duplicate_log_filters(self):
|
||||
for start_method in self._start_methods:
|
||||
with self.subTest(start_method=start_method):
|
||||
pc = start_processes(
|
||||
name="trainer",
|
||||
entrypoint=echo1,
|
||||
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
|
||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
log_dir=self.log_dir(),
|
||||
redirects={0: Std.ERR, 1: Std.NONE},
|
||||
tee={0: Std.OUT, 1: Std.ERR},
|
||||
),
|
||||
duplicate_stdout_filters=["helloA"],
|
||||
duplicate_stderr_filters=["worldA", "B"],
|
||||
start_method="spawn",
|
||||
)
|
||||
|
||||
result = pc.wait()
|
||||
|
||||
self.assertFalse(result.is_failed())
|
||||
self.assert_in_file(
|
||||
["[trainer0]:helloA stdout from 0"], pc.filtered_stdout
|
||||
)
|
||||
self.assert_not_in_file(
|
||||
["[trainer0]:helloB stdout from 0"], pc.filtered_stdout
|
||||
)
|
||||
self.assert_in_file(
|
||||
["[trainer1]:worldA stderr from 1"], pc.filtered_stderr
|
||||
)
|
||||
self.assert_in_file(
|
||||
["[trainer1]:worldB stderr from 1"], pc.filtered_stderr
|
||||
)
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_function(self):
|
||||
for start_method, redirs in product(self._start_methods, redirects_all()):
|
||||
|
|
@ -880,8 +953,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
|
|||
self.assertFalse(results.stdouts[0])
|
||||
self.assertFalse(results.stderrs[1])
|
||||
self.assertFalse(results.stdouts[1])
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
for tail_log in pc._tail_logs:
|
||||
self.assertTrue(tail_log.stopped())
|
||||
|
||||
def test_no_zombie_process_function(self):
|
||||
signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
|
||||
|
|
|
|||
|
|
@ -23,5 +23,6 @@ if __name__ == "__main__":
|
|||
print(f"exit {exitcode} from {rank}", file=sys.stderr)
|
||||
sys.exit(exitcode)
|
||||
else:
|
||||
print(f"{args.msg} stdout from {rank}")
|
||||
print(f"{args.msg} stderr from {rank}", file=sys.stderr)
|
||||
for msg in args.msg.split(","):
|
||||
print(f"{msg} stdout from {rank}")
|
||||
print(f"{msg} stderr from {rank}", file=sys.stderr)
|
||||
|
|
|
|||
|
|
@ -84,6 +84,53 @@ class TailLogTest(unittest.TestCase):
|
|||
)
|
||||
self.assertTrue(tail.stopped())
|
||||
|
||||
def test_tail_write_to_dst_file(self):
|
||||
"""
|
||||
writer() writes 0 - max (on number on each line) to a log file.
|
||||
Run nprocs such writers and tail the log files into a temp file
|
||||
and validate that all lines are accounted for.
|
||||
"""
|
||||
nprocs = 32
|
||||
max = 1000
|
||||
interval_sec = 0.0001
|
||||
|
||||
log_files = {
|
||||
local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log")
|
||||
for local_rank in range(nprocs)
|
||||
}
|
||||
|
||||
dst = os.path.join(self.test_dir, "tailed_stdout.log")
|
||||
tail = TailLog(
|
||||
name="writer", log_files=log_files, dst=dst, interval_sec=interval_sec
|
||||
).start()
|
||||
# sleep here is intentional to ensure that the log tail
|
||||
# can gracefully handle and wait for non-existent log files
|
||||
time.sleep(interval_sec * 10)
|
||||
|
||||
futs = []
|
||||
for local_rank, file in log_files.items():
|
||||
f = self.threadpool.submit(
|
||||
write, max=max, sleep=interval_sec * local_rank, file=file
|
||||
)
|
||||
futs.append(f)
|
||||
|
||||
wait(futs, return_when=ALL_COMPLETED)
|
||||
self.assertFalse(tail.stopped())
|
||||
tail.stop()
|
||||
|
||||
actual: dict[int, set[int]] = {}
|
||||
with open(dst) as dst_file:
|
||||
for line in dst_file:
|
||||
header, num = line.split(":")
|
||||
nums = actual.setdefault(header, set())
|
||||
nums.add(int(num))
|
||||
|
||||
self.assertEqual(nprocs, len(actual))
|
||||
self.assertEqual(
|
||||
{f"[writer{i}]": set(range(max)) for i in range(nprocs)}, actual
|
||||
)
|
||||
self.assertTrue(tail.stopped())
|
||||
|
||||
def test_tail_with_custom_prefix(self):
|
||||
"""
|
||||
writer() writes 0 - max (on number on each line) to a log file.
|
||||
|
|
@ -131,6 +178,52 @@ class TailLogTest(unittest.TestCase):
|
|||
self.assertIn(f"[worker{i}][{i}]", headers)
|
||||
self.assertTrue(tail.stopped())
|
||||
|
||||
def test_tail_with_custom_filter(self):
|
||||
"""
|
||||
writer() writes 0 - max (on number on each line) to a log file.
|
||||
Run nprocs such writers and tail the log files into an IOString
|
||||
and validate that all lines are accounted for.
|
||||
"""
|
||||
nprocs = 3
|
||||
max = 20
|
||||
interval_sec = 0.0001
|
||||
|
||||
log_files = {
|
||||
local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log")
|
||||
for local_rank in range(nprocs)
|
||||
}
|
||||
|
||||
dst = io.StringIO()
|
||||
tail = TailLog(
|
||||
"writer",
|
||||
log_files,
|
||||
dst,
|
||||
interval_sec=interval_sec,
|
||||
log_line_filter=lambda line: "2" in line, # only print lines containing '2'
|
||||
).start()
|
||||
# sleep here is intentional to ensure that the log tail
|
||||
# can gracefully handle and wait for non-existent log files
|
||||
time.sleep(interval_sec * 10)
|
||||
futs = []
|
||||
for local_rank, file in log_files.items():
|
||||
f = self.threadpool.submit(
|
||||
write, max=max, sleep=interval_sec * local_rank, file=file
|
||||
)
|
||||
futs.append(f)
|
||||
wait(futs, return_when=ALL_COMPLETED)
|
||||
self.assertFalse(tail.stopped())
|
||||
tail.stop()
|
||||
dst.seek(0)
|
||||
|
||||
actual: dict[int, set[int]] = {}
|
||||
for line in dst.readlines():
|
||||
header, num = line.split(":")
|
||||
nums = actual.setdefault(header, set())
|
||||
nums.add(int(num))
|
||||
self.assertEqual(nprocs, len(actual))
|
||||
self.assertEqual({f"[writer{i}]": {2, 12} for i in range(nprocs)}, actual)
|
||||
self.assertTrue(tail.stopped())
|
||||
|
||||
def test_tail_no_files(self):
|
||||
"""
|
||||
Ensures that the log tail can gracefully handle no log files
|
||||
|
|
|
|||
|
|
@ -55,9 +55,10 @@ class SignalHandlingTest(TestCase):
|
|||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Remove environment variable if it exists to test default behavior
|
||||
if "TORCHELASTIC_SIGNALS_TO_HANDLE" in os.environ:
|
||||
|
|
@ -84,8 +85,8 @@ class SignalHandlingTest(TestCase):
|
|||
# Verify _start was called
|
||||
mock_pcontext._start.assert_called_once()
|
||||
# Verify _stdout_tail.start() and _stderr_tail.start() were called
|
||||
mock_pcontext._stdout_tail.start.assert_called_once()
|
||||
mock_pcontext._stderr_tail.start.assert_called_once()
|
||||
mock_stdout_tail.start.assert_called_once()
|
||||
mock_stderr_tail.start.assert_called_once()
|
||||
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.threading")
|
||||
@patch("torch.distributed.elastic.multiprocessing.api.signal")
|
||||
|
|
@ -99,9 +100,10 @@ class SignalHandlingTest(TestCase):
|
|||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Set custom signals in the environment variable
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGUSR1,SIGUSR2"
|
||||
|
|
@ -139,9 +141,10 @@ class SignalHandlingTest(TestCase):
|
|||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Set invalid signals in the environment variable
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,INVALID_SIGNAL"
|
||||
|
|
@ -180,9 +183,10 @@ class SignalHandlingTest(TestCase):
|
|||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Set signals including ones not supported on Windows
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGTERM,SIGHUP,SIGUSR1"
|
||||
|
|
@ -234,9 +238,10 @@ class SignalHandlingTest(TestCase):
|
|||
mock_threading.current_thread.return_value = MagicMock() # Not the main thread
|
||||
mock_threading.main_thread.return_value = MagicMock()
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Call the start method
|
||||
PContext.start(mock_pcontext)
|
||||
|
|
@ -262,9 +267,10 @@ class SignalHandlingTest(TestCase):
|
|||
mock_threading.main_thread.return_value
|
||||
)
|
||||
mock_pcontext = MagicMock(spec=PContext)
|
||||
# Mock the _stdout_tail and _stderr_tail attributes
|
||||
mock_pcontext._stdout_tail = MagicMock()
|
||||
mock_pcontext._stderr_tail = MagicMock()
|
||||
# Mock the stdout_tail and stderr_tail
|
||||
mock_stdout_tail = MagicMock()
|
||||
mock_stderr_tail = MagicMock()
|
||||
mock_pcontext._tail_logs = [mock_stdout_tail, mock_stderr_tail]
|
||||
|
||||
# Set environment variable to include SIGUSR1 and SIGUSR2
|
||||
os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = "SIGUSR1,SIGUSR2"
|
||||
|
|
@ -323,8 +329,8 @@ class SignalHandlingTest(TestCase):
|
|||
# Verify _start was called
|
||||
mock_pcontext._start.assert_called_once()
|
||||
# Verify _stdout_tail.start() and _stderr_tail.start() were called
|
||||
mock_pcontext._stdout_tail.start.assert_called_once()
|
||||
mock_pcontext._stderr_tail.start.assert_called_once()
|
||||
mock_stdout_tail.start.assert_called_once()
|
||||
mock_stderr_tail.start.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -75,6 +75,10 @@ class WorkerSpec:
|
|||
takes precedence over ``redirects`` settings.
|
||||
event_log_handler: name of the event logging handler as registered in
|
||||
`elastic/events/handlers.py <https://docs.pytorch.org/docs/stable/elastic/events.html>`_.
|
||||
duplicate_stdout_filters: If non-empty, duplicates stdout to a file containing only lines
|
||||
that match _any_ of the filter strings.
|
||||
duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines
|
||||
that match _any_ of the filter strings.
|
||||
"""
|
||||
|
||||
role: str
|
||||
|
|
@ -91,6 +95,8 @@ class WorkerSpec:
|
|||
local_addr: Optional[str] = None
|
||||
event_log_handler: str = "null"
|
||||
numa_options: Optional[NumaOptions] = None
|
||||
duplicate_stdout_filters: Optional[list[str]] = None
|
||||
duplicate_stderr_filters: Optional[list[str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.local_world_size > 0
|
||||
|
|
|
|||
|
|
@ -356,6 +356,8 @@ class LocalElasticAgent(SimpleElasticAgent):
|
|||
log_line_prefixes=log_line_prefixes,
|
||||
start_method=self._start_method,
|
||||
numa_options=spec.numa_options,
|
||||
duplicate_stdout_filters=spec.duplicate_stdout_filters,
|
||||
duplicate_stderr_filters=spec.duplicate_stderr_filters,
|
||||
)
|
||||
|
||||
return self._pcontext.pids()
|
||||
|
|
|
|||
|
|
@ -109,6 +109,8 @@ def start_processes(
|
|||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
start_method: str = "spawn",
|
||||
numa_options: Optional[NumaOptions] = None,
|
||||
duplicate_stdout_filters: Optional[list[str]] = None,
|
||||
duplicate_stderr_filters: Optional[list[str]] = None,
|
||||
) -> PContext:
|
||||
"""
|
||||
Start ``n`` copies of ``entrypoint`` processes with the provided options.
|
||||
|
|
@ -130,12 +132,17 @@ def start_processes(
|
|||
this is done by default and there is no need to manually annotate
|
||||
with the ``@record`` annotation.
|
||||
|
||||
``redirects`` and ``tee`` are bitmasks specifying which std stream(s) to redirect
|
||||
to a log file in the ``log_dir``. Valid mask values are defined in ``Std``.
|
||||
To redirect/tee only certain local ranks, pass ``redirects`` as a map with the key as
|
||||
the local rank to specify the redirect behavior for.
|
||||
Inside ``logs_specs``, ``redirects`` and ``tee`` are bitmasks specifying which std
|
||||
stream(s) to redirect to a log file in the ``log_dir``. Valid mask values are defined
|
||||
in ``Std``. To redirect/tee only certain local ranks, pass ``redirects`` as a map
|
||||
with the key as the local rank to specify the redirect behavior for.
|
||||
Any missing local ranks will default to ``Std.NONE``.
|
||||
|
||||
``duplicate_stdout_filters`` and ``duplicate_stderr_filters``, if non-empty,
|
||||
duplicate stdouts and stderrs respectively specified in ``logs_specs``'s ``tee``
|
||||
to a file containing only lines that match _any_ of the filter strings. The log
|
||||
file is aggregated across all ranks selected by ``tee``.
|
||||
|
||||
``tee`` acts like the unix "tee" command in that it redirects + prints to console.
|
||||
To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter.
|
||||
|
||||
|
|
@ -144,6 +151,8 @@ def start_processes(
|
|||
#. ``{local_rank}/error.json``: if the process failed, a file with the error info
|
||||
#. ``{local_rank}/stdout.log``: if ``redirect & STDOUT == STDOUT``
|
||||
#. ``{local_rank}/stderr.log``: if ``redirect & STDERR == STDERR``
|
||||
#. ``filtered_stdout.log``: if ``duplicate_stdout_filters`` is non-empty
|
||||
#. ``filtered_stderr.log``: if ``duplicate_stderr_filters`` is non-empty
|
||||
|
||||
.. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory.
|
||||
|
||||
|
|
@ -198,9 +207,13 @@ def start_processes(
|
|||
log_dir: directory used to write log files
|
||||
start_method: multiprocessing start method (spawn, fork, forkserver)
|
||||
ignored for binaries
|
||||
redirects: which std streams to redirect to a log file
|
||||
tee: which std streams to redirect + print to console
|
||||
logs_specs: defines ``log_dir``, ``redirects``, and ``tee``.
|
||||
inside ``logs_specs``:
|
||||
- redirects: which std streams to redirect to a log file
|
||||
- tee: which std streams to redirect + print to console
|
||||
local_ranks_filter: which ranks' logs to print to console
|
||||
duplicate_stdout_filters: filters for the duplicated stdout logs
|
||||
duplicate_stderr_filters: filters for the duplicated stderr logs
|
||||
|
||||
"""
|
||||
|
||||
|
|
@ -215,6 +228,8 @@ def start_processes(
|
|||
entrypoint=entrypoint,
|
||||
args=args,
|
||||
envs=envs,
|
||||
duplicate_stdout_filters=duplicate_stdout_filters,
|
||||
duplicate_stderr_filters=duplicate_stderr_filters,
|
||||
logs_specs=logs_specs,
|
||||
log_line_prefixes=log_line_prefixes,
|
||||
numa_options=numa_options,
|
||||
|
|
@ -225,6 +240,8 @@ def start_processes(
|
|||
entrypoint=entrypoint,
|
||||
args=args,
|
||||
envs=envs,
|
||||
duplicate_stdout_filters=duplicate_stdout_filters,
|
||||
duplicate_stderr_filters=duplicate_stderr_filters,
|
||||
log_line_prefixes=log_line_prefixes,
|
||||
start_method=start_method,
|
||||
logs_specs=logs_specs,
|
||||
|
|
|
|||
|
|
@ -193,6 +193,8 @@ class LogsDest:
|
|||
tee_stdouts: dict[int, str] = field(default_factory=dict)
|
||||
tee_stderrs: dict[int, str] = field(default_factory=dict)
|
||||
error_files: dict[int, str] = field(default_factory=dict)
|
||||
filtered_stdout: str = field(default_factory=str)
|
||||
filtered_stderr: str = field(default_factory=str)
|
||||
|
||||
|
||||
class LogsSpecs(ABC):
|
||||
|
|
@ -290,6 +292,8 @@ class DefaultLogsSpecs(LogsSpecs):
|
|||
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stdout.log`
|
||||
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stderr.log`
|
||||
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/error.json`
|
||||
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/filtered_stdout.log`
|
||||
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/filtered_stderr.log`
|
||||
"""
|
||||
nprocs = len(envs)
|
||||
global_env = {} # use only to query properties that are not dependent on a rank
|
||||
|
|
@ -386,7 +390,15 @@ class DefaultLogsSpecs(LogsSpecs):
|
|||
)
|
||||
envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file
|
||||
|
||||
return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files)
|
||||
return LogsDest(
|
||||
stdouts,
|
||||
stderrs,
|
||||
tee_stdouts,
|
||||
tee_stderrs,
|
||||
error_files,
|
||||
os.path.join(attempt_log_dir, "filtered_stdout.log"),
|
||||
os.path.join(attempt_log_dir, "filtered_stderr.log"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
|
@ -438,6 +450,16 @@ class PContext(abc.ABC):
|
|||
.. warning:: stdouts and stderrs should ALWAYS be a superset of
|
||||
tee_stdouts and tee_stderrs (respectively) this is b/c
|
||||
tee is implemented as a redirect + tail -f <stdout/stderr.log>
|
||||
|
||||
Args:
|
||||
duplicate_stdout_filters:
|
||||
If non-empty, duplicates stdouts specified in ``logs_specs``'s ``tee``
|
||||
to a file containing only lines that match _any_ of the filter strings.
|
||||
The log file is aggregated across all ranks selected by ``tee``.
|
||||
duplicate_stderr_filters:
|
||||
If non-empty, duplicates stderrs specified in ``logs_specs``'s ``tee``
|
||||
to a file containing only lines that match _any_ of the filter strings.
|
||||
The log file is aggregated across all ranks selected by ``tee``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -448,6 +470,8 @@ class PContext(abc.ABC):
|
|||
envs: dict[int, dict[str, str]],
|
||||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
duplicate_stdout_filters: Optional[list[str]] = None,
|
||||
duplicate_stderr_filters: Optional[list[str]] = None,
|
||||
):
|
||||
self.name = name
|
||||
# validate that all mappings have the same number of keys and
|
||||
|
|
@ -467,13 +491,39 @@ class PContext(abc.ABC):
|
|||
self.stderrs = logs_dest.stderrs
|
||||
self.error_files = logs_dest.error_files
|
||||
self.nprocs = nprocs
|
||||
self.filtered_stdout = logs_dest.filtered_stdout
|
||||
self.filtered_stderr = logs_dest.filtered_stderr
|
||||
|
||||
self._stdout_tail = TailLog(
|
||||
name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes
|
||||
)
|
||||
self._stderr_tail = TailLog(
|
||||
name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes
|
||||
)
|
||||
self._tail_logs = [
|
||||
TailLog(name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes),
|
||||
TailLog(name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes),
|
||||
]
|
||||
|
||||
if duplicate_stdout_filters:
|
||||
self._tail_logs.append(
|
||||
TailLog(
|
||||
name,
|
||||
logs_dest.tee_stdouts,
|
||||
self.filtered_stdout,
|
||||
log_line_prefixes,
|
||||
log_line_filter=lambda line: any(
|
||||
needle in line for needle in duplicate_stdout_filters
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if duplicate_stderr_filters:
|
||||
self._tail_logs.append(
|
||||
TailLog(
|
||||
name,
|
||||
logs_dest.tee_stderrs,
|
||||
self.filtered_stderr,
|
||||
log_line_prefixes,
|
||||
log_line_filter=lambda line: any(
|
||||
needle in line for needle in duplicate_stderr_filters
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start processes using parameters defined in the constructor."""
|
||||
|
|
@ -517,8 +567,8 @@ class PContext(abc.ABC):
|
|||
"This could lead to orphaned worker processes if the torchrun is terminated."
|
||||
)
|
||||
self._start()
|
||||
self._stdout_tail.start()
|
||||
self._stderr_tail.start()
|
||||
for tail_log in self._tail_logs:
|
||||
tail_log.start()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _start(self) -> None:
|
||||
|
|
@ -605,10 +655,8 @@ class PContext(abc.ABC):
|
|||
if not death_sig:
|
||||
death_sig = _get_default_signal()
|
||||
self._close(death_sig=death_sig, timeout=timeout)
|
||||
if self._stdout_tail:
|
||||
self._stdout_tail.stop()
|
||||
if self._stderr_tail:
|
||||
self._stderr_tail.stop()
|
||||
for tail_log in self._tail_logs:
|
||||
tail_log.stop()
|
||||
|
||||
|
||||
def get_std_cm(std_rd: str, redirect_fn):
|
||||
|
|
@ -661,6 +709,8 @@ class MultiprocessContext(PContext):
|
|||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
numa_options: Optional[NumaOptions] = None,
|
||||
duplicate_stdout_filters: Optional[list[str]] = None,
|
||||
duplicate_stderr_filters: Optional[list[str]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
|
|
@ -669,6 +719,8 @@ class MultiprocessContext(PContext):
|
|||
envs,
|
||||
logs_specs,
|
||||
log_line_prefixes,
|
||||
duplicate_stdout_filters,
|
||||
duplicate_stderr_filters,
|
||||
)
|
||||
|
||||
self.start_method = start_method
|
||||
|
|
@ -846,6 +898,8 @@ class SubprocessContext(PContext):
|
|||
logs_specs: LogsSpecs,
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
numa_options: Optional[NumaOptions] = None,
|
||||
duplicate_stdout_filters: Optional[list[str]] = None,
|
||||
duplicate_stderr_filters: Optional[list[str]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
|
|
@ -854,6 +908,8 @@ class SubprocessContext(PContext):
|
|||
envs,
|
||||
logs_specs,
|
||||
log_line_prefixes,
|
||||
duplicate_stdout_filters,
|
||||
duplicate_stderr_filters,
|
||||
)
|
||||
|
||||
# state vector; _vdone[local_rank] -> is local_rank finished or not
|
||||
|
|
|
|||
|
|
@ -12,11 +12,12 @@ import os
|
|||
import time
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from threading import Event
|
||||
from typing import Optional, TextIO, TYPE_CHECKING
|
||||
from typing import Callable, Optional, TextIO, TYPE_CHECKING, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from concurrent.futures._base import Future
|
||||
from io import TextIOWrapper
|
||||
|
||||
__all__ = ["tail_logfile", "TailLog"]
|
||||
|
||||
|
|
@ -24,7 +25,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def tail_logfile(
|
||||
header: str, file: str, dst: TextIO, finished: Event, interval_sec: float
|
||||
header: str,
|
||||
file: str,
|
||||
dst: TextIO,
|
||||
finished: Event,
|
||||
interval_sec: float,
|
||||
log_line_filter: Optional[Callable[[str], bool]] = None,
|
||||
):
|
||||
while not os.path.exists(file):
|
||||
if finished.is_set():
|
||||
|
|
@ -36,7 +42,8 @@ def tail_logfile(
|
|||
line = fp.readline()
|
||||
|
||||
if line:
|
||||
dst.write(f"{header}{line}")
|
||||
if log_line_filter and log_line_filter(line):
|
||||
dst.write(f"{header}{line}")
|
||||
else: # reached EOF
|
||||
if finished.is_set():
|
||||
# log line producer is finished
|
||||
|
|
@ -90,9 +97,10 @@ class TailLog:
|
|||
self,
|
||||
name: str,
|
||||
log_files: dict[int, str],
|
||||
dst: TextIO,
|
||||
dst: Union[TextIO, str],
|
||||
log_line_prefixes: Optional[dict[int, str]] = None,
|
||||
interval_sec: float = 0.1,
|
||||
log_line_filter: Callable[[str], bool] = (lambda _: True),
|
||||
):
|
||||
n = len(log_files)
|
||||
self._threadpool = None
|
||||
|
|
@ -104,9 +112,22 @@ class TailLog:
|
|||
)
|
||||
|
||||
self._name = name
|
||||
self._dst = dst
|
||||
self._dst_file: Optional[TextIOWrapper] = None
|
||||
self._dst: Optional[Union[TextIO, TextIOWrapper]] = None
|
||||
if isinstance(dst, str):
|
||||
try:
|
||||
self._dst_file = open(dst, mode="w", errors="replace")
|
||||
self._dst = self._dst_file
|
||||
except Exception:
|
||||
logger.exception("error opening dst file %s.", dst)
|
||||
self._dst = None
|
||||
self._dst_file = None
|
||||
|
||||
else:
|
||||
self._dst = dst
|
||||
self._log_files = log_files
|
||||
self._log_line_prefixes = log_line_prefixes
|
||||
self._log_line_filter = log_line_filter
|
||||
self._finished_events: dict[int, Event] = {
|
||||
local_rank: Event() for local_rank in log_files.keys()
|
||||
}
|
||||
|
|
@ -115,7 +136,7 @@ class TailLog:
|
|||
self._stopped = False
|
||||
|
||||
def start(self) -> "TailLog":
|
||||
if not self._threadpool:
|
||||
if not self._threadpool or not self._dst:
|
||||
return self
|
||||
|
||||
for local_rank, file in self._log_files.items():
|
||||
|
|
@ -130,6 +151,7 @@ class TailLog:
|
|||
dst=self._dst,
|
||||
finished=self._finished_events[local_rank],
|
||||
interval_sec=self._interval_sec,
|
||||
log_line_filter=self._log_line_filter,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
|
@ -152,6 +174,9 @@ class TailLog:
|
|||
if self._threadpool:
|
||||
self._threadpool.shutdown(wait=True)
|
||||
|
||||
if self._dst_file:
|
||||
self._dst_file.close()
|
||||
|
||||
self._stopped = True
|
||||
|
||||
def stopped(self) -> bool:
|
||||
|
|
|
|||
|
|
@ -71,6 +71,10 @@ class LaunchConfig:
|
|||
local_ranks_filter: ranks for which to show logs in console. If not set, show from all.
|
||||
event_log_handler: name of the event logging handler as registered in
|
||||
`elastic/events/handlers.py <https://docs.pytorch.org/docs/stable/elastic/events.html>`_.
|
||||
duplicate_stdout_filters: If non-empty, duplicates stdout to a file containing only lines
|
||||
that match _any_ of the filter strings.
|
||||
duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines
|
||||
that match _any_ of the filter strings.
|
||||
|
||||
|
||||
.. note::
|
||||
|
|
@ -98,6 +102,8 @@ class LaunchConfig:
|
|||
event_log_handler: str = "null"
|
||||
numa_options: Optional[NumaOptions] = None
|
||||
signals_to_handle: str = "SIGTERM,SIGINT,SIGHUP,SIGQUIT"
|
||||
duplicate_stdout_filters: Optional[list[str]] = None
|
||||
duplicate_stderr_filters: Optional[list[str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
default_timeout = 900
|
||||
|
|
@ -214,20 +220,22 @@ def launch_agent(
|
|||
|
||||
logger.info(
|
||||
"Starting elastic_operator with launch configs:\n"
|
||||
" entrypoint : %(entrypoint)s\n"
|
||||
" min_nodes : %(min_nodes)s\n"
|
||||
" max_nodes : %(max_nodes)s\n"
|
||||
" nproc_per_node : %(nproc_per_node)s\n"
|
||||
" run_id : %(run_id)s\n"
|
||||
" rdzv_backend : %(rdzv_backend)s\n"
|
||||
" rdzv_endpoint : %(rdzv_endpoint)s\n"
|
||||
" rdzv_configs : %(rdzv_configs)s\n"
|
||||
" max_restarts : %(max_restarts)s\n"
|
||||
" monitor_interval : %(monitor_interval)s\n"
|
||||
" log_dir : %(log_dir)s\n"
|
||||
" metrics_cfg : %(metrics_cfg)s\n"
|
||||
" event_log_handler : %(event_log_handler)s\n"
|
||||
" numa_options : %(numa_options)s\n",
|
||||
" entrypoint : %(entrypoint)s\n"
|
||||
" min_nodes : %(min_nodes)s\n"
|
||||
" max_nodes : %(max_nodes)s\n"
|
||||
" nproc_per_node : %(nproc_per_node)s\n"
|
||||
" run_id : %(run_id)s\n"
|
||||
" rdzv_backend : %(rdzv_backend)s\n"
|
||||
" rdzv_endpoint : %(rdzv_endpoint)s\n"
|
||||
" rdzv_configs : %(rdzv_configs)s\n"
|
||||
" max_restarts : %(max_restarts)s\n"
|
||||
" monitor_interval : %(monitor_interval)s\n"
|
||||
" log_dir : %(log_dir)s\n"
|
||||
" metrics_cfg : %(metrics_cfg)s\n"
|
||||
" event_log_handler : %(event_log_handler)s\n"
|
||||
" numa_options : %(numa_options)s\n",
|
||||
" duplicate_stdout_filters : %(duplicate_stdout_filters)s\n",
|
||||
" duplicate_stderr_filters : %(duplicate_stderr_filters)s\n",
|
||||
{
|
||||
"entrypoint": entrypoint_name,
|
||||
"min_nodes": config.min_nodes,
|
||||
|
|
@ -244,6 +252,8 @@ def launch_agent(
|
|||
"event_log_handler": config.event_log_handler,
|
||||
"numa_options": config.numa_options,
|
||||
"signals_to_handle": config.signals_to_handle,
|
||||
"duplicate_stdout_filters": config.duplicate_stdout_filters,
|
||||
"duplicate_stderr_filters": config.duplicate_stderr_filters,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -275,6 +285,8 @@ def launch_agent(
|
|||
local_addr=config.local_addr,
|
||||
event_log_handler=config.event_log_handler,
|
||||
numa_options=config.numa_options,
|
||||
duplicate_stdout_filters=config.duplicate_stdout_filters,
|
||||
duplicate_stderr_filters=config.duplicate_stderr_filters,
|
||||
)
|
||||
|
||||
agent = LocalElasticAgent(
|
||||
|
|
|
|||
|
|
@ -399,6 +399,13 @@ def get_args_parser() -> ArgumentParser:
|
|||
"""Parse the command line options."""
|
||||
parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher")
|
||||
|
||||
def comma_separated_list(value):
|
||||
placeholder = "<COMMA_PLACEHOLDER>"
|
||||
value = value.replace(",,", placeholder)
|
||||
items = value.split(",")
|
||||
items = [item.replace(placeholder, ",") for item in items]
|
||||
return items
|
||||
|
||||
#
|
||||
# Worker/node size related arguments.
|
||||
#
|
||||
|
|
@ -571,6 +578,28 @@ def get_args_parser() -> ArgumentParser:
|
|||
"log files saved via --redirect or --tee",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--duplicate-stdout-filters",
|
||||
"--duplicate_stdout_filters",
|
||||
action=env,
|
||||
type=comma_separated_list,
|
||||
default=[],
|
||||
help="Duplicates logs streamed to stdout to another specified file with a list of filters (e.g. "
|
||||
"[--duplicate_stdout_filters 'apple,orange'] will duplicate log lines matching 'apple' "
|
||||
"OR 'orange'. An empty filters list won't duplicate any lines. Use double comma to escape a comma) ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--duplicate-stderr-filters",
|
||||
"--duplicate_stderr_filters",
|
||||
action=env,
|
||||
type=comma_separated_list,
|
||||
default=[],
|
||||
help="Duplicates logs streamed to stderr to another specified file with a list of filters (e.g. "
|
||||
"[--duplicate_stdout_filters 'apple,orange'] will duplicate log lines matching 'apple' "
|
||||
"OR 'orange'. An empty filters list won't duplicate any lines. Use double comma to escape a comma) ",
|
||||
)
|
||||
|
||||
#
|
||||
# Backwards compatible parameters with caffe2.distributed.launch.
|
||||
#
|
||||
|
|
@ -871,6 +900,8 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
|
|||
event_log_handler=args.event_log_handler,
|
||||
numa_options=numa_options,
|
||||
signals_to_handle=args.signals_to_handle,
|
||||
duplicate_stdout_filters=args.duplicate_stdout_filters,
|
||||
duplicate_stderr_filters=args.duplicate_stderr_filters,
|
||||
)
|
||||
|
||||
with_python = not args.no_python
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user