diff --git a/test/run_test.py b/test/run_test.py index 4c3d3ea5a9c..72aaba91b16 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -15,7 +15,7 @@ import tempfile import time from contextlib import ExitStack from datetime import datetime -from typing import Any, cast, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, cast, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union import pkg_resources @@ -40,12 +40,18 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent # using tools/ to optimize test run. sys.path.insert(0, str(REPO_ROOT)) -from tools.stats.import_test_stats import ADDITIONAL_CI_FILES_FOLDER, TEST_TIMES_FILE +from tools.stats.import_test_stats import ( + ADDITIONAL_CI_FILES_FOLDER, + TEST_CLASS_TIMES_FILE, + TEST_TIMES_FILE, +) from tools.stats.upload_metrics import add_global_metric, emit_metric from tools.testing.target_determination.determinator import ( AggregatedHeuristics, get_test_prioritizations, ) + +from tools.testing.test_run import TestRun from tools.testing.test_selections import ( calculate_shards, get_test_case_configs, @@ -479,7 +485,7 @@ def get_executable_command(options, disable_coverage=False, is_cpp_test=False): def run_test( - test_module, + test_module: ShardedTest, test_directory, options, launcher_cmd=None, @@ -488,14 +494,9 @@ def run_test( ) -> int: maybe_set_hip_visible_devies() unittest_args = options.additional_unittest_args.copy() - test_file = test_module + test_file = test_module.name stepcurrent_key = test_file - use_sharded_test = False - if isinstance(test_file, ShardedTest): - test_file = test_module.name - use_sharded_test = True - is_distributed_test = test_file.startswith(DISTRIBUTED_TEST_PREFIX) is_cpp_test = test_file.startswith(CPP_TEST_PREFIX) # NB: Rerun disabled tests depends on pytest-flakefinder and it doesn't work with @@ -507,17 +508,16 @@ def run_test( ) return 0 - if use_sharded_test: - if is_cpp_test: - stepcurrent_key = test_file - else: - unittest_args.extend( - [ - f"--shard-id={test_module.shard - 1}", - f"--num-shards={test_module.num_shards}", - ] - ) - stepcurrent_key = f"{test_file}_{test_module.shard - 1}" + if is_cpp_test: + stepcurrent_key = test_file + else: + unittest_args.extend( + [ + f"--shard-id={test_module.shard - 1}", + f"--num-shards={test_module.num_shards}", + ] + ) + stepcurrent_key = f"{test_file}_{test_module.shard - 1}" if options.verbose: unittest_args.append(f'-{"v"*options.verbose}') # in case of pytest @@ -541,6 +541,7 @@ def run_test( is_distributed_test=is_distributed_test, ) ) + unittest_args.extend(test_module.get_pytest_args()) unittest_args = [arg if arg != "-f" else "-x" for arg in unittest_args] # TODO: These features are not available for C++ test yet @@ -706,7 +707,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja): assert install_directory, "install_directory must not be empty" os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path]) - return run_test(test_module, test_directory, options) + return run_test(ShardedTest(test_module, 1, 1), test_directory, options) finally: os.environ["PYTHONPATH"] = python_path if os.path.exists(test_directory + "/" + test_module + ".py"): @@ -1424,14 +1425,14 @@ def get_selected_tests(options) -> List[str]: return selected_tests -def download_test_times( - file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE, -) -> Dict[str, float]: - # Download previous test times to make sharding decisions +def load_test_times_from_file( + file: str, +) -> Dict[str, Any]: + # Load previous test times to make sharding decisions path = os.path.join(str(REPO_ROOT), file) if not os.path.exists(path): print_to_stderr( - "::warning:: Failed to find test times file. Using round robin sharding." + f"::warning:: Failed to find test times file `{path}`. Using round robin sharding." ) return {} @@ -1456,6 +1457,18 @@ def download_test_times( return test_times_file["default"]["default"] +def load_test_file_times( + file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE, +) -> Dict[str, float]: + return cast(Dict[str, float], load_test_times_from_file(file)) + + +def load_test_class_times( + file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_CLASS_TIMES_FILE, +) -> Dict[str, Dict[str, float]]: + return cast(Dict[str, Dict[str, float]], load_test_times_from_file(file)) + + def get_sharding_opts(options) -> Tuple[int, int]: which_shard, num_shards = 1, 1 if options.shard: @@ -1471,8 +1484,9 @@ def get_sharding_opts(options) -> Tuple[int, int]: def do_sharding( options, - selected_tests: List[str], + selected_tests: Sequence[TestRun], test_file_times: Dict[str, float], + test_class_times: Dict[str, Dict[str, float]], sort_by_time: bool = True, ) -> List[ShardedTest]: which_shard, num_shards = get_sharding_opts(options) @@ -1482,6 +1496,7 @@ def do_sharding( num_shards, selected_tests, test_file_times, + test_class_times=test_class_times, must_serial=must_serial, sort_by_time=sort_by_time, ) @@ -1492,16 +1507,16 @@ def do_sharding( class TestFailure(NamedTuple): - test: str + test: TestRun message: str def run_test_module( - test: Union[ShardedTest, str], test_directory: str, options + test: ShardedTest, test_directory: str, options ) -> Optional[TestFailure]: maybe_set_hip_visible_devies() - test_name = test.name if isinstance(test, ShardedTest) else test + test_name = test.name # Printing the date here can help diagnose which tests are slow print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]") @@ -1519,7 +1534,7 @@ def run_test_module( # return code -N, where N is the signal number. signal_name = SIGNALS_TO_NAMES_DICT[-return_code] message += f" Received signal: {signal_name}" - return TestFailure(test_name, message) + return TestFailure(test.test, message) def run_tests( @@ -1671,6 +1686,9 @@ def main(): "cpp": options.cpp, } + test_file_times_dict = load_test_file_times() + test_class_times_dict = load_test_class_times() + class TestBatch: """Defines a set of tests with similar priority that should be run together on the current shard""" @@ -1678,11 +1696,17 @@ def main(): sharded_tests: List[ShardedTest] failures: List[TestFailure] - def __init__(self, name: str, raw_tests: List[str], should_sort_shard: bool): + def __init__( + self, name: str, raw_tests: Sequence[TestRun], should_sort_shard: bool + ): self.name = name self.failures = [] self.sharded_tests = do_sharding( - options, raw_tests, test_times_dict, sort_by_time=should_sort_shard + options, + raw_tests, + test_file_times_dict, + test_class_times_dict, + sort_by_time=should_sort_shard, ) def __str__(self): @@ -1697,7 +1721,6 @@ def main(): ) return s.strip() - test_times_dict = download_test_times(ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE) test_batches: List[TestBatch] = [] # Each batch will be run sequentially @@ -1749,7 +1772,7 @@ def main(): test_batch.sharded_tests, test_directory, options, test_batch.failures ) metrics_dict[f"{test_batch.name}_failures"] = [ - x.test for x in test_batch.failures + str(x.test) for x in test_batch.failures ] finally: diff --git a/tools/test/test_heuristics.py b/tools/test/test_heuristics.py index 615b35e626a..371613ece3c 100644 --- a/tools/test/test_heuristics.py +++ b/tools/test/test_heuristics.py @@ -3,7 +3,7 @@ import json import pathlib import sys import unittest -from typing import Any, Dict, Set +from typing import Any, Dict, Optional, Set from unittest import mock REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent @@ -12,12 +12,15 @@ try: sys.path.append(str(REPO_ROOT)) from tools.testing.target_determination.determinator import ( + AggregatedHeuristics, get_test_prioritizations, TestPrioritizations, ) + from tools.testing.target_determination.heuristics import HEURISTICS from tools.testing.target_determination.heuristics.previously_failed_in_pr import ( _get_previously_failing_tests, ) + from tools.testing.test_run import TestRun, TestRuns except ModuleNotFoundError: print("Can't import required modules, exiting") @@ -31,7 +34,37 @@ def mocked_file(contents: Dict[Any, Any]) -> io.IOBase: return file_object -class TestParsePrevTests(unittest.TestCase): +class HeuristicsTestMixin(unittest.TestCase): + def assert_heuristics_match( + self, + test_prioritizations: TestPrioritizations, + expected_high_tests: Optional[TestRuns] = None, + expected_probable_tests: Optional[TestRuns] = None, + expected_unranked_tests: Optional[TestRuns] = None, + ) -> None: + if expected_unranked_tests: + self.assertTupleEqual( + test_prioritizations.get_unranked_relevance_tests(), + expected_unranked_tests, + "Unranked tests differ", + ) + + if expected_probable_tests: + self.assertTupleEqual( + test_prioritizations.get_probable_relevance_tests(), + expected_probable_tests, + "Probable relevance tests differ", + ) + + if expected_high_tests: + self.assertTupleEqual( + test_prioritizations.get_high_relevance_tests(), + expected_high_tests, + "High relevance tests differ", + ) + + +class TestParsePrevTests(HeuristicsTestMixin): @mock.patch("pathlib.Path.exists", return_value=False) def test_cache_does_not_exist(self, mock_exists: Any) -> None: expected_failing_test_files: Set[str] = set() @@ -94,18 +127,257 @@ class TestParsePrevTests(unittest.TestCase): tests ).get_aggregated_priorities() - self.assertTupleEqual( - expected_prioritizations.get_high_relevance_tests(), - test_prioritizations.get_high_relevance_tests(), + self.assert_heuristics_match( + test_prioritizations, + expected_high_tests=expected_prioritizations.get_high_relevance_tests(), + expected_probable_tests=expected_prioritizations.get_probable_relevance_tests(), + expected_unranked_tests=expected_prioritizations.get_unranked_relevance_tests(), ) - self.assertTupleEqual( - expected_prioritizations.get_probable_relevance_tests(), - test_prioritizations.get_probable_relevance_tests(), + + +class TestInterface(HeuristicsTestMixin): + def test_class_prioritization(self) -> None: + tests = ["test1", "test2", "test3", "test4", "test5"] + + prioritizations = TestPrioritizations( + tests_being_ranked=tests, + probable_relevance=["test2::TestFooClass", "test3"], ) - self.assertTupleEqual( - expected_prioritizations.get_unranked_relevance_tests(), - test_prioritizations.get_unranked_relevance_tests(), + + expected_probable_tests = tuple( + TestRun(test) for test in ["test2::TestFooClass", "test3"] ) + expected_unranked_tests = ( + TestRun("test1"), + TestRun("test2", excluded=["TestFooClass"]), + TestRun("test4"), + TestRun("test5"), + ) + + self.assert_heuristics_match( + prioritizations, + expected_probable_tests=expected_probable_tests, + expected_unranked_tests=expected_unranked_tests, + ) + + +class TestAggregatedHeuristics(HeuristicsTestMixin): + def test_merging_multiple_test_class_heuristics(self) -> None: + tests = ["test1", "test2", "test3", "test4"] + + heuristic1 = TestPrioritizations( + tests_being_ranked=tests, + probable_relevance=["test2::TestFooClass", "test3"], + ) + + heuristic2 = TestPrioritizations( + tests_being_ranked=tests, + high_relevance=["test2::TestFooClass", "test3::TestBarClass"], + ) + + expected_high_relevance = tuple( + TestRun(test) for test in ["test2::TestFooClass", "test3::TestBarClass"] + ) + expected_probable_relevance = (TestRun("test3", excluded=["TestBarClass"]),) + expected_unranked_relevance = ( + TestRun("test1"), + TestRun("test2", excluded=["TestFooClass"]), + TestRun("test4"), + ) + + aggregator = AggregatedHeuristics(unranked_tests=tests) + aggregator.add_heuristic_results(HEURISTICS[0], heuristic1) + aggregator.add_heuristic_results(HEURISTICS[1], heuristic2) + + aggregated_pris = aggregator.get_aggregated_priorities() + + self.assert_heuristics_match( + aggregated_pris, + expected_high_tests=expected_high_relevance, + expected_probable_tests=expected_probable_relevance, + expected_unranked_tests=expected_unranked_relevance, + ) + + def test_merging_file_heuristic_after_class_heuristic(self) -> None: + tests = ["test1", "test2", "test3", "test4", "test5"] + heuristic1 = TestPrioritizations( + tests_being_ranked=tests, + high_relevance=["test2::TestFooClass"], + ) + heuristic2 = TestPrioritizations( + tests_being_ranked=tests, + probable_relevance=["test2", "test3"], + ) + + expected_aggregated_high_relevance = tuple( + TestRun(test) for test in ["test2::TestFooClass"] + ) + expected_aggregated_probable_relevance = ( + TestRun("test2", excluded=["TestFooClass"]), + TestRun("test3"), + ) + expected_aggregated_unranked_relevance = ( + TestRun("test1"), + TestRun("test4"), + TestRun("test5"), + ) + + aggregator = AggregatedHeuristics(unranked_tests=tests) + aggregator.add_heuristic_results(HEURISTICS[0], heuristic1) + aggregator.add_heuristic_results(HEURISTICS[1], heuristic2) + + aggregated_pris = aggregator.get_aggregated_priorities() + + self.assert_heuristics_match( + aggregated_pris, + expected_high_tests=expected_aggregated_high_relevance, + expected_probable_tests=expected_aggregated_probable_relevance, + expected_unranked_tests=expected_aggregated_unranked_relevance, + ) + + def test_get_test_stats_with_whole_tests(self) -> None: + self.maxDiff = None + tests = ["test1", "test2", "test3", "test4", "test5"] + heuristic1 = TestPrioritizations( + tests_being_ranked=tests, + high_relevance=["test3", "test4"], + ) + heuristic2 = TestPrioritizations( + tests_being_ranked=tests, + probable_relevance=["test5"], + ) + + aggregator = AggregatedHeuristics(unranked_tests=tests) + aggregator.add_heuristic_results(HEURISTICS[0], heuristic1) + aggregator.add_heuristic_results(HEURISTICS[1], heuristic2) + + expected_test3_stats = { + "test_name": "test3", + "test_filters": "", + "without_heuristics": { + "relevance_group": "UNRANKED", + "order_within_relevance_group": 2, + "num_tests_in_relevance_group": 5, + "order_overall": 2, + "heuristic_name": "baseline", + }, + "heuristics": [ + { + "relevance_group": "HIGH", + "order_within_relevance_group": 0, + "num_tests_in_relevance_group": 2, + "order_overall": 0, + "heuristic_name": HEURISTICS[0].name, + "trial_mode": False, + }, + { + "relevance_group": "UNRANKED", + "order_within_relevance_group": 2, + "num_tests_in_relevance_group": 4, + "order_overall": 3, + "heuristic_name": HEURISTICS[1].name, + "trial_mode": False, + }, + ], + "num_heuristics_prioritized_by": 1, + "aggregated": { + "relevance_group": "HIGH", + "order_within_relevance_group": 0, + "num_tests_in_relevance_group": 2, + "order_overall": 0, + }, + "aggregated_trial": { + "relevance_group": "HIGH", + "order_within_relevance_group": 0, + "num_tests_in_relevance_group": 2, + "order_overall": 0, + }, + "highest_ranking_heuristic": HEURISTICS[0].name, + } + + test3_stats = aggregator.get_test_stats(TestRun("test3")) + + self.assertDictEqual(test3_stats, expected_test3_stats) + + def test_get_test_stats_only_contains_allowed_types(self) -> None: + self.maxDiff = None + tests = ["test1", "test2", "test3", "test4", "test5"] + heuristic1 = TestPrioritizations( + tests_being_ranked=tests, + high_relevance=["test3", "test4"], + ) + heuristic2 = TestPrioritizations( + tests_being_ranked=tests, + probable_relevance=["test5::classA"], + ) + + aggregator = AggregatedHeuristics(unranked_tests=tests) + aggregator.add_heuristic_results(HEURISTICS[0], heuristic1) + aggregator.add_heuristic_results(HEURISTICS[1], heuristic2) + + stats3 = aggregator.get_test_stats(TestRun("test3")) + stats5 = aggregator.get_test_stats(TestRun("test5::classA")) + + def assert_valid_dict(dict_contents: Dict[str, Any]) -> None: + for key, value in dict_contents.items(): + self.assertTrue(isinstance(key, str)) + self.assertTrue( + isinstance(value, (str, float, int, list, dict)), + f"{value} is not a str, float, or dict", + ) + if isinstance(value, dict): + assert_valid_dict(value) + elif isinstance(value, list): + for item in value: + assert_valid_dict(item) + + assert_valid_dict(stats3) + assert_valid_dict(stats5) + + def test_get_test_stats_gets_rank_for_test_classes(self) -> None: + self.maxDiff = None + tests = ["test1", "test2", "test3", "test4", "test5"] + heuristic1 = TestPrioritizations( + tests_being_ranked=tests, + high_relevance=["test3", "test4"], + ) + heuristic2 = TestPrioritizations( + tests_being_ranked=tests, + probable_relevance=["test5::classA"], + ) + + aggregator = AggregatedHeuristics(unranked_tests=tests) + aggregator.add_heuristic_results(HEURISTICS[0], heuristic1) + aggregator.add_heuristic_results(HEURISTICS[1], heuristic2) + + statsInclusive = aggregator.get_test_stats( + TestRun("test5", included=["classA"]) + ) + statsExclusive = aggregator.get_test_stats( + TestRun("test5", excluded=["classA"]) + ) + + print("h") + # Validate the heuristic level stats are correct + self.assertEqual( + statsInclusive["heuristics"][1]["order_within_relevance_group"], 0 + ) + self.assertEqual( + statsInclusive["heuristics"][1]["num_tests_in_relevance_group"], 1 + ) + self.assertEqual(statsInclusive["heuristics"][1]["order_overall"], 0) + self.assertEqual(statsInclusive["heuristics"][1]["relevance_group"], "PROBABLE") + self.assertEqual(statsInclusive["aggregated"]["order_overall"], 2) + + self.assertEqual( + statsExclusive["heuristics"][1]["order_within_relevance_group"], 4 + ) + self.assertEqual( + statsExclusive["heuristics"][1]["num_tests_in_relevance_group"], 5 + ) + self.assertEqual(statsExclusive["heuristics"][1]["order_overall"], 5) + self.assertEqual(statsExclusive["heuristics"][1]["relevance_group"], "UNRANKED") + self.assertEqual(statsExclusive["aggregated"]["order_overall"], 5) if __name__ == "__main__": diff --git a/tools/test/test_test_run.py b/tools/test/test_test_run.py new file mode 100644 index 00000000000..9e1e4702d19 --- /dev/null +++ b/tools/test/test_test_run.py @@ -0,0 +1,169 @@ +import pathlib +import sys +import unittest + +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent +try: + # using tools/ to optimize test run. + sys.path.append(str(REPO_ROOT)) + from tools.testing.test_run import ShardedTest, TestRun +except ModuleNotFoundError: + print("Can't import required modules, exiting") + sys.exit(1) + + +class TestTestRun(unittest.TestCase): + def test_union_with_full_run(self) -> None: + run1 = TestRun("foo") + run2 = TestRun("foo::bar") + + self.assertEqual(run1 | run2, run1) + self.assertEqual(run2 | run1, run1) + + def test_union_with_inclusions(self) -> None: + run1 = TestRun("foo::bar") + run2 = TestRun("foo::baz") + + expected = TestRun("foo") + expected._included.add("bar") + expected._included.add("baz") + + self.assertEqual(run1 | run2, expected) + self.assertEqual(run2 | run1, expected) + + def test_union_with_non_overlapping_exclusions(self) -> None: + run1 = TestRun("foo", excluded=["bar"]) + run2 = TestRun("foo", excluded=["baz"]) + + expected = TestRun("foo") + + self.assertEqual(run1 | run2, expected) + self.assertEqual(run2 | run1, expected) + + def test_union_with_overlapping_exclusions(self) -> None: + run1 = TestRun("foo", excluded=["bar", "car"]) + run2 = TestRun("foo", excluded=["bar", "caz"]) + + expected = TestRun("foo", excluded=["bar"]) + + self.assertEqual(run1 | run2, expected) + self.assertEqual(run2 | run1, expected) + + def test_union_with_mixed_inclusion_exclusions(self) -> None: + run1 = TestRun("foo", excluded=["baz", "car"]) + run2 = TestRun("foo", included=["baz"]) + + expected = TestRun("foo", excluded=["car"]) + + self.assertEqual(run1 | run2, expected) + self.assertEqual(run2 | run1, expected) + + def test_union_with_mixed_files_fails(self) -> None: + run1 = TestRun("foo") + run2 = TestRun("bar") + + with self.assertRaises(AssertionError): + run1 | run2 + + def test_union_with_empty_file_yields_orig_file(self) -> None: + run1 = TestRun("foo") + run2 = TestRun.empty() + + self.assertEqual(run1 | run2, run1) + self.assertEqual(run2 | run1, run1) + + def test_subtracting_full_run_fails(self) -> None: + run1 = TestRun("foo::bar") + run2 = TestRun("foo") + + self.assertEqual(run1 - run2, TestRun.empty()) + + def test_subtracting_empty_file_yields_orig_file(self) -> None: + run1 = TestRun("foo") + run2 = TestRun.empty() + + self.assertEqual(run1 - run2, run1) + self.assertEqual(run2 - run1, TestRun.empty()) + + def test_empty_is_falsey(self) -> None: + self.assertFalse(TestRun.empty()) + + def test_subtracting_inclusion_from_full_run(self) -> None: + run1 = TestRun("foo") + run2 = TestRun("foo::bar") + + expected = TestRun("foo", excluded=["bar"]) + + self.assertEqual(run1 - run2, expected) + + def test_subtracting_inclusion_from_overlapping_inclusion(self) -> None: + run1 = TestRun("foo", included=["bar", "baz"]) + run2 = TestRun("foo::baz") + + self.assertEqual(run1 - run2, TestRun("foo", included=["bar"])) + + def test_subtracting_inclusion_from_nonoverlapping_inclusion(self) -> None: + run1 = TestRun("foo", included=["bar", "baz"]) + run2 = TestRun("foo", included=["car"]) + + self.assertEqual(run1 - run2, TestRun("foo", included=["bar", "baz"])) + + def test_subtracting_exclusion_from_full_run(self) -> None: + run1 = TestRun("foo") + run2 = TestRun("foo", excluded=["bar"]) + + self.assertEqual(run1 - run2, TestRun("foo", included=["bar"])) + + def test_subtracting_exclusion_from_superset_exclusion(self) -> None: + run1 = TestRun("foo", excluded=["bar", "baz"]) + run2 = TestRun("foo", excluded=["baz"]) + + self.assertEqual(run1 - run2, TestRun.empty()) + self.assertEqual(run2 - run1, TestRun("foo", included=["bar"])) + + def test_subtracting_exclusion_from_nonoverlapping_exclusion(self) -> None: + run1 = TestRun("foo", excluded=["bar", "baz"]) + run2 = TestRun("foo", excluded=["car"]) + + self.assertEqual(run1 - run2, TestRun("foo", included=["car"])) + self.assertEqual(run2 - run1, TestRun("foo", included=["bar", "baz"])) + + def test_subtracting_inclusion_from_exclusion_without_overlaps(self) -> None: + run1 = TestRun("foo", excluded=["bar", "baz"]) + run2 = TestRun("foo", included=["bar"]) + + self.assertEqual(run1 - run2, run1) + self.assertEqual(run2 - run1, run2) + + def test_subtracting_inclusion_from_exclusion_with_overlaps(self) -> None: + run1 = TestRun("foo", excluded=["bar", "baz"]) + run2 = TestRun("foo", included=["bar", "car"]) + + self.assertEqual(run1 - run2, TestRun("foo", excluded=["bar", "baz", "car"])) + self.assertEqual(run2 - run1, TestRun("foo", included=["bar"])) + + def test_and(self) -> None: + run1 = TestRun("foo", included=["bar", "baz"]) + run2 = TestRun("foo", included=["bar", "car"]) + + self.assertEqual(run1 & run2, TestRun("foo", included=["bar"])) + + def test_and_exclusions(self) -> None: + run1 = TestRun("foo", excluded=["bar", "baz"]) + run2 = TestRun("foo", excluded=["bar", "car"]) + + self.assertEqual(run1 & run2, TestRun("foo", excluded=["bar", "baz", "car"])) + + +class TestShardedTest(unittest.TestCase): + def test_get_pytest_args(self) -> None: + test = TestRun("foo", included=["bar", "baz"]) + sharded_test = ShardedTest(test, 1, 1) + + expected_args = ["-k", "bar or baz"] + + self.assertListEqual(sharded_test.get_pytest_args(), expected_args) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/test/test_test_selections.py b/tools/test/test_test_selections.py index 1a06cb20675..a6f6dc2e2bf 100644 --- a/tools/test/test_test_selections.py +++ b/tools/test/test_test_selections.py @@ -9,25 +9,30 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent try: # using tools/ to optimize test run. sys.path.append(str(REPO_ROOT)) - from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD + from tools.testing.test_run import ShardedTest, TestRun + from tools.testing.test_selections import calculate_shards, THRESHOLD except ModuleNotFoundError: print("Can't import required modules, exiting") sys.exit(1) +def gen_class_times(test_times: Dict[str, float]) -> Dict[str, Dict[str, float]]: + return {k: {"class1": v} for k, v in test_times.items()} + + class TestCalculateShards(unittest.TestCase): - tests: List[str] = [ - "super_long_test", - "long_test1", - "long_test2", - "normal_test1", - "normal_test2", - "normal_test3", - "short_test1", - "short_test2", - "short_test3", - "short_test4", - "short_test5", + tests: List[TestRun] = [ + TestRun("super_long_test"), + TestRun("long_test1"), + TestRun("long_test2"), + TestRun("normal_test1"), + TestRun("normal_test2"), + TestRun("normal_test3"), + TestRun("short_test1"), + TestRun("short_test2"), + TestRun("short_test3"), + TestRun("short_test4"), + TestRun("short_test5"), ] test_times: Dict[str, float] = { @@ -44,6 +49,20 @@ class TestCalculateShards(unittest.TestCase): "short_test5": 0.01, } + test_class_times: Dict[str, Dict[str, float]] = { + "super_long_test": {"class1": 55}, + "long_test1": {"class1": 1, "class2": 21}, + "long_test2": {"class1": 10, "class2": 8}, + "normal_test1": {"class1": 9}, + "normal_test2": {"class1": 7}, + "normal_test3": {"class1": 5}, + "short_test1": {"class1": 1}, + "short_test2": {"class1": 0.6}, + "short_test3": {"class1": 0.4}, + "short_test4": {"class1": 0.3}, + "short_test5": {"class1": 0.01}, + } + def assert_shards_equal( self, expected_shards: List[Tuple[float, List[ShardedTest]]], @@ -58,81 +77,92 @@ class TestCalculateShards(unittest.TestCase): ( 60.0, [ - ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55), - ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5), + ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55), + ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5), ], ), ( 58.31, [ - ShardedTest(name="long_test1", shard=1, num_shards=1, time=22), - ShardedTest(name="long_test2", shard=1, num_shards=1, time=18), - ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9), - ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7), - ShardedTest(name="short_test1", shard=1, num_shards=1, time=1), - ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6), - ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4), - ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3), - ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01), + ShardedTest(test="long_test1", shard=1, num_shards=1, time=22), + ShardedTest(test="long_test2", shard=1, num_shards=1, time=18), + ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9), + ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7), + ShardedTest(test="short_test1", shard=1, num_shards=1, time=1), + ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6), + ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4), + ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3), + ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01), ], ), ] self.assert_shards_equal( - expected_shards, calculate_shards(2, self.tests, self.test_times) + expected_shards, + calculate_shards(2, self.tests, self.test_times, self.test_class_times), ) def test_calculate_1_shard_with_complete_test_times(self) -> None: + tests = self.tests.copy() + class_test1 = TestRun("long_test1", excluded=["class2"]) + class_test2 = TestRun("long_test1", included=["class2"]) + tests.append(class_test1) + tests.append(class_test2) + expected_shards = [ ( - 118.31, + 140.31, [ - ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55), - ShardedTest(name="long_test1", shard=1, num_shards=1, time=22), - ShardedTest(name="long_test2", shard=1, num_shards=1, time=18), - ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9), - ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7), - ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5), - ShardedTest(name="short_test1", shard=1, num_shards=1, time=1), - ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6), - ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4), - ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3), - ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01), + ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55), + ShardedTest(test="long_test1", shard=1, num_shards=1, time=22), + ShardedTest(class_test2, shard=1, num_shards=1, time=21), + ShardedTest(test="long_test2", shard=1, num_shards=1, time=18), + ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9), + ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7), + ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5), + ShardedTest(test="short_test1", shard=1, num_shards=1, time=1), + ShardedTest(class_test1, shard=1, num_shards=1, time=1), + ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6), + ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4), + ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3), + ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01), ], ) ] self.assert_shards_equal( - expected_shards, calculate_shards(1, self.tests, self.test_times) + expected_shards, + calculate_shards(1, tests, self.test_times, self.test_class_times), ) def test_calculate_5_shards_with_complete_test_times(self) -> None: expected_shards = [ ( 55.0, - [ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55)], + [ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55)], ), - (22.0, [ShardedTest(name="long_test1", shard=1, num_shards=1, time=22)]), - (18.0, [ShardedTest(name="long_test2", shard=1, num_shards=1, time=18)]), + (22.0, [ShardedTest(test="long_test1", shard=1, num_shards=1, time=22)]), + (18.0, [ShardedTest(test="long_test2", shard=1, num_shards=1, time=18)]), ( 11.31, [ - ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9), - ShardedTest(name="short_test1", shard=1, num_shards=1, time=1), - ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6), - ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4), - ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3), - ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01), + ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9), + ShardedTest(test="short_test1", shard=1, num_shards=1, time=1), + ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6), + ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4), + ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3), + ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01), ], ), ( 12.0, [ - ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7), - ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5), + ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7), + ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5), ], ), ] self.assert_shards_equal( - expected_shards, calculate_shards(5, self.tests, self.test_times) + expected_shards, + calculate_shards(5, self.tests, self.test_times, self.test_class_times), ) def test_calculate_2_shards_with_incomplete_test_times(self) -> None: @@ -143,29 +173,35 @@ class TestCalculateShards(unittest.TestCase): ( 22.0, [ - ShardedTest(name="long_test1", shard=1, num_shards=1, time=22), - ShardedTest(name="long_test2", shard=1, num_shards=1, time=None), - ShardedTest(name="normal_test3", shard=1, num_shards=1, time=None), - ShardedTest(name="short_test3", shard=1, num_shards=1, time=None), - ShardedTest(name="short_test5", shard=1, num_shards=1, time=None), + ShardedTest(test="long_test1", shard=1, num_shards=1, time=22), + ShardedTest(test="long_test2", shard=1, num_shards=1, time=None), + ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None), + ShardedTest(test="short_test3", shard=1, num_shards=1, time=None), + ShardedTest(test="short_test5", shard=1, num_shards=1, time=None), ], ), ( 10.0, [ - ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9), - ShardedTest(name="short_test1", shard=1, num_shards=1, time=1), + ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9), + ShardedTest(test="short_test1", shard=1, num_shards=1, time=1), ShardedTest( - name="super_long_test", shard=1, num_shards=1, time=None + test="super_long_test", shard=1, num_shards=1, time=None ), - ShardedTest(name="normal_test2", shard=1, num_shards=1, time=None), - ShardedTest(name="short_test2", shard=1, num_shards=1, time=None), - ShardedTest(name="short_test4", shard=1, num_shards=1, time=None), + ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None), + ShardedTest(test="short_test2", shard=1, num_shards=1, time=None), + ShardedTest(test="short_test4", shard=1, num_shards=1, time=None), ], ), ] self.assert_shards_equal( - expected_shards, calculate_shards(2, self.tests, incomplete_test_times) + expected_shards, + calculate_shards( + 2, + self.tests, + incomplete_test_times, + gen_class_times(incomplete_test_times), + ), ) def test_calculate_5_shards_with_incomplete_test_times(self) -> None: @@ -176,54 +212,66 @@ class TestCalculateShards(unittest.TestCase): ( 22.0, [ - ShardedTest(name="long_test1", shard=1, num_shards=1, time=22), - ShardedTest(name="normal_test2", shard=1, num_shards=1, time=None), - ShardedTest(name="short_test5", shard=1, num_shards=1, time=None), + ShardedTest(test="long_test1", shard=1, num_shards=1, time=22), + ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None), + ShardedTest(test="short_test5", shard=1, num_shards=1, time=None), ], ), ( 9.0, [ - ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9), - ShardedTest(name="normal_test3", shard=1, num_shards=1, time=None), + ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9), + ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None), ], ), ( 1.0, [ - ShardedTest(name="short_test1", shard=1, num_shards=1, time=1), - ShardedTest(name="short_test2", shard=1, num_shards=1, time=None), + ShardedTest(test="short_test1", shard=1, num_shards=1, time=1), + ShardedTest(test="short_test2", shard=1, num_shards=1, time=None), ], ), ( 0.0, [ ShardedTest( - name="super_long_test", shard=1, num_shards=1, time=None + test="super_long_test", shard=1, num_shards=1, time=None ), - ShardedTest(name="short_test3", shard=1, num_shards=1, time=None), + ShardedTest(test="short_test3", shard=1, num_shards=1, time=None), ], ), ( 0.0, [ - ShardedTest(name="long_test2", shard=1, num_shards=1, time=None), - ShardedTest(name="short_test4", shard=1, num_shards=1, time=None), + ShardedTest(test="long_test2", shard=1, num_shards=1, time=None), + ShardedTest(test="short_test4", shard=1, num_shards=1, time=None), ], ), ] self.assert_shards_equal( - expected_shards, calculate_shards(5, self.tests, incomplete_test_times) + expected_shards, + calculate_shards( + 5, + self.tests, + incomplete_test_times, + gen_class_times(incomplete_test_times), + ), ) def test_split_shards(self) -> None: test_times: Dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD} expected_shards = [ - (600.0, [ShardedTest(name="test1", shard=1, num_shards=1, time=THRESHOLD)]), - (600.0, [ShardedTest(name="test2", shard=1, num_shards=1, time=THRESHOLD)]), + (600.0, [ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD)]), + (600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]), ] self.assert_shards_equal( - expected_shards, calculate_shards(2, list(test_times.keys()), test_times) + expected_shards, + calculate_shards( + 2, + [TestRun(t) for t in test_times.keys()], + test_times, + gen_class_times(test_times), + ), ) test_times = {"test1": THRESHOLD * 4, "test2": THRESHOLD * 2.5} @@ -231,35 +279,47 @@ class TestCalculateShards(unittest.TestCase): ( 2200.0, [ - ShardedTest(name="test1", shard=1, num_shards=4, time=600.0), - ShardedTest(name="test1", shard=3, num_shards=4, time=600.0), - ShardedTest(name="test2", shard=1, num_shards=3, time=500.0), - ShardedTest(name="test2", shard=3, num_shards=3, time=500.0), + ShardedTest(test="test1", shard=1, num_shards=4, time=600.0), + ShardedTest(test="test1", shard=3, num_shards=4, time=600.0), + ShardedTest(test="test2", shard=1, num_shards=3, time=500.0), + ShardedTest(test="test2", shard=3, num_shards=3, time=500.0), ], ), ( 1700.0, [ - ShardedTest(name="test1", shard=2, num_shards=4, time=600.0), - ShardedTest(name="test1", shard=4, num_shards=4, time=600.0), - ShardedTest(name="test2", shard=2, num_shards=3, time=500.0), + ShardedTest(test="test1", shard=2, num_shards=4, time=600.0), + ShardedTest(test="test1", shard=4, num_shards=4, time=600.0), + ShardedTest(test="test2", shard=2, num_shards=3, time=500.0), ], ), ] self.assert_shards_equal( - expected_shards, calculate_shards(2, list(test_times.keys()), test_times) + expected_shards, + calculate_shards( + 2, + [TestRun(t) for t in test_times.keys()], + test_times, + gen_class_times(test_times), + ), ) test_times = {"test1": THRESHOLD / 2, "test2": THRESHOLD} expected_shards = [ - (600.0, [ShardedTest(name="test2", shard=1, num_shards=1, time=THRESHOLD)]), + (600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]), ( 300.0, - [ShardedTest(name="test1", shard=1, num_shards=1, time=THRESHOLD / 2)], + [ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD / 2)], ), ] self.assert_shards_equal( - expected_shards, calculate_shards(2, list(test_times.keys()), test_times) + expected_shards, + calculate_shards( + 2, + [TestRun(t) for t in test_times.keys()], + test_times, + gen_class_times(test_times), + ), ) def test_split_shards_random(self) -> None: @@ -272,7 +332,10 @@ class TestCalculateShards(unittest.TestCase): } shards = calculate_shards( - num_shards, list(random_times.keys()), random_times + num_shards, + [TestRun(t) for t in random_times.keys()], + random_times, + gen_class_times(random_times), ) times = [x[0] for x in shards] @@ -300,7 +363,7 @@ class TestCalculateShards(unittest.TestCase): def test_calculate_2_shards_against_optimal_shards(self) -> None: random.seed(120) for _ in range(100): - random_times = {k: random.random() * 10 for k in self.tests} + random_times = {k.test_file: random.random() * 10 for k in self.tests} # all test times except first two rest_of_tests = [ i @@ -315,12 +378,14 @@ class TestCalculateShards(unittest.TestCase): # (sum_of_rest, ['super_long_test', 'long_test1']), # (sum_of_rest, [i for i in self.tests if i != 'super_long_test' and i != 'long_test1']), # ] - calculated_shards = calculate_shards(2, self.tests, random_times) + calculated_shards = calculate_shards( + 2, self.tests, random_times, gen_class_times(random_times) + ) max_shard_time = max(calculated_shards[0][0], calculated_shards[1][0]) if sum_of_rest != 0: # The calculated shard should not have a ratio worse than 7/6 for num_shards = 2 self.assertGreaterEqual(7.0 / 6.0, max_shard_time / sum_of_rest) - sorted_tests = sorted(self.tests) + sorted_tests = sorted([t.test_file for t in self.tests]) sorted_shard_tests = sorted( calculated_shards[0][1] + calculated_shards[1][1] ) diff --git a/tools/testing/target_determination/heuristics/interface.py b/tools/testing/target_determination/heuristics/interface.py index c06af4f00cb..6f323fe949f 100644 --- a/tools/testing/target_determination/heuristics/interface.py +++ b/tools/testing/target_determination/heuristics/interface.py @@ -1,20 +1,53 @@ import sys from abc import abstractmethod +from copy import copy from enum import Enum +from functools import total_ordering from itertools import chain -from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Tuple +from typing import ( + Any, + Callable, + Dict, + FrozenSet, + Iterable, + Iterator, + List, + Optional, + Tuple, +) -StringTuple = Tuple[str, ...] +from tools.testing.test_run import TestRun, TestRuns # Note: Keep the implementation of Relevance private to this file so # that it's easy to change in the future as we discover what's needed +@total_ordering class Relevance(Enum): - HIGH = 0 - PROBABLE = 1 + HIGH = 4 + PROBABLE = 3 UNRANKED = 2 - UNLIKELY = 3 # Not yet supported. Needs more infra to be usable - NONE = 4 # Not yet supported. Needs more infra to be usable + UNLIKELY = 1 # Not yet supported. Needs more infra to be usable + NONE = 0 # Not yet supported. Needs more infra to be usable + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Relevance): + return False + + return self.value == other.value + + def __lt__(self, other: object) -> bool: + if not isinstance(other, Relevance): + raise NotImplementedError(f"Can't compare {self} to {other}") + + return self.value < other.value + + @staticmethod + def priority_traversal() -> Iterator["Relevance"]: + yield Relevance.HIGH + yield Relevance.PROBABLE + yield Relevance.UNRANKED + yield Relevance.UNLIKELY + yield Relevance.NONE METRIC_RELEVANCE_GROUP = "relevance_group" @@ -37,7 +70,7 @@ class TestPrioritizations: otherwise it breaks the test sharding logic """ - _test_priorities: List[StringTuple] # This list MUST be ordered by Relevance + _test_priorities: List[List[TestRun]] # This list MUST be ordered by Relevance _original_tests: FrozenSet[str] def __init__( @@ -51,106 +84,168 @@ class TestPrioritizations: ) -> None: self._original_tests = frozenset(tests_being_ranked) - self._test_priorities = [tuple() for _ in range(5)] + self._test_priorities = [[] for _ in range(5)] + # Setup the initial priorities + self._test_priorities[Relevance.UNRANKED.value] = [ + TestRun(test) for test in tests_being_ranked + ] - self._test_priorities[Relevance.HIGH.value] = self.filter_out_extra_tests( - high_relevance - ) - self._test_priorities[Relevance.PROBABLE.value] = self.filter_out_extra_tests( - probable_relevance - ) - self._test_priorities[Relevance.UNRANKED.value] = self.filter_out_extra_tests( - unranked_relevance - ) - self._test_priorities[Relevance.UNLIKELY.value] = self.filter_out_extra_tests( - unlikely_relevance - ) - self._test_priorities[Relevance.NONE.value] = self.filter_out_extra_tests( - no_relevance - ) - - # If any of the original tests were missed from the other lists, add them to the unranked_relevance list - missing_tests = sorted(self._original_tests - set(self.get_all_tests())) - self._test_priorities[Relevance.UNRANKED.value] = self._test_priorities[ - Relevance.UNRANKED.value - ] + tuple(missing_tests) + for test in high_relevance or []: + self.set_test_relevance(TestRun(test), Relevance.HIGH) + for test in probable_relevance or []: + self.set_test_relevance(TestRun(test), Relevance.PROBABLE) + for test in unranked_relevance or []: + self.set_test_relevance(TestRun(test), Relevance.UNRANKED) + for test in unlikely_relevance or []: + self.set_test_relevance(TestRun(test), Relevance.UNLIKELY) + for test in no_relevance or []: + self.set_test_relevance(TestRun(test), Relevance.NONE) self.validate_test_priorities() - def filter_out_extra_tests( - self, relevance_group: Optional[List[str]] - ) -> StringTuple: - if not relevance_group: - return tuple() - return tuple(filter(lambda test: test in self._original_tests, relevance_group)) + def _traverse_priorities(self) -> Iterator[Tuple[Relevance, List[TestRun]]]: + for relevance in Relevance.priority_traversal(): + yield (relevance, self._test_priorities[relevance.value]) + + def get_pointer_to_test(self, test_run: TestRun) -> Iterator[Tuple[Relevance, int]]: + """ + Returns all test runs that contain any subset of the given test_run and their current relevance. + + self._test_priorities should NOT have items added or removed form it while iterating over the + results of this function. + """ + # Find a test run that contains the given TestRun and it's relevance. + found_match = False + for relevance, tests in self._traverse_priorities(): + for idx, existing_test_run in enumerate(tests): + # Does the existing test run contain any of the test we're looking for? + shared_test = existing_test_run & test_run + if not shared_test.is_empty(): + found_match = True + yield (Relevance(relevance), idx) + + if not found_match: + raise ValueError(f"Test {test_run} not found in any relevance group") + + def _update_test_relevance( + self, + test_run: TestRun, + new_relevance: Relevance, + acceptable_relevance_fn: Callable[[Relevance, Relevance], bool], + ) -> None: + """ + Updates the test run's relevance to the new relevance. + + If the tests in the test run were previously split up into multiple test runs, all the chunks at a lower + relevance will be merged into one new test run at the new relevance, appended to the end of the relevance group. + + However, any tests in a test run that are already at the desired relevance will be left alone, keeping it's + original place in the relevance group. + """ + if test_run.test_file not in self._original_tests: + return # We don't need this test + + # The tests covered by test_run could potentially have been split up into + # multiple test runs, each at a different relevance. Let's make sure to bring + # all of them up to the minimum relevance + upgraded_tests = TestRun.empty() + tests_to_remove = [] + for curr_relevance, test_run_idx in self.get_pointer_to_test(test_run): + if acceptable_relevance_fn(curr_relevance, new_relevance): + # This test is already at the desired relevance + continue # no changes needed + + test_run_to_rerank = self._test_priorities[curr_relevance.value][ + test_run_idx + ] + # Remove the requested tests from their current relevance group, to be added to the new one + remaining_tests = test_run_to_rerank - test_run + upgraded_tests |= test_run_to_rerank & test_run + + # Remove the tests that are being upgraded + if remaining_tests: + self._test_priorities[curr_relevance.value][ + test_run_idx + ] = remaining_tests + else: + # List traversal prevents us from deleting these immediately, so note them for later + tests_to_remove.append((curr_relevance, test_run_idx)) + + for relevance, test_idx in tests_to_remove: + del self._test_priorities[relevance.value][test_idx] + + # And add them to the desired relevance group + if upgraded_tests: + self._test_priorities[new_relevance.value].append(upgraded_tests) + + def set_test_relevance(self, test_run: TestRun, new_relevance: Relevance) -> None: + return self._update_test_relevance( + test_run, new_relevance, lambda curr, new: curr == new + ) + + def raise_test_relevance(self, test_run: TestRun, new_relevance: Relevance) -> None: + return self._update_test_relevance( + test_run, new_relevance, lambda curr, new: curr >= new + ) def validate_test_priorities(self) -> None: + # Union all TestRuns that contain include/exclude pairs + all_tests = self.get_all_tests() + files = {} + for test in all_tests: + if test.test_file not in files: + files[test.test_file] = copy(test) + else: + files[test.test_file] |= test + + for test in files.values(): + assert ( + test.is_full_file() + ), f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that" + # Ensure that the set of tests in the TestPrioritizations is identical to the set of tests passed in assert self._original_tests == set( - self.get_all_tests() + files.keys() ), "The set of tests in the TestPrioritizations must be identical to the set of tests passed in" - @staticmethod - def _merge_tests( - current_tests: Iterable[str], - new_tests: Iterable[str], - higher_pri_tests: Iterable[str], - ) -> StringTuple: - """ - We append all new tests to the current tests, while preserving the sorting on the new_tests - However, exclude any specified tests which have now moved to a higher priority list or tests - that weren't originally in the self's TestPrioritizations - """ - merged_tests = [ - test - for test in chain(current_tests, new_tests) - if test not in higher_pri_tests - ] # skip the excluded tests - return tuple(dict.fromkeys(merged_tests)) # remove dupes while preseving order - def integrate_priorities(self, other: "TestPrioritizations") -> None: """ Integrates priorities from another TestPrioritizations object. The final result takes all tests from the `self` and rearranges them based on priorities from `other`. - If there are tests mentioned in `other` which are not in `self`, those tests are ignored. - (For example, that can happen if a heuristic reports tests that are not run in the current job) + Currently it will only raise the priority of a test, never lower it. """ assert ( self._original_tests == other._original_tests ), "Both tests should stem from the same original test list" - higher_pri_tests: List[str] = [] - for relevance, _ in enumerate(self._test_priorities): - self._test_priorities[relevance] = TestPrioritizations._merge_tests( - current_tests=self._test_priorities[relevance], - new_tests=other._test_priorities[relevance], - higher_pri_tests=higher_pri_tests, - ) - - # Don't let the tests we just added to the current relevance group be added to a lower relevance group - higher_pri_tests.extend(self._test_priorities[relevance]) + for relevance, _ in other._traverse_priorities(): + if relevance > Relevance.UNRANKED: + for test in other._test_priorities[relevance.value]: + self.raise_test_relevance(test, relevance) + # TODO: Hande the case where a test is moved to a lower relevance group (once we support that scenario) self.validate_test_priorities() + return - def get_all_tests(self) -> StringTuple: + def get_all_tests(self) -> TestRuns: """Returns all tests in the TestPrioritizations""" - return tuple(test for test in chain(*self._test_priorities)) + return tuple(chain(*self._test_priorities)) - def get_prioritized_tests(self) -> StringTuple: + def get_prioritized_tests(self) -> TestRuns: return self.get_high_relevance_tests() + self.get_probable_relevance_tests() - def get_high_relevance_tests(self) -> StringTuple: + def get_high_relevance_tests(self) -> TestRuns: return tuple(test for test in self._test_priorities[Relevance.HIGH.value]) - def get_probable_relevance_tests(self) -> StringTuple: + def get_probable_relevance_tests(self) -> TestRuns: return tuple(test for test in self._test_priorities[Relevance.PROBABLE.value]) - def get_unranked_relevance_tests(self) -> StringTuple: + def get_unranked_relevance_tests(self) -> TestRuns: return tuple(test for test in self._test_priorities[Relevance.UNRANKED.value]) def print_info(self) -> None: - def _print_tests(label: str, tests: StringTuple) -> None: + def _print_tests(label: str, tests: List[TestRun]) -> None: if not tests: return @@ -159,48 +254,61 @@ class TestPrioritizations: if test in tests: print(f" {test}") - for relevance_group, tests in enumerate(self._test_priorities): + for relevance_group, tests in self._traverse_priorities(): _print_tests(f"{Relevance(relevance_group).name.title()} Relevance", tests) - def _get_test_relevance_group(self, test_name: str) -> Relevance: - """Returns the priority of a test.""" - for relevance_group, tests in enumerate(self._test_priorities): - if test_name in tests: + def _get_test_relevance_group(self, test_run: TestRun) -> Relevance: + """Returns the rank of the given test run.""" + for relevance_group, tests in self._traverse_priorities(): + if any(t.contains(test_run) for t in tests): return Relevance(relevance_group) - raise ValueError(f"Test {test_name} not found in any relevance group") + print("holup, retry") + for relevance_group, tests in self._traverse_priorities(): + if any( + t.contains(test_run) for t in tests + ): # t could be the entire test_run or a superset + return Relevance(relevance_group) - def _get_test_order(self, test_name: str) -> int: - """Returns the rank of the test specified by this heuristic.""" + raise ValueError(f"Test {test_run} not found in any relevance group") + + def _get_test_order(self, test_run: TestRun) -> int: + """Returns the rank this heuristic suggested for the test run.""" base_rank = 0 - for relevance_group_tests in self._test_priorities: - if test_name in relevance_group_tests: - return base_rank + relevance_group_tests.index(test_name) + for _, relevance_group_tests in self._traverse_priorities(): + for idx, test in enumerate(relevance_group_tests): + if test.contains( + test_run + ): # test could be the entire test_run or a superset + return base_rank + idx base_rank += len(relevance_group_tests) - raise ValueError(f"Test {test_name} not found in any relevance group") + raise ValueError(f"Test {test_run} not found in any relevance group") - def _get_test_order_within_relevance_group(self, test_name: str) -> int: - for relevance_group_tests in self._test_priorities: - if test_name not in relevance_group_tests: - continue + def _get_test_order_within_relevance_group(self, test_run: TestRun) -> int: + """Returns the highest test order of any test class within the same relevance group.""" + for _, relevance_group_tests in self._traverse_priorities(): + for idx, test in enumerate(relevance_group_tests): + if test.contains( + test_run + ): # test could be the entire test_run or a superset + return idx - return relevance_group_tests.index(test_name) + raise ValueError(f"Test {test_run} not found in any relevance group") - raise ValueError(f"Test {test_name} not found in any relevance group") - - def get_priority_info_for_test(self, test_name: str) -> Dict[str, Any]: + def get_priority_info_for_test(self, test_run: TestRun) -> Dict[str, Any]: """Given a failing test, returns information about it's prioritization that we want to emit in our metrics.""" + relevance = self._get_test_relevance_group(test_run) return { - METRIC_RELEVANCE_GROUP: self._get_test_relevance_group(test_name).name, + METRIC_RELEVANCE_GROUP: relevance.name, METRIC_ORDER_WITHIN_RELEVANCE_GROUP: self._get_test_order_within_relevance_group( - test_name + test_run ), METRIC_NUM_TESTS_IN_RELEVANCE_GROUP: len( - self._test_priorities[self._get_test_relevance_group(test_name).value] + self._test_priorities[relevance.value] ), - METRIC_ORDER_OVERALL: self._get_test_order(test_name), + METRIC_ORDER_OVERALL: self._get_test_order(test_run), } @@ -247,12 +355,13 @@ class AggregatedHeuristics: return aggregated_priorities - def get_test_stats(self, test: str) -> Dict[str, Any]: + def get_test_stats(self, test: TestRun) -> Dict[str, Any]: """ Returns the aggregated statistics for a given test. """ stats: Dict[str, Any] = { - "test_name": test, + "test_name": test.test_file, + "test_filters": test.get_pytest_filter(), } # Get baseline metrics assuming we didn't have any TD heuristics diff --git a/tools/testing/test_run.py b/tools/testing/test_run.py new file mode 100644 index 00000000000..e9eee27b4df --- /dev/null +++ b/tools/testing/test_run.py @@ -0,0 +1,305 @@ +from copy import copy +from functools import total_ordering +from typing import Iterable, List, Optional, Set, Tuple, Union + + +class TestRun: + """ + TestRun defines the set of tests that should be run together in a single pytest invocation. + It'll either be a whole test file or a subset of a test file. + + This class assumes that we won't always know the full set of TestClasses in a test file. + So it's designed to include or exclude explicitly requested TestClasses, while having accepting + that there will be an ambiguous set of "unknown" test classes that are not expliclty called out. + Those manifest as tests that haven't been explicitly excluded. + """ + + test_file: str + _exclued: Set[str] # Tests that should be excluded from this test run + _included: Set[str] # If non-empy, only these tests should be run in this test run + + def __init__( + self, + name: str, + excluded: Optional[Iterable[str]] = None, + included: Optional[Iterable[str]] = None, + ) -> None: + self._excluded = set() + self._included = set() + + if excluded and included: + raise ValueError("Can't specify both included and excluded") + + if "::" in name: + assert ( + not included and not excluded + ), "Can't specify included or excluded tests when specifying a test class in the file name" + self.test_file, test_class = name.split("::") + self._included.add(test_class) + else: + self.test_file = name + + # For testing purposes + if excluded: + self._excluded = set(excluded) + if included: + self._included = set(included) + + @staticmethod + def empty() -> "TestRun": + return TestRun("") + + def is_empty(self) -> bool: + # Lack of a test_file means that this is an empty run, + # which means there is nothing to run. It's the zero. + return not self.test_file + + def is_full_file(self) -> bool: + return not self._included and not self._excluded + + def included(self) -> Set[str]: + return self._included.copy() + + def excluded(self) -> Set[str]: + return self._excluded.copy() + + def get_pytest_filter(self) -> str: + if self._included: + return " or ".join(sorted(self._included)) + elif self._excluded: + return f"not ({' and '.join(sorted(self._excluded))})" + else: + return "" + + def contains(self, test: "TestRun") -> bool: + if self.test_file != test.test_file: + return False + + if self.is_full_file(): + return True # self contains all tests + + if test.is_full_file(): + return False # test contains all tests, but self doesn't + + # Does self exclude a subset of what test excldes? + if test._excluded: + return test._excluded.issubset(self._excluded) + + # Does self include everything test includes? + if self._included: + return test._included.issubset(self._included) + + # Getting to here means that test includes and self excludes + # Does self exclude anything test includes? If not, we're good + return not self._excluded.intersection(test._included) + + def __copy__(self) -> "TestRun": + return TestRun(self.test_file, excluded=self._excluded, included=self._included) + + def __bool__(self) -> bool: + return not self.is_empty() + + def __repr__(self) -> str: + r: str = f"RunTest({self.test_file}" + r += f", included: {self._included}" if self._included else "" + r += f", excluded: {self._excluded}" if self._excluded else "" + r += ")" + return r + + def __str__(self) -> str: + if self.is_empty(): + return "Empty" + + pytest_filter = self.get_pytest_filter() + if pytest_filter: + return self.test_file + ", " + pytest_filter + return self.test_file + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TestRun): + return False + + ret = self.test_file == other.test_file + ret = ret and self._included == other._included + ret = ret and self._excluded == other._excluded + return ret + + def __ior__( # noqa: PYI034 Method returns `self` + self, other: "TestRun" + ) -> "TestRun": + res = self | other + self.test_file = res.test_file + self._included = res._included + self._excluded = res._excluded + + return self + + def __or__(self, other: "TestRun") -> "TestRun": + """ + To OR/Union test runs means to run all the tests that either of the two runs specify. + """ + + # Is any file empty? + if self.is_empty(): + return other + if other.is_empty(): + return copy(self) + + # If not, ensure we have the same file + assert ( + self.test_file == other.test_file + ), f"Can't exclude {other} from {self} because they're not the same test file" + + # 4 possible cases: + + # 1. Either file is the full file, so union is everything + if self.is_full_file() or other.is_full_file(): + # The union is the whole file + return TestRun(self.test_file) + + # 2. Both files only run what's in _included, so union is the union of the two sets + if self._included and other._included: + return TestRun( + self.test_file, included=self._included.union(other._included) + ) + + # 3. Both files only exclude what's in _excluded, so union is the intersection of the two sets + if self._excluded and other._excluded: + return TestRun( + self.test_file, excluded=self._excluded.intersection(other._excluded) + ) + + # 4. One file includes and the other excludes, so we then continue excluding the _excluded set minus + # whatever is in the _included set + included = self._included | other._included + excluded = self._excluded | other._excluded + return TestRun(self.test_file, excluded=excluded - included) + + def __isub__( # noqa: PYI034 Method returns `self` + self, other: "TestRun" + ) -> "TestRun": + res = self - other + self.test_file = res.test_file + self._included = res._included + self._excluded = res._excluded + return self + + def __sub__(self, other: "TestRun") -> "TestRun": + """ + To subtract test runs means to run all the tests in the first run except for what the second run specifies. + """ + + # Is any file empty? + if self.is_empty(): + return TestRun.empty() + if other.is_empty(): + return copy(self) + + # Are you trying to subtract tests that don't even exist in this test run? + if self.test_file != other.test_file: + return copy(self) + + # You're subtracting everything? + if other.is_full_file(): + return TestRun.empty() + + def return_inclusions_or_empty(inclusions: Set[str]) -> TestRun: + if inclusions: + return TestRun(self.test_file, included=inclusions) + return TestRun.empty() + + if other._included: + if self._included: + return return_inclusions_or_empty(self._included - other._included) + else: + return TestRun( + self.test_file, excluded=self._excluded | other._included + ) + else: + if self._included: + return return_inclusions_or_empty(self._included & other._excluded) + else: + return return_inclusions_or_empty(other._excluded - self._excluded) + + def __and__(self, other: "TestRun") -> "TestRun": + if self.test_file != other.test_file: + return TestRun.empty() + + return (self | other) - (self - other) - (other - self) + + +TestRuns = Tuple[TestRun, ...] + + +@total_ordering +class ShardedTest: + test: TestRun + shard: int + num_shards: int + time: Optional[float] # In seconds + + def __init__( + self, + test: Union[TestRun, str], + shard: int, + num_shards: int, + time: Optional[float] = None, + ) -> None: + if isinstance(test, str): + test = TestRun(test) + self.test = test + self.shard = shard + self.num_shards = num_shards + self.time = time + + @property + def name(self) -> str: + return self.test.test_file + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ShardedTest): + return False + return ( + self.test == other.test + and self.shard == other.shard + and self.num_shards == other.num_shards + and self.time == other.time + ) + + def __repr__(self) -> str: + ret = f"{self.test} {self.shard}/{self.num_shards}" + if self.time: + ret += f" ({self.time}s)" + + return ret + + def __lt__(self, other: object) -> bool: + if not isinstance(other, ShardedTest): + raise NotImplementedError + + # This is how the list was implicity sorted when it was a NamedTuple + if self.name != other.name: + return self.name < other.name + if self.shard != other.shard: + return self.shard < other.shard + if self.num_shards != other.num_shards: + return self.num_shards < other.num_shards + + # None is the smallest value + if self.time is None: + return True + if other.time is None: + return False + return self.time < other.time + + def __str__(self) -> str: + return f"{self.test} {self.shard}/{self.num_shards}" + + def get_time(self) -> float: + return self.time or 0 + + def get_pytest_args(self) -> List[str]: + filter = self.test.get_pytest_filter() + if filter: + return ["-k", self.test.get_pytest_filter()] + return [] diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py index 590f915053c..340db6f499b 100644 --- a/tools/testing/test_selections.py +++ b/tools/testing/test_selections.py @@ -3,9 +3,10 @@ import os import subprocess from pathlib import Path -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple +from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests +from tools.testing.test_run import ShardedTest, TestRun REPO_ROOT = Path(__file__).resolve().parent.parent.parent @@ -42,19 +43,6 @@ if IS_ROCM and not IS_MEM_LEAK_CHECK: NUM_PROCS = 1 -class ShardedTest(NamedTuple): - name: str - shard: int - num_shards: int - time: Optional[float] # In seconds - - def __str__(self) -> str: - return f"{self.name} {self.shard}/{self.num_shards}" - - def get_time(self) -> float: - return self.time or 0 - - class ShardJob: def __init__(self) -> None: self.serial: List[ShardedTest] = [] @@ -73,11 +61,51 @@ class ShardJob: def get_with_pytest_shard( - tests: List[str], test_file_times: Dict[str, float] + tests: Sequence[TestRun], + test_file_times: Dict[str, float], + test_class_times: Optional[Dict[str, Dict[str, float]]], ) -> List[ShardedTest]: sharded_tests: List[ShardedTest] = [] + + def get_duration_for_classes( + test_file: str, test_classes: Set[str] + ) -> Optional[float]: + duration: float = 0 + if not test_class_times: + return None + + for test_class in test_classes: + class_duration = test_class_times.get(test_file, {}).get(test_class, None) + if class_duration is None: + return None + if class_duration: + duration += class_duration + return duration + for test in tests: - duration = test_file_times.get(test, None) + file_duration = test_file_times.get(test.test_file, None) + included = test.included() + excluded = test.excluded() + included_classes_duration = get_duration_for_classes(test.test_file, included) + excluded_classes_duration = get_duration_for_classes(test.test_file, excluded) + + if included: + # If we don't have the time for all included classes, our upper bound is the file duration + duration = ( + included_classes_duration + if included_classes_duration is not None + else file_duration + ) + elif excluded: + # If we don't have the time for all excluded classes, our upper bound is file duration + duration = ( + file_duration - excluded_classes_duration + if excluded_classes_duration is not None and file_duration is not None + else file_duration + ) + else: + duration = file_duration + if duration and duration > THRESHOLD: num_shards = math.ceil(duration / THRESHOLD) for i in range(num_shards): @@ -91,21 +119,27 @@ def get_with_pytest_shard( def calculate_shards( num_shards: int, - tests: List[str], + tests: Sequence[TestRun], test_file_times: Dict[str, float], + test_class_times: Optional[Dict[str, Dict[str, float]]], must_serial: Optional[Callable[[str], bool]] = None, sort_by_time: bool = True, ) -> List[Tuple[float, List[ShardedTest]]]: must_serial = must_serial or (lambda x: True) - known_tests = tests - unknown_tests = [] + known_tests: Sequence[TestRun] = tests + unknown_tests: Sequence[TestRun] = [] if sort_by_time: - known_tests = [x for x in tests if x in test_file_times] + known_tests = [ + x + for x in tests + if x.test_file in test_file_times + or (test_class_times and x.test_file in test_class_times) + ] unknown_tests = [x for x in tests if x not in known_tests] - known_tests = get_with_pytest_shard(known_tests, test_file_times) + known_tests = get_with_pytest_shard(known_tests, test_file_times, test_class_times) if sort_by_time: known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)