Retry at test file level (#97506)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97506
Approved by: https://github.com/huydhn
This commit is contained in:
Catherine Lee 2023-03-30 17:12:19 +00:00 committed by PyTorch MergeBot
parent 24a5d006f2
commit 7d5d5beba2
4 changed files with 368 additions and 255 deletions

View File

@ -24,6 +24,7 @@ from torch.testing._internal.common_utils import (
IS_CI,
is_slow_gradcheck_env,
parser as common_parser,
retry_shell,
set_cwd,
shell,
TEST_WITH_ROCM,
@ -41,6 +42,8 @@ try:
get_reordered_tests,
get_test_case_configs,
NUM_PROCS,
ShardedTest,
THRESHOLD,
)
HAVE_TEST_SELECTION_TOOLS = True
@ -51,6 +54,9 @@ except ImportError:
)
RERUN_DISABLED_TESTS = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1"
# Note [ROCm parallel CI testing]
# https://github.com/pytorch/pytorch/pull/85770 added file-granularity parallel testing.
# In .ci/pytorch/test.sh, TEST_CONFIG == "default", CUDA and HIP_VISIBLE_DEVICES is set to 0.
@ -277,6 +283,13 @@ CI_SERIAL_LIST = [
"_nvfuser/test_torchscript", # OOM on test_issue_1785
"test_schema_check", # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/95749
"functorch/test_memory_efficient_fusion", # Cause CUDA OOM on ROCm
"test_utils", # OOM
"test_sort_and_select", # OOM
"test_backward_compatible_arguments", # OOM
"test_module_init", # OOM
"test_autocast", # OOM
"test_native_mha", # OOM
"test_module_hooks", # OOM
]
# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected
@ -291,19 +304,6 @@ CORE_TEST_LIST = [
"test_torch",
]
# A list of distributed tests that run on multiple backends, i.e. gloo, nccl. These backends are spread out
# among all available shards to speed up the test. The list of backends are copied from the tests themselves
DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS = {
"distributed/test_distributed_spawn": [
"gloo",
"nccl",
"ucc",
],
"distributed/algorithms/quantization/test_quantization": [
"gloo",
"nccl",
],
}
# if a test file takes longer than 5 min, we add it to TARGET_DET_LIST
SLOW_TEST_THRESHOLD = 300
@ -406,9 +406,19 @@ def run_test(
) -> int:
maybe_set_hip_visible_devies()
unittest_args = options.additional_unittest_args.copy()
test_file = test_module
if isinstance(test_file, ShardedTest):
unittest_args.extend(
[
f"--shard-id={test_module.shard - 1}",
f"--num-shards={test_module.num_shards}",
]
)
test_file = test_module.name
if options.verbose:
unittest_args.append(f'-{"v"*options.verbose}') # in case of pytest
if test_module in RUN_PARALLEL_BLOCKLIST:
if test_file in RUN_PARALLEL_BLOCKLIST:
unittest_args = [
arg for arg in unittest_args if not arg.startswith("--run-parallel")
]
@ -422,11 +432,11 @@ def run_test(
unittest_args = [arg if arg != "-f" else "-x" for arg in unittest_args]
if IS_CI:
ci_args = ["--import-slow-tests", "--import-disabled-tests"]
if os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1":
if RERUN_DISABLED_TESTS:
ci_args.append("--rerun-disabled-tests")
# use the downloaded test cases configuration, not supported in pytest
unittest_args.extend(ci_args)
if test_module in PYTEST_SKIP_RETRIES:
if test_file in PYTEST_SKIP_RETRIES:
if not options.pytest:
raise RuntimeError(
"A test running without pytest cannot skip retries using "
@ -439,19 +449,35 @@ def run_test(
# Can't call `python -m unittest test_*` here because it doesn't run code
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
argv = [test_module + ".py"] + unittest_args
argv = [test_file + ".py"] + unittest_args
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
log_fd, log_path = tempfile.mkstemp(
dir=REPO_ROOT / "test" / "test-reports",
prefix="{}_".format(test_module.replace("\\", "-").replace("/", "-")),
prefix="{}_".format(test_file.replace("\\", "-").replace("/", "-")),
suffix=".log",
)
os.close(log_fd)
command = (launcher_cmd or []) + executable + argv
should_file_rerun = (
"--subprocess" not in command
and not RERUN_DISABLED_TESTS
and isinstance(test_module, ShardedTest)
and test_module.time is not None
)
timeout = THRESHOLD * 2 if should_file_rerun else None
print_to_stderr("Executing {} ... [{}]".format(command, datetime.now()))
with open(log_path, "w") as f:
ret_code = shell(command, test_directory, stdout=f, stderr=f, env=env)
ret_code = retry_shell(
command,
test_directory,
stdout=f,
stderr=f,
env=env,
timeout=timeout,
retries=1 if should_file_rerun else 0,
)
print_log_file(test_module, log_path, failed=(ret_code != 0))
os.remove(log_path)
return ret_code
@ -549,33 +575,12 @@ def test_distributed(test_module, test_directory, options):
if options.verbose and not mpi_available:
print_to_stderr("MPI not available -- MPI backend tests will be skipped")
if options.shard:
which_shard, num_shards = options.shard
else:
which_shard = num_shards = 1
# Round-robin all backends to different shards
backend_to_shard = {
backend: i % num_shards + 1
for i, backend in enumerate(
DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS[test_module]
)
}
print_to_stderr(
f"Map different backends to different shards for {test_module}: {backend_to_shard}"
)
config = DISTRIBUTED_TESTS_CONFIG
for backend, env_vars in config.items():
if sys.platform == "win32" and backend != "gloo":
continue
if backend == "mpi" and not mpi_available:
continue
# Default to the first shard if seeing an unrecognized backend
if which_shard != backend_to_shard.get(backend, 1):
print_to_stderr(
f"Shard {which_shard}: {backend} should be run in {backend_to_shard.get(backend, 1)}"
)
continue
for with_init_file in {True, False}:
if sys.platform == "win32" and not with_init_file:
continue
@ -584,8 +589,8 @@ def test_distributed(test_module, test_directory, options):
init_str = "with {} init_method"
with_init = init_str.format("file" if with_init_file else "env")
print_to_stderr(
"Running distributed tests for the {} backend {} in shard {} of {}".format(
backend, with_init, which_shard, num_shards
"Running distributed tests for the {} backend {}".format(
backend, with_init
)
)
old_environ = dict(os.environ)
@ -594,7 +599,7 @@ def test_distributed(test_module, test_directory, options):
os.environ["INIT_METHOD"] = "env://"
os.environ.update(env_vars)
if with_init_file:
if test_module == "test_distributed_spawn":
if test_module.name == "test_distributed_spawn":
init_method = f"{FILE_SCHEMA}{tmp_dir}/"
else:
init_method = f"{FILE_SCHEMA}{tmp_dir}/shared_init_file"
@ -778,6 +783,7 @@ def run_doctests(test_module, test_directory, options):
def print_log_file(test: str, file_path: str, failed: bool) -> None:
num_lines = sum(1 for _ in open(file_path, "rb"))
test = str(test)
n = 100
with open(file_path, "r") as f:
print_to_stderr("")
@ -805,7 +811,7 @@ def print_log_file(test: str, file_path: str, failed: bool) -> None:
def get_pytest_args(options):
if os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1":
if RERUN_DISABLED_TESTS:
# When under rerun-disabled-tests mode, run the same tests multiple times to determine their
# flakiness status. Default to 50 re-runs
rerun_options = ["--flake-finder", "--flake-runs=50"]
@ -815,7 +821,7 @@ def get_pytest_args(options):
else:
# When under the normal mode, retry a failed test 2 more times. -x means stop at the first
# failure
rerun_options = ["-x", "--reruns=2"]
rerun_options = ["-x", "--reruns=2", "--sw"]
pytest_args = [
"--use-pytest",
@ -828,55 +834,6 @@ def get_pytest_args(options):
return pytest_args
def run_test_ops(test_module, test_directory, options):
default_unittest_args = get_pytest_args(options)
return_codes = []
os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS)
pool = get_context("spawn").Pool(NUM_PROCS)
for i in range(NUM_PROCS):
extra_unittest_args = default_unittest_args.copy()
extra_unittest_args.extend(
[
f"--shard-id={i}",
f"--num-shards={NUM_PROCS}",
"-k=not _linalg_cholesky_",
]
)
return_code = pool.apply_async(
run_test,
args=(test_module, test_directory, copy.deepcopy(options)),
kwds={
"extra_unittest_args": extra_unittest_args,
},
)
return_codes.append(return_code)
pool.close()
pool.join()
del os.environ["NUM_PARALLEL_PROCS"]
for return_code in return_codes:
if return_code.get() != 0:
return return_code.get()
extra_unittest_args = default_unittest_args.copy()
extra_unittest_args.extend(
[
"-k=_linalg_cholesky_",
]
)
return_code = run_test(
test_module,
test_directory,
copy.deepcopy(options),
extra_unittest_args=extra_unittest_args,
)
return return_code
CUSTOM_HANDLERS = {
"test_cuda_primary_ctx": test_cuda_primary_ctx,
"test_cuda_nvml_based_avail": get_run_test_with_subprocess_fn(),
@ -898,15 +855,6 @@ CUSTOM_HANDLERS = {
"distributed/rpc/test_share_memory": get_run_test_with_subprocess_fn(),
"distributed/rpc/cuda/test_tensorpipe_agent": get_run_test_with_subprocess_fn(),
"doctests": run_doctests,
"inductor/test_torchinductor_opinfo": run_test_ops,
"test_ops": run_test_ops,
"test_ops_gradients": run_test_ops,
"test_ops_fwd_gradients": run_test_ops,
"test_ops_jit": run_test_ops,
"functorch/test_ops": run_test_ops,
# not a test_ops file, but takes 2 hrs on some architectures and
# run_test_ops is good at parallelizing things
"test_decomp": run_test_ops,
}
@ -1179,13 +1127,7 @@ def get_selected_tests(options):
if options.distributed_tests:
selected_tests = list(
filter(
lambda test_name: (
test_name in DISTRIBUTED_TESTS
and test_name not in DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS
),
selected_tests,
)
filter(lambda test_name: test_name in DISTRIBUTED_TESTS, selected_tests)
)
# Filter to only run core tests when --core option is specified
@ -1246,46 +1188,6 @@ def get_selected_tests(options):
elif TEST_WITH_ROCM:
selected_tests = exclude_tests(ROCM_BLOCKLIST, selected_tests, "on ROCm")
# sharding
if options.shard:
assert len(options.shard) == 2, "Unexpected shard format"
assert min(options.shard) > 0, "Shards must be positive numbers"
which_shard, num_shards = options.shard
assert (
which_shard <= num_shards
), "Selected shard must be less than or equal to total number of shards"
assert num_shards <= len(
selected_tests
), f"Number of shards must be less than {len(selected_tests)}"
if num_shards == 1:
return selected_tests
# Download previous test times to make sharding decisions
path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE)
if os.path.exists(path):
with open(path, "r") as f:
test_file_times = cast(Dict[str, Any], json.load(f))
else:
test_file_times = {}
test_config = os.environ.get("TEST_CONFIG")
if test_config not in test_file_times:
print(
"::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan."
)
selected_tests = selected_tests[which_shard - 1 :: num_shards]
else:
print("Found test time stats from artifacts")
test_file_times_config = test_file_times[test_config]
shards = calculate_shards(
num_shards,
selected_tests,
test_file_times_config,
must_serial=must_serial,
)
_, tests_from_shard = shards[which_shard - 1]
selected_tests = tests_from_shard
# skip all distributed tests if distributed package is not available.
if not dist.is_available():
selected_tests = exclude_tests(
@ -1311,30 +1213,61 @@ def get_selected_tests(options):
exact_match=True,
)
if options.distributed_tests:
# Run distributed tests with multiple backends across all shards, one per backend
selected_tests.extend(DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS.keys())
selected_tests.reverse()
selected_tests = [parse_test_module(x) for x in selected_tests]
# sharding
which_shard, num_shards = 1, 1
if options.shard:
assert len(options.shard) == 2, "Unexpected shard format"
assert min(options.shard) > 0, "Shards must be positive numbers"
which_shard, num_shards = options.shard
assert (
which_shard <= num_shards
), "Selected shard must be less than or equal to total number of shards"
assert num_shards <= len(
selected_tests
), f"Number of shards must be less than {len(selected_tests)}"
# Download previous test times to make sharding decisions
path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE)
if os.path.exists(path):
with open(path, "r") as f:
test_file_times = cast(Dict[str, Any], json.load(f))
else:
test_file_times = {}
test_config = os.environ.get("TEST_CONFIG")
if test_config not in test_file_times:
print(
"::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan."
)
else:
print("Found test time stats from artifacts")
# Do sharding
test_file_times_config = test_file_times.get(test_config, {})
shards = calculate_shards(
num_shards, selected_tests, test_file_times_config, must_serial=must_serial
)
_, tests_from_shard = shards[which_shard - 1]
selected_tests = tests_from_shard
return selected_tests
def run_test_module(test: str, test_directory: str, options) -> Optional[str]:
def run_test_module(test: ShardedTest, test_directory: str, options) -> Optional[str]:
maybe_set_hip_visible_devies()
test_module = parse_test_module(test)
# Printing the date here can help diagnose which tests are slow
print_to_stderr("Running {} ... [{}]".format(test, datetime.now()))
handler = CUSTOM_HANDLERS.get(test_module, run_test)
return_code = handler(test_module, test_directory, options)
print_to_stderr("Running {} ... [{}]".format(str(test), datetime.now()))
handler = CUSTOM_HANDLERS.get(test.name, run_test)
return_code = handler(test, test_directory, options)
assert isinstance(return_code, int) and not isinstance(
return_code, bool
), f"While running {test} got non integer return code {return_code}"
), f"While running {str(test)} got non integer return code {return_code}"
if return_code == 0:
return None
message = f"{test} failed!"
message = f"{str(test)} failed!"
if return_code < 0:
# subprocess.Popen returns the child process' exit signal as
# return code -N, where N is the signal number.
@ -1350,7 +1283,9 @@ def main():
selected_tests = get_selected_tests(options)
if options.verbose:
print_to_stderr("Selected tests:\n {}".format("\n ".join(selected_tests)))
print_to_stderr(
"Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests))
)
if options.dry_run:
return
@ -1372,18 +1307,18 @@ def main():
# parallel = in parallel with other files
# serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops)
selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
selected_tests_parallel = [x for x in selected_tests if not must_serial(x.name)]
selected_tests_serial = [
x for x in selected_tests if x not in selected_tests_parallel
]
print_to_stderr(
"parallel (file granularity) tests:\n {}".format(
"\n ".join(selected_tests_parallel)
"\n ".join(str(x) for x in selected_tests_parallel)
)
)
print_to_stderr(
"serial (file granularity) tests:\n {}".format(
"\n ".join(selected_tests_serial)
"\n ".join(str(x) for x in selected_tests_serial)
)
)
@ -1403,7 +1338,7 @@ def main():
return False
try:
os.environ["PARALLEL_TESTING"] = "1"
os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS)
for test in selected_tests_parallel:
options_clone = copy.deepcopy(options)
if can_run_in_pytest(test):
@ -1415,7 +1350,7 @@ def main():
)
pool.close()
pool.join()
del os.environ["PARALLEL_TESTING"]
del os.environ["NUM_PARALLEL_PROCS"]
if not options.continue_through_error and len(failure_messages) != 0:
raise RuntimeError(

View File

@ -1,8 +1,18 @@
import pathlib
import random
import sys
import unittest
from collections import defaultdict
from typing import Dict, List, Tuple
from tools.testing.test_selections import calculate_shards
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
try:
# using tools/ to optimize test run.
sys.path.append(str(REPO_ROOT))
from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD
except ModuleNotFoundError:
print("Can't import required modules, exiting")
exit(1)
class TestCalculateShards(unittest.TestCase):
@ -36,8 +46,8 @@ class TestCalculateShards(unittest.TestCase):
def assert_shards_equal(
self,
expected_shards: List[Tuple[float, List[str]]],
actual_shards: List[Tuple[float, List[str]]],
expected_shards: List[Tuple[float, List[ShardedTest]]],
actual_shards: List[Tuple[float, List[ShardedTest]]],
) -> None:
for expected, actual in zip(expected_shards, actual_shards):
self.assertAlmostEqual(expected[0], actual[0])
@ -45,19 +55,25 @@ class TestCalculateShards(unittest.TestCase):
def test_calculate_2_shards_with_complete_test_times(self) -> None:
expected_shards = [
(60, ["super_long_test", "normal_test3"]),
(
60.0,
[
ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55),
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5),
],
),
(
58.31,
[
"long_test1",
"long_test2",
"normal_test1",
"normal_test2",
"short_test1",
"short_test2",
"short_test3",
"short_test4",
"short_test5",
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
ShardedTest(name="long_test2", shard=1, num_shards=1, time=18),
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7),
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6),
ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4),
ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3),
ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01),
],
),
]
@ -70,19 +86,19 @@ class TestCalculateShards(unittest.TestCase):
(
118.31,
[
"super_long_test",
"long_test1",
"long_test2",
"normal_test1",
"normal_test2",
"normal_test3",
"short_test1",
"short_test2",
"short_test3",
"short_test4",
"short_test5",
ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55),
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
ShardedTest(name="long_test2", shard=1, num_shards=1, time=18),
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7),
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5),
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6),
ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4),
ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3),
ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01),
],
),
)
]
self.assert_shards_equal(
expected_shards, calculate_shards(1, self.tests, self.test_times)
@ -90,31 +106,30 @@ class TestCalculateShards(unittest.TestCase):
def test_calculate_5_shards_with_complete_test_times(self) -> None:
expected_shards = [
(55.0, ["super_long_test"]),
(
22.0,
[
"long_test1",
],
),
(
18.0,
[
"long_test2",
],
55.0,
[ShardedTest(name="super_long_test", shard=1, num_shards=1, time=55)],
),
(22.0, [ShardedTest(name="long_test1", shard=1, num_shards=1, time=22)]),
(18.0, [ShardedTest(name="long_test2", shard=1, num_shards=1, time=18)]),
(
11.31,
[
"normal_test1",
"short_test1",
"short_test2",
"short_test3",
"short_test4",
"short_test5",
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(name="short_test2", shard=1, num_shards=1, time=0.6),
ShardedTest(name="short_test3", shard=1, num_shards=1, time=0.4),
ShardedTest(name="short_test4", shard=1, num_shards=1, time=0.3),
ShardedTest(name="short_test5", shard=1, num_shards=1, time=0.01),
],
),
(
12.0,
[
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=7),
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=5),
],
),
(12.0, ["normal_test2", "normal_test3"]),
]
self.assert_shards_equal(
expected_shards, calculate_shards(5, self.tests, self.test_times)
@ -128,22 +143,24 @@ class TestCalculateShards(unittest.TestCase):
(
22.0,
[
"long_test1",
"long_test2",
"normal_test3",
"short_test3",
"short_test5",
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
ShardedTest(name="long_test2", shard=1, num_shards=1, time=None),
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=None),
ShardedTest(name="short_test3", shard=1, num_shards=1, time=None),
ShardedTest(name="short_test5", shard=1, num_shards=1, time=None),
],
),
(
10.0,
[
"normal_test1",
"short_test1",
"super_long_test",
"normal_test2",
"short_test2",
"short_test4",
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(
name="super_long_test", shard=1, num_shards=1, time=None
),
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=None),
ShardedTest(name="short_test2", shard=1, num_shards=1, time=None),
ShardedTest(name="short_test4", shard=1, num_shards=1, time=None),
],
),
]
@ -156,19 +173,133 @@ class TestCalculateShards(unittest.TestCase):
k: v for k, v in self.test_times.items() if "test1" in k
}
expected_shards = [
(22.0, ["long_test1", "normal_test2", "short_test5"]),
(9.0, ["normal_test1", "normal_test3"]),
(1.0, ["short_test1", "short_test2"]),
(0.0, ["super_long_test", "short_test3"]),
(0.0, ["long_test2", "short_test4"]),
(
22.0,
[
ShardedTest(name="long_test1", shard=1, num_shards=1, time=22),
ShardedTest(name="normal_test2", shard=1, num_shards=1, time=None),
ShardedTest(name="short_test5", shard=1, num_shards=1, time=None),
],
),
(
9.0,
[
ShardedTest(name="normal_test1", shard=1, num_shards=1, time=9),
ShardedTest(name="normal_test3", shard=1, num_shards=1, time=None),
],
),
(
1.0,
[
ShardedTest(name="short_test1", shard=1, num_shards=1, time=1),
ShardedTest(name="short_test2", shard=1, num_shards=1, time=None),
],
),
(
0.0,
[
ShardedTest(
name="super_long_test", shard=1, num_shards=1, time=None
),
ShardedTest(name="short_test3", shard=1, num_shards=1, time=None),
],
),
(
0.0,
[
ShardedTest(name="long_test2", shard=1, num_shards=1, time=None),
ShardedTest(name="short_test4", shard=1, num_shards=1, time=None),
],
),
]
self.assert_shards_equal(
expected_shards, calculate_shards(5, self.tests, incomplete_test_times)
)
def test_calculate_2_shards_against_optimal_shards(self) -> None:
def test_split_shards(self) -> None:
test_times: Dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD}
expected_shards = [
(600.0, [ShardedTest(name="test1", shard=1, num_shards=1, time=THRESHOLD)]),
(600.0, [ShardedTest(name="test2", shard=1, num_shards=1, time=THRESHOLD)]),
]
self.assert_shards_equal(
expected_shards, calculate_shards(2, list(test_times.keys()), test_times)
)
test_times = {"test1": THRESHOLD * 4, "test2": THRESHOLD * 2.5}
expected_shards = [
(
2200.0,
[
ShardedTest(name="test1", shard=1, num_shards=4, time=600.0),
ShardedTest(name="test1", shard=3, num_shards=4, time=600.0),
ShardedTest(name="test2", shard=1, num_shards=3, time=500.0),
ShardedTest(name="test2", shard=3, num_shards=3, time=500.0),
],
),
(
1700.0,
[
ShardedTest(name="test1", shard=2, num_shards=4, time=600.0),
ShardedTest(name="test1", shard=4, num_shards=4, time=600.0),
ShardedTest(name="test2", shard=2, num_shards=3, time=500.0),
],
),
]
self.assert_shards_equal(
expected_shards, calculate_shards(2, list(test_times.keys()), test_times)
)
test_times = {"test1": THRESHOLD / 2, "test2": THRESHOLD}
expected_shards = [
(600.0, [ShardedTest(name="test2", shard=1, num_shards=1, time=THRESHOLD)]),
(
300.0,
[ShardedTest(name="test1", shard=1, num_shards=1, time=THRESHOLD / 2)],
),
]
self.assert_shards_equal(
expected_shards, calculate_shards(2, list(test_times.keys()), test_times)
)
def test_split_shards_random(self) -> None:
random.seed(120)
for _ in range(100):
num_shards = random.randint(1, 10)
num_tests = random.randint(1, 100)
random_times: Dict[str, float] = {
str(i): random.randint(0, THRESHOLD * 10) for i in range(num_tests)
}
shards = calculate_shards(
num_shards, list(random_times.keys()), random_times
)
times = [x[0] for x in shards]
max_diff = max(times) - min(times)
self.assertTrue(max_diff <= THRESHOLD)
all_sharded_tests = defaultdict(list)
for time, sharded_tests in shards:
self.assertEqual(time, sum(x.time for x in sharded_tests))
for sharded_test in sharded_tests:
all_sharded_tests[sharded_test.name].append(sharded_test)
self.assertListEqual(
sorted(random_times.keys()), sorted(all_sharded_tests.keys())
)
for test, sharded_tests in all_sharded_tests.items():
self.assertAlmostEqual(
random_times[test], sum(x.time or 0 for x in sharded_tests)
)
self.assertListEqual(
list(range(sharded_tests[0].num_shards)),
sorted(x.shard - 1 for x in sharded_tests),
)
def test_calculate_2_shards_against_optimal_shards(self) -> None:
random.seed(120)
for _ in range(100):
random.seed(120)
random_times = {k: random.random() * 10 for k in self.tests}
# all test times except first two
rest_of_tests = [
@ -194,7 +325,7 @@ class TestCalculateShards(unittest.TestCase):
calculated_shards[0][1] + calculated_shards[1][1]
)
# All the tests should be represented by some shard
self.assertEqual(sorted_tests, sorted_shard_tests)
self.assertEqual(sorted_tests, [x.name for x in sorted_shard_tests])
if __name__ == "__main__":

