[BE]: enable ruff rules PLR1722 and PLW3301 (#109461)

Enables two ruff rules derived from pylint:
* PLR1722 replaces any exit() calls with sys.exit(). exit() is only designed to be used in repl contexts as may not always be imported by default. This always use the version in the sys module which is better
* PLW3301 replaces nested min / max calls with simplified versions (ie. `min(a, min(b, c))` => `min(a, b. c)`). The new version is more idiomatic and more efficient.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109461
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan 2023-09-18 02:07:18 +00:00 committed by PyTorch MergeBot
parent a9a0f7a4ad
commit 6d725e7d66
21 changed files with 34 additions and 25 deletions

View File

@ -2,6 +2,7 @@
import os
import subprocess
import sys
COMMON_TESTS = [
(
@ -53,4 +54,4 @@ if __name__ == "__main__":
print("Reruning with traceback enabled")
print("Command:", command_string)
subprocess.run(command_args, check=False)
exit(e.returncode)
sys.exit(e.returncode)

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
"""Check whether a PR has required labels."""
import sys
from typing import Any
from github_utils import gh_delete_comment, gh_post_pr_comment
@ -46,7 +47,7 @@ def main() -> None:
except Exception as e:
pass
exit(0)
sys.exit(0)
if __name__ == "__main__":

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
import argparse
import os
import sys
from typing import Set
@ -87,7 +88,7 @@ if __name__ == "__main__":
torchbench.torchbench_main()
else:
print(f"Illegal model name? {name}")
exit(-1)
sys.exit(-1)
else:
import torchbench

View File

@ -25,7 +25,7 @@ DOWNLOAD_COLUMNS = 70
# Don't let urllib hang up on big downloads
def signalHandler(signal, frame):
print("Killing download...")
exit(0)
sys.exit(0)
signal.signal(signal.SIGINT, signalHandler)
@ -107,7 +107,7 @@ def downloadModel(model, args):
response = input(query)
if response.upper() == 'N' or not response:
print("Cancelling download...")
exit(0)
sys.exit(0)
print("Overwriting existing folder! ({filename})".format(filename=model_folder))
deleteDirectory(model_folder)
@ -122,7 +122,7 @@ def downloadModel(model, args):
print("Abort: {reason}".format(reason=str(e)))
print("Cleaning up...")
deleteDirectory(model_folder)
exit(0)
sys.exit(0)
if args.install:
os.symlink("{folder}/__sym_init__.py".format(folder=dir_path),

View File

@ -74,6 +74,8 @@ select = [
"PIE807",
"PIE810",
"PLE",
"PLR1722", # use sys exit
"PLW3301", # nested min max
"RUF017",
"TRY302",
]

View File

@ -37,7 +37,7 @@ class SomeClass:
print(f"Abort: {e}")
print("Cleaning up...")
deleteDirectory(model_dir)
exit(1)
sys.exit(1)
def _caffe2_model_dir(self, model):
caffe2_home = os.path.expanduser("~/.caffe2")

View File

@ -418,7 +418,7 @@ def test_compose_affine(event_dims):
if transform.domain.event_dim > 1:
base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
dist = TransformedDistribution(base_dist, transforms)
assert dist.support.event_dim == max(1, max(event_dims))
assert dist.support.event_dim == max(1, *event_dims)
@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)

View File

@ -5,6 +5,7 @@ import logging
import os
import re
import subprocess
import sys
import time
from enum import Enum
from typing import List, NamedTuple, Optional, Pattern
@ -131,7 +132,7 @@ if __name__ == "__main__":
),
)
print(json.dumps(err_msg._asdict()), flush=True)
exit(0)
sys.exit(0)
with concurrent.futures.ThreadPoolExecutor(
max_workers=os.cpu_count(),

View File

@ -10,6 +10,7 @@ import json
import re
import shlex
import subprocess
import sys
import xml.etree.ElementTree as ET
from enum import Enum
from typing import List, NamedTuple, Optional, Set
@ -182,7 +183,7 @@ def main() -> None:
description=(f"Failed due to {e.__class__.__name__}:\n{e}"),
)
print(json.dumps(err_msg._asdict()), flush=True)
exit(0)
sys.exit(0)
for filename in args.filenames:
for lint_message in check_bazel(filename, disallowed_checksums):

View File

@ -249,7 +249,7 @@ def main() -> None:
),
)
print(json.dumps(err_msg._asdict()), flush=True)
exit(0)
sys.exit(0)
abs_build_dir = Path(args.build_dir).resolve()

View File

