mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Enable ruff's UP rules and autoformat tools and scripts (#105428)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105428 Approved by: https://github.com/albanD, https://github.com/soulitzer, https://github.com/malfet
This commit is contained in:
parent
5666d20bb8
commit
14d87bb5ff
|
|
@ -88,9 +88,9 @@ def sign_certificate_request(path, csr_cert, ca_cert, private_ca_key):
|
|||
|
||||
|
||||
ca_key = genrsa(temp_dir + "/ca.key")
|
||||
ca_cert = create_cert(temp_dir + "/ca.pem", u"US", u"New York", u"New York", u"Gloo Certificate Authority", ca_key)
|
||||
ca_cert = create_cert(temp_dir + "/ca.pem", "US", "New York", "New York", "Gloo Certificate Authority", ca_key)
|
||||
|
||||
pkey = genrsa(temp_dir + "/pkey.key")
|
||||
csr = create_req(temp_dir + "/csr.csr", u"US", u"California", u"San Francisco", u"Gloo Testing Company", pkey)
|
||||
csr = create_req(temp_dir + "/csr.csr", "US", "California", "San Francisco", "Gloo Testing Company", pkey)
|
||||
|
||||
cert = sign_certificate_request(temp_dir + "/cert.pem", csr, ca_cert, ca_key)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ if 'cpu' in test_name:
|
|||
elif 'gpu' in test_name:
|
||||
backend = 'gpu'
|
||||
|
||||
data_file_path = '../{}_runtime.json'.format(backend)
|
||||
data_file_path = f'../{backend}_runtime.json'
|
||||
|
||||
with open(data_file_path) as data_file:
|
||||
data = json.load(data_file)
|
||||
|
|
@ -69,7 +69,7 @@ else:
|
|||
print("z-value < 3, no perf regression detected.")
|
||||
if args.update:
|
||||
print("We will use these numbers as new baseline.")
|
||||
new_data_file_path = '../new_{}_runtime.json'.format(backend)
|
||||
new_data_file_path = f'../new_{backend}_runtime.json'
|
||||
with open(new_data_file_path) as new_data_file:
|
||||
new_data = json.load(new_data_file)
|
||||
new_data[test_name] = {}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import cimodel.data.binary_build_data as binary_build_data
|
|||
import cimodel.lib.conf_tree as conf_tree
|
||||
import cimodel.lib.miniutils as miniutils
|
||||
|
||||
class Conf(object):
|
||||
class Conf:
|
||||
def __init__(self, os, gpu_version, pydistro, parms, smoke, libtorch_variant, gcc_config_variant, libtorch_config_variant):
|
||||
|
||||
self.os = os
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ class Conf:
|
|||
|
||||
|
||||
# TODO This is a hack to special case some configs just for the workflow list
|
||||
class HiddenConf(object):
|
||||
class HiddenConf:
|
||||
def __init__(self, name, parent_build=None, filters=None):
|
||||
self.name = name
|
||||
self.parent_build = parent_build
|
||||
|
|
@ -160,7 +160,7 @@ class HiddenConf(object):
|
|||
def gen_build_name(self, _):
|
||||
return self.name
|
||||
|
||||
class DocPushConf(object):
|
||||
class DocPushConf:
|
||||
def __init__(self, name, parent_build=None, branch="master"):
|
||||
self.name = name
|
||||
self.parent_build = parent_build
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import cimodel.lib.miniutils as miniutils
|
|||
import cimodel.lib.miniyaml as miniyaml
|
||||
|
||||
|
||||
class File(object):
|
||||
class File:
|
||||
"""
|
||||
Verbatim copy the contents of a file into config.yml
|
||||
"""
|
||||
|
|
@ -57,7 +57,7 @@ def horizontal_rule():
|
|||
return "".join("#" * 78)
|
||||
|
||||
|
||||
class Header(object):
|
||||
class Header:
|
||||
def __init__(self, title, summary=None):
|
||||
self.title = title
|
||||
self.summary_lines = summary or []
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ EXPECTED_GROUP = (
|
|||
|
||||
|
||||
def should_check(filename: Path) -> bool:
|
||||
with open(filename, "r") as f:
|
||||
with open(filename) as f:
|
||||
content = f.read()
|
||||
|
||||
data = yaml.safe_load(content)
|
||||
|
|
@ -37,7 +37,7 @@ if __name__ == "__main__":
|
|||
files = [f for f in files if should_check(f)]
|
||||
names = set()
|
||||
for filename in files:
|
||||
with open(filename, "r") as f:
|
||||
with open(filename) as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
name = data.get("name")
|
||||
|
|
|
|||
2
.github/scripts/file_io_utils.py
vendored
2
.github/scripts/file_io_utils.py
vendored
|
|
@ -44,7 +44,7 @@ def load_json_file(file_path: Path) -> Any:
|
|||
"""
|
||||
Returns the deserialized json object
|
||||
"""
|
||||
with open(file_path, "r") as f:
|
||||
with open(file_path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
|
|
|
|||
2
.github/scripts/filter_test_configs.py
vendored
2
.github/scripts/filter_test_configs.py
vendored
|
|
@ -319,7 +319,7 @@ def process_jobs(
|
|||
try:
|
||||
# The job name from github is in the PLATFORM / JOB (CONFIG) format, so breaking
|
||||
# it into its two components first
|
||||
current_platform, _ = [n.strip() for n in job_name.split(JOB_NAME_SEP, 1) if n]
|
||||
current_platform, _ = (n.strip() for n in job_name.split(JOB_NAME_SEP, 1) if n)
|
||||
except ValueError as error:
|
||||
warnings.warn(f"Invalid job name {job_name}, returning")
|
||||
return test_matrix
|
||||
|
|
|
|||
2
.github/scripts/generate_pytorch_version.py
vendored
2
.github/scripts/generate_pytorch_version.py
vendored
|
|
@ -50,7 +50,7 @@ def get_tag() -> str:
|
|||
|
||||
def get_base_version() -> str:
|
||||
root = get_pytorch_root()
|
||||
dirty_version = open(root / "version.txt", "r").read().strip()
|
||||
dirty_version = open(root / "version.txt").read().strip()
|
||||
# Strips trailing a0 from version.txt, not too sure why it's there in the
|
||||
# first place
|
||||
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
|
||||
|
|
|
|||
2
.github/scripts/label_utils.py
vendored
2
.github/scripts/label_utils.py
vendored
|
|
@ -51,7 +51,7 @@ def get_last_page_num_from_header(header: Any) -> int:
|
|||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def gh_get_labels(org: str, repo: str) -> List[str]:
|
||||
prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100"
|
||||
header, info = request_for_labels(prefix + "&page=1")
|
||||
|
|
|
|||
2
.github/scripts/lint_native_functions.py
vendored
2
.github/scripts/lint_native_functions.py
vendored
|
|
@ -26,7 +26,7 @@ def fn(base: str) -> str:
|
|||
return str(base / Path("aten/src/ATen/native/native_functions.yaml"))
|
||||
|
||||
|
||||
with open(Path(__file__).parent.parent.parent / fn("."), "r") as f:
|
||||
with open(Path(__file__).parent.parent.parent / fn(".")) as f:
|
||||
contents = f.read()
|
||||
|
||||
yaml = ruamel.yaml.YAML() # type: ignore[attr-defined]
|
||||
|
|
|
|||
4
.github/scripts/run_torchbench.py
vendored
4
.github/scripts/run_torchbench.py
vendored
|
|
@ -129,7 +129,7 @@ def extract_models_from_pr(
|
|||
model_list = []
|
||||
userbenchmark_list = []
|
||||
pr_list = []
|
||||
with open(prbody_file, "r") as pf:
|
||||
with open(prbody_file) as pf:
|
||||
lines = (x.strip() for x in pf.read().splitlines())
|
||||
magic_lines = list(filter(lambda x: x.startswith(MAGIC_PREFIX), lines))
|
||||
if magic_lines:
|
||||
|
|
@ -157,7 +157,7 @@ def extract_models_from_pr(
|
|||
|
||||
def find_torchbench_branch(prbody_file: str) -> str:
|
||||
branch_name: str = ""
|
||||
with open(prbody_file, "r") as pf:
|
||||
with open(prbody_file) as pf:
|
||||
lines = (x.strip() for x in pf.read().splitlines())
|
||||
magic_lines = list(
|
||||
filter(lambda x: x.startswith(MAGIC_TORCHBENCH_PREFIX), lines)
|
||||
|
|
|
|||
2
.github/scripts/test_check_labels.py
vendored
2
.github/scripts/test_check_labels.py
vendored
|
|
@ -15,7 +15,7 @@ from trymerge import GitHubPR
|
|||
|
||||
|
||||
def mock_parse_args() -> object:
|
||||
class Object(object):
|
||||
class Object:
|
||||
def __init__(self) -> None:
|
||||
self.pr_num = 76123
|
||||
|
||||
|
|
|
|||
2
.github/scripts/test_trymerge.py
vendored
2
.github/scripts/test_trymerge.py
vendored
|
|
@ -114,7 +114,7 @@ def mocked_rockset_results(head_sha: str, merge_base: str, num_retries: int = 3)
|
|||
|
||||
|
||||
def mock_parse_args(revert: bool = False, force: bool = False) -> Any:
|
||||
class Object(object):
|
||||
class Object:
|
||||
def __init__(self) -> None:
|
||||
self.revert = revert
|
||||
self.force = force
|
||||
|
|
|
|||
6
.github/scripts/trymerge.py
vendored
6
.github/scripts/trymerge.py
vendored
|
|
@ -1628,10 +1628,8 @@ def validate_revert(
|
|||
allowed_reverters.append("CONTRIBUTOR")
|
||||
if author_association not in allowed_reverters:
|
||||
raise PostCommentError(
|
||||
(
|
||||
f"Will not revert as @{author_login} is not one of "
|
||||
f"[{', '.join(allowed_reverters)}], but instead is {author_association}."
|
||||
)
|
||||
f"Will not revert as @{author_login} is not one of "
|
||||
f"[{', '.join(allowed_reverters)}], but instead is {author_association}."
|
||||
)
|
||||
skip_internal_checks = can_skip_internal_checks(pr, comment_id)
|
||||
|
||||
|
|
|
|||
2
.github/scripts/trymerge_explainer.py
vendored
2
.github/scripts/trymerge_explainer.py
vendored
|
|
@ -17,7 +17,7 @@ def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
|
|||
return len(list(filter(pattern.match, labels))) > 0
|
||||
|
||||
|
||||
class TryMergeExplainer(object):
|
||||
class TryMergeExplainer:
|
||||
force: bool
|
||||
labels: List[str]
|
||||
pr_num: int
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import shutil
|
|||
|
||||
# Module caffe2...caffe2.python.control_test
|
||||
def insert(originalfile, first_line, description):
|
||||
with open(originalfile, 'r') as f:
|
||||
with open(originalfile) as f:
|
||||
f1 = f.readline()
|
||||
if(f1.find(first_line) < 0):
|
||||
docs = first_line + description + f1
|
||||
|
|
@ -30,7 +30,7 @@ for root, dirs, files in os.walk("."):
|
|||
for file in files:
|
||||
if (file.endswith(".py") and not file.endswith("_test.py") and not file.endswith("__.py")):
|
||||
filepath = os.path.join(root, file)
|
||||
print(("filepath: " + filepath))
|
||||
print("filepath: " + filepath)
|
||||
directory = os.path.dirname(filepath)[2:]
|
||||
directory = directory.replace("/", ".")
|
||||
print("directory: " + directory)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# PyTorch documentation build configuration file, created by
|
||||
# sphinx-quickstart on Fri Dec 23 13:31:47 2016.
|
||||
|
|
@ -99,7 +98,7 @@ exhale_args = {
|
|||
############################################################################
|
||||
# Main library page layout example configuration. #
|
||||
############################################################################
|
||||
"afterTitleDescription": textwrap.dedent(u'''
|
||||
"afterTitleDescription": textwrap.dedent('''
|
||||
Welcome to the developer reference for the PyTorch C++ API.
|
||||
'''),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# PyTorch documentation build configuration file, created by
|
||||
# sphinx-quickstart on Fri Dec 23 13:31:47 2016.
|
||||
|
|
@ -624,7 +623,7 @@ def replace(Klass):
|
|||
anchor = ref_anchor[1]
|
||||
txt = node.parent.astext()
|
||||
if txt == anchor or txt == anchor.split('.')[-1]:
|
||||
self.body.append('<p id="{}"/>'.format(ref_anchor[1]))
|
||||
self.body.append(f'<p id="{ref_anchor[1]}"/>')
|
||||
return old_call(self, node)
|
||||
Klass.visit_reference = visit_reference
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def generate_example_rst(example_case: ExportCase):
|
|||
if isinstance(model, torch.nn.Module)
|
||||
else inspect.getfile(model)
|
||||
)
|
||||
with open(source_file, "r") as file:
|
||||
with open(source_file) as file:
|
||||
source_code = file.read()
|
||||
source_code = re.sub(r"from torch\._export\.db\.case import .*\n", "", source_code)
|
||||
source_code = re.sub(r"@export_case\((.|\n)*?\)\n", "", source_code)
|
||||
|
|
@ -114,7 +114,7 @@ def generate_index_rst(example_cases, tag_to_modules, support_level_to_modules):
|
|||
|
||||
tag_names = "\n ".join(t for t in tag_to_modules.keys())
|
||||
|
||||
with open(os.path.join(PWD, "blurb.txt"), "r") as file:
|
||||
with open(os.path.join(PWD, "blurb.txt")) as file:
|
||||
blurb = file.read()
|
||||
|
||||
# Generate contents of the .rst file
|
||||
|
|
|
|||
38
setup.py
38
setup.py
|
|
@ -323,7 +323,7 @@ cmake_python_include_dir = sysconfig.get_path("include")
|
|||
package_name = os.getenv('TORCH_PACKAGE_NAME', 'torch')
|
||||
package_type = os.getenv('PACKAGE_TYPE', 'wheel')
|
||||
version = get_torch_version()
|
||||
report("Building wheel {}-{}".format(package_name, version))
|
||||
report(f"Building wheel {package_name}-{version}")
|
||||
|
||||
cmake = CMake()
|
||||
|
||||
|
|
@ -361,7 +361,7 @@ def check_submodules():
|
|||
start = time.time()
|
||||
subprocess.check_call(["git", "submodule", "update", "--init", "--recursive"], cwd=cwd)
|
||||
end = time.time()
|
||||
print(' --- Submodule initialization took {:.2f} sec'.format(end - start))
|
||||
print(f' --- Submodule initialization took {end - start:.2f} sec')
|
||||
except Exception:
|
||||
print(' --- Submodule initalization failed')
|
||||
print('Please run:\n\tgit submodule update --init --recursive')
|
||||
|
|
@ -616,16 +616,16 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||
continue
|
||||
fullname = self.get_ext_fullname(ext.name)
|
||||
filename = self.get_ext_filename(fullname)
|
||||
report("\nCopying extension {}".format(ext.name))
|
||||
report(f"\nCopying extension {ext.name}")
|
||||
|
||||
relative_site_packages = sysconfig.get_path('purelib').replace(sysconfig.get_path('data'), '').lstrip(os.path.sep)
|
||||
src = os.path.join("torch", relative_site_packages, filename)
|
||||
if not os.path.exists(src):
|
||||
report("{} does not exist".format(src))
|
||||
report(f"{src} does not exist")
|
||||
del self.extensions[i]
|
||||
else:
|
||||
dst = os.path.join(os.path.realpath(self.build_lib), filename)
|
||||
report("Copying {} from {} to {}".format(ext.name, src, dst))
|
||||
report(f"Copying {ext.name} from {src} to {dst}")
|
||||
dst_dir = os.path.dirname(dst)
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
|
|
@ -642,7 +642,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||
src = os.path.join(os.path.dirname(filename), "functorch" + fileext)
|
||||
dst = os.path.join(os.path.realpath(self.build_lib), filename)
|
||||
if os.path.exists(src):
|
||||
report("Copying {} from {} to {}".format(ext.name, src, dst))
|
||||
report(f"Copying {ext.name} from {src} to {dst}")
|
||||
dst_dir = os.path.dirname(dst)
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
|
|
@ -658,7 +658,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||
src = os.path.join(os.path.dirname(filename), "nvfuser" + fileext)
|
||||
dst = os.path.join(os.path.realpath(self.build_lib), filename)
|
||||
if os.path.exists(src):
|
||||
report("Copying {} from {} to {}".format(ext.name, src, dst))
|
||||
report(f"Copying {ext.name} from {src} to {dst}")
|
||||
dst_dir = os.path.dirname(dst)
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
|
|
@ -670,7 +670,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||
def get_outputs(self):
|
||||
outputs = setuptools.command.build_ext.build_ext.get_outputs(self)
|
||||
outputs.append(os.path.join(self.build_lib, "caffe2"))
|
||||
report("setup.py::get_outputs returning {}".format(outputs))
|
||||
report(f"setup.py::get_outputs returning {outputs}")
|
||||
return outputs
|
||||
|
||||
def create_compile_commands(self):
|
||||
|
|
@ -694,13 +694,13 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||
new_contents = json.dumps(all_commands, indent=2)
|
||||
contents = ''
|
||||
if os.path.exists('compile_commands.json'):
|
||||
with open('compile_commands.json', 'r') as f:
|
||||
with open('compile_commands.json') as f:
|
||||
contents = f.read()
|
||||
if contents != new_contents:
|
||||
with open('compile_commands.json', 'w') as f:
|
||||
f.write(new_contents)
|
||||
|
||||
class concat_license_files():
|
||||
class concat_license_files:
|
||||
"""Merge LICENSE and LICENSES_BUNDLED.txt as a context manager
|
||||
|
||||
LICENSE is the main PyTorch license, LICENSES_BUNDLED.txt is auto-generated
|
||||
|
|
@ -723,7 +723,7 @@ class concat_license_files():
|
|||
finally:
|
||||
sys.path = old_path
|
||||
|
||||
with open(self.f1, 'r') as f1:
|
||||
with open(self.f1) as f1:
|
||||
self.bsd_text = f1.read()
|
||||
|
||||
with open(self.f1, 'a') as f1:
|
||||
|
|
@ -771,7 +771,7 @@ class clean(setuptools.Command):
|
|||
def run(self):
|
||||
import glob
|
||||
import re
|
||||
with open('.gitignore', 'r') as f:
|
||||
with open('.gitignore') as f:
|
||||
ignores = f.read()
|
||||
pat = re.compile(r'^#( BEGIN NOT-CLEAN-FILES )?')
|
||||
for wildcard in filter(None, ignores.split('\n')):
|
||||
|
|
@ -934,31 +934,31 @@ def configure_extension_build():
|
|||
if cmake_cache_vars['BUILD_CAFFE2']:
|
||||
extensions.append(
|
||||
Extension(
|
||||
name=str('caffe2.python.caffe2_pybind11_state'),
|
||||
name='caffe2.python.caffe2_pybind11_state',
|
||||
sources=[]),
|
||||
)
|
||||
if cmake_cache_vars['USE_CUDA']:
|
||||
extensions.append(
|
||||
Extension(
|
||||
name=str('caffe2.python.caffe2_pybind11_state_gpu'),
|
||||
name='caffe2.python.caffe2_pybind11_state_gpu',
|
||||
sources=[]),
|
||||
)
|
||||
if cmake_cache_vars['USE_ROCM']:
|
||||
extensions.append(
|
||||
Extension(
|
||||
name=str('caffe2.python.caffe2_pybind11_state_hip'),
|
||||
name='caffe2.python.caffe2_pybind11_state_hip',
|
||||
sources=[]),
|
||||
)
|
||||
if cmake_cache_vars['BUILD_FUNCTORCH']:
|
||||
extensions.append(
|
||||
Extension(
|
||||
name=str('functorch._C'),
|
||||
name='functorch._C',
|
||||
sources=[]),
|
||||
)
|
||||
if cmake_cache_vars['BUILD_NVFUSER']:
|
||||
extensions.append(
|
||||
Extension(
|
||||
name=str('nvfuser._C'),
|
||||
name='nvfuser._C',
|
||||
sources=[]),
|
||||
)
|
||||
|
||||
|
|
@ -1271,7 +1271,7 @@ def main():
|
|||
download_url='https://github.com/pytorch/pytorch/tags',
|
||||
author='PyTorch Team',
|
||||
author_email='packages@pytorch.org',
|
||||
python_requires='>={}'.format(python_min_version_str),
|
||||
python_requires=f'>={python_min_version_str}',
|
||||
# PyPI package information.
|
||||
classifiers=[
|
||||
'Development Status :: 5 - Production/Stable',
|
||||
|
|
@ -1287,7 +1287,7 @@ def main():
|
|||
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||
'Programming Language :: C++',
|
||||
'Programming Language :: Python :: 3',
|
||||
] + ['Programming Language :: Python :: 3.{}'.format(i) for i in range(python_min_version[1], version_range_max)],
|
||||
] + [f'Programming Language :: Python :: 3.{i}' for i in range(python_min_version[1], version_range_max)],
|
||||
license='BSD-3',
|
||||
keywords='pytorch, machine learning',
|
||||
)
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ def is_hip_clang() -> bool:
|
|||
hip_path = os.getenv("HIP_PATH", "/opt/rocm/hip")
|
||||
with open(hip_path + "/lib/.hipInfo") as f:
|
||||
return "HIP_COMPILER=clang" in f.read()
|
||||
except IOError:
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -149,7 +149,7 @@ if is_hip_clang():
|
|||
gloo_cmake_file = "third_party/gloo/cmake/Hip.cmake"
|
||||
do_write = False
|
||||
if os.path.exists(gloo_cmake_file):
|
||||
with open(gloo_cmake_file, "r") as sources:
|
||||
with open(gloo_cmake_file) as sources:
|
||||
lines = sources.readlines()
|
||||
newlines = [line.replace(" hip_hcc ", " amdhip64 ") for line in lines]
|
||||
if lines == newlines:
|
||||
|
|
@ -163,7 +163,7 @@ if is_hip_clang():
|
|||
gloo_cmake_file = "third_party/gloo/cmake/Modules/Findrccl.cmake"
|
||||
if os.path.exists(gloo_cmake_file):
|
||||
do_write = False
|
||||
with open(gloo_cmake_file, "r") as sources:
|
||||
with open(gloo_cmake_file) as sources:
|
||||
lines = sources.readlines()
|
||||
newlines = [line.replace("RCCL_LIBRARY", "RCCL_LIB_PATH") for line in lines]
|
||||
if lines == newlines:
|
||||
|
|
@ -179,7 +179,7 @@ if is_hip_clang():
|
|||
gloo_cmake_file = "third_party/gloo/cmake/Dependencies.cmake"
|
||||
do_write = False
|
||||
if os.path.exists(gloo_cmake_file):
|
||||
with open(gloo_cmake_file, "r") as sources:
|
||||
with open(gloo_cmake_file) as sources:
|
||||
lines = sources.readlines()
|
||||
newlines = [line.replace("HIP_HCC_FLAGS", "HIP_CLANG_FLAGS") for line in lines]
|
||||
if lines == newlines:
|
||||
|
|
|
|||
|
|
@ -553,7 +553,7 @@ def load_deprecated_signatures(
|
|||
# find matching original signatures for each deprecated signature
|
||||
results: List[PythonSignatureNativeFunctionPair] = []
|
||||
|
||||
with open(deprecated_yaml_path, "r") as f:
|
||||
with open(deprecated_yaml_path) as f:
|
||||
deprecated_defs = yaml.load(f, Loader=YamlLoader)
|
||||
|
||||
for deprecated in deprecated_defs:
|
||||
|
|
@ -873,7 +873,7 @@ def method_impl(
|
|||
name=name,
|
||||
pycname=pycname,
|
||||
method_header=method_header,
|
||||
max_args=max((o.signature.arguments_count() for o in overloads)),
|
||||
max_args=max(o.signature.arguments_count() for o in overloads),
|
||||
signatures=signatures,
|
||||
traceable=traceable,
|
||||
check_has_torch_function=gen_has_torch_function_check(
|
||||
|
|
@ -1255,10 +1255,7 @@ def emit_single_dispatch(
|
|||
# dispatch lambda signature
|
||||
name = cpp.name(f.func)
|
||||
lambda_formals = ", ".join(
|
||||
(
|
||||
f"{a.type_str} {a.name}"
|
||||
for a in dispatch_lambda_args(ps, f, symint=symint)
|
||||
)
|
||||
f"{a.type_str} {a.name}" for a in dispatch_lambda_args(ps, f, symint=symint)
|
||||
)
|
||||
lambda_return = dispatch_lambda_return_str(f)
|
||||
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ def load_derivatives(
|
|||
global _GLOBAL_LOAD_DERIVATIVE_CACHE
|
||||
key = (derivatives_yaml_path, native_yaml_path)
|
||||
if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:
|
||||
with open(derivatives_yaml_path, "r") as f:
|
||||
with open(derivatives_yaml_path) as f:
|
||||
definitions = yaml.load(f, Loader=YamlLoader)
|
||||
|
||||
funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ def canonical_name(opname: str) -> str:
|
|||
|
||||
|
||||
def load_op_dep_graph(fname: str) -> DepGraph:
|
||||
with open(fname, "r") as stream:
|
||||
with open(fname) as stream:
|
||||
result = defaultdict(set)
|
||||
for op in yaml.safe_load(stream):
|
||||
op_name = canonical_name(op["name"])
|
||||
|
|
@ -36,7 +36,7 @@ def load_op_dep_graph(fname: str) -> DepGraph:
|
|||
|
||||
def load_root_ops(fname: str) -> List[str]:
|
||||
result = []
|
||||
with open(fname, "r") as stream:
|
||||
with open(fname) as stream:
|
||||
for op in yaml.safe_load(stream):
|
||||
result.append(canonical_name(op))
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ SupportedMobileModelCheckerRegistry register_model_versions;
|
|||
|
||||
supported_hashes = ""
|
||||
for md5 in md5_hashes:
|
||||
supported_hashes += '"{}",\n'.format(md5)
|
||||
supported_hashes += f'"{md5}",\n'
|
||||
with open(
|
||||
os.path.join(output_dir, "SupportedMobileModelsRegistration.cpp"), "wb"
|
||||
) as out_file:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import setuptools # type: ignore[import]
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
with open("README.md", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
setuptools.setup(
|
||||
|
|
|
|||
|
|
@ -32,16 +32,16 @@ def report_download_progress(
|
|||
def download(destination_path: str, resource: str, quiet: bool) -> None:
|
||||
if os.path.exists(destination_path):
|
||||
if not quiet:
|
||||
print("{} already exists, skipping ...".format(destination_path))
|
||||
print(f"{destination_path} already exists, skipping ...")
|
||||
else:
|
||||
for mirror in MIRRORS:
|
||||
url = mirror + resource
|
||||
print("Downloading {} ...".format(url))
|
||||
print(f"Downloading {url} ...")
|
||||
try:
|
||||
hook = None if quiet else report_download_progress
|
||||
urlretrieve(url, destination_path, reporthook=hook)
|
||||
except (URLError, ConnectionError) as e:
|
||||
print("Failed to download (trying next):\n{}".format(e))
|
||||
print(f"Failed to download (trying next):\n{e}")
|
||||
continue
|
||||
finally:
|
||||
if not quiet:
|
||||
|
|
@ -56,13 +56,13 @@ def unzip(zipped_path: str, quiet: bool) -> None:
|
|||
unzipped_path = os.path.splitext(zipped_path)[0]
|
||||
if os.path.exists(unzipped_path):
|
||||
if not quiet:
|
||||
print("{} already exists, skipping ... ".format(unzipped_path))
|
||||
print(f"{unzipped_path} already exists, skipping ... ")
|
||||
return
|
||||
with gzip.open(zipped_path, "rb") as zipped_file:
|
||||
with open(unzipped_path, "wb") as unzipped_file:
|
||||
unzipped_file.write(zipped_file.read())
|
||||
if not quiet:
|
||||
print("Unzipped {} ...".format(zipped_path))
|
||||
print(f"Unzipped {zipped_path} ...")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class VulkanShaderGenerator:
|
|||
|
||||
def add_params_yaml(self, parameters_yaml_file): # type: ignore[no-untyped-def]
|
||||
all_template_params = OrderedDict()
|
||||
with open(parameters_yaml_file, "r") as f:
|
||||
with open(parameters_yaml_file) as f:
|
||||
contents = yaml.load(f, Loader=UniqueKeyLoader)
|
||||
for key in contents:
|
||||
all_template_params[key] = contents[key]
|
||||
|
|
@ -204,7 +204,7 @@ def determineDescriptorType(lineStr: str) -> str:
|
|||
|
||||
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
|
||||
shader_info = ShaderInfo([], [], "")
|
||||
with open(srcFilePath, 'r') as srcFile:
|
||||
with open(srcFilePath) as srcFile:
|
||||
for line in srcFile:
|
||||
if isDescriptorLine(line):
|
||||
shader_info.layouts.append(determineDescriptorType(line))
|
||||
|
|
@ -271,13 +271,13 @@ def genCppH(
|
|||
if len(f) > 1:
|
||||
templateSrcPaths.append(f)
|
||||
templateSrcPaths.sort()
|
||||
print("templateSrcPaths:{}".format(templateSrcPaths))
|
||||
print(f"templateSrcPaths:{templateSrcPaths}")
|
||||
|
||||
spvPaths = {}
|
||||
for templateSrcPath in templateSrcPaths:
|
||||
print("templateSrcPath {}".format(templateSrcPath))
|
||||
print(f"templateSrcPath {templateSrcPath}")
|
||||
name = getName(templateSrcPath).replace("_glsl", "")
|
||||
print("name {}".format(name))
|
||||
print(f"name {name}")
|
||||
|
||||
codeTemplate = CodeTemplate.from_file(templateSrcPath)
|
||||
srcPath = tmpDirPath + "/" + name + ".glsl"
|
||||
|
|
@ -286,7 +286,7 @@ def genCppH(
|
|||
fw.write(content)
|
||||
|
||||
spvPath = tmpDirPath + "/" + name + ".spv"
|
||||
print("spvPath {}".format(spvPath))
|
||||
print(f"spvPath {spvPath}")
|
||||
|
||||
cmd = [
|
||||
glslcPath, "-fshader-stage=compute",
|
||||
|
|
@ -327,7 +327,7 @@ def genCppH(
|
|||
h += nsend
|
||||
|
||||
cpp = "#include <ATen/native/vulkan/api/Shader.h>\n"
|
||||
cpp += "#include <ATen/native/vulkan/{}>\n".format(H_NAME)
|
||||
cpp += f"#include <ATen/native/vulkan/{H_NAME}>\n"
|
||||
cpp += "#include <stdint.h>\n"
|
||||
cpp += "#include <vector>\n"
|
||||
cpp += nsbegin
|
||||
|
|
@ -339,7 +339,7 @@ def genCppH(
|
|||
for spvPath, srcPath in spvPaths.items():
|
||||
name = getName(spvPath).replace("_spv", "")
|
||||
|
||||
print("spvPath:{}".format(spvPath))
|
||||
print(f"spvPath:{spvPath}")
|
||||
with open(spvPath, 'rb') as fr:
|
||||
next_bin = array.array('I', fr.read())
|
||||
sizeBytes = 4 * len(next_bin)
|
||||
|
|
@ -361,8 +361,8 @@ def genCppH(
|
|||
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
|
||||
|
||||
shader_info_args = [
|
||||
"\"vulkan.{}\"".format(name),
|
||||
"{}_bin".format(name),
|
||||
f"\"vulkan.{name}\"",
|
||||
f"{name}_bin",
|
||||
str(sizeBytes),
|
||||
shader_info_layouts,
|
||||
tile_size,
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def get_tag(pytorch_root: Union[str, Path]) -> str:
|
|||
|
||||
def get_torch_version(sha: Optional[str] = None) -> str:
|
||||
pytorch_root = Path(__file__).parent.parent
|
||||
version = open(pytorch_root / "version.txt", "r").read().strip()
|
||||
version = open(pytorch_root / "version.txt").read().strip()
|
||||
|
||||
if os.getenv("PYTORCH_BUILD_VERSION"):
|
||||
assert os.getenv("PYTORCH_BUILD_NUMBER") is not None
|
||||
|
|
@ -86,11 +86,11 @@ if __name__ == "__main__":
|
|||
version = tagged_version
|
||||
|
||||
with open(version_path, "w") as f:
|
||||
f.write("__version__ = '{}'\n".format(version))
|
||||
f.write(f"__version__ = '{version}'\n")
|
||||
# NB: This is not 100% accurate, because you could have built the
|
||||
# library code with DEBUG, but csrc without DEBUG (in which case
|
||||
# this would claim to be a release build when it's not.)
|
||||
f.write("debug = {}\n".format(repr(bool(args.is_debug))))
|
||||
f.write("cuda = {}\n".format(repr(args.cuda_version)))
|
||||
f.write("git_version = {}\n".format(repr(sha)))
|
||||
f.write("hip = {}\n".format(repr(args.hip_version)))
|
||||
f.write(f"debug = {repr(bool(args.is_debug))}\n")
|
||||
f.write(f"cuda = {repr(args.cuda_version)}\n")
|
||||
f.write(f"git_version = {repr(sha)}\n")
|
||||
f.write(f"hip = {repr(args.hip_version)}\n")
|
||||
|
|
|
|||
|
|
@ -250,7 +250,7 @@ def main(args: List[str]) -> None:
|
|||
if options.op_registration_allowlist:
|
||||
op_registration_allowlist = options.op_registration_allowlist
|
||||
elif options.TEST_ONLY_op_registration_allowlist_yaml_path:
|
||||
with open(options.TEST_ONLY_op_registration_allowlist_yaml_path, "r") as f:
|
||||
with open(options.TEST_ONLY_op_registration_allowlist_yaml_path) as f:
|
||||
op_registration_allowlist = yaml.safe_load(f)
|
||||
else:
|
||||
op_registration_allowlist = None
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class LintMessage(NamedTuple):
|
|||
def check_file(filename: str) -> Optional[LintMessage]:
|
||||
logging.debug("Checking file %s", filename)
|
||||
|
||||
with open(filename, "r") as f:
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for idx, line in enumerate(lines):
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ def lint_file(
|
|||
original = None
|
||||
replacement = None
|
||||
if replace_pattern:
|
||||
with open(filename, "r") as f:
|
||||
with open(filename) as f:
|
||||
original = f.read()
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ class Formatter(logging.Formatter):
|
|||
self.redactions[needle] = replace
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
@functools.lru_cache
|
||||
def logging_base_dir() -> str:
|
||||
meta_dir = os.getcwd()
|
||||
base_dir = os.path.join(meta_dir, "nightly", "log")
|
||||
|
|
@ -113,17 +113,17 @@ def logging_base_dir() -> str:
|
|||
return base_dir
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
@functools.lru_cache
|
||||
def logging_run_dir() -> str:
|
||||
cur_dir = os.path.join(
|
||||
logging_base_dir(),
|
||||
"{}_{}".format(datetime.datetime.now().strftime(DATETIME_FORMAT), uuid.uuid1()),
|
||||
f"{datetime.datetime.now().strftime(DATETIME_FORMAT)}_{uuid.uuid1()}",
|
||||
)
|
||||
os.makedirs(cur_dir, exist_ok=True)
|
||||
return cur_dir
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
@functools.lru_cache
|
||||
def logging_record_argv() -> None:
|
||||
s = subprocess.list2cmdline(sys.argv)
|
||||
with open(os.path.join(logging_run_dir(), "argv"), "w") as f:
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ def gen_diagnostics(
|
|||
out_cpp_dir: str,
|
||||
out_docs_dir: str,
|
||||
) -> None:
|
||||
with open(rules_path, "r") as f:
|
||||
with open(rules_path) as f:
|
||||
rules = yaml.load(f, Loader=YamlLoader)
|
||||
|
||||
template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from typing import Any
|
|||
def read_sub_write(path: str, prefix_pat: str, new_default: int) -> None:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
content_str = f.read()
|
||||
content_str = re.sub(prefix_pat, r"\g<1>{}".format(new_default), content_str)
|
||||
content_str = re.sub(prefix_pat, rf"\g<1>{new_default}", content_str)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content_str)
|
||||
print("modified", path)
|
||||
|
|
|
|||
|
|
@ -191,15 +191,15 @@ def sig_for_ops(opname: str) -> List[str]:
|
|||
|
||||
name = opname[2:-2]
|
||||
if name in binary_ops:
|
||||
return ["def {}(self, other: Any) -> Tensor: ...".format(opname)]
|
||||
return [f"def {opname}(self, other: Any) -> Tensor: ..."]
|
||||
elif name in comparison_ops:
|
||||
sig = "def {}(self, other: Any) -> Tensor: ...".format(opname)
|
||||
sig = f"def {opname}(self, other: Any) -> Tensor: ..."
|
||||
if name in symmetric_comparison_ops:
|
||||
# unsafe override https://github.com/python/mypy/issues/5704
|
||||
sig += " # type: ignore[override]"
|
||||
return [sig]
|
||||
elif name in unary_ops:
|
||||
return ["def {}(self) -> Tensor: ...".format(opname)]
|
||||
return [f"def {opname}(self) -> Tensor: ..."]
|
||||
elif name in to_py_type_ops:
|
||||
if name in {"bool", "float", "complex"}:
|
||||
tname = name
|
||||
|
|
@ -209,7 +209,7 @@ def sig_for_ops(opname: str) -> List[str]:
|
|||
tname = "int"
|
||||
if tname in {"float", "int", "bool", "complex"}:
|
||||
tname = "builtins." + tname
|
||||
return ["def {}(self) -> {}: ...".format(opname, tname)]
|
||||
return [f"def {opname}(self) -> {tname}: ..."]
|
||||
else:
|
||||
raise Exception("unknown op", opname)
|
||||
|
||||
|
|
@ -1120,9 +1120,7 @@ def gen_pyi(
|
|||
"bfloat16",
|
||||
]
|
||||
for name in simple_conversions:
|
||||
unsorted_tensor_method_hints[name].append(
|
||||
"def {}(self) -> Tensor: ...".format(name)
|
||||
)
|
||||
unsorted_tensor_method_hints[name].append(f"def {name}(self) -> Tensor: ...")
|
||||
|
||||
# pyi tensor methods don't currently include deprecated signatures for some reason
|
||||
# TODO: we should probably add them in
|
||||
|
|
@ -1151,7 +1149,7 @@ def gen_pyi(
|
|||
namedtuples[tuple_name] = tuple_def
|
||||
|
||||
for op in all_ops:
|
||||
name = "__{}__".format(op)
|
||||
name = f"__{op}__"
|
||||
unsorted_tensor_method_hints[name] += sig_for_ops(name)
|
||||
|
||||
tensor_method_hints = []
|
||||
|
|
@ -1165,7 +1163,7 @@ def gen_pyi(
|
|||
# Generate namedtuple definitions
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
namedtuple_defs = ["{}\n".format(defn) for defn in namedtuples.values()]
|
||||
namedtuple_defs = [f"{defn}\n" for defn in namedtuples.values()]
|
||||
|
||||
# Generate type signatures for legacy classes
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
@ -1184,7 +1182,7 @@ def gen_pyi(
|
|||
"ByteTensor",
|
||||
"BoolTensor",
|
||||
):
|
||||
legacy_class_hints.append("class {}(Tensor): ...".format(c))
|
||||
legacy_class_hints.append(f"class {c}(Tensor): ...")
|
||||
|
||||
# Generate type signatures for dtype classes
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
@ -1192,7 +1190,7 @@ def gen_pyi(
|
|||
# TODO: don't explicitly list dtypes here; get it from canonical
|
||||
# source
|
||||
dtype_class_hints = [
|
||||
"{}: dtype = ...".format(n)
|
||||
f"{n}: dtype = ..."
|
||||
for n in [
|
||||
"float32",
|
||||
"float",
|
||||
|
|
@ -1233,7 +1231,7 @@ def gen_pyi(
|
|||
]
|
||||
all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names)
|
||||
all_directive = pformat(all_symbols, width=100, compact=True).split("\n")
|
||||
all_directive[0] = "__all__ = {}".format(all_directive[0])
|
||||
all_directive[0] = f"__all__ = {all_directive[0]}"
|
||||
|
||||
# Dispatch key hints
|
||||
# ~~~~~~~~~~~~~~~~~~
|
||||
|
|
|
|||
|
|
@ -61,10 +61,8 @@ class CMake:
|
|||
|
||||
_cmake_min_version = LooseVersion("3.13.0")
|
||||
if all(
|
||||
(
|
||||
ver is None or ver < _cmake_min_version
|
||||
for ver in [cmake_version, cmake3_version]
|
||||
)
|
||||
ver is None or ver < _cmake_min_version
|
||||
for ver in [cmake_version, cmake3_version]
|
||||
):
|
||||
raise RuntimeError("no cmake or cmake3 with version >= 3.13.0 found")
|
||||
|
||||
|
|
@ -108,7 +106,7 @@ class CMake:
|
|||
"Adds definitions to a cmake argument list."
|
||||
for key, value in sorted(kwargs.items()):
|
||||
if value is not None:
|
||||
args.append("-D{}={}".format(key, value))
|
||||
args.append(f"-D{key}={value}")
|
||||
|
||||
def get_cmake_cache_variables(self) -> Dict[str, CMakeValue]:
|
||||
r"""Gets values in CMakeCache.txt into a dictionary.
|
||||
|
|
@ -172,9 +170,7 @@ class CMake:
|
|||
args.append("-Ax64")
|
||||
toolset_dict["host"] = "x64"
|
||||
if toolset_dict:
|
||||
toolset_expr = ",".join(
|
||||
["{}={}".format(k, v) for k, v in toolset_dict.items()]
|
||||
)
|
||||
toolset_expr = ",".join([f"{k}={v}" for k, v in toolset_dict.items()])
|
||||
args.append("-T" + toolset_expr)
|
||||
|
||||
base_dir = os.path.dirname(
|
||||
|
|
@ -322,11 +318,9 @@ class CMake:
|
|||
expected_wrapper = "/usr/local/opt/ccache/libexec"
|
||||
if IS_DARWIN and os.path.exists(expected_wrapper):
|
||||
if "CMAKE_C_COMPILER" not in build_options and "CC" not in os.environ:
|
||||
CMake.defines(args, CMAKE_C_COMPILER="{}/gcc".format(expected_wrapper))
|
||||
CMake.defines(args, CMAKE_C_COMPILER=f"{expected_wrapper}/gcc")
|
||||
if "CMAKE_CXX_COMPILER" not in build_options and "CXX" not in os.environ:
|
||||
CMake.defines(
|
||||
args, CMAKE_CXX_COMPILER="{}/g++".format(expected_wrapper)
|
||||
)
|
||||
CMake.defines(args, CMAKE_CXX_COMPILER=f"{expected_wrapper}/g++")
|
||||
|
||||
for env_var_name in my_env:
|
||||
if env_var_name.startswith("gh"):
|
||||
|
|
@ -335,11 +329,9 @@ class CMake:
|
|||
try:
|
||||
my_env[env_var_name] = str(my_env[env_var_name].encode("utf-8"))
|
||||
except UnicodeDecodeError as e:
|
||||
shex = ":".join(
|
||||
"{:02x}".format(ord(c)) for c in my_env[env_var_name]
|
||||
)
|
||||
shex = ":".join(f"{ord(c):02x}" for c in my_env[env_var_name])
|
||||
print(
|
||||
"Invalid ENV[{}] = {}".format(env_var_name, shex),
|
||||
f"Invalid ENV[{env_var_name}] = {shex}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(e, file=sys.stderr)
|
||||
|
|
@ -396,7 +388,7 @@ class CMake:
|
|||
build_args += ["--"]
|
||||
if IS_WINDOWS and not USE_NINJA:
|
||||
# We are likely using msbuild here
|
||||
build_args += ["/p:CL_MPCount={}".format(max_jobs)]
|
||||
build_args += [f"/p:CL_MPCount={max_jobs}"]
|
||||
else:
|
||||
build_args += ["-j", max_jobs]
|
||||
self.run(build_args, my_env)
|
||||
|
|
|
|||
|
|
@ -71,9 +71,7 @@ def get_cmake_cache_variables_from_file(
|
|||
r'("?)(.+?)\1(?::\s*([a-zA-Z_-][a-zA-Z0-9_-]*)?)?\s*=\s*(.*)', line
|
||||
)
|
||||
if matched is None: # Illegal line
|
||||
raise ValueError(
|
||||
"Unexpected line {} in {}: {}".format(i, repr(cmake_cache_file), line)
|
||||
)
|
||||
raise ValueError(f"Unexpected line {i} in {repr(cmake_cache_file)}: {line}")
|
||||
_, variable, type_, value = matched.groups()
|
||||
if type_ is None:
|
||||
type_ = ""
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ def generate_code(
|
|||
def get_selector_from_legacy_operator_selection_list(
|
||||
selected_op_list_path: str,
|
||||
) -> Any:
|
||||
with open(selected_op_list_path, "r") as f:
|
||||
with open(selected_op_list_path) as f:
|
||||
# strip out the overload part
|
||||
# It's only for legacy config - do NOT copy this code!
|
||||
selected_op_list = {
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def fetch_and_cache(
|
|||
|
||||
if os.path.exists(path) and is_cached_file_valid():
|
||||
# Another test process already download the file, so don't re-do it
|
||||
with open(path, "r") as f:
|
||||
with open(path) as f:
|
||||
return cast(Dict[str, Any], json.load(f))
|
||||
|
||||
for _ in range(3):
|
||||
|
|
|
|||
|
|
@ -249,10 +249,8 @@ class EnvVarMetric:
|
|||
value = os.environ.get(self.env_var)
|
||||
if value is None and self.required:
|
||||
raise ValueError(
|
||||
(
|
||||
f"Missing {self.name}. Please set the {self.env_var}"
|
||||
"environment variable to pass in this value."
|
||||
)
|
||||
f"Missing {self.name}. Please set the {self.env_var}"
|
||||
"environment variable to pass in this value."
|
||||
)
|
||||
if self.type_conversion_fn:
|
||||
return self.type_conversion_fn(value)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--replace", action="append", nargs=2)
|
||||
options = parser.parse_args()
|
||||
|
||||
with open(options.input_file, "r") as f:
|
||||
with open(options.input_file) as f:
|
||||
contents = f.read()
|
||||
|
||||
output_file = os.path.join(options.install_dir, options.output_file)
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ class TestParseNativeYaml(unittest.TestCase):
|
|||
use_aten_lib=False,
|
||||
out_file=out_file,
|
||||
)
|
||||
with open(out_yaml_path, "r") as out_file:
|
||||
with open(out_yaml_path) as out_file:
|
||||
es = yaml.load(out_file, Loader=LineLoader)
|
||||
self.assertTrue(all("func" in e for e in es))
|
||||
self.assertTrue(all(e.get("variants") == "function" for e in es))
|
||||
|
|
@ -268,7 +268,7 @@ class TestParseKernelYamlFiles(unittest.TestCase):
|
|||
use_aten_lib=False,
|
||||
out_file=out_file,
|
||||
)
|
||||
with open(out_yaml_path, "r") as out_file:
|
||||
with open(out_yaml_path) as out_file:
|
||||
es = yaml.load(out_file, Loader=LineLoader)
|
||||
self.assertTrue(all("func" in e for e in es))
|
||||
self.assertTrue(all(e.get("variants") == "function" for e in es))
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class ExecutorchCppSignatureTest(unittest.TestCase):
|
|||
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
|
||||
):
|
||||
args = self.sig.arguments(include_context=True)
|
||||
self.assertEquals(len(args), 3)
|
||||
self.assertEqual(len(args), 3)
|
||||
self.assertTrue(any(a.name == "context" for a in args))
|
||||
|
||||
def test_runtime_signature_does_not_contain_runtime_context(self) -> None:
|
||||
|
|
@ -30,7 +30,7 @@ class ExecutorchCppSignatureTest(unittest.TestCase):
|
|||
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
|
||||
):
|
||||
args = self.sig.arguments(include_context=False)
|
||||
self.assertEquals(len(args), 2)
|
||||
self.assertEqual(len(args), 2)
|
||||
self.assertFalse(any(a.name == "context" for a in args))
|
||||
|
||||
def test_runtime_signature_declaration_correct(self) -> None:
|
||||
|
|
@ -38,7 +38,7 @@ class ExecutorchCppSignatureTest(unittest.TestCase):
|
|||
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
|
||||
):
|
||||
decl = self.sig.decl(include_context=True)
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
decl,
|
||||
(
|
||||
"torch::executor::Tensor & foo_outf("
|
||||
|
|
@ -48,7 +48,7 @@ class ExecutorchCppSignatureTest(unittest.TestCase):
|
|||
),
|
||||
)
|
||||
no_context_decl = self.sig.decl(include_context=False)
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
no_context_decl,
|
||||
(
|
||||
"torch::executor::Tensor & foo_outf("
|
||||
|
|
|
|||
|
|
@ -106,9 +106,9 @@ x = $TILE_SIZE_X + $TILE_SIZE_Y
|
|||
file_name_2 = os.path.join(tmp_dir, "conv2d_pw_1x2.glsl")
|
||||
self.assertTrue(os.path.exists(file_name_1))
|
||||
self.assertTrue(os.path.exists(file_name_2))
|
||||
with open(file_name_1, "r") as f:
|
||||
with open(file_name_1) as f:
|
||||
contents = f.read()
|
||||
self.assertTrue("1 + 1" in contents)
|
||||
with open(file_name_2, "r") as f:
|
||||
with open(file_name_2) as f:
|
||||
contents = f.read()
|
||||
self.assertTrue("1 + 2" in contents)
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
|
||||
touched_files = [CONFIG_YML]
|
||||
with open(CONFIG_YML, "r") as f:
|
||||
with open(CONFIG_YML) as f:
|
||||
config_yml = yaml.safe_load(f.read())
|
||||
|
||||
config_yml["workflows"] = get_filtered_circleci_config(
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ def _get_previously_failing_tests() -> Set[str]:
|
|||
)
|
||||
return set()
|
||||
|
||||
with open(PYTEST_FAILED_TESTS_CACHE_FILE_PATH, "r") as f:
|
||||
with open(PYTEST_FAILED_TESTS_CACHE_FILE_PATH) as f:
|
||||
last_failed_tests = json.load(f)
|
||||
|
||||
prioritized_tests = _parse_prev_failing_test_files(last_failed_tests)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user