[BE] Enable ruff's UP rules and autoformat tools and scripts (#105428)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105428
Approved by: https://github.com/albanD, https://github.com/soulitzer, https://github.com/malfet
This commit is contained in:
Justin Chu 2023-07-18 21:12:48 +00:00 committed by PyTorch MergeBot
parent 5666d20bb8
commit 14d87bb5ff
48 changed files with 125 additions and 146 deletions

View File

@ -88,9 +88,9 @@ def sign_certificate_request(path, csr_cert, ca_cert, private_ca_key):
ca_key = genrsa(temp_dir + "/ca.key")
ca_cert = create_cert(temp_dir + "/ca.pem", u"US", u"New York", u"New York", u"Gloo Certificate Authority", ca_key)
ca_cert = create_cert(temp_dir + "/ca.pem", "US", "New York", "New York", "Gloo Certificate Authority", ca_key)
pkey = genrsa(temp_dir + "/pkey.key")
csr = create_req(temp_dir + "/csr.csr", u"US", u"California", u"San Francisco", u"Gloo Testing Company", pkey)
csr = create_req(temp_dir + "/csr.csr", "US", "California", "San Francisco", "Gloo Testing Company", pkey)
cert = sign_certificate_request(temp_dir + "/cert.pem", csr, ca_cert, ca_key)

View File

@ -19,7 +19,7 @@ if 'cpu' in test_name:
elif 'gpu' in test_name:
backend = 'gpu'
data_file_path = '../{}_runtime.json'.format(backend)
data_file_path = f'../{backend}_runtime.json'
with open(data_file_path) as data_file:
data = json.load(data_file)
@ -69,7 +69,7 @@ else:
print("z-value < 3, no perf regression detected.")
if args.update:
print("We will use these numbers as new baseline.")
new_data_file_path = '../new_{}_runtime.json'.format(backend)
new_data_file_path = f'../new_{backend}_runtime.json'
with open(new_data_file_path) as new_data_file:
new_data = json.load(new_data_file)
new_data[test_name] = {}

View File

@ -5,7 +5,7 @@ import cimodel.data.binary_build_data as binary_build_data
import cimodel.lib.conf_tree as conf_tree
import cimodel.lib.miniutils as miniutils
class Conf(object):
class Conf:
def __init__(self, os, gpu_version, pydistro, parms, smoke, libtorch_variant, gcc_config_variant, libtorch_config_variant):
self.os = os

View File

@ -143,7 +143,7 @@ class Conf:
# TODO This is a hack to special case some configs just for the workflow list
class HiddenConf(object):
class HiddenConf:
def __init__(self, name, parent_build=None, filters=None):
self.name = name
self.parent_build = parent_build
@ -160,7 +160,7 @@ class HiddenConf(object):
def gen_build_name(self, _):
return self.name
class DocPushConf(object):
class DocPushConf:
def __init__(self, name, parent_build=None, branch="master"):
self.name = name
self.parent_build = parent_build

View File

@ -18,7 +18,7 @@ import cimodel.lib.miniutils as miniutils
import cimodel.lib.miniyaml as miniyaml
class File(object):
class File:
"""
Verbatim copy the contents of a file into config.yml
"""
@ -57,7 +57,7 @@ def horizontal_rule():
return "".join("#" * 78)
class Header(object):
class Header:
def __init__(self, title, summary=None):
self.title = title
self.summary_lines = summary or []

View File

@ -17,7 +17,7 @@ EXPECTED_GROUP = (
def should_check(filename: Path) -> bool:
with open(filename, "r") as f:
with open(filename) as f:
content = f.read()
data = yaml.safe_load(content)
@ -37,7 +37,7 @@ if __name__ == "__main__":
files = [f for f in files if should_check(f)]
names = set()
for filename in files:
with open(filename, "r") as f:
with open(filename) as f:
data = yaml.safe_load(f)
name = data.get("name")

View File

@ -44,7 +44,7 @@ def load_json_file(file_path: Path) -> Any:
"""
Returns the deserialized json object
"""
with open(file_path, "r") as f:
with open(file_path) as f:
return json.load(f)

View File

@ -319,7 +319,7 @@ def process_jobs(
try:
# The job name from github is in the PLATFORM / JOB (CONFIG) format, so breaking
# it into its two components first
current_platform, _ = [n.strip() for n in job_name.split(JOB_NAME_SEP, 1) if n]
current_platform, _ = (n.strip() for n in job_name.split(JOB_NAME_SEP, 1) if n)
except ValueError as error:
warnings.warn(f"Invalid job name {job_name}, returning")
return test_matrix

View File

@ -50,7 +50,7 @@ def get_tag() -> str:
def get_base_version() -> str:
root = get_pytorch_root()
dirty_version = open(root / "version.txt", "r").read().strip()
dirty_version = open(root / "version.txt").read().strip()
# Strips trailing a0 from version.txt, not too sure why it's there in the
# first place
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)

View File

@ -51,7 +51,7 @@ def get_last_page_num_from_header(header: Any) -> int:
)
@lru_cache()
@lru_cache
def gh_get_labels(org: str, repo: str) -> List[str]:
prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100"
header, info = request_for_labels(prefix + "&page=1")

