[BE]: enable ruff rules PLR1722 and PLW3301 (#109461)

Enables two ruff rules derived from pylint:
* PLR1722 replaces any exit() calls with sys.exit(). exit() is only designed to be used in repl contexts as may not always be imported by default. This always use the version in the sys module which is better
* PLW3301 replaces nested min / max calls with simplified versions (ie. `min(a, min(b, c))` => `min(a, b. c)`). The new version is more idiomatic and more efficient.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109461
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan 2023-09-18 02:07:18 +00:00 committed by PyTorch MergeBot
parent a9a0f7a4ad
commit 6d725e7d66
21 changed files with 34 additions and 25 deletions

View File

@ -2,6 +2,7 @@
import os import os
import subprocess import subprocess
import sys
COMMON_TESTS = [ COMMON_TESTS = [
( (
@ -53,4 +54,4 @@ if __name__ == "__main__":
print("Reruning with traceback enabled") print("Reruning with traceback enabled")
print("Command:", command_string) print("Command:", command_string)
subprocess.run(command_args, check=False) subprocess.run(command_args, check=False)
exit(e.returncode) sys.exit(e.returncode)

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Check whether a PR has required labels.""" """Check whether a PR has required labels."""
import sys
from typing import Any from typing import Any
from github_utils import gh_delete_comment, gh_post_pr_comment from github_utils import gh_delete_comment, gh_post_pr_comment
@ -46,7 +47,7 @@ def main() -> None:
except Exception as e: except Exception as e:
pass pass
exit(0) sys.exit(0)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import os import os
import sys
from typing import Set from typing import Set
@ -87,7 +88,7 @@ if __name__ == "__main__":
torchbench.torchbench_main() torchbench.torchbench_main()
else: else:
print(f"Illegal model name? {name}") print(f"Illegal model name? {name}")
exit(-1) sys.exit(-1)
else: else:
import torchbench import torchbench

View File

@ -25,7 +25,7 @@ DOWNLOAD_COLUMNS = 70
# Don't let urllib hang up on big downloads # Don't let urllib hang up on big downloads
def signalHandler(signal, frame): def signalHandler(signal, frame):
print("Killing download...") print("Killing download...")
exit(0) sys.exit(0)
signal.signal(signal.SIGINT, signalHandler) signal.signal(signal.SIGINT, signalHandler)
@ -107,7 +107,7 @@ def downloadModel(model, args):
response = input(query) response = input(query)
if response.upper() == 'N' or not response: if response.upper() == 'N' or not response:
print("Cancelling download...") print("Cancelling download...")
exit(0) sys.exit(0)
print("Overwriting existing folder! ({filename})".format(filename=model_folder)) print("Overwriting existing folder! ({filename})".format(filename=model_folder))
deleteDirectory(model_folder) deleteDirectory(model_folder)
@ -122,7 +122,7 @@ def downloadModel(model, args):
print("Abort: {reason}".format(reason=str(e))) print("Abort: {reason}".format(reason=str(e)))
print("Cleaning up...") print("Cleaning up...")
deleteDirectory(model_folder) deleteDirectory(model_folder)
exit(0) sys.exit(0)
if args.install: if args.install:
os.symlink("{folder}/__sym_init__.py".format(folder=dir_path), os.symlink("{folder}/__sym_init__.py".format(folder=dir_path),

View File

@ -74,6 +74,8 @@ select = [
"PIE807", "PIE807",
"PIE810", "PIE810",
"PLE", "PLE",
"PLR1722", # use sys exit
"PLW3301", # nested min max
"RUF017", "RUF017",
"TRY302", "TRY302",
] ]

View File

@ -37,7 +37,7 @@ class SomeClass:
print(f"Abort: {e}") print(f"Abort: {e}")
print("Cleaning up...") print("Cleaning up...")
deleteDirectory(model_dir) deleteDirectory(model_dir)
exit(1) sys.exit(1)
def _caffe2_model_dir(self, model): def _caffe2_model_dir(self, model):
caffe2_home = os.path.expanduser("~/.caffe2") caffe2_home = os.path.expanduser("~/.caffe2")

View File

@ -418,7 +418,7 @@ def test_compose_affine(event_dims):
if transform.domain.event_dim > 1: if transform.domain.event_dim > 1:
base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1)) base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
dist = TransformedDistribution(base_dist, transforms) dist = TransformedDistribution(base_dist, transforms)
assert dist.support.event_dim == max(1, max(event_dims)) assert dist.support.event_dim == max(1, *event_dims)
@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str) @pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)

View File

@ -5,6 +5,7 @@ import logging
import os import os
import re import re
import subprocess import subprocess
import sys
import time import time
from enum import Enum from enum import Enum
from typing import List, NamedTuple, Optional, Pattern from typing import List, NamedTuple, Optional, Pattern
@ -131,7 +132,7 @@ if __name__ == "__main__":
), ),
) )
print(json.dumps(err_msg._asdict()), flush=True) print(json.dumps(err_msg._asdict()), flush=True)
exit(0) sys.exit(0)
with concurrent.futures.ThreadPoolExecutor( with concurrent.futures.ThreadPoolExecutor(
max_workers=os.cpu_count(), max_workers=os.cpu_count(),

View File

@ -10,6 +10,7 @@ import json
import re import re
import shlex import shlex
import subprocess import subprocess
import sys
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from enum import Enum from enum import Enum
from typing import List, NamedTuple, Optional, Set from typing import List, NamedTuple, Optional, Set
@ -182,7 +183,7 @@ def main() -> None:
description=(f"Failed due to {e.__class__.__name__}:\n{e}"), description=(f"Failed due to {e.__class__.__name__}:\n{e}"),
) )
print(json.dumps(err_msg._asdict()), flush=True) print(json.dumps(err_msg._asdict()), flush=True)
exit(0) sys.exit(0)
for filename in args.filenames: for filename in args.filenames:
for lint_message in check_bazel(filename, disallowed_checksums): for lint_message in check_bazel(filename, disallowed_checksums):

View File

@ -249,7 +249,7 @@ def main() -> None:
), ),
) )
print(json.dumps(err_msg._asdict()), flush=True) print(json.dumps(err_msg._asdict()), flush=True)
exit(0) sys.exit(0)
abs_build_dir = Path(args.build_dir).resolve() abs_build_dir = Path(args.build_dir).resolve()

View File

@ -252,7 +252,7 @@ def main() -> None:
), ),
) )
print(json.dumps(err_msg._asdict()), flush=True) print(json.dumps(err_msg._asdict()), flush=True)
exit(0) sys.exit(0)
lines = proc.stdout.decode().splitlines() lines = proc.stdout.decode().splitlines()
for line in lines: for line in lines:

View File

@ -1,5 +1,6 @@
import json import json
import subprocess import subprocess
import sys
from enum import Enum from enum import Enum
from typing import NamedTuple, Optional, Tuple from typing import NamedTuple, Optional, Tuple
@ -53,7 +54,7 @@ if __name__ == "__main__":
replacement=None, replacement=None,
description="Lintrunner is not installed, did you forget to run `make setup_lint && make lint`?", description="Lintrunner is not installed, did you forget to run `make setup_lint && make lint`?",
) )
exit(0) sys.exit(0)
curr_version = int(version_match[1]), int(version_match[2]), int(version_match[3]) curr_version = int(version_match[1]), int(version_match[2]), int(version_match[3])
min_version = (0, 10, 7) min_version = (0, 10, 7)

View File

@ -205,7 +205,7 @@ if __name__ == "__main__":
# If the host platform is not in platform_to_hash, it is unsupported. # If the host platform is not in platform_to_hash, it is unsupported.
if host_platform not in config: if host_platform not in config:
logging.error("Unsupported platform: %s/%s", HOST_PLATFORM, HOST_PLATFORM_ARCH) logging.error("Unsupported platform: %s/%s", HOST_PLATFORM, HOST_PLATFORM_ARCH)
exit(1) sys.exit(1)
url = config[host_platform]["download_url"] url = config[host_platform]["download_url"]
hash = config[host_platform]["hash"] hash = config[host_platform]["hash"]

View File

@ -3,6 +3,7 @@ import json
import logging import logging
import shutil import shutil
import subprocess import subprocess
import sys
import time import time
from enum import Enum from enum import Enum
from typing import List, NamedTuple, Optional from typing import List, NamedTuple, Optional
@ -108,7 +109,7 @@ if __name__ == "__main__":
description="shellcheck is not installed, did you forget to run `lintrunner init`?", description="shellcheck is not installed, did you forget to run `lintrunner init`?",
) )
print(json.dumps(err_msg._asdict()), flush=True) print(json.dumps(err_msg._asdict()), flush=True)
exit(0) sys.exit(0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -18,7 +18,7 @@ def run_cmd(cmd: List[str]) -> None:
print(stderr) print(stderr)
if result.returncode != 0: if result.returncode != 0:
print(f"Failed to run {cmd}") print(f"Failed to run {cmd}")
exit(1) sys.exit(1)
def update_submodules() -> None: def update_submodules() -> None:

View File

@ -21,7 +21,7 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
print("Can't import required modules, exiting") print("Can't import required modules, exiting")
exit(1) sys.exit(1)
def mocked_file(contents: Dict[Any, Any]) -> io.IOBase: def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:

View File

@ -12,7 +12,7 @@ try:
from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD
except ModuleNotFoundError: except ModuleNotFoundError:
print("Can't import required modules, exiting") print("Can't import required modules, exiting")
exit(1) sys.exit(1)
class TestCalculateShards(unittest.TestCase): class TestCalculateShards(unittest.TestCase):
@ -308,7 +308,7 @@ class TestCalculateShards(unittest.TestCase):
if k != "super_long_test" and k != "long_test1" if k != "super_long_test" and k != "long_test1"
] ]
sum_of_rest = sum(rest_of_tests) sum_of_rest = sum(rest_of_tests)
random_times["super_long_test"] = max(sum_of_rest / 2, max(rest_of_tests)) random_times["super_long_test"] = max(sum_of_rest / 2, *rest_of_tests)
random_times["long_test1"] = sum_of_rest - random_times["super_long_test"] random_times["long_test1"] = sum_of_rest - random_times["super_long_test"]
# An optimal sharding would look like the below, but we don't need to compute this for the test: # An optimal sharding would look like the below, but we don't need to compute this for the test:
# optimal_shards = [ # optimal_shards = [

View File

@ -500,7 +500,7 @@ socket",
ncore_per_node, ncore_per_node,
args.ncores_per_instance, args.ncores_per_instance,
) )
exit(-1) sys.exit(-1)
elif num_leftover_cores == 0: elif num_leftover_cores == 0:
# aren't any cross-node cores # aren't any cross-node cores
logger.info( logger.info(
@ -573,7 +573,7 @@ won't take effect even if it is set explicitly."
"Core binding with numactl is not available, and --disable_taskset is set. \ "Core binding with numactl is not available, and --disable_taskset is set. \
Please unset --disable_taskset to use taskset instead of numactl." Please unset --disable_taskset to use taskset instead of numactl."
) )
exit(-1) sys.exit(-1)
if not args.disable_taskset: if not args.disable_taskset:
enable_taskset = True enable_taskset = True

View File

@ -2083,7 +2083,7 @@ def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=
size = [1, 5, 10] size = [1, 5, 10]
for batch, m, n in product(batches, size, size): for batch, m, n in product(batches, size, size):
for k in range(min(3, min(m, n))): for k in range(min(3, m, n)):
a = make_arg((*batch, m, k)) a = make_arg((*batch, m, k))
b = make_arg((*batch, n, k)) b = make_arg((*batch, n, k))
yield SampleInput(a, b, **kwargs) yield SampleInput(a, b, **kwargs)

View File

@ -934,11 +934,11 @@ def run_tests(argv=UNITTEST_ARGS):
if not RERUN_DISABLED_TESTS: if not RERUN_DISABLED_TESTS:
# exitcode of 5 means no tests were found, which happens since some test configs don't # exitcode of 5 means no tests were found, which happens since some test configs don't
# run tests from certain files # run tests from certain files
exit(0 if exit_code == 5 else exit_code) sys.exit(0 if exit_code == 5 else exit_code)
else: else:
# Only record the test report and always return a success code when running under rerun # Only record the test report and always return a success code when running under rerun
# disabled tests mode # disabled tests mode
exit(0) sys.exit(0)
elif TEST_SAVE_XML is not None: elif TEST_SAVE_XML is not None:
# import here so that non-CI doesn't need xmlrunner installed # import here so that non-CI doesn't need xmlrunner installed
import xmlrunner # type: ignore[import] import xmlrunner # type: ignore[import]

View File

@ -535,7 +535,7 @@ def sample_inputs_linalg_pinv_singular(
size = [0, 3, 50] size = [0, 3, 50]
for batch, m, n in product(batches, size, size): for batch, m, n in product(batches, size, size):
for k in range(min(3, min(m, n))): for k in range(min(3, m, n)):
# Note that by making the columns of `a` and `b` orthonormal we make sure that # Note that by making the columns of `a` and `b` orthonormal we make sure that
# the product matrix `a @ b.t()` has condition number 1 when restricted to its image # the product matrix `a @ b.t()` has condition number 1 when restricted to its image
a = ( a = (