View File

@ -1,13 +1,15 @@
import math
import os
import subprocess
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 2
THRESHOLD = 60 * 10 # 10 minutes
# See Note [ROCm parallel CI testing]
# Special logic for ROCm GHA runners to query number of GPUs available.
@ -30,43 +32,73 @@ if os.path.exists("/opt/rocm") and not IS_MEM_LEAK_CHECK:
NUM_PROCS = 1
class ShardedTest(NamedTuple):
name: str
shard: int
num_shards: int
time: Optional[float]
def __str__(self) -> str:
return f"{self.name} {self.shard}/{self.num_shards}"
def get_time(self) -> float:
return self.time or 0
class ShardJob:
def __init__(self, test_times: Dict[str, float]):
self.test_times = test_times
self.serial: List[str] = []
self.parallel: List[str] = []
def __init__(self) -> None:
self.serial: List[ShardedTest] = []
self.parallel: List[ShardedTest] = []
def get_total_time(self) -> float:
procs = [0.0 for _ in range(NUM_PROCS)]
for test in self.parallel:
test_time = self.test_times.get(test, 0)
min_index = procs.index(min(procs))
procs[min_index] += test_time
time = max(procs) + sum(self.test_times.get(test, 0) for test in self.serial)
procs[min_index] += test.get_time()
time = max(procs) + sum(test.get_time() for test in self.serial)
return time
def convert_to_tuple(self) -> Tuple[float, List[str]]:
def convert_to_tuple(self) -> Tuple[float, List[ShardedTest]]:
return (self.get_total_time(), self.serial + self.parallel)
def get_with_pytest_shard(
tests: List[str], test_file_times: Dict[str, float]
) -> List[ShardedTest]:
sharded_tests: List[ShardedTest] = []
for test in tests:
duration = test_file_times[test]
if duration > THRESHOLD:
num_shards = math.ceil(duration / THRESHOLD)
for i in range(num_shards):
sharded_tests.append(
ShardedTest(test, i + 1, num_shards, duration / num_shards)
)
else:
sharded_tests.append(ShardedTest(test, 1, 1, duration))
return sharded_tests
def calculate_shards(
num_shards: int,
tests: List[str],
test_file_times: Dict[str, float],
must_serial: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[float, List[str]]]:
) -> List[Tuple[float, List[ShardedTest]]]:
must_serial = must_serial or (lambda x: True)
known_tests = [x for x in tests if x in test_file_times]
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
sorted_tests = sorted(known_tests, key=lambda j: test_file_times[j], reverse=True)
sorted_tests = sorted(
get_with_pytest_shard(known_tests, test_file_times),
key=lambda j: j.get_time(),
reverse=True,
)
sharded_jobs: List[ShardJob] = [
ShardJob(test_file_times) for _ in range(num_shards)
]
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
for test in sorted_tests:
if must_serial(test):
if must_serial(test.name):
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
min_sharded_job.serial.append(test)
else:
@ -75,8 +107,8 @@ def calculate_shards(
# Round robin the unknown jobs starting with the smallest shard
index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time())
for test in unknown_tests:
sharded_jobs[index].serial.append(test)
for unknown_test in unknown_tests:
sharded_jobs[index].serial.append(ShardedTest(unknown_test, 1, 1, None))
index = (index + 1) % num_shards
return [job.convert_to_tuple() for job in sharded_jobs]

