mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[TD] Enable Test Class granularity on heuristics (#112161)
Changes the heuristic framework to support multiple prioritizing individual classes within a test file. Components of this included: - Updating TestPrioritizations to accept individual test classes being prioritized. Previously, when a heuristic wanted to prioritize a test file it would pass in the test's name, now to prioritize a class within a test it uses the notation "test::classname" - Changes are fully backwards compatible with existing heuristics - Test sharding now supports sharding individual tests (for when they're prioritized) - When a TestClass is prioritized, we pass the appropriate "-k" flags down to pytest Pull Request resolved: https://github.com/pytorch/pytorch/pull/112161 Approved by: https://github.com/huydhn
This commit is contained in:
parent
5cd1208415
commit
a5641bc56b
|
|
@ -15,7 +15,7 @@ import tempfile
|
|||
import time
|
||||
from contextlib import ExitStack
|
||||
from datetime import datetime
|
||||
from typing import Any, cast, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import Any, cast, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
import pkg_resources
|
||||
|
||||
|
|
@ -40,12 +40,18 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
|||
|
||||
# using tools/ to optimize test run.
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
from tools.stats.import_test_stats import ADDITIONAL_CI_FILES_FOLDER, TEST_TIMES_FILE
|
||||
from tools.stats.import_test_stats import (
|
||||
ADDITIONAL_CI_FILES_FOLDER,
|
||||
TEST_CLASS_TIMES_FILE,
|
||||
TEST_TIMES_FILE,
|
||||
)
|
||||
from tools.stats.upload_metrics import add_global_metric, emit_metric
|
||||
from tools.testing.target_determination.determinator import (
|
||||
AggregatedHeuristics,
|
||||
get_test_prioritizations,
|
||||
)
|
||||
|
||||
from tools.testing.test_run import TestRun
|
||||
from tools.testing.test_selections import (
|
||||
calculate_shards,
|
||||
get_test_case_configs,
|
||||
|
|
@ -479,7 +485,7 @@ def get_executable_command(options, disable_coverage=False, is_cpp_test=False):
|
|||
|
||||
|
||||
def run_test(
|
||||
test_module,
|
||||
test_module: ShardedTest,
|
||||
test_directory,
|
||||
options,
|
||||
launcher_cmd=None,
|
||||
|
|
@ -488,14 +494,9 @@ def run_test(
|
|||
) -> int:
|
||||
maybe_set_hip_visible_devies()
|
||||
unittest_args = options.additional_unittest_args.copy()
|
||||
test_file = test_module
|
||||
test_file = test_module.name
|
||||
stepcurrent_key = test_file
|
||||
|
||||
use_sharded_test = False
|
||||
if isinstance(test_file, ShardedTest):
|
||||
test_file = test_module.name
|
||||
use_sharded_test = True
|
||||
|
||||
is_distributed_test = test_file.startswith(DISTRIBUTED_TEST_PREFIX)
|
||||
is_cpp_test = test_file.startswith(CPP_TEST_PREFIX)
|
||||
# NB: Rerun disabled tests depends on pytest-flakefinder and it doesn't work with
|
||||
|
|
@ -507,17 +508,16 @@ def run_test(
|
|||
)
|
||||
return 0
|
||||
|
||||
if use_sharded_test:
|
||||
if is_cpp_test:
|
||||
stepcurrent_key = test_file
|
||||
else:
|
||||
unittest_args.extend(
|
||||
[
|
||||
f"--shard-id={test_module.shard - 1}",
|
||||
f"--num-shards={test_module.num_shards}",
|
||||
]
|
||||
)
|
||||
stepcurrent_key = f"{test_file}_{test_module.shard - 1}"
|
||||
if is_cpp_test:
|
||||
stepcurrent_key = test_file
|
||||
else:
|
||||
unittest_args.extend(
|
||||
[
|
||||
f"--shard-id={test_module.shard - 1}",
|
||||
f"--num-shards={test_module.num_shards}",
|
||||
]
|
||||
)
|
||||
stepcurrent_key = f"{test_file}_{test_module.shard - 1}"
|
||||
|
||||
if options.verbose:
|
||||
unittest_args.append(f'-{"v"*options.verbose}') # in case of pytest
|
||||
|
|
@ -541,6 +541,7 @@ def run_test(
|
|||
is_distributed_test=is_distributed_test,
|
||||
)
|
||||
)
|
||||
unittest_args.extend(test_module.get_pytest_args())
|
||||
unittest_args = [arg if arg != "-f" else "-x" for arg in unittest_args]
|
||||
|
||||
# TODO: These features are not available for C++ test yet
|
||||
|
|
@ -706,7 +707,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
|
|||
|
||||
assert install_directory, "install_directory must not be empty"
|
||||
os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
|
||||
return run_test(test_module, test_directory, options)
|
||||
return run_test(ShardedTest(test_module, 1, 1), test_directory, options)
|
||||
finally:
|
||||
os.environ["PYTHONPATH"] = python_path
|
||||
if os.path.exists(test_directory + "/" + test_module + ".py"):
|
||||
|
|
@ -1424,14 +1425,14 @@ def get_selected_tests(options) -> List[str]:
|
|||
return selected_tests
|
||||
|
||||
|
||||
def download_test_times(
|
||||
file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE,
|
||||
) -> Dict[str, float]:
|
||||
# Download previous test times to make sharding decisions
|
||||
def load_test_times_from_file(
|
||||
file: str,
|
||||
) -> Dict[str, Any]:
|
||||
# Load previous test times to make sharding decisions
|
||||
path = os.path.join(str(REPO_ROOT), file)
|
||||
if not os.path.exists(path):
|
||||
print_to_stderr(
|
||||
"::warning:: Failed to find test times file. Using round robin sharding."
|
||||
f"::warning:: Failed to find test times file `{path}`. Using round robin sharding."
|
||||
)
|
||||
return {}
|
||||
|
||||
|
|
@ -1456,6 +1457,18 @@ def download_test_times(
|
|||
return test_times_file["default"]["default"]
|
||||
|
||||
|
||||
def load_test_file_times(
|
||||
file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE,
|
||||
) -> Dict[str, float]:
|
||||
return cast(Dict[str, float], load_test_times_from_file(file))
|
||||
|
||||
|
||||
def load_test_class_times(
|
||||
file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_CLASS_TIMES_FILE,
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
return cast(Dict[str, Dict[str, float]], load_test_times_from_file(file))
|
||||
|
||||
|
||||
def get_sharding_opts(options) -> Tuple[int, int]:
|
||||
which_shard, num_shards = 1, 1
|
||||
if options.shard:
|
||||
|
|
@ -1471,8 +1484,9 @@ def get_sharding_opts(options) -> Tuple[int, int]:
|
|||
|
||||
def do_sharding(
|
||||
options,
|
||||
selected_tests: List[str],
|
||||
selected_tests: Sequence[TestRun],
|
||||
test_file_times: Dict[str, float],
|
||||
test_class_times: Dict[str, Dict[str, float]],
|
||||
sort_by_time: bool = True,
|
||||
) -> List[ShardedTest]:
|
||||
which_shard, num_shards = get_sharding_opts(options)
|
||||
|
|
@ -1482,6 +1496,7 @@ def do_sharding(
|
|||
num_shards,
|
||||
selected_tests,
|
||||
test_file_times,
|
||||
test_class_times=test_class_times,
|
||||
must_serial=must_serial,
|
||||
sort_by_time=sort_by_time,
|
||||
)
|
||||
|
|
@ -1492,16 +1507,16 @@ def do_sharding(
|
|||
|
||||
|
||||
class TestFailure(NamedTuple):
|
||||
test: str
|
||||
test: TestRun
|
||||
message: str
|
||||
|
||||
|
||||
def run_test_module(
|
||||
test: Union[ShardedTest, str], test_directory: str, options
|
||||
test: ShardedTest, test_directory: str, options
|
||||
) -> Optional[TestFailure]:
|
||||
maybe_set_hip_visible_devies()
|
||||
|
||||
test_name = test.name if isinstance(test, ShardedTest) else test
|
||||
test_name = test.name
|
||||
|
||||
# Printing the date here can help diagnose which tests are slow
|
||||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
|
||||
|
|
@ -1519,7 +1534,7 @@ def run_test_module(
|
|||
# return code -N, where N is the signal number.
|
||||
signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
|
||||
message += f" Received signal: {signal_name}"
|
||||
return TestFailure(test_name, message)
|
||||
return TestFailure(test.test, message)
|
||||
|
||||
|
||||
def run_tests(
|
||||
|
|
@ -1671,6 +1686,9 @@ def main():
|
|||
"cpp": options.cpp,
|
||||
}
|
||||
|
||||
test_file_times_dict = load_test_file_times()
|
||||
test_class_times_dict = load_test_class_times()
|
||||
|
||||
class TestBatch:
|
||||
"""Defines a set of tests with similar priority that should be run together on the current shard"""
|
||||
|
||||
|
|
@ -1678,11 +1696,17 @@ def main():
|
|||
sharded_tests: List[ShardedTest]
|
||||
failures: List[TestFailure]
|
||||
|
||||
def __init__(self, name: str, raw_tests: List[str], should_sort_shard: bool):
|
||||
def __init__(
|
||||
self, name: str, raw_tests: Sequence[TestRun], should_sort_shard: bool
|
||||
):
|
||||
self.name = name
|
||||
self.failures = []
|
||||
self.sharded_tests = do_sharding(
|
||||
options, raw_tests, test_times_dict, sort_by_time=should_sort_shard
|
||||
options,
|
||||
raw_tests,
|
||||
test_file_times_dict,
|
||||
test_class_times_dict,
|
||||
sort_by_time=should_sort_shard,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
|
|
@ -1697,7 +1721,6 @@ def main():
|
|||
)
|
||||
return s.strip()
|
||||
|
||||
test_times_dict = download_test_times(ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE)
|
||||
test_batches: List[TestBatch] = []
|
||||
|
||||
# Each batch will be run sequentially
|
||||
|
|
@ -1749,7 +1772,7 @@ def main():
|
|||
test_batch.sharded_tests, test_directory, options, test_batch.failures
|
||||
)
|
||||
metrics_dict[f"{test_batch.name}_failures"] = [
|
||||
x.test for x in test_batch.failures
|
||||
str(x.test) for x in test_batch.failures
|
||||
]
|
||||
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
import pathlib
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Any, Dict, Set
|
||||
from typing import Any, Dict, Optional, Set
|
||||
from unittest import mock
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
||||
|
|
@ -12,12 +12,15 @@ try:
|
|||
sys.path.append(str(REPO_ROOT))
|
||||
|
||||
from tools.testing.target_determination.determinator import (
|
||||
AggregatedHeuristics,
|
||||
get_test_prioritizations,
|
||||
TestPrioritizations,
|
||||
)
|
||||
from tools.testing.target_determination.heuristics import HEURISTICS
|
||||
from tools.testing.target_determination.heuristics.previously_failed_in_pr import (
|
||||
_get_previously_failing_tests,
|
||||
)
|
||||
from tools.testing.test_run import TestRun, TestRuns
|
||||
|
||||
except ModuleNotFoundError:
|
||||
print("Can't import required modules, exiting")
|
||||
|
|
@ -31,7 +34,37 @@ def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
|
|||
return file_object
|
||||
|
||||
|
||||
class TestParsePrevTests(unittest.TestCase):
|
||||
class HeuristicsTestMixin(unittest.TestCase):
|
||||
def assert_heuristics_match(
|
||||
self,
|
||||
test_prioritizations: TestPrioritizations,
|
||||
expected_high_tests: Optional[TestRuns] = None,
|
||||
expected_probable_tests: Optional[TestRuns] = None,
|
||||
expected_unranked_tests: Optional[TestRuns] = None,
|
||||
) -> None:
|
||||
if expected_unranked_tests:
|
||||
self.assertTupleEqual(
|
||||
test_prioritizations.get_unranked_relevance_tests(),
|
||||
expected_unranked_tests,
|
||||
"Unranked tests differ",
|
||||
)
|
||||
|
||||
if expected_probable_tests:
|
||||
self.assertTupleEqual(
|
||||
test_prioritizations.get_probable_relevance_tests(),
|
||||
expected_probable_tests,
|
||||
"Probable relevance tests differ",
|
||||
)
|
||||
|
||||
if expected_high_tests:
|
||||
self.assertTupleEqual(
|
||||
test_prioritizations.get_high_relevance_tests(),
|
||||
expected_high_tests,
|
||||
"High relevance tests differ",
|
||||
)
|
||||
|
||||
|
||||
class TestParsePrevTests(HeuristicsTestMixin):
|
||||
@mock.patch("pathlib.Path.exists", return_value=False)
|
||||
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
|
||||
expected_failing_test_files: Set[str] = set()
|
||||
|
|
@ -94,18 +127,257 @@ class TestParsePrevTests(unittest.TestCase):
|
|||
tests
|
||||
).get_aggregated_priorities()
|
||||
|
||||
self.assertTupleEqual(
|
||||
expected_prioritizations.get_high_relevance_tests(),
|
||||
test_prioritizations.get_high_relevance_tests(),
|
||||
self.assert_heuristics_match(
|
||||
test_prioritizations,
|
||||
expected_high_tests=expected_prioritizations.get_high_relevance_tests(),
|
||||
expected_probable_tests=expected_prioritizations.get_probable_relevance_tests(),
|
||||
expected_unranked_tests=expected_prioritizations.get_unranked_relevance_tests(),
|
||||
)
|
||||
self.assertTupleEqual(
|
||||
expected_prioritizations.get_probable_relevance_tests(),
|
||||
test_prioritizations.get_probable_relevance_tests(),
|
||||
|
||||
|
||||
class TestInterface(HeuristicsTestMixin):
|
||||
def test_class_prioritization(self) -> None:
|
||||
tests = ["test1", "test2", "test3", "test4", "test5"]
|
||||
|
||||
prioritizations = TestPrioritizations(
|
||||
tests_being_ranked=tests,
|
||||
probable_relevance=["test2::TestFooClass", "test3"],
|
||||
)
|
||||
self.assertTupleEqual(
|
||||
expected_prioritizations.get_unranked_relevance_tests(),
|
||||
test_prioritizations.get_unranked_relevance_tests(),
|
||||
|
||||
expected_probable_tests = tuple(
|
||||
TestRun(test) for test in ["test2::TestFooClass", "test3"]
|
||||
)
|
||||
expected_unranked_tests = (
|
||||
TestRun("test1"),
|
||||
TestRun("test2", excluded=["TestFooClass"]),
|
||||
TestRun("test4"),
|
||||
TestRun("test5"),
|
||||
)
|
||||
|
||||
self.assert_heuristics_match(
|
||||
prioritizations,
|
||||
expected_probable_tests=expected_probable_tests,
|
||||
expected_unranked_tests=expected_unranked_tests,
|
||||
)
|
||||
|
||||
|
||||
class TestAggregatedHeuristics(HeuristicsTestMixin):
|
||||
def test_merging_multiple_test_class_heuristics(self) -> None:
|
||||
tests = ["test1", "test2", "test3", "test4"]
|
||||
|
||||
heuristic1 = TestPrioritizations(
|
||||
tests_being_ranked=tests,
|
||||
probable_relevance=["test2::TestFooClass", "test3"],
|
||||
)
|
||||
|
||||
heuristic2 = TestPrioritizations(
|
||||
tests_being_ranked=tests,
|
||||
high_relevance=["test2::TestFooClass", "test3::TestBarClass"],
|
||||
)
|
||||
|
||||
expected_high_relevance = tuple(
|
||||
TestRun(test) for test in ["test2::TestFooClass", "test3::TestBarClass"]
|
||||
)
|
||||
expected_probable_relevance = (TestRun("test3", excluded=["TestBarClass"]),)
|
||||
expected_unranked_relevance = (
|
||||
TestRun("test1"),
|
||||
TestRun("test2", excluded=["TestFooClass"]),
|
||||
TestRun("test4"),
|
||||
)
|
||||
|
||||
aggregator = AggregatedHeuristics(unranked_tests=tests)
|
||||
aggregator.add_heuristic_results(HEURISTICS[0], heuristic1)
|
||||
aggregator.add_heuristic_results(HEURISTICS[1], heuristic2)
|
||||
|
||||
aggregated_pris = aggregator.get_aggregated_priorities()
|
||||
|
||||
self.assert_heuristics_match(
|
||||
aggregated_pris,
|
||||
expected_high_tests=expected_high_relevance,
|
||||
expected_probable_tests=expected_probable_relevance,
|
||||
expected_unranked_tests=expected_unranked_relevance,
|
||||
)
|
||||
|
||||
def test_merging_file_heuristic_after_class_heuristic(self) -> None:
|
||||
tests = ["test1", "test2", "test3", "test4", "test5"]
|
||||
heuristic1 = TestPrioritizations(
|
||||
tests_being_ranked=tests,
|
||||
high_relevance=["test2::TestFooClass"],
|
||||
)
|
||||
heuristic2 = TestPrioritizations(
|
||||
tests_being_ranked=tests,
|
||||
probable_relevance=["test2", "test3"],
|
||||
)
|
||||
|
||||
expected_aggregated_high_relevance = tuple(
|
||||
TestRun(test) for test in ["test2::TestFooClass"]
|
||||
)
|
||||
expected_aggregated_probable_relevance = (
|
||||
TestRun("test2", excluded=["TestFooClass"]),
|
||||
TestRun("test3"),
|
||||
)
|
||||
expected_aggregated_unranked_relevance = (
|
||||
TestRun("test1"),
|
||||
TestRun("test4"),
|
||||
TestRun("test5"),
|
||||
)
|
||||
|
||||
aggregator = AggregatedHeuristics(unranked_tests=tests)
|
||||
aggregator.add_heuristic_results(HEURISTICS[0], heuristic1)
|
||||
aggregator.add_heuristic_results(HEURISTICS[1], heuristic2)
|
||||
|
||||
aggregated_pris = aggregator.get_aggregated_priorities()
|
||||
|
||||
self.assert_heuristics_match(
|
||||
aggregated_pris,
|
||||
expected_high_tests=expected_aggregated_high_relevance,
|
||||
expected_probable_tests=expected_aggregated_probable_relevance,
|
||||
expected_unranked_tests=expected_aggregated_unranked_relevance,
|
||||
)
|
||||
|
||||
def test_get_test_stats_with_whole_tests(self) -> None:
|
||||
self.maxDiff = None
|
||||
tests = ["test1", "test2", "test3", "test4", "test5"]
|
||||
heuristic1 = TestPrioritizations(
|
||||
tests_being_ranked=tests,
|
||||
high_relevance=["test3", "test4"],
|
||||
)
|
||||
heuristic2 = TestPrioritizations(
|
||||
tests_being_ranked=tests,
|
||||
probable_relevance=["test5"],
|
||||
)
|
||||
|
||||
aggregator = AggregatedHeuristics(unranked_tests=tests)
|
||||
aggregator.add_heuristic_results(HEURISTICS[0], heuristic1)
|
||||
aggregator.add_heuristic_results(HEURISTICS[1], heuristic2)
|
||||
|
||||
expected_test3_stats = {
|
||||
"test_name": "test3",
|
||||
"test_filters": "",
|
||||
"without_heuristics": {
|
||||
"relevance_group": "UNRANKED",
|
||||
"order_within_relevance_group": 2,
|
||||
"num_tests_in_relevance_group": 5,
|
||||
"order_overall": 2,
|
||||
"heuristic_name": "baseline",
|
||||
},
|
||||
"heuristics": [
|
||||
{
|
||||
"relevance_group": "HIGH",
|
||||
"order_within_relevance_group": 0,
|
||||
"num_tests_in_relevance_group": 2,
|
||||
"order_overall": 0,
|
||||
"heuristic_name": HEURISTICS[0].name,
|
||||
"trial_mode": False,
|
||||
},
|
||||
{
|
||||
"relevance_group": "UNRANKED",
|
||||
"order_within_relevance_group": 2,
|
||||
"num_tests_in_relevance_group": 4,
|
||||
"order_overall": 3,
|
||||
"heuristic_name": HEURISTICS[1].name,
|
||||
"trial_mode": False,
|
||||
},
|
||||
],
|
||||
"num_heuristics_prioritized_by": 1,
|
||||
"aggregated": {
|
||||
"relevance_group": "HIGH",
|
||||
"order_within_relevance_group": 0,
|
||||
"num_tests_in_relevance_group": 2,
|
||||
"order_overall": 0,
|
||||
},
|
||||
"aggregated_trial": {
|
||||
"relevance_group": "HIGH",
|
||||
"order_within_relevance_group": 0,
|
||||
"num_tests_in_relevance_group": 2,
|
||||
"order_overall": 0,
|
||||
},
|
||||
"highest_ranking_heuristic": HEURISTICS[0].name,
|
||||
}
|
||||
|
||||
test3_stats = aggregator.get_test_stats(TestRun("test3"))
|
||||
|
||||
self.assertDictEqual(test3_stats, expected_test3_stats)
|
||||
|
||||
def test_get_test_stats_only_contains_allowed_types(self) -> None:
|
||||
self.maxDiff = None
|
||||
tests = ["test1", "test2", "test3", "test4", "test5"]
|
||||
heuristic1 = TestPrioritizations(
|
||||
tests_being_ranked=tests,
|
||||
high_relevance=["test3", "test4"],
|
||||
)
|
||||
heuristic2 = TestPrioritizations(
|
||||
tests_being_ranked=tests,
|
||||
probable_relevance=["test5::classA"],
|
||||
)
|
||||
|
||||
aggregator = AggregatedHeuristics(unranked_tests=tests)
|
||||
aggregator.add_heuristic_results(HEURISTICS[0], heuristic1)
|
||||
aggregator.add_heuristic_results(HEURISTICS[1], heuristic2)
|
||||
|
||||
stats3 = aggregator.get_test_stats(TestRun("test3"))
|
||||
stats5 = aggregator.get_test_stats(TestRun("test5::classA"))
|
||||
|
||||
def assert_valid_dict(dict_contents: Dict[str, Any]) -> None:
|
||||
for key, value in dict_contents.items():
|
||||
self.assertTrue(isinstance(key, str))
|
||||
self.assertTrue(
|
||||
isinstance(value, (str, float, int, list, dict)),
|
||||
f"{value} is not a str, float, or dict",
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
assert_valid_dict(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
assert_valid_dict(item)
|
||||
|
||||
assert_valid_dict(stats3)
|
||||
assert_valid_dict(stats5)
|
||||
|
||||
def test_get_test_stats_gets_rank_for_test_classes(self) -> None:
|
||||
self.maxDiff = None
|
||||
tests = ["test1", "test2", "test3", "test4", "test5"]
|
||||
heuristic1 = TestPrioritizations(
|
||||
tests_being_ranked=tests,
|
||||
high_relevance=["test3", "test4"],
|
||||
)
|
||||
heuristic2 = TestPrioritizations(
|
||||
tests_being_ranked=tests,
|
||||
probable_relevance=["test5::classA"],
|
||||
)
|
||||
|
||||
aggregator = AggregatedHeuristics(unranked_tests=tests)
|
||||
aggregator.add_heuristic_results(HEURISTICS[0], heuristic1)
|
||||
aggregator.add_heuristic_results(HEURISTICS[1], heuristic2)
|
||||
|
||||
statsInclusive = aggregator.get_test_stats(
|
||||
TestRun("test5", included=["classA"])
|
||||
)
|
||||
statsExclusive = aggregator.get_test_stats(
|
||||
TestRun("test5", excluded=["classA"])
|
||||
)
|
||||
|
||||
print("h")
|
||||
# Validate the heuristic level stats are correct
|
||||
self.assertEqual(
|
||||
statsInclusive["heuristics"][1]["order_within_relevance_group"], 0
|
||||
)
|
||||
self.assertEqual(
|
||||
statsInclusive["heuristics"][1]["num_tests_in_relevance_group"], 1
|
||||
)
|
||||
self.assertEqual(statsInclusive["heuristics"][1]["order_overall"], 0)
|
||||
self.assertEqual(statsInclusive["heuristics"][1]["relevance_group"], "PROBABLE")
|
||||
self.assertEqual(statsInclusive["aggregated"]["order_overall"], 2)
|
||||
|
||||
self.assertEqual(
|
||||
statsExclusive["heuristics"][1]["order_within_relevance_group"], 4
|
||||
)
|
||||
self.assertEqual(
|
||||
statsExclusive["heuristics"][1]["num_tests_in_relevance_group"], 5
|
||||
)
|
||||
self.assertEqual(statsExclusive["heuristics"][1]["order_overall"], 5)
|
||||
self.assertEqual(statsExclusive["heuristics"][1]["relevance_group"], "UNRANKED")
|
||||
self.assertEqual(statsExclusive["aggregated"]["order_overall"], 5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
169
tools/test/test_test_run.py
Normal file
169
tools/test/test_test_run.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
import pathlib
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
||||
try:
|
||||
# using tools/ to optimize test run.
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
from tools.testing.test_run import ShardedTest, TestRun
|
||||
except ModuleNotFoundError:
|
||||
print("Can't import required modules, exiting")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class TestTestRun(unittest.TestCase):
|
||||
def test_union_with_full_run(self) -> None:
|
||||
run1 = TestRun("foo")
|
||||
run2 = TestRun("foo::bar")
|
||||
|
||||
self.assertEqual(run1 | run2, run1)
|
||||
self.assertEqual(run2 | run1, run1)
|
||||
|
||||
def test_union_with_inclusions(self) -> None:
|
||||
run1 = TestRun("foo::bar")
|
||||
run2 = TestRun("foo::baz")
|
||||
|
||||
expected = TestRun("foo")
|
||||
expected._included.add("bar")
|
||||
expected._included.add("baz")
|
||||
|
||||
self.assertEqual(run1 | run2, expected)
|
||||
self.assertEqual(run2 | run1, expected)
|
||||
|
||||
def test_union_with_non_overlapping_exclusions(self) -> None:
|
||||
run1 = TestRun("foo", excluded=["bar"])
|
||||
run2 = TestRun("foo", excluded=["baz"])
|
||||
|
||||
expected = TestRun("foo")
|
||||
|
||||
self.assertEqual(run1 | run2, expected)
|
||||
self.assertEqual(run2 | run1, expected)
|
||||
|
||||
def test_union_with_overlapping_exclusions(self) -> None:
|
||||
run1 = TestRun("foo", excluded=["bar", "car"])
|
||||
run2 = TestRun("foo", excluded=["bar", "caz"])
|
||||
|
||||
expected = TestRun("foo", excluded=["bar"])
|
||||
|
||||
self.assertEqual(run1 | run2, expected)
|
||||
self.assertEqual(run2 | run1, expected)
|
||||
|
||||
def test_union_with_mixed_inclusion_exclusions(self) -> None:
|
||||
run1 = TestRun("foo", excluded=["baz", "car"])
|
||||
run2 = TestRun("foo", included=["baz"])
|
||||
|
||||
expected = TestRun("foo", excluded=["car"])
|
||||
|
||||
self.assertEqual(run1 | run2, expected)
|
||||
self.assertEqual(run2 | run1, expected)
|
||||
|
||||
def test_union_with_mixed_files_fails(self) -> None:
|
||||
run1 = TestRun("foo")
|
||||
run2 = TestRun("bar")
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
run1 | run2
|
||||
|
||||
def test_union_with_empty_file_yields_orig_file(self) -> None:
|
||||
run1 = TestRun("foo")
|
||||
run2 = TestRun.empty()
|
||||
|
||||
self.assertEqual(run1 | run2, run1)
|
||||
self.assertEqual(run2 | run1, run1)
|
||||
|
||||
def test_subtracting_full_run_fails(self) -> None:
|
||||
run1 = TestRun("foo::bar")
|
||||
run2 = TestRun("foo")
|
||||
|
||||
self.assertEqual(run1 - run2, TestRun.empty())
|
||||
|
||||
def test_subtracting_empty_file_yields_orig_file(self) -> None:
|
||||
run1 = TestRun("foo")
|
||||
run2 = TestRun.empty()
|
||||
|
||||
self.assertEqual(run1 - run2, run1)
|
||||
self.assertEqual(run2 - run1, TestRun.empty())
|
||||
|
||||
def test_empty_is_falsey(self) -> None:
|
||||
self.assertFalse(TestRun.empty())
|
||||
|
||||
def test_subtracting_inclusion_from_full_run(self) -> None:
|
||||
run1 = TestRun("foo")
|
||||
run2 = TestRun("foo::bar")
|
||||
|
||||
expected = TestRun("foo", excluded=["bar"])
|
||||
|
||||
self.assertEqual(run1 - run2, expected)
|
||||
|
||||
def test_subtracting_inclusion_from_overlapping_inclusion(self) -> None:
|
||||
run1 = TestRun("foo", included=["bar", "baz"])
|
||||
run2 = TestRun("foo::baz")
|
||||
|
||||
self.assertEqual(run1 - run2, TestRun("foo", included=["bar"]))
|
||||
|
||||
def test_subtracting_inclusion_from_nonoverlapping_inclusion(self) -> None:
|
||||
run1 = TestRun("foo", included=["bar", "baz"])
|
||||
run2 = TestRun("foo", included=["car"])
|
||||
|
||||
self.assertEqual(run1 - run2, TestRun("foo", included=["bar", "baz"]))
|
||||
|
||||
def test_subtracting_exclusion_from_full_run(self) -> None:
|
||||
run1 = TestRun("foo")
|
||||
run2 = TestRun("foo", excluded=["bar"])
|
||||
|
||||
self.assertEqual(run1 - run2, TestRun("foo", included=["bar"]))
|
||||
|
||||
def test_subtracting_exclusion_from_superset_exclusion(self) -> None:
|
||||
run1 = TestRun("foo", excluded=["bar", "baz"])
|
||||
run2 = TestRun("foo", excluded=["baz"])
|
||||
|
||||
self.assertEqual(run1 - run2, TestRun.empty())
|
||||
self.assertEqual(run2 - run1, TestRun("foo", included=["bar"]))
|
||||
|
||||
def test_subtracting_exclusion_from_nonoverlapping_exclusion(self) -> None:
|
||||
run1 = TestRun("foo", excluded=["bar", "baz"])
|
||||
run2 = TestRun("foo", excluded=["car"])
|
||||
|
||||
self.assertEqual(run1 - run2, TestRun("foo", included=["car"]))
|
||||
self.assertEqual(run2 - run1, TestRun("foo", included=["bar", "baz"]))
|
||||
|
||||
def test_subtracting_inclusion_from_exclusion_without_overlaps(self) -> None:
|
||||
run1 = TestRun("foo", excluded=["bar", "baz"])
|
||||
run2 = TestRun("foo", included=["bar"])
|
||||
|
||||
self.assertEqual(run1 - run2, run1)
|
||||
self.assertEqual(run2 - run1, run2)
|
||||
|
||||
def test_subtracting_inclusion_from_exclusion_with_overlaps(self) -> None:
|
||||
run1 = TestRun("foo", excluded=["bar", "baz"])
|
||||
run2 = TestRun("foo", included=["bar", "car"])
|
||||
|
||||
self.assertEqual(run1 - run2, TestRun("foo", excluded=["bar", "baz", "car"]))
|
||||
self.assertEqual(run2 - run1, TestRun("foo", included=["bar"]))
|
||||
|
||||
def test_and(self) -> None:
|
||||
run1 = TestRun("foo", included=["bar", "baz"])
|
||||
run2 = TestRun("foo", included=["bar", "car"])
|
||||
|
||||
self.assertEqual(run1 & run2, TestRun("foo", included=["bar"]))
|
||||
|
||||
def test_and_exclusions(self) -> None:
|
||||
run1 = TestRun("foo", excluded=["bar", "baz"])
|
||||
run2 = TestRun("foo", excluded=["bar", "car"])
|
||||
|
||||
self.assertEqual(run1 & run2, TestRun("foo", excluded=["bar", "baz", "car"]))
|
||||
|
||||
|
||||
class TestShardedTest(unittest.TestCase):
|
||||
def test_get_pytest_args(self) -> None:
|
||||
test = TestRun("foo", included=["bar", "baz"])
|
||||
sharded_test = ShardedTest(test, 1, 1)
|
||||
|
||||
expected_args = ["-k", "bar or baz"]
|
||||
|
||||
self.assertListEqual(sharded_test.get_pytest_args(), expected_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -9,25 +9,30 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
|||
try:
|
||||
# using tools/ to optimize test run.
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD
|
||||
from tools.testing.test_run import ShardedTest, TestRun
|
||||
from tools.testing.test_selections import calculate_shards, THRESHOLD
|
||||
except ModuleNotFoundError:
|
||||
print("Can't import required modules, exiting")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def gen_class_times(test_times: Dict[str, float]) -> Dict[str, Dict[str, float]]:
|
||||
return {k: {"class1": v} for k, v in test_times.items()}
|
||||
|
||||
|
||||
class TestCalculateShards(unittest.TestCase):
|
||||
tests: List[str] = [
|
||||
"super_long_test",
|
||||
"long_test1",
|
||||
"long_test2",
|
||||
"normal_test1",
|
||||
"normal_test2",
|
||||
"normal_test3",
|
||||
"short_test1",
|
||||
"short_test2",
|
||||
"short_test3",
|
||||
"short_test4",
|
||||
"short_test5",
|
||||
tests: List[TestRun] = [
|
||||
TestRun("super_long_test"),
|
||||
TestRun("long_test1"),
|
||||
TestRun("long_test2"),
|
||||
TestRun("normal_test1"),
|
||||
TestRun("normal_test2"),
|
||||
TestRun("normal_test3"),
|
||||
TestRun("short_test1"),
|
||||
TestRun("short_test2"),
|
||||
TestRun("short_test3"),
|
||||
TestRun("short_test4"),
|
||||
TestRun("short_test5"),
|
||||
]
|
||||
|
||||
test_times: Dict[str, float] = {
|
||||
|
|
@ -44,6 +49,20 @@ class TestCalculateShards(unittest.TestCase):
|
|||
"short_test5": 0.01,
|
||||
}
|
||||
|
||||
test_class_times: Dict[str, Dict[str, float]] = {
|
||||
"super_long_test": {"class1": 55},
|
||||
"long_test1": {"class1": 1, "class2": 21},
|
||||
"long_test2": {"class1": 10, "class2": 8},
|
||||
"normal_test1": {"class1": 9},
|
||||
"normal_test2": {"class1": 7},
|
||||
"normal_test3": {"class1": 5},
|
||||
"short_test1": {"class1": 1},
|
||||
"short_test2": {"class1": 0.6},
|
||||
"short_test3": {"class1": 0.4},
|
||||
"short_test4": {"class1": 0.3},
|
||||
"short_test5": {"class1": 0.01},
|
||||
}
|
||||
|
||||
def assert_shards_equal(
|
||||
self,
|
||||
expected_shards: List[Tuple[float, List[ShardedTest]]],
|
||||
|
|
@ -58,81 +77,92 @@ class TestCalculateShards(unittest.TestCase):
|
|||
(
|
||||
60.0,
|
||||
[
|
||||
ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55),
|
||||
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5),
|
||||
ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55),
|
||||
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5),
|
||||
],
|
||||
),
|
||||
(
|
||||
58.31,
|
||||
[
|
||||
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(name="long_test2", shard=1, num_shards=1, time=18),
|
||||
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7),
|
||||
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6),
|
||||
ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4),
|
||||
ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3),
|
||||
ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01),
|
||||
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(test="long_test2", shard=1, num_shards=1, time=18),
|
||||
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7),
|
||||
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6),
|
||||
ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4),
|
||||
ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3),
|
||||
ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01),
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(2, self.tests, self.test_times)
|
||||
expected_shards,
|
||||
calculate_shards(2, self.tests, self.test_times, self.test_class_times),
|
||||
)
|
||||
|
||||
def test_calculate_1_shard_with_complete_test_times(self) -> None:
|
||||
tests = self.tests.copy()
|
||||
class_test1 = TestRun("long_test1", excluded=["class2"])
|
||||
class_test2 = TestRun("long_test1", included=["class2"])
|
||||
tests.append(class_test1)
|
||||
tests.append(class_test2)
|
||||
|
||||
expected_shards = [
|
||||
(
|
||||
118.31,
|
||||
140.31,
|
||||
[
|
||||
ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55),
|
||||
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(name="long_test2", shard=1, num_shards=1, time=18),
|
||||
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7),
|
||||
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5),
|
||||
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6),
|
||||
ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4),
|
||||
ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3),
|
||||
ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01),
|
||||
ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55),
|
||||
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(class_test2, shard=1, num_shards=1, time=21),
|
||||
ShardedTest(test="long_test2", shard=1, num_shards=1, time=18),
|
||||
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7),
|
||||
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5),
|
||||
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(class_test1, shard=1, num_shards=1, time=1),
|
||||
ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6),
|
||||
ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4),
|
||||
ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3),
|
||||
ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01),
|
||||
],
|
||||
)
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(1, self.tests, self.test_times)
|
||||
expected_shards,
|
||||
calculate_shards(1, tests, self.test_times, self.test_class_times),
|
||||
)
|
||||
|
||||
def test_calculate_5_shards_with_complete_test_times(self) -> None:
|
||||
expected_shards = [
|
||||
(
|
||||
55.0,
|
||||
[ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55)],
|
||||
[ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55)],
|
||||
),
|
||||
(22.0, [ShardedTest(name="long_test1", shard=1, num_shards=1, time=22)]),
|
||||
(18.0, [ShardedTest(name="long_test2", shard=1, num_shards=1, time=18)]),
|
||||
(22.0, [ShardedTest(test="long_test1", shard=1, num_shards=1, time=22)]),
|
||||
(18.0, [ShardedTest(test="long_test2", shard=1, num_shards=1, time=18)]),
|
||||
(
|
||||
11.31,
|
||||
[
|
||||
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6),
|
||||
ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4),
|
||||
ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3),
|
||||
ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01),
|
||||
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6),
|
||||
ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4),
|
||||
ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3),
|
||||
ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01),
|
||||
],
|
||||
),
|
||||
(
|
||||
12.0,
|
||||
[
|
||||
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7),
|
||||
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5),
|
||||
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7),
|
||||
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5),
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(5, self.tests, self.test_times)
|
||||
expected_shards,
|
||||
calculate_shards(5, self.tests, self.test_times, self.test_class_times),
|
||||
)
|
||||
|
||||
def test_calculate_2_shards_with_incomplete_test_times(self) -> None:
|
||||
|
|
@ -143,29 +173,35 @@ class TestCalculateShards(unittest.TestCase):
|
|||
(
|
||||
22.0,
|
||||
[
|
||||
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(name="long_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="short_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="short_test5", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
10.0,
|
||||
[
|
||||
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(
|
||||
name="super_long_test", shard=1, num_shards=1, time=None
|
||||
test="super_long_test", shard=1, num_shards=1, time=None
|
||||
),
|
||||
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="short_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="short_test4", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(2, self.tests, incomplete_test_times)
|
||||
expected_shards,
|
||||
calculate_shards(
|
||||
2,
|
||||
self.tests,
|
||||
incomplete_test_times,
|
||||
gen_class_times(incomplete_test_times),
|
||||
),
|
||||
)
|
||||
|
||||
def test_calculate_5_shards_with_incomplete_test_times(self) -> None:
|
||||
|
|
@ -176,54 +212,66 @@ class TestCalculateShards(unittest.TestCase):
|
|||
(
|
||||
22.0,
|
||||
[
|
||||
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="short_test5", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
9.0,
|
||||
[
|
||||
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
1.0,
|
||||
[
|
||||
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(name="short_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
0.0,
|
||||
[
|
||||
ShardedTest(
|
||||
name="super_long_test", shard=1, num_shards=1, time=None
|
||||
test="super_long_test", shard=1, num_shards=1, time=None
|
||||
),
|
||||
ShardedTest(name="short_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test3", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
0.0,
|
||||
[
|
||||
ShardedTest(name="long_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="short_test4", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(5, self.tests, incomplete_test_times)
|
||||
expected_shards,
|
||||
calculate_shards(
|
||||
5,
|
||||
self.tests,
|
||||
incomplete_test_times,
|
||||
gen_class_times(incomplete_test_times),
|
||||
),
|
||||
)
|
||||
|
||||
def test_split_shards(self) -> None:
|
||||
test_times: Dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD}
|
||||
expected_shards = [
|
||||
(600.0, [ShardedTest(name="test1", shard=1, num_shards=1, time=THRESHOLD)]),
|
||||
(600.0, [ShardedTest(name="test2", shard=1, num_shards=1, time=THRESHOLD)]),
|
||||
(600.0, [ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD)]),
|
||||
(600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(2, list(test_times.keys()), test_times)
|
||||
expected_shards,
|
||||
calculate_shards(
|
||||
2,
|
||||
[TestRun(t) for t in test_times.keys()],
|
||||
test_times,
|
||||
gen_class_times(test_times),
|
||||
),
|
||||
)
|
||||
|
||||
test_times = {"test1": THRESHOLD * 4, "test2": THRESHOLD * 2.5}
|
||||
|
|
@ -231,35 +279,47 @@ class TestCalculateShards(unittest.TestCase):
|
|||
(
|
||||
2200.0,
|
||||
[
|
||||
ShardedTest(name="test1", shard=1, num_shards=4, time=600.0),
|
||||
ShardedTest(name="test1", shard=3, num_shards=4, time=600.0),
|
||||
ShardedTest(name="test2", shard=1, num_shards=3, time=500.0),
|
||||
ShardedTest(name="test2", shard=3, num_shards=3, time=500.0),
|
||||
ShardedTest(test="test1", shard=1, num_shards=4, time=600.0),
|
||||
ShardedTest(test="test1", shard=3, num_shards=4, time=600.0),
|
||||
ShardedTest(test="test2", shard=1, num_shards=3, time=500.0),
|
||||
ShardedTest(test="test2", shard=3, num_shards=3, time=500.0),
|
||||
],
|
||||
),
|
||||
(
|
||||
1700.0,
|
||||
[
|
||||
ShardedTest(name="test1", shard=2, num_shards=4, time=600.0),
|
||||
ShardedTest(name="test1", shard=4, num_shards=4, time=600.0),
|
||||
ShardedTest(name="test2", shard=2, num_shards=3, time=500.0),
|
||||
ShardedTest(test="test1", shard=2, num_shards=4, time=600.0),
|
||||
ShardedTest(test="test1", shard=4, num_shards=4, time=600.0),
|
||||
ShardedTest(test="test2", shard=2, num_shards=3, time=500.0),
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(2, list(test_times.keys()), test_times)
|
||||
expected_shards,
|
||||
calculate_shards(
|
||||
2,
|
||||
[TestRun(t) for t in test_times.keys()],
|
||||
test_times,
|
||||
gen_class_times(test_times),
|
||||
),
|
||||
)
|
||||
|
||||
test_times = {"test1": THRESHOLD / 2, "test2": THRESHOLD}
|
||||
expected_shards = [
|
||||
(600.0, [ShardedTest(name="test2", shard=1, num_shards=1, time=THRESHOLD)]),
|
||||
(600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]),
|
||||
(
|
||||
300.0,
|
||||
[ShardedTest(name="test1", shard=1, num_shards=1, time=THRESHOLD / 2)],
|
||||
[ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD / 2)],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(2, list(test_times.keys()), test_times)
|
||||
expected_shards,
|
||||
calculate_shards(
|
||||
2,
|
||||
[TestRun(t) for t in test_times.keys()],
|
||||
test_times,
|
||||
gen_class_times(test_times),
|
||||
),
|
||||
)
|
||||
|
||||
def test_split_shards_random(self) -> None:
|
||||
|
|
@ -272,7 +332,10 @@ class TestCalculateShards(unittest.TestCase):
|
|||
}
|
||||
|
||||
shards = calculate_shards(
|
||||
num_shards, list(random_times.keys()), random_times
|
||||
num_shards,
|
||||
[TestRun(t) for t in random_times.keys()],
|
||||
random_times,
|
||||
gen_class_times(random_times),
|
||||
)
|
||||
|
||||
times = [x[0] for x in shards]
|
||||
|
|
@ -300,7 +363,7 @@ class TestCalculateShards(unittest.TestCase):
|
|||
def test_calculate_2_shards_against_optimal_shards(self) -> None:
|
||||
random.seed(120)
|
||||
for _ in range(100):
|
||||
random_times = {k: random.random() * 10 for k in self.tests}
|
||||
random_times = {k.test_file: random.random() * 10 for k in self.tests}
|
||||
# all test times except first two
|
||||
rest_of_tests = [
|
||||
i
|
||||
|
|
@ -315,12 +378,14 @@ class TestCalculateShards(unittest.TestCase):
|
|||
# (sum_of_rest, ['super_long_test', 'long_test1']),
|
||||
# (sum_of_rest, [i for i in self.tests if i != 'super_long_test' and i != 'long_test1']),
|
||||
# ]
|
||||
calculated_shards = calculate_shards(2, self.tests, random_times)
|
||||
calculated_shards = calculate_shards(
|
||||
2, self.tests, random_times, gen_class_times(random_times)
|
||||
)
|
||||
max_shard_time = max(calculated_shards[0][0], calculated_shards[1][0])
|
||||
if sum_of_rest != 0:
|
||||
# The calculated shard should not have a ratio worse than 7/6 for num_shards = 2
|
||||
self.assertGreaterEqual(7.0 / 6.0, max_shard_time / sum_of_rest)
|
||||
sorted_tests = sorted(self.tests)
|
||||
sorted_tests = sorted([t.test_file for t in self.tests])
|
||||
sorted_shard_tests = sorted(
|
||||
calculated_shards[0][1] + calculated_shards[1][1]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,20 +1,53 @@
|
|||
import sys
|
||||
from abc import abstractmethod
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Tuple
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
StringTuple = Tuple[str, ...]
|
||||
from tools.testing.test_run import TestRun, TestRuns
|
||||
|
||||
|
||||
# Note: Keep the implementation of Relevance private to this file so
|
||||
# that it's easy to change in the future as we discover what's needed
|
||||
@total_ordering
|
||||
class Relevance(Enum):
|
||||
HIGH = 0
|
||||
PROBABLE = 1
|
||||
HIGH = 4
|
||||
PROBABLE = 3
|
||||
UNRANKED = 2
|
||||
UNLIKELY = 3 # Not yet supported. Needs more infra to be usable
|
||||
NONE = 4 # Not yet supported. Needs more infra to be usable
|
||||
UNLIKELY = 1 # Not yet supported. Needs more infra to be usable
|
||||
NONE = 0 # Not yet supported. Needs more infra to be usable
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Relevance):
|
||||
return False
|
||||
|
||||
return self.value == other.value
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, Relevance):
|
||||
raise NotImplementedError(f"Can't compare {self} to {other}")
|
||||
|
||||
return self.value < other.value
|
||||
|
||||
@staticmethod
|
||||
def priority_traversal() -> Iterator["Relevance"]:
|
||||
yield Relevance.HIGH
|
||||
yield Relevance.PROBABLE
|
||||
yield Relevance.UNRANKED
|
||||
yield Relevance.UNLIKELY
|
||||
yield Relevance.NONE
|
||||
|
||||
|
||||
METRIC_RELEVANCE_GROUP = "relevance_group"
|
||||
|
|
@ -37,7 +70,7 @@ class TestPrioritizations:
|
|||
otherwise it breaks the test sharding logic
|
||||
"""
|
||||
|
||||
_test_priorities: List[StringTuple] # This list MUST be ordered by Relevance
|
||||
_test_priorities: List[List[TestRun]] # This list MUST be ordered by Relevance
|
||||
_original_tests: FrozenSet[str]
|
||||
|
||||
def __init__(
|
||||
|
|
@ -51,106 +84,168 @@ class TestPrioritizations:
|
|||
) -> None:
|
||||
self._original_tests = frozenset(tests_being_ranked)
|
||||
|
||||
self._test_priorities = [tuple() for _ in range(5)]
|
||||
self._test_priorities = [[] for _ in range(5)]
|
||||
# Setup the initial priorities
|
||||
self._test_priorities[Relevance.UNRANKED.value] = [
|
||||
TestRun(test) for test in tests_being_ranked
|
||||
]
|
||||
|
||||
self._test_priorities[Relevance.HIGH.value] = self.filter_out_extra_tests(
|
||||
high_relevance
|
||||
)
|
||||
self._test_priorities[Relevance.PROBABLE.value] = self.filter_out_extra_tests(
|
||||
probable_relevance
|
||||
)
|
||||
self._test_priorities[Relevance.UNRANKED.value] = self.filter_out_extra_tests(
|
||||
unranked_relevance
|
||||
)
|
||||
self._test_priorities[Relevance.UNLIKELY.value] = self.filter_out_extra_tests(
|
||||
unlikely_relevance
|
||||
)
|
||||
self._test_priorities[Relevance.NONE.value] = self.filter_out_extra_tests(
|
||||
no_relevance
|
||||
)
|
||||
|
||||
# If any of the original tests were missed from the other lists, add them to the unranked_relevance list
|
||||
missing_tests = sorted(self._original_tests - set(self.get_all_tests()))
|
||||
self._test_priorities[Relevance.UNRANKED.value] = self._test_priorities[
|
||||
Relevance.UNRANKED.value
|
||||
] + tuple(missing_tests)
|
||||
for test in high_relevance or []:
|
||||
self.set_test_relevance(TestRun(test), Relevance.HIGH)
|
||||
for test in probable_relevance or []:
|
||||
self.set_test_relevance(TestRun(test), Relevance.PROBABLE)
|
||||
for test in unranked_relevance or []:
|
||||
self.set_test_relevance(TestRun(test), Relevance.UNRANKED)
|
||||
for test in unlikely_relevance or []:
|
||||
self.set_test_relevance(TestRun(test), Relevance.UNLIKELY)
|
||||
for test in no_relevance or []:
|
||||
self.set_test_relevance(TestRun(test), Relevance.NONE)
|
||||
|
||||
self.validate_test_priorities()
|
||||
|
||||
def filter_out_extra_tests(
|
||||
self, relevance_group: Optional[List[str]]
|
||||
) -> StringTuple:
|
||||
if not relevance_group:
|
||||
return tuple()
|
||||
return tuple(filter(lambda test: test in self._original_tests, relevance_group))
|
||||
def _traverse_priorities(self) -> Iterator[Tuple[Relevance, List[TestRun]]]:
|
||||
for relevance in Relevance.priority_traversal():
|
||||
yield (relevance, self._test_priorities[relevance.value])
|
||||
|
||||
def get_pointer_to_test(self, test_run: TestRun) -> Iterator[Tuple[Relevance, int]]:
|
||||
"""
|
||||
Returns all test runs that contain any subset of the given test_run and their current relevance.
|
||||
|
||||
self._test_priorities should NOT have items added or removed form it while iterating over the
|
||||
results of this function.
|
||||
"""
|
||||
# Find a test run that contains the given TestRun and it's relevance.
|
||||
found_match = False
|
||||
for relevance, tests in self._traverse_priorities():
|
||||
for idx, existing_test_run in enumerate(tests):
|
||||
# Does the existing test run contain any of the test we're looking for?
|
||||
shared_test = existing_test_run & test_run
|
||||
if not shared_test.is_empty():
|
||||
found_match = True
|
||||
yield (Relevance(relevance), idx)
|
||||
|
||||
if not found_match:
|
||||
raise ValueError(f"Test {test_run} not found in any relevance group")
|
||||
|
||||
def _update_test_relevance(
|
||||
self,
|
||||
test_run: TestRun,
|
||||
new_relevance: Relevance,
|
||||
acceptable_relevance_fn: Callable[[Relevance, Relevance], bool],
|
||||
) -> None:
|
||||
"""
|
||||
Updates the test run's relevance to the new relevance.
|
||||
|
||||
If the tests in the test run were previously split up into multiple test runs, all the chunks at a lower
|
||||
relevance will be merged into one new test run at the new relevance, appended to the end of the relevance group.
|
||||
|
||||
However, any tests in a test run that are already at the desired relevance will be left alone, keeping it's
|
||||
original place in the relevance group.
|
||||
"""
|
||||
if test_run.test_file not in self._original_tests:
|
||||
return # We don't need this test
|
||||
|
||||
# The tests covered by test_run could potentially have been split up into
|
||||
# multiple test runs, each at a different relevance. Let's make sure to bring
|
||||
# all of them up to the minimum relevance
|
||||
upgraded_tests = TestRun.empty()
|
||||
tests_to_remove = []
|
||||
for curr_relevance, test_run_idx in self.get_pointer_to_test(test_run):
|
||||
if acceptable_relevance_fn(curr_relevance, new_relevance):
|
||||
# This test is already at the desired relevance
|
||||
continue # no changes needed
|
||||
|
||||
test_run_to_rerank = self._test_priorities[curr_relevance.value][
|
||||
test_run_idx
|
||||
]
|
||||
# Remove the requested tests from their current relevance group, to be added to the new one
|
||||
remaining_tests = test_run_to_rerank - test_run
|
||||
upgraded_tests |= test_run_to_rerank & test_run
|
||||
|
||||
# Remove the tests that are being upgraded
|
||||
if remaining_tests:
|
||||
self._test_priorities[curr_relevance.value][
|
||||
test_run_idx
|
||||
] = remaining_tests
|
||||
else:
|
||||
# List traversal prevents us from deleting these immediately, so note them for later
|
||||
tests_to_remove.append((curr_relevance, test_run_idx))
|
||||
|
||||
for relevance, test_idx in tests_to_remove:
|
||||
del self._test_priorities[relevance.value][test_idx]
|
||||
|
||||
# And add them to the desired relevance group
|
||||
if upgraded_tests:
|
||||
self._test_priorities[new_relevance.value].append(upgraded_tests)
|
||||
|
||||
def set_test_relevance(self, test_run: TestRun, new_relevance: Relevance) -> None:
|
||||
return self._update_test_relevance(
|
||||
test_run, new_relevance, lambda curr, new: curr == new
|
||||
)
|
||||
|
||||
def raise_test_relevance(self, test_run: TestRun, new_relevance: Relevance) -> None:
|
||||
return self._update_test_relevance(
|
||||
test_run, new_relevance, lambda curr, new: curr >= new
|
||||
)
|
||||
|
||||
def validate_test_priorities(self) -> None:
|
||||
# Union all TestRuns that contain include/exclude pairs
|
||||
all_tests = self.get_all_tests()
|
||||
files = {}
|
||||
for test in all_tests:
|
||||
if test.test_file not in files:
|
||||
files[test.test_file] = copy(test)
|
||||
else:
|
||||
files[test.test_file] |= test
|
||||
|
||||
for test in files.values():
|
||||
assert (
|
||||
test.is_full_file()
|
||||
), f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that"
|
||||
|
||||
# Ensure that the set of tests in the TestPrioritizations is identical to the set of tests passed in
|
||||
assert self._original_tests == set(
|
||||
self.get_all_tests()
|
||||
files.keys()
|
||||
), "The set of tests in the TestPrioritizations must be identical to the set of tests passed in"
|
||||
|
||||
@staticmethod
|
||||
def _merge_tests(
|
||||
current_tests: Iterable[str],
|
||||
new_tests: Iterable[str],
|
||||
higher_pri_tests: Iterable[str],
|
||||
) -> StringTuple:
|
||||
"""
|
||||
We append all new tests to the current tests, while preserving the sorting on the new_tests
|
||||
However, exclude any specified tests which have now moved to a higher priority list or tests
|
||||
that weren't originally in the self's TestPrioritizations
|
||||
"""
|
||||
merged_tests = [
|
||||
test
|
||||
for test in chain(current_tests, new_tests)
|
||||
if test not in higher_pri_tests
|
||||
] # skip the excluded tests
|
||||
return tuple(dict.fromkeys(merged_tests)) # remove dupes while preseving order
|
||||
|
||||
def integrate_priorities(self, other: "TestPrioritizations") -> None:
|
||||
"""
|
||||
Integrates priorities from another TestPrioritizations object.
|
||||
|
||||
The final result takes all tests from the `self` and rearranges them based on priorities from `other`.
|
||||
If there are tests mentioned in `other` which are not in `self`, those tests are ignored.
|
||||
(For example, that can happen if a heuristic reports tests that are not run in the current job)
|
||||
Currently it will only raise the priority of a test, never lower it.
|
||||
"""
|
||||
assert (
|
||||
self._original_tests == other._original_tests
|
||||
), "Both tests should stem from the same original test list"
|
||||
|
||||
higher_pri_tests: List[str] = []
|
||||
for relevance, _ in enumerate(self._test_priorities):
|
||||
self._test_priorities[relevance] = TestPrioritizations._merge_tests(
|
||||
current_tests=self._test_priorities[relevance],
|
||||
new_tests=other._test_priorities[relevance],
|
||||
higher_pri_tests=higher_pri_tests,
|
||||
)
|
||||
|
||||
# Don't let the tests we just added to the current relevance group be added to a lower relevance group
|
||||
higher_pri_tests.extend(self._test_priorities[relevance])
|
||||
for relevance, _ in other._traverse_priorities():
|
||||
if relevance > Relevance.UNRANKED:
|
||||
for test in other._test_priorities[relevance.value]:
|
||||
self.raise_test_relevance(test, relevance)
|
||||
# TODO: Hande the case where a test is moved to a lower relevance group (once we support that scenario)
|
||||
|
||||
self.validate_test_priorities()
|
||||
return
|
||||
|
||||
def get_all_tests(self) -> StringTuple:
|
||||
def get_all_tests(self) -> TestRuns:
|
||||
"""Returns all tests in the TestPrioritizations"""
|
||||
return tuple(test for test in chain(*self._test_priorities))
|
||||
return tuple(chain(*self._test_priorities))
|
||||
|
||||
def get_prioritized_tests(self) -> StringTuple:
|
||||
def get_prioritized_tests(self) -> TestRuns:
|
||||
return self.get_high_relevance_tests() + self.get_probable_relevance_tests()
|
||||
|
||||
def get_high_relevance_tests(self) -> StringTuple:
|
||||
def get_high_relevance_tests(self) -> TestRuns:
|
||||
return tuple(test for test in self._test_priorities[Relevance.HIGH.value])
|
||||
|
||||
def get_probable_relevance_tests(self) -> StringTuple:
|
||||
def get_probable_relevance_tests(self) -> TestRuns:
|
||||
return tuple(test for test in self._test_priorities[Relevance.PROBABLE.value])
|
||||
|
||||
def get_unranked_relevance_tests(self) -> StringTuple:
|
||||
def get_unranked_relevance_tests(self) -> TestRuns:
|
||||
return tuple(test for test in self._test_priorities[Relevance.UNRANKED.value])
|
||||
|
||||
def print_info(self) -> None:
|
||||
def _print_tests(label: str, tests: StringTuple) -> None:
|
||||
def _print_tests(label: str, tests: List[TestRun]) -> None:
|
||||
if not tests:
|
||||
return
|
||||
|
||||
|
|
@ -159,48 +254,61 @@ class TestPrioritizations:
|
|||
if test in tests:
|
||||
print(f" {test}")
|
||||
|
||||
for relevance_group, tests in enumerate(self._test_priorities):
|
||||
for relevance_group, tests in self._traverse_priorities():
|
||||
_print_tests(f"{Relevance(relevance_group).name.title()} Relevance", tests)
|
||||
|
||||
def _get_test_relevance_group(self, test_name: str) -> Relevance:
|
||||
"""Returns the priority of a test."""
|
||||
for relevance_group, tests in enumerate(self._test_priorities):
|
||||
if test_name in tests:
|
||||
def _get_test_relevance_group(self, test_run: TestRun) -> Relevance:
|
||||
"""Returns the rank of the given test run."""
|
||||
for relevance_group, tests in self._traverse_priorities():
|
||||
if any(t.contains(test_run) for t in tests):
|
||||
return Relevance(relevance_group)
|
||||
|
||||
raise ValueError(f"Test {test_name} not found in any relevance group")
|
||||
print("holup, retry")
|
||||
for relevance_group, tests in self._traverse_priorities():
|
||||
if any(
|
||||
t.contains(test_run) for t in tests
|
||||
): # t could be the entire test_run or a superset
|
||||
return Relevance(relevance_group)
|
||||
|
||||
def _get_test_order(self, test_name: str) -> int:
|
||||
"""Returns the rank of the test specified by this heuristic."""
|
||||
raise ValueError(f"Test {test_run} not found in any relevance group")
|
||||
|
||||
def _get_test_order(self, test_run: TestRun) -> int:
|
||||
"""Returns the rank this heuristic suggested for the test run."""
|
||||
base_rank = 0
|
||||
|
||||
for relevance_group_tests in self._test_priorities:
|
||||
if test_name in relevance_group_tests:
|
||||
return base_rank + relevance_group_tests.index(test_name)
|
||||
for _, relevance_group_tests in self._traverse_priorities():
|
||||
for idx, test in enumerate(relevance_group_tests):
|
||||
if test.contains(
|
||||
test_run
|
||||
): # test could be the entire test_run or a superset
|
||||
return base_rank + idx
|
||||
base_rank += len(relevance_group_tests)
|
||||
|
||||
raise ValueError(f"Test {test_name} not found in any relevance group")
|
||||
raise ValueError(f"Test {test_run} not found in any relevance group")
|
||||
|
||||
def _get_test_order_within_relevance_group(self, test_name: str) -> int:
|
||||
for relevance_group_tests in self._test_priorities:
|
||||
if test_name not in relevance_group_tests:
|
||||
continue
|
||||
def _get_test_order_within_relevance_group(self, test_run: TestRun) -> int:
|
||||
"""Returns the highest test order of any test class within the same relevance group."""
|
||||
for _, relevance_group_tests in self._traverse_priorities():
|
||||
for idx, test in enumerate(relevance_group_tests):
|
||||
if test.contains(
|
||||
test_run
|
||||
): # test could be the entire test_run or a superset
|
||||
return idx
|
||||
|
||||
return relevance_group_tests.index(test_name)
|
||||
raise ValueError(f"Test {test_run} not found in any relevance group")
|
||||
|
||||
raise ValueError(f"Test {test_name} not found in any relevance group")
|
||||
|
||||
def get_priority_info_for_test(self, test_name: str) -> Dict[str, Any]:
|
||||
def get_priority_info_for_test(self, test_run: TestRun) -> Dict[str, Any]:
|
||||
"""Given a failing test, returns information about it's prioritization that we want to emit in our metrics."""
|
||||
relevance = self._get_test_relevance_group(test_run)
|
||||
return {
|
||||
METRIC_RELEVANCE_GROUP: self._get_test_relevance_group(test_name).name,
|
||||
METRIC_RELEVANCE_GROUP: relevance.name,
|
||||
METRIC_ORDER_WITHIN_RELEVANCE_GROUP: self._get_test_order_within_relevance_group(
|
||||
test_name
|
||||
test_run
|
||||
),
|
||||
METRIC_NUM_TESTS_IN_RELEVANCE_GROUP: len(
|
||||
self._test_priorities[self._get_test_relevance_group(test_name).value]
|
||||
self._test_priorities[relevance.value]
|
||||
),
|
||||
METRIC_ORDER_OVERALL: self._get_test_order(test_name),
|
||||
METRIC_ORDER_OVERALL: self._get_test_order(test_run),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -247,12 +355,13 @@ class AggregatedHeuristics:
|
|||
|
||||
return aggregated_priorities
|
||||
|
||||
def get_test_stats(self, test: str) -> Dict[str, Any]:
|
||||
def get_test_stats(self, test: TestRun) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns the aggregated statistics for a given test.
|
||||
"""
|
||||
stats: Dict[str, Any] = {
|
||||
"test_name": test,
|
||||
"test_name": test.test_file,
|
||||
"test_filters": test.get_pytest_filter(),
|
||||
}
|
||||
|
||||
# Get baseline metrics assuming we didn't have any TD heuristics
|
||||
|
|
|
|||
305
tools/testing/test_run.py
Normal file
305
tools/testing/test_run.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
from copy import copy
|
||||
from functools import total_ordering
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
|
||||
class TestRun:
|
||||
"""
|
||||
TestRun defines the set of tests that should be run together in a single pytest invocation.
|
||||
It'll either be a whole test file or a subset of a test file.
|
||||
|
||||
This class assumes that we won't always know the full set of TestClasses in a test file.
|
||||
So it's designed to include or exclude explicitly requested TestClasses, while having accepting
|
||||
that there will be an ambiguous set of "unknown" test classes that are not expliclty called out.
|
||||
Those manifest as tests that haven't been explicitly excluded.
|
||||
"""
|
||||
|
||||
test_file: str
|
||||
_exclued: Set[str] # Tests that should be excluded from this test run
|
||||
_included: Set[str] # If non-empy, only these tests should be run in this test run
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
excluded: Optional[Iterable[str]] = None,
|
||||
included: Optional[Iterable[str]] = None,
|
||||
) -> None:
|
||||
self._excluded = set()
|
||||
self._included = set()
|
||||
|
||||
if excluded and included:
|
||||
raise ValueError("Can't specify both included and excluded")
|
||||
|
||||
if "::" in name:
|
||||
assert (
|
||||
not included and not excluded
|
||||
), "Can't specify included or excluded tests when specifying a test class in the file name"
|
||||
self.test_file, test_class = name.split("::")
|
||||
self._included.add(test_class)
|
||||
else:
|
||||
self.test_file = name
|
||||
|
||||
# For testing purposes
|
||||
if excluded:
|
||||
self._excluded = set(excluded)
|
||||
if included:
|
||||
self._included = set(included)
|
||||
|
||||
@staticmethod
|
||||
def empty() -> "TestRun":
|
||||
return TestRun("")
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
# Lack of a test_file means that this is an empty run,
|
||||
# which means there is nothing to run. It's the zero.
|
||||
return not self.test_file
|
||||
|
||||
def is_full_file(self) -> bool:
|
||||
return not self._included and not self._excluded
|
||||
|
||||
def included(self) -> Set[str]:
|
||||
return self._included.copy()
|
||||
|
||||
def excluded(self) -> Set[str]:
|
||||
return self._excluded.copy()
|
||||
|
||||
def get_pytest_filter(self) -> str:
|
||||
if self._included:
|
||||
return " or ".join(sorted(self._included))
|
||||
elif self._excluded:
|
||||
return f"not ({' and '.join(sorted(self._excluded))})"
|
||||
else:
|
||||
return ""
|
||||
|
||||
def contains(self, test: "TestRun") -> bool:
|
||||
if self.test_file != test.test_file:
|
||||
return False
|
||||
|
||||
if self.is_full_file():
|
||||
return True # self contains all tests
|
||||
|
||||
if test.is_full_file():
|
||||
return False # test contains all tests, but self doesn't
|
||||
|
||||
# Does self exclude a subset of what test excldes?
|
||||
if test._excluded:
|
||||
return test._excluded.issubset(self._excluded)
|
||||
|
||||
# Does self include everything test includes?
|
||||
if self._included:
|
||||
return test._included.issubset(self._included)
|
||||
|
||||
# Getting to here means that test includes and self excludes
|
||||
# Does self exclude anything test includes? If not, we're good
|
||||
return not self._excluded.intersection(test._included)
|
||||
|
||||
def __copy__(self) -> "TestRun":
|
||||
return TestRun(self.test_file, excluded=self._excluded, included=self._included)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return not self.is_empty()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
r: str = f"RunTest({self.test_file}"
|
||||
r += f", included: {self._included}" if self._included else ""
|
||||
r += f", excluded: {self._excluded}" if self._excluded else ""
|
||||
r += ")"
|
||||
return r
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.is_empty():
|
||||
return "Empty"
|
||||
|
||||
pytest_filter = self.get_pytest_filter()
|
||||
if pytest_filter:
|
||||
return self.test_file + ", " + pytest_filter
|
||||
return self.test_file
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, TestRun):
|
||||
return False
|
||||
|
||||
ret = self.test_file == other.test_file
|
||||
ret = ret and self._included == other._included
|
||||
ret = ret and self._excluded == other._excluded
|
||||
return ret
|
||||
|
||||
def __ior__( # noqa: PYI034 Method returns `self`
|
||||
self, other: "TestRun"
|
||||
) -> "TestRun":
|
||||
res = self | other
|
||||
self.test_file = res.test_file
|
||||
self._included = res._included
|
||||
self._excluded = res._excluded
|
||||
|
||||
return self
|
||||
|
||||
def __or__(self, other: "TestRun") -> "TestRun":
|
||||
"""
|
||||
To OR/Union test runs means to run all the tests that either of the two runs specify.
|
||||
"""
|
||||
|
||||
# Is any file empty?
|
||||
if self.is_empty():
|
||||
return other
|
||||
if other.is_empty():
|
||||
return copy(self)
|
||||
|
||||
# If not, ensure we have the same file
|
||||
assert (
|
||||
self.test_file == other.test_file
|
||||
), f"Can't exclude {other} from {self} because they're not the same test file"
|
||||
|
||||
# 4 possible cases:
|
||||
|
||||
# 1. Either file is the full file, so union is everything
|
||||
if self.is_full_file() or other.is_full_file():
|
||||
# The union is the whole file
|
||||
return TestRun(self.test_file)
|
||||
|
||||
# 2. Both files only run what's in _included, so union is the union of the two sets
|
||||
if self._included and other._included:
|
||||
return TestRun(
|
||||
self.test_file, included=self._included.union(other._included)
|
||||
)
|
||||
|
||||
# 3. Both files only exclude what's in _excluded, so union is the intersection of the two sets
|
||||
if self._excluded and other._excluded:
|
||||
return TestRun(
|
||||
self.test_file, excluded=self._excluded.intersection(other._excluded)
|
||||
)
|
||||
|
||||
# 4. One file includes and the other excludes, so we then continue excluding the _excluded set minus
|
||||
# whatever is in the _included set
|
||||
included = self._included | other._included
|
||||
excluded = self._excluded | other._excluded
|
||||
return TestRun(self.test_file, excluded=excluded - included)
|
||||
|
||||
def __isub__( # noqa: PYI034 Method returns `self`
|
||||
self, other: "TestRun"
|
||||
) -> "TestRun":
|
||||
res = self - other
|
||||
self.test_file = res.test_file
|
||||
self._included = res._included
|
||||
self._excluded = res._excluded
|
||||
return self
|
||||
|
||||
def __sub__(self, other: "TestRun") -> "TestRun":
|
||||
"""
|
||||
To subtract test runs means to run all the tests in the first run except for what the second run specifies.
|
||||
"""
|
||||
|
||||
# Is any file empty?
|
||||
if self.is_empty():
|
||||
return TestRun.empty()
|
||||
if other.is_empty():
|
||||
return copy(self)
|
||||
|
||||
# Are you trying to subtract tests that don't even exist in this test run?
|
||||
if self.test_file != other.test_file:
|
||||
return copy(self)
|
||||
|
||||
# You're subtracting everything?
|
||||
if other.is_full_file():
|
||||
return TestRun.empty()
|
||||
|
||||
def return_inclusions_or_empty(inclusions: Set[str]) -> TestRun:
|
||||
if inclusions:
|
||||
return TestRun(self.test_file, included=inclusions)
|
||||
return TestRun.empty()
|
||||
|
||||
if other._included:
|
||||
if self._included:
|
||||
return return_inclusions_or_empty(self._included - other._included)
|
||||
else:
|
||||
return TestRun(
|
||||
self.test_file, excluded=self._excluded | other._included
|
||||
)
|
||||
else:
|
||||
if self._included:
|
||||
return return_inclusions_or_empty(self._included & other._excluded)
|
||||
else:
|
||||
return return_inclusions_or_empty(other._excluded - self._excluded)
|
||||
|
||||
def __and__(self, other: "TestRun") -> "TestRun":
|
||||
if self.test_file != other.test_file:
|
||||
return TestRun.empty()
|
||||
|
||||
return (self | other) - (self - other) - (other - self)
|
||||
|
||||
|
||||
TestRuns = Tuple[TestRun, ...]
|
||||
|
||||
|
||||
@total_ordering
|
||||
class ShardedTest:
|
||||
test: TestRun
|
||||
shard: int
|
||||
num_shards: int
|
||||
time: Optional[float] # In seconds
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
test: Union[TestRun, str],
|
||||
shard: int,
|
||||
num_shards: int,
|
||||
time: Optional[float] = None,
|
||||
) -> None:
|
||||
if isinstance(test, str):
|
||||
test = TestRun(test)
|
||||
self.test = test
|
||||
self.shard = shard
|
||||
self.num_shards = num_shards
|
||||
self.time = time
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.test.test_file
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, ShardedTest):
|
||||
return False
|
||||
return (
|
||||
self.test == other.test
|
||||
and self.shard == other.shard
|
||||
and self.num_shards == other.num_shards
|
||||
and self.time == other.time
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
ret = f"{self.test} {self.shard}/{self.num_shards}"
|
||||
if self.time:
|
||||
ret += f" ({self.time}s)"
|
||||
|
||||
return ret
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, ShardedTest):
|
||||
raise NotImplementedError
|
||||
|
||||
# This is how the list was implicity sorted when it was a NamedTuple
|
||||
if self.name != other.name:
|
||||
return self.name < other.name
|
||||
if self.shard != other.shard:
|
||||
return self.shard < other.shard
|
||||
if self.num_shards != other.num_shards:
|
||||
return self.num_shards < other.num_shards
|
||||
|
||||
# None is the smallest value
|
||||
if self.time is None:
|
||||
return True
|
||||
if other.time is None:
|
||||
return False
|
||||
return self.time < other.time
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.test} {self.shard}/{self.num_shards}"
|
||||
|
||||
def get_time(self) -> float:
|
||||
return self.time or 0
|
||||
|
||||
def get_pytest_args(self) -> List[str]:
|
||||
filter = self.test.get_pytest_filter()
|
||||
if filter:
|
||||
return ["-k", self.test.get_pytest_filter()]
|
||||
return []
|
||||
|
|
@ -3,9 +3,10 @@ import os
|
|||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||
|
||||
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
||||
from tools.testing.test_run import ShardedTest, TestRun
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
|
|
@ -42,19 +43,6 @@ if IS_ROCM and not IS_MEM_LEAK_CHECK:
|
|||
NUM_PROCS = 1
|
||||
|
||||
|
||||
class ShardedTest(NamedTuple):
|
||||
name: str
|
||||
shard: int
|
||||
num_shards: int
|
||||
time: Optional[float] # In seconds
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name} {self.shard}/{self.num_shards}"
|
||||
|
||||
def get_time(self) -> float:
|
||||
return self.time or 0
|
||||
|
||||
|
||||
class ShardJob:
|
||||
def __init__(self) -> None:
|
||||
self.serial: List[ShardedTest] = []
|
||||
|
|
@ -73,11 +61,51 @@ class ShardJob:
|
|||
|
||||
|
||||
def get_with_pytest_shard(
|
||||
tests: List[str], test_file_times: Dict[str, float]
|
||||
tests: Sequence[TestRun],
|
||||
test_file_times: Dict[str, float],
|
||||
test_class_times: Optional[Dict[str, Dict[str, float]]],
|
||||
) -> List[ShardedTest]:
|
||||
sharded_tests: List[ShardedTest] = []
|
||||
|
||||
def get_duration_for_classes(
|
||||
test_file: str, test_classes: Set[str]
|
||||
) -> Optional[float]:
|
||||
duration: float = 0
|
||||
if not test_class_times:
|
||||
return None
|
||||
|
||||
for test_class in test_classes:
|
||||
class_duration = test_class_times.get(test_file, {}).get(test_class, None)
|
||||
if class_duration is None:
|
||||
return None
|
||||
if class_duration:
|
||||
duration += class_duration
|
||||
return duration
|
||||
|
||||
for test in tests:
|
||||
duration = test_file_times.get(test, None)
|
||||
file_duration = test_file_times.get(test.test_file, None)
|
||||
included = test.included()
|
||||
excluded = test.excluded()
|
||||
included_classes_duration = get_duration_for_classes(test.test_file, included)
|
||||
excluded_classes_duration = get_duration_for_classes(test.test_file, excluded)
|
||||
|
||||
if included:
|
||||
# If we don't have the time for all included classes, our upper bound is the file duration
|
||||
duration = (
|
||||
included_classes_duration
|
||||
if included_classes_duration is not None
|
||||
else file_duration
|
||||
)
|
||||
elif excluded:
|
||||
# If we don't have the time for all excluded classes, our upper bound is file duration
|
||||
duration = (
|
||||
file_duration - excluded_classes_duration
|
||||
if excluded_classes_duration is not None and file_duration is not None
|
||||
else file_duration
|
||||
)
|
||||
else:
|
||||
duration = file_duration
|
||||
|
||||
if duration and duration > THRESHOLD:
|
||||
num_shards = math.ceil(duration / THRESHOLD)
|
||||
for i in range(num_shards):
|
||||
|
|
@ -91,21 +119,27 @@ def get_with_pytest_shard(
|
|||
|
||||
def calculate_shards(
|
||||
num_shards: int,
|
||||
tests: List[str],
|
||||
tests: Sequence[TestRun],
|
||||
test_file_times: Dict[str, float],
|
||||
test_class_times: Optional[Dict[str, Dict[str, float]]],
|
||||
must_serial: Optional[Callable[[str], bool]] = None,
|
||||
sort_by_time: bool = True,
|
||||
) -> List[Tuple[float, List[ShardedTest]]]:
|
||||
must_serial = must_serial or (lambda x: True)
|
||||
|
||||
known_tests = tests
|
||||
unknown_tests = []
|
||||
known_tests: Sequence[TestRun] = tests
|
||||
unknown_tests: Sequence[TestRun] = []
|
||||
|
||||
if sort_by_time:
|
||||
known_tests = [x for x in tests if x in test_file_times]
|
||||
known_tests = [
|
||||
x
|
||||
for x in tests
|
||||
if x.test_file in test_file_times
|
||||
or (test_class_times and x.test_file in test_class_times)
|
||||
]
|
||||
unknown_tests = [x for x in tests if x not in known_tests]
|
||||
|
||||
known_tests = get_with_pytest_shard(known_tests, test_file_times)
|
||||
known_tests = get_with_pytest_shard(known_tests, test_file_times, test_class_times)
|
||||
|
||||
if sort_by_time:
|
||||
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user