mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[lint] fixes to mypy linter
I thought I landed this already, but: - Don't run one mypy instance per file, run one per config - Do the same for flake8 - Properly handle stub files Pull Request resolved: https://github.com/pytorch/pytorch/pull/68479 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
d6e6061b98
commit
fe1e6de73a
|
|
@ -53,6 +53,9 @@ include_patterns = [
|
||||||
'test/cpp/tensorexpr/**/*.h',
|
'test/cpp/tensorexpr/**/*.h',
|
||||||
'test/cpp/tensorexpr/**/*.cpp',
|
'test/cpp/tensorexpr/**/*.cpp',
|
||||||
]
|
]
|
||||||
|
exclude_patterns = [
|
||||||
|
'torch/csrc/jit/serialization/mobile_bytecode_generated.h',
|
||||||
|
]
|
||||||
init_command = [
|
init_command = [
|
||||||
'python3',
|
'python3',
|
||||||
'tools/linter/adapters/s3_init.py',
|
'tools/linter/adapters/s3_init.py',
|
||||||
|
|
@ -74,13 +77,15 @@ command = [
|
||||||
code = 'MYPY'
|
code = 'MYPY'
|
||||||
include_patterns = [
|
include_patterns = [
|
||||||
'torch/**/*.py',
|
'torch/**/*.py',
|
||||||
|
'torch/**/*.pyi',
|
||||||
'caffe2/**/*.py',
|
'caffe2/**/*.py',
|
||||||
|
'caffe2/**/*.pyi',
|
||||||
'test/test_bundled_images.py',
|
'test/test_bundled_images.py',
|
||||||
'test/test_bundled_inputs.py',
|
'test/test_bundled_inputs.py',
|
||||||
'test/test_complex.py',
|
'test/test_complex.py',
|
||||||
'test/test_datapipe.py',
|
'test/test_datapipe.py',
|
||||||
'test/test_futures.py',
|
'test/test_futures.py',
|
||||||
'test/test_numpy_interop.py',
|
# 'test/test_numpy_interop.py',
|
||||||
'test/test_torch.py',
|
'test/test_torch.py',
|
||||||
'test/test_type_hints.py',
|
'test/test_type_hints.py',
|
||||||
'test/test_type_info.py',
|
'test/test_type_info.py',
|
||||||
|
|
@ -90,6 +95,17 @@ exclude_patterns = [
|
||||||
'torch/include/**',
|
'torch/include/**',
|
||||||
'torch/csrc/**',
|
'torch/csrc/**',
|
||||||
'torch/distributed/elastic/agent/server/api.py',
|
'torch/distributed/elastic/agent/server/api.py',
|
||||||
|
'torch/testing/_internal/**',
|
||||||
|
'torch/distributed/fsdp/fully_sharded_data_parallel.py',
|
||||||
|
# TODO(suo): these exclusions were added just to get lint clean on master.
|
||||||
|
# Follow up to do more target suppressions and remove them.
|
||||||
|
'torch/distributed/fsdp/flatten_params_wrapper.py',
|
||||||
|
'torch/ao/quantization/fx/convert.py',
|
||||||
|
'torch/ao/quantization/_dbr/function_fusion.py',
|
||||||
|
'test/test_datapipe.py',
|
||||||
|
'caffe2/contrib/fakelowp/test/test_batchmatmul_nnpi_fp16.py',
|
||||||
|
'test/test_numpy_interop.py',
|
||||||
|
'torch/torch_version.py',
|
||||||
]
|
]
|
||||||
command = [
|
command = [
|
||||||
'python3',
|
'python3',
|
||||||
|
|
@ -108,6 +124,7 @@ init_command = [
|
||||||
'mypy==0.812',
|
'mypy==0.812',
|
||||||
'junitparser==2.1.1',
|
'junitparser==2.1.1',
|
||||||
'rich==10.9.0',
|
'rich==10.9.0',
|
||||||
|
'pyyaml==6.0',
|
||||||
]
|
]
|
||||||
|
|
||||||
[[linter]]
|
[[linter]]
|
||||||
|
|
@ -319,6 +336,7 @@ include_patterns = [
|
||||||
]
|
]
|
||||||
exclude_patterns = [
|
exclude_patterns = [
|
||||||
'aten/src/ATen/native/quantized/cpu/qnnpack/**',
|
'aten/src/ATen/native/quantized/cpu/qnnpack/**',
|
||||||
|
'torch/csrc/jit/serialization/mobile_bytecode_generated.h',
|
||||||
]
|
]
|
||||||
command = [
|
command = [
|
||||||
'python3',
|
'python3',
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
@ -244,8 +243,8 @@ def get_issue_documentation_url(code: str) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def check_file(
|
def check_files(
|
||||||
filename: str,
|
filenames: List[str],
|
||||||
binary: str,
|
binary: str,
|
||||||
flake8_plugins_path: Optional[str],
|
flake8_plugins_path: Optional[str],
|
||||||
severities: Dict[str, LintSeverity],
|
severities: Dict[str, LintSeverity],
|
||||||
|
|
@ -253,7 +252,7 @@ def check_file(
|
||||||
) -> List[LintMessage]:
|
) -> List[LintMessage]:
|
||||||
try:
|
try:
|
||||||
proc = run_command(
|
proc = run_command(
|
||||||
[binary, "--exit-zero", filename],
|
[binary, "--exit-zero"] + filenames,
|
||||||
extra_env={"FLAKE8_PLUGINS_PATH": flake8_plugins_path}
|
extra_env={"FLAKE8_PLUGINS_PATH": flake8_plugins_path}
|
||||||
if flake8_plugins_path
|
if flake8_plugins_path
|
||||||
else None,
|
else None,
|
||||||
|
|
@ -262,7 +261,7 @@ def check_file(
|
||||||
except (OSError, subprocess.CalledProcessError) as err:
|
except (OSError, subprocess.CalledProcessError) as err:
|
||||||
return [
|
return [
|
||||||
LintMessage(
|
LintMessage(
|
||||||
path=filename,
|
path=None,
|
||||||
line=None,
|
line=None,
|
||||||
char=None,
|
char=None,
|
||||||
code="FLAKE8",
|
code="FLAKE8",
|
||||||
|
|
@ -369,28 +368,9 @@ def main() -> None:
|
||||||
assert len(parts) == 2, f"invalid severity `{severity}`"
|
assert len(parts) == 2, f"invalid severity `{severity}`"
|
||||||
severities[parts[0]] = LintSeverity(parts[1])
|
severities[parts[0]] = LintSeverity(parts[1])
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(
|
lint_messages = check_files(args.filenames, args.binary, flake8_plugins_path, severities, args.retries)
|
||||||
max_workers=os.cpu_count(),
|
for lint_message in lint_messages:
|
||||||
thread_name_prefix="Thread",
|
print(json.dumps(lint_message._asdict()), flush=True)
|
||||||
) as executor:
|
|
||||||
futures = {
|
|
||||||
executor.submit(
|
|
||||||
check_file,
|
|
||||||
filename,
|
|
||||||
args.binary,
|
|
||||||
flake8_plugins_path,
|
|
||||||
severities,
|
|
||||||
args.retries,
|
|
||||||
): filename
|
|
||||||
for filename in args.filenames
|
|
||||||
}
|
|
||||||
for future in concurrent.futures.as_completed(futures):
|
|
||||||
try:
|
|
||||||
for lint_message in future.result():
|
|
||||||
print(json.dumps(lint_message._asdict()), flush=True)
|
|
||||||
except Exception:
|
|
||||||
logging.critical('Failed at "%s".', futures[future])
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
@ -8,6 +7,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional, Pattern
|
from typing import Any, Dict, List, NamedTuple, Optional, Pattern
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -56,7 +56,6 @@ RESULTS_RE: Pattern[str] = re.compile(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_command(
|
def run_command(
|
||||||
args: List[str],
|
args: List[str],
|
||||||
*,
|
*,
|
||||||
|
|
@ -82,15 +81,16 @@ severities = {
|
||||||
"note": LintSeverity.ADVICE,
|
"note": LintSeverity.ADVICE,
|
||||||
}
|
}
|
||||||
|
|
||||||
def check_file(
|
|
||||||
filename: str,
|
def check_files(
|
||||||
|
filenames: List[str],
|
||||||
config: str,
|
config: str,
|
||||||
binary: str,
|
binary: str,
|
||||||
retries: int,
|
retries: int,
|
||||||
) -> List[LintMessage]:
|
) -> List[LintMessage]:
|
||||||
try:
|
try:
|
||||||
proc = run_command(
|
proc = run_command(
|
||||||
[binary, f"--config={config}", filename],
|
[binary, f"--config={config}"] + filenames,
|
||||||
extra_env={},
|
extra_env={},
|
||||||
retries=retries,
|
retries=retries,
|
||||||
)
|
)
|
||||||
|
|
@ -105,9 +105,7 @@ def check_file(
|
||||||
name="command-failed",
|
name="command-failed",
|
||||||
original=None,
|
original=None,
|
||||||
replacement=None,
|
replacement=None,
|
||||||
description=(
|
description=(f"Failed due to {err.__class__.__name__}:\n{err}"),
|
||||||
f"Failed due to {err.__class__.__name__}:\n{err}"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
stdout = str(proc.stdout, "utf-8").strip()
|
stdout = str(proc.stdout, "utf-8").strip()
|
||||||
|
|
@ -172,27 +170,26 @@ def main() -> None:
|
||||||
stream=sys.stderr,
|
stream=sys.stderr,
|
||||||
)
|
)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(
|
# Use a dictionary here to preserve order. mypy cares about order,
|
||||||
max_workers=os.cpu_count(),
|
# tragically, e.g. https://github.com/python/mypy/issues/2015
|
||||||
thread_name_prefix="Thread",
|
filenames: Dict[str, bool] = {}
|
||||||
) as executor:
|
|
||||||
futures = {
|
# If a stub file exists, have mypy check it instead of the original file, in
|
||||||
executor.submit(
|
# accordance with PEP-484 (see https://www.python.org/dev/peps/pep-0484/#stub-files)
|
||||||
check_file,
|
for filename in args.filenames:
|
||||||
filename,
|
if filename.endswith(".pyi"):
|
||||||
args.config,
|
filenames[filename] = True
|
||||||
args.binary,
|
continue
|
||||||
args.retries,
|
|
||||||
): filename
|
stub_filename = filename.replace(".py", ".pyi")
|
||||||
for filename in args.filenames
|
if Path(stub_filename).exists():
|
||||||
}
|
filenames[stub_filename] = True
|
||||||
for future in concurrent.futures.as_completed(futures):
|
else:
|
||||||
try:
|
filenames[filename] = True
|
||||||
for lint_message in future.result():
|
|
||||||
print(json.dumps(lint_message._asdict()), flush=True)
|
lint_messages = check_files(list(filenames), args.config, args.binary, args.retries)
|
||||||
except Exception:
|
for lint_message in lint_messages:
|
||||||
logging.critical('Failed at "%s".', futures[future])
|
print(json.dumps(lint_message._asdict()), flush=True)
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -7,15 +7,14 @@
|
||||||
*/
|
*/
|
||||||
#include <torch/csrc/jit/jit_log.h>
|
#include <torch/csrc/jit/jit_log.h>
|
||||||
#include <torch/csrc/jit/passes/inliner.h>
|
#include <torch/csrc/jit/passes/inliner.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
|
||||||
#include <torch/csrc/jit/runtime/decomposition_registry_util.h>
|
#include <torch/csrc/jit/runtime/decomposition_registry_util.h>
|
||||||
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
|
|
||||||
const std::string decomp_funcs =
|
const std::string decomp_funcs =
|
||||||
R"(def var_decomposition(input: Tensor,
|
R"(def var_decomposition(input: Tensor,
|
||||||
dim: Optional[List[int]]=None,
|
dim: Optional[List[int]]=None,
|
||||||
correction: Optional[int]=None,
|
correction: Optional[int]=None,
|
||||||
keepdim: bool=False) -> Tensor:
|
keepdim: bool=False) -> Tensor:
|
||||||
|
|
|
||||||
|
|
@ -456,7 +456,9 @@ void ProfilingGraphExecutorImpl::runNoGradOptimizations(
|
||||||
for (const auto& passPair : getCustomPostPasses()) {
|
for (const auto& passPair : getCustomPostPasses()) {
|
||||||
passPair.first(graph);
|
passPair.first(graph);
|
||||||
}
|
}
|
||||||
GRAPH_DEBUG("After customPostPasses, before RemoveTensorTypeSpecializations \n", *graph);
|
GRAPH_DEBUG(
|
||||||
|
"After customPostPasses, before RemoveTensorTypeSpecializations \n",
|
||||||
|
*graph);
|
||||||
RemoveTensorTypeSpecializations(graph);
|
RemoveTensorTypeSpecializations(graph);
|
||||||
GRAPH_DEBUG("After RemoveTensorTypeSpecializations\n", *graph);
|
GRAPH_DEBUG("After RemoveTensorTypeSpecializations\n", *graph);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user