mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Move all downloading logic out of common_utils.py (#61479)
Summary: and into tools/ folder Currently run_tests.py invokes tools/test_selections.py 1. download and analyze what test_file to run 2. download and parse S3 stats and pass the info to local files. 3. common_utils.py uses download S3 stats to determine what test cases to run. Pull Request resolved: https://github.com/pytorch/pytorch/pull/61479 Reviewed By: janeyx99 Differential Revision: D29661986 Pulled By: walterddr fbshipit-source-id: bebd8c474bcc2444e135bfd2fa4bdd1eefafe595
This commit is contained in:
parent
2aedd17661
commit
a5a10fe353
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -16,8 +16,9 @@ coverage.xml
|
|||
.mypy_cache
|
||||
/.extracted_scripts/
|
||||
**/.pytorch_specified_test_cases.csv
|
||||
**/.pytorch-test-times.json
|
||||
**/.pytorch-disabled-tests.json
|
||||
**/.pytorch-slow-tests.json
|
||||
**/.pytorch-test-times.json
|
||||
*/*.pyc
|
||||
*/*.so*
|
||||
*/**/__pycache__
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ try:
|
|||
get_shard_based_on_S3,
|
||||
get_slow_tests_based_on_S3,
|
||||
get_specified_test_cases,
|
||||
get_reordered_tests
|
||||
get_reordered_tests,
|
||||
get_test_case_configs,
|
||||
)
|
||||
HAVE_TEST_SELECTION_TOOLS = True
|
||||
except ImportError:
|
||||
|
|
@ -451,6 +452,9 @@ def run_test(test_module, test_directory, options, launcher_cmd=None, extra_unit
|
|||
# If using pytest, replace -f with equivalent -x
|
||||
if options.pytest:
|
||||
unittest_args = [arg if arg != '-f' else '-x' for arg in unittest_args]
|
||||
elif IS_IN_CI:
|
||||
# use the downloaded test cases configuration, not supported in pytest
|
||||
unittest_args.extend(['--import-slow-tests', '--import-disabled-tests'])
|
||||
|
||||
# Multiprocessing related tests cannot run with coverage.
|
||||
# Tracking issue: https://github.com/pytorch/pytorch/issues/50661
|
||||
|
|
@ -1044,6 +1048,8 @@ def main():
|
|||
|
||||
if IS_IN_CI:
|
||||
selected_tests = get_reordered_tests(selected_tests, ENABLE_PR_HISTORY_REORDERING)
|
||||
# downloading test cases configuration to local environment
|
||||
get_test_case_configs(dirpath=os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
has_failed = False
|
||||
failure_messages = []
|
||||
|
|
|
|||
83
tools/stats/import_test_stats.py
Normal file
83
tools/stats/import_test_stats.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
from typing import Any, Callable, Dict, Optional, cast
|
||||
from urllib.request import urlopen
|
||||
|
||||
SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
|
||||
DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
|
||||
|
||||
FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
|
||||
|
||||
def fetch_and_cache(
|
||||
dirpath: str,
|
||||
name: str,
|
||||
url: str,
|
||||
process_fn: Callable[[Dict[str, Any]], Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
This fetch and cache utils allows sharing between different process.
|
||||
"""
|
||||
path = os.path.join(dirpath, name)
|
||||
|
||||
def is_cached_file_valid() -> bool:
|
||||
# Check if the file is new enough (see: FILE_CACHE_LIFESPAN_SECONDS). A real check
|
||||
# could make a HEAD request and check/store the file's ETag
|
||||
fname = pathlib.Path(path)
|
||||
now = datetime.datetime.now()
|
||||
mtime = datetime.datetime.fromtimestamp(fname.stat().st_mtime)
|
||||
diff = now - mtime
|
||||
return diff.total_seconds() < FILE_CACHE_LIFESPAN_SECONDS
|
||||
|
||||
if os.path.exists(path) and is_cached_file_valid():
|
||||
# Another test process already downloaded the file, so don't re-do it
|
||||
with open(path, "r") as f:
|
||||
return cast(Dict[str, Any], json.load(f))
|
||||
try:
|
||||
contents = urlopen(url, timeout=1).read().decode('utf-8')
|
||||
processed_contents = process_fn(json.loads(contents))
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps(processed_contents))
|
||||
return processed_contents
|
||||
except Exception as e:
|
||||
print(f'Could not download {url} because of error {e}.')
|
||||
return {}
|
||||
|
||||
|
||||
def get_slow_tests(dirpath: str, filename: str = SLOW_TESTS_FILE) -> Optional[Dict[str, float]]:
|
||||
url = "https://raw.githubusercontent.com/pytorch/test-infra/master/stats/slow-tests.json"
|
||||
try:
|
||||
return fetch_and_cache(dirpath, filename, url, lambda x: x)
|
||||
except Exception:
|
||||
print("Couldn't download slow test set, leaving all tests enabled...")
|
||||
return {}
|
||||
|
||||
|
||||
def get_disabled_tests(dirpath: str, filename: str = DISABLED_TESTS_FILE) -> Optional[Dict[str, Any]]:
|
||||
def process_disabled_test(the_response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
disabled_test_from_issues = dict()
|
||||
for item in the_response['items']:
|
||||
title = item['title']
|
||||
key = 'DISABLED '
|
||||
if title.startswith(key):
|
||||
test_name = title[len(key):].strip()
|
||||
body = item['body']
|
||||
platforms_to_skip = []
|
||||
key = 'platforms:'
|
||||
for line in body.splitlines():
|
||||
line = line.lower()
|
||||
if line.startswith(key):
|
||||
pattern = re.compile(r"^\s+|\s*,\s*|\s+$")
|
||||
platforms_to_skip.extend([x for x in pattern.split(line[len(key):]) if x])
|
||||
disabled_test_from_issues[test_name] = (item['html_url'], platforms_to_skip)
|
||||
return disabled_test_from_issues
|
||||
try:
|
||||
url = 'https://raw.githubusercontent.com/pytorch/test-infra/master/stats/disabled-tests.json'
|
||||
return fetch_and_cache(dirpath, filename, url, process_disabled_test)
|
||||
except Exception:
|
||||
print("Couldn't download test skip set, leaving all tests enabled...")
|
||||
return {}
|
||||
|
|
@ -8,6 +8,10 @@ from tools.stats.s3_stat_parser import (
|
|||
get_previous_reports_for_pr,
|
||||
Report, Version2Report,
|
||||
HAVE_BOTO3)
|
||||
from tools.stats.import_test_stats import (
|
||||
get_disabled_tests,
|
||||
get_slow_tests
|
||||
)
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing_extensions import TypedDict
|
||||
|
|
@ -284,3 +288,8 @@ def export_S3_test_times(test_times_filename: Optional[str] = None) -> Dict[str,
|
|||
json.dump(job_times_json, file, indent=' ', separators=(',', ': '))
|
||||
file.write('\n')
|
||||
return test_times
|
||||
|
||||
|
||||
def get_test_case_configs(dirpath: str) -> None:
|
||||
get_slow_tests(dirpath=dirpath)
|
||||
get_disabled_tests(dirpath=dirpath)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ import warnings
|
|||
import random
|
||||
import contextlib
|
||||
import shutil
|
||||
import datetime
|
||||
import pathlib
|
||||
import socket
|
||||
import subprocess
|
||||
|
|
@ -37,7 +36,6 @@ from copy import deepcopy
|
|||
from numbers import Number
|
||||
import tempfile
|
||||
import json
|
||||
from urllib.request import urlopen
|
||||
import __main__ # type: ignore[import]
|
||||
import errno
|
||||
from typing import cast, Any, Dict, Iterable, Iterator, Optional, Union
|
||||
|
|
@ -69,6 +67,12 @@ IS_SANDCASTLE = os.getenv('SANDCASTLE') == '1' or os.getenv('TW_JOB_USER') == 's
|
|||
IS_FBCODE = os.getenv('PYTORCH_TEST_FBCODE') == '1'
|
||||
IS_REMOTE_GPU = os.getenv('PYTORCH_TEST_REMOTE_GPU') == '1'
|
||||
|
||||
DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
|
||||
SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
|
||||
|
||||
slow_tests_dict: Optional[Dict[str, Any]] = None
|
||||
disabled_tests_dict: Optional[Dict[str, Any]] = None
|
||||
|
||||
class ProfilingMode(Enum):
|
||||
LEGACY = 1
|
||||
SIMPLE = 2
|
||||
|
|
@ -164,6 +168,8 @@ parser.add_argument('--save-xml', nargs='?', type=str,
|
|||
parser.add_argument('--discover-tests', action='store_true')
|
||||
parser.add_argument('--log-suffix', type=str, default="")
|
||||
parser.add_argument('--run-parallel', type=int, default=1)
|
||||
parser.add_argument('--import-slow-tests', type=str, nargs='?', const=SLOW_TESTS_FILE)
|
||||
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DISABLED_TESTS_FILE)
|
||||
|
||||
args, remaining = parser.parse_known_args()
|
||||
if args.jit_executor == 'legacy':
|
||||
|
|
@ -177,6 +183,8 @@ else:
|
|||
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
|
||||
|
||||
|
||||
IMPORT_SLOW_TESTS = args.import_slow_tests
|
||||
IMPORT_DISABLED_TESTS = args.import_disabled_tests
|
||||
LOG_SUFFIX = args.log_suffix
|
||||
RUN_PARALLEL = args.run_parallel
|
||||
TEST_BAILOUTS = args.test_bailouts
|
||||
|
|
@ -263,6 +271,16 @@ def sanitize_test_filename(filename):
|
|||
return re.sub('/', r'.', strip_py)
|
||||
|
||||
def run_tests(argv=UNITTEST_ARGS):
|
||||
# import test files.
|
||||
if IMPORT_SLOW_TESTS:
|
||||
global slow_tests_dict
|
||||
with open(IMPORT_SLOW_TESTS, 'r') as fp:
|
||||
slow_tests_dict = json.load(fp)
|
||||
if IMPORT_DISABLED_TESTS:
|
||||
global disabled_tests_dict
|
||||
with open(IMPORT_DISABLED_TESTS, 'r') as fp:
|
||||
disabled_tests_dict = json.load(fp)
|
||||
# Determine the test launch mechanism
|
||||
if TEST_DISCOVER:
|
||||
suite = unittest.TestLoader().loadTestsFromModule(__main__)
|
||||
test_cases = discover_test_cases_recursively(suite)
|
||||
|
|
@ -842,93 +860,16 @@ try:
|
|||
except ImportError:
|
||||
print('Fail to import hypothesis in common_utils, tests are not derandomized')
|
||||
|
||||
|
||||
FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
|
||||
|
||||
def fetch_and_cache(name: str, url: str):
|
||||
"""
|
||||
Some tests run in a different process so globals like `slow_test_dict` won't
|
||||
always be filled even though the test file was already downloaded on this
|
||||
machine, so cache it on disk
|
||||
"""
|
||||
path = os.path.join(tempfile.gettempdir(), name)
|
||||
|
||||
def is_cached_file_valid():
|
||||
# Check if the file is new enough (say 1 hour for now). A real check
|
||||
# could make a HEAD request and check/store the file's ETag
|
||||
fname = pathlib.Path(path)
|
||||
now = datetime.datetime.now()
|
||||
mtime = datetime.datetime.fromtimestamp(fname.stat().st_mtime)
|
||||
diff = now - mtime
|
||||
return diff.total_seconds() < FILE_CACHE_LIFESPAN_SECONDS
|
||||
|
||||
if os.path.exists(path) and is_cached_file_valid():
|
||||
# Another test process already downloaded the file, so don't re-do it
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
try:
|
||||
contents = urlopen(url, timeout=1).read().decode('utf-8')
|
||||
with open(path, "w") as f:
|
||||
f.write(contents)
|
||||
return json.loads(contents)
|
||||
except Exception as e:
|
||||
print(f'Could not download {url} because of error {e}.')
|
||||
return {}
|
||||
|
||||
|
||||
slow_tests_dict: Optional[Dict[str, float]] = None
|
||||
def check_slow_test_from_stats(test):
|
||||
global slow_tests_dict
|
||||
if slow_tests_dict is None:
|
||||
if not IS_SANDCASTLE:
|
||||
url = "https://raw.githubusercontent.com/pytorch/test-infra/master/stats/slow-tests.json"
|
||||
slow_tests_dict = fetch_and_cache(".pytorch-slow-tests.json", url)
|
||||
else:
|
||||
slow_tests_dict = {}
|
||||
def check_if_enable(test: unittest.TestCase):
|
||||
test_suite = str(test.__class__).split('\'')[1]
|
||||
test_name = f'{test._testMethodName} ({test_suite})'
|
||||
|
||||
if test_name in slow_tests_dict:
|
||||
if slow_tests_dict is not None and test_name in slow_tests_dict:
|
||||
getattr(test, test._testMethodName).__dict__['slow_test'] = True
|
||||
if not TEST_WITH_SLOW:
|
||||
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
|
||||
|
||||
|
||||
disabled_test_from_issues: Optional[Dict[str, Any]] = None
|
||||
def check_disabled(test_name):
|
||||
global disabled_test_from_issues
|
||||
if disabled_test_from_issues is None:
|
||||
_disabled_test_from_issues: Dict = {}
|
||||
|
||||
def read_and_process():
|
||||
url = 'https://raw.githubusercontent.com/pytorch/test-infra/master/stats/disabled-tests.json'
|
||||
the_response = fetch_and_cache(".pytorch-disabled-tests", url)
|
||||
for item in the_response['items']:
|
||||
title = item['title']
|
||||
key = 'DISABLED '
|
||||
if title.startswith(key):
|
||||
test_name = title[len(key):].strip()
|
||||
body = item['body']
|
||||
platforms_to_skip = []
|
||||
key = 'platforms:'
|
||||
for line in body.splitlines():
|
||||
line = line.lower()
|
||||
if line.startswith(key):
|
||||
pattern = re.compile(r"^\s+|\s*,\s*|\s+$")
|
||||
platforms_to_skip.extend([x for x in pattern.split(line[len(key):]) if x])
|
||||
_disabled_test_from_issues[test_name] = (item['html_url'], platforms_to_skip)
|
||||
|
||||
if not IS_SANDCASTLE and os.getenv("PYTORCH_RUN_DISABLED_TESTS", "0") != "1":
|
||||
try:
|
||||
read_and_process()
|
||||
disabled_test_from_issues = _disabled_test_from_issues
|
||||
except Exception:
|
||||
print("Couldn't download test skip set, leaving all tests enabled...")
|
||||
disabled_test_from_issues = {}
|
||||
|
||||
if disabled_test_from_issues is not None:
|
||||
if test_name in disabled_test_from_issues:
|
||||
issue_url, platforms = disabled_test_from_issues[test_name]
|
||||
if not IS_SANDCASTLE and disabled_tests_dict is not None:
|
||||
if test_name in disabled_tests_dict:
|
||||
issue_url, platforms = disabled_tests_dict[test_name]
|
||||
platform_to_conditional: Dict = {
|
||||
"mac": IS_MACOS,
|
||||
"macos": IS_MACOS,
|
||||
|
|
@ -940,7 +881,9 @@ def check_disabled(test_name):
|
|||
f"Test is disabled because an issue exists disabling it: {issue_url}" +
|
||||
f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}." +
|
||||
" To enable, set the environment variable PYTORCH_RUN_DISABLED_TESTS=1")
|
||||
|
||||
if TEST_SKIP_FAST:
|
||||
if not getattr(test, test._testMethodName).__dict__.get('slow_test', False):
|
||||
raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")
|
||||
|
||||
# Acquires the comparison dtype, required since isclose
|
||||
# requires both inputs have the same dtype, and isclose is not supported
|
||||
|
|
@ -1106,12 +1049,7 @@ class TestCase(expecttest.TestCase):
|
|||
result.stop()
|
||||
|
||||
def setUp(self):
|
||||
check_slow_test_from_stats(self)
|
||||
if TEST_SKIP_FAST:
|
||||
if not getattr(self, self._testMethodName).__dict__.get('slow_test', False):
|
||||
raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")
|
||||
check_disabled(str(self))
|
||||
|
||||
check_if_enable(self)
|
||||
set_rng_seed(SEED)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user