@ -252,7 +252,7 @@ def main() -> None:
),
)
print(json.dumps(err_msg._asdict()), flush=True)
exit(0)
sys.exit(0)
lines = proc.stdout.decode().splitlines()
for line in lines:

View File

@ -1,5 +1,6 @@
import json
import subprocess
import sys
from enum import Enum
from typing import NamedTuple, Optional, Tuple
@ -53,7 +54,7 @@ if __name__ == "__main__":
replacement=None,
description="Lintrunner is not installed, did you forget to run `make setup_lint && make lint`?",
)
exit(0)
sys.exit(0)
curr_version = int(version_match[1]), int(version_match[2]), int(version_match[3])
min_version = (0, 10, 7)

View File

@ -205,7 +205,7 @@ if __name__ == "__main__":
# If the host platform is not in platform_to_hash, it is unsupported.
if host_platform not in config:
logging.error("Unsupported platform: %s/%s", HOST_PLATFORM, HOST_PLATFORM_ARCH)
exit(1)
sys.exit(1)
url = config[host_platform]["download_url"]
hash = config[host_platform]["hash"]

View File

@ -3,6 +3,7 @@ import json
import logging
import shutil
import subprocess
import sys
import time
from enum import Enum
from typing import List, NamedTuple, Optional
@ -108,7 +109,7 @@ if __name__ == "__main__":
description="shellcheck is not installed, did you forget to run `lintrunner init`?",
)
print(json.dumps(err_msg._asdict()), flush=True)
exit(0)
sys.exit(0)
args = parser.parse_args()

View File

@ -18,7 +18,7 @@ def run_cmd(cmd: List[str]) -> None:
print(stderr)
if result.returncode != 0:
print(f"Failed to run {cmd}")
exit(1)
sys.exit(1)
def update_submodules() -> None:

View File

@ -21,7 +21,7 @@ try:
except ModuleNotFoundError:
print("Can't import required modules, exiting")
exit(1)
sys.exit(1)
def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:

View File

@ -12,7 +12,7 @@ try:
from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD
except ModuleNotFoundError:
print("Can't import required modules, exiting")
exit(1)
sys.exit(1)
class TestCalculateShards(unittest.TestCase):
@ -308,7 +308,7 @@ class TestCalculateShards(unittest.TestCase):
if k != "super_long_test" and k != "long_test1"
]
sum_of_rest = sum(rest_of_tests)
random_times["super_long_test"] = max(sum_of_rest / 2, max(rest_of_tests))
random_times["super_long_test"] = max(sum_of_rest / 2, *rest_of_tests)
random_times["long_test1"] = sum_of_rest - random_times["super_long_test"]
# An optimal sharding would look like the below, but we don't need to compute this for the test:
# optimal_shards = [

View File

@ -500,7 +500,7 @@ socket",
ncore_per_node,
args.ncores_per_instance,
)
exit(-1)
sys.exit(-1)
elif num_leftover_cores == 0:
# aren't any cross-node cores
logger.info(
@ -573,7 +573,7 @@ won't take effect even if it is set explicitly."
"Core binding with numactl is not available, and --disable_taskset is set. \
Please unset --disable_taskset to use taskset instead of numactl."
)
exit(-1)
sys.exit(-1)
if not args.disable_taskset:
enable_taskset = True

View File

@ -2083,7 +2083,7 @@ def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=
size = [1, 5, 10]
for batch, m, n in product(batches, size, size):
for k in range(min(3, min(m, n))):
for k in range(min(3, m, n)):
a = make_arg((*batch, m, k))
b = make_arg((*batch, n, k))
yield SampleInput(a, b, **kwargs)

View File

@ -934,11 +934,11 @@ def run_tests(argv=UNITTEST_ARGS):
if not RERUN_DISABLED_TESTS:
# exitcode of 5 means no tests were found, which happens since some test configs don't
# run tests from certain files
exit(0 if exit_code == 5 else exit_code)
sys.exit(0 if exit_code == 5 else exit_code)
else:
# Only record the test report and always return a success code when running under rerun
# disabled tests mode
exit(0)
sys.exit(0)
elif TEST_SAVE_XML is not None:
# import here so that non-CI doesn't need xmlrunner installed
import xmlrunner # type: ignore[import]

View File

@ -535,7 +535,7 @@ def sample_inputs_linalg_pinv_singular(
size = [0, 3, 50]
for batch, m, n in product(batches, size, size):
for k in range(min(3, min(m, n))):
for k in range(min(3, m, n)):
# Note that by making the columns of `a` and `b` orthonormal we make sure that
# the product matrix `a @ b.t()` has condition number 1 when restricted to its image
a = (