mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
With ufmt in place https://github.com/pytorch/pytorch/pull/81157, we can now use it to gradually format all files. I'm breaking this down into multiple smaller batches to avoid too many merge conflicts later on. This batch (as copied from the current BLACK linter config): * `tools/**/*.py` Upcoming batchs: * `torchgen/**/*.py` * `torch/package/**/*.py` * `torch/onnx/**/*.py` * `torch/_refs/**/*.py` * `torch/_prims/**/*.py` * `torch/_meta_registrations.py` * `torch/_decomp/**/*.py` * `test/onnx/**/*.py` Once they are all formatted, BLACK linter will be removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81285 Approved by: https://github.com/suo
194 lines
6.0 KiB
Python
194 lines
6.0 KiB
Python
import modulefinder
|
|
import os
|
|
import pathlib
|
|
import sys
|
|
import warnings
|
|
from typing import Any, Dict, List, Set
|
|
|
|
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
|
|
|
# These tests are slow enough that it's worth calculating whether the patch
|
|
# touched any related files first. This list was manually generated, but for every
|
|
# run with --determine-from, we use another generated list based on this one and the
|
|
# previous test stats.
|
|
TARGET_DET_LIST = [
|
|
# test_autograd.py is not slow, so it does not belong here. But
|
|
# note that if you try to add it back it will run into
|
|
# https://bugs.python.org/issue40350 because it imports files
|
|
# under test/autograd/.
|
|
"test_binary_ufuncs",
|
|
"test_cpp_extensions_aot_ninja",
|
|
"test_cpp_extensions_aot_no_ninja",
|
|
"test_cpp_extensions_jit",
|
|
"test_cpp_extensions_open_device_registration",
|
|
"test_cuda",
|
|
"test_cuda_primary_ctx",
|
|
"test_dataloader",
|
|
"test_determination",
|
|
"test_futures",
|
|
"test_jit",
|
|
"test_jit_legacy",
|
|
"test_jit_profiling",
|
|
"test_linalg",
|
|
"test_multiprocessing",
|
|
"test_nn",
|
|
"test_numpy_interop",
|
|
"test_optim",
|
|
"test_overrides",
|
|
"test_pruning_op",
|
|
"test_quantization",
|
|
"test_reductions",
|
|
"test_serialization",
|
|
"test_shape_ops",
|
|
"test_sort_and_select",
|
|
"test_tensorboard",
|
|
"test_testing",
|
|
"test_torch",
|
|
"test_utils",
|
|
"test_view_ops",
|
|
]
|
|
|
|
|
|
_DEP_MODULES_CACHE: Dict[str, Set[str]] = {}
|
|
|
|
|
|
def should_run_test(
|
|
target_det_list: List[str], test: str, touched_files: List[str], options: Any
|
|
) -> bool:
|
|
test = parse_test_module(test)
|
|
# Some tests are faster to execute than to determine.
|
|
if test not in target_det_list:
|
|
if options.verbose:
|
|
print_to_stderr(f"Running {test} without determination")
|
|
return True
|
|
# HACK: "no_ninja" is not a real module
|
|
if test.endswith("_no_ninja"):
|
|
test = test[: (-1 * len("_no_ninja"))]
|
|
if test.endswith("_ninja"):
|
|
test = test[: (-1 * len("_ninja"))]
|
|
|
|
dep_modules = get_dep_modules(test)
|
|
|
|
for touched_file in touched_files:
|
|
file_type = test_impact_of_file(touched_file)
|
|
if file_type == "NONE":
|
|
continue
|
|
elif file_type == "CI":
|
|
# Force all tests to run if any change is made to the CI
|
|
# configurations.
|
|
log_test_reason(file_type, touched_file, test, options)
|
|
return True
|
|
elif file_type == "UNKNOWN":
|
|
# Assume uncategorized source files can affect every test.
|
|
log_test_reason(file_type, touched_file, test, options)
|
|
return True
|
|
elif file_type in ["TORCH", "CAFFE2", "TEST"]:
|
|
parts = os.path.splitext(touched_file)[0].split(os.sep)
|
|
touched_module = ".".join(parts)
|
|
# test/ path does not have a "test." namespace
|
|
if touched_module.startswith("test."):
|
|
touched_module = touched_module.split("test.")[1]
|
|
if touched_module in dep_modules or touched_module == test.replace(
|
|
"/", "."
|
|
):
|
|
log_test_reason(file_type, touched_file, test, options)
|
|
return True
|
|
|
|
# If nothing has determined the test has run, don't run the test.
|
|
if options.verbose:
|
|
print_to_stderr(f"Determination is skipping {test}")
|
|
|
|
return False
|
|
|
|
|
|
def test_impact_of_file(filename: str) -> str:
|
|
"""Determine what class of impact this file has on test runs.
|
|
|
|
Possible values:
|
|
TORCH - torch python code
|
|
CAFFE2 - caffe2 python code
|
|
TEST - torch test code
|
|
UNKNOWN - may affect all tests
|
|
NONE - known to have no effect on test outcome
|
|
CI - CI configuration files
|
|
"""
|
|
parts = filename.split(os.sep)
|
|
if parts[0] in [".jenkins", ".circleci"]:
|
|
return "CI"
|
|
if parts[0] in ["docs", "scripts", "CODEOWNERS", "README.md"]:
|
|
return "NONE"
|
|
elif parts[0] == "torch":
|
|
if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
|
|
return "TORCH"
|
|
elif parts[0] == "caffe2":
|
|
if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
|
|
return "CAFFE2"
|
|
elif parts[0] == "test":
|
|
if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
|
|
return "TEST"
|
|
|
|
return "UNKNOWN"
|
|
|
|
|
|
def log_test_reason(file_type: str, filename: str, test: str, options: Any) -> None:
|
|
if options.verbose:
|
|
print_to_stderr(
|
|
"Determination found {} file {} -- running {}".format(
|
|
file_type,
|
|
filename,
|
|
test,
|
|
)
|
|
)
|
|
|
|
|
|
def get_dep_modules(test: str) -> Set[str]:
|
|
# Cache results in case of repetition
|
|
if test in _DEP_MODULES_CACHE:
|
|
return _DEP_MODULES_CACHE[test]
|
|
|
|
test_location = REPO_ROOT / "test" / f"{test}.py"
|
|
|
|
# HACK: some platforms default to ascii, so we can't just run_script :(
|
|
finder = modulefinder.ModuleFinder(
|
|
# Ideally exclude all third party modules, to speed up calculation.
|
|
excludes=[
|
|
"scipy",
|
|
"numpy",
|
|
"numba",
|
|
"multiprocessing",
|
|
"sklearn",
|
|
"setuptools",
|
|
"hypothesis",
|
|
"llvmlite",
|
|
"joblib",
|
|
"email",
|
|
"importlib",
|
|
"unittest",
|
|
"urllib",
|
|
"json",
|
|
"collections",
|
|
# Modules below are excluded because they are hitting https://bugs.python.org/issue40350
|
|
# Trigger AttributeError: 'NoneType' object has no attribute 'is_package'
|
|
"mpl_toolkits",
|
|
"google",
|
|
"onnx",
|
|
# Triggers RecursionError
|
|
"mypy",
|
|
],
|
|
)
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
finder.run_script(str(test_location))
|
|
dep_modules = set(finder.modules.keys())
|
|
_DEP_MODULES_CACHE[test] = dep_modules
|
|
return dep_modules
|
|
|
|
|
|
def parse_test_module(test: str) -> str:
|
|
return test.split(".")[0]
|
|
|
|
|
|
def print_to_stderr(message: str) -> None:
|
|
print(message, file=sys.stderr)
|