pytorch/tools/testing/modulefinder_determinator.py
Huy Do 347b036350 Apply ufmt linter to all py files under tools (#81285)
With ufmt in place https://github.com/pytorch/pytorch/pull/81157, we can now use it to gradually format all files. I'm breaking this down into multiple smaller batches to avoid too many merge conflicts later on.

This batch (as copied from the current BLACK linter config):
* `tools/**/*.py`

Upcoming batchs:
* `torchgen/**/*.py`
* `torch/package/**/*.py`
* `torch/onnx/**/*.py`
* `torch/_refs/**/*.py`
* `torch/_prims/**/*.py`
* `torch/_meta_registrations.py`
* `torch/_decomp/**/*.py`
* `test/onnx/**/*.py`

Once they are all formatted, BLACK linter will be removed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81285
Approved by: https://github.com/suo
2022-07-13 07:59:22 +00:00

194 lines
6.0 KiB
Python

import modulefinder
import os
import pathlib
import sys
import warnings
from typing import Any, Dict, List, Set
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
# These tests are slow enough that it's worth calculating whether the patch
# touched any related files first. This list was manually generated, but for every
# run with --determine-from, we use another generated list based on this one and the
# previous test stats.
TARGET_DET_LIST = [
# test_autograd.py is not slow, so it does not belong here. But
# note that if you try to add it back it will run into
# https://bugs.python.org/issue40350 because it imports files
# under test/autograd/.
"test_binary_ufuncs",
"test_cpp_extensions_aot_ninja",
"test_cpp_extensions_aot_no_ninja",
"test_cpp_extensions_jit",
"test_cpp_extensions_open_device_registration",
"test_cuda",
"test_cuda_primary_ctx",
"test_dataloader",
"test_determination",
"test_futures",
"test_jit",
"test_jit_legacy",
"test_jit_profiling",
"test_linalg",
"test_multiprocessing",
"test_nn",
"test_numpy_interop",
"test_optim",
"test_overrides",
"test_pruning_op",
"test_quantization",
"test_reductions",
"test_serialization",
"test_shape_ops",
"test_sort_and_select",
"test_tensorboard",
"test_testing",
"test_torch",
"test_utils",
"test_view_ops",
]
_DEP_MODULES_CACHE: Dict[str, Set[str]] = {}
def should_run_test(
target_det_list: List[str], test: str, touched_files: List[str], options: Any
) -> bool:
test = parse_test_module(test)
# Some tests are faster to execute than to determine.
if test not in target_det_list:
if options.verbose:
print_to_stderr(f"Running {test} without determination")
return True
# HACK: "no_ninja" is not a real module
if test.endswith("_no_ninja"):
test = test[: (-1 * len("_no_ninja"))]
if test.endswith("_ninja"):
test = test[: (-1 * len("_ninja"))]
dep_modules = get_dep_modules(test)
for touched_file in touched_files:
file_type = test_impact_of_file(touched_file)
if file_type == "NONE":
continue
elif file_type == "CI":
# Force all tests to run if any change is made to the CI
# configurations.
log_test_reason(file_type, touched_file, test, options)
return True
elif file_type == "UNKNOWN":
# Assume uncategorized source files can affect every test.
log_test_reason(file_type, touched_file, test, options)
return True
elif file_type in ["TORCH", "CAFFE2", "TEST"]:
parts = os.path.splitext(touched_file)[0].split(os.sep)
touched_module = ".".join(parts)
# test/ path does not have a "test." namespace
if touched_module.startswith("test."):
touched_module = touched_module.split("test.")[1]
if touched_module in dep_modules or touched_module == test.replace(
"/", "."
):
log_test_reason(file_type, touched_file, test, options)
return True
# If nothing has determined the test has run, don't run the test.
if options.verbose:
print_to_stderr(f"Determination is skipping {test}")
return False
def test_impact_of_file(filename: str) -> str:
"""Determine what class of impact this file has on test runs.
Possible values:
TORCH - torch python code
CAFFE2 - caffe2 python code
TEST - torch test code
UNKNOWN - may affect all tests
NONE - known to have no effect on test outcome
CI - CI configuration files
"""
parts = filename.split(os.sep)
if parts[0] in [".jenkins", ".circleci"]:
return "CI"
if parts[0] in ["docs", "scripts", "CODEOWNERS", "README.md"]:
return "NONE"
elif parts[0] == "torch":
if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
return "TORCH"
elif parts[0] == "caffe2":
if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
return "CAFFE2"
elif parts[0] == "test":
if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
return "TEST"
return "UNKNOWN"
def log_test_reason(file_type: str, filename: str, test: str, options: Any) -> None:
if options.verbose:
print_to_stderr(
"Determination found {} file {} -- running {}".format(
file_type,
filename,
test,
)
)
def get_dep_modules(test: str) -> Set[str]:
# Cache results in case of repetition
if test in _DEP_MODULES_CACHE:
return _DEP_MODULES_CACHE[test]
test_location = REPO_ROOT / "test" / f"{test}.py"
# HACK: some platforms default to ascii, so we can't just run_script :(
finder = modulefinder.ModuleFinder(
# Ideally exclude all third party modules, to speed up calculation.
excludes=[
"scipy",
"numpy",
"numba",
"multiprocessing",
"sklearn",
"setuptools",
"hypothesis",
"llvmlite",
"joblib",
"email",
"importlib",
"unittest",
"urllib",
"json",
"collections",
# Modules below are excluded because they are hitting https://bugs.python.org/issue40350
# Trigger AttributeError: 'NoneType' object has no attribute 'is_package'
"mpl_toolkits",
"google",
"onnx",
# Triggers RecursionError
"mypy",
],
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
finder.run_script(str(test_location))
dep_modules = set(finder.modules.keys())
_DEP_MODULES_CACHE[test] = dep_modules
return dep_modules
def parse_test_module(test: str) -> str:
return test.split(".")[0]
def print_to_stderr(message: str) -> None:
print(message, file=sys.stderr)