[TD] Enable Test Class granularity on heuristics (#112161)

Changes the heuristic framework to support multiple prioritizing individual classes within a test file.

Components of this included:
- Updating TestPrioritizations to accept individual test classes being prioritized. Previously, when a heuristic wanted to prioritize a test file it would pass in the test's name, now to prioritize a class within a test it uses the notation "test::classname"
- Changes are fully backwards compatible with existing heuristics
- Test sharding now supports sharding individual tests (for when they're prioritized)
- When a TestClass is prioritized, we pass the appropriate "-k" flags down to pytest

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112161
Approved by: https://github.com/huydhn
This commit is contained in:
Zain Rizvi 2023-10-30 18:31:40 -05:00 committed by PyTorch MergeBot
parent 5cd1208415
commit a5641bc56b
7 changed files with 1235 additions and 258 deletions

View File

@ -15,7 +15,7 @@ import tempfile
import time
from contextlib import ExitStack
from datetime import datetime
from typing import Any, cast, Dict, List, NamedTuple, Optional, Tuple, Union
from typing import Any, cast, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import pkg_resources
@ -40,12 +40,18 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
# using tools/ to optimize test run.
sys.path.insert(0, str(REPO_ROOT))
from tools.stats.import_test_stats import ADDITIONAL_CI_FILES_FOLDER, TEST_TIMES_FILE
from tools.stats.import_test_stats import (
ADDITIONAL_CI_FILES_FOLDER,
TEST_CLASS_TIMES_FILE,
TEST_TIMES_FILE,
)
from tools.stats.upload_metrics import add_global_metric, emit_metric
from tools.testing.target_determination.determinator import (
AggregatedHeuristics,
get_test_prioritizations,
)
from tools.testing.test_run import TestRun
from tools.testing.test_selections import (
calculate_shards,
get_test_case_configs,
@ -479,7 +485,7 @@ def get_executable_command(options, disable_coverage=False, is_cpp_test=False):
def run_test(
test_module,
test_module: ShardedTest,
test_directory,
options,
launcher_cmd=None,
@ -488,14 +494,9 @@ def run_test(
) -> int:
maybe_set_hip_visible_devies()
unittest_args = options.additional_unittest_args.copy()
test_file = test_module
test_file = test_module.name
stepcurrent_key = test_file
use_sharded_test = False
if isinstance(test_file, ShardedTest):
test_file = test_module.name
use_sharded_test = True
is_distributed_test = test_file.startswith(DISTRIBUTED_TEST_PREFIX)
is_cpp_test = test_file.startswith(CPP_TEST_PREFIX)
# NB: Rerun disabled tests depends on pytest-flakefinder and it doesn't work with
@ -507,17 +508,16 @@ def run_test(
)
return 0
if use_sharded_test:
if is_cpp_test:
stepcurrent_key = test_file
else:
unittest_args.extend(
[
f"--shard-id={test_module.shard - 1}",
f"--num-shards={test_module.num_shards}",
]
)
stepcurrent_key = f"{test_file}_{test_module.shard - 1}"
if is_cpp_test:
stepcurrent_key = test_file
else:
unittest_args.extend(
[
f"--shard-id={test_module.shard - 1}",
f"--num-shards={test_module.num_shards}",
]
)
stepcurrent_key = f"{test_file}_{test_module.shard - 1}"
if options.verbose:
unittest_args.append(f'-{"v"*options.verbose}') # in case of pytest
@ -541,6 +541,7 @@ def run_test(
is_distributed_test=is_distributed_test,
)
)
unittest_args.extend(test_module.get_pytest_args())
unittest_args = [arg if arg != "-f" else "-x" for arg in unittest_args]
# TODO: These features are not available for C++ test yet
@ -706,7 +707,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
assert install_directory, "install_directory must not be empty"
os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
return run_test(test_module, test_directory, options)
return run_test(ShardedTest(test_module, 1, 1), test_directory, options)
finally:
os.environ["PYTHONPATH"] = python_path
if os.path.exists(test_directory + "/" + test_module + ".py"):
@ -1424,14 +1425,14 @@ def get_selected_tests(options) -> List[str]:
return selected_tests
def download_test_times(
file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE,
) -> Dict[str, float]:
# Download previous test times to make sharding decisions
def load_test_times_from_file(
file: str,
) -> Dict[str, Any]:
# Load previous test times to make sharding decisions
path = os.path.join(str(REPO_ROOT), file)
if not os.path.exists(path):
print_to_stderr(
"::warning:: Failed to find test times file. Using round robin sharding."
f"::warning:: Failed to find test times file `{path}`. Using round robin sharding."
)
return {}
@ -1456,6 +1457,18 @@ def download_test_times(
return test_times_file["default"]["default"]
def load_test_file_times(
file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE,
) -> Dict[str, float]:
return cast(Dict[str, float], load_test_times_from_file(file))
def load_test_class_times(
file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_CLASS_TIMES_FILE,
) -> Dict[str, Dict[str, float]]:
return cast(Dict[str, Dict[str, float]], load_test_times_from_file(file))
def get_sharding_opts(options) -> Tuple[int, int]:
which_shard, num_shards = 1, 1
if options.shard:
@ -1471,8 +1484,9 @@ def get_sharding_opts(options) -> Tuple[int, int]:
def do_sharding(
options,
selected_tests: List[str],
selected_tests: Sequence[TestRun],
test_file_times: Dict[str, float],
test_class_times: Dict[str, Dict[str, float]],
sort_by_time: bool = True,
) -> List[ShardedTest]:
which_shard, num_shards = get_sharding_opts(options)
@ -1482,6 +1496,7 @@ def do_sharding(
num_shards,
selected_tests,
test_file_times,
test_class_times=test_class_times,
must_serial=must_serial,
sort_by_time=sort_by_time,
)
@ -1492,16 +1507,16 @@ def do_sharding(
class TestFailure(NamedTuple):
test: str
test: TestRun
message: str
def run_test_module(
test: Union[ShardedTest, str], test_directory: str, options
test: ShardedTest, test_directory: str, options
) -> Optional[TestFailure]:
maybe_set_hip_visible_devies()
test_name = test.name if isinstance(test, ShardedTest) else test
test_name = test.name
# Printing the date here can help diagnose which tests are slow
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
@ -1519,7 +1534,7 @@ def run_test_module(
# return code -N, where N is the signal number.
signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
message += f" Received signal: {signal_name}"
return TestFailure(test_name, message)
return TestFailure(test.test, message)
def run_tests(
@ -1671,6 +1686,9 @@ def main():
"cpp": options.cpp,
}
test_file_times_dict = load_test_file_times()
test_class_times_dict = load_test_class_times()
class TestBatch:
"""Defines a set of tests with similar priority that should be run together on the current shard"""
@ -1678,11 +1696,17 @@ def main():
sharded_tests: List[ShardedTest]
failures: List[TestFailure]
def __init__(self, name: str, raw_tests: List[str], should_sort_shard: bool):
def __init__(
self, name: str, raw_tests: Sequence[TestRun], should_sort_shard: bool
):
self.name = name
self.failures = []
self.sharded_tests = do_sharding(
options, raw_tests, test_times_dict, sort_by_time=should_sort_shard
options,
raw_tests,
test_file_times_dict,
test_class_times_dict,
sort_by_time=should_sort_shard,
)
def __str__(self):
@ -1697,7 +1721,6 @@ def main():
)
return s.strip()
test_times_dict = download_test_times(ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE)
test_batches: List[TestBatch] = []
# Each batch will be run sequentially
@ -1749,7 +1772,7 @@ def main():
test_batch.sharded_tests, test_directory, options, test_batch.failures
)
metrics_dict[f"{test_batch.name}_failures"] = [
x.test for x in test_batch.failures
str(x.test) for x in test_batch.failures
]
finally:

View File

@ -3,7 +3,7 @@ import json
import pathlib
import sys
import unittest
from typing import Any, Dict, Set
from typing import Any, Dict, Optional, Set
from unittest import mock
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
@ -12,12 +12,15 @@ try:
sys.path.append(str(REPO_ROOT))
from tools.testing.target_determination.determinator import (
AggregatedHeuristics,
get_test_prioritizations,
TestPrioritizations,
)
from tools.testing.target_determination.heuristics import HEURISTICS
from tools.testing.target_determination.heuristics.previously_failed_in_pr import (
_get_previously_failing_tests,
)
from tools.testing.test_run import TestRun, TestRuns
except ModuleNotFoundError:
print("Can't import required modules, exiting")
@ -31,7 +34,37 @@ def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
return file_object
class TestParsePrevTests(unittest.TestCase):
class HeuristicsTestMixin(unittest.TestCase):
def assert_heuristics_match(
self,
test_prioritizations: TestPrioritizations,
expected_high_tests: Optional[TestRuns] = None,
expected_probable_tests: Optional[TestRuns] = None,
expected_unranked_tests: Optional[TestRuns] = None,
) -> None:
if expected_unranked_tests:
self.assertTupleEqual(
test_prioritizations.get_unranked_relevance_tests(),
expected_unranked_tests,
"Unranked tests differ",
)
if expected_probable_tests:
self.assertTupleEqual(
test_prioritizations.get_probable_relevance_tests(),
expected_probable_tests,
"Probable relevance tests differ",
)
if expected_high_tests:
self.assertTupleEqual(
test_prioritizations.get_high_relevance_tests(),
expected_high_tests,
"High relevance tests differ",
)
class TestParsePrevTests(HeuristicsTestMixin):
@mock.patch("pathlib.Path.exists", return_value=False)
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
expected_failing_test_files: Set[str] = set()
@ -94,18 +127,257 @@ class TestParsePrevTests(unittest.TestCase):
tests
).get_aggregated_priorities()
self.assertTupleEqual(
expected_prioritizations.get_high_relevance_tests(),
test_prioritizations.get_high_relevance_tests(),
self.assert_heuristics_match(
test_prioritizations,
expected_high_tests=expected_prioritizations.get_high_relevance_tests(),
expected_probable_tests=expected_prioritizations.get_probable_relevance_tests(),
expected_unranked_tests=expected_prioritizations.get_unranked_relevance_tests(),
)
self.assertTupleEqual(
expected_prioritizations.get_probable_relevance_tests(),
test_prioritizations.get_probable_relevance_tests(),
class TestInterface(HeuristicsTestMixin):
def test_class_prioritization(self) -> None:
tests = ["test1", "test2", "test3", "test4", "test5"]
prioritizations = TestPrioritizations(
tests_being_ranked=tests,
probable_relevance=["test2::TestFooClass", "test3"],
)
self.assertTupleEqual(
expected_prioritizations.get_unranked_relevance_tests(),
test_prioritizations.get_unranked_relevance_tests(),
expected_probable_tests = tuple(
TestRun(test) for test in ["test2::TestFooClass", "test3"]
)
expected_unranked_tests = (
TestRun("test1"),
TestRun("test2", excluded=["TestFooClass"]),
TestRun("test4"),
TestRun("test5"),
)
self.assert_heuristics_match(
prioritizations,
expected_probable_tests=expected_probable_tests,
expected_unranked_tests=expected_unranked_tests,
)
class TestAggregatedHeuristics(HeuristicsTestMixin):
def test_merging_multiple_test_class_heuristics(self) -> None:
tests = ["test1", "test2", "test3", "test4"]
heuristic1 = TestPrioritizations(
tests_being_ranked=tests,
probable_relevance=["test2::TestFooClass", "test3"],
)
heuristic2 = TestPrioritizations(
tests_being_ranked=tests,
high_relevance=["test2::TestFooClass", "test3::TestBarClass"],
)
expected_high_relevance = tuple(
TestRun(test) for test in ["test2::TestFooClass", "test3::TestBarClass"]
)
expected_probable_relevance = (TestRun("test3", excluded=["TestBarClass"]),)
expected_unranked_relevance = (
TestRun("test1"),
TestRun("test2", excluded=["TestFooClass"]),
TestRun("test4"),
)
aggregator = AggregatedHeuristics(unranked_tests=tests)
aggregator.add_heuristic_results(HEURISTICS[0], heuristic1)
aggregator.add_heuristic_results(HEURISTICS[1], heuristic2)
aggregated_pris = aggregator.get_aggregated_priorities()
self.assert_heuristics_match(
aggregated_pris,
expected_high_tests=expected_high_relevance,
expected_probable_tests=expected_probable_relevance,
expected_unranked_tests=expected_unranked_relevance,
)
def test_merging_file_heuristic_after_class_heuristic(self) -> None:
tests = ["test1", "test2", "test3", "test4", "test5"]
heuristic1 = TestPrioritizations(
tests_being_ranked=tests,
high_relevance=["test2::TestFooClass"],
)
heuristic2 = TestPrioritizations(
tests_being_ranked=tests,
probable_relevance=["test2", "test3"],
)
expected_aggregated_high_relevance = tuple(
TestRun(test) for test in ["test2::TestFooClass"]
)
expected_aggregated_probable_relevance = (
TestRun("test2", excluded=["TestFooClass"]),
TestRun("test3"),
)
expected_aggregated_unranked_relevance = (
TestRun("test1"),
TestRun("test4"),
TestRun("test5"),
)
aggregator = AggregatedHeuristics(unranked_tests=tests)
aggregator.add_heuristic_results(HEURISTICS[0], heuristic1)
aggregator.add_heuristic_results(HEURISTICS[1], heuristic2)
aggregated_pris = aggregator.get_aggregated_priorities()
self.assert_heuristics_match(
aggregated_pris,
expected_high_tests=expected_aggregated_high_relevance,
expected_probable_tests=expected_aggregated_probable_relevance,
expected_unranked_tests=expected_aggregated_unranked_relevance,
)
def test_get_test_stats_with_whole_tests(self) -> None:
self.maxDiff = None
tests = ["test1", "test2", "test3", "test4", "test5"]
heuristic1 = TestPrioritizations(
tests_being_ranked=tests,
high_relevance=["test3", "test4"],
)
heuristic2 = TestPrioritizations(
tests_being_ranked=tests,
probable_relevance=["test5"],
)
aggregator = AggregatedHeuristics(unranked_tests=tests)
aggregator.add_heuristic_results(HEURISTICS[0], heuristic1)
aggregator.add_heuristic_results(HEURISTICS[1], heuristic2)
expected_test3_stats = {
"test_name": "test3",
"test_filters": "",
"without_heuristics": {
"relevance_group": "UNRANKED",
"order_within_relevance_group": 2,
"num_tests_in_relevance_group": 5,
"order_overall": 2,
"heuristic_name": "baseline",
},
"heuristics": [
{
"relevance_group": "HIGH",
"order_within_relevance_group": 0,
"num_tests_in_relevance_group": 2,
"order_overall": 0,
"heuristic_name": HEURISTICS[0].name,
"trial_mode": False,
},
{
"relevance_group": "UNRANKED",
"order_within_relevance_group": 2,
"num_tests_in_relevance_group": 4,
"order_overall": 3,
"heuristic_name": HEURISTICS[1].name,
"trial_mode": False,
},
],
"num_heuristics_prioritized_by": 1,
"aggregated": {
"relevance_group": "HIGH",
"order_within_relevance_group": 0,
"num_tests_in_relevance_group": 2,
"order_overall": 0,
},
"aggregated_trial": {
"relevance_group": "HIGH",
"order_within_relevance_group": 0,
"num_tests_in_relevance_group": 2,
"order_overall": 0,
},
"highest_ranking_heuristic": HEURISTICS[0].name,
}
test3_stats = aggregator.get_test_stats(TestRun("test3"))
self.assertDictEqual(test3_stats, expected_test3_stats)
def test_get_test_stats_only_contains_allowed_types(self) -> None:
self.maxDiff = None
tests = ["test1", "test2", "test3", "test4", "test5"]
heuristic1 = TestPrioritizations(
tests_being_ranked=tests,
high_relevance=["test3", "test4"],
)
heuristic2 = TestPrioritizations(
tests_being_ranked=tests,
probable_relevance=["test5::classA"],
)
aggregator = AggregatedHeuristics(unranked_tests=tests)
aggregator.add_heuristic_results(HEURISTICS[0], heuristic1)
aggregator.add_heuristic_results(HEURISTICS[1], heuristic2)
stats3 = aggregator.get_test_stats(TestRun("test3"))
stats5 = aggregator.get_test_stats(TestRun("test5::classA"))
def assert_valid_dict(dict_contents: Dict[str, Any]) -> None:
for key, value in dict_contents.items():
self.assertTrue(isinstance(key, str))
self.assertTrue(
isinstance(value, (str, float, int, list, dict)),
f"{value} is not a str, float, or dict",
)
if isinstance(value, dict):
assert_valid_dict(value)
elif isinstance(value, list):
for item in value:
assert_valid_dict(item)
assert_valid_dict(stats3)
assert_valid_dict(stats5)
def test_get_test_stats_gets_rank_for_test_classes(self) -> None:
self.maxDiff = None
tests = ["test1", "test2", "test3", "test4", "test5"]
heuristic1 = TestPrioritizations(
tests_being_ranked=tests,
high_relevance=["test3", "test4"],
)
heuristic2 = TestPrioritizations(
tests_being_ranked=tests,
probable_relevance=["test5::classA"],
)
aggregator = AggregatedHeuristics(unranked_tests=tests)
aggregator.add_heuristic_results(HEURISTICS[0], heuristic1)
aggregator.add_heuristic_results(HEURISTICS[1], heuristic2)
statsInclusive = aggregator.get_test_stats(
TestRun("test5", included=["classA"])
)
statsExclusive = aggregator.get_test_stats(
TestRun("test5", excluded=["classA"])
)
print("h")
# Validate the heuristic level stats are correct
self.assertEqual(
statsInclusive["heuristics"][1]["order_within_relevance_group"], 0
)
self.assertEqual(
statsInclusive["heuristics"][1]["num_tests_in_relevance_group"], 1
)
self.assertEqual(statsInclusive["heuristics"][1]["order_overall"], 0)
self.assertEqual(statsInclusive["heuristics"][1]["relevance_group"], "PROBABLE")
self.assertEqual(statsInclusive["aggregated"]["order_overall"], 2)
self.assertEqual(
statsExclusive["heuristics"][1]["order_within_relevance_group"], 4
)
self.assertEqual(
statsExclusive["heuristics"][1]["num_tests_in_relevance_group"], 5
)
self.assertEqual(statsExclusive["heuristics"][1]["order_overall"], 5)
self.assertEqual(statsExclusive["heuristics"][1]["relevance_group"], "UNRANKED")
self.assertEqual(statsExclusive["aggregated"]["order_overall"], 5)
if __name__ == "__main__":

169
tools/test/test_test_run.py Normal file
View File

@ -0,0 +1,169 @@
import pathlib
import sys
import unittest
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
try:
# using tools/ to optimize test run.
sys.path.append(str(REPO_ROOT))
from tools.testing.test_run import ShardedTest, TestRun
except ModuleNotFoundError:
print("Can't import required modules, exiting")
sys.exit(1)
class TestTestRun(unittest.TestCase):
def test_union_with_full_run(self) -> None:
run1 = TestRun("foo")
run2 = TestRun("foo::bar")
self.assertEqual(run1 | run2, run1)
self.assertEqual(run2 | run1, run1)
def test_union_with_inclusions(self) -> None:
run1 = TestRun("foo::bar")
run2 = TestRun("foo::baz")
expected = TestRun("foo")
expected._included.add("bar")
expected._included.add("baz")
self.assertEqual(run1 | run2, expected)
self.assertEqual(run2 | run1, expected)
def test_union_with_non_overlapping_exclusions(self) -> None:
run1 = TestRun("foo", excluded=["bar"])
run2 = TestRun("foo", excluded=["baz"])
expected = TestRun("foo")
self.assertEqual(run1 | run2, expected)
self.assertEqual(run2 | run1, expected)
def test_union_with_overlapping_exclusions(self) -> None:
run1 = TestRun("foo", excluded=["bar", "car"])
run2 = TestRun("foo", excluded=["bar", "caz"])
expected = TestRun("foo", excluded=["bar"])
self.assertEqual(run1 | run2, expected)
self.assertEqual(run2 | run1, expected)
def test_union_with_mixed_inclusion_exclusions(self) -> None:
run1 = TestRun("foo", excluded=["baz", "car"])
run2 = TestRun("foo", included=["baz"])
expected = TestRun("foo", excluded=["car"])
self.assertEqual(run1 | run2, expected)
self.assertEqual(run2 | run1, expected)
def test_union_with_mixed_files_fails(self) -> None:
run1 = TestRun("foo")
run2 = TestRun("bar")
with self.assertRaises(AssertionError):
run1 | run2
def test_union_with_empty_file_yields_orig_file(self) -> None:
run1 = TestRun("foo")
run2 = TestRun.empty()
self.assertEqual(run1 | run2, run1)
self.assertEqual(run2 | run1, run1)
def test_subtracting_full_run_fails(self) -> None:
run1 = TestRun("foo::bar")
run2 = TestRun("foo")
self.assertEqual(run1 - run2, TestRun.empty())
def test_subtracting_empty_file_yields_orig_file(self) -> None:
run1 = TestRun("foo")
run2 = TestRun.empty()
self.assertEqual(run1 - run2, run1)
self.assertEqual(run2 - run1, TestRun.empty())
def test_empty_is_falsey(self) -> None:
self.assertFalse(TestRun.empty())
def test_subtracting_inclusion_from_full_run(self) -> None:
run1 = TestRun("foo")
run2 = TestRun("foo::bar")
expected = TestRun("foo", excluded=["bar"])
self.assertEqual(run1 - run2, expected)
def test_subtracting_inclusion_from_overlapping_inclusion(self) -> None:
run1 = TestRun("foo", included=["bar", "baz"])
run2 = TestRun("foo::baz")
self.assertEqual(run1 - run2, TestRun("foo", included=["bar"]))
def test_subtracting_inclusion_from_nonoverlapping_inclusion(self) -> None:
run1 = TestRun("foo", included=["bar", "baz"])
run2 = TestRun("foo", included=["car"])
self.assertEqual(run1 - run2, TestRun("foo", included=["bar", "baz"]))
def test_subtracting_exclusion_from_full_run(self) -> None:
run1 = TestRun("foo")
run2 = TestRun("foo", excluded=["bar"])
self.assertEqual(run1 - run2, TestRun("foo", included=["bar"]))
def test_subtracting_exclusion_from_superset_exclusion(self) -> None:
run1 = TestRun("foo", excluded=["bar", "baz"])
run2 = TestRun("foo", excluded=["baz"])
self.assertEqual(run1 - run2, TestRun.empty())
self.assertEqual(run2 - run1, TestRun("foo", included=["bar"]))
def test_subtracting_exclusion_from_nonoverlapping_exclusion(self) -> None:
run1 = TestRun("foo", excluded=["bar", "baz"])
run2 = TestRun("foo", excluded=["car"])
self.assertEqual(run1 - run2, TestRun("foo", included=["car"]))
self.assertEqual(run2 - run1, TestRun("foo", included=["bar", "baz"]))
def test_subtracting_inclusion_from_exclusion_without_overlaps(self) -> None:
run1 = TestRun("foo", excluded=["bar", "baz"])
run2 = TestRun("foo", included=["bar"])
self.assertEqual(run1 - run2, run1)
self.assertEqual(run2 - run1, run2)
def test_subtracting_inclusion_from_exclusion_with_overlaps(self) -> None:
run1 = TestRun("foo", excluded=["bar", "baz"])
run2 = TestRun("foo", included=["bar", "car"])
self.assertEqual(run1 - run2, TestRun("foo", excluded=["bar", "baz", "car"]))
self.assertEqual(run2 - run1, TestRun("foo", included=["bar"]))
def test_and(self) -> None:
run1 = TestRun("foo", included=["bar", "baz"])
run2 = TestRun("foo", included=["bar", "car"])
self.assertEqual(run1 & run2, TestRun("foo", included=["bar"]))
def test_and_exclusions(self) -> None:
run1 = TestRun("foo", excluded=["bar", "baz"])
run2 = TestRun("foo", excluded=["bar", "car"])
self.assertEqual(run1 & run2, TestRun("foo", excluded=["bar", "baz", "car"]))
class TestShardedTest(unittest.TestCase):
def test_get_pytest_args(self) -> None:
test = TestRun("foo", included=["bar", "baz"])
sharded_test = ShardedTest(test, 1, 1)
expected_args = ["-k", "bar or baz"]
self.assertListEqual(sharded_test.get_pytest_args(), expected_args)
if __name__ == "__main__":
unittest.main()

View File

@ -9,25 +9,30 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
try:
# using tools/ to optimize test run.
sys.path.append(str(REPO_ROOT))
from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD
from tools.testing.test_run import ShardedTest, TestRun
from tools.testing.test_selections import calculate_shards, THRESHOLD
except ModuleNotFoundError:
print("Can't import required modules, exiting")
sys.exit(1)
def gen_class_times(test_times: Dict[str, float]) -> Dict[str, Dict[str, float]]:
return {k: {"class1": v} for k, v in test_times.items()}
class TestCalculateShards(unittest.TestCase):
tests: List[str] = [
"super_long_test",
"long_test1",
"long_test2",
"normal_test1",
"normal_test2",
"normal_test3",
"short_test1",
"short_test2",
"short_test3",
"short_test4",
"short_test5",
tests: List[TestRun] = [
TestRun("super_long_test"),
TestRun("long_test1"),
TestRun("long_test2"),
TestRun("normal_test1"),
TestRun("normal_test2"),
TestRun("normal_test3"),
TestRun("short_test1"),
TestRun("short_test2"),
TestRun("short_test3"),
TestRun("short_test4"),
TestRun("short_test5"),
]
test_times: Dict[str, float] = {
@ -44,6 +49,20 @@ class TestCalculateShards(unittest.TestCase):
"short_test5": 0.01,
}
test_class_times: Dict[str, Dict[str, float]] = {
"super_long_test": {"class1": 55},
"long_test1": {"class1": 1, "class2": 21},
"long_test2": {"class1": 10, "class2": 8},
"normal_test1": {"class1": 9},
"normal_test2": {"class1": 7},
"normal_test3": {"class1": 5},
"short_test1": {"class1": 1},
"short_test2": {"class1": 0.6},
"short_test3": {"class1": 0.4},
"short_test4": {"class1": 0.3},
"short_test5": {"class1": 0.01},
}
def assert_shards_equal(
self,
expected_shards: List[Tuple[float, List[ShardedTest]]],
@ -58,81 +77,92 @@ class TestCalculateShards(unittest.TestCase):
(
60.0,
[
ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55),
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5),
ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55),
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5),
],
),
(
58.31,
[
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
ShardedTest(name="long_test2", shard=1, num_shards=1, time=18),
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7),
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6),
ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4),
ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3),
ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01),
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
ShardedTest(test="long_test2", shard=1, num_shards=1, time=18),
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7),
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6),
ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4),
ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3),
ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01),
],
),
]
self.assert_shards_equal(
expected_shards, calculate_shards(2, self.tests, self.test_times)
expected_shards,
calculate_shards(2, self.tests, self.test_times, self.test_class_times),
)
def test_calculate_1_shard_with_complete_test_times(self) -> None:
tests = self.tests.copy()
class_test1 = TestRun("long_test1", excluded=["class2"])
class_test2 = TestRun("long_test1", included=["class2"])
tests.append(class_test1)
tests.append(class_test2)
expected_shards = [
(
118.31,
140.31,
[
ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55),
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
ShardedTest(name="long_test2", shard=1, num_shards=1, time=18),
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7),
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5),
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6),
ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4),
ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3),
ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01),
ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55),
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
ShardedTest(class_test2, shard=1, num_shards=1, time=21),
ShardedTest(test="long_test2", shard=1, num_shards=1, time=18),
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7),
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5),
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(class_test1, shard=1, num_shards=1, time=1),
ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6),
ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4),
ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3),
ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01),
],
)
]
self.assert_shards_equal(
expected_shards, calculate_shards(1, self.tests, self.test_times)
expected_shards,
calculate_shards(1, tests, self.test_times, self.test_class_times),
)
def test_calculate_5_shards_with_complete_test_times(self) -> None:
expected_shards = [
(
55.0,
[ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55)],
[ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55)],
),
(22.0, [ShardedTest(name="long_test1", shard=1, num_shards=1, time=22)]),
(18.0, [ShardedTest(name="long_test2", shard=1, num_shards=1, time=18)]),
(22.0, [ShardedTest(test="long_test1", shard=1, num_shards=1, time=22)]),
(18.0, [ShardedTest(test="long_test2", shard=1, num_shards=1, time=18)]),
(
11.31,
[
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6),
ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4),
ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3),
ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01),
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6),
ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4),
ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3),
ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01),
],
),
(
12.0,
[
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7),
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5),
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7),
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5),
],
),
]
self.assert_shards_equal(
expected_shards, calculate_shards(5, self.tests, self.test_times)
expected_shards,
calculate_shards(5, self.tests, self.test_times, self.test_class_times),
)
def test_calculate_2_shards_with_incomplete_test_times(self) -> None:
@ -143,29 +173,35 @@ class TestCalculateShards(unittest.TestCase):
(
22.0,
[
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
ShardedTest(name="long_test2", shard=1, num_shards=1, time=None),
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=None),
ShardedTest(name="short_test3", shard=1, num_shards=1, time=None),
ShardedTest(name="short_test5", shard=1, num_shards=1, time=None),
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
ShardedTest(test="short_test3", shard=1, num_shards=1, time=None),
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
],
),
(
10.0,
[
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(
name="super_long_test", shard=1, num_shards=1, time=None
test="super_long_test", shard=1, num_shards=1, time=None
),
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=None),
ShardedTest(name="short_test2", shard=1, num_shards=1, time=None),
ShardedTest(name="short_test4", shard=1, num_shards=1, time=None),
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
],
),
]
self.assert_shards_equal(
expected_shards, calculate_shards(2, self.tests, incomplete_test_times)
expected_shards,
calculate_shards(
2,
self.tests,
incomplete_test_times,
gen_class_times(incomplete_test_times),
),
)
def test_calculate_5_shards_with_incomplete_test_times(self) -> None:
@ -176,54 +212,66 @@ class TestCalculateShards(unittest.TestCase):
(
22.0,
[
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=None),
ShardedTest(name="short_test5", shard=1, num_shards=1, time=None),
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
],
),
(
9.0,
[
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=None),
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
],
),
(
1.0,
[
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(name="short_test2", shard=1, num_shards=1, time=None),
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
],
),
(
0.0,
[
ShardedTest(
name="super_long_test", shard=1, num_shards=1, time=None
test="super_long_test", shard=1, num_shards=1, time=None
),
ShardedTest(name="short_test3", shard=1, num_shards=1, time=None),
ShardedTest(test="short_test3", shard=1, num_shards=1, time=None),
],
),
(
0.0,
[
ShardedTest(name="long_test2", shard=1, num_shards=1, time=None),
ShardedTest(name="short_test4", shard=1, num_shards=1, time=None),
ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
],
),
]
self.assert_shards_equal(
expected_shards, calculate_shards(5, self.tests, incomplete_test_times)
expected_shards,
calculate_shards(
5,
self.tests,
incomplete_test_times,
gen_class_times(incomplete_test_times),
),
)
def test_split_shards(self) -> None:
test_times: Dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD}
expected_shards = [
(600.0, [ShardedTest(name="test1", shard=1, num_shards=1, time=THRESHOLD)]),
(600.0, [ShardedTest(name="test2", shard=1, num_shards=1, time=THRESHOLD)]),
(600.0, [ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD)]),
(600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]),
]
self.assert_shards_equal(
expected_shards, calculate_shards(2, list(test_times.keys()), test_times)
expected_shards,
calculate_shards(
2,
[TestRun(t) for t in test_times.keys()],
test_times,
gen_class_times(test_times),
),
)
test_times = {"test1": THRESHOLD * 4, "test2": THRESHOLD * 2.5}
@ -231,35 +279,47 @@ class TestCalculateShards(unittest.TestCase):
(
2200.0,
[
ShardedTest(name="test1", shard=1, num_shards=4, time=600.0),
ShardedTest(name="test1", shard=3, num_shards=4, time=600.0),
ShardedTest(name="test2", shard=1, num_shards=3, time=500.0),
ShardedTest(name="test2", shard=3, num_shards=3, time=500.0),
ShardedTest(test="test1", shard=1, num_shards=4, time=600.0),
ShardedTest(test="test1", shard=3, num_shards=4, time=600.0),
ShardedTest(test="test2", shard=1, num_shards=3, time=500.0),
ShardedTest(test="test2", shard=3, num_shards=3, time=500.0),
],
),
(
1700.0,
[
ShardedTest(name="test1", shard=2, num_shards=4, time=600.0),
ShardedTest(name="test1", shard=4, num_shards=4, time=600.0),
ShardedTest(name="test2", shard=2, num_shards=3, time=500.0),
ShardedTest(test="test1", shard=2, num_shards=4, time=600.0),
ShardedTest(test="test1", shard=4, num_shards=4, time=600.0),
ShardedTest(test="test2", shard=2, num_shards=3, time=500.0),
],
),
]
self.assert_shards_equal(
expected_shards, calculate_shards(2, list(test_times.keys()), test_times)
expected_shards,
calculate_shards(
2,
[TestRun(t) for t in test_times.keys()],
test_times,
gen_class_times(test_times),
),
)
test_times = {"test1": THRESHOLD / 2, "test2": THRESHOLD}
expected_shards = [
(600.0, [ShardedTest(name="test2", shard=1, num_shards=1, time=THRESHOLD)]),
(600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]),
(
300.0,
[ShardedTest(name="test1", shard=1, num_shards=1, time=THRESHOLD / 2)],
[ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD / 2)],
),
]
self.assert_shards_equal(
expected_shards, calculate_shards(2, list(test_times.keys()), test_times)
expected_shards,
calculate_shards(
2,
[TestRun(t) for t in test_times.keys()],
test_times,
gen_class_times(test_times),
),
)
def test_split_shards_random(self) -> None:
@ -272,7 +332,10 @@ class TestCalculateShards(unittest.TestCase):
}
shards = calculate_shards(
num_shards, list(random_times.keys()), random_times
num_shards,
[TestRun(t) for t in random_times.keys()],
random_times,
gen_class_times(random_times),
)
times = [x[0] for x in shards]
@ -300,7 +363,7 @@ class TestCalculateShards(unittest.TestCase):
def test_calculate_2_shards_against_optimal_shards(self) -> None:
random.seed(120)
for _ in range(100):
random_times = {k: random.random() * 10 for k in self.tests}
random_times = {k.test_file: random.random() * 10 for k in self.tests}
# all test times except first two
rest_of_tests = [
i
@ -315,12 +378,14 @@ class TestCalculateShards(unittest.TestCase):
# (sum_of_rest, ['super_long_test', 'long_test1']),
# (sum_of_rest, [i for i in self.tests if i != 'super_long_test' and i != 'long_test1']),
# ]
calculated_shards = calculate_shards(2, self.tests, random_times)
calculated_shards = calculate_shards(
2, self.tests, random_times, gen_class_times(random_times)
)
max_shard_time = max(calculated_shards[0][0], calculated_shards[1][0])
if sum_of_rest != 0:
# The calculated shard should not have a ratio worse than 7/6 for num_shards = 2
self.assertGreaterEqual(7.0 / 6.0, max_shard_time / sum_of_rest)
sorted_tests = sorted(self.tests)
sorted_tests = sorted([t.test_file for t in self.tests])
sorted_shard_tests = sorted(
calculated_shards[0][1] + calculated_shards[1][1]
)

View File

@ -1,20 +1,53 @@
import sys
from abc import abstractmethod
from copy import copy
from enum import Enum
from functools import total_ordering
from itertools import chain
from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Tuple
from typing import (
Any,
Callable,
Dict,
FrozenSet,
Iterable,
Iterator,
List,
Optional,
Tuple,
)
StringTuple = Tuple[str, ...]
from tools.testing.test_run import TestRun, TestRuns
# Note: Keep the implementation of Relevance private to this file so
# that it's easy to change in the future as we discover what's needed
@total_ordering
class Relevance(Enum):
HIGH = 0
PROBABLE = 1
HIGH = 4
PROBABLE = 3
UNRANKED = 2
UNLIKELY = 3 # Not yet supported. Needs more infra to be usable
NONE = 4 # Not yet supported. Needs more infra to be usable
UNLIKELY = 1 # Not yet supported. Needs more infra to be usable
NONE = 0 # Not yet supported. Needs more infra to be usable
def __eq__(self, other: object) -> bool:
if not isinstance(other, Relevance):
return False
return self.value == other.value
def __lt__(self, other: object) -> bool:
if not isinstance(other, Relevance):
raise NotImplementedError(f"Can't compare {self} to {other}")
return self.value < other.value
@staticmethod
def priority_traversal() -> Iterator["Relevance"]:
yield Relevance.HIGH
yield Relevance.PROBABLE
yield Relevance.UNRANKED
yield Relevance.UNLIKELY
yield Relevance.NONE
METRIC_RELEVANCE_GROUP = "relevance_group"
@ -37,7 +70,7 @@ class TestPrioritizations:
otherwise it breaks the test sharding logic
"""
_test_priorities: List[StringTuple] # This list MUST be ordered by Relevance
_test_priorities: List[List[TestRun]] # This list MUST be ordered by Relevance
_original_tests: FrozenSet[str]
def __init__(
@ -51,106 +84,168 @@ class TestPrioritizations:
) -> None:
self._original_tests = frozenset(tests_being_ranked)
self._test_priorities = [tuple() for _ in range(5)]
self._test_priorities = [[] for _ in range(5)]
# Setup the initial priorities
self._test_priorities[Relevance.UNRANKED.value] = [
TestRun(test) for test in tests_being_ranked
]
self._test_priorities[Relevance.HIGH.value] = self.filter_out_extra_tests(
high_relevance
)
self._test_priorities[Relevance.PROBABLE.value] = self.filter_out_extra_tests(
probable_relevance
)
self._test_priorities[Relevance.UNRANKED.value] = self.filter_out_extra_tests(
unranked_relevance
)
self._test_priorities[Relevance.UNLIKELY.value] = self.filter_out_extra_tests(
unlikely_relevance
)
self._test_priorities[Relevance.NONE.value] = self.filter_out_extra_tests(
no_relevance
)
# If any of the original tests were missed from the other lists, add them to the unranked_relevance list
missing_tests = sorted(self._original_tests - set(self.get_all_tests()))
self._test_priorities[Relevance.UNRANKED.value] = self._test_priorities[
Relevance.UNRANKED.value
] + tuple(missing_tests)
for test in high_relevance or []:
self.set_test_relevance(TestRun(test), Relevance.HIGH)
for test in probable_relevance or []:
self.set_test_relevance(TestRun(test), Relevance.PROBABLE)
for test in unranked_relevance or []:
self.set_test_relevance(TestRun(test), Relevance.UNRANKED)
for test in unlikely_relevance or []:
self.set_test_relevance(TestRun(test), Relevance.UNLIKELY)
for test in no_relevance or []:
self.set_test_relevance(TestRun(test), Relevance.NONE)
self.validate_test_priorities()
def filter_out_extra_tests(
self, relevance_group: Optional[List[str]]
) -> StringTuple:
if not relevance_group:
return tuple()
return tuple(filter(lambda test: test in self._original_tests, relevance_group))
def _traverse_priorities(self) -> Iterator[Tuple[Relevance, List[TestRun]]]:
for relevance in Relevance.priority_traversal():
yield (relevance, self._test_priorities[relevance.value])
def get_pointer_to_test(self, test_run: TestRun) -> Iterator[Tuple[Relevance, int]]:
"""
Returns all test runs that contain any subset of the given test_run and their current relevance.
self._test_priorities should NOT have items added or removed form it while iterating over the
results of this function.
"""
# Find a test run that contains the given TestRun and it's relevance.
found_match = False
for relevance, tests in self._traverse_priorities():
for idx, existing_test_run in enumerate(tests):
# Does the existing test run contain any of the test we're looking for?
shared_test = existing_test_run & test_run
if not shared_test.is_empty():
found_match = True
yield (Relevance(relevance), idx)
if not found_match:
raise ValueError(f"Test {test_run} not found in any relevance group")
def _update_test_relevance(
self,
test_run: TestRun,
new_relevance: Relevance,
acceptable_relevance_fn: Callable[[Relevance, Relevance], bool],
) -> None:
"""
Updates the test run's relevance to the new relevance.
If the tests in the test run were previously split up into multiple test runs, all the chunks at a lower
relevance will be merged into one new test run at the new relevance, appended to the end of the relevance group.
However, any tests in a test run that are already at the desired relevance will be left alone, keeping it's
original place in the relevance group.
"""
if test_run.test_file not in self._original_tests:
return # We don't need this test
# The tests covered by test_run could potentially have been split up into
# multiple test runs, each at a different relevance. Let's make sure to bring
# all of them up to the minimum relevance
upgraded_tests = TestRun.empty()
tests_to_remove = []
for curr_relevance, test_run_idx in self.get_pointer_to_test(test_run):
if acceptable_relevance_fn(curr_relevance, new_relevance):
# This test is already at the desired relevance
continue # no changes needed
test_run_to_rerank = self._test_priorities[curr_relevance.value][
test_run_idx
]
# Remove the requested tests from their current relevance group, to be added to the new one
remaining_tests = test_run_to_rerank - test_run
upgraded_tests |= test_run_to_rerank & test_run
# Remove the tests that are being upgraded
if remaining_tests:
self._test_priorities[curr_relevance.value][
test_run_idx
] = remaining_tests
else:
# List traversal prevents us from deleting these immediately, so note them for later
tests_to_remove.append((curr_relevance, test_run_idx))
for relevance, test_idx in tests_to_remove:
del self._test_priorities[relevance.value][test_idx]
# And add them to the desired relevance group
if upgraded_tests:
self._test_priorities[new_relevance.value].append(upgraded_tests)
def set_test_relevance(self, test_run: TestRun, new_relevance: Relevance) -> None:
return self._update_test_relevance(
test_run, new_relevance, lambda curr, new: curr == new
)
def raise_test_relevance(self, test_run: TestRun, new_relevance: Relevance) -> None:
return self._update_test_relevance(
test_run, new_relevance, lambda curr, new: curr >= new
)
def validate_test_priorities(self) -> None:
# Union all TestRuns that contain include/exclude pairs
all_tests = self.get_all_tests()
files = {}
for test in all_tests:
if test.test_file not in files:
files[test.test_file] = copy(test)
else:
files[test.test_file] |= test
for test in files.values():
assert (
test.is_full_file()
), f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that"
# Ensure that the set of tests in the TestPrioritizations is identical to the set of tests passed in
assert self._original_tests == set(
self.get_all_tests()
files.keys()
), "The set of tests in the TestPrioritizations must be identical to the set of tests passed in"
@staticmethod
def _merge_tests(
current_tests: Iterable[str],
new_tests: Iterable[str],
higher_pri_tests: Iterable[str],
) -> StringTuple:
"""
We append all new tests to the current tests, while preserving the sorting on the new_tests
However, exclude any specified tests which have now moved to a higher priority list or tests
that weren't originally in the self's TestPrioritizations
"""
merged_tests = [
test
for test in chain(current_tests, new_tests)
if test not in higher_pri_tests
] # skip the excluded tests
return tuple(dict.fromkeys(merged_tests)) # remove dupes while preseving order
def integrate_priorities(self, other: "TestPrioritizations") -> None:
"""
Integrates priorities from another TestPrioritizations object.
The final result takes all tests from the `self` and rearranges them based on priorities from `other`.
If there are tests mentioned in `other` which are not in `self`, those tests are ignored.
(For example, that can happen if a heuristic reports tests that are not run in the current job)
Currently it will only raise the priority of a test, never lower it.
"""
assert (
self._original_tests == other._original_tests
), "Both tests should stem from the same original test list"
higher_pri_tests: List[str] = []
for relevance, _ in enumerate(self._test_priorities):
self._test_priorities[relevance] = TestPrioritizations._merge_tests(
current_tests=self._test_priorities[relevance],
new_tests=other._test_priorities[relevance],
higher_pri_tests=higher_pri_tests,
)
# Don't let the tests we just added to the current relevance group be added to a lower relevance group
higher_pri_tests.extend(self._test_priorities[relevance])
for relevance, _ in other._traverse_priorities():
if relevance > Relevance.UNRANKED:
for test in other._test_priorities[relevance.value]:
self.raise_test_relevance(test, relevance)
# TODO: Hande the case where a test is moved to a lower relevance group (once we support that scenario)
self.validate_test_priorities()
return
def get_all_tests(self) -> StringTuple:
def get_all_tests(self) -> TestRuns:
"""Returns all tests in the TestPrioritizations"""
return tuple(test for test in chain(*self._test_priorities))
return tuple(chain(*self._test_priorities))
def get_prioritized_tests(self) -> StringTuple:
def get_prioritized_tests(self) -> TestRuns:
return self.get_high_relevance_tests() + self.get_probable_relevance_tests()
def get_high_relevance_tests(self) -> StringTuple:
def get_high_relevance_tests(self) -> TestRuns:
return tuple(test for test in self._test_priorities[Relevance.HIGH.value])
def get_probable_relevance_tests(self) -> StringTuple:
def get_probable_relevance_tests(self) -> TestRuns:
return tuple(test for test in self._test_priorities[Relevance.PROBABLE.value])
def get_unranked_relevance_tests(self) -> StringTuple:
def get_unranked_relevance_tests(self) -> TestRuns:
return tuple(test for test in self._test_priorities[Relevance.UNRANKED.value])
def print_info(self) -> None:
def _print_tests(label: str, tests: StringTuple) -> None:
def _print_tests(label: str, tests: List[TestRun]) -> None:
if not tests:
return
@ -159,48 +254,61 @@ class TestPrioritizations:
if test in tests:
print(f" {test}")
for relevance_group, tests in enumerate(self._test_priorities):
for relevance_group, tests in self._traverse_priorities():
_print_tests(f"{Relevance(relevance_group).name.title()} Relevance", tests)
def _get_test_relevance_group(self, test_name: str) -> Relevance:
"""Returns the priority of a test."""
for relevance_group, tests in enumerate(self._test_priorities):
if test_name in tests:
def _get_test_relevance_group(self, test_run: TestRun) -> Relevance:
"""Returns the rank of the given test run."""
for relevance_group, tests in self._traverse_priorities():
if any(t.contains(test_run) for t in tests):
return Relevance(relevance_group)
raise ValueError(f"Test {test_name} not found in any relevance group")
print("holup, retry")
for relevance_group, tests in self._traverse_priorities():
if any(
t.contains(test_run) for t in tests
): # t could be the entire test_run or a superset
return Relevance(relevance_group)
def _get_test_order(self, test_name: str) -> int:
"""Returns the rank of the test specified by this heuristic."""
raise ValueError(f"Test {test_run} not found in any relevance group")
def _get_test_order(self, test_run: TestRun) -> int:
"""Returns the rank this heuristic suggested for the test run."""
base_rank = 0
for relevance_group_tests in self._test_priorities:
if test_name in relevance_group_tests:
return base_rank + relevance_group_tests.index(test_name)
for _, relevance_group_tests in self._traverse_priorities():
for idx, test in enumerate(relevance_group_tests):
if test.contains(
test_run
): # test could be the entire test_run or a superset
return base_rank + idx
base_rank += len(relevance_group_tests)
raise ValueError(f"Test {test_name} not found in any relevance group")
raise ValueError(f"Test {test_run} not found in any relevance group")
def _get_test_order_within_relevance_group(self, test_name: str) -> int:
for relevance_group_tests in self._test_priorities:
if test_name not in relevance_group_tests:
continue
def _get_test_order_within_relevance_group(self, test_run: TestRun) -> int:
"""Returns the highest test order of any test class within the same relevance group."""
for _, relevance_group_tests in self._traverse_priorities():
for idx, test in enumerate(relevance_group_tests):
if test.contains(
test_run
): # test could be the entire test_run or a superset
return idx
return relevance_group_tests.index(test_name)
raise ValueError(f"Test {test_run} not found in any relevance group")
raise ValueError(f"Test {test_name} not found in any relevance group")
def get_priority_info_for_test(self, test_name: str) -> Dict[str, Any]:
def get_priority_info_for_test(self, test_run: TestRun) -> Dict[str, Any]:
"""Given a failing test, returns information about it's prioritization that we want to emit in our metrics."""
relevance = self._get_test_relevance_group(test_run)
return {
METRIC_RELEVANCE_GROUP: self._get_test_relevance_group(test_name).name,
METRIC_RELEVANCE_GROUP: relevance.name,
METRIC_ORDER_WITHIN_RELEVANCE_GROUP: self._get_test_order_within_relevance_group(
test_name
test_run
),
METRIC_NUM_TESTS_IN_RELEVANCE_GROUP: len(
self._test_priorities[self._get_test_relevance_group(test_name).value]
self._test_priorities[relevance.value]
),
METRIC_ORDER_OVERALL: self._get_test_order(test_name),
METRIC_ORDER_OVERALL: self._get_test_order(test_run),
}
@ -247,12 +355,13 @@ class AggregatedHeuristics:
return aggregated_priorities
def get_test_stats(self, test: str) -> Dict[str, Any]:
def get_test_stats(self, test: TestRun) -> Dict[str, Any]:
"""
Returns the aggregated statistics for a given test.
"""
stats: Dict[str, Any] = {
"test_name": test,
"test_name": test.test_file,
"test_filters": test.get_pytest_filter(),
}
# Get baseline metrics assuming we didn't have any TD heuristics

305
tools/testing/test_run.py Normal file
View File

@ -0,0 +1,305 @@
from copy import copy
from functools import total_ordering
from typing import Iterable, List, Optional, Set, Tuple, Union
class TestRun:
"""
TestRun defines the set of tests that should be run together in a single pytest invocation.
It'll either be a whole test file or a subset of a test file.
This class assumes that we won't always know the full set of TestClasses in a test file.
So it's designed to include or exclude explicitly requested TestClasses, while having accepting
that there will be an ambiguous set of "unknown" test classes that are not expliclty called out.
Those manifest as tests that haven't been explicitly excluded.
"""
test_file: str
_exclued: Set[str] # Tests that should be excluded from this test run
_included: Set[str] # If non-empy, only these tests should be run in this test run
def __init__(
self,
name: str,
excluded: Optional[Iterable[str]] = None,
included: Optional[Iterable[str]] = None,
) -> None:
self._excluded = set()
self._included = set()
if excluded and included:
raise ValueError("Can't specify both included and excluded")
if "::" in name:
assert (
not included and not excluded
), "Can't specify included or excluded tests when specifying a test class in the file name"
self.test_file, test_class = name.split("::")
self._included.add(test_class)
else:
self.test_file = name
# For testing purposes
if excluded:
self._excluded = set(excluded)
if included:
self._included = set(included)
@staticmethod
def empty() -> "TestRun":
return TestRun("")
def is_empty(self) -> bool:
# Lack of a test_file means that this is an empty run,
# which means there is nothing to run. It's the zero.
return not self.test_file
def is_full_file(self) -> bool:
return not self._included and not self._excluded
def included(self) -> Set[str]:
return self._included.copy()
def excluded(self) -> Set[str]:
return self._excluded.copy()
def get_pytest_filter(self) -> str:
if self._included:
return " or ".join(sorted(self._included))
elif self._excluded:
return f"not ({' and '.join(sorted(self._excluded))})"
else:
return ""
def contains(self, test: "TestRun") -> bool:
if self.test_file != test.test_file:
return False
if self.is_full_file():
return True # self contains all tests
if test.is_full_file():
return False # test contains all tests, but self doesn't
# Does self exclude a subset of what test excldes?
if test._excluded:
return test._excluded.issubset(self._excluded)
# Does self include everything test includes?
if self._included:
return test._included.issubset(self._included)
# Getting to here means that test includes and self excludes
# Does self exclude anything test includes? If not, we're good
return not self._excluded.intersection(test._included)
def __copy__(self) -> "TestRun":
return TestRun(self.test_file, excluded=self._excluded, included=self._included)
def __bool__(self) -> bool:
return not self.is_empty()
def __repr__(self) -> str:
r: str = f"RunTest({self.test_file}"
r += f", included: {self._included}" if self._included else ""
r += f", excluded: {self._excluded}" if self._excluded else ""
r += ")"
return r
def __str__(self) -> str:
if self.is_empty():
return "Empty"
pytest_filter = self.get_pytest_filter()
if pytest_filter:
return self.test_file + ", " + pytest_filter
return self.test_file
def __eq__(self, other: object) -> bool:
if not isinstance(other, TestRun):
return False
ret = self.test_file == other.test_file
ret = ret and self._included == other._included
ret = ret and self._excluded == other._excluded
return ret
def __ior__( # noqa: PYI034 Method returns `self`
self, other: "TestRun"
) -> "TestRun":
res = self | other
self.test_file = res.test_file
self._included = res._included
self._excluded = res._excluded
return self
def __or__(self, other: "TestRun") -> "TestRun":
"""
To OR/Union test runs means to run all the tests that either of the two runs specify.
"""
# Is any file empty?
if self.is_empty():
return other
if other.is_empty():
return copy(self)
# If not, ensure we have the same file
assert (
self.test_file == other.test_file
), f"Can't exclude {other} from {self} because they're not the same test file"
# 4 possible cases:
# 1. Either file is the full file, so union is everything
if self.is_full_file() or other.is_full_file():
# The union is the whole file
return TestRun(self.test_file)
# 2. Both files only run what's in _included, so union is the union of the two sets
if self._included and other._included:
return TestRun(
self.test_file, included=self._included.union(other._included)
)
# 3. Both files only exclude what's in _excluded, so union is the intersection of the two sets
if self._excluded and other._excluded:
return TestRun(
self.test_file, excluded=self._excluded.intersection(other._excluded)
)
# 4. One file includes and the other excludes, so we then continue excluding the _excluded set minus
# whatever is in the _included set
included = self._included | other._included
excluded = self._excluded | other._excluded
return TestRun(self.test_file, excluded=excluded - included)
def __isub__( # noqa: PYI034 Method returns `self`
self, other: "TestRun"
) -> "TestRun":
res = self - other
self.test_file = res.test_file
self._included = res._included
self._excluded = res._excluded
return self
def __sub__(self, other: "TestRun") -> "TestRun":
"""
To subtract test runs means to run all the tests in the first run except for what the second run specifies.
"""
# Is any file empty?
if self.is_empty():
return TestRun.empty()
if other.is_empty():
return copy(self)
# Are you trying to subtract tests that don't even exist in this test run?
if self.test_file != other.test_file:
return copy(self)
# You're subtracting everything?
if other.is_full_file():
return TestRun.empty()
def return_inclusions_or_empty(inclusions: Set[str]) -> TestRun:
if inclusions:
return TestRun(self.test_file, included=inclusions)
return TestRun.empty()
if other._included:
if self._included:
return return_inclusions_or_empty(self._included - other._included)
else:
return TestRun(
self.test_file, excluded=self._excluded | other._included
)
else:
if self._included:
return return_inclusions_or_empty(self._included & other._excluded)
else:
return return_inclusions_or_empty(other._excluded - self._excluded)
def __and__(self, other: "TestRun") -> "TestRun":
if self.test_file != other.test_file:
return TestRun.empty()
return (self | other) - (self - other) - (other - self)
TestRuns = Tuple[TestRun, ...]
@total_ordering
class ShardedTest:
test: TestRun
shard: int
num_shards: int
time: Optional[float] # In seconds
def __init__(
self,
test: Union[TestRun, str],
shard: int,
num_shards: int,
time: Optional[float] = None,
) -> None:
if isinstance(test, str):
test = TestRun(test)
self.test = test
self.shard = shard
self.num_shards = num_shards
self.time = time
@property
def name(self) -> str:
return self.test.test_file
def __eq__(self, other: object) -> bool:
if not isinstance(other, ShardedTest):
return False
return (
self.test == other.test
and self.shard == other.shard
and self.num_shards == other.num_shards
and self.time == other.time
)
def __repr__(self) -> str:
ret = f"{self.test} {self.shard}/{self.num_shards}"
if self.time:
ret += f" ({self.time}s)"
return ret
def __lt__(self, other: object) -> bool:
if not isinstance(other, ShardedTest):
raise NotImplementedError
# This is how the list was implicity sorted when it was a NamedTuple
if self.name != other.name:
return self.name < other.name
if self.shard != other.shard:
return self.shard < other.shard
if self.num_shards != other.num_shards:
return self.num_shards < other.num_shards
# None is the smallest value
if self.time is None:
return True
if other.time is None:
return False
return self.time < other.time
def __str__(self) -> str:
return f"{self.test} {self.shard}/{self.num_shards}"
def get_time(self) -> float:
return self.time or 0
def get_pytest_args(self) -> List[str]:
filter = self.test.get_pytest_filter()
if filter:
return ["-k", self.test.get_pytest_filter()]
return []

View File

@ -3,9 +3,10 @@ import os
import subprocess
from pathlib import Path
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
from tools.testing.test_run import ShardedTest, TestRun
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
@ -42,19 +43,6 @@ if IS_ROCM and not IS_MEM_LEAK_CHECK:
NUM_PROCS = 1
class ShardedTest(NamedTuple):
name: str
shard: int
num_shards: int
time: Optional[float] # In seconds
def __str__(self) -> str:
return f"{self.name} {self.shard}/{self.num_shards}"
def get_time(self) -> float:
return self.time or 0
class ShardJob:
def __init__(self) -> None:
self.serial: List[ShardedTest] = []
@ -73,11 +61,51 @@ class ShardJob:
def get_with_pytest_shard(
tests: List[str], test_file_times: Dict[str, float]
tests: Sequence[TestRun],
test_file_times: Dict[str, float],
test_class_times: Optional[Dict[str, Dict[str, float]]],
) -> List[ShardedTest]:
sharded_tests: List[ShardedTest] = []
def get_duration_for_classes(
test_file: str, test_classes: Set[str]
) -> Optional[float]:
duration: float = 0
if not test_class_times:
return None
for test_class in test_classes:
class_duration = test_class_times.get(test_file, {}).get(test_class, None)
if class_duration is None:
return None
if class_duration:
duration += class_duration
return duration
for test in tests:
duration = test_file_times.get(test, None)
file_duration = test_file_times.get(test.test_file, None)
included = test.included()
excluded = test.excluded()
included_classes_duration = get_duration_for_classes(test.test_file, included)
excluded_classes_duration = get_duration_for_classes(test.test_file, excluded)
if included:
# If we don't have the time for all included classes, our upper bound is the file duration
duration = (
included_classes_duration
if included_classes_duration is not None
else file_duration
)
elif excluded:
# If we don't have the time for all excluded classes, our upper bound is file duration
duration = (
file_duration - excluded_classes_duration
if excluded_classes_duration is not None and file_duration is not None
else file_duration
)
else:
duration = file_duration
if duration and duration > THRESHOLD:
num_shards = math.ceil(duration / THRESHOLD)
for i in range(num_shards):
@ -91,21 +119,27 @@ def get_with_pytest_shard(
def calculate_shards(
num_shards: int,
tests: List[str],
tests: Sequence[TestRun],
test_file_times: Dict[str, float],
test_class_times: Optional[Dict[str, Dict[str, float]]],
must_serial: Optional[Callable[[str], bool]] = None,
sort_by_time: bool = True,
) -> List[Tuple[float, List[ShardedTest]]]:
must_serial = must_serial or (lambda x: True)
known_tests = tests
unknown_tests = []
known_tests: Sequence[TestRun] = tests
unknown_tests: Sequence[TestRun] = []
if sort_by_time:
known_tests = [x for x in tests if x in test_file_times]
known_tests = [
x
for x in tests
if x.test_file in test_file_times
or (test_class_times and x.test_file in test_class_times)
]
unknown_tests = [x for x in tests if x not in known_tests]
known_tests = get_with_pytest_shard(known_tests, test_file_times)
known_tests = get_with_pytest_shard(known_tests, test_file_times, test_class_times)
if sort_by_time:
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)