mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Move all downloading logic out of common_utils.py (#61479)
Summary: and into tools/ folder Currently run_tests.py invokes tools/test_selections.py 1. download and analyze what test_file to run 2. download and parse S3 stats and pass the info to local files. 3. common_utils.py uses download S3 stats to determine what test cases to run. Pull Request resolved: https://github.com/pytorch/pytorch/pull/61479 Reviewed By: janeyx99 Differential Revision: D29661986 Pulled By: walterddr fbshipit-source-id: bebd8c474bcc2444e135bfd2fa4bdd1eefafe595
This commit is contained in:
parent
2aedd17661
commit
a5a10fe353
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -16,8 +16,9 @@ coverage.xml
|
||||||
.mypy_cache
|
.mypy_cache
|
||||||
/.extracted_scripts/
|
/.extracted_scripts/
|
||||||
**/.pytorch_specified_test_cases.csv
|
**/.pytorch_specified_test_cases.csv
|
||||||
**/.pytorch-test-times.json
|
**/.pytorch-disabled-tests.json
|
||||||
**/.pytorch-slow-tests.json
|
**/.pytorch-slow-tests.json
|
||||||
|
**/.pytorch-test-times.json
|
||||||
*/*.pyc
|
*/*.pyc
|
||||||
*/*.so*
|
*/*.so*
|
||||||
*/**/__pycache__
|
*/**/__pycache__
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,8 @@ try:
|
||||||
get_shard_based_on_S3,
|
get_shard_based_on_S3,
|
||||||
get_slow_tests_based_on_S3,
|
get_slow_tests_based_on_S3,
|
||||||
get_specified_test_cases,
|
get_specified_test_cases,
|
||||||
get_reordered_tests
|
get_reordered_tests,
|
||||||
|
get_test_case_configs,
|
||||||
)
|
)
|
||||||
HAVE_TEST_SELECTION_TOOLS = True
|
HAVE_TEST_SELECTION_TOOLS = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
@ -451,6 +452,9 @@ def run_test(test_module, test_directory, options, launcher_cmd=None, extra_unit
|
||||||
# If using pytest, replace -f with equivalent -x
|
# If using pytest, replace -f with equivalent -x
|
||||||
if options.pytest:
|
if options.pytest:
|
||||||
unittest_args = [arg if arg != '-f' else '-x' for arg in unittest_args]
|
unittest_args = [arg if arg != '-f' else '-x' for arg in unittest_args]
|
||||||
|
elif IS_IN_CI:
|
||||||
|
# use the downloaded test cases configuration, not supported in pytest
|
||||||
|
unittest_args.extend(['--import-slow-tests', '--import-disabled-tests'])
|
||||||
|
|
||||||
# Multiprocessing related tests cannot run with coverage.
|
# Multiprocessing related tests cannot run with coverage.
|
||||||
# Tracking issue: https://github.com/pytorch/pytorch/issues/50661
|
# Tracking issue: https://github.com/pytorch/pytorch/issues/50661
|
||||||
|
|
@ -1044,6 +1048,8 @@ def main():
|
||||||
|
|
||||||
if IS_IN_CI:
|
if IS_IN_CI:
|
||||||
selected_tests = get_reordered_tests(selected_tests, ENABLE_PR_HISTORY_REORDERING)
|
selected_tests = get_reordered_tests(selected_tests, ENABLE_PR_HISTORY_REORDERING)
|
||||||
|
# downloading test cases configuration to local environment
|
||||||
|
get_test_case_configs(dirpath=os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
has_failed = False
|
has_failed = False
|
||||||
failure_messages = []
|
failure_messages = []
|
||||||
|
|
|
||||||
83
tools/stats/import_test_stats.py
Normal file
83
tools/stats/import_test_stats.py
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
from typing import Any, Callable, Dict, Optional, cast
|
||||||
|
from urllib.request import urlopen
|
||||||
|
|
||||||
|
SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
|
||||||
|
DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
|
||||||
|
|
||||||
|
FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
|
||||||
|
|
||||||
|
def fetch_and_cache(
|
||||||
|
dirpath: str,
|
||||||
|
name: str,
|
||||||
|
url: str,
|
||||||
|
process_fn: Callable[[Dict[str, Any]], Dict[str, Any]]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
This fetch and cache utils allows sharing between different process.
|
||||||
|
"""
|
||||||
|
path = os.path.join(dirpath, name)
|
||||||
|
|
||||||
|
def is_cached_file_valid() -> bool:
|
||||||
|
# Check if the file is new enough (see: FILE_CACHE_LIFESPAN_SECONDS). A real check
|
||||||
|
# could make a HEAD request and check/store the file's ETag
|
||||||
|
fname = pathlib.Path(path)
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
mtime = datetime.datetime.fromtimestamp(fname.stat().st_mtime)
|
||||||
|
diff = now - mtime
|
||||||
|
return diff.total_seconds() < FILE_CACHE_LIFESPAN_SECONDS
|
||||||
|
|
||||||
|
if os.path.exists(path) and is_cached_file_valid():
|
||||||
|
# Another test process already downloaded the file, so don't re-do it
|
||||||
|
with open(path, "r") as f:
|
||||||
|
return cast(Dict[str, Any], json.load(f))
|
||||||
|
try:
|
||||||
|
contents = urlopen(url, timeout=1).read().decode('utf-8')
|
||||||
|
processed_contents = process_fn(json.loads(contents))
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(json.dumps(processed_contents))
|
||||||
|
return processed_contents
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Could not download {url} because of error {e}.')
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_slow_tests(dirpath: str, filename: str = SLOW_TESTS_FILE) -> Optional[Dict[str, float]]:
|
||||||
|
url = "https://raw.githubusercontent.com/pytorch/test-infra/master/stats/slow-tests.json"
|
||||||
|
try:
|
||||||
|
return fetch_and_cache(dirpath, filename, url, lambda x: x)
|
||||||
|
except Exception:
|
||||||
|
print("Couldn't download slow test set, leaving all tests enabled...")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_disabled_tests(dirpath: str, filename: str = DISABLED_TESTS_FILE) -> Optional[Dict[str, Any]]:
|
||||||
|
def process_disabled_test(the_response: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
disabled_test_from_issues = dict()
|
||||||
|
for item in the_response['items']:
|
||||||
|
title = item['title']
|
||||||
|
key = 'DISABLED '
|
||||||
|
if title.startswith(key):
|
||||||
|
test_name = title[len(key):].strip()
|
||||||
|
body = item['body']
|
||||||
|
platforms_to_skip = []
|
||||||
|
key = 'platforms:'
|
||||||
|
for line in body.splitlines():
|
||||||
|
line = line.lower()
|
||||||
|
if line.startswith(key):
|
||||||
|
pattern = re.compile(r"^\s+|\s*,\s*|\s+$")
|
||||||
|
platforms_to_skip.extend([x for x in pattern.split(line[len(key):]) if x])
|
||||||
|
disabled_test_from_issues[test_name] = (item['html_url'], platforms_to_skip)
|
||||||
|
return disabled_test_from_issues
|
||||||
|
try:
|
||||||
|
url = 'https://raw.githubusercontent.com/pytorch/test-infra/master/stats/disabled-tests.json'
|
||||||
|
return fetch_and_cache(dirpath, filename, url, process_disabled_test)
|
||||||
|
except Exception:
|
||||||
|
print("Couldn't download test skip set, leaving all tests enabled...")
|
||||||
|
return {}
|
||||||
|
|
@ -8,6 +8,10 @@ from tools.stats.s3_stat_parser import (
|
||||||
get_previous_reports_for_pr,
|
get_previous_reports_for_pr,
|
||||||
Report, Version2Report,
|
Report, Version2Report,
|
||||||
HAVE_BOTO3)
|
HAVE_BOTO3)
|
||||||
|
from tools.stats.import_test_stats import (
|
||||||
|
get_disabled_tests,
|
||||||
|
get_slow_tests
|
||||||
|
)
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
@ -284,3 +288,8 @@ def export_S3_test_times(test_times_filename: Optional[str] = None) -> Dict[str,
|
||||||
json.dump(job_times_json, file, indent=' ', separators=(',', ': '))
|
json.dump(job_times_json, file, indent=' ', separators=(',', ': '))
|
||||||
file.write('\n')
|
file.write('\n')
|
||||||
return test_times
|
return test_times
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_case_configs(dirpath: str) -> None:
|
||||||
|
get_slow_tests(dirpath=dirpath)
|
||||||
|
get_disabled_tests(dirpath=dirpath)
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ import warnings
|
||||||
import random
|
import random
|
||||||
import contextlib
|
import contextlib
|
||||||
import shutil
|
import shutil
|
||||||
import datetime
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
@ -37,7 +36,6 @@ from copy import deepcopy
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
import tempfile
|
import tempfile
|
||||||
import json
|
import json
|
||||||
from urllib.request import urlopen
|
|
||||||
import __main__ # type: ignore[import]
|
import __main__ # type: ignore[import]
|
||||||
import errno
|
import errno
|
||||||
from typing import cast, Any, Dict, Iterable, Iterator, Optional, Union
|
from typing import cast, Any, Dict, Iterable, Iterator, Optional, Union
|
||||||
|
|
@ -69,6 +67,12 @@ IS_SANDCASTLE = os.getenv('SANDCASTLE') == '1' or os.getenv('TW_JOB_USER') == 's
|
||||||
IS_FBCODE = os.getenv('PYTORCH_TEST_FBCODE') == '1'
|
IS_FBCODE = os.getenv('PYTORCH_TEST_FBCODE') == '1'
|
||||||
IS_REMOTE_GPU = os.getenv('PYTORCH_TEST_REMOTE_GPU') == '1'
|
IS_REMOTE_GPU = os.getenv('PYTORCH_TEST_REMOTE_GPU') == '1'
|
||||||
|
|
||||||
|
DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
|
||||||
|
SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
|
||||||
|
|
||||||
|
slow_tests_dict: Optional[Dict[str, Any]] = None
|
||||||
|
disabled_tests_dict: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
class ProfilingMode(Enum):
|
class ProfilingMode(Enum):
|
||||||
LEGACY = 1
|
LEGACY = 1
|
||||||
SIMPLE = 2
|
SIMPLE = 2
|
||||||
|
|
@ -164,6 +168,8 @@ parser.add_argument('--save-xml', nargs='?', type=str,
|
||||||
parser.add_argument('--discover-tests', action='store_true')
|
parser.add_argument('--discover-tests', action='store_true')
|
||||||
parser.add_argument('--log-suffix', type=str, default="")
|
parser.add_argument('--log-suffix', type=str, default="")
|
||||||
parser.add_argument('--run-parallel', type=int, default=1)
|
parser.add_argument('--run-parallel', type=int, default=1)
|
||||||
|
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=SLOW_TESTS_FILE)
|
||||||
|
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DISABLED_TESTS_FILE)
|
||||||
|
|
||||||
args, remaining = parser.parse_known_args()
|
args, remaining = parser.parse_known_args()
|
||||||
if args.jit_executor == 'legacy':
|
if args.jit_executor == 'legacy':
|
||||||
|
|
@ -177,6 +183,8 @@ else:
|
||||||
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
|
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
|
||||||
|
|
||||||
|
|
||||||
|
IMPORT_SLOW_TESTS = args.import_slow_tests
|
||||||
|
IMPORT_DISABLED_TESTS = args.import_disabled_tests
|
||||||
LOG_SUFFIX = args.log_suffix
|
LOG_SUFFIX = args.log_suffix
|
||||||
RUN_PARALLEL = args.run_parallel
|
RUN_PARALLEL = args.run_parallel
|
||||||
TEST_BAILOUTS = args.test_bailouts
|
TEST_BAILOUTS = args.test_bailouts
|
||||||
|
|
@ -263,6 +271,16 @@ def sanitize_test_filename(filename):
|
||||||
return re.sub('/', r'.', strip_py)
|
return re.sub('/', r'.', strip_py)
|
||||||
|
|
||||||
def run_tests(argv=UNITTEST_ARGS):
|
def run_tests(argv=UNITTEST_ARGS):
|
||||||
|
# import test files.
|
||||||
|
if IMPORT_SLOW_TESTS:
|
||||||
|
global slow_tests_dict
|
||||||
|
with open(IMPORT_SLOW_TESTS, 'r') as fp:
|
||||||
|
slow_tests_dict = json.load(fp)
|
||||||
|
if IMPORT_DISABLED_TESTS:
|
||||||
|
global disabled_tests_dict
|
||||||
|
with open(IMPORT_DISABLED_TESTS, 'r') as fp:
|
||||||
|
disabled_tests_dict = json.load(fp)
|
||||||
|
# Determine the test launch mechanism
|
||||||
if TEST_DISCOVER:
|
if TEST_DISCOVER:
|
||||||
suite = unittest.TestLoader().loadTestsFromModule(__main__)
|
suite = unittest.TestLoader().loadTestsFromModule(__main__)
|
||||||
test_cases = discover_test_cases_recursively(suite)
|
test_cases = discover_test_cases_recursively(suite)
|
||||||
|
|
@ -842,93 +860,16 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print('Fail to import hypothesis in common_utils, tests are not derandomized')
|
print('Fail to import hypothesis in common_utils, tests are not derandomized')
|
||||||
|
|
||||||
|
def check_if_enable(test: unittest.TestCase):
|
||||||
FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
|
|
||||||
|
|
||||||
def fetch_and_cache(name: str, url: str):
|
|
||||||
"""
|
|
||||||
Some tests run in a different process so globals like `slow_test_dict` won't
|
|
||||||
always be filled even though the test file was already downloaded on this
|
|
||||||
machine, so cache it on disk
|
|
||||||
"""
|
|
||||||
path = os.path.join(tempfile.gettempdir(), name)
|
|
||||||
|
|
||||||
def is_cached_file_valid():
|
|
||||||
# Check if the file is new enough (say 1 hour for now). A real check
|
|
||||||
# could make a HEAD request and check/store the file's ETag
|
|
||||||
fname = pathlib.Path(path)
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
mtime = datetime.datetime.fromtimestamp(fname.stat().st_mtime)
|
|
||||||
diff = now - mtime
|
|
||||||
return diff.total_seconds() < FILE_CACHE_LIFESPAN_SECONDS
|
|
||||||
|
|
||||||
if os.path.exists(path) and is_cached_file_valid():
|
|
||||||
# Another test process already downloaded the file, so don't re-do it
|
|
||||||
with open(path, "r") as f:
|
|
||||||
return json.load(f)
|
|
||||||
try:
|
|
||||||
contents = urlopen(url, timeout=1).read().decode('utf-8')
|
|
||||||
with open(path, "w") as f:
|
|
||||||
f.write(contents)
|
|
||||||
return json.loads(contents)
|
|
||||||
except Exception as e:
|
|
||||||
print(f'Could not download {url} because of error {e}.')
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
slow_tests_dict: Optional[Dict[str, float]] = None
|
|
||||||
def check_slow_test_from_stats(test):
|
|
||||||
global slow_tests_dict
|
|
||||||
if slow_tests_dict is None:
|
|
||||||
if not IS_SANDCASTLE:
|
|
||||||
url = "https://raw.githubusercontent.com/pytorch/test-infra/master/stats/slow-tests.json"
|
|
||||||
slow_tests_dict = fetch_and_cache(".pytorch-slow-tests.json", url)
|
|
||||||
else:
|
|
||||||
slow_tests_dict = {}
|
|
||||||
test_suite = str(test.__class__).split('\'')[1]
|
test_suite = str(test.__class__).split('\'')[1]
|
||||||
test_name = f'{test._testMethodName} ({test_suite})'
|
test_name = f'{test._testMethodName} ({test_suite})'
|
||||||
|
if slow_tests_dict is not None and test_name in slow_tests_dict:
|
||||||
if test_name in slow_tests_dict:
|
|
||||||
getattr(test, test._testMethodName).__dict__['slow_test'] = True
|
getattr(test, test._testMethodName).__dict__['slow_test'] = True
|
||||||
if not TEST_WITH_SLOW:
|
if not TEST_WITH_SLOW:
|
||||||
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
|
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
|
||||||
|
if not IS_SANDCASTLE and disabled_tests_dict is not None:
|
||||||
|
if test_name in disabled_tests_dict:
|
||||||
disabled_test_from_issues: Optional[Dict[str, Any]] = None
|
issue_url, platforms = disabled_tests_dict[test_name]
|
||||||
def check_disabled(test_name):
|
|
||||||
global disabled_test_from_issues
|
|
||||||
if disabled_test_from_issues is None:
|
|
||||||
_disabled_test_from_issues: Dict = {}
|
|
||||||
|
|
||||||
def read_and_process():
|
|
||||||
url = 'https://raw.githubusercontent.com/pytorch/test-infra/master/stats/disabled-tests.json'
|
|
||||||
the_response = fetch_and_cache(".pytorch-disabled-tests", url)
|
|
||||||
for item in the_response['items']:
|
|
||||||
title = item['title']
|
|
||||||
key = 'DISABLED '
|
|
||||||
if title.startswith(key):
|
|
||||||
test_name = title[len(key):].strip()
|
|
||||||
body = item['body']
|
|
||||||
platforms_to_skip = []
|
|
||||||
key = 'platforms:'
|
|
||||||
for line in body.splitlines():
|
|
||||||
line = line.lower()
|
|
||||||
if line.startswith(key):
|
|
||||||
pattern = re.compile(r"^\s+|\s*,\s*|\s+$")
|
|
||||||
platforms_to_skip.extend([x for x in pattern.split(line[len(key):]) if x])
|
|
||||||
_disabled_test_from_issues[test_name] = (item['html_url'], platforms_to_skip)
|
|
||||||
|
|
||||||
if not IS_SANDCASTLE and os.getenv("PYTORCH_RUN_DISABLED_TESTS", "0") != "1":
|
|
||||||
try:
|
|
||||||
read_and_process()
|
|
||||||
disabled_test_from_issues = _disabled_test_from_issues
|
|
||||||
except Exception:
|
|
||||||
print("Couldn't download test skip set, leaving all tests enabled...")
|
|
||||||
disabled_test_from_issues = {}
|
|
||||||
|
|
||||||
if disabled_test_from_issues is not None:
|
|
||||||
if test_name in disabled_test_from_issues:
|
|
||||||
issue_url, platforms = disabled_test_from_issues[test_name]
|
|
||||||
platform_to_conditional: Dict = {
|
platform_to_conditional: Dict = {
|
||||||
"mac": IS_MACOS,
|
"mac": IS_MACOS,
|
||||||
"macos": IS_MACOS,
|
"macos": IS_MACOS,
|
||||||
|
|
@ -940,7 +881,9 @@ def check_disabled(test_name):
|
||||||
f"Test is disabled because an issue exists disabling it: {issue_url}" +
|
f"Test is disabled because an issue exists disabling it: {issue_url}" +
|
||||||
f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}." +
|
f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}." +
|
||||||
" To enable, set the environment variable PYTORCH_RUN_DISABLED_TESTS=1")
|
" To enable, set the environment variable PYTORCH_RUN_DISABLED_TESTS=1")
|
||||||
|
if TEST_SKIP_FAST:
|
||||||
|
if not getattr(test, test._testMethodName).__dict__.get('slow_test', False):
|
||||||
|
raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")
|
||||||
|
|
||||||
# Acquires the comparison dtype, required since isclose
|
# Acquires the comparison dtype, required since isclose
|
||||||
# requires both inputs have the same dtype, and isclose is not supported
|
# requires both inputs have the same dtype, and isclose is not supported
|
||||||
|
|
@ -1106,12 +1049,7 @@ class TestCase(expecttest.TestCase):
|
||||||
result.stop()
|
result.stop()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
check_slow_test_from_stats(self)
|
check_if_enable(self)
|
||||||
if TEST_SKIP_FAST:
|
|
||||||
if not getattr(self, self._testMethodName).__dict__.get('slow_test', False):
|
|
||||||
raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")
|
|
||||||
check_disabled(str(self))
|
|
||||||
|
|
||||||
set_rng_seed(SEED)
|
set_rng_seed(SEED)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user