PEP585: .github (#145707)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145707
Approved by: https://github.com/huydhn
This commit is contained in:
Aaron Orenstein 2025-01-26 11:50:47 -08:00 committed by PyTorch MergeBot
parent bfaf76bfc6
commit 60f98262f1
22 changed files with 265 additions and 280 deletions

View File

@ -3,7 +3,7 @@
import json import json
import os import os
import re import re
from typing import Any, cast, Dict, List, Optional from typing import Any, cast, Optional
from urllib.error import HTTPError from urllib.error import HTTPError
from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels
@ -67,7 +67,7 @@ def get_release_version(onto_branch: str) -> Optional[str]:
def get_tracker_issues( def get_tracker_issues(
org: str, project: str, onto_branch: str org: str, project: str, onto_branch: str
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Find the tracker issue from the repo. The tracker issue needs to have the title Find the tracker issue from the repo. The tracker issue needs to have the title
like [VERSION] Release Tracker following the convention on PyTorch like [VERSION] Release Tracker following the convention on PyTorch
@ -117,7 +117,7 @@ def cherry_pick(
continue continue
res = cast( res = cast(
Dict[str, Any], dict[str, Any],
post_tracker_issue_comment( post_tracker_issue_comment(
org, org,
project, project,
@ -220,7 +220,7 @@ def submit_pr(
def post_pr_comment( def post_pr_comment(
org: str, project: str, pr_num: int, msg: str, dry_run: bool = False org: str, project: str, pr_num: int, msg: str, dry_run: bool = False
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Post a comment on the PR itself to point to the cherry picking PR when success Post a comment on the PR itself to point to the cherry picking PR when success
or print the error when failure or print the error when failure
@ -255,7 +255,7 @@ def post_tracker_issue_comment(
classification: str, classification: str,
fixes: str, fixes: str,
dry_run: bool = False, dry_run: bool = False,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Post a comment on the tracker issue (if any) to record the cherry pick Post a comment on the tracker issue (if any) to record the cherry pick
""" """

View File

@ -6,7 +6,7 @@ import re
import sys import sys
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple from typing import Any
import requests import requests
from gitutils import retries_decorator from gitutils import retries_decorator
@ -76,7 +76,7 @@ DISABLED_TESTS_JSON = (
@retries_decorator() @retries_decorator()
def query_db(query: str, params: Dict[str, Any]) -> List[Dict[str, Any]]: def query_db(query: str, params: dict[str, Any]) -> list[dict[str, Any]]:
return query_clickhouse(query, params) return query_clickhouse(query, params)
@ -97,7 +97,7 @@ def download_log_worker(temp_dir: str, id: int, name: str) -> None:
f.write(data) f.write(data)
def printer(item: Tuple[str, Tuple[int, str, List[Any]]], extra: str) -> None: def printer(item: tuple[str, tuple[int, str, list[Any]]], extra: str) -> None:
test, (_, link, _) = item test, (_, link, _) = item
print(f"{link:<55} {test:<120} {extra}") print(f"{link:<55} {test:<120} {extra}")
@ -120,8 +120,8 @@ def close_issue(num: int) -> None:
def check_if_exists( def check_if_exists(
item: Tuple[str, Tuple[int, str, List[str]]], all_logs: List[str] item: tuple[str, tuple[int, str, list[str]]], all_logs: list[str]
) -> Tuple[bool, str]: ) -> tuple[bool, str]:
test, (_, link, _) = item test, (_, link, _) = item
# Test names should look like `test_a (module.path.classname)` # Test names should look like `test_a (module.path.classname)`
reg = re.match(r"(\S+) \((\S*)\)", test) reg = re.match(r"(\S+) \((\S*)\)", test)

View File

@ -2,7 +2,7 @@
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, cast, Dict, List, Set from typing import Any, cast
import yaml import yaml
@ -10,9 +10,9 @@ import yaml
GITHUB_DIR = Path(__file__).parent.parent GITHUB_DIR = Path(__file__).parent.parent
def get_workflows_push_tags() -> Set[str]: def get_workflows_push_tags() -> set[str]:
"Extract all known push tags from workflows" "Extract all known push tags from workflows"
rc: Set[str] = set() rc: set[str] = set()
for fname in (GITHUB_DIR / "workflows").glob("*.yml"): for fname in (GITHUB_DIR / "workflows").glob("*.yml"):
with fname.open("r") as f: with fname.open("r") as f:
wf_yml = yaml.safe_load(f) wf_yml = yaml.safe_load(f)
@ -25,19 +25,19 @@ def get_workflows_push_tags() -> Set[str]:
return rc return rc
def filter_ciflow_tags(tags: Set[str]) -> List[str]: def filter_ciflow_tags(tags: set[str]) -> list[str]:
"Return sorted list of ciflow tags" "Return sorted list of ciflow tags"
return sorted( return sorted(
tag[:-2] for tag in tags if tag.startswith("ciflow/") and tag.endswith("/*") tag[:-2] for tag in tags if tag.startswith("ciflow/") and tag.endswith("/*")
) )
def read_probot_config() -> Dict[str, Any]: def read_probot_config() -> dict[str, Any]:
with (GITHUB_DIR / "pytorch-probot.yml").open("r") as f: with (GITHUB_DIR / "pytorch-probot.yml").open("r") as f:
return cast(Dict[str, Any], yaml.safe_load(f)) return cast(dict[str, Any], yaml.safe_load(f))
def update_probot_config(labels: Set[str]) -> None: def update_probot_config(labels: set[str]) -> None:
orig = read_probot_config() orig = read_probot_config()
orig["ciflow_push_tags"] = filter_ciflow_tags(labels) orig["ciflow_push_tags"] = filter_ciflow_tags(labels)
with (GITHUB_DIR / "pytorch-probot.yml").open("w") as f: with (GITHUB_DIR / "pytorch-probot.yml").open("w") as f:

View File

@ -4,7 +4,7 @@ import re
from datetime import datetime from datetime import datetime
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Set from typing import Any, Callable
from github_utils import gh_fetch_json_dict, gh_graphql from github_utils import gh_fetch_json_dict, gh_graphql
from gitutils import GitRepo from gitutils import GitRepo
@ -112,7 +112,7 @@ def convert_gh_timestamp(date: str) -> float:
return datetime.strptime(date, "%Y-%m-%dT%H:%M:%SZ").timestamp() return datetime.strptime(date, "%Y-%m-%dT%H:%M:%SZ").timestamp()
def get_branches(repo: GitRepo) -> Dict[str, Any]: def get_branches(repo: GitRepo) -> dict[str, Any]:
# Query locally for branches, group by branch base name (e.g. gh/blah/base -> gh/blah), and get the most recent branch # Query locally for branches, group by branch base name (e.g. gh/blah/base -> gh/blah), and get the most recent branch
git_response = repo._run_git( git_response = repo._run_git(
"for-each-ref", "for-each-ref",
@ -120,7 +120,7 @@ def get_branches(repo: GitRepo) -> Dict[str, Any]:
"--format=%(refname) %(committerdate:iso-strict)", "--format=%(refname) %(committerdate:iso-strict)",
"refs/remotes/origin", "refs/remotes/origin",
) )
branches_by_base_name: Dict[str, Any] = {} branches_by_base_name: dict[str, Any] = {}
for line in git_response.splitlines(): for line in git_response.splitlines():
branch, date = line.split(" ") branch, date = line.split(" ")
re_branch = re.match(r"refs/remotes/origin/(.*)", branch) re_branch = re.match(r"refs/remotes/origin/(.*)", branch)
@ -140,14 +140,14 @@ def get_branches(repo: GitRepo) -> Dict[str, Any]:
def paginate_graphql( def paginate_graphql(
query: str, query: str,
kwargs: Dict[str, Any], kwargs: dict[str, Any],
termination_func: Callable[[List[Dict[str, Any]]], bool], termination_func: Callable[[list[dict[str, Any]]], bool],
get_data: Callable[[Dict[str, Any]], List[Dict[str, Any]]], get_data: Callable[[dict[str, Any]], list[dict[str, Any]]],
get_page_info: Callable[[Dict[str, Any]], Dict[str, Any]], get_page_info: Callable[[dict[str, Any]], dict[str, Any]],
) -> List[Any]: ) -> list[Any]:
hasNextPage = True hasNextPage = True
endCursor = None endCursor = None
data: List[Dict[str, Any]] = [] data: list[dict[str, Any]] = []
while hasNextPage: while hasNextPage:
ESTIMATED_TOKENS[0] += 1 ESTIMATED_TOKENS[0] += 1
res = gh_graphql(query, cursor=endCursor, **kwargs) res = gh_graphql(query, cursor=endCursor, **kwargs)
@ -159,11 +159,11 @@ def paginate_graphql(
return data return data
def get_recent_prs() -> Dict[str, Any]: def get_recent_prs() -> dict[str, Any]:
now = datetime.now().timestamp() now = datetime.now().timestamp()
# Grab all PRs updated in last CLOSED_PR_RETENTION days # Grab all PRs updated in last CLOSED_PR_RETENTION days
pr_infos: List[Dict[str, Any]] = paginate_graphql( pr_infos: list[dict[str, Any]] = paginate_graphql(
GRAPHQL_ALL_PRS_BY_UPDATED_AT, GRAPHQL_ALL_PRS_BY_UPDATED_AT,
{"owner": "pytorch", "repo": "pytorch"}, {"owner": "pytorch", "repo": "pytorch"},
lambda data: ( lambda data: (
@ -190,7 +190,7 @@ def get_recent_prs() -> Dict[str, Any]:
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def get_open_prs() -> List[Dict[str, Any]]: def get_open_prs() -> list[dict[str, Any]]:
return paginate_graphql( return paginate_graphql(
GRAPHQL_OPEN_PRS, GRAPHQL_OPEN_PRS,
{"owner": "pytorch", "repo": "pytorch"}, {"owner": "pytorch", "repo": "pytorch"},
@ -200,8 +200,8 @@ def get_open_prs() -> List[Dict[str, Any]]:
) )
def get_branches_with_magic_label_or_open_pr() -> Set[str]: def get_branches_with_magic_label_or_open_pr() -> set[str]:
pr_infos: List[Dict[str, Any]] = paginate_graphql( pr_infos: list[dict[str, Any]] = paginate_graphql(
GRAPHQL_NO_DELETE_BRANCH_LABEL, GRAPHQL_NO_DELETE_BRANCH_LABEL,
{"owner": "pytorch", "repo": "pytorch"}, {"owner": "pytorch", "repo": "pytorch"},
lambda data: False, lambda data: False,

View File

@ -2,7 +2,7 @@ import json
import re import re
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any, List from typing import Any
import boto3 # type: ignore[import] import boto3 # type: ignore[import]
@ -77,7 +77,7 @@ def upload_file_to_s3(file_name: Path, bucket: str, key: str) -> None:
def download_s3_objects_with_prefix( def download_s3_objects_with_prefix(
bucket_name: str, prefix: str, download_folder: Path bucket_name: str, prefix: str, download_folder: Path
) -> List[Path]: ) -> list[Path]:
s3 = boto3.resource("s3") s3 = boto3.resource("s3")
bucket = s3.Bucket(bucket_name) bucket = s3.Bucket(bucket_name)

View File

@ -8,9 +8,9 @@ import subprocess
import sys import sys
import warnings import warnings
from enum import Enum from enum import Enum
from functools import lru_cache from functools import cache
from logging import info from logging import info
from typing import Any, Callable, Dict, List, Optional, Set from typing import Any, Callable, Optional
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
import yaml import yaml
@ -32,7 +32,7 @@ def is_cuda_or_rocm_job(job_name: Optional[str]) -> bool:
# Supported modes when running periodically. Only applying the mode when # Supported modes when running periodically. Only applying the mode when
# its lambda condition returns true # its lambda condition returns true
SUPPORTED_PERIODICAL_MODES: Dict[str, Callable[[Optional[str]], bool]] = { SUPPORTED_PERIODICAL_MODES: dict[str, Callable[[Optional[str]], bool]] = {
# Memory leak check is only needed for CUDA and ROCm jobs which utilize GPU memory # Memory leak check is only needed for CUDA and ROCm jobs which utilize GPU memory
"mem_leak_check": is_cuda_or_rocm_job, "mem_leak_check": is_cuda_or_rocm_job,
"rerun_disabled_tests": lambda job_name: True, "rerun_disabled_tests": lambda job_name: True,
@ -102,8 +102,8 @@ def parse_args() -> Any:
return parser.parse_args() return parser.parse_args()
@lru_cache(maxsize=None) @cache
def get_pr_info(pr_number: int) -> Dict[str, Any]: def get_pr_info(pr_number: int) -> dict[str, Any]:
""" """
Dynamically get PR information Dynamically get PR information
""" """
@ -116,7 +116,7 @@ def get_pr_info(pr_number: int) -> Dict[str, Any]:
"Accept": "application/vnd.github.v3+json", "Accept": "application/vnd.github.v3+json",
"Authorization": f"token {github_token}", "Authorization": f"token {github_token}",
} }
json_response: Dict[str, Any] = download_json( json_response: dict[str, Any] = download_json(
url=f"{pytorch_github_api}/issues/{pr_number}", url=f"{pytorch_github_api}/issues/{pr_number}",
headers=headers, headers=headers,
) )
@ -128,7 +128,7 @@ def get_pr_info(pr_number: int) -> Dict[str, Any]:
return json_response return json_response
def get_labels(pr_number: int) -> Set[str]: def get_labels(pr_number: int) -> set[str]:
""" """
Dynamically get the latest list of labels from the pull request Dynamically get the latest list of labels from the pull request
""" """
@ -138,14 +138,14 @@ def get_labels(pr_number: int) -> Set[str]:
} }
def filter_labels(labels: Set[str], label_regex: Any) -> Set[str]: def filter_labels(labels: set[str], label_regex: Any) -> set[str]:
""" """
Return the list of matching labels Return the list of matching labels
""" """
return {l for l in labels if re.match(label_regex, l)} return {l for l in labels if re.match(label_regex, l)}
def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, List[Any]]: def filter(test_matrix: dict[str, list[Any]], labels: set[str]) -> dict[str, list[Any]]:
""" """
Select the list of test config to run from the test matrix. The logic works Select the list of test config to run from the test matrix. The logic works
as follows: as follows:
@ -157,7 +157,7 @@ def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, Lis
If the PR has none of the test-config label, all tests are run as usual. If the PR has none of the test-config label, all tests are run as usual.
""" """
filtered_test_matrix: Dict[str, List[Any]] = {"include": []} filtered_test_matrix: dict[str, list[Any]] = {"include": []}
for entry in test_matrix.get("include", []): for entry in test_matrix.get("include", []):
config_name = entry.get("config", "") config_name = entry.get("config", "")
@ -185,8 +185,8 @@ def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, Lis
def filter_selected_test_configs( def filter_selected_test_configs(
test_matrix: Dict[str, List[Any]], selected_test_configs: Set[str] test_matrix: dict[str, list[Any]], selected_test_configs: set[str]
) -> Dict[str, List[Any]]: ) -> dict[str, list[Any]]:
""" """
Keep only the selected configs if the list if not empty. Otherwise, keep all test configs. Keep only the selected configs if the list if not empty. Otherwise, keep all test configs.
This filter is used when the workflow is dispatched manually. This filter is used when the workflow is dispatched manually.
@ -194,7 +194,7 @@ def filter_selected_test_configs(
if not selected_test_configs: if not selected_test_configs:
return test_matrix return test_matrix
filtered_test_matrix: Dict[str, List[Any]] = {"include": []} filtered_test_matrix: dict[str, list[Any]] = {"include": []}
for entry in test_matrix.get("include", []): for entry in test_matrix.get("include", []):
config_name = entry.get("config", "") config_name = entry.get("config", "")
if not config_name: if not config_name:
@ -207,12 +207,12 @@ def filter_selected_test_configs(
def set_periodic_modes( def set_periodic_modes(
test_matrix: Dict[str, List[Any]], job_name: Optional[str] test_matrix: dict[str, list[Any]], job_name: Optional[str]
) -> Dict[str, List[Any]]: ) -> dict[str, list[Any]]:
""" """
Apply all periodic modes when running under a schedule Apply all periodic modes when running under a schedule
""" """
scheduled_test_matrix: Dict[str, List[Any]] = { scheduled_test_matrix: dict[str, list[Any]] = {
"include": [], "include": [],
} }
@ -229,8 +229,8 @@ def set_periodic_modes(
def mark_unstable_jobs( def mark_unstable_jobs(
workflow: str, job_name: str, test_matrix: Dict[str, List[Any]] workflow: str, job_name: str, test_matrix: dict[str, list[Any]]
) -> Dict[str, List[Any]]: ) -> dict[str, list[Any]]:
""" """
Check the list of unstable jobs and mark them accordingly. Note that if a job Check the list of unstable jobs and mark them accordingly. Note that if a job
is unstable, all its dependents will also be marked accordingly is unstable, all its dependents will also be marked accordingly
@ -245,8 +245,8 @@ def mark_unstable_jobs(
def remove_disabled_jobs( def remove_disabled_jobs(
workflow: str, job_name: str, test_matrix: Dict[str, List[Any]] workflow: str, job_name: str, test_matrix: dict[str, list[Any]]
) -> Dict[str, List[Any]]: ) -> dict[str, list[Any]]:
""" """
Check the list of disabled jobs, remove the current job and all its dependents Check the list of disabled jobs, remove the current job and all its dependents
if it exists in the list if it exists in the list
@ -261,15 +261,15 @@ def remove_disabled_jobs(
def _filter_jobs( def _filter_jobs(
test_matrix: Dict[str, List[Any]], test_matrix: dict[str, list[Any]],
issue_type: IssueType, issue_type: IssueType,
target_cfg: Optional[str] = None, target_cfg: Optional[str] = None,
) -> Dict[str, List[Any]]: ) -> dict[str, list[Any]]:
""" """
An utility function used to actually apply the job filter An utility function used to actually apply the job filter
""" """
# The result will be stored here # The result will be stored here
filtered_test_matrix: Dict[str, List[Any]] = {"include": []} filtered_test_matrix: dict[str, list[Any]] = {"include": []}
# This is an issue to disable a CI job # This is an issue to disable a CI job
if issue_type == IssueType.DISABLED: if issue_type == IssueType.DISABLED:
@ -302,10 +302,10 @@ def _filter_jobs(
def process_jobs( def process_jobs(
workflow: str, workflow: str,
job_name: str, job_name: str,
test_matrix: Dict[str, List[Any]], test_matrix: dict[str, list[Any]],
issue_type: IssueType, issue_type: IssueType,
url: str, url: str,
) -> Dict[str, List[Any]]: ) -> dict[str, list[Any]]:
""" """
Both disabled and unstable jobs are in the following format: Both disabled and unstable jobs are in the following format:
@ -441,7 +441,7 @@ def process_jobs(
return test_matrix return test_matrix
def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any: def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any:
for _ in range(num_retries): for _ in range(num_retries):
try: try:
req = Request(url=url, headers=headers) req = Request(url=url, headers=headers)
@ -462,7 +462,7 @@ def set_output(name: str, val: Any) -> None:
print(f"::set-output name={name}::{val}") print(f"::set-output name={name}::{val}")
def parse_reenabled_issues(s: Optional[str]) -> List[str]: def parse_reenabled_issues(s: Optional[str]) -> list[str]:
# NB: When the PR body is empty, GitHub API returns a None value, which is # NB: When the PR body is empty, GitHub API returns a None value, which is
# passed into this function # passed into this function
if not s: if not s:
@ -477,7 +477,7 @@ def parse_reenabled_issues(s: Optional[str]) -> List[str]:
return issue_numbers return issue_numbers
def get_reenabled_issues(pr_body: str = "") -> List[str]: def get_reenabled_issues(pr_body: str = "") -> list[str]:
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}" default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
try: try:
commit_messages = subprocess.check_output( commit_messages = subprocess.check_output(
@ -489,12 +489,12 @@ def get_reenabled_issues(pr_body: str = "") -> List[str]:
return parse_reenabled_issues(pr_body) + parse_reenabled_issues(commit_messages) return parse_reenabled_issues(pr_body) + parse_reenabled_issues(commit_messages)
def check_for_setting(labels: Set[str], body: str, setting: str) -> bool: def check_for_setting(labels: set[str], body: str, setting: str) -> bool:
return setting in labels or f"[{setting}]" in body return setting in labels or f"[{setting}]" in body
def perform_misc_tasks( def perform_misc_tasks(
labels: Set[str], test_matrix: Dict[str, List[Any]], job_name: str, pr_body: str labels: set[str], test_matrix: dict[str, list[Any]], job_name: str, pr_body: str
) -> None: ) -> None:
""" """
In addition to apply the filter logic, the script also does the following In addition to apply the filter logic, the script also does the following

View File

@ -12,7 +12,7 @@ architectures:
""" """
import os import os
from typing import Dict, List, Optional, Tuple from typing import Optional
# NOTE: Also update the CUDA sources in tools/nightly.py when changing this list # NOTE: Also update the CUDA sources in tools/nightly.py when changing this list
@ -181,7 +181,7 @@ CXX11_ABI = "cxx11-abi"
RELEASE = "release" RELEASE = "release"
DEBUG = "debug" DEBUG = "debug"
LIBTORCH_CONTAINER_IMAGES: Dict[Tuple[str, str], str] = { LIBTORCH_CONTAINER_IMAGES: dict[tuple[str, str], str] = {
**{ **{
( (
gpu_arch, gpu_arch,
@ -223,16 +223,16 @@ def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
}.get(gpu_arch_type, gpu_arch_version) }.get(gpu_arch_type, gpu_arch_version)
def list_without(in_list: List[str], without: List[str]) -> List[str]: def list_without(in_list: list[str], without: list[str]) -> list[str]:
return [item for item in in_list if item not in without] return [item for item in in_list if item not in without]
def generate_libtorch_matrix( def generate_libtorch_matrix(
os: str, os: str,
abi_version: str, abi_version: str,
arches: Optional[List[str]] = None, arches: Optional[list[str]] = None,
libtorch_variants: Optional[List[str]] = None, libtorch_variants: Optional[list[str]] = None,
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
if arches is None: if arches is None:
arches = ["cpu"] arches = ["cpu"]
if os == "linux": if os == "linux":
@ -248,7 +248,7 @@ def generate_libtorch_matrix(
"static-without-deps", "static-without-deps",
] ]
ret: List[Dict[str, str]] = [] ret: list[dict[str, str]] = []
for arch_version in arches: for arch_version in arches:
for libtorch_variant in libtorch_variants: for libtorch_variant in libtorch_variants:
# one of the values in the following list must be exactly # one of the values in the following list must be exactly
@ -287,10 +287,10 @@ def generate_libtorch_matrix(
def generate_wheels_matrix( def generate_wheels_matrix(
os: str, os: str,
arches: Optional[List[str]] = None, arches: Optional[list[str]] = None,
python_versions: Optional[List[str]] = None, python_versions: Optional[list[str]] = None,
use_split_build: bool = False, use_split_build: bool = False,
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
package_type = "wheel" package_type = "wheel"
if os == "linux" or os == "linux-aarch64" or os == "linux-s390x": if os == "linux" or os == "linux-aarch64" or os == "linux-s390x":
# NOTE: We only build manywheel packages for x86_64 and aarch64 and s390x linux # NOTE: We only build manywheel packages for x86_64 and aarch64 and s390x linux
@ -315,7 +315,7 @@ def generate_wheels_matrix(
# uses different build/test scripts # uses different build/test scripts
arches = ["cpu-s390x"] arches = ["cpu-s390x"]
ret: List[Dict[str, str]] = [] ret: list[dict[str, str]] = []
for python_version in python_versions: for python_version in python_versions:
for arch_version in arches: for arch_version in arches:
gpu_arch_type = arch_type(arch_version) gpu_arch_type = arch_type(arch_version)

View File

@ -2,9 +2,10 @@
import os import os
import sys import sys
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, Literal, Set from typing import Literal
from typing_extensions import TypedDict # Python 3.11+ from typing_extensions import TypedDict # Python 3.11+
import generate_binary_build_matrix # type: ignore[import] import generate_binary_build_matrix # type: ignore[import]
@ -27,7 +28,7 @@ LABEL_CIFLOW_BINARIES_WHEEL = "ciflow/binaries_wheel"
class CIFlowConfig: class CIFlowConfig:
# For use to enable workflows to run on pytorch/pytorch-canary # For use to enable workflows to run on pytorch/pytorch-canary
run_on_canary: bool = False run_on_canary: bool = False
labels: Set[str] = field(default_factory=set) labels: set[str] = field(default_factory=set)
# Certain jobs might not want to be part of the ciflow/[all,trunk] workflow # Certain jobs might not want to be part of the ciflow/[all,trunk] workflow
isolated_workflow: bool = False isolated_workflow: bool = False
unstable: bool = False unstable: bool = False
@ -48,7 +49,7 @@ class Config(TypedDict):
@dataclass @dataclass
class BinaryBuildWorkflow: class BinaryBuildWorkflow:
os: str os: str
build_configs: List[Dict[str, str]] build_configs: list[dict[str, str]]
package_type: str package_type: str
# Optional fields # Optional fields

View File

@ -11,11 +11,11 @@ import sys
import time import time
import urllib import urllib
import urllib.parse import urllib.parse
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Optional
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]: def parse_json_and_links(conn: Any) -> tuple[Any, dict[str, dict[str, str]]]:
links = {} links = {}
# Extract links which GH uses for pagination # Extract links which GH uses for pagination
# see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link # see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link
@ -42,7 +42,7 @@ def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]:
def fetch_url( def fetch_url(
url: str, url: str,
*, *,
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
reader: Callable[[Any], Any] = lambda x: x.read(), reader: Callable[[Any], Any] = lambda x: x.read(),
retries: Optional[int] = 3, retries: Optional[int] = 3,
backoff_timeout: float = 0.5, backoff_timeout: float = 0.5,
@ -83,7 +83,7 @@ def parse_args() -> Any:
return parser.parse_args() return parser.parse_args()
def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]: def fetch_jobs(url: str, headers: dict[str, str]) -> list[dict[str, str]]:
response, links = fetch_url(url, headers=headers, reader=parse_json_and_links) response, links = fetch_url(url, headers=headers, reader=parse_json_and_links)
jobs = response["jobs"] jobs = response["jobs"]
assert type(jobs) is list assert type(jobs) is list
@ -111,7 +111,7 @@ def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]:
# running. # running.
def find_job_id_name(args: Any) -> Tuple[str, str]: def find_job_id_name(args: Any) -> tuple[str, str]:
# From https://docs.github.com/en/actions/learn-github-actions/environment-variables # From https://docs.github.com/en/actions/learn-github-actions/environment-variables
PYTORCH_REPO = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch") PYTORCH_REPO = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch")
PYTORCH_GITHUB_API = f"https://api.github.com/repos/{PYTORCH_REPO}" PYTORCH_GITHUB_API = f"https://api.github.com/repos/{PYTORCH_REPO}"

View File

@ -4,7 +4,7 @@ import json
import os import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from typing import Any, Callable, cast, Optional, Union
from urllib.error import HTTPError from urllib.error import HTTPError
from urllib.parse import quote from urllib.parse import quote
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
@ -27,11 +27,11 @@ class GitHubComment:
def gh_fetch_url_and_headers( def gh_fetch_url_and_headers(
url: str, url: str,
*, *,
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
data: Union[Optional[Dict[str, Any]], str] = None, data: Union[Optional[dict[str, Any]], str] = None,
method: Optional[str] = None, method: Optional[str] = None,
reader: Callable[[Any], Any] = lambda x: x.read(), reader: Callable[[Any], Any] = lambda x: x.read(),
) -> Tuple[Any, Any]: ) -> tuple[Any, Any]:
if headers is None: if headers is None:
headers = {} headers = {}
token = os.environ.get("GITHUB_TOKEN") token = os.environ.get("GITHUB_TOKEN")
@ -70,8 +70,8 @@ def gh_fetch_url_and_headers(
def gh_fetch_url( def gh_fetch_url(
url: str, url: str,
*, *,
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
data: Union[Optional[Dict[str, Any]], str] = None, data: Union[Optional[dict[str, Any]], str] = None,
method: Optional[str] = None, method: Optional[str] = None,
reader: Callable[[Any], Any] = json.load, reader: Callable[[Any], Any] = json.load,
) -> Any: ) -> Any:
@ -82,25 +82,25 @@ def gh_fetch_url(
def gh_fetch_json( def gh_fetch_json(
url: str, url: str,
params: Optional[Dict[str, Any]] = None, params: Optional[dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None, data: Optional[dict[str, Any]] = None,
method: Optional[str] = None, method: Optional[str] = None,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
headers = {"Accept": "application/vnd.github.v3+json"} headers = {"Accept": "application/vnd.github.v3+json"}
if params is not None and len(params) > 0: if params is not None and len(params) > 0:
url += "?" + "&".join( url += "?" + "&".join(
f"{name}={quote(str(val))}" for name, val in params.items() f"{name}={quote(str(val))}" for name, val in params.items()
) )
return cast( return cast(
List[Dict[str, Any]], list[dict[str, Any]],
gh_fetch_url(url, headers=headers, data=data, reader=json.load, method=method), gh_fetch_url(url, headers=headers, data=data, reader=json.load, method=method),
) )
def _gh_fetch_json_any( def _gh_fetch_json_any(
url: str, url: str,
params: Optional[Dict[str, Any]] = None, params: Optional[dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None, data: Optional[dict[str, Any]] = None,
) -> Any: ) -> Any:
headers = {"Accept": "application/vnd.github.v3+json"} headers = {"Accept": "application/vnd.github.v3+json"}
if params is not None and len(params) > 0: if params is not None and len(params) > 0:
@ -112,21 +112,21 @@ def _gh_fetch_json_any(
def gh_fetch_json_list( def gh_fetch_json_list(
url: str, url: str,
params: Optional[Dict[str, Any]] = None, params: Optional[dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None, data: Optional[dict[str, Any]] = None,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
return cast(List[Dict[str, Any]], _gh_fetch_json_any(url, params, data)) return cast(list[dict[str, Any]], _gh_fetch_json_any(url, params, data))
def gh_fetch_json_dict( def gh_fetch_json_dict(
url: str, url: str,
params: Optional[Dict[str, Any]] = None, params: Optional[dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None, data: Optional[dict[str, Any]] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data)) return cast(dict[str, Any], _gh_fetch_json_any(url, params, data))
def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]: def gh_graphql(query: str, **kwargs: Any) -> dict[str, Any]:
rc = gh_fetch_url( rc = gh_fetch_url(
"https://api.github.com/graphql", "https://api.github.com/graphql",
data={"query": query, "variables": kwargs}, data={"query": query, "variables": kwargs},
@ -136,12 +136,12 @@ def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]:
raise RuntimeError( raise RuntimeError(
f"GraphQL query {query}, args {kwargs} failed: {rc['errors']}" f"GraphQL query {query}, args {kwargs} failed: {rc['errors']}"
) )
return cast(Dict[str, Any], rc) return cast(dict[str, Any], rc)
def _gh_post_comment( def _gh_post_comment(
url: str, comment: str, dry_run: bool = False url: str, comment: str, dry_run: bool = False
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
if dry_run: if dry_run:
print(comment) print(comment)
return [] return []
@ -150,7 +150,7 @@ def _gh_post_comment(
def gh_post_pr_comment( def gh_post_pr_comment(
org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
return _gh_post_comment( return _gh_post_comment(
f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/{pr_num}/comments", f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/{pr_num}/comments",
comment, comment,
@ -160,7 +160,7 @@ def gh_post_pr_comment(
def gh_post_commit_comment( def gh_post_commit_comment(
org: str, repo: str, sha: str, comment: str, dry_run: bool = False org: str, repo: str, sha: str, comment: str, dry_run: bool = False
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
return _gh_post_comment( return _gh_post_comment(
f"{GITHUB_API_URL}/repos/{org}/{repo}/commits/{sha}/comments", f"{GITHUB_API_URL}/repos/{org}/{repo}/commits/{sha}/comments",
comment, comment,
@ -220,8 +220,8 @@ def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") ->
def gh_query_issues_by_labels( def gh_query_issues_by_labels(
org: str, repo: str, labels: List[str], state: str = "open" org: str, repo: str, labels: list[str], state: str = "open"
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues" url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues"
return gh_fetch_json( return gh_fetch_json(
url, method="GET", params={"labels": ",".join(labels), "state": state} url, method="GET", params={"labels": ",".join(labels), "state": state}

View File

@ -4,20 +4,10 @@ import os
import re import re
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from typing import ( from typing import Any, Callable, cast, Optional, TypeVar, Union
Any,
Callable,
cast,
Dict,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
)
T = TypeVar("T") T = TypeVar("T")
@ -35,17 +25,17 @@ def get_git_repo_dir() -> str:
return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parents[2])) return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parents[2]))
def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]: def fuzzy_list_to_dict(items: list[tuple[str, str]]) -> dict[str, list[str]]:
""" """
Converts list to dict preserving elements with duplicate keys Converts list to dict preserving elements with duplicate keys
""" """
rc: Dict[str, List[str]] = defaultdict(list) rc: dict[str, list[str]] = defaultdict(list)
for key, val in items: for key, val in items:
rc[key].append(val) rc[key].append(val)
return dict(rc) return dict(rc)
def _check_output(items: List[str], encoding: str = "utf-8") -> str: def _check_output(items: list[str], encoding: str = "utf-8") -> str:
from subprocess import CalledProcessError, check_output, STDOUT from subprocess import CalledProcessError, check_output, STDOUT
try: try:
@ -95,7 +85,7 @@ class GitCommit:
return item in self.body or item in self.title return item in self.body or item in self.title
def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit: def parse_fuller_format(lines: Union[str, list[str]]) -> GitCommit:
""" """
Expect commit message generated using `--format=fuller --date=unix` format, i.e.: Expect commit message generated using `--format=fuller --date=unix` format, i.e.:
commit <sha1> commit <sha1>
@ -142,13 +132,13 @@ class GitRepo:
print(f"+ git -C {self.repo_dir} {' '.join(args)}") print(f"+ git -C {self.repo_dir} {' '.join(args)}")
return _check_output(["git", "-C", self.repo_dir] + list(args)) return _check_output(["git", "-C", self.repo_dir] + list(args))
def revlist(self, revision_range: str) -> List[str]: def revlist(self, revision_range: str) -> list[str]:
rc = self._run_git("rev-list", revision_range, "--", ".").strip() rc = self._run_git("rev-list", revision_range, "--", ".").strip()
return rc.split("\n") if len(rc) > 0 else [] return rc.split("\n") if len(rc) > 0 else []
def branches_containing_ref( def branches_containing_ref(
self, ref: str, *, include_remote: bool = True self, ref: str, *, include_remote: bool = True
) -> List[str]: ) -> list[str]:
rc = ( rc = (
self._run_git("branch", "--remote", "--contains", ref) self._run_git("branch", "--remote", "--contains", ref)
if include_remote if include_remote
@ -189,7 +179,7 @@ class GitRepo:
def get_merge_base(self, from_ref: str, to_ref: str) -> str: def get_merge_base(self, from_ref: str, to_ref: str) -> str:
return self._run_git("merge-base", from_ref, to_ref).strip() return self._run_git("merge-base", from_ref, to_ref).strip()
def patch_id(self, ref: Union[str, List[str]]) -> List[Tuple[str, str]]: def patch_id(self, ref: Union[str, list[str]]) -> list[tuple[str, str]]:
is_list = isinstance(ref, list) is_list = isinstance(ref, list)
if is_list: if is_list:
if len(ref) == 0: if len(ref) == 0:
@ -198,9 +188,9 @@ class GitRepo:
rc = _check_output( rc = _check_output(
["sh", "-c", f"git -C {self.repo_dir} show {ref}|git patch-id --stable"] ["sh", "-c", f"git -C {self.repo_dir} show {ref}|git patch-id --stable"]
).strip() ).strip()
return [cast(Tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")] return [cast(tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")]
def commits_resolving_gh_pr(self, pr_num: int) -> List[str]: def commits_resolving_gh_pr(self, pr_num: int) -> list[str]:
owner, name = self.gh_owner_and_name() owner, name = self.gh_owner_and_name()
msg = f"Pull Request resolved: https://github.com/{owner}/{name}/pull/{pr_num}" msg = f"Pull Request resolved: https://github.com/{owner}/{name}/pull/{pr_num}"
rc = self._run_git("log", "--format=%H", "--grep", msg).strip() rc = self._run_git("log", "--format=%H", "--grep", msg).strip()
@ -219,7 +209,7 @@ class GitRepo:
def compute_branch_diffs( def compute_branch_diffs(
self, from_branch: str, to_branch: str self, from_branch: str, to_branch: str
) -> Tuple[List[str], List[str]]: ) -> tuple[list[str], list[str]]:
""" """
Returns list of commmits that are missing in each other branch since their merge base Returns list of commmits that are missing in each other branch since their merge base
Might be slow if merge base is between two branches is pretty far off Might be slow if merge base is between two branches is pretty far off
@ -311,14 +301,14 @@ class GitRepo:
def remote_url(self) -> str: def remote_url(self) -> str:
return self._run_git("remote", "get-url", self.remote) return self._run_git("remote", "get-url", self.remote)
def gh_owner_and_name(self) -> Tuple[str, str]: def gh_owner_and_name(self) -> tuple[str, str]:
url = os.getenv("GIT_REMOTE_URL", None) url = os.getenv("GIT_REMOTE_URL", None)
if url is None: if url is None:
url = self.remote_url() url = self.remote_url()
rc = RE_GITHUB_URL_MATCH.match(url) rc = RE_GITHUB_URL_MATCH.match(url)
if rc is None: if rc is None:
raise RuntimeError(f"Unexpected url format {url}") raise RuntimeError(f"Unexpected url format {url}")
return cast(Tuple[str, str], rc.groups()) return cast(tuple[str, str], rc.groups())
def commit_message(self, ref: str) -> str: def commit_message(self, ref: str) -> str:
return self._run_git("log", "-1", "--format=%B", ref) return self._run_git("log", "-1", "--format=%B", ref)
@ -366,7 +356,7 @@ class PeekableIterator(Iterator[str]):
return rc return rc
def patterns_to_regex(allowed_patterns: List[str]) -> Any: def patterns_to_regex(allowed_patterns: list[str]) -> Any:
""" """
pattern is glob-like, i.e. the only special sequences it has are: pattern is glob-like, i.e. the only special sequences it has are:
- ? - matches single character - ? - matches single character
@ -437,7 +427,7 @@ def retries_decorator(
) -> Callable[[Callable[..., T]], Callable[..., T]]: ) -> Callable[[Callable[..., T]], Callable[..., T]]:
def decorator(f: Callable[..., T]) -> Callable[..., T]: def decorator(f: Callable[..., T]) -> Callable[..., T]:
@wraps(f) @wraps(f)
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> T: def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> T:
for idx in range(num_retries): for idx in range(num_retries):
try: try:
return f(*args, **kwargs) return f(*args, **kwargs)

View File

@ -2,7 +2,7 @@
import json import json
from functools import lru_cache from functools import lru_cache
from typing import Any, List, Tuple, TYPE_CHECKING, Union from typing import Any, TYPE_CHECKING, Union
from github_utils import gh_fetch_url_and_headers, GitHubComment from github_utils import gh_fetch_url_and_headers, GitHubComment
@ -28,14 +28,14 @@ https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for
""" """
def request_for_labels(url: str) -> Tuple[Any, Any]: def request_for_labels(url: str) -> tuple[Any, Any]:
headers = {"Accept": "application/vnd.github.v3+json"} headers = {"Accept": "application/vnd.github.v3+json"}
return gh_fetch_url_and_headers( return gh_fetch_url_and_headers(
url, headers=headers, reader=lambda x: x.read().decode("utf-8") url, headers=headers, reader=lambda x: x.read().decode("utf-8")
) )
def update_labels(labels: List[str], info: str) -> None: def update_labels(labels: list[str], info: str) -> None:
labels_json = json.loads(info) labels_json = json.loads(info)
labels.extend([x["name"] for x in labels_json]) labels.extend([x["name"] for x in labels_json])
@ -56,10 +56,10 @@ def get_last_page_num_from_header(header: Any) -> int:
@lru_cache @lru_cache
def gh_get_labels(org: str, repo: str) -> List[str]: def gh_get_labels(org: str, repo: str) -> list[str]:
prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100" prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100"
header, info = request_for_labels(prefix + "&page=1") header, info = request_for_labels(prefix + "&page=1")
labels: List[str] = [] labels: list[str] = []
update_labels(labels, info) update_labels(labels, info)
last_page = get_last_page_num_from_header(header) last_page = get_last_page_num_from_header(header)
@ -74,7 +74,7 @@ def gh_get_labels(org: str, repo: str) -> List[str]:
def gh_add_labels( def gh_add_labels(
org: str, repo: str, pr_num: int, labels: Union[str, List[str]], dry_run: bool org: str, repo: str, pr_num: int, labels: Union[str, list[str]], dry_run: bool
) -> None: ) -> None:
if dry_run: if dry_run:
print(f"Dryrun: Adding labels {labels} to PR {pr_num}") print(f"Dryrun: Adding labels {labels} to PR {pr_num}")
@ -97,7 +97,7 @@ def gh_remove_label(
) )
def get_release_notes_labels(org: str, repo: str) -> List[str]: def get_release_notes_labels(org: str, repo: str) -> list[str]:
return [ return [
label label
for label in gh_get_labels(org, repo) for label in gh_get_labels(org, repo)

View File

@ -1,7 +1,7 @@
import hashlib import hashlib
import os import os
from pathlib import Path from pathlib import Path
from typing import Dict, NamedTuple from typing import NamedTuple
from file_io_utils import ( from file_io_utils import (
copy_file, copy_file,
@ -219,8 +219,8 @@ def _merge_lastfailed_files(source_pytest_cache: Path, dest_pytest_cache: Path)
def _merged_lastfailed_content( def _merged_lastfailed_content(
from_lastfailed: Dict[str, bool], to_lastfailed: Dict[str, bool] from_lastfailed: dict[str, bool], to_lastfailed: dict[str, bool]
) -> Dict[str, bool]: ) -> dict[str, bool]:
""" """
The lastfailed files are dictionaries where the key is the test identifier. The lastfailed files are dictionaries where the key is the test identifier.
Each entry's value appears to always be `true`, but let's not count on that. Each entry's value appears to always be `true`, but let's not count on that.

View File

@ -61,9 +61,10 @@ import random
import re import re
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from functools import lru_cache from collections.abc import Iterable
from functools import cache
from logging import LogRecord from logging import LogRecord
from typing import Any, Dict, FrozenSet, Iterable, List, NamedTuple, Set, Tuple from typing import Any, NamedTuple
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
import yaml import yaml
@ -105,7 +106,7 @@ class Settings(NamedTuple):
Settings for the experiments that can be opted into. Settings for the experiments that can be opted into.
""" """
experiments: Dict[str, Experiment] = {} experiments: dict[str, Experiment] = {}
class ColorFormatter(logging.Formatter): class ColorFormatter(logging.Formatter):
@ -150,7 +151,7 @@ def set_github_output(key: str, value: str) -> None:
f.write(f"{key}={value}\n") f.write(f"{key}={value}\n")
def _str_comma_separated_to_set(value: str) -> FrozenSet[str]: def _str_comma_separated_to_set(value: str) -> frozenset[str]:
return frozenset( return frozenset(
filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(","))) filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(",")))
) )
@ -208,12 +209,12 @@ def parse_args() -> Any:
return parser.parse_args() return parser.parse_args()
def get_gh_client(github_token: str) -> Github: def get_gh_client(github_token: str) -> Github: # type: ignore[no-any-unimported]
auth = Auth.Token(github_token) auth = Auth.Token(github_token)
return Github(auth=auth) return Github(auth=auth)
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: # type: ignore[no-any-unimported]
repo = gh.get_repo(repo) repo = gh.get_repo(repo)
return repo.get_issue(number=issue_num) return repo.get_issue(number=issue_num)
@ -242,7 +243,7 @@ def get_potential_pr_author(
raise Exception( # noqa: TRY002 raise Exception( # noqa: TRY002
f"issue with pull request {pr_number} from repo {repository}" f"issue with pull request {pr_number} from repo {repository}"
) from e ) from e
return pull.user.login return pull.user.login # type: ignore[no-any-return]
# In all other cases, return the original input username # In all other cases, return the original input username
return username return username
@ -263,7 +264,7 @@ def load_yaml(yaml_text: str) -> Any:
raise raise
def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]: def extract_settings_user_opt_in_from_text(rollout_state: str) -> tuple[str, str]:
""" """
Extracts the text with settings, if any, and the opted in users from the rollout state. Extracts the text with settings, if any, and the opted in users from the rollout state.
@ -279,7 +280,7 @@ def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str
return "", rollout_state return "", rollout_state
class UserOptins(Dict[str, List[str]]): class UserOptins(dict[str, list[str]]):
""" """
Dictionary of users with a list of features they have opted into Dictionary of users with a list of features they have opted into
""" """
@ -420,7 +421,7 @@ def get_runner_prefix(
rollout_state: str, rollout_state: str,
workflow_requestors: Iterable[str], workflow_requestors: Iterable[str],
branch: str, branch: str,
eligible_experiments: FrozenSet[str] = frozenset(), eligible_experiments: frozenset[str] = frozenset(),
is_canary: bool = False, is_canary: bool = False,
) -> str: ) -> str:
settings = parse_settings(rollout_state) settings = parse_settings(rollout_state)
@ -519,7 +520,7 @@ def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -
return str(issue.get_comments()[0].body.strip("\n\t ")) return str(issue.get_comments()[0].body.strip("\n\t "))
def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any: def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any:
for _ in range(num_retries): for _ in range(num_retries):
try: try:
req = Request(url=url, headers=headers) req = Request(url=url, headers=headers)
@ -532,8 +533,8 @@ def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> An
return {} return {}
@lru_cache(maxsize=None) @cache
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str, Any]: def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> dict[str, Any]:
""" """
Dynamically get PR information Dynamically get PR information
""" """
@ -542,7 +543,7 @@ def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str
"Accept": "application/vnd.github.v3+json", "Accept": "application/vnd.github.v3+json",
"Authorization": f"token {github_token}", "Authorization": f"token {github_token}",
} }
json_response: Dict[str, Any] = download_json( json_response: dict[str, Any] = download_json(
url=f"{github_api}/issues/{pr_number}", url=f"{github_api}/issues/{pr_number}",
headers=headers, headers=headers,
) )
@ -554,7 +555,7 @@ def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str
return json_response return json_response
def get_labels(github_repo: str, github_token: str, pr_number: int) -> Set[str]: def get_labels(github_repo: str, github_token: str, pr_number: int) -> set[str]:
""" """
Dynamically get the latest list of labels from the pull request Dynamically get the latest list of labels from the pull request
""" """

View File

@ -1,6 +1,5 @@
import argparse import argparse
import subprocess import subprocess
from typing import Dict
import generate_binary_build_matrix import generate_binary_build_matrix
@ -10,7 +9,7 @@ def tag_image(
default_tag: str, default_tag: str,
release_version: str, release_version: str,
dry_run: str, dry_run: str,
tagged_images: Dict[str, bool], tagged_images: dict[str, bool],
) -> None: ) -> None:
if image in tagged_images: if image in tagged_images:
return return
@ -41,7 +40,7 @@ def main() -> None:
) )
options = parser.parse_args() options = parser.parse_args()
tagged_images: Dict[str, bool] = {} tagged_images: dict[str, bool] = {}
platform_images = [ platform_images = [
generate_binary_build_matrix.WHEEL_CONTAINER_IMAGES, generate_binary_build_matrix.WHEEL_CONTAINER_IMAGES,
generate_binary_build_matrix.LIBTORCH_CONTAINER_IMAGES, generate_binary_build_matrix.LIBTORCH_CONTAINER_IMAGES,

View File

@ -1,6 +1,6 @@
"""test_check_labels.py""" """test_check_labels.py"""
from typing import Any, List from typing import Any
from unittest import main, mock, TestCase from unittest import main, mock, TestCase
from check_labels import ( from check_labels import (
@ -31,7 +31,7 @@ def mock_delete_all_label_err_comments(pr: "GitHubPR") -> None:
pass pass
def mock_get_comments() -> List[GitHubComment]: def mock_get_comments() -> list[GitHubComment]:
return [ return [
# Case 1 - a non label err comment # Case 1 - a non label err comment
GitHubComment( GitHubComment(

View File

@ -3,7 +3,7 @@
import json import json
import os import os
import tempfile import tempfile
from typing import Any, Dict, List from typing import Any
from unittest import main, mock, TestCase from unittest import main, mock, TestCase
import yaml import yaml
@ -362,7 +362,7 @@ class TestConfigFilter(TestCase):
self.assertEqual(case["expected"], json.dumps(filtered_test_matrix)) self.assertEqual(case["expected"], json.dumps(filtered_test_matrix))
def test_set_periodic_modes(self) -> None: def test_set_periodic_modes(self) -> None:
testcases: List[Dict[str, str]] = [ testcases: list[dict[str, str]] = [
{ {
"job_name": "a CI job", "job_name": "a CI job",
"test_matrix": "{include: []}", "test_matrix": "{include: []}",
@ -702,7 +702,7 @@ class TestConfigFilter(TestCase):
) )
mocked_subprocess.return_value = b"" mocked_subprocess.return_value = b""
testcases: List[Dict[str, Any]] = [ testcases: list[dict[str, Any]] = [
{ {
"labels": {}, "labels": {},
"test_matrix": '{include: [{config: "default"}]}', "test_matrix": '{include: [{config: "default"}]}',

View File

@ -12,7 +12,7 @@ import json
import os import os
import warnings import warnings
from hashlib import sha256 from hashlib import sha256
from typing import Any, List, Optional from typing import Any, Optional
from unittest import main, mock, skip, TestCase from unittest import main, mock, skip, TestCase
from urllib.error import HTTPError from urllib.error import HTTPError
@ -170,7 +170,7 @@ def mock_gh_get_info() -> Any:
} }
def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeRule]: def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> list[MergeRule]:
return [ return [
MergeRule( MergeRule(
name="mock with nonexistent check", name="mock with nonexistent check",
@ -182,7 +182,7 @@ def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeR
] ]
def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]: def mocked_read_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]:
return [ return [
MergeRule( MergeRule(
name="super", name="super",
@ -211,7 +211,7 @@ def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule
def mocked_read_merge_rules_approvers( def mocked_read_merge_rules_approvers(
repo: Any, org: str, project: str repo: Any, org: str, project: str
) -> List[MergeRule]: ) -> list[MergeRule]:
return [ return [
MergeRule( MergeRule(
name="Core Reviewers", name="Core Reviewers",
@ -234,11 +234,11 @@ def mocked_read_merge_rules_approvers(
] ]
def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> List[MergeRule]: def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> list[MergeRule]:
raise RuntimeError("testing") raise RuntimeError("testing")
def xla_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]: def xla_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]:
return [ return [
MergeRule( MergeRule(
name=" OSS CI / pytorchbot / XLA", name=" OSS CI / pytorchbot / XLA",
@ -260,7 +260,7 @@ class DummyGitRepo(GitRepo):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(get_git_repo_dir(), get_git_remote_name()) super().__init__(get_git_repo_dir(), get_git_remote_name())
def commits_resolving_gh_pr(self, pr_num: int) -> List[str]: def commits_resolving_gh_pr(self, pr_num: int) -> list[str]:
return ["FakeCommitSha"] return ["FakeCommitSha"]
def commit_message(self, ref: str) -> str: def commit_message(self, ref: str) -> str:

View File

@ -17,21 +17,12 @@ import re
import time import time
import urllib.parse import urllib.parse
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import cache
from pathlib import Path from pathlib import Path
from typing import ( from re import Pattern
Any, from typing import Any, Callable, cast, NamedTuple, Optional
Callable,
cast,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Pattern,
Tuple,
)
from warnings import warn from warnings import warn
import yaml import yaml
@ -78,7 +69,7 @@ class JobCheckState(NamedTuple):
summary: Optional[str] summary: Optional[str]
JobNameToStateDict = Dict[str, JobCheckState] JobNameToStateDict = dict[str, JobCheckState]
class WorkflowCheckState: class WorkflowCheckState:
@ -468,10 +459,10 @@ def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any:
return rc["data"]["repository"]["pullRequest"] return rc["data"]["repository"]["pullRequest"]
@lru_cache(maxsize=None) @cache
def gh_get_team_members(org: str, name: str) -> List[str]: def gh_get_team_members(org: str, name: str) -> list[str]:
rc: List[str] = [] rc: list[str] = []
team_members: Dict[str, Any] = { team_members: dict[str, Any] = {
"pageInfo": {"hasNextPage": "true", "endCursor": None} "pageInfo": {"hasNextPage": "true", "endCursor": None}
} }
while bool(team_members["pageInfo"]["hasNextPage"]): while bool(team_members["pageInfo"]["hasNextPage"]):
@ -503,14 +494,14 @@ def is_passing_status(status: Optional[str]) -> bool:
def add_workflow_conclusions( def add_workflow_conclusions(
checksuites: Any, checksuites: Any,
get_next_checkruns_page: Callable[[List[Dict[str, Dict[str, Any]]], int, Any], Any], get_next_checkruns_page: Callable[[list[dict[str, dict[str, Any]]], int, Any], Any],
get_next_checksuites: Callable[[Any], Any], get_next_checksuites: Callable[[Any], Any],
) -> JobNameToStateDict: ) -> JobNameToStateDict:
# graphql seems to favor the most recent workflow run, so in theory we # graphql seems to favor the most recent workflow run, so in theory we
# shouldn't need to account for reruns, but do it just in case # shouldn't need to account for reruns, but do it just in case
# workflow -> job -> job info # workflow -> job -> job info
workflows: Dict[str, WorkflowCheckState] = {} workflows: dict[str, WorkflowCheckState] = {}
# for the jobs that don't have a workflow # for the jobs that don't have a workflow
no_workflow_obj: WorkflowCheckState = WorkflowCheckState("", "", 0, None) no_workflow_obj: WorkflowCheckState = WorkflowCheckState("", "", 0, None)
@ -633,8 +624,8 @@ def _revlist_to_prs(
pr: "GitHubPR", pr: "GitHubPR",
rev_list: Iterable[str], rev_list: Iterable[str],
should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None, should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None,
) -> List[Tuple["GitHubPR", str]]: ) -> list[tuple["GitHubPR", str]]:
rc: List[Tuple[GitHubPR, str]] = [] rc: list[tuple[GitHubPR, str]] = []
for idx, rev in enumerate(rev_list): for idx, rev in enumerate(rev_list):
msg = repo.commit_message(rev) msg = repo.commit_message(rev)
m = RE_PULL_REQUEST_RESOLVED.search(msg) m = RE_PULL_REQUEST_RESOLVED.search(msg)
@ -656,7 +647,7 @@ def _revlist_to_prs(
def get_ghstack_prs( def get_ghstack_prs(
repo: GitRepo, pr: "GitHubPR", open_only: bool = True repo: GitRepo, pr: "GitHubPR", open_only: bool = True
) -> List[Tuple["GitHubPR", str]]: ) -> list[tuple["GitHubPR", str]]:
""" """
Get the PRs in the stack that are below this PR (inclusive). Throws error if any of the open PRs are out of sync. Get the PRs in the stack that are below this PR (inclusive). Throws error if any of the open PRs are out of sync.
@:param open_only: Only return open PRs @:param open_only: Only return open PRs
@ -701,14 +692,14 @@ class GitHubPR:
self.project = project self.project = project
self.pr_num = pr_num self.pr_num = pr_num
self.info = gh_get_pr_info(org, project, pr_num) self.info = gh_get_pr_info(org, project, pr_num)
self.changed_files: Optional[List[str]] = None self.changed_files: Optional[list[str]] = None
self.labels: Optional[List[str]] = None self.labels: Optional[list[str]] = None
self.conclusions: Optional[JobNameToStateDict] = None self.conclusions: Optional[JobNameToStateDict] = None
self.comments: Optional[List[GitHubComment]] = None self.comments: Optional[list[GitHubComment]] = None
self._authors: Optional[List[Tuple[str, str]]] = None self._authors: Optional[list[tuple[str, str]]] = None
self._reviews: Optional[List[Tuple[str, str]]] = None self._reviews: Optional[list[tuple[str, str]]] = None
self.merge_base: Optional[str] = None self.merge_base: Optional[str] = None
self.submodules: Optional[List[str]] = None self.submodules: Optional[list[str]] = None
def is_closed(self) -> bool: def is_closed(self) -> bool:
return bool(self.info["closed"]) return bool(self.info["closed"])
@ -763,7 +754,7 @@ class GitHubPR:
return self.merge_base return self.merge_base
def get_changed_files(self) -> List[str]: def get_changed_files(self) -> list[str]:
if self.changed_files is None: if self.changed_files is None:
info = self.info info = self.info
unique_changed_files = set() unique_changed_files = set()
@ -786,14 +777,14 @@ class GitHubPR:
raise RuntimeError("Changed file count mismatch") raise RuntimeError("Changed file count mismatch")
return self.changed_files return self.changed_files
def get_submodules(self) -> List[str]: def get_submodules(self) -> list[str]:
if self.submodules is None: if self.submodules is None:
rc = gh_graphql(GH_GET_REPO_SUBMODULES, name=self.project, owner=self.org) rc = gh_graphql(GH_GET_REPO_SUBMODULES, name=self.project, owner=self.org)
info = rc["data"]["repository"]["submodules"] info = rc["data"]["repository"]["submodules"]
self.submodules = [s["path"] for s in info["nodes"]] self.submodules = [s["path"] for s in info["nodes"]]
return self.submodules return self.submodules
def get_changed_submodules(self) -> List[str]: def get_changed_submodules(self) -> list[str]:
submodules = self.get_submodules() submodules = self.get_submodules()
return [f for f in self.get_changed_files() if f in submodules] return [f for f in self.get_changed_files() if f in submodules]
@ -809,7 +800,7 @@ class GitHubPR:
and all("submodule" not in label for label in self.get_labels()) and all("submodule" not in label for label in self.get_labels())
) )
def _get_reviews(self) -> List[Tuple[str, str]]: def _get_reviews(self) -> list[tuple[str, str]]:
if self._reviews is None: if self._reviews is None:
self._reviews = [] self._reviews = []
info = self.info info = self.info
@ -834,7 +825,7 @@ class GitHubPR:
reviews[author] = state reviews[author] = state
return list(reviews.items()) return list(reviews.items())
def get_approved_by(self) -> List[str]: def get_approved_by(self) -> list[str]:
return [login for (login, state) in self._get_reviews() if state == "APPROVED"] return [login for (login, state) in self._get_reviews() if state == "APPROVED"]
def get_commit_count(self) -> int: def get_commit_count(self) -> int:
@ -843,12 +834,12 @@ class GitHubPR:
def get_pr_creator_login(self) -> str: def get_pr_creator_login(self) -> str:
return cast(str, self.info["author"]["login"]) return cast(str, self.info["author"]["login"])
def _fetch_authors(self) -> List[Tuple[str, str]]: def _fetch_authors(self) -> list[tuple[str, str]]:
if self._authors is not None: if self._authors is not None:
return self._authors return self._authors
authors: List[Tuple[str, str]] = [] authors: list[tuple[str, str]] = []
def add_authors(info: Dict[str, Any]) -> None: def add_authors(info: dict[str, Any]) -> None:
for node in info["commits_with_authors"]["nodes"]: for node in info["commits_with_authors"]["nodes"]:
for author_node in node["commit"]["authors"]["nodes"]: for author_node in node["commit"]["authors"]["nodes"]:
user_node = author_node["user"] user_node = author_node["user"]
@ -881,7 +872,7 @@ class GitHubPR:
def get_committer_author(self, num: int = 0) -> str: def get_committer_author(self, num: int = 0) -> str:
return self._fetch_authors()[num][1] return self._fetch_authors()[num][1]
def get_labels(self) -> List[str]: def get_labels(self) -> list[str]:
if self.labels is not None: if self.labels is not None:
return self.labels return self.labels
labels = ( labels = (
@ -899,7 +890,7 @@ class GitHubPR:
orig_last_commit = self.last_commit() orig_last_commit = self.last_commit()
def get_pr_next_check_runs( def get_pr_next_check_runs(
edges: List[Dict[str, Dict[str, Any]]], edge_idx: int, checkruns: Any edges: list[dict[str, dict[str, Any]]], edge_idx: int, checkruns: Any
) -> Any: ) -> Any:
rc = gh_graphql( rc = gh_graphql(
GH_GET_PR_NEXT_CHECK_RUNS, GH_GET_PR_NEXT_CHECK_RUNS,
@ -951,7 +942,7 @@ class GitHubPR:
return self.conclusions return self.conclusions
def get_authors(self) -> Dict[str, str]: def get_authors(self) -> dict[str, str]:
rc = {} rc = {}
for idx in range(len(self._fetch_authors())): for idx in range(len(self._fetch_authors())):
rc[self.get_committer_login(idx)] = self.get_committer_author(idx) rc[self.get_committer_login(idx)] = self.get_committer_author(idx)
@ -995,7 +986,7 @@ class GitHubPR:
url=node["url"], url=node["url"],
) )
def get_comments(self) -> List[GitHubComment]: def get_comments(self) -> list[GitHubComment]:
if self.comments is not None: if self.comments is not None:
return self.comments return self.comments
self.comments = [] self.comments = []
@ -1069,7 +1060,7 @@ class GitHubPR:
skip_mandatory_checks: bool, skip_mandatory_checks: bool,
comment_id: Optional[int] = None, comment_id: Optional[int] = None,
skip_all_rule_checks: bool = False, skip_all_rule_checks: bool = False,
) -> List["GitHubPR"]: ) -> list["GitHubPR"]:
assert self.is_ghstack_pr() assert self.is_ghstack_pr()
ghstack_prs = get_ghstack_prs( ghstack_prs = get_ghstack_prs(
repo, self, open_only=False repo, self, open_only=False
@ -1099,7 +1090,7 @@ class GitHubPR:
def gen_commit_message( def gen_commit_message(
self, self,
filter_ghstack: bool = False, filter_ghstack: bool = False,
ghstack_deps: Optional[List["GitHubPR"]] = None, ghstack_deps: Optional[list["GitHubPR"]] = None,
) -> str: ) -> str:
"""Fetches title and body from PR description """Fetches title and body from PR description
adds reviewed by, pull request resolved and optionally adds reviewed by, pull request resolved and optionally
@ -1151,7 +1142,7 @@ class GitHubPR:
skip_mandatory_checks: bool = False, skip_mandatory_checks: bool = False,
dry_run: bool = False, dry_run: bool = False,
comment_id: Optional[int] = None, comment_id: Optional[int] = None,
ignore_current_checks: Optional[List[str]] = None, ignore_current_checks: Optional[list[str]] = None,
) -> None: ) -> None:
# Raises exception if matching rule is not found # Raises exception if matching rule is not found
( (
@ -1223,7 +1214,7 @@ class GitHubPR:
comment_id: Optional[int] = None, comment_id: Optional[int] = None,
branch: Optional[str] = None, branch: Optional[str] = None,
skip_all_rule_checks: bool = False, skip_all_rule_checks: bool = False,
) -> List["GitHubPR"]: ) -> list["GitHubPR"]:
""" """
:param skip_all_rule_checks: If true, skips all rule checks, useful for dry-running merge locally :param skip_all_rule_checks: If true, skips all rule checks, useful for dry-running merge locally
""" """
@ -1263,14 +1254,14 @@ class PostCommentError(Exception):
@dataclass @dataclass
class MergeRule: class MergeRule:
name: str name: str
patterns: List[str] patterns: list[str]
approved_by: List[str] approved_by: list[str]
mandatory_checks_name: Optional[List[str]] mandatory_checks_name: Optional[list[str]]
ignore_flaky_failures: bool = True ignore_flaky_failures: bool = True
def gen_new_issue_link( def gen_new_issue_link(
org: str, project: str, labels: List[str], template: str = "bug-report.yml" org: str, project: str, labels: list[str], template: str = "bug-report.yml"
) -> str: ) -> str:
labels_str = ",".join(labels) labels_str = ",".join(labels)
return ( return (
@ -1282,7 +1273,7 @@ def gen_new_issue_link(
def read_merge_rules( def read_merge_rules(
repo: Optional[GitRepo], org: str, project: str repo: Optional[GitRepo], org: str, project: str
) -> List[MergeRule]: ) -> list[MergeRule]:
"""Returns the list of all merge rules for the repo or project. """Returns the list of all merge rules for the repo or project.
NB: this function is used in Meta-internal workflows, see the comment NB: this function is used in Meta-internal workflows, see the comment
@ -1312,12 +1303,12 @@ def find_matching_merge_rule(
repo: Optional[GitRepo] = None, repo: Optional[GitRepo] = None,
skip_mandatory_checks: bool = False, skip_mandatory_checks: bool = False,
skip_internal_checks: bool = False, skip_internal_checks: bool = False,
ignore_current_checks: Optional[List[str]] = None, ignore_current_checks: Optional[list[str]] = None,
) -> Tuple[ ) -> tuple[
MergeRule, MergeRule,
List[Tuple[str, Optional[str], Optional[int]]], list[tuple[str, Optional[str], Optional[int]]],
List[Tuple[str, Optional[str], Optional[int]]], list[tuple[str, Optional[str], Optional[int]]],
Dict[str, List[Any]], dict[str, list[Any]],
]: ]:
""" """
Returns merge rule matching to this pr together with the list of associated pending Returns merge rule matching to this pr together with the list of associated pending
@ -1504,13 +1495,13 @@ def find_matching_merge_rule(
raise MergeRuleFailedError(reject_reason, rule) raise MergeRuleFailedError(reject_reason, rule)
def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str: def checks_to_str(checks: list[tuple[str, Optional[str]]]) -> str:
return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks) return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)
def checks_to_markdown_bullets( def checks_to_markdown_bullets(
checks: List[Tuple[str, Optional[str], Optional[int]]], checks: list[tuple[str, Optional[str], Optional[int]]],
) -> List[str]: ) -> list[str]:
return [ return [
f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5] f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]
] ]
@ -1518,7 +1509,7 @@ def checks_to_markdown_bullets(
def manually_close_merged_pr( def manually_close_merged_pr(
pr: GitHubPR, pr: GitHubPR,
additional_merged_prs: List[GitHubPR], additional_merged_prs: list[GitHubPR],
merge_commit_sha: str, merge_commit_sha: str,
dry_run: bool, dry_run: bool,
) -> None: ) -> None:
@ -1551,12 +1542,12 @@ def save_merge_record(
owner: str, owner: str,
project: str, project: str,
author: str, author: str,
pending_checks: List[Tuple[str, Optional[str], Optional[int]]], pending_checks: list[tuple[str, Optional[str], Optional[int]]],
failed_checks: List[Tuple[str, Optional[str], Optional[int]]], failed_checks: list[tuple[str, Optional[str], Optional[int]]],
ignore_current_checks: List[Tuple[str, Optional[str], Optional[int]]], ignore_current_checks: list[tuple[str, Optional[str], Optional[int]]],
broken_trunk_checks: List[Tuple[str, Optional[str], Optional[int]]], broken_trunk_checks: list[tuple[str, Optional[str], Optional[int]]],
flaky_checks: List[Tuple[str, Optional[str], Optional[int]]], flaky_checks: list[tuple[str, Optional[str], Optional[int]]],
unstable_checks: List[Tuple[str, Optional[str], Optional[int]]], unstable_checks: list[tuple[str, Optional[str], Optional[int]]],
last_commit_sha: str, last_commit_sha: str,
merge_base_sha: str, merge_base_sha: str,
merge_commit_sha: str = "", merge_commit_sha: str = "",
@ -1714,9 +1705,9 @@ def is_invalid_cancel(
def get_classifications( def get_classifications(
pr_num: int, pr_num: int,
project: str, project: str,
checks: Dict[str, JobCheckState], checks: dict[str, JobCheckState],
ignore_current_checks: Optional[List[str]], ignore_current_checks: Optional[list[str]],
) -> Dict[str, JobCheckState]: ) -> dict[str, JobCheckState]:
# Get the failure classification from Dr.CI, which is the source of truth # Get the failure classification from Dr.CI, which is the source of truth
# going forward. It's preferable to try calling Dr.CI API directly first # going forward. It's preferable to try calling Dr.CI API directly first
# to get the latest results as well as update Dr.CI PR comment # to get the latest results as well as update Dr.CI PR comment
@ -1825,7 +1816,7 @@ def get_classifications(
def filter_checks_with_lambda( def filter_checks_with_lambda(
checks: JobNameToStateDict, status_filter: Callable[[Optional[str]], bool] checks: JobNameToStateDict, status_filter: Callable[[Optional[str]], bool]
) -> List[JobCheckState]: ) -> list[JobCheckState]:
return [check for check in checks.values() if status_filter(check.status)] return [check for check in checks.values() if status_filter(check.status)]
@ -1841,7 +1832,7 @@ def get_pr_commit_sha(repo: GitRepo, pr: GitHubPR) -> str:
def validate_revert( def validate_revert(
repo: GitRepo, pr: GitHubPR, *, comment_id: Optional[int] = None repo: GitRepo, pr: GitHubPR, *, comment_id: Optional[int] = None
) -> Tuple[str, str]: ) -> tuple[str, str]:
comment = ( comment = (
pr.get_last_comment() pr.get_last_comment()
if comment_id is None if comment_id is None
@ -1871,7 +1862,7 @@ def validate_revert(
def get_ghstack_dependent_prs( def get_ghstack_dependent_prs(
repo: GitRepo, pr: GitHubPR, only_closed: bool = True repo: GitRepo, pr: GitHubPR, only_closed: bool = True
) -> List[Tuple[str, GitHubPR]]: ) -> list[tuple[str, GitHubPR]]:
""" """
Get the PRs in the stack that are above this PR (inclusive). Get the PRs in the stack that are above this PR (inclusive).
Throws error if stack have branched or original branches are gone Throws error if stack have branched or original branches are gone
@ -1897,7 +1888,7 @@ def get_ghstack_dependent_prs(
# Remove commits original PR depends on # Remove commits original PR depends on
if skip_len > 0: if skip_len > 0:
rev_list = rev_list[:-skip_len] rev_list = rev_list[:-skip_len]
rc: List[Tuple[str, GitHubPR]] = [] rc: list[tuple[str, GitHubPR]] = []
for pr_, sha in _revlist_to_prs(repo, pr, rev_list): for pr_, sha in _revlist_to_prs(repo, pr, rev_list):
if not pr_.is_closed(): if not pr_.is_closed():
if not only_closed: if not only_closed:
@ -1910,7 +1901,7 @@ def get_ghstack_dependent_prs(
def do_revert_prs( def do_revert_prs(
repo: GitRepo, repo: GitRepo,
shas_and_prs: List[Tuple[str, GitHubPR]], shas_and_prs: list[tuple[str, GitHubPR]],
*, *,
author_login: str, author_login: str,
extra_msg: str = "", extra_msg: str = "",
@ -2001,7 +1992,7 @@ def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None:
if skip_mandatory_checks: if skip_mandatory_checks:
return return
response = cast( response = cast(
Dict[str, Any], dict[str, Any],
gh_fetch_json_list( gh_fetch_json_list(
"https://api.github.com/search/issues", "https://api.github.com/search/issues",
# Having two label: queries is an AND operation # Having two label: queries is an AND operation
@ -2019,29 +2010,29 @@ def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None:
return return
def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool: def has_label(labels: list[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
return len(list(filter(pattern.match, labels))) > 0 return len(list(filter(pattern.match, labels))) > 0
def categorize_checks( def categorize_checks(
check_runs: JobNameToStateDict, check_runs: JobNameToStateDict,
required_checks: List[str], required_checks: list[str],
ok_failed_checks_threshold: Optional[int] = None, ok_failed_checks_threshold: Optional[int] = None,
) -> Tuple[ ) -> tuple[
List[Tuple[str, Optional[str], Optional[int]]], list[tuple[str, Optional[str], Optional[int]]],
List[Tuple[str, Optional[str], Optional[int]]], list[tuple[str, Optional[str], Optional[int]]],
Dict[str, List[Any]], dict[str, list[Any]],
]: ]:
""" """
Categories all jobs into the list of pending and failing jobs. All known flaky Categories all jobs into the list of pending and failing jobs. All known flaky
failures and broken trunk are ignored by defaults when ok_failed_checks_threshold failures and broken trunk are ignored by defaults when ok_failed_checks_threshold
is not set (unlimited) is not set (unlimited)
""" """
pending_checks: List[Tuple[str, Optional[str], Optional[int]]] = [] pending_checks: list[tuple[str, Optional[str], Optional[int]]] = []
failed_checks: List[Tuple[str, Optional[str], Optional[int]]] = [] failed_checks: list[tuple[str, Optional[str], Optional[int]]] = []
# failed_checks_categorization is used to keep track of all ignorable failures when saving the merge record on s3 # failed_checks_categorization is used to keep track of all ignorable failures when saving the merge record on s3
failed_checks_categorization: Dict[str, List[Any]] = defaultdict(list) failed_checks_categorization: dict[str, list[Any]] = defaultdict(list)
# If required_checks is not set or empty, consider all names are relevant # If required_checks is not set or empty, consider all names are relevant
relevant_checknames = [ relevant_checknames = [

View File

@ -1,6 +1,7 @@
import os import os
import re import re
from typing import List, Optional, Pattern, Tuple from re import Pattern
from typing import Optional
BOT_COMMANDS_WIKI = "https://github.com/pytorch/pytorch/wiki/Bot-commands" BOT_COMMANDS_WIKI = "https://github.com/pytorch/pytorch/wiki/Bot-commands"
@ -13,13 +14,13 @@ CONTACT_US = f"Questions? Feedback? Please reach out to the [PyTorch DevX Team](
ALTERNATIVES = f"Learn more about merging in the [wiki]({BOT_COMMANDS_WIKI})." ALTERNATIVES = f"Learn more about merging in the [wiki]({BOT_COMMANDS_WIKI})."
def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool: def has_label(labels: list[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
return len(list(filter(pattern.match, labels))) > 0 return len(list(filter(pattern.match, labels))) > 0
class TryMergeExplainer: class TryMergeExplainer:
force: bool force: bool
labels: List[str] labels: list[str]
pr_num: int pr_num: int
org: str org: str
project: str project: str
@ -31,7 +32,7 @@ class TryMergeExplainer:
def __init__( def __init__(
self, self,
force: bool, force: bool,
labels: List[str], labels: list[str],
pr_num: int, pr_num: int,
org: str, org: str,
project: str, project: str,
@ -47,7 +48,7 @@ class TryMergeExplainer:
def _get_flag_msg( def _get_flag_msg(
self, self,
ignore_current_checks: Optional[ ignore_current_checks: Optional[
List[Tuple[str, Optional[str], Optional[int]]] list[tuple[str, Optional[str], Optional[int]]]
] = None, ] = None,
) -> str: ) -> str:
if self.force: if self.force:
@ -68,7 +69,7 @@ class TryMergeExplainer:
def get_merge_message( def get_merge_message(
self, self,
ignore_current_checks: Optional[ ignore_current_checks: Optional[
List[Tuple[str, Optional[str], Optional[int]]] list[tuple[str, Optional[str], Optional[int]]]
] = None, ] = None,
) -> str: ) -> str:
title = "### Merge started" title = "### Merge started"

View File

@ -5,7 +5,8 @@ import os
import re import re
import subprocess import subprocess
import sys import sys
from typing import Any, Generator from collections.abc import Generator
from typing import Any
from github_utils import gh_post_pr_comment as gh_post_comment from github_utils import gh_post_pr_comment as gh_post_comment
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo

View File

@ -129,9 +129,10 @@ jobs:
import re import re
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from functools import lru_cache from collections.abc import Iterable
from functools import cache
from logging import LogRecord from logging import LogRecord
from typing import Any, Dict, FrozenSet, Iterable, List, NamedTuple, Set, Tuple from typing import Any, NamedTuple
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
import yaml import yaml
@ -173,7 +174,7 @@ jobs:
Settings for the experiments that can be opted into. Settings for the experiments that can be opted into.
""" """
experiments: Dict[str, Experiment] = {} experiments: dict[str, Experiment] = {}
class ColorFormatter(logging.Formatter): class ColorFormatter(logging.Formatter):
@ -218,7 +219,7 @@ jobs:
f.write(f"{key}={value}\n") f.write(f"{key}={value}\n")
def _str_comma_separated_to_set(value: str) -> FrozenSet[str]: def _str_comma_separated_to_set(value: str) -> frozenset[str]:
return frozenset( return frozenset(
filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(","))) filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(",")))
) )
@ -276,12 +277,12 @@ jobs:
return parser.parse_args() return parser.parse_args()
def get_gh_client(github_token: str) -> Github: def get_gh_client(github_token: str) -> Github: # type: ignore[no-any-unimported]
auth = Auth.Token(github_token) auth = Auth.Token(github_token)
return Github(auth=auth) return Github(auth=auth)
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: # type: ignore[no-any-unimported]
repo = gh.get_repo(repo) repo = gh.get_repo(repo)
return repo.get_issue(number=issue_num) return repo.get_issue(number=issue_num)
@ -310,7 +311,7 @@ jobs:
raise Exception( # noqa: TRY002 raise Exception( # noqa: TRY002
f"issue with pull request {pr_number} from repo {repository}" f"issue with pull request {pr_number} from repo {repository}"
) from e ) from e
return pull.user.login return pull.user.login # type: ignore[no-any-return]
# In all other cases, return the original input username # In all other cases, return the original input username
return username return username
@ -331,7 +332,7 @@ jobs:
raise raise
def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]: def extract_settings_user_opt_in_from_text(rollout_state: str) -> tuple[str, str]:
""" """
Extracts the text with settings, if any, and the opted in users from the rollout state. Extracts the text with settings, if any, and the opted in users from the rollout state.
@ -347,7 +348,7 @@ jobs:
return "", rollout_state return "", rollout_state
class UserOptins(Dict[str, List[str]]): class UserOptins(dict[str, list[str]]):
""" """
Dictionary of users with a list of features they have opted into Dictionary of users with a list of features they have opted into
""" """
@ -488,7 +489,7 @@ jobs:
rollout_state: str, rollout_state: str,
workflow_requestors: Iterable[str], workflow_requestors: Iterable[str],
branch: str, branch: str,
eligible_experiments: FrozenSet[str] = frozenset(), eligible_experiments: frozenset[str] = frozenset(),
is_canary: bool = False, is_canary: bool = False,
) -> str: ) -> str:
settings = parse_settings(rollout_state) settings = parse_settings(rollout_state)
@ -587,7 +588,7 @@ jobs:
return str(issue.get_comments()[0].body.strip("\n\t ")) return str(issue.get_comments()[0].body.strip("\n\t "))
def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any: def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any:
for _ in range(num_retries): for _ in range(num_retries):
try: try:
req = Request(url=url, headers=headers) req = Request(url=url, headers=headers)
@ -600,8 +601,8 @@ jobs:
return {} return {}
@lru_cache(maxsize=None) @cache
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str, Any]: def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> dict[str, Any]:
""" """
Dynamically get PR information Dynamically get PR information
""" """
@ -610,7 +611,7 @@ jobs:
"Accept": "application/vnd.github.v3+json", "Accept": "application/vnd.github.v3+json",
"Authorization": f"token {github_token}", "Authorization": f"token {github_token}",
} }
json_response: Dict[str, Any] = download_json( json_response: dict[str, Any] = download_json(
url=f"{github_api}/issues/{pr_number}", url=f"{github_api}/issues/{pr_number}",
headers=headers, headers=headers,
) )
@ -622,7 +623,7 @@ jobs:
return json_response return json_response
def get_labels(github_repo: str, github_token: str, pr_number: int) -> Set[str]: def get_labels(github_repo: str, github_token: str, pr_number: int) -> set[str]:
""" """
Dynamically get the latest list of labels from the pull request Dynamically get the latest list of labels from the pull request
""" """