PEP585: .github (#145707)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145707
Approved by: https://github.com/huydhn
This commit is contained in:
Aaron Orenstein 2025-01-26 11:50:47 -08:00 committed by PyTorch MergeBot
parent bfaf76bfc6
commit 60f98262f1
22 changed files with 265 additions and 280 deletions

View File

@ -3,7 +3,7 @@
import json
import os
import re
from typing import Any, cast, Dict, List, Optional
from typing import Any, cast, Optional
from urllib.error import HTTPError
from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels
@ -67,7 +67,7 @@ def get_release_version(onto_branch: str) -> Optional[str]:
def get_tracker_issues(
org: str, project: str, onto_branch: str
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
Find the tracker issue from the repo. The tracker issue needs to have the title
like [VERSION] Release Tracker following the convention on PyTorch
@ -117,7 +117,7 @@ def cherry_pick(
continue
res = cast(
Dict[str, Any],
dict[str, Any],
post_tracker_issue_comment(
org,
project,
@ -220,7 +220,7 @@ def submit_pr(
def post_pr_comment(
org: str, project: str, pr_num: int, msg: str, dry_run: bool = False
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
Post a comment on the PR itself to point to the cherry picking PR when success
or print the error when failure
@ -255,7 +255,7 @@ def post_tracker_issue_comment(
classification: str,
fixes: str,
dry_run: bool = False,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
Post a comment on the tracker issue (if any) to record the cherry pick
"""

View File

@ -6,7 +6,7 @@ import re
import sys
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Tuple
from typing import Any
import requests
from gitutils import retries_decorator
@ -76,7 +76,7 @@ DISABLED_TESTS_JSON = (
@retries_decorator()
def query_db(query: str, params: Dict[str, Any]) -> List[Dict[str, Any]]:
def query_db(query: str, params: dict[str, Any]) -> list[dict[str, Any]]:
return query_clickhouse(query, params)
@ -97,7 +97,7 @@ def download_log_worker(temp_dir: str, id: int, name: str) -> None:
f.write(data)
def printer(item: Tuple[str, Tuple[int, str, List[Any]]], extra: str) -> None:
def printer(item: tuple[str, tuple[int, str, list[Any]]], extra: str) -> None:
test, (_, link, _) = item
print(f"{link:<55} {test:<120} {extra}")
@ -120,8 +120,8 @@ def close_issue(num: int) -> None:
def check_if_exists(
item: Tuple[str, Tuple[int, str, List[str]]], all_logs: List[str]
) -> Tuple[bool, str]:
item: tuple[str, tuple[int, str, list[str]]], all_logs: list[str]
) -> tuple[bool, str]:
test, (_, link, _) = item
# Test names should look like `test_a (module.path.classname)`
reg = re.match(r"(\S+) \((\S*)\)", test)

View File

@ -2,7 +2,7 @@
import sys
from pathlib import Path
from typing import Any, cast, Dict, List, Set
from typing import Any, cast
import yaml
@ -10,9 +10,9 @@ import yaml
GITHUB_DIR = Path(__file__).parent.parent
def get_workflows_push_tags() -> Set[str]:
def get_workflows_push_tags() -> set[str]:
"Extract all known push tags from workflows"
rc: Set[str] = set()
rc: set[str] = set()
for fname in (GITHUB_DIR / "workflows").glob("*.yml"):
with fname.open("r") as f:
wf_yml = yaml.safe_load(f)
@ -25,19 +25,19 @@ def get_workflows_push_tags() -> Set[str]:
return rc
def filter_ciflow_tags(tags: Set[str]) -> List[str]:
def filter_ciflow_tags(tags: set[str]) -> list[str]:
"Return sorted list of ciflow tags"
return sorted(
tag[:-2] for tag in tags if tag.startswith("ciflow/") and tag.endswith("/*")
)
def read_probot_config() -> Dict[str, Any]:
def read_probot_config() -> dict[str, Any]:
with (GITHUB_DIR / "pytorch-probot.yml").open("r") as f:
return cast(Dict[str, Any], yaml.safe_load(f))
return cast(dict[str, Any], yaml.safe_load(f))
def update_probot_config(labels: Set[str]) -> None:
def update_probot_config(labels: set[str]) -> None:
orig = read_probot_config()
orig["ciflow_push_tags"] = filter_ciflow_tags(labels)
with (GITHUB_DIR / "pytorch-probot.yml").open("w") as f:

View File

@ -4,7 +4,7 @@ import re
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Dict, List, Set
from typing import Any, Callable
from github_utils import gh_fetch_json_dict, gh_graphql
from gitutils import GitRepo
@ -112,7 +112,7 @@ def convert_gh_timestamp(date: str) -> float:
return datetime.strptime(date, "%Y-%m-%dT%H:%M:%SZ").timestamp()
def get_branches(repo: GitRepo) -> Dict[str, Any]:
def get_branches(repo: GitRepo) -> dict[str, Any]:
# Query locally for branches, group by branch base name (e.g. gh/blah/base -> gh/blah), and get the most recent branch
git_response = repo._run_git(
"for-each-ref",
@ -120,7 +120,7 @@ def get_branches(repo: GitRepo) -> Dict[str, Any]:
"--format=%(refname) %(committerdate:iso-strict)",
"refs/remotes/origin",
)
branches_by_base_name: Dict[str, Any] = {}
branches_by_base_name: dict[str, Any] = {}
for line in git_response.splitlines():
branch, date = line.split(" ")
re_branch = re.match(r"refs/remotes/origin/(.*)", branch)
@ -140,14 +140,14 @@ def get_branches(repo: GitRepo) -> Dict[str, Any]:
def paginate_graphql(
query: str,
kwargs: Dict[str, Any],
termination_func: Callable[[List[Dict[str, Any]]], bool],
get_data: Callable[[Dict[str, Any]], List[Dict[str, Any]]],
get_page_info: Callable[[Dict[str, Any]], Dict[str, Any]],
) -> List[Any]:
kwargs: dict[str, Any],
termination_func: Callable[[list[dict[str, Any]]], bool],
get_data: Callable[[dict[str, Any]], list[dict[str, Any]]],
get_page_info: Callable[[dict[str, Any]], dict[str, Any]],
) -> list[Any]:
hasNextPage = True
endCursor = None
data: List[Dict[str, Any]] = []
data: list[dict[str, Any]] = []
while hasNextPage:
ESTIMATED_TOKENS[0] += 1
res = gh_graphql(query, cursor=endCursor, **kwargs)
@ -159,11 +159,11 @@ def paginate_graphql(
return data
def get_recent_prs() -> Dict[str, Any]:
def get_recent_prs() -> dict[str, Any]:
now = datetime.now().timestamp()
# Grab all PRs updated in last CLOSED_PR_RETENTION days
pr_infos: List[Dict[str, Any]] = paginate_graphql(
pr_infos: list[dict[str, Any]] = paginate_graphql(
GRAPHQL_ALL_PRS_BY_UPDATED_AT,
{"owner": "pytorch", "repo": "pytorch"},
lambda data: (
@ -190,7 +190,7 @@ def get_recent_prs() -> Dict[str, Any]:
@lru_cache(maxsize=1)
def get_open_prs() -> List[Dict[str, Any]]:
def get_open_prs() -> list[dict[str, Any]]:
return paginate_graphql(
GRAPHQL_OPEN_PRS,
{"owner": "pytorch", "repo": "pytorch"},
@ -200,8 +200,8 @@ def get_open_prs() -> List[Dict[str, Any]]:
)
def get_branches_with_magic_label_or_open_pr() -> Set[str]:
pr_infos: List[Dict[str, Any]] = paginate_graphql(
def get_branches_with_magic_label_or_open_pr() -> set[str]:
pr_infos: list[dict[str, Any]] = paginate_graphql(
GRAPHQL_NO_DELETE_BRANCH_LABEL,
{"owner": "pytorch", "repo": "pytorch"},
lambda data: False,

View File

@ -2,7 +2,7 @@ import json
import re
import shutil
from pathlib import Path
from typing import Any, List
from typing import Any
import boto3 # type: ignore[import]
@ -77,7 +77,7 @@ def upload_file_to_s3(file_name: Path, bucket: str, key: str) -> None:
def download_s3_objects_with_prefix(
bucket_name: str, prefix: str, download_folder: Path
) -> List[Path]:
) -> list[Path]:
s3 = boto3.resource("s3")
bucket = s3.Bucket(bucket_name)

View File

@ -8,9 +8,9 @@ import subprocess
import sys
import warnings
from enum import Enum
from functools import lru_cache
from functools import cache
from logging import info
from typing import Any, Callable, Dict, List, Optional, Set
from typing import Any, Callable, Optional
from urllib.request import Request, urlopen
import yaml
@ -32,7 +32,7 @@ def is_cuda_or_rocm_job(job_name: Optional[str]) -> bool:
# Supported modes when running periodically. Only applying the mode when
# its lambda condition returns true
SUPPORTED_PERIODICAL_MODES: Dict[str, Callable[[Optional[str]], bool]] = {
SUPPORTED_PERIODICAL_MODES: dict[str, Callable[[Optional[str]], bool]] = {
# Memory leak check is only needed for CUDA and ROCm jobs which utilize GPU memory
"mem_leak_check": is_cuda_or_rocm_job,
"rerun_disabled_tests": lambda job_name: True,
@ -102,8 +102,8 @@ def parse_args() -> Any:
return parser.parse_args()
@lru_cache(maxsize=None)
def get_pr_info(pr_number: int) -> Dict[str, Any]:
@cache
def get_pr_info(pr_number: int) -> dict[str, Any]:
"""
Dynamically get PR information
"""
@ -116,7 +116,7 @@ def get_pr_info(pr_number: int) -> Dict[str, Any]:
"Accept": "application/vnd.github.v3+json",
"Authorization": f"token {github_token}",
}
json_response: Dict[str, Any] = download_json(
json_response: dict[str, Any] = download_json(
url=f"{pytorch_github_api}/issues/{pr_number}",
headers=headers,
)
@ -128,7 +128,7 @@ def get_pr_info(pr_number: int) -> Dict[str, Any]:
return json_response
def get_labels(pr_number: int) -> Set[str]:
def get_labels(pr_number: int) -> set[str]:
"""
Dynamically get the latest list of labels from the pull request
"""
@ -138,14 +138,14 @@ def get_labels(pr_number: int) -> Set[str]:
}
def filter_labels(labels: Set[str], label_regex: Any) -> Set[str]:
def filter_labels(labels: set[str], label_regex: Any) -> set[str]:
"""
Return the list of matching labels
"""
return {l for l in labels if re.match(label_regex, l)}
def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, List[Any]]:
def filter(test_matrix: dict[str, list[Any]], labels: set[str]) -> dict[str, list[Any]]:
"""
Select the list of test config to run from the test matrix. The logic works
as follows:
@ -157,7 +157,7 @@ def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, Lis
If the PR has none of the test-config label, all tests are run as usual.
"""
filtered_test_matrix: Dict[str, List[Any]] = {"include": []}
filtered_test_matrix: dict[str, list[Any]] = {"include": []}
for entry in test_matrix.get("include", []):
config_name = entry.get("config", "")
@ -185,8 +185,8 @@ def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, Lis
def filter_selected_test_configs(
test_matrix: Dict[str, List[Any]], selected_test_configs: Set[str]
) -> Dict[str, List[Any]]:
test_matrix: dict[str, list[Any]], selected_test_configs: set[str]
) -> dict[str, list[Any]]:
"""
Keep only the selected configs if the list if not empty. Otherwise, keep all test configs.
This filter is used when the workflow is dispatched manually.
@ -194,7 +194,7 @@ def filter_selected_test_configs(
if not selected_test_configs:
return test_matrix
filtered_test_matrix: Dict[str, List[Any]] = {"include": []}
filtered_test_matrix: dict[str, list[Any]] = {"include": []}
for entry in test_matrix.get("include", []):
config_name = entry.get("config", "")
if not config_name:
@ -207,12 +207,12 @@ def filter_selected_test_configs(
def set_periodic_modes(
test_matrix: Dict[str, List[Any]], job_name: Optional[str]
) -> Dict[str, List[Any]]:
test_matrix: dict[str, list[Any]], job_name: Optional[str]
) -> dict[str, list[Any]]:
"""
Apply all periodic modes when running under a schedule
"""
scheduled_test_matrix: Dict[str, List[Any]] = {
scheduled_test_matrix: dict[str, list[Any]] = {
"include": [],
}
@ -229,8 +229,8 @@ def set_periodic_modes(
def mark_unstable_jobs(
workflow: str, job_name: str, test_matrix: Dict[str, List[Any]]
) -> Dict[str, List[Any]]:
workflow: str, job_name: str, test_matrix: dict[str, list[Any]]
) -> dict[str, list[Any]]:
"""
Check the list of unstable jobs and mark them accordingly. Note that if a job
is unstable, all its dependents will also be marked accordingly
@ -245,8 +245,8 @@ def mark_unstable_jobs(
def remove_disabled_jobs(
workflow: str, job_name: str, test_matrix: Dict[str, List[Any]]
) -> Dict[str, List[Any]]:
workflow: str, job_name: str, test_matrix: dict[str, list[Any]]
) -> dict[str, list[Any]]:
"""
Check the list of disabled jobs, remove the current job and all its dependents
if it exists in the list
@ -261,15 +261,15 @@ def remove_disabled_jobs(
def _filter_jobs(
test_matrix: Dict[str, List[Any]],
test_matrix: dict[str, list[Any]],
issue_type: IssueType,
target_cfg: Optional[str] = None,
) -> Dict[str, List[Any]]:
) -> dict[str, list[Any]]:
"""
An utility function used to actually apply the job filter
"""
# The result will be stored here
filtered_test_matrix: Dict[str, List[Any]] = {"include": []}
filtered_test_matrix: dict[str, list[Any]] = {"include": []}
# This is an issue to disable a CI job
if issue_type == IssueType.DISABLED:
@ -302,10 +302,10 @@ def _filter_jobs(
def process_jobs(
workflow: str,
job_name: str,
test_matrix: Dict[str, List[Any]],
test_matrix: dict[str, list[Any]],
issue_type: IssueType,
url: str,
) -> Dict[str, List[Any]]:
) -> dict[str, list[Any]]:
"""
Both disabled and unstable jobs are in the following format:
@ -441,7 +441,7 @@ def process_jobs(
return test_matrix
def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any:
def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any:
for _ in range(num_retries):
try:
req = Request(url=url, headers=headers)
@ -462,7 +462,7 @@ def set_output(name: str, val: Any) -> None:
print(f"::set-output name={name}::{val}")
def parse_reenabled_issues(s: Optional[str]) -> List[str]:
def parse_reenabled_issues(s: Optional[str]) -> list[str]:
# NB: When the PR body is empty, GitHub API returns a None value, which is
# passed into this function
if not s:
@ -477,7 +477,7 @@ def parse_reenabled_issues(s: Optional[str]) -> List[str]:
return issue_numbers
def get_reenabled_issues(pr_body: str = "") -> List[str]:
def get_reenabled_issues(pr_body: str = "") -> list[str]:
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
try:
commit_messages = subprocess.check_output(
@ -489,12 +489,12 @@ def get_reenabled_issues(pr_body: str = "") -> List[str]:
return parse_reenabled_issues(pr_body) + parse_reenabled_issues(commit_messages)
def check_for_setting(labels: Set[str], body: str, setting: str) -> bool:
def check_for_setting(labels: set[str], body: str, setting: str) -> bool:
return setting in labels or f"[{setting}]" in body
def perform_misc_tasks(
labels: Set[str], test_matrix: Dict[str, List[Any]], job_name: str, pr_body: str
labels: set[str], test_matrix: dict[str, list[Any]], job_name: str, pr_body: str
) -> None:
"""
In addition to apply the filter logic, the script also does the following

View File

@ -12,7 +12,7 @@ architectures:
"""
import os
from typing import Dict, List, Optional, Tuple
from typing import Optional
# NOTE: Also update the CUDA sources in tools/nightly.py when changing this list
@ -181,7 +181,7 @@ CXX11_ABI = "cxx11-abi"
RELEASE = "release"
DEBUG = "debug"
LIBTORCH_CONTAINER_IMAGES: Dict[Tuple[str, str], str] = {
LIBTORCH_CONTAINER_IMAGES: dict[tuple[str, str], str] = {
**{
(
gpu_arch,
@ -223,16 +223,16 @@ def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
}.get(gpu_arch_type, gpu_arch_version)
def list_without(in_list: List[str], without: List[str]) -> List[str]:
def list_without(in_list: list[str], without: list[str]) -> list[str]:
return [item for item in in_list if item not in without]
def generate_libtorch_matrix(
os: str,
abi_version: str,
arches: Optional[List[str]] = None,
libtorch_variants: Optional[List[str]] = None,
) -> List[Dict[str, str]]:
arches: Optional[list[str]] = None,
libtorch_variants: Optional[list[str]] = None,
) -> list[dict[str, str]]:
if arches is None:
arches = ["cpu"]
if os == "linux":
@ -248,7 +248,7 @@ def generate_libtorch_matrix(
"static-without-deps",
]
ret: List[Dict[str, str]] = []
ret: list[dict[str, str]] = []
for arch_version in arches:
for libtorch_variant in libtorch_variants:
# one of the values in the following list must be exactly
@ -287,10 +287,10 @@ def generate_libtorch_matrix(
def generate_wheels_matrix(
os: str,
arches: Optional[List[str]] = None,
python_versions: Optional[List[str]] = None,
arches: Optional[list[str]] = None,
python_versions: Optional[list[str]] = None,
use_split_build: bool = False,
) -> List[Dict[str, str]]:
) -> list[dict[str, str]]:
package_type = "wheel"
if os == "linux" or os == "linux-aarch64" or os == "linux-s390x":
# NOTE: We only build manywheel packages for x86_64 and aarch64 and s390x linux
@ -315,7 +315,7 @@ def generate_wheels_matrix(
# uses different build/test scripts
arches = ["cpu-s390x"]
ret: List[Dict[str, str]] = []
ret: list[dict[str, str]] = []
for python_version in python_versions:
for arch_version in arches:
gpu_arch_type = arch_type(arch_version)

View File

@ -2,9 +2,10 @@
import os
import sys
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Literal, Set
from typing import Literal
from typing_extensions import TypedDict # Python 3.11+
import generate_binary_build_matrix # type: ignore[import]
@ -27,7 +28,7 @@ LABEL_CIFLOW_BINARIES_WHEEL = "ciflow/binaries_wheel"
class CIFlowConfig:
# For use to enable workflows to run on pytorch/pytorch-canary
run_on_canary: bool = False
labels: Set[str] = field(default_factory=set)
labels: set[str] = field(default_factory=set)
# Certain jobs might not want to be part of the ciflow/[all,trunk] workflow
isolated_workflow: bool = False
unstable: bool = False
@ -48,7 +49,7 @@ class Config(TypedDict):
@dataclass
class BinaryBuildWorkflow:
os: str
build_configs: List[Dict[str, str]]
build_configs: list[dict[str, str]]
package_type: str
# Optional fields

View File

@ -11,11 +11,11 @@ import sys
import time
import urllib
import urllib.parse
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Optional
from urllib.request import Request, urlopen
def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]:
def parse_json_and_links(conn: Any) -> tuple[Any, dict[str, dict[str, str]]]:
links = {}
# Extract links which GH uses for pagination
# see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link
@ -42,7 +42,7 @@ def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]:
def fetch_url(
url: str,
*,
headers: Optional[Dict[str, str]] = None,
headers: Optional[dict[str, str]] = None,
reader: Callable[[Any], Any] = lambda x: x.read(),
retries: Optional[int] = 3,
backoff_timeout: float = 0.5,
@ -83,7 +83,7 @@ def parse_args() -> Any:
return parser.parse_args()
def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]:
def fetch_jobs(url: str, headers: dict[str, str]) -> list[dict[str, str]]:
response, links = fetch_url(url, headers=headers, reader=parse_json_and_links)
jobs = response["jobs"]
assert type(jobs) is list
@ -111,7 +111,7 @@ def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]:
# running.
def find_job_id_name(args: Any) -> Tuple[str, str]:
def find_job_id_name(args: Any) -> tuple[str, str]:
# From https://docs.github.com/en/actions/learn-github-actions/environment-variables
PYTORCH_REPO = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch")
PYTORCH_GITHUB_API = f"https://api.github.com/repos/{PYTORCH_REPO}"

View File

@ -4,7 +4,7 @@ import json
import os
import warnings
from dataclasses import dataclass
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Optional, Union
from urllib.error import HTTPError
from urllib.parse import quote
from urllib.request import Request, urlopen
@ -27,11 +27,11 @@ class GitHubComment:
def gh_fetch_url_and_headers(
url: str,
*,
headers: Optional[Dict[str, str]] = None,
data: Union[Optional[Dict[str, Any]], str] = None,
headers: Optional[dict[str, str]] = None,
data: Union[Optional[dict[str, Any]], str] = None,
method: Optional[str] = None,
reader: Callable[[Any], Any] = lambda x: x.read(),
) -> Tuple[Any, Any]:
) -> tuple[Any, Any]:
if headers is None:
headers = {}
token = os.environ.get("GITHUB_TOKEN")
@ -70,8 +70,8 @@ def gh_fetch_url_and_headers(
def gh_fetch_url(
url: str,
*,
headers: Optional[Dict[str, str]] = None,
data: Union[Optional[Dict[str, Any]], str] = None,
headers: Optional[dict[str, str]] = None,
data: Union[Optional[dict[str, Any]], str] = None,
method: Optional[str] = None,
reader: Callable[[Any], Any] = json.load,
) -> Any:
@ -82,25 +82,25 @@ def gh_fetch_url(
def gh_fetch_json(
url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
params: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
method: Optional[str] = None,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
headers = {"Accept": "application/vnd.github.v3+json"}
if params is not None and len(params) > 0:
url += "?" + "&".join(
f"{name}={quote(str(val))}" for name, val in params.items()
)
return cast(
List[Dict[str, Any]],
list[dict[str, Any]],
gh_fetch_url(url, headers=headers, data=data, reader=json.load, method=method),
)
def _gh_fetch_json_any(
url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
params: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
) -> Any:
headers = {"Accept": "application/vnd.github.v3+json"}
if params is not None and len(params) > 0:
@ -112,21 +112,21 @@ def _gh_fetch_json_any(
def gh_fetch_json_list(
url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
return cast(List[Dict[str, Any]], _gh_fetch_json_any(url, params, data))
params: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
) -> list[dict[str, Any]]:
return cast(list[dict[str, Any]], _gh_fetch_json_any(url, params, data))
def gh_fetch_json_dict(
url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
params: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
return cast(dict[str, Any], _gh_fetch_json_any(url, params, data))
def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]:
def gh_graphql(query: str, **kwargs: Any) -> dict[str, Any]:
rc = gh_fetch_url(
"https://api.github.com/graphql",
data={"query": query, "variables": kwargs},
@ -136,12 +136,12 @@ def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]:
raise RuntimeError(
f"GraphQL query {query}, args {kwargs} failed: {rc['errors']}"
)
return cast(Dict[str, Any], rc)
return cast(dict[str, Any], rc)
def _gh_post_comment(
url: str, comment: str, dry_run: bool = False
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
if dry_run:
print(comment)
return []
@ -150,7 +150,7 @@ def _gh_post_comment(
def gh_post_pr_comment(
org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
return _gh_post_comment(
f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/{pr_num}/comments",
comment,
@ -160,7 +160,7 @@ def gh_post_pr_comment(
def gh_post_commit_comment(
org: str, repo: str, sha: str, comment: str, dry_run: bool = False
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
return _gh_post_comment(
f"{GITHUB_API_URL}/repos/{org}/{repo}/commits/{sha}/comments",
comment,
@ -220,8 +220,8 @@ def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") ->
def gh_query_issues_by_labels(
org: str, repo: str, labels: List[str], state: str = "open"
) -> List[Dict[str, Any]]:
org: str, repo: str, labels: list[str], state: str = "open"
) -> list[dict[str, Any]]:
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues"
return gh_fetch_json(
url, method="GET", params={"labels": ",".join(labels), "state": state}

View File

@ -4,20 +4,10 @@ import os
import re
import tempfile
from collections import defaultdict
from collections.abc import Iterator
from datetime import datetime
from functools import wraps
from typing import (
Any,
Callable,
cast,
Dict,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
)
from typing import Any, Callable, cast, Optional, TypeVar, Union
T = TypeVar("T")
@ -35,17 +25,17 @@ def get_git_repo_dir() -> str:
return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parents[2]))
def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:
def fuzzy_list_to_dict(items: list[tuple[str, str]]) -> dict[str, list[str]]:
"""
Converts list to dict preserving elements with duplicate keys
"""
rc: Dict[str, List[str]] = defaultdict(list)
rc: dict[str, list[str]] = defaultdict(list)
for key, val in items:
rc[key].append(val)
return dict(rc)
def _check_output(items: List[str], encoding: str = "utf-8") -> str:
def _check_output(items: list[str], encoding: str = "utf-8") -> str:
from subprocess import CalledProcessError, check_output, STDOUT
try:
@ -95,7 +85,7 @@ class GitCommit:
return item in self.body or item in self.title
def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit:
def parse_fuller_format(lines: Union[str, list[str]]) -> GitCommit:
"""
Expect commit message generated using `--format=fuller --date=unix` format, i.e.:
commit <sha1>
@ -142,13 +132,13 @@ class GitRepo:
print(f"+ git -C {self.repo_dir} {' '.join(args)}")
return _check_output(["git", "-C", self.repo_dir] + list(args))
def revlist(self, revision_range: str) -> List[str]:
def revlist(self, revision_range: str) -> list[str]:
rc = self._run_git("rev-list", revision_range, "--", ".").strip()
return rc.split("\n") if len(rc) > 0 else []
def branches_containing_ref(
self, ref: str, *, include_remote: bool = True
) -> List[str]:
) -> list[str]:
rc = (
self._run_git("branch", "--remote", "--contains", ref)
if include_remote
@ -189,7 +179,7 @@ class GitRepo:
def get_merge_base(self, from_ref: str, to_ref: str) -> str:
return self._run_git("merge-base", from_ref, to_ref).strip()
def patch_id(self, ref: Union[str, List[str]]) -> List[Tuple[str, str]]:
def patch_id(self, ref: Union[str, list[str]]) -> list[tuple[str, str]]:
is_list = isinstance(ref, list)
if is_list:
if len(ref) == 0:
@ -198,9 +188,9 @@ class GitRepo:
rc = _check_output(
["sh", "-c", f"git -C {self.repo_dir} show {ref}|git patch-id --stable"]
).strip()
return [cast(Tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")]
return [cast(tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")]
def commits_resolving_gh_pr(self, pr_num: int) -> List[str]:
def commits_resolving_gh_pr(self, pr_num: int) -> list[str]:
owner, name = self.gh_owner_and_name()
msg = f"Pull Request resolved: https://github.com/{owner}/{name}/pull/{pr_num}"
rc = self._run_git("log", "--format=%H", "--grep", msg).strip()
@ -219,7 +209,7 @@ class GitRepo:
def compute_branch_diffs(
self, from_branch: str, to_branch: str
) -> Tuple[List[str], List[str]]:
) -> tuple[list[str], list[str]]:
"""
Returns list of commmits that are missing in each other branch since their merge base
Might be slow if merge base is between two branches is pretty far off
@ -311,14 +301,14 @@ class GitRepo:
def remote_url(self) -> str:
return self._run_git("remote", "get-url", self.remote)
def gh_owner_and_name(self) -> Tuple[str, str]:
def gh_owner_and_name(self) -> tuple[str, str]:
url = os.getenv("GIT_REMOTE_URL", None)
if url is None:
url = self.remote_url()
rc = RE_GITHUB_URL_MATCH.match(url)
if rc is None:
raise RuntimeError(f"Unexpected url format {url}")
return cast(Tuple[str, str], rc.groups())
return cast(tuple[str, str], rc.groups())
def commit_message(self, ref: str) -> str:
return self._run_git("log", "-1", "--format=%B", ref)
@ -366,7 +356,7 @@ class PeekableIterator(Iterator[str]):
return rc
def patterns_to_regex(allowed_patterns: List[str]) -> Any:
def patterns_to_regex(allowed_patterns: list[str]) -> Any:
"""
pattern is glob-like, i.e. the only special sequences it has are:
- ? - matches single character
@ -437,7 +427,7 @@ def retries_decorator(
) -> Callable[[Callable[..., T]], Callable[..., T]]:
def decorator(f: Callable[..., T]) -> Callable[..., T]:
@wraps(f)
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> T:
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> T:
for idx in range(num_retries):
try:
return f(*args, **kwargs)

View File

@ -2,7 +2,7 @@
import json
from functools import lru_cache
from typing import Any, List, Tuple, TYPE_CHECKING, Union
from typing import Any, TYPE_CHECKING, Union
from github_utils import gh_fetch_url_and_headers, GitHubComment
@ -28,14 +28,14 @@ https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for
"""
def request_for_labels(url: str) -> Tuple[Any, Any]:
def request_for_labels(url: str) -> tuple[Any, Any]:
headers = {"Accept": "application/vnd.github.v3+json"}
return gh_fetch_url_and_headers(
url, headers=headers, reader=lambda x: x.read().decode("utf-8")
)
def update_labels(labels: List[str], info: str) -> None:
def update_labels(labels: list[str], info: str) -> None:
labels_json = json.loads(info)
labels.extend([x["name"] for x in labels_json])
@ -56,10 +56,10 @@ def get_last_page_num_from_header(header: Any) -> int:
@lru_cache
def gh_get_labels(org: str, repo: str) -> List[str]:
def gh_get_labels(org: str, repo: str) -> list[str]:
prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100"
header, info = request_for_labels(prefix + "&page=1")
labels: List[str] = []
labels: list[str] = []
update_labels(labels, info)
last_page = get_last_page_num_from_header(header)
@ -74,7 +74,7 @@ def gh_get_labels(org: str, repo: str) -> List[str]:
def gh_add_labels(
org: str, repo: str, pr_num: int, labels: Union[str, List[str]], dry_run: bool
org: str, repo: str, pr_num: int, labels: Union[str, list[str]], dry_run: bool
) -> None:
if dry_run:
print(f"Dryrun: Adding labels {labels} to PR {pr_num}")
@ -97,7 +97,7 @@ def gh_remove_label(
)
def get_release_notes_labels(org: str, repo: str) -> List[str]:
def get_release_notes_labels(org: str, repo: str) -> list[str]:
return [
label
for label in gh_get_labels(org, repo)

View File

@ -1,7 +1,7 @@
import hashlib
import os
from pathlib import Path
from typing import Dict, NamedTuple
from typing import NamedTuple
from file_io_utils import (
copy_file,
@ -219,8 +219,8 @@ def _merge_lastfailed_files(source_pytest_cache: Path, dest_pytest_cache: Path)
def _merged_lastfailed_content(
from_lastfailed: Dict[str, bool], to_lastfailed: Dict[str, bool]
) -> Dict[str, bool]:
from_lastfailed: dict[str, bool], to_lastfailed: dict[str, bool]
) -> dict[str, bool]:
"""
The lastfailed files are dictionaries where the key is the test identifier.
Each entry's value appears to always be `true`, but let's not count on that.

View File

@ -61,9 +61,10 @@ import random
import re
import sys
from argparse import ArgumentParser
from functools import lru_cache
from collections.abc import Iterable
from functools import cache
from logging import LogRecord
from typing import Any, Dict, FrozenSet, Iterable, List, NamedTuple, Set, Tuple
from typing import Any, NamedTuple
from urllib.request import Request, urlopen
import yaml
@ -105,7 +106,7 @@ class Settings(NamedTuple):
Settings for the experiments that can be opted into.
"""
experiments: Dict[str, Experiment] = {}
experiments: dict[str, Experiment] = {}
class ColorFormatter(logging.Formatter):
@ -150,7 +151,7 @@ def set_github_output(key: str, value: str) -> None:
f.write(f"{key}={value}\n")
def _str_comma_separated_to_set(value: str) -> FrozenSet[str]:
def _str_comma_separated_to_set(value: str) -> frozenset[str]:
return frozenset(
filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(",")))
)
@ -208,12 +209,12 @@ def parse_args() -> Any:
return parser.parse_args()
def get_gh_client(github_token: str) -> Github:
def get_gh_client(github_token: str) -> Github: # type: ignore[no-any-unimported]
auth = Auth.Token(github_token)
return Github(auth=auth)
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue:
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: # type: ignore[no-any-unimported]
repo = gh.get_repo(repo)
return repo.get_issue(number=issue_num)
@ -242,7 +243,7 @@ def get_potential_pr_author(
raise Exception( # noqa: TRY002
f"issue with pull request {pr_number} from repo {repository}"
) from e
return pull.user.login
return pull.user.login # type: ignore[no-any-return]
# In all other cases, return the original input username
return username
@ -263,7 +264,7 @@ def load_yaml(yaml_text: str) -> Any:
raise
def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]:
def extract_settings_user_opt_in_from_text(rollout_state: str) -> tuple[str, str]:
"""
Extracts the text with settings, if any, and the opted in users from the rollout state.
@ -279,7 +280,7 @@ def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str
return "", rollout_state
class UserOptins(Dict[str, List[str]]):
class UserOptins(dict[str, list[str]]):
"""
Dictionary of users with a list of features they have opted into
"""
@ -420,7 +421,7 @@ def get_runner_prefix(
rollout_state: str,
workflow_requestors: Iterable[str],
branch: str,
eligible_experiments: FrozenSet[str] = frozenset(),
eligible_experiments: frozenset[str] = frozenset(),
is_canary: bool = False,
) -> str:
settings = parse_settings(rollout_state)
@ -519,7 +520,7 @@ def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -
return str(issue.get_comments()[0].body.strip("\n\t "))
def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any:
def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any:
for _ in range(num_retries):
try:
req = Request(url=url, headers=headers)
@ -532,8 +533,8 @@ def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> An
return {}
@lru_cache(maxsize=None)
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str, Any]:
@cache
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> dict[str, Any]:
"""
Dynamically get PR information
"""
@ -542,7 +543,7 @@ def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str
"Accept": "application/vnd.github.v3+json",
"Authorization": f"token {github_token}",
}
json_response: Dict[str, Any] = download_json(
json_response: dict[str, Any] = download_json(
url=f"{github_api}/issues/{pr_number}",
headers=headers,
)
@ -554,7 +555,7 @@ def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str
return json_response
def get_labels(github_repo: str, github_token: str, pr_number: int) -> Set[str]:
def get_labels(github_repo: str, github_token: str, pr_number: int) -> set[str]:
"""
Dynamically get the latest list of labels from the pull request
"""

View File

@ -1,6 +1,5 @@
import argparse
import subprocess
from typing import Dict
import generate_binary_build_matrix
@ -10,7 +9,7 @@ def tag_image(
default_tag: str,
release_version: str,
dry_run: str,
tagged_images: Dict[str, bool],
tagged_images: dict[str, bool],
) -> None:
if image in tagged_images:
return
@ -41,7 +40,7 @@ def main() -> None:
)
options = parser.parse_args()
tagged_images: Dict[str, bool] = {}
tagged_images: dict[str, bool] = {}
platform_images = [
generate_binary_build_matrix.WHEEL_CONTAINER_IMAGES,
generate_binary_build_matrix.LIBTORCH_CONTAINER_IMAGES,

View File

@ -1,6 +1,6 @@
"""test_check_labels.py"""
from typing import Any, List
from typing import Any
from unittest import main, mock, TestCase
from check_labels import (
@ -31,7 +31,7 @@ def mock_delete_all_label_err_comments(pr: "GitHubPR") -> None:
pass
def mock_get_comments() -> List[GitHubComment]:
def mock_get_comments() -> list[GitHubComment]:
return [
# Case 1 - a non label err comment
GitHubComment(

View File

@ -3,7 +3,7 @@
import json
import os
import tempfile
from typing import Any, Dict, List
from typing import Any
from unittest import main, mock, TestCase
import yaml
@ -362,7 +362,7 @@ class TestConfigFilter(TestCase):
self.assertEqual(case["expected"], json.dumps(filtered_test_matrix))
def test_set_periodic_modes(self) -> None:
testcases: List[Dict[str, str]] = [
testcases: list[dict[str, str]] = [
{
"job_name": "a CI job",
"test_matrix": "{include: []}",
@ -702,7 +702,7 @@ class TestConfigFilter(TestCase):
)
mocked_subprocess.return_value = b""
testcases: List[Dict[str, Any]] = [
testcases: list[dict[str, Any]] = [
{
"labels": {},
"test_matrix": '{include: [{config: "default"}]}',

View File

@ -12,7 +12,7 @@ import json
import os
import warnings
from hashlib import sha256
from typing import Any, List, Optional
from typing import Any, Optional
from unittest import main, mock, skip, TestCase
from urllib.error import HTTPError
@ -170,7 +170,7 @@ def mock_gh_get_info() -> Any:
}
def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeRule]:
def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> list[MergeRule]:
return [
MergeRule(
name="mock with nonexistent check",
@ -182,7 +182,7 @@ def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeR
]
def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
def mocked_read_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]:
return [
MergeRule(
name="super",
@ -211,7 +211,7 @@ def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule
def mocked_read_merge_rules_approvers(
repo: Any, org: str, project: str
) -> List[MergeRule]:
) -> list[MergeRule]:
return [
MergeRule(
name="Core Reviewers",
@ -234,11 +234,11 @@ def mocked_read_merge_rules_approvers(
]
def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> List[MergeRule]:
def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> list[MergeRule]:
raise RuntimeError("testing")
def xla_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
def xla_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]:
return [
MergeRule(
name=" OSS CI / pytorchbot / XLA",
@ -260,7 +260,7 @@ class DummyGitRepo(GitRepo):
def __init__(self) -> None:
super().__init__(get_git_repo_dir(), get_git_remote_name())
def commits_resolving_gh_pr(self, pr_num: int) -> List[str]:
def commits_resolving_gh_pr(self, pr_num: int) -> list[str]:
return ["FakeCommitSha"]
def commit_message(self, ref: str) -> str:

View File

@ -17,21 +17,12 @@ import re
import time
import urllib.parse
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from functools import lru_cache
from functools import cache
from pathlib import Path
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Pattern,
Tuple,
)
from re import Pattern
from typing import Any, Callable, cast, NamedTuple, Optional
from warnings import warn
import yaml
@ -78,7 +69,7 @@ class JobCheckState(NamedTuple):
summary: Optional[str]
JobNameToStateDict = Dict[str, JobCheckState]
JobNameToStateDict = dict[str, JobCheckState]
class WorkflowCheckState:
@ -468,10 +459,10 @@ def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any:
return rc["data"]["repository"]["pullRequest"]
@lru_cache(maxsize=None)
def gh_get_team_members(org: str, name: str) -> List[str]:
rc: List[str] = []
team_members: Dict[str, Any] = {
@cache
def gh_get_team_members(org: str, name: str) -> list[str]:
rc: list[str] = []
team_members: dict[str, Any] = {
"pageInfo": {"hasNextPage": "true", "endCursor": None}
}
while bool(team_members["pageInfo"]["hasNextPage"]):
@ -503,14 +494,14 @@ def is_passing_status(status: Optional[str]) -> bool:
def add_workflow_conclusions(
checksuites: Any,
get_next_checkruns_page: Callable[[List[Dict[str, Dict[str, Any]]], int, Any], Any],
get_next_checkruns_page: Callable[[list[dict[str, dict[str, Any]]], int, Any], Any],
get_next_checksuites: Callable[[Any], Any],
) -> JobNameToStateDict:
# graphql seems to favor the most recent workflow run, so in theory we
# shouldn't need to account for reruns, but do it just in case
# workflow -> job -> job info
workflows: Dict[str, WorkflowCheckState] = {}
workflows: dict[str, WorkflowCheckState] = {}
# for the jobs that don't have a workflow
no_workflow_obj: WorkflowCheckState = WorkflowCheckState("", "", 0, None)
@ -633,8 +624,8 @@ def _revlist_to_prs(
pr: "GitHubPR",
rev_list: Iterable[str],
should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None,
) -> List[Tuple["GitHubPR", str]]:
rc: List[Tuple[GitHubPR, str]] = []
) -> list[tuple["GitHubPR", str]]:
rc: list[tuple[GitHubPR, str]] = []
for idx, rev in enumerate(rev_list):
msg = repo.commit_message(rev)
m = RE_PULL_REQUEST_RESOLVED.search(msg)
@ -656,7 +647,7 @@ def _revlist_to_prs(
def get_ghstack_prs(
repo: GitRepo, pr: "GitHubPR", open_only: bool = True
) -> List[Tuple["GitHubPR", str]]:
) -> list[tuple["GitHubPR", str]]:
"""
Get the PRs in the stack that are below this PR (inclusive). Throws error if any of the open PRs are out of sync.
@:param open_only: Only return open PRs
@ -701,14 +692,14 @@ class GitHubPR:
self.project = project
self.pr_num = pr_num
self.info = gh_get_pr_info(org, project, pr_num)
self.changed_files: Optional[List[str]] = None
self.labels: Optional[List[str]] = None
self.changed_files: Optional[list[str]] = None
self.labels: Optional[list[str]] = None
self.conclusions: Optional[JobNameToStateDict] = None
self.comments: Optional[List[GitHubComment]] = None
self._authors: Optional[List[Tuple[str, str]]] = None
self._reviews: Optional[List[Tuple[str, str]]] = None
self.comments: Optional[list[GitHubComment]] = None
self._authors: Optional[list[tuple[str, str]]] = None
self._reviews: Optional[list[tuple[str, str]]] = None
self.merge_base: Optional[str] = None
self.submodules: Optional[List[str]] = None
self.submodules: Optional[list[str]] = None
def is_closed(self) -> bool:
return bool(self.info["closed"])
@ -763,7 +754,7 @@ class GitHubPR:
return self.merge_base
def get_changed_files(self) -> List[str]:
def get_changed_files(self) -> list[str]:
if self.changed_files is None:
info = self.info
unique_changed_files = set()
@ -786,14 +777,14 @@ class GitHubPR:
raise RuntimeError("Changed file count mismatch")
return self.changed_files
def get_submodules(self) -> List[str]:
def get_submodules(self) -> list[str]:
if self.submodules is None:
rc = gh_graphql(GH_GET_REPO_SUBMODULES, name=self.project, owner=self.org)
info = rc["data"]["repository"]["submodules"]
self.submodules = [s["path"] for s in info["nodes"]]
return self.submodules
def get_changed_submodules(self) -> List[str]:
def get_changed_submodules(self) -> list[str]:
submodules = self.get_submodules()
return [f for f in self.get_changed_files() if f in submodules]
@ -809,7 +800,7 @@ class GitHubPR:
and all("submodule" not in label for label in self.get_labels())
)
def _get_reviews(self) -> List[Tuple[str, str]]:
def _get_reviews(self) -> list[tuple[str, str]]:
if self._reviews is None:
self._reviews = []
info = self.info
@ -834,7 +825,7 @@ class GitHubPR:
reviews[author] = state
return list(reviews.items())
def get_approved_by(self) -> List[str]:
def get_approved_by(self) -> list[str]:
return [login for (login, state) in self._get_reviews() if state == "APPROVED"]
def get_commit_count(self) -> int:
@ -843,12 +834,12 @@ class GitHubPR:
def get_pr_creator_login(self) -> str:
return cast(str, self.info["author"]["login"])
def _fetch_authors(self) -> List[Tuple[str, str]]:
def _fetch_authors(self) -> list[tuple[str, str]]:
if self._authors is not None:
return self._authors
authors: List[Tuple[str, str]] = []
authors: list[tuple[str, str]] = []
def add_authors(info: Dict[str, Any]) -> None:
def add_authors(info: dict[str, Any]) -> None:
for node in info["commits_with_authors"]["nodes"]:
for author_node in node["commit"]["authors"]["nodes"]:
user_node = author_node["user"]
@ -881,7 +872,7 @@ class GitHubPR:
def get_committer_author(self, num: int = 0) -> str:
return self._fetch_authors()[num][1]
def get_labels(self) -> List[str]:
def get_labels(self) -> list[str]:
if self.labels is not None:
return self.labels
labels = (
@ -899,7 +890,7 @@ class GitHubPR:
orig_last_commit = self.last_commit()
def get_pr_next_check_runs(
edges: List[Dict[str, Dict[str, Any]]], edge_idx: int, checkruns: Any
edges: list[dict[str, dict[str, Any]]], edge_idx: int, checkruns: Any
) -> Any:
rc = gh_graphql(
GH_GET_PR_NEXT_CHECK_RUNS,
@ -951,7 +942,7 @@ class GitHubPR:
return self.conclusions
def get_authors(self) -> Dict[str, str]:
def get_authors(self) -> dict[str, str]:
rc = {}
for idx in range(len(self._fetch_authors())):
rc[self.get_committer_login(idx)] = self.get_committer_author(idx)
@ -995,7 +986,7 @@ class GitHubPR:
url=node["url"],
)
def get_comments(self) -> List[GitHubComment]:
def get_comments(self) -> list[GitHubComment]:
if self.comments is not None:
return self.comments
self.comments = []
@ -1069,7 +1060,7 @@ class GitHubPR:
skip_mandatory_checks: bool,
comment_id: Optional[int] = None,
skip_all_rule_checks: bool = False,
) -> List["GitHubPR"]:
) -> list["GitHubPR"]:
assert self.is_ghstack_pr()
ghstack_prs = get_ghstack_prs(
repo, self, open_only=False
@ -1099,7 +1090,7 @@ class GitHubPR:
def gen_commit_message(
self,
filter_ghstack: bool = False,
ghstack_deps: Optional[List["GitHubPR"]] = None,
ghstack_deps: Optional[list["GitHubPR"]] = None,
) -> str:
"""Fetches title and body from PR description
adds reviewed by, pull request resolved and optionally
@ -1151,7 +1142,7 @@ class GitHubPR:
skip_mandatory_checks: bool = False,
dry_run: bool = False,
comment_id: Optional[int] = None,
ignore_current_checks: Optional[List[str]] = None,
ignore_current_checks: Optional[list[str]] = None,
) -> None:
# Raises exception if matching rule is not found
(
@ -1223,7 +1214,7 @@ class GitHubPR:
comment_id: Optional[int] = None,
branch: Optional[str] = None,
skip_all_rule_checks: bool = False,
) -> List["GitHubPR"]:
) -> list["GitHubPR"]:
"""
:param skip_all_rule_checks: If true, skips all rule checks, useful for dry-running merge locally
"""
@ -1263,14 +1254,14 @@ class PostCommentError(Exception):
@dataclass
class MergeRule:
name: str
patterns: List[str]
approved_by: List[str]
mandatory_checks_name: Optional[List[str]]
patterns: list[str]
approved_by: list[str]
mandatory_checks_name: Optional[list[str]]
ignore_flaky_failures: bool = True
def gen_new_issue_link(
org: str, project: str, labels: List[str], template: str = "bug-report.yml"
org: str, project: str, labels: list[str], template: str = "bug-report.yml"
) -> str:
labels_str = ",".join(labels)
return (
@ -1282,7 +1273,7 @@ def gen_new_issue_link(
def read_merge_rules(
repo: Optional[GitRepo], org: str, project: str
) -> List[MergeRule]:
) -> list[MergeRule]:
"""Returns the list of all merge rules for the repo or project.
NB: this function is used in Meta-internal workflows, see the comment
@ -1312,12 +1303,12 @@ def find_matching_merge_rule(
repo: Optional[GitRepo] = None,
skip_mandatory_checks: bool = False,
skip_internal_checks: bool = False,
ignore_current_checks: Optional[List[str]] = None,
) -> Tuple[
ignore_current_checks: Optional[list[str]] = None,
) -> tuple[
MergeRule,
List[Tuple[str, Optional[str], Optional[int]]],
List[Tuple[str, Optional[str], Optional[int]]],
Dict[str, List[Any]],
list[tuple[str, Optional[str], Optional[int]]],
list[tuple[str, Optional[str], Optional[int]]],
dict[str, list[Any]],
]:
"""
Returns merge rule matching to this pr together with the list of associated pending
@ -1504,13 +1495,13 @@ def find_matching_merge_rule(
raise MergeRuleFailedError(reject_reason, rule)
def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
def checks_to_str(checks: list[tuple[str, Optional[str]]]) -> str:
return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)
def checks_to_markdown_bullets(
checks: List[Tuple[str, Optional[str], Optional[int]]],
) -> List[str]:
checks: list[tuple[str, Optional[str], Optional[int]]],
) -> list[str]:
return [
f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]
]
@ -1518,7 +1509,7 @@ def checks_to_markdown_bullets(
def manually_close_merged_pr(
pr: GitHubPR,
additional_merged_prs: List[GitHubPR],
additional_merged_prs: list[GitHubPR],
merge_commit_sha: str,
dry_run: bool,
) -> None:
@ -1551,12 +1542,12 @@ def save_merge_record(
owner: str,
project: str,
author: str,
pending_checks: List[Tuple[str, Optional[str], Optional[int]]],
failed_checks: List[Tuple[str, Optional[str], Optional[int]]],
ignore_current_checks: List[Tuple[str, Optional[str], Optional[int]]],
broken_trunk_checks: List[Tuple[str, Optional[str], Optional[int]]],
flaky_checks: List[Tuple[str, Optional[str], Optional[int]]],
unstable_checks: List[Tuple[str, Optional[str], Optional[int]]],
pending_checks: list[tuple[str, Optional[str], Optional[int]]],
failed_checks: list[tuple[str, Optional[str], Optional[int]]],
ignore_current_checks: list[tuple[str, Optional[str], Optional[int]]],
broken_trunk_checks: list[tuple[str, Optional[str], Optional[int]]],
flaky_checks: list[tuple[str, Optional[str], Optional[int]]],
unstable_checks: list[tuple[str, Optional[str], Optional[int]]],
last_commit_sha: str,
merge_base_sha: str,
merge_commit_sha: str = "",
@ -1714,9 +1705,9 @@ def is_invalid_cancel(
def get_classifications(
pr_num: int,
project: str,
checks: Dict[str, JobCheckState],
ignore_current_checks: Optional[List[str]],
) -> Dict[str, JobCheckState]:
checks: dict[str, JobCheckState],
ignore_current_checks: Optional[list[str]],
) -> dict[str, JobCheckState]:
# Get the failure classification from Dr.CI, which is the source of truth
# going forward. It's preferable to try calling Dr.CI API directly first
# to get the latest results as well as update Dr.CI PR comment
@ -1825,7 +1816,7 @@ def get_classifications(
def filter_checks_with_lambda(
checks: JobNameToStateDict, status_filter: Callable[[Optional[str]], bool]
) -> List[JobCheckState]:
) -> list[JobCheckState]:
return [check for check in checks.values() if status_filter(check.status)]
@ -1841,7 +1832,7 @@ def get_pr_commit_sha(repo: GitRepo, pr: GitHubPR) -> str:
def validate_revert(
repo: GitRepo, pr: GitHubPR, *, comment_id: Optional[int] = None
) -> Tuple[str, str]:
) -> tuple[str, str]:
comment = (
pr.get_last_comment()
if comment_id is None
@ -1871,7 +1862,7 @@ def validate_revert(
def get_ghstack_dependent_prs(
repo: GitRepo, pr: GitHubPR, only_closed: bool = True
) -> List[Tuple[str, GitHubPR]]:
) -> list[tuple[str, GitHubPR]]:
"""
Get the PRs in the stack that are above this PR (inclusive).
Throws error if stack have branched or original branches are gone
@ -1897,7 +1888,7 @@ def get_ghstack_dependent_prs(
# Remove commits original PR depends on
if skip_len > 0:
rev_list = rev_list[:-skip_len]
rc: List[Tuple[str, GitHubPR]] = []
rc: list[tuple[str, GitHubPR]] = []
for pr_, sha in _revlist_to_prs(repo, pr, rev_list):
if not pr_.is_closed():
if not only_closed:
@ -1910,7 +1901,7 @@ def get_ghstack_dependent_prs(
def do_revert_prs(
repo: GitRepo,
shas_and_prs: List[Tuple[str, GitHubPR]],
shas_and_prs: list[tuple[str, GitHubPR]],
*,
author_login: str,
extra_msg: str = "",
@ -2001,7 +1992,7 @@ def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None:
if skip_mandatory_checks:
return
response = cast(
Dict[str, Any],
dict[str, Any],
gh_fetch_json_list(
"https://api.github.com/search/issues",
# Having two label: queries is an AND operation
@ -2019,29 +2010,29 @@ def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None:
return
def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
def has_label(labels: list[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
return len(list(filter(pattern.match, labels))) > 0
def categorize_checks(
check_runs: JobNameToStateDict,
required_checks: List[str],
required_checks: list[str],
ok_failed_checks_threshold: Optional[int] = None,
) -> Tuple[
List[Tuple[str, Optional[str], Optional[int]]],
List[Tuple[str, Optional[str], Optional[int]]],
Dict[str, List[Any]],
) -> tuple[
list[tuple[str, Optional[str], Optional[int]]],
list[tuple[str, Optional[str], Optional[int]]],
dict[str, list[Any]],
]:
"""
Categories all jobs into the list of pending and failing jobs. All known flaky
failures and broken trunk are ignored by defaults when ok_failed_checks_threshold
is not set (unlimited)
"""
pending_checks: List[Tuple[str, Optional[str], Optional[int]]] = []
failed_checks: List[Tuple[str, Optional[str], Optional[int]]] = []
pending_checks: list[tuple[str, Optional[str], Optional[int]]] = []
failed_checks: list[tuple[str, Optional[str], Optional[int]]] = []
# failed_checks_categorization is used to keep track of all ignorable failures when saving the merge record on s3
failed_checks_categorization: Dict[str, List[Any]] = defaultdict(list)
failed_checks_categorization: dict[str, list[Any]] = defaultdict(list)
# If required_checks is not set or empty, consider all names are relevant
relevant_checknames = [

View File

@ -1,6 +1,7 @@
import os
import re
from typing import List, Optional, Pattern, Tuple
from re import Pattern
from typing import Optional
BOT_COMMANDS_WIKI = "https://github.com/pytorch/pytorch/wiki/Bot-commands"
@ -13,13 +14,13 @@ CONTACT_US = f"Questions? Feedback? Please reach out to the [PyTorch DevX Team](
ALTERNATIVES = f"Learn more about merging in the [wiki]({BOT_COMMANDS_WIKI})."
def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
def has_label(labels: list[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
return len(list(filter(pattern.match, labels))) > 0
class TryMergeExplainer:
force: bool
labels: List[str]
labels: list[str]
pr_num: int
org: str
project: str
@ -31,7 +32,7 @@ class TryMergeExplainer:
def __init__(
self,
force: bool,
labels: List[str],
labels: list[str],
pr_num: int,
org: str,
project: str,
@ -47,7 +48,7 @@ class TryMergeExplainer:
def _get_flag_msg(
self,
ignore_current_checks: Optional[
List[Tuple[str, Optional[str], Optional[int]]]
list[tuple[str, Optional[str], Optional[int]]]
] = None,
) -> str:
if self.force:
@ -68,7 +69,7 @@ class TryMergeExplainer:
def get_merge_message(
self,
ignore_current_checks: Optional[
List[Tuple[str, Optional[str], Optional[int]]]
list[tuple[str, Optional[str], Optional[int]]]
] = None,
) -> str:
title = "### Merge started"

View File

@ -5,7 +5,8 @@ import os
import re
import subprocess
import sys
from typing import Any, Generator
from collections.abc import Generator
from typing import Any
from github_utils import gh_post_pr_comment as gh_post_comment
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo

View File

@ -129,9 +129,10 @@ jobs:
import re
import sys
from argparse import ArgumentParser
from functools import lru_cache
from collections.abc import Iterable
from functools import cache
from logging import LogRecord
from typing import Any, Dict, FrozenSet, Iterable, List, NamedTuple, Set, Tuple
from typing import Any, NamedTuple
from urllib.request import Request, urlopen
import yaml
@ -173,7 +174,7 @@ jobs:
Settings for the experiments that can be opted into.
"""
experiments: Dict[str, Experiment] = {}
experiments: dict[str, Experiment] = {}
class ColorFormatter(logging.Formatter):
@ -218,7 +219,7 @@ jobs:
f.write(f"{key}={value}\n")
def _str_comma_separated_to_set(value: str) -> FrozenSet[str]:
def _str_comma_separated_to_set(value: str) -> frozenset[str]:
return frozenset(
filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(",")))
)
@ -276,12 +277,12 @@ jobs:
return parser.parse_args()
def get_gh_client(github_token: str) -> Github:
def get_gh_client(github_token: str) -> Github: # type: ignore[no-any-unimported]
auth = Auth.Token(github_token)
return Github(auth=auth)
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue:
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: # type: ignore[no-any-unimported]
repo = gh.get_repo(repo)
return repo.get_issue(number=issue_num)
@ -310,7 +311,7 @@ jobs:
raise Exception( # noqa: TRY002
f"issue with pull request {pr_number} from repo {repository}"
) from e
return pull.user.login
return pull.user.login # type: ignore[no-any-return]
# In all other cases, return the original input username
return username
@ -331,7 +332,7 @@ jobs:
raise
def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]:
def extract_settings_user_opt_in_from_text(rollout_state: str) -> tuple[str, str]:
"""
Extracts the text with settings, if any, and the opted in users from the rollout state.
@ -347,7 +348,7 @@ jobs:
return "", rollout_state
class UserOptins(Dict[str, List[str]]):
class UserOptins(dict[str, list[str]]):
"""
Dictionary of users with a list of features they have opted into
"""
@ -488,7 +489,7 @@ jobs:
rollout_state: str,
workflow_requestors: Iterable[str],
branch: str,
eligible_experiments: FrozenSet[str] = frozenset(),
eligible_experiments: frozenset[str] = frozenset(),
is_canary: bool = False,
) -> str:
settings = parse_settings(rollout_state)
@ -587,7 +588,7 @@ jobs:
return str(issue.get_comments()[0].body.strip("\n\t "))
def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any:
def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any:
for _ in range(num_retries):
try:
req = Request(url=url, headers=headers)
@ -600,8 +601,8 @@ jobs:
return {}
@lru_cache(maxsize=None)
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str, Any]:
@cache
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> dict[str, Any]:
"""
Dynamically get PR information
"""
@ -610,7 +611,7 @@ jobs:
"Accept": "application/vnd.github.v3+json",
"Authorization": f"token {github_token}",
}
json_response: Dict[str, Any] = download_json(
json_response: dict[str, Any] = download_json(
url=f"{github_api}/issues/{pr_number}",
headers=headers,
)
@ -622,7 +623,7 @@ jobs:
return json_response
def get_labels(github_repo: str, github_token: str, pr_number: int) -> Set[str]:
def get_labels(github_repo: str, github_token: str, pr_number: int) -> set[str]:
"""
Dynamically get the latest list of labels from the pull request
"""