[lint] fixes to mypy linter

I thought I landed this already, but:
- Don't run one mypy instance per file, run one per config
- Do the same for flake8
- Properly handle stub files

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68479

Approved by: https://github.com/janeyx99
This commit is contained in:
Michael Suo 2022-04-13 09:09:39 -07:00 committed by PyTorch MergeBot
parent d6e6061b98
commit fe1e6de73a
5 changed files with 57 additions and 61 deletions

View File

@ -53,6 +53,9 @@ include_patterns = [
'test/cpp/tensorexpr/**/*.h', 'test/cpp/tensorexpr/**/*.h',
'test/cpp/tensorexpr/**/*.cpp', 'test/cpp/tensorexpr/**/*.cpp',
] ]
exclude_patterns = [
'torch/csrc/jit/serialization/mobile_bytecode_generated.h',
]
init_command = [ init_command = [
'python3', 'python3',
'tools/linter/adapters/s3_init.py', 'tools/linter/adapters/s3_init.py',
@ -74,13 +77,15 @@ command = [
code = 'MYPY' code = 'MYPY'
include_patterns = [ include_patterns = [
'torch/**/*.py', 'torch/**/*.py',
'torch/**/*.pyi',
'caffe2/**/*.py', 'caffe2/**/*.py',
'caffe2/**/*.pyi',
'test/test_bundled_images.py', 'test/test_bundled_images.py',
'test/test_bundled_inputs.py', 'test/test_bundled_inputs.py',
'test/test_complex.py', 'test/test_complex.py',
'test/test_datapipe.py', 'test/test_datapipe.py',
'test/test_futures.py', 'test/test_futures.py',
'test/test_numpy_interop.py', # 'test/test_numpy_interop.py',
'test/test_torch.py', 'test/test_torch.py',
'test/test_type_hints.py', 'test/test_type_hints.py',
'test/test_type_info.py', 'test/test_type_info.py',
@ -90,6 +95,17 @@ exclude_patterns = [
'torch/include/**', 'torch/include/**',
'torch/csrc/**', 'torch/csrc/**',
'torch/distributed/elastic/agent/server/api.py', 'torch/distributed/elastic/agent/server/api.py',
'torch/testing/_internal/**',
'torch/distributed/fsdp/fully_sharded_data_parallel.py',
# TODO(suo): these exclusions were added just to get lint clean on master.
# Follow up to do more target suppressions and remove them.
'torch/distributed/fsdp/flatten_params_wrapper.py',
'torch/ao/quantization/fx/convert.py',
'torch/ao/quantization/_dbr/function_fusion.py',
'test/test_datapipe.py',
'caffe2/contrib/fakelowp/test/test_batchmatmul_nnpi_fp16.py',
'test/test_numpy_interop.py',
'torch/torch_version.py',
] ]
command = [ command = [
'python3', 'python3',
@ -108,6 +124,7 @@ init_command = [
'mypy==0.812', 'mypy==0.812',
'junitparser==2.1.1', 'junitparser==2.1.1',
'rich==10.9.0', 'rich==10.9.0',
'pyyaml==6.0',
] ]
[[linter]] [[linter]]
@ -319,6 +336,7 @@ include_patterns = [
] ]
exclude_patterns = [ exclude_patterns = [
'aten/src/ATen/native/quantized/cpu/qnnpack/**', 'aten/src/ATen/native/quantized/cpu/qnnpack/**',
'torch/csrc/jit/serialization/mobile_bytecode_generated.h',
] ]
command = [ command = [
'python3', 'python3',

View File

@ -1,5 +1,4 @@
import argparse import argparse
import concurrent.futures
import json import json
import logging import logging
import os import os
@ -244,8 +243,8 @@ def get_issue_documentation_url(code: str) -> str:
return "" return ""
def check_file( def check_files(
filename: str, filenames: List[str],
binary: str, binary: str,
flake8_plugins_path: Optional[str], flake8_plugins_path: Optional[str],
severities: Dict[str, LintSeverity], severities: Dict[str, LintSeverity],
@ -253,7 +252,7 @@ def check_file(
) -> List[LintMessage]: ) -> List[LintMessage]:
try: try:
proc = run_command( proc = run_command(
[binary, "--exit-zero", filename], [binary, "--exit-zero"] + filenames,
extra_env={"FLAKE8_PLUGINS_PATH": flake8_plugins_path} extra_env={"FLAKE8_PLUGINS_PATH": flake8_plugins_path}
if flake8_plugins_path if flake8_plugins_path
else None, else None,
@ -262,7 +261,7 @@ def check_file(
except (OSError, subprocess.CalledProcessError) as err: except (OSError, subprocess.CalledProcessError) as err:
return [ return [
LintMessage( LintMessage(
path=filename, path=None,
line=None, line=None,
char=None, char=None,
code="FLAKE8", code="FLAKE8",
@ -369,28 +368,9 @@ def main() -> None:
assert len(parts) == 2, f"invalid severity `{severity}`" assert len(parts) == 2, f"invalid severity `{severity}`"
severities[parts[0]] = LintSeverity(parts[1]) severities[parts[0]] = LintSeverity(parts[1])
with concurrent.futures.ThreadPoolExecutor( lint_messages = check_files(args.filenames, args.binary, flake8_plugins_path, severities, args.retries)
max_workers=os.cpu_count(), for lint_message in lint_messages:
thread_name_prefix="Thread",
) as executor:
futures = {
executor.submit(
check_file,
filename,
args.binary,
flake8_plugins_path,
severities,
args.retries,
): filename
for filename in args.filenames
}
for future in concurrent.futures.as_completed(futures):
try:
for lint_message in future.result():
print(json.dumps(lint_message._asdict()), flush=True) print(json.dumps(lint_message._asdict()), flush=True)
except Exception:
logging.critical('Failed at "%s".', futures[future])
raise
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,5 +1,4 @@
import argparse import argparse
import concurrent.futures
import json import json
import logging import logging
import os import os
@ -8,6 +7,7 @@ import subprocess
import sys import sys
import time import time
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, NamedTuple, Optional, Pattern from typing import Any, Dict, List, NamedTuple, Optional, Pattern
@ -56,7 +56,6 @@ RESULTS_RE: Pattern[str] = re.compile(
) )
def run_command( def run_command(
args: List[str], args: List[str],
*, *,
@ -82,15 +81,16 @@ severities = {
"note": LintSeverity.ADVICE, "note": LintSeverity.ADVICE,
} }
def check_file(
filename: str, def check_files(
filenames: List[str],
config: str, config: str,
binary: str, binary: str,
retries: int, retries: int,
) -> List[LintMessage]: ) -> List[LintMessage]:
try: try:
proc = run_command( proc = run_command(
[binary, f"--config={config}", filename], [binary, f"--config={config}"] + filenames,
extra_env={}, extra_env={},
retries=retries, retries=retries,
) )
@ -105,9 +105,7 @@ def check_file(
name="command-failed", name="command-failed",
original=None, original=None,
replacement=None, replacement=None,
description=( description=(f"Failed due to {err.__class__.__name__}:\n{err}"),
f"Failed due to {err.__class__.__name__}:\n{err}"
),
) )
] ]
stdout = str(proc.stdout, "utf-8").strip() stdout = str(proc.stdout, "utf-8").strip()
@ -172,27 +170,26 @@ def main() -> None:
stream=sys.stderr, stream=sys.stderr,
) )
with concurrent.futures.ThreadPoolExecutor( # Use a dictionary here to preserve order. mypy cares about order,
max_workers=os.cpu_count(), # tragically, e.g. https://github.com/python/mypy/issues/2015
thread_name_prefix="Thread", filenames: Dict[str, bool] = {}
) as executor:
futures = { # If a stub file exists, have mypy check it instead of the original file, in
executor.submit( # accordance with PEP-484 (see https://www.python.org/dev/peps/pep-0484/#stub-files)
check_file, for filename in args.filenames:
filename, if filename.endswith(".pyi"):
args.config, filenames[filename] = True
args.binary, continue
args.retries,
): filename stub_filename = filename.replace(".py", ".pyi")
for filename in args.filenames if Path(stub_filename).exists():
} filenames[stub_filename] = True
for future in concurrent.futures.as_completed(futures): else:
try: filenames[filename] = True
for lint_message in future.result():
lint_messages = check_files(list(filenames), args.config, args.binary, args.retries)
for lint_message in lint_messages:
print(json.dumps(lint_message._asdict()), flush=True) print(json.dumps(lint_message._asdict()), flush=True)
except Exception:
logging.critical('Failed at "%s".', futures[future])
raise
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -7,13 +7,12 @@
*/ */
#include <torch/csrc/jit/jit_log.h> #include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/inliner.h> #include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/decomposition_registry_util.h> #include <torch/csrc/jit/runtime/decomposition_registry_util.h>
#include <torch/csrc/jit/runtime/operator.h>
namespace torch { namespace torch {
namespace jit { namespace jit {
const std::string decomp_funcs = const std::string decomp_funcs =
R"(def var_decomposition(input: Tensor, R"(def var_decomposition(input: Tensor,
dim: Optional[List[int]]=None, dim: Optional[List[int]]=None,

View File

@ -456,7 +456,9 @@ void ProfilingGraphExecutorImpl::runNoGradOptimizations(
for (const auto& passPair : getCustomPostPasses()) { for (const auto& passPair : getCustomPostPasses()) {
passPair.first(graph); passPair.first(graph);
} }
GRAPH_DEBUG("After customPostPasses, before RemoveTensorTypeSpecializations \n", *graph); GRAPH_DEBUG(
"After customPostPasses, before RemoveTensorTypeSpecializations \n",
*graph);
RemoveTensorTypeSpecializations(graph); RemoveTensorTypeSpecializations(graph);
GRAPH_DEBUG("After RemoveTensorTypeSpecializations\n", *graph); GRAPH_DEBUG("After RemoveTensorTypeSpecializations\n", *graph);
} }