mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Sharing code between the code that handles test results in parallel vs serial mode. Note that the original version of this code had an inconsistency between the two versions where it would execute `print_to_stderr(err_message)` on every test that ran in parallel, but for serial tests it would only invoke `print_to_stderr(err_message)` if `continue_on_error` was also specified. By sharing code, this PR changes that behavior to be consistent between the two modes. Also adding some comments. <!-- copilot:poem --> ### <samp>🤖 Generated by Copilot at 029342c</samp> > _Sing, O Muse, of the skillful coder who refined_ > _The PyTorch testing script, `run_test.py`, and shined_ > _A light on its obscure logic, with docstrings and comments_ > _And made it run more smoothly, with better error contents_ Pull Request resolved: https://github.com/pytorch/pytorch/pull/99467 Approved by: https://github.com/huydhn, https://github.com/malfet
175 lines
6.0 KiB
Python
175 lines
6.0 KiB
Python
import math
|
|
import os
|
|
import subprocess
|
|
|
|
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
|
|
|
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
|
|
|
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
|
|
|
|
NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 2
|
|
THRESHOLD = 60 * 10 # 10 minutes
|
|
|
|
# See Note [ROCm parallel CI testing]
|
|
# Special logic for ROCm GHA runners to query number of GPUs available.
|
|
# torch.version.hip was not available to check if this was a ROCm self-hosted runner.
|
|
# Must check for ROCm runner in another way. We look for /opt/rocm directory.
|
|
if os.path.exists("/opt/rocm") and not IS_MEM_LEAK_CHECK:
|
|
try:
|
|
# This is the same logic used in GHA health check, see .github/templates/common.yml.j2
|
|
lines = (
|
|
subprocess.check_output(["rocminfo"], encoding="ascii").strip().split("\n")
|
|
)
|
|
count = 0
|
|
for line in lines:
|
|
if " gfx" in line:
|
|
count += 1
|
|
assert count > 0 # there must be at least 1 GPU
|
|
NUM_PROCS = count
|
|
except subprocess.CalledProcessError as e:
|
|
# The safe default for ROCm GHA runners is to run tests serially.
|
|
NUM_PROCS = 1
|
|
|
|
|
|
class ShardedTest(NamedTuple):
|
|
name: str
|
|
shard: int
|
|
num_shards: int
|
|
time: Optional[float]
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.name} {self.shard}/{self.num_shards}"
|
|
|
|
def get_time(self) -> float:
|
|
return self.time or 0
|
|
|
|
|
|
class ShardJob:
|
|
def __init__(self) -> None:
|
|
self.serial: List[ShardedTest] = []
|
|
self.parallel: List[ShardedTest] = []
|
|
|
|
def get_total_time(self) -> float:
|
|
procs = [0.0 for _ in range(NUM_PROCS)]
|
|
for test in self.parallel:
|
|
min_index = procs.index(min(procs))
|
|
procs[min_index] += test.get_time()
|
|
time = max(procs) + sum(test.get_time() for test in self.serial)
|
|
return time
|
|
|
|
def convert_to_tuple(self) -> Tuple[float, List[ShardedTest]]:
|
|
return (self.get_total_time(), self.serial + self.parallel)
|
|
|
|
|
|
def get_with_pytest_shard(
|
|
tests: List[str], test_file_times: Dict[str, float]
|
|
) -> List[ShardedTest]:
|
|
sharded_tests: List[ShardedTest] = []
|
|
for test in tests:
|
|
duration = test_file_times[test]
|
|
if duration > THRESHOLD:
|
|
num_shards = math.ceil(duration / THRESHOLD)
|
|
for i in range(num_shards):
|
|
sharded_tests.append(
|
|
ShardedTest(test, i + 1, num_shards, duration / num_shards)
|
|
)
|
|
else:
|
|
sharded_tests.append(ShardedTest(test, 1, 1, duration))
|
|
return sharded_tests
|
|
|
|
|
|
def calculate_shards(
|
|
num_shards: int,
|
|
tests: List[str],
|
|
test_file_times: Dict[str, float],
|
|
must_serial: Optional[Callable[[str], bool]] = None,
|
|
) -> List[Tuple[float, List[ShardedTest]]]:
|
|
must_serial = must_serial or (lambda x: True)
|
|
|
|
known_tests = [x for x in tests if x in test_file_times]
|
|
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
|
|
|
|
sorted_tests = sorted(
|
|
get_with_pytest_shard(known_tests, test_file_times),
|
|
key=lambda j: j.get_time(),
|
|
reverse=True,
|
|
)
|
|
|
|
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
|
|
for test in sorted_tests:
|
|
if must_serial(test.name):
|
|
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
|
min_sharded_job.serial.append(test)
|
|
else:
|
|
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
|
min_sharded_job.parallel.append(test)
|
|
|
|
# Round robin the unknown jobs starting with the smallest shard
|
|
index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time())
|
|
for unknown_test in unknown_tests:
|
|
sharded_jobs[index].serial.append(ShardedTest(unknown_test, 1, 1, None))
|
|
index = (index + 1) % num_shards
|
|
return [job.convert_to_tuple() for job in sharded_jobs]
|
|
|
|
|
|
def _query_changed_test_files() -> List[str]:
|
|
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'master')}"
|
|
cmd = ["git", "diff", "--name-only", default_branch, "HEAD"]
|
|
proc = subprocess.run(cmd, capture_output=True)
|
|
|
|
if proc.returncode != 0:
|
|
raise RuntimeError("Unable to get changed files")
|
|
|
|
lines = proc.stdout.decode().strip().split("\n")
|
|
lines = [line.strip() for line in lines]
|
|
return lines
|
|
|
|
|
|
def get_reordered_tests(tests: List[str]) -> List[str]:
|
|
"""
|
|
Get the reordered test filename list based on github PR history or git changed file.
|
|
We prioritize running test files that were changed.
|
|
"""
|
|
prioritized_tests: List[str] = []
|
|
if len(prioritized_tests) == 0:
|
|
try:
|
|
changed_files = _query_changed_test_files()
|
|
except Exception:
|
|
# If unable to get changed files from git, quit without doing any sorting
|
|
return tests
|
|
|
|
prefix = f"test{os.path.sep}"
|
|
prioritized_tests = [
|
|
f for f in changed_files if f.startswith(prefix) and f.endswith(".py")
|
|
]
|
|
prioritized_tests = [f[len(prefix) :] for f in prioritized_tests]
|
|
prioritized_tests = [f[: -len(".py")] for f in prioritized_tests]
|
|
print("Prioritized test from test file changes.")
|
|
|
|
bring_to_front = []
|
|
the_rest = []
|
|
|
|
for test in tests:
|
|
if test in prioritized_tests:
|
|
bring_to_front.append(test)
|
|
else:
|
|
the_rest.append(test)
|
|
if len(tests) == len(bring_to_front) + len(the_rest):
|
|
print(
|
|
f"reordering tests for PR:\n"
|
|
f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n"
|
|
)
|
|
return bring_to_front + the_rest
|
|
else:
|
|
print(
|
|
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
|
|
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
|
|
)
|
|
return tests
|
|
|
|
|
|
def get_test_case_configs(dirpath: str) -> None:
|
|
get_slow_tests(dirpath=dirpath)
|
|
get_disabled_tests(dirpath=dirpath)
|