View File

@ -116,11 +116,9 @@ slow_tests_dict = {}
if os.getenv("SLOW_TESTS_FILE", ""):
with open(os.getenv("SLOW_TESTS_FILE"), 'r') as fp:
slow_tests_dict = json.load(fp)
warnings.warn(f"loaded {len(slow_tests_dict)} slow tests")
if os.getenv("DISABLED_TESTS_FILE", ""):
with open(os.getenv("DISABLED_TESTS_FILE"), 'r') as fp:
disabled_tests_dict = json.load(fp)
warnings.warn(f"loaded {len(disabled_tests_dict)} disabled tests")
NATIVE_DEVICES = ('cpu', 'cuda', 'meta')
@ -571,9 +569,9 @@ CI_TEST_PREFIX = str(Path(os.getcwd()))
CI_PT_ROOT = str(Path(os.getcwd()).parent)
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
def wait_for_process(p):
def wait_for_process(p, timeout=None):
try:
return p.wait()
return p.wait(timeout=timeout)
except KeyboardInterrupt:
# Give `p` a chance to handle KeyboardInterrupt. Without this,
# `pytest` can't print errors it collected so far upon KeyboardInterrupt.
@ -590,7 +588,7 @@ def wait_for_process(p):
# Always call p.wait() to ensure exit
p.wait()
def shell(command, cwd=None, env=None, stdout=None, stderr=None):
def shell(command, cwd=None, env=None, stdout=None, stderr=None, timeout=None):
sys.stdout.flush()
sys.stderr.flush()
# The following cool snippet is copied from Py3 core library subprocess.call
@ -602,7 +600,22 @@ def shell(command, cwd=None, env=None, stdout=None, stderr=None):
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
assert not isinstance(command, str), "Command to shell should be a list or tuple of tokens"
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env, stdout=stdout, stderr=stderr)
return wait_for_process(p)
return wait_for_process(p, timeout=timeout)
def retry_shell(command, cwd=None, env=None, stdout=None, stderr=None, timeout=None, retries=1):
assert retries >= 0, f"Expecting non negative number for number of retries, got {retries}"
try:
exit_code = shell(command, cwd=cwd, env=env, stdout=stdout, stderr=stderr, timeout=timeout)
if exit_code == 0 or retries == 0:
return exit_code
print(f"Got exit code {exit_code}, retrying (retries left={retries})", file=stdout, flush=True)
except subprocess.TimeoutExpired:
if retries == 0:
print(f"Command took >{timeout // 60}min, returning 124", file=stdout, flush=True)
return 124
print(f"Command took >{timeout // 60}min, retrying (retries left={retries})", file=stdout, flush=True)
return retry_shell(command, cwd=cwd, env=env, stdout=stdout, stderr=stderr, timeout=timeout, retries=retries - 1)
def discover_test_cases_recursively(suite_or_case):
@ -753,7 +766,10 @@ def run_tests(argv=UNITTEST_ARGS):
[test_case_full_name]
)
string_cmd = " ".join(cmd)
exitcode = shell(cmd)
timeout = None if RERUN_DISABLED_TESTS else 15 * 60
exitcode = retry_shell(cmd, timeout=timeout, retries=0 if RERUN_DISABLED_TESTS else 1)
if exitcode != 0:
# This is sort of hacky, but add on relevant env variables for distributed tests.
@ -794,7 +810,6 @@ def run_tests(argv=UNITTEST_ARGS):
exit_code = pytest.main(args=pytest_args)
if TEST_SAVE_XML:
sanitize_pytest_xml(test_report_path)
print("If in CI, skip info is located in the xml test reports, please either go to s3 or the hud to download them")
if not RERUN_DISABLED_TESTS:
# exitcode of 5 means no tests were found, which happens since some test configs don't