mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Since we are using Rockset for all this now, remove the code that used the S3 path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81163 Approved by: https://github.com/janeyx99
99 lines
3.5 KiB
Python
99 lines
3.5 KiB
Python
import os
|
|
import subprocess
|
|
|
|
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
def calculate_shards(
|
|
num_shards: int, tests: List[str], job_times: Dict[str, float]
|
|
) -> List[Tuple[float, List[str]]]:
|
|
filtered_job_times: Dict[str, float] = dict()
|
|
unknown_jobs: List[str] = []
|
|
for test in tests:
|
|
if test in job_times:
|
|
filtered_job_times[test] = job_times[test]
|
|
else:
|
|
unknown_jobs.append(test)
|
|
|
|
# The following attempts to implement a partition approximation greedy algorithm
|
|
# See more at https://en.wikipedia.org/wiki/Greedy_number_partitioning
|
|
sorted_jobs = sorted(
|
|
filtered_job_times, key=lambda j: filtered_job_times[j], reverse=True
|
|
)
|
|
sharded_jobs: List[Tuple[float, List[str]]] = [(0.0, []) for _ in range(num_shards)]
|
|
for job in sorted_jobs:
|
|
min_shard_index = sorted(range(num_shards), key=lambda i: sharded_jobs[i][0])[0]
|
|
curr_shard_time, curr_shard_jobs = sharded_jobs[min_shard_index]
|
|
curr_shard_jobs.append(job)
|
|
sharded_jobs[min_shard_index] = (
|
|
curr_shard_time + filtered_job_times[job],
|
|
curr_shard_jobs,
|
|
)
|
|
|
|
# Round robin the unknown jobs starting with the smallest shard
|
|
index = sorted(range(num_shards), key=lambda i: sharded_jobs[i][0])[0]
|
|
for job in unknown_jobs:
|
|
sharded_jobs[index][1].append(job)
|
|
index = (index + 1) % num_shards
|
|
return sharded_jobs
|
|
|
|
|
|
def _query_changed_test_files() -> List[str]:
|
|
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'master')}"
|
|
cmd = ["git", "diff", "--name-only", default_branch, "HEAD"]
|
|
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
|
|
if proc.returncode != 0:
|
|
raise RuntimeError("Unable to get changed files")
|
|
|
|
lines = proc.stdout.decode().strip().split("\n")
|
|
lines = [line.strip() for line in lines]
|
|
return lines
|
|
|
|
|
|
def get_reordered_tests(tests: List[str]) -> List[str]:
|
|
"""Get the reordered test filename list based on github PR history or git changed file."""
|
|
prioritized_tests: List[str] = []
|
|
if len(prioritized_tests) == 0:
|
|
try:
|
|
changed_files = _query_changed_test_files()
|
|
except Exception:
|
|
# If unable to get changed files from git, quit without doing any sorting
|
|
return tests
|
|
|
|
prefix = f"test{os.path.sep}"
|
|
prioritized_tests = [
|
|
f for f in changed_files if f.startswith(prefix) and f.endswith(".py")
|
|
]
|
|
prioritized_tests = [f[len(prefix) :] for f in prioritized_tests]
|
|
prioritized_tests = [f[: -len(".py")] for f in prioritized_tests]
|
|
print("Prioritized test from test file changes.")
|
|
|
|
bring_to_front = []
|
|
the_rest = []
|
|
|
|
for test in tests:
|
|
if test in prioritized_tests:
|
|
bring_to_front.append(test)
|
|
else:
|
|
the_rest.append(test)
|
|
if len(tests) == len(bring_to_front) + len(the_rest):
|
|
print(
|
|
f"reordering tests for PR:\n"
|
|
f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n"
|
|
)
|
|
return bring_to_front + the_rest
|
|
else:
|
|
print(
|
|
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
|
|
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
|
|
)
|
|
return tests
|
|
|
|
|
|
def get_test_case_configs(dirpath: str) -> None:
|
|
get_slow_tests(dirpath=dirpath)
|
|
get_disabled_tests(dirpath=dirpath)
|