test in parallel at file granularity (#84961)

run tests in parallel at the test file granularity

runs 3 files in parallel using multiprocessing pool, output goes to a file, which is then printed when the test finishes.  Some tests cannot be run in parallel (usually due to lacking memory), so we run those after.  Sharding is changed to attempt to mask large files with other large files/run them on the same shard.

test_ops* gets a custom handler to run it because it is simply too big (2hrs on windows) and linalg_cholesky fails (I would really like a solution to this if possible, but until then we use the custom handler).

reduces cuda tests by a lot, reduces total windows test time by ~1hr

Ref. https://github.com/pytorch/pytorch/issues/82894
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84961
Approved by: https://github.com/huydhn
This commit is contained in:
Catherine Lee 2022-09-21 16:58:11 +00:00 committed by PyTorch MergeBot
parent 2fb820455c
commit 8107666c6a
10 changed files with 227 additions and 67 deletions

2
.circleci/config.yml generated
View File

@ -693,7 +693,7 @@ jobs:
- run_brew_for_macos_build
- run:
name: Test
no_output_timeout: "1h"
no_output_timeout: "2h"
command: |
set -x

View File

@ -149,6 +149,11 @@ pytest-xdist
#Pinned versions:
#test that import:
pytest-shard
#Description: plugin spliting up tests in pytest
#Pinned versions:
#test that import:
pytest-rerunfailures
#Description: plugin for rerunning tests in pytest
#Pinned versions:

View File

@ -218,7 +218,7 @@
- run_brew_for_macos_build
- run:
name: Test
no_output_timeout: "1h"
no_output_timeout: "2h"
command: |
set -x

View File

@ -16,6 +16,7 @@ fi
pip install "unittest-xml-reporting<=3.2.0,>=2.0.0" \
pytest \
pytest-xdist \
pytest-shard \
pytest-rerunfailures \
"xdoctest==1.0.2" \
"pygments==2.12.0"

View File

@ -36,7 +36,7 @@ popd
=======
:: Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014
pip install "ninja==1.10.0.post1" future "hypothesis==5.35.1" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest pytest-xdist pytest-rerunfailures "xdoctest==1.0.2" "pygments==2.12.0"
pip install "ninja==1.10.0.post1" future "hypothesis==5.35.1" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest pytest-xdist pytest-shard pytest-rerunfailures "xdoctest==1.0.2" "pygments==2.12.0"
if errorlevel 1 exit /b
if not errorlevel 0 exit /b

View File

@ -1561,6 +1561,9 @@ static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& piv
input_array[i] = &input_data[i * input_matrix_stride];
}
// needed to run lu tests in parallel, see https://github.com/pytorch/pytorch/issues/82894 for examples
// of failures
c10::cuda::device_synchronize();
MAGMAQueue magma_queue(input.get_device());
if (compute_pivots) {

View File

@ -237,6 +237,9 @@ void apply_ldl_solve_cusolver(
auto pivots_ = pivots.to(kLong);
auto pivots_data = pivots_.data_ptr<int64_t>();
// needed to run ldl_solve tests in parallel
// see https://github.com/pytorch/pytorch/issues/82894 for examples of failures
c10::cuda::device_synchronize();
auto handle = at::cuda::getCurrentCUDASolverDnHandle();
auto datatype = at::cuda::solver::get_cusolver_datatype<scalar_t>();
size_t worksize_device = 0;

View File

@ -27,7 +27,7 @@ from torch.testing._internal.common_utils import (
parser as common_parser,
)
import torch.distributed as dist
from torch.multiprocessing import Pool
from torch.multiprocessing import Pool, get_context
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
@ -39,6 +39,7 @@ try:
get_reordered_tests,
get_test_case_configs,
calculate_shards,
NUM_PROCS
)
HAVE_TEST_SELECTION_TOOLS = True
except ImportError:
@ -125,7 +126,6 @@ TESTS = discover_tests(
"distributed/elastic/utils/util_test",
"distributed/elastic/utils/distributed_test",
"distributed/elastic/multiprocessing/api_test",
"test_deploy",
]
)
@ -264,6 +264,29 @@ RUN_PARALLEL_BLOCKLIST = [
"test_cuda_trace",
] + FSDP_TEST
CI_SERIAL_LIST = [
'test_nn',
'test_fake_tensor',
'test_cpp_api_parity',
'test_reductions',
'test_cuda',
'test_jit_cuda_fuser', # OOM on test_issue_1785, also profiling?
'test_indexing',
'test_fx_backends',
'test_linalg',
'test_cpp_extensions_jit',
'test_torch',
'test_tensor_creation_ops',
'test_sparse_csr',
'test_dispatch',
'nn/test_pooling',
'distributions/test_distributions',
'test_autograd', # slow gradcheck runs a test that checks the cuda memory allocator
'test_prims', # slow gradcheck runs a test that checks the cuda memory allocator
'test_modules', # failed test due to mismatched elements
]
# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected
CORE_TEST_LIST = [
"test_autograd",
@ -340,6 +363,7 @@ def discover_functorch_tests():
assert len(result) >= 8
return result
FUNCTORCH_TESTS = discover_functorch_tests()
TESTS_REQUIRING_LAPACK = [
@ -374,7 +398,7 @@ def run_test(
launcher_cmd=None,
extra_unittest_args=None,
env=None,
):
) -> int:
unittest_args = options.additional_unittest_args.copy()
if options.verbose:
unittest_args.append(f'-{"v"*options.verbose}') # in case of pytest
@ -402,9 +426,16 @@ def run_test(
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
argv = [test_module + ".py"] + unittest_args
log_fd, log_path = tempfile.mkstemp(dir=REPO_ROOT / "test" / "test-reports",
prefix=test_module.replace("\\", "-").replace("/", "-"))
os.close(log_fd)
command = (launcher_cmd or []) + executable + argv
print_to_stderr("Executing {} ... [{}]".format(command, datetime.now()))
return shell(command, test_directory, env=env)
with open(log_path, "w") as f:
ret_code = shell(command, test_directory, stdout=f, stderr=f, env=env)
print_log_file(test_module, log_path)
os.remove(log_path)
return ret_code
def test_cuda_primary_ctx(test_module, test_directory, options):
@ -676,6 +707,49 @@ def run_doctests(test_module, test_directory, options):
return result
def print_log_file(test: str, file_path: str) -> None:
with open(file_path, "r") as f:
print_to_stderr("")
print_to_stderr(f"PRINT LOG FILE of {test} ({file_path})")
print_to_stderr(f"##[group]PRINT LOG FILE of {test} ({file_path})")
print_to_stderr(f.read())
print_to_stderr("##[endgroup]")
print_to_stderr(f"FINISHED PRINT LOG FILE of {test} ({file_path})")
print_to_stderr("")
def run_test_ops(test_module, test_directory, options):
if 'slow-gradcheck' in os.getenv("BUILD_ENVIRONMENT", ""):
# there are a lot of tests that take up a lot of space in slowgrad check, so don't bother parallelizing
# it's also on periodic so we don't care about TTS as much
return run_test(test_module, test_directory, copy.deepcopy(options),
extra_unittest_args=["--use-pytest", '-vv', '-x', '--reruns=2', '-rfEX'],
)
return_codes = []
os.environ["PARALLEL_TESTING"] = "1"
pool = Pool(NUM_PROCS)
for i in range(NUM_PROCS):
return_code = pool.apply_async(run_test, args=(test_module, test_directory, copy.deepcopy(options)),
kwds={"extra_unittest_args": ["--use-pytest", '-vv', '-x', '--reruns=2', '-rfEX',
f'--shard-id={i}', f'--num-shards={NUM_PROCS}',
"-k=not _linalg_cholesky_"],
})
return_codes.append(return_code)
pool.close()
pool.join()
del os.environ['PARALLEL_TESTING']
for return_code in return_codes:
if return_code.get() != 0:
return return_code.get()
return_code = run_test(test_module, test_directory, copy.deepcopy(options),
extra_unittest_args=["--use-pytest", '-vv', '-x', '--reruns=2', '-rfEX',
"-k=_linalg_cholesky_"],
)
return return_code
CUSTOM_HANDLERS = {
"test_cuda_primary_ctx": test_cuda_primary_ctx,
"test_cuda_trace": get_run_test_with_subprocess_fn(),
@ -695,6 +769,9 @@ CUSTOM_HANDLERS = {
"distributed/rpc/test_share_memory": get_run_test_with_subprocess_fn(),
"distributed/rpc/cuda/test_tensorpipe_agent": get_run_test_with_subprocess_fn(),
"doctests": run_doctests,
"test_ops": run_test_ops,
"test_ops_gradients": run_test_ops,
"test_ops_jit": run_test_ops,
}
@ -911,6 +988,18 @@ def exclude_tests(exclude_list, selected_tests, exclude_message=None):
return selected_tests
def must_serial(file: str) -> bool:
return (
"distributed" in os.getenv("TEST_CONFIG", "") or
"functorch" in os.getenv("TEST_CONFIG", "") or
"dynamo" in os.getenv("TEST_CONFIG", "") or
"distributed" in file or
file in CUSTOM_HANDLERS or
file in RUN_PARALLEL_BLOCKLIST or
file in CI_SERIAL_LIST
)
def get_selected_tests(options):
selected_tests = options.include
@ -1010,11 +1099,12 @@ def get_selected_tests(options):
print(
"::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan."
)
selected_tests = selected_tests[which_shard - 1 :: num_shards]
selected_tests = selected_tests[which_shard - 1:: num_shards]
else:
print("Found test time stats from artifacts")
test_file_times_config = test_file_times[test_config]
shards = calculate_shards(num_shards, selected_tests, test_file_times_config)
shards = calculate_shards(num_shards, selected_tests, test_file_times_config,
must_serial=must_serial)
_, tests_from_shard = shards[which_shard - 1]
selected_tests = tests_from_shard
@ -1040,7 +1130,7 @@ def run_test_module(test: str, test_directory: str, options) -> Optional[str]:
return_code = handler(test_module, test_directory, options)
assert isinstance(return_code, int) and not isinstance(
return_code, bool
), "Return code should be an integer"
), f"While running {test} got non integer return code {return_code}"
if return_code == 0:
return None
@ -1073,22 +1163,52 @@ def main():
# downloading test cases configuration to local environment
get_test_case_configs(dirpath=test_directory)
has_failed = False
failure_messages = []
selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
selected_tests_serial = [x for x in selected_tests if x not in selected_tests_parallel]
print_to_stderr("parallel tests:\n {}".format("\n ".join(selected_tests_parallel)))
print_to_stderr("serial tests:\n {}".format("\n ".join(selected_tests_serial)))
pool = get_context("spawn").Pool(NUM_PROCS, maxtasksperchild=1)
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
def success_callback(err_message):
if err_message is None:
return True
failure_messages.append(err_message)
print_to_stderr(err_message)
if not options.continue_through_error:
pool.terminate()
return False
try:
for test in selected_tests:
os.environ['PARALLEL_TESTING'] = '1'
for test in selected_tests_parallel:
pool.apply_async(run_test_module, args=(test, test_directory,
copy.deepcopy(options)), callback=success_callback)
pool.close()
pool.join()
del os.environ['PARALLEL_TESTING']
if not options.continue_through_error and len(failure_messages) != 0:
raise RuntimeError("\n".join(failure_messages))
for test in selected_tests_serial:
options_clone = copy.deepcopy(options)
if test in USE_PYTEST_LIST:
options_clone.pytest = True
err_message = run_test_module(test, test_directory, options_clone)
if err_message is None:
continue
has_failed = True
failure_messages.append(err_message)
if not options_clone.continue_through_error:
raise RuntimeError(err_message)
print_to_stderr(err_message)
finally:
pool.terminate()
pool.join()
if options.coverage:
from coverage import Coverage
@ -1101,7 +1221,7 @@ def main():
if not PYTORCH_COLLECT_COVERAGE:
cov.html_report()
if options.continue_through_error and has_failed:
if len(failure_messages) != 0:
for err in failure_messages:
print_to_stderr(err)
sys.exit(1)

View File

@ -1,15 +1,24 @@
import os
import subprocess
from typing import Dict, List, Tuple
from typing import Callable, Dict, List, Optional, Tuple
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
# mac has 3 CPUs and also received the best speedup with 3 processes. Setting this any larger
# will also force use further restrict the amount of memory per process for cuda
NUM_PROCS = 3
def calculate_shards(
num_shards: int, tests: List[str], job_times: Dict[str, float]
num_shards: int,
tests: List[str],
job_times: Dict[str, float],
must_serial: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[float, List[str]]]:
filtered_job_times: Dict[str, float] = {}
must_serial = must_serial if callable(must_serial) else lambda x: True
filtered_job_times: Dict[str, float] = dict()
unknown_jobs: List[str] = []
for test in tests:
if test in job_times:
@ -17,18 +26,30 @@ def calculate_shards(
else:
unknown_jobs.append(test)
# The following attempts to implement a partition approximation greedy algorithm
# See more at https://en.wikipedia.org/wiki/Greedy_number_partitioning
sorted_jobs = sorted(
filtered_job_times, key=lambda j: filtered_job_times[j], reverse=True
)
sharded_jobs: List[Tuple[float, List[str]]] = [(0.0, []) for _ in range(num_shards)]
for job in sorted_jobs:
min_shard_index = sorted(range(num_shards), key=lambda i: sharded_jobs[i][0])[0]
serial = [x for x in sorted_jobs if must_serial(x)]
parallel = [x for x in sorted_jobs if x not in serial]
for i in range(0, len(serial)):
min_shard_index = sorted(range(num_shards), key=lambda j: sharded_jobs[j][0])[0]
curr_shard_time, curr_shard_jobs = sharded_jobs[min_shard_index]
curr_shard_jobs.append(job)
curr_shard_jobs.append(serial[i])
sharded_jobs[min_shard_index] = (
curr_shard_time + filtered_job_times[job],
curr_shard_time + filtered_job_times[serial[i]],
curr_shard_jobs,
)
# Not the best idea, but attempt to mask the long jobs with other long jobs
for i in range(0, len(parallel), NUM_PROCS):
min_shard_index = sorted(range(num_shards), key=lambda j: sharded_jobs[j][0])[0]
curr_shard_time, curr_shard_jobs = sharded_jobs[min_shard_index]
curr_shard_jobs.extend(parallel[i : i + NUM_PROCS])
sharded_jobs[min_shard_index] = (
curr_shard_time + filtered_job_times[parallel[i]],
curr_shard_jobs,
)

View File

@ -95,8 +95,6 @@ from .composite_compliance import no_dispatch
torch.backends.disable_global_flags()
PYTEST_FILES = ["test_ops", "test_ops_gradients", "test_ops_jit"]
FILE_SCHEMA = "file://"
if sys.platform == 'win32':
FILE_SCHEMA = "file:///"
@ -498,6 +496,7 @@ parser.add_argument('--accept', action='store_true')
parser.add_argument('--jit_executor', type=str)
parser.add_argument('--repeat', type=int, default=1)
parser.add_argument('--test_bailouts', action='store_true')
parser.add_argument('--use-pytest', action='store_true')
parser.add_argument('--save-xml', nargs='?', type=str,
const=_get_test_report_path(),
default=_get_test_report_path() if IS_CI else None)
@ -533,6 +532,7 @@ DISABLED_TESTS_FILE = args.import_disabled_tests
LOG_SUFFIX = args.log_suffix
RUN_PARALLEL = args.run_parallel
TEST_BAILOUTS = args.test_bailouts
USE_PYTEST = args.use_pytest
TEST_DISCOVER = args.discover_tests
TEST_IN_SUBPROCESS = args.subprocess
TEST_SAVE_XML = args.save_xml
@ -567,7 +567,7 @@ def wait_for_process(p):
# Always call p.wait() to ensure exit
p.wait()
def shell(command, cwd=None, env=None):
def shell(command, cwd=None, env=None, stdout=None, stderr=None):
sys.stdout.flush()
sys.stderr.flush()
# The following cool snippet is copied from Py3 core library subprocess.call
@ -578,7 +578,7 @@ def shell(command, cwd=None, env=None):
#
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env)
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env, stdout=stdout, stderr=stderr)
return wait_for_process(p)
@ -638,6 +638,22 @@ def lint_test_case_extension(suite):
succeed = False
return succeed
def get_report_path(pytest=False):
test_filename = inspect.getfile(sys._getframe(2))
test_filename = sanitize_if_functorch_test_filename(test_filename)
test_filename = sanitize_test_filename(test_filename)
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
test_report_path = os.path.join(test_report_path, test_filename)
if pytest:
test_report_path = test_report_path.replace('python-unittest', 'python-pytest')
os.makedirs(test_report_path, exist_ok=True)
test_report_path = os.path.join(test_report_path, f"{test_filename}-{os.urandom(8).hex()}.xml")
return test_report_path
os.makedirs(test_report_path, exist_ok=True)
return test_report_path
def sanitize_pytest_xml(xml_file: str):
# pytext xml is different from unittext xml, this function makes pytest xml more similar to unittest xml
# consider somehow modifying the XML logger in conftest to do this instead
@ -718,6 +734,22 @@ def run_tests(argv=UNITTEST_ARGS):
for p in processes:
failed |= wait_for_process(p) != 0
assert not failed, "Some test shards have failed"
elif USE_PYTEST:
if TEST_SAVE_XML:
test_report_path = get_report_path(pytest=True)
print(f'Test results will be stored in {test_report_path}')
import pytest
os.environ["NO_COLOR"] = "1"
os.environ["USING_PYTEST"] = "1"
exit_code = pytest.main(args=argv + [f'--junit-xml-reruns={test_report_path}'] if TEST_SAVE_XML else [])
del os.environ["USING_PYTEST"]
if TEST_SAVE_XML:
sanitize_pytest_xml(test_report_path)
print("If in CI, skip info is located in the xml test reports, please either go to s3 or the hud to download them")
# exitcode of 5 means no tests were found, which happens since some test configs don't
# run tests from certain files
exit(0 if exit_code == 5 else exit_code)
elif TEST_SAVE_XML is not None:
# import here so that non-CI doesn't need xmlrunner installed
import xmlrunner # type: ignore[import]
@ -744,46 +776,14 @@ def run_tests(argv=UNITTEST_ARGS):
# it stands for `verbose_str` captured in the closure
c.cell_contents = f"skip: {reason}"
test_filename = inspect.getfile(sys._getframe(1))
test_filename = sanitize_if_functorch_test_filename(test_filename)
test_filename = sanitize_test_filename(test_filename)
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
test_report_path = os.path.join(test_report_path, test_filename)
build_environment = os.environ.get("BUILD_ENVIRONMENT", "")
if test_filename in PYTEST_FILES and not IS_SANDCASTLE and not (
"cuda" in build_environment and "linux" in build_environment
):
# exclude linux cuda tests because we run into memory issues when running in parallel
import pytest
os.environ["NO_COLOR"] = "1"
os.environ["USING_PYTEST"] = "1"
pytest_report_path = test_report_path.replace('python-unittest', 'python-pytest')
os.makedirs(pytest_report_path, exist_ok=True)
# part of our xml parsing looks for grandparent folder names
pytest_report_path = os.path.join(pytest_report_path, f"{test_filename}.xml")
print(f'Test results will be stored in {pytest_report_path}')
# mac slower on 4 proc than 3
num_procs = 3 if "macos" in build_environment else 4
# f = failed
# E = error
# X = unexpected success
exit_code = pytest.main(args=[inspect.getfile(sys._getframe(1)), f'-n={num_procs}', '-vv', '-x',
'--reruns=2', '-rfEX', f'--junit-xml-reruns={pytest_report_path}'])
del os.environ["USING_PYTEST"]
sanitize_pytest_xml(f'{pytest_report_path}')
print("Skip info is located in the xml test reports, please either go to s3 or the hud to download them")
# exitcode of 5 means no tests were found, which happens since some test configs don't
# run tests from certain files
exit(0 if exit_code == 5 else exit_code)
else:
os.makedirs(test_report_path, exist_ok=True)
verbose = '--verbose' in argv or '-v' in argv
if verbose:
print(f'Test results will be stored in {test_report_path}')
unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner(
output=test_report_path,
verbosity=2 if verbose else 1,
resultclass=XMLTestResultVerbose))
test_report_path = get_report_path()
verbose = '--verbose' in argv or '-v' in argv
if verbose:
print(f'Test results will be stored in {test_report_path}')
unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner(
output=test_report_path,
verbosity=2 if verbose else 1,
resultclass=XMLTestResultVerbose))
elif REPEAT_COUNT > 1:
for _ in range(REPEAT_COUNT):
if not unittest.main(exit=False, argv=argv).result.wasSuccessful():
@ -904,6 +904,13 @@ TEST_SKIP_FAST = os.getenv('PYTORCH_TEST_SKIP_FAST', '0') == '1'
# as we had before. By default, we don't run these tests.
TEST_WITH_CROSSREF = os.getenv('PYTORCH_TEST_WITH_CROSSREF', '0') == '1'
if TEST_CUDA and 'NUM_PARALLEL_PROCS' in os.environ:
from tools.testing.test_selections import NUM_PROCS
# other libraries take up about 11% of space per process
torch.cuda.set_per_process_memory_fraction(round(1 / NUM_PROCS - .11, 2))
def skipIfCrossRef(fn):
@wraps(fn)
def wrapper(*args, **kwargs):