mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Reference: https://docs.astral.sh/ruff/formatter/black/#assert-statements > Unlike Black, Ruff prefers breaking the message over breaking the assertion, similar to how both Ruff and Black prefer breaking the assignment value over breaking the assignment target: > > ```python > # Input > assert ( > len(policy_types) >= priority + num_duplicates > ), f"This tests needs at least {priority+num_duplicates} many types." > > > # Black > assert ( > len(policy_types) >= priority + num_duplicates > ), f"This tests needs at least {priority+num_duplicates} many types." > > # Ruff > assert len(policy_types) >= priority + num_duplicates, ( > f"This tests needs at least {priority + num_duplicates} many types." > ) > ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/144546 Approved by: https://github.com/malfet
334 lines
12 KiB
Python
334 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import abstractmethod
|
|
from copy import copy
|
|
from typing import Any, TYPE_CHECKING
|
|
|
|
from tools.testing.test_run import TestRun
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterable, Iterator
|
|
|
|
|
|
class TestPrioritizations:
|
|
"""
|
|
Describes the results of whether heuristics consider a test relevant or not.
|
|
|
|
All the different ranks of tests are disjoint, meaning a test can only be in one category, and they are only
|
|
declared at initialization time.
|
|
|
|
A list can be empty if a heuristic doesn't consider any tests to be in that category.
|
|
|
|
Important: Lists of tests must always be returned in a deterministic order,
|
|
otherwise it breaks the test sharding logic
|
|
"""
|
|
|
|
_original_tests: frozenset[str]
|
|
_test_scores: dict[TestRun, float]
|
|
|
|
def __init__(
|
|
self,
|
|
tests_being_ranked: Iterable[str], # The tests that are being prioritized.
|
|
scores: dict[TestRun, float],
|
|
) -> None:
|
|
self._original_tests = frozenset(tests_being_ranked)
|
|
self._test_scores = {TestRun(test): 0.0 for test in self._original_tests}
|
|
|
|
for test, score in scores.items():
|
|
self.set_test_score(test, score)
|
|
|
|
self.validate()
|
|
|
|
def validate(self) -> None:
|
|
# Union all TestRuns that contain include/exclude pairs
|
|
all_tests = self._test_scores.keys()
|
|
files = {}
|
|
for test in all_tests:
|
|
if test.test_file not in files:
|
|
files[test.test_file] = copy(test)
|
|
else:
|
|
assert (files[test.test_file] & test).is_empty(), (
|
|
f"Test run `{test}` overlaps with `{files[test.test_file]}`"
|
|
)
|
|
files[test.test_file] |= test
|
|
|
|
for test in files.values():
|
|
assert test.is_full_file(), (
|
|
f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that"
|
|
) # noqa: B950
|
|
|
|
# Ensure that the set of tests in the TestPrioritizations is identical to the set of tests passed in
|
|
assert self._original_tests == set(files.keys()), (
|
|
"The set of tests in the TestPrioritizations must be identical to the set of tests passed in"
|
|
)
|
|
|
|
def _traverse_scores(self) -> Iterator[tuple[float, TestRun]]:
|
|
# Sort by score, then alphabetically by test name
|
|
for test, score in sorted(
|
|
self._test_scores.items(), key=lambda x: (-x[1], str(x[0]))
|
|
):
|
|
yield score, test
|
|
|
|
def set_test_score(self, test_run: TestRun, new_score: float) -> None:
|
|
if test_run.test_file not in self._original_tests:
|
|
return # We don't need this test
|
|
|
|
relevant_test_runs: list[TestRun] = [
|
|
tr for tr in self._test_scores.keys() if tr & test_run and tr != test_run
|
|
]
|
|
|
|
# Set the score of all the tests that are covered by test_run to the same score
|
|
self._test_scores[test_run] = new_score
|
|
# Set the score of all the tests that are not covered by test_run to original score
|
|
for relevant_test_run in relevant_test_runs:
|
|
old_score = self._test_scores[relevant_test_run]
|
|
del self._test_scores[relevant_test_run]
|
|
|
|
not_to_be_updated = relevant_test_run - test_run
|
|
if not not_to_be_updated.is_empty():
|
|
self._test_scores[not_to_be_updated] = old_score
|
|
self.validate()
|
|
|
|
def add_test_score(self, test_run: TestRun, score_to_add: float) -> None:
|
|
if test_run.test_file not in self._original_tests:
|
|
return
|
|
|
|
relevant_test_runs: list[TestRun] = [
|
|
tr for tr in self._test_scores.keys() if tr & test_run
|
|
]
|
|
|
|
for relevant_test_run in relevant_test_runs:
|
|
old_score = self._test_scores[relevant_test_run]
|
|
del self._test_scores[relevant_test_run]
|
|
|
|
intersection = relevant_test_run & test_run
|
|
if not intersection.is_empty():
|
|
self._test_scores[intersection] = old_score + score_to_add
|
|
|
|
not_to_be_updated = relevant_test_run - test_run
|
|
if not not_to_be_updated.is_empty():
|
|
self._test_scores[not_to_be_updated] = old_score
|
|
|
|
self.validate()
|
|
|
|
def get_all_tests(self) -> list[TestRun]:
|
|
"""Returns all tests in the TestPrioritizations"""
|
|
return [x[1] for x in self._traverse_scores()]
|
|
|
|
def get_top_per_tests(self, n: int) -> tuple[list[TestRun], list[TestRun]]:
|
|
"""Divides list of tests into two based on the top n% of scores. The
|
|
first list is the top, and the second is the rest."""
|
|
tests = [x[1] for x in self._traverse_scores()]
|
|
index = n * len(tests) // 100 + 1
|
|
return tests[:index], tests[index:]
|
|
|
|
def get_info_str(self, verbose: bool = True) -> str:
|
|
info = ""
|
|
|
|
for score, test in self._traverse_scores():
|
|
if not verbose and score == 0:
|
|
continue
|
|
info += f" {test} ({score})\n"
|
|
|
|
return info.rstrip()
|
|
|
|
def print_info(self) -> None:
|
|
print(self.get_info_str())
|
|
|
|
def get_priority_info_for_test(self, test_run: TestRun) -> dict[str, Any]:
|
|
"""Given a failing test, returns information about it's prioritization that we want to emit in our metrics."""
|
|
for idx, (score, test) in enumerate(self._traverse_scores()):
|
|
# Different heuristics may result in a given test file being split
|
|
# into different test runs, so look for the overlapping tests to
|
|
# find the match
|
|
if test & test_run:
|
|
return {"position": idx, "score": score}
|
|
raise AssertionError(f"Test run {test_run} not found")
|
|
|
|
def get_test_stats(self, test: TestRun) -> dict[str, Any]:
|
|
return {
|
|
"test_name": test.test_file,
|
|
"test_filters": test.get_pytest_filter(),
|
|
**self.get_priority_info_for_test(test),
|
|
"max_score": max(score for score, _ in self._traverse_scores()),
|
|
"min_score": min(score for score, _ in self._traverse_scores()),
|
|
"all_scores": {
|
|
str(test): score for test, score in self._test_scores.items()
|
|
},
|
|
}
|
|
|
|
def to_json(self) -> dict[str, Any]:
|
|
"""
|
|
Returns a JSON dict that describes this TestPrioritizations object.
|
|
"""
|
|
json_dict = {
|
|
"_test_scores": [
|
|
(test.to_json(), score)
|
|
for test, score in self._test_scores.items()
|
|
if score != 0
|
|
],
|
|
"_original_tests": list(self._original_tests),
|
|
}
|
|
return json_dict
|
|
|
|
@staticmethod
|
|
def from_json(json_dict: dict[str, Any]) -> TestPrioritizations:
|
|
"""
|
|
Returns a TestPrioritizations object from a JSON dict.
|
|
"""
|
|
test_prioritizations = TestPrioritizations(
|
|
tests_being_ranked=json_dict["_original_tests"],
|
|
scores={
|
|
TestRun.from_json(testrun_json): score
|
|
for testrun_json, score in json_dict["_test_scores"]
|
|
},
|
|
)
|
|
return test_prioritizations
|
|
|
|
def amend_tests(self, tests: list[str]) -> None:
|
|
"""
|
|
Removes tests that are not in the given list from the
|
|
TestPrioritizations. Adds tests that are in the list but not in the
|
|
TestPrioritizations.
|
|
"""
|
|
valid_scores = {
|
|
test: score
|
|
for test, score in self._test_scores.items()
|
|
if test.test_file in tests
|
|
}
|
|
self._test_scores = valid_scores
|
|
|
|
for test in tests:
|
|
if test not in self._original_tests:
|
|
self._test_scores[TestRun(test)] = 0
|
|
self._original_tests = frozenset(tests)
|
|
|
|
self.validate()
|
|
|
|
|
|
class AggregatedHeuristics:
|
|
"""
|
|
Aggregates the results across all heuristics.
|
|
|
|
It saves the individual results from each heuristic and exposes an aggregated view.
|
|
"""
|
|
|
|
_heuristic_results: dict[
|
|
HeuristicInterface, TestPrioritizations
|
|
] # Key is the Heuristic's name. Dicts will preserve the order of insertion, which is important for sharding
|
|
|
|
_all_tests: frozenset[str]
|
|
|
|
def __init__(self, all_tests: list[str]) -> None:
|
|
self._all_tests = frozenset(all_tests)
|
|
self._heuristic_results = {}
|
|
self.validate()
|
|
|
|
def validate(self) -> None:
|
|
for heuristic, heuristic_results in self._heuristic_results.items():
|
|
heuristic_results.validate()
|
|
assert heuristic_results._original_tests == self._all_tests, (
|
|
f"Tests in {heuristic.name} are not the same as the tests in the AggregatedHeuristics"
|
|
)
|
|
|
|
def add_heuristic_results(
|
|
self, heuristic: HeuristicInterface, heuristic_results: TestPrioritizations
|
|
) -> None:
|
|
if heuristic in self._heuristic_results:
|
|
raise ValueError(f"We already have heuristics for {heuristic.name}")
|
|
|
|
self._heuristic_results[heuristic] = heuristic_results
|
|
self.validate()
|
|
|
|
def get_aggregated_priorities(
|
|
self, include_trial: bool = False
|
|
) -> TestPrioritizations:
|
|
"""
|
|
Returns the aggregated priorities across all heuristics.
|
|
"""
|
|
valid_heuristics = {
|
|
heuristic: heuristic_results
|
|
for heuristic, heuristic_results in self._heuristic_results.items()
|
|
if not heuristic.trial_mode or include_trial
|
|
}
|
|
|
|
new_tp = TestPrioritizations(self._all_tests, {})
|
|
|
|
for heuristic_results in valid_heuristics.values():
|
|
for score, testrun in heuristic_results._traverse_scores():
|
|
new_tp.add_test_score(testrun, score)
|
|
new_tp.validate()
|
|
return new_tp
|
|
|
|
def get_test_stats(self, test: TestRun) -> dict[str, Any]:
|
|
"""
|
|
Returns the aggregated statistics for a given test.
|
|
"""
|
|
stats: dict[str, Any] = {
|
|
"test_name": test.test_file,
|
|
"test_filters": test.get_pytest_filter(),
|
|
}
|
|
|
|
# Get metrics about the heuristics used
|
|
heuristics = []
|
|
|
|
for heuristic, heuristic_results in self._heuristic_results.items():
|
|
metrics = heuristic_results.get_priority_info_for_test(test)
|
|
metrics["heuristic_name"] = heuristic.name
|
|
metrics["trial_mode"] = heuristic.trial_mode
|
|
heuristics.append(metrics)
|
|
|
|
stats["heuristics"] = heuristics
|
|
|
|
stats["aggregated"] = (
|
|
self.get_aggregated_priorities().get_priority_info_for_test(test)
|
|
)
|
|
|
|
stats["aggregated_trial"] = self.get_aggregated_priorities(
|
|
include_trial=True
|
|
).get_priority_info_for_test(test)
|
|
|
|
return stats
|
|
|
|
def to_json(self) -> dict[str, Any]:
|
|
"""
|
|
Returns a JSON dict that describes this AggregatedHeuristics object.
|
|
"""
|
|
json_dict: dict[str, Any] = {}
|
|
for heuristic, heuristic_results in self._heuristic_results.items():
|
|
json_dict[heuristic.name] = heuristic_results.to_json()
|
|
|
|
return json_dict
|
|
|
|
|
|
class HeuristicInterface:
|
|
"""
|
|
Interface for all heuristics.
|
|
"""
|
|
|
|
description: str
|
|
|
|
# When trial mode is set to True, this heuristic's predictions will not be used
|
|
# to reorder tests. It's results will however be emitted in the metrics.
|
|
trial_mode: bool
|
|
|
|
@abstractmethod
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
self.trial_mode = kwargs.get("trial_mode", False) # type: ignore[assignment]
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self.__class__.__name__
|
|
|
|
def __str__(self) -> str:
|
|
return self.name
|
|
|
|
@abstractmethod
|
|
def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations:
|
|
"""
|
|
Returns a float ranking ranging from -1 to 1, where negative means skip,
|
|
positive means run, 0 means no idea, and magnitude = how confident the
|
|
heuristic is. Used by AggregatedHeuristicsRankings.
|
|
"""
|