mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Retry at test file level (#97506)"
This reverts commit 7d5d5beba2.
Reverted https://github.com/pytorch/pytorch/pull/97506 on behalf of https://github.com/clee2000 due to test_jit_cuda_fuser having a rough time
This commit is contained in:
parent
3a5ca4bdd4
commit
675dfd2c1f
271
test/run_test.py
271
test/run_test.py
|
|
@ -24,7 +24,6 @@ from torch.testing._internal.common_utils import (
|
|||
IS_CI,
|
||||
is_slow_gradcheck_env,
|
||||
parser as common_parser,
|
||||
retry_shell,
|
||||
set_cwd,
|
||||
shell,
|
||||
TEST_WITH_ROCM,
|
||||
|
|
@ -42,8 +41,6 @@ try:
|
|||
get_reordered_tests,
|
||||
get_test_case_configs,
|
||||
NUM_PROCS,
|
||||
ShardedTest,
|
||||
THRESHOLD,
|
||||
)
|
||||
|
||||
HAVE_TEST_SELECTION_TOOLS = True
|
||||
|
|
@ -54,9 +51,6 @@ except ImportError:
|
|||
)
|
||||
|
||||
|
||||
RERUN_DISABLED_TESTS = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1"
|
||||
|
||||
|
||||
# Note [ROCm parallel CI testing]
|
||||
# https://github.com/pytorch/pytorch/pull/85770 added file-granularity parallel testing.
|
||||
# In .ci/pytorch/test.sh, TEST_CONFIG == "default", CUDA and HIP_VISIBLE_DEVICES is set to 0.
|
||||
|
|
@ -283,13 +277,6 @@ CI_SERIAL_LIST = [
|
|||
"_nvfuser/test_torchscript", # OOM on test_issue_1785
|
||||
"test_schema_check", # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/95749
|
||||
"functorch/test_memory_efficient_fusion", # Cause CUDA OOM on ROCm
|
||||
"test_utils", # OOM
|
||||
"test_sort_and_select", # OOM
|
||||
"test_backward_compatible_arguments", # OOM
|
||||
"test_module_init", # OOM
|
||||
"test_autocast", # OOM
|
||||
"test_native_mha", # OOM
|
||||
"test_module_hooks", # OOM
|
||||
]
|
||||
|
||||
# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected
|
||||
|
|
@ -304,6 +291,19 @@ CORE_TEST_LIST = [
|
|||
"test_torch",
|
||||
]
|
||||
|
||||
# A list of distributed tests that run on multiple backends, i.e. gloo, nccl. These backends are spread out
|
||||
# among all available shards to speed up the test. The list of backends are copied from the tests themselves
|
||||
DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS = {
|
||||
"distributed/test_distributed_spawn": [
|
||||
"gloo",
|
||||
"nccl",
|
||||
"ucc",
|
||||
],
|
||||
"distributed/algorithms/quantization/test_quantization": [
|
||||
"gloo",
|
||||
"nccl",
|
||||
],
|
||||
}
|
||||
|
||||
# if a test file takes longer than 5 min, we add it to TARGET_DET_LIST
|
||||
SLOW_TEST_THRESHOLD = 300
|
||||
|
|
@ -406,19 +406,9 @@ def run_test(
|
|||
) -> int:
|
||||
maybe_set_hip_visible_devies()
|
||||
unittest_args = options.additional_unittest_args.copy()
|
||||
test_file = test_module
|
||||
if isinstance(test_file, ShardedTest):
|
||||
unittest_args.extend(
|
||||
[
|
||||
f"--shard-id={test_module.shard - 1}",
|
||||
f"--num-shards={test_module.num_shards}",
|
||||
]
|
||||
)
|
||||
test_file = test_module.name
|
||||
|
||||
if options.verbose:
|
||||
unittest_args.append(f'-{"v"*options.verbose}') # in case of pytest
|
||||
if test_file in RUN_PARALLEL_BLOCKLIST:
|
||||
if test_module in RUN_PARALLEL_BLOCKLIST:
|
||||
unittest_args = [
|
||||
arg for arg in unittest_args if not arg.startswith("--run-parallel")
|
||||
]
|
||||
|
|
@ -432,11 +422,11 @@ def run_test(
|
|||
unittest_args = [arg if arg != "-f" else "-x" for arg in unittest_args]
|
||||
if IS_CI:
|
||||
ci_args = ["--import-slow-tests", "--import-disabled-tests"]
|
||||
if RERUN_DISABLED_TESTS:
|
||||
if os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1":
|
||||
ci_args.append("--rerun-disabled-tests")
|
||||
# use the downloaded test cases configuration, not supported in pytest
|
||||
unittest_args.extend(ci_args)
|
||||
if test_file in PYTEST_SKIP_RETRIES:
|
||||
if test_module in PYTEST_SKIP_RETRIES:
|
||||
if not options.pytest:
|
||||
raise RuntimeError(
|
||||
"A test running without pytest cannot skip retries using "
|
||||
|
|
@ -449,35 +439,19 @@ def run_test(
|
|||
|
||||
# Can't call `python -m unittest test_*` here because it doesn't run code
|
||||
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
|
||||
argv = [test_file + ".py"] + unittest_args
|
||||
argv = [test_module + ".py"] + unittest_args
|
||||
|
||||
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
|
||||
log_fd, log_path = tempfile.mkstemp(
|
||||
dir=REPO_ROOT / "test" / "test-reports",
|
||||
prefix="{}_".format(test_file.replace("\\", "-").replace("/", "-")),
|
||||
prefix="{}_".format(test_module.replace("\\", "-").replace("/", "-")),
|
||||
suffix=".log",
|
||||
)
|
||||
os.close(log_fd)
|
||||
command = (launcher_cmd or []) + executable + argv
|
||||
should_file_rerun = (
|
||||
"--subprocess" not in command
|
||||
and not RERUN_DISABLED_TESTS
|
||||
and isinstance(test_module, ShardedTest)
|
||||
and test_module.time is not None
|
||||
)
|
||||
timeout = THRESHOLD * 2 if should_file_rerun else None
|
||||
print_to_stderr("Executing {} ... [{}]".format(command, datetime.now()))
|
||||
with open(log_path, "w") as f:
|
||||
ret_code = retry_shell(
|
||||
command,
|
||||
test_directory,
|
||||
stdout=f,
|
||||
stderr=f,
|
||||
env=env,
|
||||
timeout=timeout,
|
||||
retries=1 if should_file_rerun else 0,
|
||||
)
|
||||
|
||||
ret_code = shell(command, test_directory, stdout=f, stderr=f, env=env)
|
||||
print_log_file(test_module, log_path, failed=(ret_code != 0))
|
||||
os.remove(log_path)
|
||||
return ret_code
|
||||
|
|
@ -575,12 +549,33 @@ def test_distributed(test_module, test_directory, options):
|
|||
if options.verbose and not mpi_available:
|
||||
print_to_stderr("MPI not available -- MPI backend tests will be skipped")
|
||||
|
||||
if options.shard:
|
||||
which_shard, num_shards = options.shard
|
||||
else:
|
||||
which_shard = num_shards = 1
|
||||
# Round-robin all backends to different shards
|
||||
backend_to_shard = {
|
||||
backend: i % num_shards + 1
|
||||
for i, backend in enumerate(
|
||||
DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS[test_module]
|
||||
)
|
||||
}
|
||||
print_to_stderr(
|
||||
f"Map different backends to different shards for {test_module}: {backend_to_shard}"
|
||||
)
|
||||
|
||||
config = DISTRIBUTED_TESTS_CONFIG
|
||||
for backend, env_vars in config.items():
|
||||
if sys.platform == "win32" and backend != "gloo":
|
||||
continue
|
||||
if backend == "mpi" and not mpi_available:
|
||||
continue
|
||||
# Default to the first shard if seeing an unrecognized backend
|
||||
if which_shard != backend_to_shard.get(backend, 1):
|
||||
print_to_stderr(
|
||||
f"Shard {which_shard}: {backend} should be run in {backend_to_shard.get(backend, 1)}"
|
||||
)
|
||||
continue
|
||||
for with_init_file in {True, False}:
|
||||
if sys.platform == "win32" and not with_init_file:
|
||||
continue
|
||||
|
|
@ -589,8 +584,8 @@ def test_distributed(test_module, test_directory, options):
|
|||
init_str = "with {} init_method"
|
||||
with_init = init_str.format("file" if with_init_file else "env")
|
||||
print_to_stderr(
|
||||
"Running distributed tests for the {} backend {}".format(
|
||||
backend, with_init
|
||||
"Running distributed tests for the {} backend {} in shard {} of {}".format(
|
||||
backend, with_init, which_shard, num_shards
|
||||
)
|
||||
)
|
||||
old_environ = dict(os.environ)
|
||||
|
|
@ -599,7 +594,7 @@ def test_distributed(test_module, test_directory, options):
|
|||
os.environ["INIT_METHOD"] = "env://"
|
||||
os.environ.update(env_vars)
|
||||
if with_init_file:
|
||||
if test_module.name == "test_distributed_spawn":
|
||||
if test_module == "test_distributed_spawn":
|
||||
init_method = f"{FILE_SCHEMA}{tmp_dir}/"
|
||||
else:
|
||||
init_method = f"{FILE_SCHEMA}{tmp_dir}/shared_init_file"
|
||||
|
|
@ -783,7 +778,6 @@ def run_doctests(test_module, test_directory, options):
|
|||
|
||||
def print_log_file(test: str, file_path: str, failed: bool) -> None:
|
||||
num_lines = sum(1 for _ in open(file_path, "rb"))
|
||||
test = str(test)
|
||||
n = 100
|
||||
with open(file_path, "r") as f:
|
||||
print_to_stderr("")
|
||||
|
|
@ -811,7 +805,7 @@ def print_log_file(test: str, file_path: str, failed: bool) -> None:
|
|||
|
||||
|
||||
def get_pytest_args(options):
|
||||
if RERUN_DISABLED_TESTS:
|
||||
if os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1":
|
||||
# When under rerun-disabled-tests mode, run the same tests multiple times to determine their
|
||||
# flakiness status. Default to 50 re-runs
|
||||
rerun_options = ["--flake-finder", "--flake-runs=50"]
|
||||
|
|
@ -821,7 +815,7 @@ def get_pytest_args(options):
|
|||
else:
|
||||
# When under the normal mode, retry a failed test 2 more times. -x means stop at the first
|
||||
# failure
|
||||
rerun_options = ["-x", "--reruns=2", "--sw"]
|
||||
rerun_options = ["-x", "--reruns=2"]
|
||||
|
||||
pytest_args = [
|
||||
"--use-pytest",
|
||||
|
|
@ -834,6 +828,55 @@ def get_pytest_args(options):
|
|||
return pytest_args
|
||||
|
||||
|
||||
def run_test_ops(test_module, test_directory, options):
|
||||
default_unittest_args = get_pytest_args(options)
|
||||
|
||||
return_codes = []
|
||||
os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS)
|
||||
pool = get_context("spawn").Pool(NUM_PROCS)
|
||||
for i in range(NUM_PROCS):
|
||||
extra_unittest_args = default_unittest_args.copy()
|
||||
extra_unittest_args.extend(
|
||||
[
|
||||
f"--shard-id={i}",
|
||||
f"--num-shards={NUM_PROCS}",
|
||||
"-k=not _linalg_cholesky_",
|
||||
]
|
||||
)
|
||||
|
||||
return_code = pool.apply_async(
|
||||
run_test,
|
||||
args=(test_module, test_directory, copy.deepcopy(options)),
|
||||
kwds={
|
||||
"extra_unittest_args": extra_unittest_args,
|
||||
},
|
||||
)
|
||||
return_codes.append(return_code)
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
del os.environ["NUM_PARALLEL_PROCS"]
|
||||
|
||||
for return_code in return_codes:
|
||||
if return_code.get() != 0:
|
||||
return return_code.get()
|
||||
|
||||
extra_unittest_args = default_unittest_args.copy()
|
||||
extra_unittest_args.extend(
|
||||
[
|
||||
"-k=_linalg_cholesky_",
|
||||
]
|
||||
)
|
||||
|
||||
return_code = run_test(
|
||||
test_module,
|
||||
test_directory,
|
||||
copy.deepcopy(options),
|
||||
extra_unittest_args=extra_unittest_args,
|
||||
)
|
||||
return return_code
|
||||
|
||||
|
||||
CUSTOM_HANDLERS = {
|
||||
"test_cuda_primary_ctx": test_cuda_primary_ctx,
|
||||
"test_cuda_nvml_based_avail": get_run_test_with_subprocess_fn(),
|
||||
|
|
@ -855,6 +898,15 @@ CUSTOM_HANDLERS = {
|
|||
"distributed/rpc/test_share_memory": get_run_test_with_subprocess_fn(),
|
||||
"distributed/rpc/cuda/test_tensorpipe_agent": get_run_test_with_subprocess_fn(),
|
||||
"doctests": run_doctests,
|
||||
"inductor/test_torchinductor_opinfo": run_test_ops,
|
||||
"test_ops": run_test_ops,
|
||||
"test_ops_gradients": run_test_ops,
|
||||
"test_ops_fwd_gradients": run_test_ops,
|
||||
"test_ops_jit": run_test_ops,
|
||||
"functorch/test_ops": run_test_ops,
|
||||
# not a test_ops file, but takes 2 hrs on some architectures and
|
||||
# run_test_ops is good at parallelizing things
|
||||
"test_decomp": run_test_ops,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -1127,7 +1179,13 @@ def get_selected_tests(options):
|
|||
|
||||
if options.distributed_tests:
|
||||
selected_tests = list(
|
||||
filter(lambda test_name: test_name in DISTRIBUTED_TESTS, selected_tests)
|
||||
filter(
|
||||
lambda test_name: (
|
||||
test_name in DISTRIBUTED_TESTS
|
||||
and test_name not in DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS
|
||||
),
|
||||
selected_tests,
|
||||
)
|
||||
)
|
||||
|
||||
# Filter to only run core tests when --core option is specified
|
||||
|
|
@ -1188,6 +1246,46 @@ def get_selected_tests(options):
|
|||
elif TEST_WITH_ROCM:
|
||||
selected_tests = exclude_tests(ROCM_BLOCKLIST, selected_tests, "on ROCm")
|
||||
|
||||
# sharding
|
||||
if options.shard:
|
||||
assert len(options.shard) == 2, "Unexpected shard format"
|
||||
assert min(options.shard) > 0, "Shards must be positive numbers"
|
||||
which_shard, num_shards = options.shard
|
||||
assert (
|
||||
which_shard <= num_shards
|
||||
), "Selected shard must be less than or equal to total number of shards"
|
||||
assert num_shards <= len(
|
||||
selected_tests
|
||||
), f"Number of shards must be less than {len(selected_tests)}"
|
||||
|
||||
if num_shards == 1:
|
||||
return selected_tests
|
||||
|
||||
# Download previous test times to make sharding decisions
|
||||
path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE)
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
test_file_times = cast(Dict[str, Any], json.load(f))
|
||||
else:
|
||||
test_file_times = {}
|
||||
test_config = os.environ.get("TEST_CONFIG")
|
||||
if test_config not in test_file_times:
|
||||
print(
|
||||
"::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan."
|
||||
)
|
||||
selected_tests = selected_tests[which_shard - 1 :: num_shards]
|
||||
else:
|
||||
print("Found test time stats from artifacts")
|
||||
test_file_times_config = test_file_times[test_config]
|
||||
shards = calculate_shards(
|
||||
num_shards,
|
||||
selected_tests,
|
||||
test_file_times_config,
|
||||
must_serial=must_serial,
|
||||
)
|
||||
_, tests_from_shard = shards[which_shard - 1]
|
||||
selected_tests = tests_from_shard
|
||||
|
||||
# skip all distributed tests if distributed package is not available.
|
||||
if not dist.is_available():
|
||||
selected_tests = exclude_tests(
|
||||
|
|
@ -1213,61 +1311,30 @@ def get_selected_tests(options):
|
|||
exact_match=True,
|
||||
)
|
||||
|
||||
selected_tests = [parse_test_module(x) for x in selected_tests]
|
||||
|
||||
# sharding
|
||||
which_shard, num_shards = 1, 1
|
||||
if options.shard:
|
||||
assert len(options.shard) == 2, "Unexpected shard format"
|
||||
assert min(options.shard) > 0, "Shards must be positive numbers"
|
||||
which_shard, num_shards = options.shard
|
||||
assert (
|
||||
which_shard <= num_shards
|
||||
), "Selected shard must be less than or equal to total number of shards"
|
||||
assert num_shards <= len(
|
||||
selected_tests
|
||||
), f"Number of shards must be less than {len(selected_tests)}"
|
||||
|
||||
# Download previous test times to make sharding decisions
|
||||
path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE)
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
test_file_times = cast(Dict[str, Any], json.load(f))
|
||||
else:
|
||||
test_file_times = {}
|
||||
test_config = os.environ.get("TEST_CONFIG")
|
||||
if test_config not in test_file_times:
|
||||
print(
|
||||
"::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan."
|
||||
)
|
||||
else:
|
||||
print("Found test time stats from artifacts")
|
||||
|
||||
# Do sharding
|
||||
test_file_times_config = test_file_times.get(test_config, {})
|
||||
shards = calculate_shards(
|
||||
num_shards, selected_tests, test_file_times_config, must_serial=must_serial
|
||||
)
|
||||
_, tests_from_shard = shards[which_shard - 1]
|
||||
selected_tests = tests_from_shard
|
||||
if options.distributed_tests:
|
||||
# Run distributed tests with multiple backends across all shards, one per backend
|
||||
selected_tests.extend(DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS.keys())
|
||||
selected_tests.reverse()
|
||||
|
||||
return selected_tests
|
||||
|
||||
|
||||
def run_test_module(test: ShardedTest, test_directory: str, options) -> Optional[str]:
|
||||
def run_test_module(test: str, test_directory: str, options) -> Optional[str]:
|
||||
maybe_set_hip_visible_devies()
|
||||
|
||||
test_module = parse_test_module(test)
|
||||
|
||||
# Printing the date here can help diagnose which tests are slow
|
||||
print_to_stderr("Running {} ... [{}]".format(str(test), datetime.now()))
|
||||
handler = CUSTOM_HANDLERS.get(test.name, run_test)
|
||||
return_code = handler(test, test_directory, options)
|
||||
print_to_stderr("Running {} ... [{}]".format(test, datetime.now()))
|
||||
handler = CUSTOM_HANDLERS.get(test_module, run_test)
|
||||
return_code = handler(test_module, test_directory, options)
|
||||
assert isinstance(return_code, int) and not isinstance(
|
||||
return_code, bool
|
||||
), f"While running {str(test)} got non integer return code {return_code}"
|
||||
), f"While running {test} got non integer return code {return_code}"
|
||||
if return_code == 0:
|
||||
return None
|
||||
|
||||
message = f"{str(test)} failed!"
|
||||
message = f"{test} failed!"
|
||||
if return_code < 0:
|
||||
# subprocess.Popen returns the child process' exit signal as
|
||||
# return code -N, where N is the signal number.
|
||||
|
|
@ -1283,9 +1350,7 @@ def main():
|
|||
selected_tests = get_selected_tests(options)
|
||||
|
||||
if options.verbose:
|
||||
print_to_stderr(
|
||||
"Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests))
|
||||
)
|
||||
print_to_stderr("Selected tests:\n {}".format("\n ".join(selected_tests)))
|
||||
|
||||
if options.dry_run:
|
||||
return
|
||||
|
|
@ -1307,18 +1372,18 @@ def main():
|
|||
|
||||
# parallel = in parallel with other files
|
||||
# serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops)
|
||||
selected_tests_parallel = [x for x in selected_tests if not must_serial(x.name)]
|
||||
selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
|
||||
selected_tests_serial = [
|
||||
x for x in selected_tests if x not in selected_tests_parallel
|
||||
]
|
||||
print_to_stderr(
|
||||
"parallel (file granularity) tests:\n {}".format(
|
||||
"\n ".join(str(x) for x in selected_tests_parallel)
|
||||
"\n ".join(selected_tests_parallel)
|
||||
)
|
||||
)
|
||||
print_to_stderr(
|
||||
"serial (file granularity) tests:\n {}".format(
|
||||
"\n ".join(str(x) for x in selected_tests_serial)
|
||||
"\n ".join(selected_tests_serial)
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1338,7 +1403,7 @@ def main():
|
|||
return False
|
||||
|
||||
try:
|
||||
os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS)
|
||||
os.environ["PARALLEL_TESTING"] = "1"
|
||||
for test in selected_tests_parallel:
|
||||
options_clone = copy.deepcopy(options)
|
||||
if can_run_in_pytest(test):
|
||||
|
|
@ -1350,7 +1415,7 @@ def main():
|
|||
)
|
||||
pool.close()
|
||||
pool.join()
|
||||
del os.environ["NUM_PARALLEL_PROCS"]
|
||||
del os.environ["PARALLEL_TESTING"]
|
||||
|
||||
if not options.continue_through_error and len(failure_messages) != 0:
|
||||
raise RuntimeError(
|
||||
|
|
|
|||
|
|
@ -1,18 +1,8 @@
|
|||
import pathlib
|
||||
import random
|
||||
import sys
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
||||
try:
|
||||
# using tools/ to optimize test run.
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD
|
||||
except ModuleNotFoundError:
|
||||
print("Can't import required modules, exiting")
|
||||
exit(1)
|
||||
from tools.testing.test_selections import calculate_shards
|
||||
|
||||
|
||||
class TestCalculateShards(unittest.TestCase):
|
||||
|
|
@ -46,8 +36,8 @@ class TestCalculateShards(unittest.TestCase):
|
|||
|
||||
def assert_shards_equal(
|
||||
self,
|
||||
expected_shards: List[Tuple[float, List[ShardedTest]]],
|
||||
actual_shards: List[Tuple[float, List[ShardedTest]]],
|
||||
expected_shards: List[Tuple[float, List[str]]],
|
||||
actual_shards: List[Tuple[float, List[str]]],
|
||||
) -> None:
|
||||
for expected, actual in zip(expected_shards, actual_shards):
|
||||
self.assertAlmostEqual(expected[0], actual[0])
|
||||
|
|
@ -55,25 +45,19 @@ class TestCalculateShards(unittest.TestCase):
|
|||
|
||||
def test_calculate_2_shards_with_complete_test_times(self) -> None:
|
||||
expected_shards = [
|
||||
(
|
||||
60.0,
|
||||
[
|
||||
ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55),
|
||||
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5),
|
||||
],
|
||||
),
|
||||
(60, ["super_long_test", "normal_test3"]),
|
||||
(
|
||||
58.31,
|
||||
[
|
||||
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(name="long_test2", shard=1, num_shards=1, time=18),
|
||||
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7),
|
||||
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6),
|
||||
ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4),
|
||||
ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3),
|
||||
ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01),
|
||||
"long_test1",
|
||||
"long_test2",
|
||||
"normal_test1",
|
||||
"normal_test2",
|
||||
"short_test1",
|
||||
"short_test2",
|
||||
"short_test3",
|
||||
"short_test4",
|
||||
"short_test5",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
@ -86,19 +70,19 @@ class TestCalculateShards(unittest.TestCase):
|
|||
(
|
||||
118.31,
|
||||
[
|
||||
ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55),
|
||||
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(name="long_test2", shard=1, num_shards=1, time=18),
|
||||
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7),
|
||||
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5),
|
||||
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6),
|
||||
ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4),
|
||||
ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3),
|
||||
ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01),
|
||||
"super_long_test",
|
||||
"long_test1",
|
||||
"long_test2",
|
||||
"normal_test1",
|
||||
"normal_test2",
|
||||
"normal_test3",
|
||||
"short_test1",
|
||||
"short_test2",
|
||||
"short_test3",
|
||||
"short_test4",
|
||||
"short_test5",
|
||||
],
|
||||
)
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(1, self.tests, self.test_times)
|
||||
|
|
@ -106,30 +90,31 @@ class TestCalculateShards(unittest.TestCase):
|
|||
|
||||
def test_calculate_5_shards_with_complete_test_times(self) -> None:
|
||||
expected_shards = [
|
||||
(55.0, ["super_long_test"]),
|
||||
(
|
||||
55.0,
|
||||
[ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55)],
|
||||
22.0,
|
||||
[
|
||||
"long_test1",
|
||||
],
|
||||
),
|
||||
(
|
||||
18.0,
|
||||
[
|
||||
"long_test2",
|
||||
],
|
||||
),
|
||||
(22.0, [ShardedTest(name="long_test1", shard=1, num_shards=1, time=22)]),
|
||||
(18.0, [ShardedTest(name="long_test2", shard=1, num_shards=1, time=18)]),
|
||||
(
|
||||
11.31,
|
||||
[
|
||||
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6),
|
||||
ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4),
|
||||
ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3),
|
||||
ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01),
|
||||
],
|
||||
),
|
||||
(
|
||||
12.0,
|
||||
[
|
||||
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7),
|
||||
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5),
|
||||
"normal_test1",
|
||||
"short_test1",
|
||||
"short_test2",
|
||||
"short_test3",
|
||||
"short_test4",
|
||||
"short_test5",
|
||||
],
|
||||
),
|
||||
(12.0, ["normal_test2", "normal_test3"]),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(5, self.tests, self.test_times)
|
||||
|
|
@ -143,24 +128,22 @@ class TestCalculateShards(unittest.TestCase):
|
|||
(
|
||||
22.0,
|
||||
[
|
||||
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(name="long_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="short_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="short_test5", shard=1, num_shards=1, time=None),
|
||||
"long_test1",
|
||||
"long_test2",
|
||||
"normal_test3",
|
||||
"short_test3",
|
||||
"short_test5",
|
||||
],
|
||||
),
|
||||
(
|
||||
10.0,
|
||||
[
|
||||
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(
|
||||
name="super_long_test", shard=1, num_shards=1, time=None
|
||||
),
|
||||
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="short_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="short_test4", shard=1, num_shards=1, time=None),
|
||||
"normal_test1",
|
||||
"short_test1",
|
||||
"super_long_test",
|
||||
"normal_test2",
|
||||
"short_test2",
|
||||
"short_test4",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
@ -173,133 +156,19 @@ class TestCalculateShards(unittest.TestCase):
|
|||
k: v for k, v in self.test_times.items() if "test1" in k
|
||||
}
|
||||
expected_shards = [
|
||||
(
|
||||
22.0,
|
||||
[
|
||||
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="short_test5", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
9.0,
|
||||
[
|
||||
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
1.0,
|
||||
[
|
||||
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(name="short_test2", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
0.0,
|
||||
[
|
||||
ShardedTest(
|
||||
name="super_long_test", shard=1, num_shards=1, time=None
|
||||
),
|
||||
ShardedTest(name="short_test3", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
0.0,
|
||||
[
|
||||
ShardedTest(name="long_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(name="short_test4", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(22.0, ["long_test1", "normal_test2", "short_test5"]),
|
||||
(9.0, ["normal_test1", "normal_test3"]),
|
||||
(1.0, ["short_test1", "short_test2"]),
|
||||
(0.0, ["super_long_test", "short_test3"]),
|
||||
(0.0, ["long_test2", "short_test4"]),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(5, self.tests, incomplete_test_times)
|
||||
)
|
||||
|
||||
def test_split_shards(self) -> None:
|
||||
test_times: Dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD}
|
||||
expected_shards = [
|
||||
(600.0, [ShardedTest(name="test1", shard=1, num_shards=1, time=THRESHOLD)]),
|
||||
(600.0, [ShardedTest(name="test2", shard=1, num_shards=1, time=THRESHOLD)]),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(2, list(test_times.keys()), test_times)
|
||||
)
|
||||
|
||||
test_times = {"test1": THRESHOLD * 4, "test2": THRESHOLD * 2.5}
|
||||
expected_shards = [
|
||||
(
|
||||
2200.0,
|
||||
[
|
||||
ShardedTest(name="test1", shard=1, num_shards=4, time=600.0),
|
||||
ShardedTest(name="test1", shard=3, num_shards=4, time=600.0),
|
||||
ShardedTest(name="test2", shard=1, num_shards=3, time=500.0),
|
||||
ShardedTest(name="test2", shard=3, num_shards=3, time=500.0),
|
||||
],
|
||||
),
|
||||
(
|
||||
1700.0,
|
||||
[
|
||||
ShardedTest(name="test1", shard=2, num_shards=4, time=600.0),
|
||||
ShardedTest(name="test1", shard=4, num_shards=4, time=600.0),
|
||||
ShardedTest(name="test2", shard=2, num_shards=3, time=500.0),
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(2, list(test_times.keys()), test_times)
|
||||
)
|
||||
|
||||
test_times = {"test1": THRESHOLD / 2, "test2": THRESHOLD}
|
||||
expected_shards = [
|
||||
(600.0, [ShardedTest(name="test2", shard=1, num_shards=1, time=THRESHOLD)]),
|
||||
(
|
||||
300.0,
|
||||
[ShardedTest(name="test1", shard=1, num_shards=1, time=THRESHOLD / 2)],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(2, list(test_times.keys()), test_times)
|
||||
)
|
||||
|
||||
def test_split_shards_random(self) -> None:
|
||||
random.seed(120)
|
||||
for _ in range(100):
|
||||
num_shards = random.randint(1, 10)
|
||||
num_tests = random.randint(1, 100)
|
||||
random_times: Dict[str, float] = {
|
||||
str(i): random.randint(0, THRESHOLD * 10) for i in range(num_tests)
|
||||
}
|
||||
|
||||
shards = calculate_shards(
|
||||
num_shards, list(random_times.keys()), random_times
|
||||
)
|
||||
|
||||
times = [x[0] for x in shards]
|
||||
max_diff = max(times) - min(times)
|
||||
self.assertTrue(max_diff <= THRESHOLD)
|
||||
|
||||
all_sharded_tests = defaultdict(list)
|
||||
for time, sharded_tests in shards:
|
||||
self.assertEqual(time, sum(x.time for x in sharded_tests))
|
||||
for sharded_test in sharded_tests:
|
||||
all_sharded_tests[sharded_test.name].append(sharded_test)
|
||||
|
||||
self.assertListEqual(
|
||||
sorted(random_times.keys()), sorted(all_sharded_tests.keys())
|
||||
)
|
||||
for test, sharded_tests in all_sharded_tests.items():
|
||||
self.assertAlmostEqual(
|
||||
random_times[test], sum(x.time or 0 for x in sharded_tests)
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(range(sharded_tests[0].num_shards)),
|
||||
sorted(x.shard - 1 for x in sharded_tests),
|
||||
)
|
||||
|
||||
def test_calculate_2_shards_against_optimal_shards(self) -> None:
|
||||
random.seed(120)
|
||||
for _ in range(100):
|
||||
random.seed(120)
|
||||
random_times = {k: random.random() * 10 for k in self.tests}
|
||||
# all test times except first two
|
||||
rest_of_tests = [
|
||||
|
|
@ -325,7 +194,7 @@ class TestCalculateShards(unittest.TestCase):
|
|||
calculated_shards[0][1] + calculated_shards[1][1]
|
||||
)
|
||||
# All the tests should be represented by some shard
|
||||
self.assertEqual(sorted_tests, [x.name for x in sorted_shard_tests])
|
||||
self.assertEqual(sorted_tests, sorted_shard_tests)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,15 +1,13 @@
|
|||
import math
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
||||
|
||||
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
|
||||
|
||||
NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 2
|
||||
THRESHOLD = 60 * 10 # 10 minutes
|
||||
|
||||
# See Note [ROCm parallel CI testing]
|
||||
# Special logic for ROCm GHA runners to query number of GPUs available.
|
||||
|
|
@ -32,73 +30,43 @@ if os.path.exists("/opt/rocm") and not IS_MEM_LEAK_CHECK:
|
|||
NUM_PROCS = 1
|
||||
|
||||
|
||||
class ShardedTest(NamedTuple):
|
||||
name: str
|
||||
shard: int
|
||||
num_shards: int
|
||||
time: Optional[float]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name} {self.shard}/{self.num_shards}"
|
||||
|
||||
def get_time(self) -> float:
|
||||
return self.time or 0
|
||||
|
||||
|
||||
class ShardJob:
|
||||
def __init__(self) -> None:
|
||||
self.serial: List[ShardedTest] = []
|
||||
self.parallel: List[ShardedTest] = []
|
||||
def __init__(self, test_times: Dict[str, float]):
|
||||
self.test_times = test_times
|
||||
self.serial: List[str] = []
|
||||
self.parallel: List[str] = []
|
||||
|
||||
def get_total_time(self) -> float:
|
||||
procs = [0.0 for _ in range(NUM_PROCS)]
|
||||
for test in self.parallel:
|
||||
test_time = self.test_times.get(test, 0)
|
||||
min_index = procs.index(min(procs))
|
||||
procs[min_index] += test.get_time()
|
||||
time = max(procs) + sum(test.get_time() for test in self.serial)
|
||||
procs[min_index] += test_time
|
||||
time = max(procs) + sum(self.test_times.get(test, 0) for test in self.serial)
|
||||
return time
|
||||
|
||||
def convert_to_tuple(self) -> Tuple[float, List[ShardedTest]]:
|
||||
def convert_to_tuple(self) -> Tuple[float, List[str]]:
|
||||
return (self.get_total_time(), self.serial + self.parallel)
|
||||
|
||||
|
||||
def get_with_pytest_shard(
|
||||
tests: List[str], test_file_times: Dict[str, float]
|
||||
) -> List[ShardedTest]:
|
||||
sharded_tests: List[ShardedTest] = []
|
||||
for test in tests:
|
||||
duration = test_file_times[test]
|
||||
if duration > THRESHOLD:
|
||||
num_shards = math.ceil(duration / THRESHOLD)
|
||||
for i in range(num_shards):
|
||||
sharded_tests.append(
|
||||
ShardedTest(test, i + 1, num_shards, duration / num_shards)
|
||||
)
|
||||
else:
|
||||
sharded_tests.append(ShardedTest(test, 1, 1, duration))
|
||||
return sharded_tests
|
||||
|
||||
|
||||
def calculate_shards(
|
||||
num_shards: int,
|
||||
tests: List[str],
|
||||
test_file_times: Dict[str, float],
|
||||
must_serial: Optional[Callable[[str], bool]] = None,
|
||||
) -> List[Tuple[float, List[ShardedTest]]]:
|
||||
) -> List[Tuple[float, List[str]]]:
|
||||
must_serial = must_serial or (lambda x: True)
|
||||
|
||||
known_tests = [x for x in tests if x in test_file_times]
|
||||
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
|
||||
|
||||
sorted_tests = sorted(
|
||||
get_with_pytest_shard(known_tests, test_file_times),
|
||||
key=lambda j: j.get_time(),
|
||||
reverse=True,
|
||||
)
|
||||
sorted_tests = sorted(known_tests, key=lambda j: test_file_times[j], reverse=True)
|
||||
|
||||
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
|
||||
sharded_jobs: List[ShardJob] = [
|
||||
ShardJob(test_file_times) for _ in range(num_shards)
|
||||
]
|
||||
for test in sorted_tests:
|
||||
if must_serial(test.name):
|
||||
if must_serial(test):
|
||||
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
||||
min_sharded_job.serial.append(test)
|
||||
else:
|
||||
|
|
@ -107,8 +75,8 @@ def calculate_shards(
|
|||
|
||||
# Round robin the unknown jobs starting with the smallest shard
|
||||
index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time())
|
||||
for unknown_test in unknown_tests:
|
||||
sharded_jobs[index].serial.append(ShardedTest(unknown_test, 1, 1, None))
|
||||
for test in unknown_tests:
|
||||
sharded_jobs[index].serial.append(test)
|
||||
index = (index + 1) % num_shards
|
||||
return [job.convert_to_tuple() for job in sharded_jobs]
|
||||
|
||||
|
|
|
|||
|
|
@ -116,9 +116,11 @@ slow_tests_dict = {}
|
|||
if os.getenv("SLOW_TESTS_FILE", ""):
|
||||
with open(os.getenv("SLOW_TESTS_FILE"), 'r') as fp:
|
||||
slow_tests_dict = json.load(fp)
|
||||
warnings.warn(f"loaded {len(slow_tests_dict)} slow tests")
|
||||
if os.getenv("DISABLED_TESTS_FILE", ""):
|
||||
with open(os.getenv("DISABLED_TESTS_FILE"), 'r') as fp:
|
||||
disabled_tests_dict = json.load(fp)
|
||||
warnings.warn(f"loaded {len(disabled_tests_dict)} disabled tests")
|
||||
|
||||
NATIVE_DEVICES = ('cpu', 'cuda', 'meta')
|
||||
|
||||
|
|
@ -569,9 +571,9 @@ CI_TEST_PREFIX = str(Path(os.getcwd()))
|
|||
CI_PT_ROOT = str(Path(os.getcwd()).parent)
|
||||
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
|
||||
|
||||
def wait_for_process(p, timeout=None):
|
||||
def wait_for_process(p):
|
||||
try:
|
||||
return p.wait(timeout=timeout)
|
||||
return p.wait()
|
||||
except KeyboardInterrupt:
|
||||
# Give `p` a chance to handle KeyboardInterrupt. Without this,
|
||||
# `pytest` can't print errors it collected so far upon KeyboardInterrupt.
|
||||
|
|
@ -588,7 +590,7 @@ def wait_for_process(p, timeout=None):
|
|||
# Always call p.wait() to ensure exit
|
||||
p.wait()
|
||||
|
||||
def shell(command, cwd=None, env=None, stdout=None, stderr=None, timeout=None):
|
||||
def shell(command, cwd=None, env=None, stdout=None, stderr=None):
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
# The following cool snippet is copied from Py3 core library subprocess.call
|
||||
|
|
@ -600,22 +602,7 @@ def shell(command, cwd=None, env=None, stdout=None, stderr=None, timeout=None):
|
|||
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
|
||||
assert not isinstance(command, str), "Command to shell should be a list or tuple of tokens"
|
||||
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env, stdout=stdout, stderr=stderr)
|
||||
return wait_for_process(p, timeout=timeout)
|
||||
|
||||
|
||||
def retry_shell(command, cwd=None, env=None, stdout=None, stderr=None, timeout=None, retries=1):
|
||||
assert retries >= 0, f"Expecting non negative number for number of retries, got {retries}"
|
||||
try:
|
||||
exit_code = shell(command, cwd=cwd, env=env, stdout=stdout, stderr=stderr, timeout=timeout)
|
||||
if exit_code == 0 or retries == 0:
|
||||
return exit_code
|
||||
print(f"Got exit code {exit_code}, retrying (retries left={retries})", file=stdout, flush=True)
|
||||
except subprocess.TimeoutExpired:
|
||||
if retries == 0:
|
||||
print(f"Command took >{timeout // 60}min, returning 124", file=stdout, flush=True)
|
||||
return 124
|
||||
print(f"Command took >{timeout // 60}min, retrying (retries left={retries})", file=stdout, flush=True)
|
||||
return retry_shell(command, cwd=cwd, env=env, stdout=stdout, stderr=stderr, timeout=timeout, retries=retries - 1)
|
||||
return wait_for_process(p)
|
||||
|
||||
|
||||
def discover_test_cases_recursively(suite_or_case):
|
||||
|
|
@ -766,10 +753,7 @@ def run_tests(argv=UNITTEST_ARGS):
|
|||
[test_case_full_name]
|
||||
)
|
||||
string_cmd = " ".join(cmd)
|
||||
|
||||
timeout = None if RERUN_DISABLED_TESTS else 15 * 60
|
||||
|
||||
exitcode = retry_shell(cmd, timeout=timeout, retries=0 if RERUN_DISABLED_TESTS else 1)
|
||||
exitcode = shell(cmd)
|
||||
|
||||
if exitcode != 0:
|
||||
# This is sort of hacky, but add on relevant env variables for distributed tests.
|
||||
|
|
@ -810,6 +794,7 @@ def run_tests(argv=UNITTEST_ARGS):
|
|||
exit_code = pytest.main(args=pytest_args)
|
||||
if TEST_SAVE_XML:
|
||||
sanitize_pytest_xml(test_report_path)
|
||||
print("If in CI, skip info is located in the xml test reports, please either go to s3 or the hud to download them")
|
||||
|
||||
if not RERUN_DISABLED_TESTS:
|
||||
# exitcode of 5 means no tests were found, which happens since some test configs don't
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user