diff --git a/test/run_test.py b/test/run_test.py index 097dc5e8f7c..b13ec37f765 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -24,6 +24,7 @@ from torch.testing._internal.common_utils import ( IS_CI, is_slow_gradcheck_env, parser as common_parser, + retry_shell, set_cwd, shell, TEST_WITH_ROCM, @@ -41,6 +42,8 @@ try: get_reordered_tests, get_test_case_configs, NUM_PROCS, + ShardedTest, + THRESHOLD, ) HAVE_TEST_SELECTION_TOOLS = True @@ -51,6 +54,9 @@ except ImportError: ) +RERUN_DISABLED_TESTS = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1" + + # Note [ROCm parallel CI testing] # https://github.com/pytorch/pytorch/pull/85770 added file-granularity parallel testing. # In .ci/pytorch/test.sh, TEST_CONFIG == "default", CUDA and HIP_VISIBLE_DEVICES is set to 0. @@ -277,6 +283,13 @@ CI_SERIAL_LIST = [ "_nvfuser/test_torchscript", # OOM on test_issue_1785 "test_schema_check", # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/95749 "functorch/test_memory_efficient_fusion", # Cause CUDA OOM on ROCm + "test_utils", # OOM + "test_sort_and_select", # OOM + "test_backward_compatible_arguments", # OOM + "test_module_init", # OOM + "test_autocast", # OOM + "test_native_mha", # OOM + "test_module_hooks", # OOM ] # A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected @@ -291,19 +304,6 @@ CORE_TEST_LIST = [ "test_torch", ] -# A list of distributed tests that run on multiple backends, i.e. gloo, nccl. These backends are spread out -# among all available shards to speed up the test. The list of backends are copied from the tests themselves -DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS = { - "distributed/test_distributed_spawn": [ - "gloo", - "nccl", - "ucc", - ], - "distributed/algorithms/quantization/test_quantization": [ - "gloo", - "nccl", - ], -} # if a test file takes longer than 5 min, we add it to TARGET_DET_LIST SLOW_TEST_THRESHOLD = 300 @@ -406,9 +406,19 @@ def run_test( ) -> int: maybe_set_hip_visible_devies() unittest_args = options.additional_unittest_args.copy() + test_file = test_module + if isinstance(test_file, ShardedTest): + unittest_args.extend( + [ + f"--shard-id={test_module.shard - 1}", + f"--num-shards={test_module.num_shards}", + ] + ) + test_file = test_module.name + if options.verbose: unittest_args.append(f'-{"v"*options.verbose}') # in case of pytest - if test_module in RUN_PARALLEL_BLOCKLIST: + if test_file in RUN_PARALLEL_BLOCKLIST: unittest_args = [ arg for arg in unittest_args if not arg.startswith("--run-parallel") ] @@ -422,11 +432,11 @@ def run_test( unittest_args = [arg if arg != "-f" else "-x" for arg in unittest_args] if IS_CI: ci_args = ["--import-slow-tests", "--import-disabled-tests"] - if os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1": + if RERUN_DISABLED_TESTS: ci_args.append("--rerun-disabled-tests") # use the downloaded test cases configuration, not supported in pytest unittest_args.extend(ci_args) - if test_module in PYTEST_SKIP_RETRIES: + if test_file in PYTEST_SKIP_RETRIES: if not options.pytest: raise RuntimeError( "A test running without pytest cannot skip retries using " @@ -439,19 +449,35 @@ def run_test( # Can't call `python -m unittest test_*` here because it doesn't run code # in `if __name__ == '__main__': `. So call `python test_*.py` instead. - argv = [test_module + ".py"] + unittest_args + argv = [test_file + ".py"] + unittest_args os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True) log_fd, log_path = tempfile.mkstemp( dir=REPO_ROOT / "test" / "test-reports", - prefix="{}_".format(test_module.replace("\\", "-").replace("/", "-")), + prefix="{}_".format(test_file.replace("\\", "-").replace("/", "-")), suffix=".log", ) os.close(log_fd) command = (launcher_cmd or []) + executable + argv + should_file_rerun = ( + "--subprocess" not in command + and not RERUN_DISABLED_TESTS + and isinstance(test_module, ShardedTest) + and test_module.time is not None + ) + timeout = THRESHOLD * 2 if should_file_rerun else None print_to_stderr("Executing {} ... [{}]".format(command, datetime.now())) with open(log_path, "w") as f: - ret_code = shell(command, test_directory, stdout=f, stderr=f, env=env) + ret_code = retry_shell( + command, + test_directory, + stdout=f, + stderr=f, + env=env, + timeout=timeout, + retries=1 if should_file_rerun else 0, + ) + print_log_file(test_module, log_path, failed=(ret_code != 0)) os.remove(log_path) return ret_code @@ -549,33 +575,12 @@ def test_distributed(test_module, test_directory, options): if options.verbose and not mpi_available: print_to_stderr("MPI not available -- MPI backend tests will be skipped") - if options.shard: - which_shard, num_shards = options.shard - else: - which_shard = num_shards = 1 - # Round-robin all backends to different shards - backend_to_shard = { - backend: i % num_shards + 1 - for i, backend in enumerate( - DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS[test_module] - ) - } - print_to_stderr( - f"Map different backends to different shards for {test_module}: {backend_to_shard}" - ) - config = DISTRIBUTED_TESTS_CONFIG for backend, env_vars in config.items(): if sys.platform == "win32" and backend != "gloo": continue if backend == "mpi" and not mpi_available: continue - # Default to the first shard if seeing an unrecognized backend - if which_shard != backend_to_shard.get(backend, 1): - print_to_stderr( - f"Shard {which_shard}: {backend} should be run in {backend_to_shard.get(backend, 1)}" - ) - continue for with_init_file in {True, False}: if sys.platform == "win32" and not with_init_file: continue @@ -584,8 +589,8 @@ def test_distributed(test_module, test_directory, options): init_str = "with {} init_method" with_init = init_str.format("file" if with_init_file else "env") print_to_stderr( - "Running distributed tests for the {} backend {} in shard {} of {}".format( - backend, with_init, which_shard, num_shards + "Running distributed tests for the {} backend {}".format( + backend, with_init ) ) old_environ = dict(os.environ) @@ -594,7 +599,7 @@ def test_distributed(test_module, test_directory, options): os.environ["INIT_METHOD"] = "env://" os.environ.update(env_vars) if with_init_file: - if test_module == "test_distributed_spawn": + if test_module.name == "test_distributed_spawn": init_method = f"{FILE_SCHEMA}{tmp_dir}/" else: init_method = f"{FILE_SCHEMA}{tmp_dir}/shared_init_file" @@ -778,6 +783,7 @@ def run_doctests(test_module, test_directory, options): def print_log_file(test: str, file_path: str, failed: bool) -> None: num_lines = sum(1 for _ in open(file_path, "rb")) + test = str(test) n = 100 with open(file_path, "r") as f: print_to_stderr("") @@ -805,7 +811,7 @@ def print_log_file(test: str, file_path: str, failed: bool) -> None: def get_pytest_args(options): - if os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1": + if RERUN_DISABLED_TESTS: # When under rerun-disabled-tests mode, run the same tests multiple times to determine their # flakiness status. Default to 50 re-runs rerun_options = ["--flake-finder", "--flake-runs=50"] @@ -815,7 +821,7 @@ def get_pytest_args(options): else: # When under the normal mode, retry a failed test 2 more times. -x means stop at the first # failure - rerun_options = ["-x", "--reruns=2"] + rerun_options = ["-x", "--reruns=2", "--sw"] pytest_args = [ "--use-pytest", @@ -828,55 +834,6 @@ def get_pytest_args(options): return pytest_args -def run_test_ops(test_module, test_directory, options): - default_unittest_args = get_pytest_args(options) - - return_codes = [] - os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS) - pool = get_context("spawn").Pool(NUM_PROCS) - for i in range(NUM_PROCS): - extra_unittest_args = default_unittest_args.copy() - extra_unittest_args.extend( - [ - f"--shard-id={i}", - f"--num-shards={NUM_PROCS}", - "-k=not _linalg_cholesky_", - ] - ) - - return_code = pool.apply_async( - run_test, - args=(test_module, test_directory, copy.deepcopy(options)), - kwds={ - "extra_unittest_args": extra_unittest_args, - }, - ) - return_codes.append(return_code) - - pool.close() - pool.join() - del os.environ["NUM_PARALLEL_PROCS"] - - for return_code in return_codes: - if return_code.get() != 0: - return return_code.get() - - extra_unittest_args = default_unittest_args.copy() - extra_unittest_args.extend( - [ - "-k=_linalg_cholesky_", - ] - ) - - return_code = run_test( - test_module, - test_directory, - copy.deepcopy(options), - extra_unittest_args=extra_unittest_args, - ) - return return_code - - CUSTOM_HANDLERS = { "test_cuda_primary_ctx": test_cuda_primary_ctx, "test_cuda_nvml_based_avail": get_run_test_with_subprocess_fn(), @@ -898,15 +855,6 @@ CUSTOM_HANDLERS = { "distributed/rpc/test_share_memory": get_run_test_with_subprocess_fn(), "distributed/rpc/cuda/test_tensorpipe_agent": get_run_test_with_subprocess_fn(), "doctests": run_doctests, - "inductor/test_torchinductor_opinfo": run_test_ops, - "test_ops": run_test_ops, - "test_ops_gradients": run_test_ops, - "test_ops_fwd_gradients": run_test_ops, - "test_ops_jit": run_test_ops, - "functorch/test_ops": run_test_ops, - # not a test_ops file, but takes 2 hrs on some architectures and - # run_test_ops is good at parallelizing things - "test_decomp": run_test_ops, } @@ -1179,13 +1127,7 @@ def get_selected_tests(options): if options.distributed_tests: selected_tests = list( - filter( - lambda test_name: ( - test_name in DISTRIBUTED_TESTS - and test_name not in DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS - ), - selected_tests, - ) + filter(lambda test_name: test_name in DISTRIBUTED_TESTS, selected_tests) ) # Filter to only run core tests when --core option is specified @@ -1246,46 +1188,6 @@ def get_selected_tests(options): elif TEST_WITH_ROCM: selected_tests = exclude_tests(ROCM_BLOCKLIST, selected_tests, "on ROCm") - # sharding - if options.shard: - assert len(options.shard) == 2, "Unexpected shard format" - assert min(options.shard) > 0, "Shards must be positive numbers" - which_shard, num_shards = options.shard - assert ( - which_shard <= num_shards - ), "Selected shard must be less than or equal to total number of shards" - assert num_shards <= len( - selected_tests - ), f"Number of shards must be less than {len(selected_tests)}" - - if num_shards == 1: - return selected_tests - - # Download previous test times to make sharding decisions - path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE) - if os.path.exists(path): - with open(path, "r") as f: - test_file_times = cast(Dict[str, Any], json.load(f)) - else: - test_file_times = {} - test_config = os.environ.get("TEST_CONFIG") - if test_config not in test_file_times: - print( - "::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan." - ) - selected_tests = selected_tests[which_shard - 1 :: num_shards] - else: - print("Found test time stats from artifacts") - test_file_times_config = test_file_times[test_config] - shards = calculate_shards( - num_shards, - selected_tests, - test_file_times_config, - must_serial=must_serial, - ) - _, tests_from_shard = shards[which_shard - 1] - selected_tests = tests_from_shard - # skip all distributed tests if distributed package is not available. if not dist.is_available(): selected_tests = exclude_tests( @@ -1311,30 +1213,61 @@ def get_selected_tests(options): exact_match=True, ) - if options.distributed_tests: - # Run distributed tests with multiple backends across all shards, one per backend - selected_tests.extend(DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS.keys()) - selected_tests.reverse() + selected_tests = [parse_test_module(x) for x in selected_tests] + + # sharding + which_shard, num_shards = 1, 1 + if options.shard: + assert len(options.shard) == 2, "Unexpected shard format" + assert min(options.shard) > 0, "Shards must be positive numbers" + which_shard, num_shards = options.shard + assert ( + which_shard <= num_shards + ), "Selected shard must be less than or equal to total number of shards" + assert num_shards <= len( + selected_tests + ), f"Number of shards must be less than {len(selected_tests)}" + + # Download previous test times to make sharding decisions + path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE) + if os.path.exists(path): + with open(path, "r") as f: + test_file_times = cast(Dict[str, Any], json.load(f)) + else: + test_file_times = {} + test_config = os.environ.get("TEST_CONFIG") + if test_config not in test_file_times: + print( + "::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan." + ) + else: + print("Found test time stats from artifacts") + + # Do sharding + test_file_times_config = test_file_times.get(test_config, {}) + shards = calculate_shards( + num_shards, selected_tests, test_file_times_config, must_serial=must_serial + ) + _, tests_from_shard = shards[which_shard - 1] + selected_tests = tests_from_shard return selected_tests -def run_test_module(test: str, test_directory: str, options) -> Optional[str]: +def run_test_module(test: ShardedTest, test_directory: str, options) -> Optional[str]: maybe_set_hip_visible_devies() - test_module = parse_test_module(test) - # Printing the date here can help diagnose which tests are slow - print_to_stderr("Running {} ... [{}]".format(test, datetime.now())) - handler = CUSTOM_HANDLERS.get(test_module, run_test) - return_code = handler(test_module, test_directory, options) + print_to_stderr("Running {} ... [{}]".format(str(test), datetime.now())) + handler = CUSTOM_HANDLERS.get(test.name, run_test) + return_code = handler(test, test_directory, options) assert isinstance(return_code, int) and not isinstance( return_code, bool - ), f"While running {test} got non integer return code {return_code}" + ), f"While running {str(test)} got non integer return code {return_code}" if return_code == 0: return None - message = f"{test} failed!" + message = f"{str(test)} failed!" if return_code < 0: # subprocess.Popen returns the child process' exit signal as # return code -N, where N is the signal number. @@ -1350,7 +1283,9 @@ def main(): selected_tests = get_selected_tests(options) if options.verbose: - print_to_stderr("Selected tests:\n {}".format("\n ".join(selected_tests))) + print_to_stderr( + "Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests)) + ) if options.dry_run: return @@ -1372,18 +1307,18 @@ def main(): # parallel = in parallel with other files # serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops) - selected_tests_parallel = [x for x in selected_tests if not must_serial(x)] + selected_tests_parallel = [x for x in selected_tests if not must_serial(x.name)] selected_tests_serial = [ x for x in selected_tests if x not in selected_tests_parallel ] print_to_stderr( "parallel (file granularity) tests:\n {}".format( - "\n ".join(selected_tests_parallel) + "\n ".join(str(x) for x in selected_tests_parallel) ) ) print_to_stderr( "serial (file granularity) tests:\n {}".format( - "\n ".join(selected_tests_serial) + "\n ".join(str(x) for x in selected_tests_serial) ) ) @@ -1403,7 +1338,7 @@ def main(): return False try: - os.environ["PARALLEL_TESTING"] = "1" + os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS) for test in selected_tests_parallel: options_clone = copy.deepcopy(options) if can_run_in_pytest(test): @@ -1415,7 +1350,7 @@ def main(): ) pool.close() pool.join() - del os.environ["PARALLEL_TESTING"] + del os.environ["NUM_PARALLEL_PROCS"] if not options.continue_through_error and len(failure_messages) != 0: raise RuntimeError( diff --git a/tools/test/test_test_selections.py b/tools/test/test_test_selections.py index 23f05cb99fe..f8aa78b9946 100644 --- a/tools/test/test_test_selections.py +++ b/tools/test/test_test_selections.py @@ -1,8 +1,18 @@ +import pathlib import random +import sys import unittest +from collections import defaultdict from typing import Dict, List, Tuple -from tools.testing.test_selections import calculate_shards +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent +try: + # using tools/ to optimize test run. + sys.path.append(str(REPO_ROOT)) + from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD +except ModuleNotFoundError: + print("Can't import required modules, exiting") + exit(1) class TestCalculateShards(unittest.TestCase): @@ -36,8 +46,8 @@ class TestCalculateShards(unittest.TestCase): def assert_shards_equal( self, - expected_shards: List[Tuple[float, List[str]]], - actual_shards: List[Tuple[float, List[str]]], + expected_shards: List[Tuple[float, List[ShardedTest]]], + actual_shards: List[Tuple[float, List[ShardedTest]]], ) -> None: for expected, actual in zip(expected_shards, actual_shards): self.assertAlmostEqual(expected[0], actual[0]) @@ -45,19 +55,25 @@ class TestCalculateShards(unittest.TestCase): def test_calculate_2_shards_with_complete_test_times(self) -> None: expected_shards = [ - (60, ["super_long_test", "normal_test3"]), + ( + 60.0, + [ + ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55), + ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5), + ], + ), ( 58.31, [ - "long_test1", - "long_test2", - "normal_test1", - "normal_test2", - "short_test1", - "short_test2", - "short_test3", - "short_test4", - "short_test5", + ShardedTest(name="long_test1", shard=1, num_shards=1, time=22), + ShardedTest(name="long_test2", shard=1, num_shards=1, time=18), + ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9), + ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7), + ShardedTest(name="short_test1", shard=1, num_shards=1, time=1), + ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6), + ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4), + ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3), + ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01), ], ), ] @@ -70,19 +86,19 @@ class TestCalculateShards(unittest.TestCase): ( 118.31, [ - "super_long_test", - "long_test1", - "long_test2", - "normal_test1", - "normal_test2", - "normal_test3", - "short_test1", - "short_test2", - "short_test3", - "short_test4", - "short_test5", + ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55), + ShardedTest(name="long_test1", shard=1, num_shards=1, time=22), + ShardedTest(name="long_test2", shard=1, num_shards=1, time=18), + ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9), + ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7), + ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5), + ShardedTest(name="short_test1", shard=1, num_shards=1, time=1), + ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6), + ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4), + ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3), + ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01), ], - ), + ) ] self.assert_shards_equal( expected_shards, calculate_shards(1, self.tests, self.test_times) @@ -90,31 +106,30 @@ class TestCalculateShards(unittest.TestCase): def test_calculate_5_shards_with_complete_test_times(self) -> None: expected_shards = [ - (55.0, ["super_long_test"]), ( - 22.0, - [ - "long_test1", - ], - ), - ( - 18.0, - [ - "long_test2", - ], + 55.0, + [ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55)], ), + (22.0, [ShardedTest(name="long_test1", shard=1, num_shards=1, time=22)]), + (18.0, [ShardedTest(name="long_test2", shard=1, num_shards=1, time=18)]), ( 11.31, [ - "normal_test1", - "short_test1", - "short_test2", - "short_test3", - "short_test4", - "short_test5", + ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9), + ShardedTest(name="short_test1", shard=1, num_shards=1, time=1), + ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6), + ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4), + ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3), + ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01), + ], + ), + ( + 12.0, + [ + ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7), + ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5), ], ), - (12.0, ["normal_test2", "normal_test3"]), ] self.assert_shards_equal( expected_shards, calculate_shards(5, self.tests, self.test_times) @@ -128,22 +143,24 @@ class TestCalculateShards(unittest.TestCase): ( 22.0, [ - "long_test1", - "long_test2", - "normal_test3", - "short_test3", - "short_test5", + ShardedTest(name="long_test1", shard=1, num_shards=1, time=22), + ShardedTest(name="long_test2", shard=1, num_shards=1, time=None), + ShardedTest(name="normal_test3", shard=1, num_shards=1, time=None), + ShardedTest(name="short_test3", shard=1, num_shards=1, time=None), + ShardedTest(name="short_test5", shard=1, num_shards=1, time=None), ], ), ( 10.0, [ - "normal_test1", - "short_test1", - "super_long_test", - "normal_test2", - "short_test2", - "short_test4", + ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9), + ShardedTest(name="short_test1", shard=1, num_shards=1, time=1), + ShardedTest( + name="super_long_test", shard=1, num_shards=1, time=None + ), + ShardedTest(name="normal_test2", shard=1, num_shards=1, time=None), + ShardedTest(name="short_test2", shard=1, num_shards=1, time=None), + ShardedTest(name="short_test4", shard=1, num_shards=1, time=None), ], ), ] @@ -156,19 +173,133 @@ class TestCalculateShards(unittest.TestCase): k: v for k, v in self.test_times.items() if "test1" in k } expected_shards = [ - (22.0, ["long_test1", "normal_test2", "short_test5"]), - (9.0, ["normal_test1", "normal_test3"]), - (1.0, ["short_test1", "short_test2"]), - (0.0, ["super_long_test", "short_test3"]), - (0.0, ["long_test2", "short_test4"]), + ( + 22.0, + [ + ShardedTest(name="long_test1", shard=1, num_shards=1, time=22), + ShardedTest(name="normal_test2", shard=1, num_shards=1, time=None), + ShardedTest(name="short_test5", shard=1, num_shards=1, time=None), + ], + ), + ( + 9.0, + [ + ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9), + ShardedTest(name="normal_test3", shard=1, num_shards=1, time=None), + ], + ), + ( + 1.0, + [ + ShardedTest(name="short_test1", shard=1, num_shards=1, time=1), + ShardedTest(name="short_test2", shard=1, num_shards=1, time=None), + ], + ), + ( + 0.0, + [ + ShardedTest( + name="super_long_test", shard=1, num_shards=1, time=None + ), + ShardedTest(name="short_test3", shard=1, num_shards=1, time=None), + ], + ), + ( + 0.0, + [ + ShardedTest(name="long_test2", shard=1, num_shards=1, time=None), + ShardedTest(name="short_test4", shard=1, num_shards=1, time=None), + ], + ), ] self.assert_shards_equal( expected_shards, calculate_shards(5, self.tests, incomplete_test_times) ) - def test_calculate_2_shards_against_optimal_shards(self) -> None: + def test_split_shards(self) -> None: + test_times: Dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD} + expected_shards = [ + (600.0, [ShardedTest(name="test1", shard=1, num_shards=1, time=THRESHOLD)]), + (600.0, [ShardedTest(name="test2", shard=1, num_shards=1, time=THRESHOLD)]), + ] + self.assert_shards_equal( + expected_shards, calculate_shards(2, list(test_times.keys()), test_times) + ) + + test_times = {"test1": THRESHOLD * 4, "test2": THRESHOLD * 2.5} + expected_shards = [ + ( + 2200.0, + [ + ShardedTest(name="test1", shard=1, num_shards=4, time=600.0), + ShardedTest(name="test1", shard=3, num_shards=4, time=600.0), + ShardedTest(name="test2", shard=1, num_shards=3, time=500.0), + ShardedTest(name="test2", shard=3, num_shards=3, time=500.0), + ], + ), + ( + 1700.0, + [ + ShardedTest(name="test1", shard=2, num_shards=4, time=600.0), + ShardedTest(name="test1", shard=4, num_shards=4, time=600.0), + ShardedTest(name="test2", shard=2, num_shards=3, time=500.0), + ], + ), + ] + self.assert_shards_equal( + expected_shards, calculate_shards(2, list(test_times.keys()), test_times) + ) + + test_times = {"test1": THRESHOLD / 2, "test2": THRESHOLD} + expected_shards = [ + (600.0, [ShardedTest(name="test2", shard=1, num_shards=1, time=THRESHOLD)]), + ( + 300.0, + [ShardedTest(name="test1", shard=1, num_shards=1, time=THRESHOLD / 2)], + ), + ] + self.assert_shards_equal( + expected_shards, calculate_shards(2, list(test_times.keys()), test_times) + ) + + def test_split_shards_random(self) -> None: + random.seed(120) + for _ in range(100): + num_shards = random.randint(1, 10) + num_tests = random.randint(1, 100) + random_times: Dict[str, float] = { + str(i): random.randint(0, THRESHOLD * 10) for i in range(num_tests) + } + + shards = calculate_shards( + num_shards, list(random_times.keys()), random_times + ) + + times = [x[0] for x in shards] + max_diff = max(times) - min(times) + self.assertTrue(max_diff <= THRESHOLD) + + all_sharded_tests = defaultdict(list) + for time, sharded_tests in shards: + self.assertEqual(time, sum(x.time for x in sharded_tests)) + for sharded_test in sharded_tests: + all_sharded_tests[sharded_test.name].append(sharded_test) + + self.assertListEqual( + sorted(random_times.keys()), sorted(all_sharded_tests.keys()) + ) + for test, sharded_tests in all_sharded_tests.items(): + self.assertAlmostEqual( + random_times[test], sum(x.time or 0 for x in sharded_tests) + ) + self.assertListEqual( + list(range(sharded_tests[0].num_shards)), + sorted(x.shard - 1 for x in sharded_tests), + ) + + def test_calculate_2_shards_against_optimal_shards(self) -> None: + random.seed(120) for _ in range(100): - random.seed(120) random_times = {k: random.random() * 10 for k in self.tests} # all test times except first two rest_of_tests = [ @@ -194,7 +325,7 @@ class TestCalculateShards(unittest.TestCase): calculated_shards[0][1] + calculated_shards[1][1] ) # All the tests should be represented by some shard - self.assertEqual(sorted_tests, sorted_shard_tests) + self.assertEqual(sorted_tests, [x.name for x in sorted_shard_tests]) if __name__ == "__main__": diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py index bde066de7a6..013cc05d242 100644 --- a/tools/testing/test_selections.py +++ b/tools/testing/test_selections.py @@ -1,13 +1,15 @@ +import math import os import subprocess -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1" NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 2 +THRESHOLD = 60 * 10 # 10 minutes # See Note [ROCm parallel CI testing] # Special logic for ROCm GHA runners to query number of GPUs available. @@ -30,43 +32,73 @@ if os.path.exists("/opt/rocm") and not IS_MEM_LEAK_CHECK: NUM_PROCS = 1 +class ShardedTest(NamedTuple): + name: str + shard: int + num_shards: int + time: Optional[float] + + def __str__(self) -> str: + return f"{self.name} {self.shard}/{self.num_shards}" + + def get_time(self) -> float: + return self.time or 0 + + class ShardJob: - def __init__(self, test_times: Dict[str, float]): - self.test_times = test_times - self.serial: List[str] = [] - self.parallel: List[str] = [] + def __init__(self) -> None: + self.serial: List[ShardedTest] = [] + self.parallel: List[ShardedTest] = [] def get_total_time(self) -> float: procs = [0.0 for _ in range(NUM_PROCS)] for test in self.parallel: - test_time = self.test_times.get(test, 0) min_index = procs.index(min(procs)) - procs[min_index] += test_time - time = max(procs) + sum(self.test_times.get(test, 0) for test in self.serial) + procs[min_index] += test.get_time() + time = max(procs) + sum(test.get_time() for test in self.serial) return time - def convert_to_tuple(self) -> Tuple[float, List[str]]: + def convert_to_tuple(self) -> Tuple[float, List[ShardedTest]]: return (self.get_total_time(), self.serial + self.parallel) +def get_with_pytest_shard( + tests: List[str], test_file_times: Dict[str, float] +) -> List[ShardedTest]: + sharded_tests: List[ShardedTest] = [] + for test in tests: + duration = test_file_times[test] + if duration > THRESHOLD: + num_shards = math.ceil(duration / THRESHOLD) + for i in range(num_shards): + sharded_tests.append( + ShardedTest(test, i + 1, num_shards, duration / num_shards) + ) + else: + sharded_tests.append(ShardedTest(test, 1, 1, duration)) + return sharded_tests + + def calculate_shards( num_shards: int, tests: List[str], test_file_times: Dict[str, float], must_serial: Optional[Callable[[str], bool]] = None, -) -> List[Tuple[float, List[str]]]: +) -> List[Tuple[float, List[ShardedTest]]]: must_serial = must_serial or (lambda x: True) known_tests = [x for x in tests if x in test_file_times] unknown_tests: List[str] = [x for x in tests if x not in known_tests] - sorted_tests = sorted(known_tests, key=lambda j: test_file_times[j], reverse=True) + sorted_tests = sorted( + get_with_pytest_shard(known_tests, test_file_times), + key=lambda j: j.get_time(), + reverse=True, + ) - sharded_jobs: List[ShardJob] = [ - ShardJob(test_file_times) for _ in range(num_shards) - ] + sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)] for test in sorted_tests: - if must_serial(test): + if must_serial(test.name): min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time()) min_sharded_job.serial.append(test) else: @@ -75,8 +107,8 @@ def calculate_shards( # Round robin the unknown jobs starting with the smallest shard index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time()) - for test in unknown_tests: - sharded_jobs[index].serial.append(test) + for unknown_test in unknown_tests: + sharded_jobs[index].serial.append(ShardedTest(unknown_test, 1, 1, None)) index = (index + 1) % num_shards return [job.convert_to_tuple() for job in sharded_jobs] diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 170f5058ae6..13588a1c8da 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -116,11 +116,9 @@ slow_tests_dict = {} if os.getenv("SLOW_TESTS_FILE", ""): with open(os.getenv("SLOW_TESTS_FILE"), 'r') as fp: slow_tests_dict = json.load(fp) - warnings.warn(f"loaded {len(slow_tests_dict)} slow tests") if os.getenv("DISABLED_TESTS_FILE", ""): with open(os.getenv("DISABLED_TESTS_FILE"), 'r') as fp: disabled_tests_dict = json.load(fp) - warnings.warn(f"loaded {len(disabled_tests_dict)} disabled tests") NATIVE_DEVICES = ('cpu', 'cuda', 'meta') @@ -571,9 +569,9 @@ CI_TEST_PREFIX = str(Path(os.getcwd())) CI_PT_ROOT = str(Path(os.getcwd()).parent) CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) -def wait_for_process(p): +def wait_for_process(p, timeout=None): try: - return p.wait() + return p.wait(timeout=timeout) except KeyboardInterrupt: # Give `p` a chance to handle KeyboardInterrupt. Without this, # `pytest` can't print errors it collected so far upon KeyboardInterrupt. @@ -590,7 +588,7 @@ def wait_for_process(p): # Always call p.wait() to ensure exit p.wait() -def shell(command, cwd=None, env=None, stdout=None, stderr=None): +def shell(command, cwd=None, env=None, stdout=None, stderr=None, timeout=None): sys.stdout.flush() sys.stderr.flush() # The following cool snippet is copied from Py3 core library subprocess.call @@ -602,7 +600,22 @@ def shell(command, cwd=None, env=None, stdout=None, stderr=None): # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323 assert not isinstance(command, str), "Command to shell should be a list or tuple of tokens" p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env, stdout=stdout, stderr=stderr) - return wait_for_process(p) + return wait_for_process(p, timeout=timeout) + + +def retry_shell(command, cwd=None, env=None, stdout=None, stderr=None, timeout=None, retries=1): + assert retries >= 0, f"Expecting non negative number for number of retries, got {retries}" + try: + exit_code = shell(command, cwd=cwd, env=env, stdout=stdout, stderr=stderr, timeout=timeout) + if exit_code == 0 or retries == 0: + return exit_code + print(f"Got exit code {exit_code}, retrying (retries left={retries})", file=stdout, flush=True) + except subprocess.TimeoutExpired: + if retries == 0: + print(f"Command took >{timeout // 60}min, returning 124", file=stdout, flush=True) + return 124 + print(f"Command took >{timeout // 60}min, retrying (retries left={retries})", file=stdout, flush=True) + return retry_shell(command, cwd=cwd, env=env, stdout=stdout, stderr=stderr, timeout=timeout, retries=retries - 1) def discover_test_cases_recursively(suite_or_case): @@ -753,7 +766,10 @@ def run_tests(argv=UNITTEST_ARGS): [test_case_full_name] ) string_cmd = " ".join(cmd) - exitcode = shell(cmd) + + timeout = None if RERUN_DISABLED_TESTS else 15 * 60 + + exitcode = retry_shell(cmd, timeout=timeout, retries=0 if RERUN_DISABLED_TESTS else 1) if exitcode != 0: # This is sort of hacky, but add on relevant env variables for distributed tests. @@ -794,7 +810,6 @@ def run_tests(argv=UNITTEST_ARGS): exit_code = pytest.main(args=pytest_args) if TEST_SAVE_XML: sanitize_pytest_xml(test_report_path) - print("If in CI, skip info is located in the xml test reports, please either go to s3 or the hud to download them") if not RERUN_DISABLED_TESTS: # exitcode of 5 means no tests were found, which happens since some test configs don't