View File

@ -26,7 +26,7 @@ def fn(base: str) -> str:
return str(base / Path("aten/src/ATen/native/native_functions.yaml"))
with open(Path(__file__).parent.parent.parent / fn("."), "r") as f:
with open(Path(__file__).parent.parent.parent / fn(".")) as f:
contents = f.read()
yaml = ruamel.yaml.YAML() # type: ignore[attr-defined]

View File

@ -129,7 +129,7 @@ def extract_models_from_pr(
model_list = []
userbenchmark_list = []
pr_list = []
with open(prbody_file, "r") as pf:
with open(prbody_file) as pf:
lines = (x.strip() for x in pf.read().splitlines())
magic_lines = list(filter(lambda x: x.startswith(MAGIC_PREFIX), lines))
if magic_lines:
@ -157,7 +157,7 @@ def extract_models_from_pr(
def find_torchbench_branch(prbody_file: str) -> str:
branch_name: str = ""
with open(prbody_file, "r") as pf:
with open(prbody_file) as pf:
lines = (x.strip() for x in pf.read().splitlines())
magic_lines = list(
filter(lambda x: x.startswith(MAGIC_TORCHBENCH_PREFIX), lines)

View File

@ -15,7 +15,7 @@ from trymerge import GitHubPR
def mock_parse_args() -> object:
class Object(object):
class Object:
def __init__(self) -> None:
self.pr_num = 76123

View File

@ -114,7 +114,7 @@ def mocked_rockset_results(head_sha: str, merge_base: str, num_retries: int = 3)
def mock_parse_args(revert: bool = False, force: bool = False) -> Any:
class Object(object):
class Object:
def __init__(self) -> None:
self.revert = revert
self.force = force

View File

@ -1628,10 +1628,8 @@ def validate_revert(
allowed_reverters.append("CONTRIBUTOR")
if author_association not in allowed_reverters:
raise PostCommentError(
(
f"Will not revert as @{author_login} is not one of "
f"[{', '.join(allowed_reverters)}], but instead is {author_association}."
)
f"Will not revert as @{author_login} is not one of "
f"[{', '.join(allowed_reverters)}], but instead is {author_association}."
)
skip_internal_checks = can_skip_internal_checks(pr, comment_id)

View File

@ -17,7 +17,7 @@ def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
return len(list(filter(pattern.match, labels))) > 0
class TryMergeExplainer(object):
class TryMergeExplainer:
force: bool
labels: List[str]
pr_num: int

View File

@ -8,7 +8,7 @@ import shutil
# Module caffe2...caffe2.python.control_test
def insert(originalfile, first_line, description):
with open(originalfile, 'r') as f:
with open(originalfile) as f:
f1 = f.readline()
if(f1.find(first_line) < 0):
docs = first_line + description + f1
@ -30,7 +30,7 @@ for root, dirs, files in os.walk("."):
for file in files:
if (file.endswith(".py") and not file.endswith("_test.py") and not file.endswith("__.py")):
filepath = os.path.join(root, file)
print(("filepath: " + filepath))
print("filepath: " + filepath)
directory = os.path.dirname(filepath)[2:]
directory = directory.replace("/", ".")
print("directory: " + directory)

View File

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
#
# PyTorch documentation build configuration file, created by
# sphinx-quickstart on Fri Dec 23 13:31:47 2016.
@ -99,7 +98,7 @@ exhale_args = {
############################################################################
# Main library page layout example configuration. #
############################################################################
"afterTitleDescription": textwrap.dedent(u'''
"afterTitleDescription": textwrap.dedent('''
Welcome to the developer reference for the PyTorch C++ API.
'''),
}

View File

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
#
# PyTorch documentation build configuration file, created by
# sphinx-quickstart on Fri Dec 23 13:31:47 2016.
@ -624,7 +623,7 @@ def replace(Klass):
anchor = ref_anchor[1]
txt = node.parent.astext()
if txt == anchor or txt == anchor.split('.')[-1]:
self.body.append('<p id="{}"/>'.format(ref_anchor[1]))
self.body.append(f'<p id="{ref_anchor[1]}"/>')
return old_call(self, node)
Klass.visit_reference = visit_reference

View File

@ -31,7 +31,7 @@ def generate_example_rst(example_case: ExportCase):
if isinstance(model, torch.nn.Module)
else inspect.getfile(model)
)
with open(source_file, "r") as file:
with open(source_file) as file:
source_code = file.read()
source_code = re.sub(r"from torch\._export\.db\.case import .*\n", "", source_code)
source_code = re.sub(r"@export_case\((.|\n)*?\)\n", "", source_code)
@ -114,7 +114,7 @@ def generate_index_rst(example_cases, tag_to_modules, support_level_to_modules):
tag_names = "\n ".join(t for t in tag_to_modules.keys())
with open(os.path.join(PWD, "blurb.txt"), "r") as file:
with open(os.path.join(PWD, "blurb.txt")) as file:
blurb = file.read()
# Generate contents of the .rst file

View File

@ -323,7 +323,7 @@ cmake_python_include_dir = sysconfig.get_path("include")
package_name = os.getenv('TORCH_PACKAGE_NAME', 'torch')
package_type = os.getenv('PACKAGE_TYPE', 'wheel')
version = get_torch_version()
report("Building wheel {}-{}".format(package_name, version))
report(f"Building wheel {package_name}-{version}")
cmake = CMake()
@ -361,7 +361,7 @@ def check_submodules():
start = time.time()
subprocess.check_call(["git", "submodule", "update", "--init", "--recursive"], cwd=cwd)
end = time.time()
print(' --- Submodule initialization took {:.2f} sec'.format(end - start))
print(f' --- Submodule initialization took {end - start:.2f} sec')
except Exception:
print(' --- Submodule initalization failed')
print('Please run:\n\tgit submodule update --init --recursive')
@ -616,16 +616,16 @@ class build_ext(setuptools.command.build_ext.build_ext):
continue
fullname = self.get_ext_fullname(ext.name)
filename = self.get_ext_filename(fullname)
report("\nCopying extension {}".format(ext.name))
report(f"\nCopying extension {ext.name}")
relative_site_packages = sysconfig.get_path('purelib').replace(sysconfig.get_path('data'), '').lstrip(os.path.sep)
src = os.path.join("torch", relative_site_packages, filename)
if not os.path.exists(src):
report("{} does not exist".format(src))
report(f"{src} does not exist")
del self.extensions[i]
else:
dst = os.path.join(os.path.realpath(self.build_lib), filename)
report("Copying {} from {} to {}".format(ext.name, src, dst))
report(f"Copying {ext.name} from {src} to {dst}")
dst_dir = os.path.dirname(dst)
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
@ -642,7 +642,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
src = os.path.join(os.path.dirname(filename), "functorch" + fileext)
dst = os.path.join(os.path.realpath(self.build_lib), filename)
if os.path.exists(src):
report("Copying {} from {} to {}".format(ext.name, src, dst))
report(f"Copying {ext.name} from {src} to {dst}")
dst_dir = os.path.dirname(dst)
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
@ -658,7 +658,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
src = os.path.join(os.path.dirname(filename), "nvfuser" + fileext)
dst = os.path.join(os.path.realpath(self.build_lib), filename)
if os.path.exists(src):
report("Copying {} from {} to {}".format(ext.name, src, dst))
report(f"Copying {ext.name} from {src} to {dst}")
dst_dir = os.path.dirname(dst)
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
@ -670,7 +670,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
def get_outputs(self):
outputs = setuptools.command.build_ext.build_ext.get_outputs(self)
outputs.append(os.path.join(self.build_lib, "caffe2"))
report("setup.py::get_outputs returning {}".format(outputs))
report(f"setup.py::get_outputs returning {outputs}")
return outputs
def create_compile_commands(self):
@ -694,13 +694,13 @@ class build_ext(setuptools.command.build_ext.build_ext):
new_contents = json.dumps(all_commands, indent=2)
contents = ''
if os.path.exists('compile_commands.json'):
with open('compile_commands.json', 'r') as f:
with open('compile_commands.json') as f:
contents = f.read()
if contents != new_contents:
with open('compile_commands.json', 'w') as f:
f.write(new_contents)
class concat_license_files():
class concat_license_files:
"""Merge LICENSE and LICENSES_BUNDLED.txt as a context manager
LICENSE is the main PyTorch license, LICENSES_BUNDLED.txt is auto-generated
@ -723,7 +723,7 @@ class concat_license_files():
finally:
sys.path = old_path
with open(self.f1, 'r') as f1:
with open(self.f1) as f1:
self.bsd_text = f1.read()
with open(self.f1, 'a') as f1:
@ -771,7 +771,7 @@ class clean(setuptools.Command):
def run(self):
import glob
import re
with open('.gitignore', 'r') as f:
with open('.gitignore') as f:
ignores = f.read()
pat = re.compile(r'^#( BEGIN NOT-CLEAN-FILES )?')
for wildcard in filter(None, ignores.split('\n')):
@ -934,31 +934,31 @@ def configure_extension_build():
if cmake_cache_vars['BUILD_CAFFE2']:
extensions.append(
Extension(
name=str('caffe2.python.caffe2_pybind11_state'),
name='caffe2.python.caffe2_pybind11_state',
sources=[]),
)
if cmake_cache_vars['USE_CUDA']:
extensions.append(
Extension(
name=str('caffe2.python.caffe2_pybind11_state_gpu'),
name='caffe2.python.caffe2_pybind11_state_gpu',
sources=[]),
)
if cmake_cache_vars['USE_ROCM']:
extensions.append(
Extension(
name=str('caffe2.python.caffe2_pybind11_state_hip'),
name='caffe2.python.caffe2_pybind11_state_hip',
sources=[]),
)
if cmake_cache_vars['BUILD_FUNCTORCH']:
extensions.append(
Extension(
name=str('functorch._C'),
name='functorch._C',
sources=[]),
)
if cmake_cache_vars['BUILD_NVFUSER']:
extensions.append(
Extension(
name=str('nvfuser._C'),
name='nvfuser._C',
sources=[]),
)
@ -1271,7 +1271,7 @@ def main():
download_url='https://github.com/pytorch/pytorch/tags',
author='PyTorch Team',
author_email='packages@pytorch.org',
python_requires='>={}'.format(python_min_version_str),
python_requires=f'>={python_min_version_str}',
# PyPI package information.
classifiers=[
'Development Status :: 5 - Production/Stable',
@ -1287,7 +1287,7 @@ def main():
'Topic :: Software Development :: Libraries :: Python Modules',
'Programming Language :: C++',
'Programming Language :: Python :: 3',
] + ['Programming Language :: Python :: 3.{}'.format(i) for i in range(python_min_version[1], version_range_max)],
] + [f'Programming Language :: Python :: 3.{i}' for i in range(python_min_version[1], version_range_max)],
license='BSD-3',
keywords='pytorch, machine learning',
)

View File

@ -140,7 +140,7 @@ def is_hip_clang() -> bool:
hip_path = os.getenv("HIP_PATH", "/opt/rocm/hip")
with open(hip_path + "/lib/.hipInfo") as f:
return "HIP_COMPILER=clang" in f.read()
except IOError:
except OSError:
return False
@ -149,7 +149,7 @@ if is_hip_clang():
gloo_cmake_file = "third_party/gloo/cmake/Hip.cmake"
do_write = False
if os.path.exists(gloo_cmake_file):
with open(gloo_cmake_file, "r") as sources:
with open(gloo_cmake_file) as sources:
lines = sources.readlines()
newlines = [line.replace(" hip_hcc ", " amdhip64 ") for line in lines]
if lines == newlines:
@ -163,7 +163,7 @@ if is_hip_clang():
gloo_cmake_file = "third_party/gloo/cmake/Modules/Findrccl.cmake"
if os.path.exists(gloo_cmake_file):
do_write = False
with open(gloo_cmake_file, "r") as sources:
with open(gloo_cmake_file) as sources:
lines = sources.readlines()
newlines = [line.replace("RCCL_LIBRARY", "RCCL_LIB_PATH") for line in lines]
if lines == newlines:
@ -179,7 +179,7 @@ if is_hip_clang():
gloo_cmake_file = "third_party/gloo/cmake/Dependencies.cmake"
do_write = False
if os.path.exists(gloo_cmake_file):
with open(gloo_cmake_file, "r") as sources:
with open(gloo_cmake_file) as sources:
lines = sources.readlines()
newlines = [line.replace("HIP_HCC_FLAGS", "HIP_CLANG_FLAGS") for line in lines]
if lines == newlines:

View File

@ -553,7 +553,7 @@ def load_deprecated_signatures(
# find matching original signatures for each deprecated signature
results: List[PythonSignatureNativeFunctionPair] = []
with open(deprecated_yaml_path, "r") as f:
with open(deprecated_yaml_path) as f:
deprecated_defs = yaml.load(f, Loader=YamlLoader)
for deprecated in deprecated_defs:
@ -873,7 +873,7 @@ def method_impl(
name=name,
pycname=pycname,
method_header=method_header,
max_args=max((o.signature.arguments_count() for o in overloads)),
max_args=max(o.signature.arguments_count() for o in overloads),
signatures=signatures,
traceable=traceable,
check_has_torch_function=gen_has_torch_function_check(
@ -1255,10 +1255,7 @@ def emit_single_dispatch(
# dispatch lambda signature
name = cpp.name(f.func)
lambda_formals = ", ".join(
(
f"{a.type_str} {a.name}"
for a in dispatch_lambda_args(ps, f, symint=symint)
)
f"{a.type_str} {a.name}" for a in dispatch_lambda_args(ps, f, symint=symint)
)
lambda_return = dispatch_lambda_return_str(f)

View File

@ -98,7 +98,7 @@ def load_derivatives(
global _GLOBAL_LOAD_DERIVATIVE_CACHE
key = (derivatives_yaml_path, native_yaml_path)
if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:
with open(derivatives_yaml_path, "r") as f:
with open(derivatives_yaml_path) as f:
definitions = yaml.load(f, Loader=YamlLoader)
funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions

View File

@ -24,7 +24,7 @@ def canonical_name(opname: str) -> str:
def load_op_dep_graph(fname: str) -> DepGraph:
with open(fname, "r") as stream:
with open(fname) as stream:
result = defaultdict(set)
for op in yaml.safe_load(stream):
op_name = canonical_name(op["name"])
@ -36,7 +36,7 @@ def load_op_dep_graph(fname: str) -> DepGraph:
def load_root_ops(fname: str) -> List[str]:
result = []
with open(fname, "r") as stream:
with open(fname) as stream:
for op in yaml.safe_load(stream):
result.append(canonical_name(op))
return result

View File

@ -79,7 +79,7 @@ SupportedMobileModelCheckerRegistry register_model_versions;
supported_hashes = ""
for md5 in md5_hashes:
supported_hashes += '"{}",\n'.format(md5)
supported_hashes += f'"{md5}",\n'
with open(
os.path.join(output_dir, "SupportedMobileModelsRegistration.cpp"), "wb"
) as out_file:

View File

@ -1,6 +1,6 @@
import setuptools # type: ignore[import]
with open("README.md", "r", encoding="utf-8") as fh:
with open("README.md", encoding="utf-8") as fh:
long_description = fh.read()
setuptools.setup(

View File

@ -32,16 +32,16 @@ def report_download_progress(
def download(destination_path: str, resource: str, quiet: bool) -> None:
if os.path.exists(destination_path):
if not quiet:
print("{} already exists, skipping ...".format(destination_path))
print(f"{destination_path} already exists, skipping ...")
else:
for mirror in MIRRORS:
url = mirror + resource
print("Downloading {} ...".format(url))
print(f"Downloading {url} ...")
try:
hook = None if quiet else report_download_progress
urlretrieve(url, destination_path, reporthook=hook)
except (URLError, ConnectionError) as e:
print("Failed to download (trying next):\n{}".format(e))
print(f"Failed to download (trying next):\n{e}")
continue
finally:
if not quiet:
@ -56,13 +56,13 @@ def unzip(zipped_path: str, quiet: bool) -> None:
unzipped_path = os.path.splitext(zipped_path)[0]
if os.path.exists(unzipped_path):
if not quiet:
print("{} already exists, skipping ... ".format(unzipped_path))
print(f"{unzipped_path} already exists, skipping ... ")
return
with gzip.open(zipped_path, "rb") as zipped_file:
with open(unzipped_path, "wb") as unzipped_file:
unzipped_file.write(zipped_file.read())
if not quiet:
print("Unzipped {} ...".format(zipped_path))
print(f"Unzipped {zipped_path} ...")
def main() -> None:

View File

@ -74,7 +74,7 @@ class VulkanShaderGenerator:
def add_params_yaml(self, parameters_yaml_file): # type: ignore[no-untyped-def]
all_template_params = OrderedDict()
with open(parameters_yaml_file, "r") as f:
with open(parameters_yaml_file) as f:
contents = yaml.load(f, Loader=UniqueKeyLoader)
for key in contents:
all_template_params[key] = contents[key]
@ -204,7 +204,7 @@ def determineDescriptorType(lineStr: str) -> str:
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
shader_info = ShaderInfo([], [], "")
with open(srcFilePath, 'r') as srcFile:
with open(srcFilePath) as srcFile:
for line in srcFile:
if isDescriptorLine(line):
shader_info.layouts.append(determineDescriptorType(line))
@ -271,13 +271,13 @@ def genCppH(
if len(f) > 1:
templateSrcPaths.append(f)
templateSrcPaths.sort()
print("templateSrcPaths:{}".format(templateSrcPaths))
print(f"templateSrcPaths:{templateSrcPaths}")
spvPaths = {}
for templateSrcPath in templateSrcPaths:
print("templateSrcPath {}".format(templateSrcPath))
print(f"templateSrcPath {templateSrcPath}")
name = getName(templateSrcPath).replace("_glsl", "")
print("name {}".format(name))
print(f"name {name}")
codeTemplate = CodeTemplate.from_file(templateSrcPath)
srcPath = tmpDirPath + "/" + name + ".glsl"
@ -286,7 +286,7 @@ def genCppH(
fw.write(content)
spvPath = tmpDirPath + "/" + name + ".spv"
print("spvPath {}".format(spvPath))
print(f"spvPath {spvPath}")
cmd = [
glslcPath, "-fshader-stage=compute",
@ -327,7 +327,7 @@ def genCppH(
h += nsend
cpp = "#include <ATen/native/vulkan/api/Shader.h>\n"
cpp += "#include <ATen/native/vulkan/{}>\n".format(H_NAME)
cpp += f"#include <ATen/native/vulkan/{H_NAME}>\n"
cpp += "#include <stdint.h>\n"
cpp += "#include <vector>\n"
cpp += nsbegin
@ -339,7 +339,7 @@ def genCppH(
for spvPath, srcPath in spvPaths.items():
name = getName(spvPath).replace("_spv", "")
print("spvPath:{}".format(spvPath))
print(f"spvPath:{spvPath}")
with open(spvPath, 'rb') as fr:
next_bin = array.array('I', fr.read())
sizeBytes = 4 * len(next_bin)
@ -361,8 +361,8 @@ def genCppH(
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
shader_info_args = [
"\"vulkan.{}\"".format(name),
"{}_bin".format(name),
f"\"vulkan.{name}\"",
f"{name}_bin",
str(sizeBytes),
shader_info_layouts,
tile_size,

View File

@ -41,7 +41,7 @@ def get_tag(pytorch_root: Union[str, Path]) -> str:
def get_torch_version(sha: Optional[str] = None) -> str:
pytorch_root = Path(__file__).parent.parent
version = open(pytorch_root / "version.txt", "r").read().strip()
version = open(pytorch_root / "version.txt").read().strip()
if os.getenv("PYTORCH_BUILD_VERSION"):
assert os.getenv("PYTORCH_BUILD_NUMBER") is not None
@ -86,11 +86,11 @@ if __name__ == "__main__":
version = tagged_version
with open(version_path, "w") as f:
f.write("__version__ = '{}'\n".format(version))
f.write(f"__version__ = '{version}'\n")
# NB: This is not 100% accurate, because you could have built the
# library code with DEBUG, but csrc without DEBUG (in which case
# this would claim to be a release build when it's not.)
f.write("debug = {}\n".format(repr(bool(args.is_debug))))
f.write("cuda = {}\n".format(repr(args.cuda_version)))
f.write("git_version = {}\n".format(repr(sha)))
f.write("hip = {}\n".format(repr(args.hip_version)))
f.write(f"debug = {repr(bool(args.is_debug))}\n")
f.write(f"cuda = {repr(args.cuda_version)}\n")
f.write(f"git_version = {repr(sha)}\n")
f.write(f"hip = {repr(args.hip_version)}\n")

View File

@ -250,7 +250,7 @@ def main(args: List[str]) -> None:
if options.op_registration_allowlist:
op_registration_allowlist = options.op_registration_allowlist
elif options.TEST_ONLY_op_registration_allowlist_yaml_path:
with open(options.TEST_ONLY_op_registration_allowlist_yaml_path, "r") as f:
with open(options.TEST_ONLY_op_registration_allowlist_yaml_path) as f:
op_registration_allowlist = yaml.safe_load(f)
else:
op_registration_allowlist = None

View File

@ -35,7 +35,7 @@ class LintMessage(NamedTuple):
def check_file(filename: str) -> Optional[LintMessage]:
logging.debug("Checking file %s", filename)
with open(filename, "r") as f:
with open(filename) as f:
lines = f.readlines()
for idx, line in enumerate(lines):

View File

@ -108,7 +108,7 @@ def lint_file(
original = None
replacement = None
if replace_pattern:
with open(filename, "r") as f:
with open(filename) as f:
original = f.read()
try:

View File

@ -105,7 +105,7 @@ class Formatter(logging.Formatter):
self.redactions[needle] = replace
@functools.lru_cache()
@functools.lru_cache
def logging_base_dir() -> str:
meta_dir = os.getcwd()
base_dir = os.path.join(meta_dir, "nightly", "log")
@ -113,17 +113,17 @@ def logging_base_dir() -> str:
return base_dir
@functools.lru_cache()
@functools.lru_cache
def logging_run_dir() -> str:
cur_dir = os.path.join(
logging_base_dir(),
"{}_{}".format(datetime.datetime.now().strftime(DATETIME_FORMAT), uuid.uuid1()),
f"{datetime.datetime.now().strftime(DATETIME_FORMAT)}_{uuid.uuid1()}",
)
os.makedirs(cur_dir, exist_ok=True)
return cur_dir
@functools.lru_cache()
@functools.lru_cache
def logging_record_argv() -> None:
s = subprocess.list2cmdline(sys.argv)
with open(os.path.join(logging_run_dir(), "argv"), "w") as f:

View File

@ -205,7 +205,7 @@ def gen_diagnostics(
out_cpp_dir: str,
out_docs_dir: str,
) -> None:
with open(rules_path, "r") as f:
with open(rules_path) as f:
rules = yaml.load(f, Loader=YamlLoader)
template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")

View File

@ -23,7 +23,7 @@ from typing import Any
def read_sub_write(path: str, prefix_pat: str, new_default: int) -> None:
with open(path, encoding="utf-8") as f:
content_str = f.read()
content_str = re.sub(prefix_pat, r"\g<1>{}".format(new_default), content_str)
content_str = re.sub(prefix_pat, rf"\g<1>{new_default}", content_str)
with open(path, "w", encoding="utf-8") as f:
f.write(content_str)
print("modified", path)

View File

@ -191,15 +191,15 @@ def sig_for_ops(opname: str) -> List[str]:
name = opname[2:-2]
if name in binary_ops:
return ["def {}(self, other: Any) -> Tensor: ...".format(opname)]
return [f"def {opname}(self, other: Any) -> Tensor: ..."]
elif name in comparison_ops:
sig = "def {}(self, other: Any) -> Tensor: ...".format(opname)
sig = f"def {opname}(self, other: Any) -> Tensor: ..."
if name in symmetric_comparison_ops:
# unsafe override https://github.com/python/mypy/issues/5704
sig += " # type: ignore[override]"
return [sig]
elif name in unary_ops:
return ["def {}(self) -> Tensor: ...".format(opname)]
return [f"def {opname}(self) -> Tensor: ..."]
elif name in to_py_type_ops:
if name in {"bool", "float", "complex"}:
tname = name
@ -209,7 +209,7 @@ def sig_for_ops(opname: str) -> List[str]:
tname = "int"
if tname in {"float", "int", "bool", "complex"}:
tname = "builtins." + tname
return ["def {}(self) -> {}: ...".format(opname, tname)]
return [f"def {opname}(self) -> {tname}: ..."]
else:
raise Exception("unknown op", opname)
@ -1120,9 +1120,7 @@ def gen_pyi(
"bfloat16",
]
for name in simple_conversions:
unsorted_tensor_method_hints[name].append(
"def {}(self) -> Tensor: ...".format(name)
)
unsorted_tensor_method_hints[name].append(f"def {name}(self) -> Tensor: ...")
# pyi tensor methods don't currently include deprecated signatures for some reason
# TODO: we should probably add them in
@ -1151,7 +1149,7 @@ def gen_pyi(
namedtuples[tuple_name] = tuple_def
for op in all_ops:
name = "__{}__".format(op)
name = f"__{op}__"
unsorted_tensor_method_hints[name] += sig_for_ops(name)
tensor_method_hints = []
@ -1165,7 +1163,7 @@ def gen_pyi(
# Generate namedtuple definitions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
namedtuple_defs = ["{}\n".format(defn) for defn in namedtuples.values()]
namedtuple_defs = [f"{defn}\n" for defn in namedtuples.values()]
# Generate type signatures for legacy classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -1184,7 +1182,7 @@ def gen_pyi(
"ByteTensor",
"BoolTensor",
):
legacy_class_hints.append("class {}(Tensor): ...".format(c))
legacy_class_hints.append(f"class {c}(Tensor): ...")
# Generate type signatures for dtype classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -1192,7 +1190,7 @@ def gen_pyi(
# TODO: don't explicitly list dtypes here; get it from canonical
# source
dtype_class_hints = [
"{}: dtype = ...".format(n)
f"{n}: dtype = ..."
for n in [
"float32",
"float",
@ -1233,7 +1231,7 @@ def gen_pyi(
]
all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names)
all_directive = pformat(all_symbols, width=100, compact=True).split("\n")
all_directive[0] = "__all__ = {}".format(all_directive[0])
all_directive[0] = f"__all__ = {all_directive[0]}"
# Dispatch key hints
# ~~~~~~~~~~~~~~~~~~

View File

@ -61,10 +61,8 @@ class CMake:
_cmake_min_version = LooseVersion("3.13.0")
if all(
(
ver is None or ver < _cmake_min_version
for ver in [cmake_version, cmake3_version]
)
ver is None or ver < _cmake_min_version
for ver in [cmake_version, cmake3_version]
):
raise RuntimeError("no cmake or cmake3 with version >= 3.13.0 found")
@ -108,7 +106,7 @@ class CMake:
"Adds definitions to a cmake argument list."
for key, value in sorted(kwargs.items()):
if value is not None:
args.append("-D{}={}".format(key, value))
args.append(f"-D{key}={value}")
def get_cmake_cache_variables(self) -> Dict[str, CMakeValue]:
r"""Gets values in CMakeCache.txt into a dictionary.
@ -172,9 +170,7 @@ class CMake:
args.append("-Ax64")
toolset_dict["host"] = "x64"
if toolset_dict:
toolset_expr = ",".join(
["{}={}".format(k, v) for k, v in toolset_dict.items()]
)
toolset_expr = ",".join([f"{k}={v}" for k, v in toolset_dict.items()])
args.append("-T" + toolset_expr)
base_dir = os.path.dirname(
@ -322,11 +318,9 @@ class CMake:
expected_wrapper = "/usr/local/opt/ccache/libexec"
if IS_DARWIN and os.path.exists(expected_wrapper):
if "CMAKE_C_COMPILER" not in build_options and "CC" not in os.environ:
CMake.defines(args, CMAKE_C_COMPILER="{}/gcc".format(expected_wrapper))
CMake.defines(args, CMAKE_C_COMPILER=f"{expected_wrapper}/gcc")
if "CMAKE_CXX_COMPILER" not in build_options and "CXX" not in os.environ:
CMake.defines(
args, CMAKE_CXX_COMPILER="{}/g++".format(expected_wrapper)
)
CMake.defines(args, CMAKE_CXX_COMPILER=f"{expected_wrapper}/g++")
for env_var_name in my_env:
if env_var_name.startswith("gh"):
@ -335,11 +329,9 @@ class CMake:
try:
my_env[env_var_name] = str(my_env[env_var_name].encode("utf-8"))
except UnicodeDecodeError as e:
shex = ":".join(
"{:02x}".format(ord(c)) for c in my_env[env_var_name]
)
shex = ":".join(f"{ord(c):02x}" for c in my_env[env_var_name])
print(
"Invalid ENV[{}] = {}".format(env_var_name, shex),
f"Invalid ENV[{env_var_name}] = {shex}",
file=sys.stderr,
)
print(e, file=sys.stderr)
@ -396,7 +388,7 @@ class CMake:
build_args += ["--"]
if IS_WINDOWS and not USE_NINJA:
# We are likely using msbuild here
build_args += ["/p:CL_MPCount={}".format(max_jobs)]
build_args += [f"/p:CL_MPCount={max_jobs}"]
else:
build_args += ["-j", max_jobs]
self.run(build_args, my_env)

View File

@ -71,9 +71,7 @@ def get_cmake_cache_variables_from_file(
r'("?)(.+?)\1(?::\s*([a-zA-Z_-][a-zA-Z0-9_-]*)?)?\s*=\s*(.*)', line
)
if matched is None: # Illegal line
raise ValueError(
"Unexpected line {} in {}: {}".format(i, repr(cmake_cache_file), line)
)
raise ValueError(f"Unexpected line {i} in {repr(cmake_cache_file)}: {line}")
_, variable, type_, value = matched.groups()
if type_ is None:
type_ = ""

View File

@ -75,7 +75,7 @@ def generate_code(
def get_selector_from_legacy_operator_selection_list(
selected_op_list_path: str,
) -> Any:
with open(selected_op_list_path, "r") as f:
with open(selected_op_list_path) as f:
# strip out the overload part
# It's only for legacy config - do NOT copy this code!
selected_op_list = {

View File

@ -46,7 +46,7 @@ def fetch_and_cache(
if os.path.exists(path) and is_cached_file_valid():
# Another test process already download the file, so don't re-do it
with open(path, "r") as f:
with open(path) as f:
return cast(Dict[str, Any], json.load(f))
for _ in range(3):

View File

@ -249,10 +249,8 @@ class EnvVarMetric:
value = os.environ.get(self.env_var)
if value is None and self.required:
raise ValueError(
(
f"Missing {self.name}. Please set the {self.env_var}"
"environment variable to pass in this value."
)
f"Missing {self.name}. Please set the {self.env_var}"
"environment variable to pass in this value."
)
if self.type_conversion_fn:
return self.type_conversion_fn(value)

View File

@ -11,7 +11,7 @@ if __name__ == "__main__":
parser.add_argument("--replace", action="append", nargs=2)
options = parser.parse_args()
with open(options.input_file, "r") as f:
with open(options.input_file) as f:
contents = f.read()
output_file = os.path.join(options.install_dir, options.output_file)

View File

@ -181,7 +181,7 @@ class TestParseNativeYaml(unittest.TestCase):
use_aten_lib=False,
out_file=out_file,
)
with open(out_yaml_path, "r") as out_file:
with open(out_yaml_path) as out_file:
es = yaml.load(out_file, Loader=LineLoader)
self.assertTrue(all("func" in e for e in es))
self.assertTrue(all(e.get("variants") == "function" for e in es))
@ -268,7 +268,7 @@ class TestParseKernelYamlFiles(unittest.TestCase):
use_aten_lib=False,
out_file=out_file,
)
with open(out_yaml_path, "r") as out_file:
with open(out_yaml_path) as out_file:
es = yaml.load(out_file, Loader=LineLoader)
self.assertTrue(all("func" in e for e in es))
self.assertTrue(all(e.get("variants") == "function" for e in es))

View File

@ -21,7 +21,7 @@ class ExecutorchCppSignatureTest(unittest.TestCase):
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
):
args = self.sig.arguments(include_context=True)
self.assertEquals(len(args), 3)
self.assertEqual(len(args), 3)
self.assertTrue(any(a.name == "context" for a in args))
def test_runtime_signature_does_not_contain_runtime_context(self) -> None:
@ -30,7 +30,7 @@ class ExecutorchCppSignatureTest(unittest.TestCase):
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
):
args = self.sig.arguments(include_context=False)
self.assertEquals(len(args), 2)
self.assertEqual(len(args), 2)
self.assertFalse(any(a.name == "context" for a in args))
def test_runtime_signature_declaration_correct(self) -> None:
@ -38,7 +38,7 @@ class ExecutorchCppSignatureTest(unittest.TestCase):
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
):
decl = self.sig.decl(include_context=True)
self.assertEquals(
self.assertEqual(
decl,
(
"torch::executor::Tensor & foo_outf("
@ -48,7 +48,7 @@ class ExecutorchCppSignatureTest(unittest.TestCase):
),
)
no_context_decl = self.sig.decl(include_context=False)
self.assertEquals(
self.assertEqual(
no_context_decl,
(
"torch::executor::Tensor & foo_outf("

View File

@ -106,9 +106,9 @@ x = $TILE_SIZE_X + $TILE_SIZE_Y
file_name_2 = os.path.join(tmp_dir, "conv2d_pw_1x2.glsl")
self.assertTrue(os.path.exists(file_name_1))
self.assertTrue(os.path.exists(file_name_2))
with open(file_name_1, "r") as f:
with open(file_name_1) as f:
contents = f.read()
self.assertTrue("1 + 1" in contents)
with open(file_name_2, "r") as f:
with open(file_name_2) as f:
contents = f.read()
self.assertTrue("1 + 2" in contents)

View File

@ -127,7 +127,7 @@ if __name__ == "__main__":
args = parser.parse_args()
touched_files = [CONFIG_YML]
with open(CONFIG_YML, "r") as f:
with open(CONFIG_YML) as f:
config_yml = yaml.safe_load(f.read())
config_yml["workflows"] = get_filtered_circleci_config(

View File

@ -163,7 +163,7 @@ def _get_previously_failing_tests() -> Set[str]:
)
return set()
with open(PYTEST_FAILED_TESTS_CACHE_FILE_PATH, "r") as f:
with open(PYTEST_FAILED_TESTS_CACHE_FILE_PATH) as f:
last_failed_tests = json.load(f)
prioritized_tests = _parse_prev_failing_test_files(last_failed_tests)