mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE]: enable ruff rules PLR1722 and PLW3301 (#109461)
Enables two ruff rules derived from pylint: * PLR1722 replaces any exit() calls with sys.exit(). exit() is only designed to be used in repl contexts as may not always be imported by default. This always use the version in the sys module which is better * PLW3301 replaces nested min / max calls with simplified versions (ie. `min(a, min(b, c))` => `min(a, b. c)`). The new version is more idiomatic and more efficient. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109461 Approved by: https://github.com/ezyang
This commit is contained in:
parent
a9a0f7a4ad
commit
6d725e7d66
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
COMMON_TESTS = [
|
||||
(
|
||||
|
|
@ -53,4 +54,4 @@ if __name__ == "__main__":
|
|||
print("Reruning with traceback enabled")
|
||||
print("Command:", command_string)
|
||||
subprocess.run(command_args, check=False)
|
||||
exit(e.returncode)
|
||||
sys.exit(e.returncode)
|
||||
|
|
|
|||
3
.github/scripts/check_labels.py
vendored
3
.github/scripts/check_labels.py
vendored
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Check whether a PR has required labels."""
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from github_utils import gh_delete_comment, gh_post_pr_comment
|
||||
|
|
@ -46,7 +47,7 @@ def main() -> None:
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
from typing import Set
|
||||
|
||||
|
|
@ -87,7 +88,7 @@ if __name__ == "__main__":
|
|||
torchbench.torchbench_main()
|
||||
else:
|
||||
print(f"Illegal model name? {name}")
|
||||
exit(-1)
|
||||
sys.exit(-1)
|
||||
else:
|
||||
import torchbench
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ DOWNLOAD_COLUMNS = 70
|
|||
# Don't let urllib hang up on big downloads
|
||||
def signalHandler(signal, frame):
|
||||
print("Killing download...")
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, signalHandler)
|
||||
|
|
@ -107,7 +107,7 @@ def downloadModel(model, args):
|
|||
response = input(query)
|
||||
if response.upper() == 'N' or not response:
|
||||
print("Cancelling download...")
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
print("Overwriting existing folder! ({filename})".format(filename=model_folder))
|
||||
deleteDirectory(model_folder)
|
||||
|
||||
|
|
@ -122,7 +122,7 @@ def downloadModel(model, args):
|
|||
print("Abort: {reason}".format(reason=str(e)))
|
||||
print("Cleaning up...")
|
||||
deleteDirectory(model_folder)
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
if args.install:
|
||||
os.symlink("{folder}/__sym_init__.py".format(folder=dir_path),
|
||||
|
|
|
|||
|
|
@ -74,6 +74,8 @@ select = [
|
|||
"PIE807",
|
||||
"PIE810",
|
||||
"PLE",
|
||||
"PLR1722", # use sys exit
|
||||
"PLW3301", # nested min max
|
||||
"RUF017",
|
||||
"TRY302",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class SomeClass:
|
|||
print(f"Abort: {e}")
|
||||
print("Cleaning up...")
|
||||
deleteDirectory(model_dir)
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
def _caffe2_model_dir(self, model):
|
||||
caffe2_home = os.path.expanduser("~/.caffe2")
|
||||
|
|
|
|||
|
|
@ -418,7 +418,7 @@ def test_compose_affine(event_dims):
|
|||
if transform.domain.event_dim > 1:
|
||||
base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
|
||||
dist = TransformedDistribution(base_dist, transforms)
|
||||
assert dist.support.event_dim == max(1, max(event_dims))
|
||||
assert dist.support.event_dim == max(1, *event_dims)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import List, NamedTuple, Optional, Pattern
|
||||
|
|
@ -131,7 +132,7 @@ if __name__ == "__main__":
|
|||
),
|
||||
)
|
||||
print(json.dumps(err_msg._asdict()), flush=True)
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=os.cpu_count(),
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import json
|
|||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
import xml.etree.ElementTree as ET
|
||||
from enum import Enum
|
||||
from typing import List, NamedTuple, Optional, Set
|
||||
|
|
@ -182,7 +183,7 @@ def main() -> None:
|
|||
description=(f"Failed due to {e.__class__.__name__}:\n{e}"),
|
||||
)
|
||||
print(json.dumps(err_msg._asdict()), flush=True)
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
for filename in args.filenames:
|
||||
for lint_message in check_bazel(filename, disallowed_checksums):
|
||||
|
|
|
|||
|
|
@ -249,7 +249,7 @@ def main() -> None:
|
|||
),
|
||||
)
|
||||
print(json.dumps(err_msg._asdict()), flush=True)
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
abs_build_dir = Path(args.build_dir).resolve()
|
||||
|
||||
|
|
|
|||
|
|
@ -252,7 +252,7 @@ def main() -> None:
|
|||
),
|
||||
)
|
||||
print(json.dumps(err_msg._asdict()), flush=True)
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
lines = proc.stdout.decode().splitlines()
|
||||
for line in lines:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, Optional, Tuple
|
||||
|
||||
|
|
@ -53,7 +54,7 @@ if __name__ == "__main__":
|
|||
replacement=None,
|
||||
description="Lintrunner is not installed, did you forget to run `make setup_lint && make lint`?",
|
||||
)
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
curr_version = int(version_match[1]), int(version_match[2]), int(version_match[3])
|
||||
min_version = (0, 10, 7)
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ if __name__ == "__main__":
|
|||
# If the host platform is not in platform_to_hash, it is unsupported.
|
||||
if host_platform not in config:
|
||||
logging.error("Unsupported platform: %s/%s", HOST_PLATFORM, HOST_PLATFORM_ARCH)
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
url = config[host_platform]["download_url"]
|
||||
hash = config[host_platform]["hash"]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import json
|
|||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
|
@ -108,7 +109,7 @@ if __name__ == "__main__":
|
|||
description="shellcheck is not installed, did you forget to run `lintrunner init`?",
|
||||
)
|
||||
print(json.dumps(err_msg._asdict()), flush=True)
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ def run_cmd(cmd: List[str]) -> None:
|
|||
print(stderr)
|
||||
if result.returncode != 0:
|
||||
print(f"Failed to run {cmd}")
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def update_submodules() -> None:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ try:
|
|||
|
||||
except ModuleNotFoundError:
|
||||
print("Can't import required modules, exiting")
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ try:
|
|||
from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD
|
||||
except ModuleNotFoundError:
|
||||
print("Can't import required modules, exiting")
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class TestCalculateShards(unittest.TestCase):
|
||||
|
|
@ -308,7 +308,7 @@ class TestCalculateShards(unittest.TestCase):
|
|||
if k != "super_long_test" and k != "long_test1"
|
||||
]
|
||||
sum_of_rest = sum(rest_of_tests)
|
||||
random_times["super_long_test"] = max(sum_of_rest / 2, max(rest_of_tests))
|
||||
random_times["super_long_test"] = max(sum_of_rest / 2, *rest_of_tests)
|
||||
random_times["long_test1"] = sum_of_rest - random_times["super_long_test"]
|
||||
# An optimal sharding would look like the below, but we don't need to compute this for the test:
|
||||
# optimal_shards = [
|
||||
|
|
|
|||
|
|
@ -500,7 +500,7 @@ socket",
|
|||
ncore_per_node,
|
||||
args.ncores_per_instance,
|
||||
)
|
||||
exit(-1)
|
||||
sys.exit(-1)
|
||||
elif num_leftover_cores == 0:
|
||||
# aren't any cross-node cores
|
||||
logger.info(
|
||||
|
|
@ -573,7 +573,7 @@ won't take effect even if it is set explicitly."
|
|||
"Core binding with numactl is not available, and --disable_taskset is set. \
|
||||
Please unset --disable_taskset to use taskset instead of numactl."
|
||||
)
|
||||
exit(-1)
|
||||
sys.exit(-1)
|
||||
|
||||
if not args.disable_taskset:
|
||||
enable_taskset = True
|
||||
|
|
|
|||
|
|
@ -2083,7 +2083,7 @@ def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=
|
|||
size = [1, 5, 10]
|
||||
|
||||
for batch, m, n in product(batches, size, size):
|
||||
for k in range(min(3, min(m, n))):
|
||||
for k in range(min(3, m, n)):
|
||||
a = make_arg((*batch, m, k))
|
||||
b = make_arg((*batch, n, k))
|
||||
yield SampleInput(a, b, **kwargs)
|
||||
|
|
|
|||
|
|
@ -934,11 +934,11 @@ def run_tests(argv=UNITTEST_ARGS):
|
|||
if not RERUN_DISABLED_TESTS:
|
||||
# exitcode of 5 means no tests were found, which happens since some test configs don't
|
||||
# run tests from certain files
|
||||
exit(0 if exit_code == 5 else exit_code)
|
||||
sys.exit(0 if exit_code == 5 else exit_code)
|
||||
else:
|
||||
# Only record the test report and always return a success code when running under rerun
|
||||
# disabled tests mode
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
elif TEST_SAVE_XML is not None:
|
||||
# import here so that non-CI doesn't need xmlrunner installed
|
||||
import xmlrunner # type: ignore[import]
|
||||
|
|
|
|||
|
|
@ -535,7 +535,7 @@ def sample_inputs_linalg_pinv_singular(
|
|||
size = [0, 3, 50]
|
||||
|
||||
for batch, m, n in product(batches, size, size):
|
||||
for k in range(min(3, min(m, n))):
|
||||
for k in range(min(3, m, n)):
|
||||
# Note that by making the columns of `a` and `b` orthonormal we make sure that
|
||||
# the product matrix `a @ b.t()` has condition number 1 when restricted to its image
|
||||
a = (
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user