diff --git a/test/distributed/elastic/multiprocessing/api_test.py b/test/distributed/elastic/multiprocessing/api_test.py index 458859356c4..4ac0dcacb4b 100644 --- a/test/distributed/elastic/multiprocessing/api_test.py +++ b/test/distributed/elastic/multiprocessing/api_test.py @@ -559,7 +559,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): FAIL = 138 pc = start_processes( name="echo", - entrypoint=bin("echo1.py"), + entrypoint=bin("echo4.py"), args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, logs_specs=DefaultLogsSpecs( diff --git a/test/distributed/elastic/multiprocessing/bin/echo1.py b/test/distributed/elastic/multiprocessing/bin/echo1.py index 5ffa5bd9045..8bcd574e8d8 100755 --- a/test/distributed/elastic/multiprocessing/bin/echo1.py +++ b/test/distributed/elastic/multiprocessing/bin/echo1.py @@ -9,7 +9,6 @@ import argparse import os import sys -import time if __name__ == "__main__": @@ -24,6 +23,5 @@ if __name__ == "__main__": print(f"exit {exitcode} from {rank}", file=sys.stderr) sys.exit(exitcode) else: - time.sleep(1000) print(f"{args.msg} stdout from {rank}") print(f"{args.msg} stderr from {rank}", file=sys.stderr) diff --git a/test/distributed/elastic/multiprocessing/bin/echo4.py b/test/distributed/elastic/multiprocessing/bin/echo4.py new file mode 100755 index 00000000000..5ffa5bd9045 --- /dev/null +++ b/test/distributed/elastic/multiprocessing/bin/echo4.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import sys +import time + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="test binary, exits with exitcode") + parser.add_argument("--exitcode", type=int, default=0) + parser.add_argument("msg", type=str) + args = parser.parse_args() + + rank = int(os.environ["RANK"]) + exitcode = args.exitcode + if exitcode != 0: + print(f"exit {exitcode} from {rank}", file=sys.stderr) + sys.exit(exitcode) + else: + time.sleep(1000) + print(f"{args.msg} stdout from {rank}") + print(f"{args.msg} stderr from {rank}", file=sys.stderr)