From 60f98262f17ced1b2da1cddc36f2cbf666a43a12 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sun, 26 Jan 2025 11:50:47 -0800 Subject: [PATCH] PEP585: .github (#145707) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145707 Approved by: https://github.com/huydhn --- .github/scripts/cherry_pick.py | 10 +- .../close_nonexistent_disable_issues.py | 10 +- .github/scripts/collect_ciflow_labels.py | 14 +- .github/scripts/delete_old_branches.py | 28 ++-- .github/scripts/file_io_utils.py | 4 +- .github/scripts/filter_test_configs.py | 60 +++---- .../scripts/generate_binary_build_matrix.py | 22 +-- .github/scripts/generate_ci_workflows.py | 7 +- .github/scripts/get_workflow_job_id.py | 10 +- .github/scripts/github_utils.py | 54 +++--- .github/scripts/gitutils.py | 42 ++--- .github/scripts/label_utils.py | 14 +- .github/scripts/pytest_caching_utils.py | 6 +- .github/scripts/runner_determinator.py | 31 ++-- .../scripts/tag_docker_images_for_release.py | 5 +- .github/scripts/test_check_labels.py | 4 +- .github/scripts/test_filter_test_configs.py | 6 +- .github/scripts/test_trymerge.py | 14 +- .github/scripts/trymerge.py | 157 +++++++++--------- .github/scripts/trymerge_explainer.py | 13 +- .github/scripts/tryrebase.py | 3 +- .github/workflows/_runner-determinator.yml | 31 ++-- 22 files changed, 265 insertions(+), 280 deletions(-) diff --git a/.github/scripts/cherry_pick.py b/.github/scripts/cherry_pick.py index 2fecf0bcb63..c2776040d81 100755 --- a/.github/scripts/cherry_pick.py +++ b/.github/scripts/cherry_pick.py @@ -3,7 +3,7 @@ import json import os import re -from typing import Any, cast, Dict, List, Optional +from typing import Any, cast, Optional from urllib.error import HTTPError from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels @@ -67,7 +67,7 @@ def get_release_version(onto_branch: str) -> Optional[str]: def get_tracker_issues( org: str, project: str, onto_branch: str -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ Find the tracker issue from the repo. The tracker issue needs to have the title like [VERSION] Release Tracker following the convention on PyTorch @@ -117,7 +117,7 @@ def cherry_pick( continue res = cast( - Dict[str, Any], + dict[str, Any], post_tracker_issue_comment( org, project, @@ -220,7 +220,7 @@ def submit_pr( def post_pr_comment( org: str, project: str, pr_num: int, msg: str, dry_run: bool = False -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ Post a comment on the PR itself to point to the cherry picking PR when success or print the error when failure @@ -255,7 +255,7 @@ def post_tracker_issue_comment( classification: str, fixes: str, dry_run: bool = False, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ Post a comment on the tracker issue (if any) to record the cherry pick """ diff --git a/.github/scripts/close_nonexistent_disable_issues.py b/.github/scripts/close_nonexistent_disable_issues.py index da58078d251..a40e3c851a3 100644 --- a/.github/scripts/close_nonexistent_disable_issues.py +++ b/.github/scripts/close_nonexistent_disable_issues.py @@ -6,7 +6,7 @@ import re import sys import tempfile from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any import requests from gitutils import retries_decorator @@ -76,7 +76,7 @@ DISABLED_TESTS_JSON = ( @retries_decorator() -def query_db(query: str, params: Dict[str, Any]) -> List[Dict[str, Any]]: +def query_db(query: str, params: dict[str, Any]) -> list[dict[str, Any]]: return query_clickhouse(query, params) @@ -97,7 +97,7 @@ def download_log_worker(temp_dir: str, id: int, name: str) -> None: f.write(data) -def printer(item: Tuple[str, Tuple[int, str, List[Any]]], extra: str) -> None: +def printer(item: tuple[str, tuple[int, str, list[Any]]], extra: str) -> None: test, (_, link, _) = item print(f"{link:<55} {test:<120} {extra}") @@ -120,8 +120,8 @@ def close_issue(num: int) -> None: def check_if_exists( - item: Tuple[str, Tuple[int, str, List[str]]], all_logs: List[str] -) -> Tuple[bool, str]: + item: tuple[str, tuple[int, str, list[str]]], all_logs: list[str] +) -> tuple[bool, str]: test, (_, link, _) = item # Test names should look like `test_a (module.path.classname)` reg = re.match(r"(\S+) \((\S*)\)", test) diff --git a/.github/scripts/collect_ciflow_labels.py b/.github/scripts/collect_ciflow_labels.py index 2cd53d14795..920c8a9e524 100755 --- a/.github/scripts/collect_ciflow_labels.py +++ b/.github/scripts/collect_ciflow_labels.py @@ -2,7 +2,7 @@ import sys from pathlib import Path -from typing import Any, cast, Dict, List, Set +from typing import Any, cast import yaml @@ -10,9 +10,9 @@ import yaml GITHUB_DIR = Path(__file__).parent.parent -def get_workflows_push_tags() -> Set[str]: +def get_workflows_push_tags() -> set[str]: "Extract all known push tags from workflows" - rc: Set[str] = set() + rc: set[str] = set() for fname in (GITHUB_DIR / "workflows").glob("*.yml"): with fname.open("r") as f: wf_yml = yaml.safe_load(f) @@ -25,19 +25,19 @@ def get_workflows_push_tags() -> Set[str]: return rc -def filter_ciflow_tags(tags: Set[str]) -> List[str]: +def filter_ciflow_tags(tags: set[str]) -> list[str]: "Return sorted list of ciflow tags" return sorted( tag[:-2] for tag in tags if tag.startswith("ciflow/") and tag.endswith("/*") ) -def read_probot_config() -> Dict[str, Any]: +def read_probot_config() -> dict[str, Any]: with (GITHUB_DIR / "pytorch-probot.yml").open("r") as f: - return cast(Dict[str, Any], yaml.safe_load(f)) + return cast(dict[str, Any], yaml.safe_load(f)) -def update_probot_config(labels: Set[str]) -> None: +def update_probot_config(labels: set[str]) -> None: orig = read_probot_config() orig["ciflow_push_tags"] = filter_ciflow_tags(labels) with (GITHUB_DIR / "pytorch-probot.yml").open("w") as f: diff --git a/.github/scripts/delete_old_branches.py b/.github/scripts/delete_old_branches.py index e28d33c642b..b96c3956856 100644 --- a/.github/scripts/delete_old_branches.py +++ b/.github/scripts/delete_old_branches.py @@ -4,7 +4,7 @@ import re from datetime import datetime from functools import lru_cache from pathlib import Path -from typing import Any, Callable, Dict, List, Set +from typing import Any, Callable from github_utils import gh_fetch_json_dict, gh_graphql from gitutils import GitRepo @@ -112,7 +112,7 @@ def convert_gh_timestamp(date: str) -> float: return datetime.strptime(date, "%Y-%m-%dT%H:%M:%SZ").timestamp() -def get_branches(repo: GitRepo) -> Dict[str, Any]: +def get_branches(repo: GitRepo) -> dict[str, Any]: # Query locally for branches, group by branch base name (e.g. gh/blah/base -> gh/blah), and get the most recent branch git_response = repo._run_git( "for-each-ref", @@ -120,7 +120,7 @@ def get_branches(repo: GitRepo) -> Dict[str, Any]: "--format=%(refname) %(committerdate:iso-strict)", "refs/remotes/origin", ) - branches_by_base_name: Dict[str, Any] = {} + branches_by_base_name: dict[str, Any] = {} for line in git_response.splitlines(): branch, date = line.split(" ") re_branch = re.match(r"refs/remotes/origin/(.*)", branch) @@ -140,14 +140,14 @@ def get_branches(repo: GitRepo) -> Dict[str, Any]: def paginate_graphql( query: str, - kwargs: Dict[str, Any], - termination_func: Callable[[List[Dict[str, Any]]], bool], - get_data: Callable[[Dict[str, Any]], List[Dict[str, Any]]], - get_page_info: Callable[[Dict[str, Any]], Dict[str, Any]], -) -> List[Any]: + kwargs: dict[str, Any], + termination_func: Callable[[list[dict[str, Any]]], bool], + get_data: Callable[[dict[str, Any]], list[dict[str, Any]]], + get_page_info: Callable[[dict[str, Any]], dict[str, Any]], +) -> list[Any]: hasNextPage = True endCursor = None - data: List[Dict[str, Any]] = [] + data: list[dict[str, Any]] = [] while hasNextPage: ESTIMATED_TOKENS[0] += 1 res = gh_graphql(query, cursor=endCursor, **kwargs) @@ -159,11 +159,11 @@ def paginate_graphql( return data -def get_recent_prs() -> Dict[str, Any]: +def get_recent_prs() -> dict[str, Any]: now = datetime.now().timestamp() # Grab all PRs updated in last CLOSED_PR_RETENTION days - pr_infos: List[Dict[str, Any]] = paginate_graphql( + pr_infos: list[dict[str, Any]] = paginate_graphql( GRAPHQL_ALL_PRS_BY_UPDATED_AT, {"owner": "pytorch", "repo": "pytorch"}, lambda data: ( @@ -190,7 +190,7 @@ def get_recent_prs() -> Dict[str, Any]: @lru_cache(maxsize=1) -def get_open_prs() -> List[Dict[str, Any]]: +def get_open_prs() -> list[dict[str, Any]]: return paginate_graphql( GRAPHQL_OPEN_PRS, {"owner": "pytorch", "repo": "pytorch"}, @@ -200,8 +200,8 @@ def get_open_prs() -> List[Dict[str, Any]]: ) -def get_branches_with_magic_label_or_open_pr() -> Set[str]: - pr_infos: List[Dict[str, Any]] = paginate_graphql( +def get_branches_with_magic_label_or_open_pr() -> set[str]: + pr_infos: list[dict[str, Any]] = paginate_graphql( GRAPHQL_NO_DELETE_BRANCH_LABEL, {"owner": "pytorch", "repo": "pytorch"}, lambda data: False, diff --git a/.github/scripts/file_io_utils.py b/.github/scripts/file_io_utils.py index faba9f06d2a..9826cdececd 100644 --- a/.github/scripts/file_io_utils.py +++ b/.github/scripts/file_io_utils.py @@ -2,7 +2,7 @@ import json import re import shutil from pathlib import Path -from typing import Any, List +from typing import Any import boto3 # type: ignore[import] @@ -77,7 +77,7 @@ def upload_file_to_s3(file_name: Path, bucket: str, key: str) -> None: def download_s3_objects_with_prefix( bucket_name: str, prefix: str, download_folder: Path -) -> List[Path]: +) -> list[Path]: s3 = boto3.resource("s3") bucket = s3.Bucket(bucket_name) diff --git a/.github/scripts/filter_test_configs.py b/.github/scripts/filter_test_configs.py index 6c4d53837f9..27e934fb3b9 100755 --- a/.github/scripts/filter_test_configs.py +++ b/.github/scripts/filter_test_configs.py @@ -8,9 +8,9 @@ import subprocess import sys import warnings from enum import Enum -from functools import lru_cache +from functools import cache from logging import info -from typing import Any, Callable, Dict, List, Optional, Set +from typing import Any, Callable, Optional from urllib.request import Request, urlopen import yaml @@ -32,7 +32,7 @@ def is_cuda_or_rocm_job(job_name: Optional[str]) -> bool: # Supported modes when running periodically. Only applying the mode when # its lambda condition returns true -SUPPORTED_PERIODICAL_MODES: Dict[str, Callable[[Optional[str]], bool]] = { +SUPPORTED_PERIODICAL_MODES: dict[str, Callable[[Optional[str]], bool]] = { # Memory leak check is only needed for CUDA and ROCm jobs which utilize GPU memory "mem_leak_check": is_cuda_or_rocm_job, "rerun_disabled_tests": lambda job_name: True, @@ -102,8 +102,8 @@ def parse_args() -> Any: return parser.parse_args() -@lru_cache(maxsize=None) -def get_pr_info(pr_number: int) -> Dict[str, Any]: +@cache +def get_pr_info(pr_number: int) -> dict[str, Any]: """ Dynamically get PR information """ @@ -116,7 +116,7 @@ def get_pr_info(pr_number: int) -> Dict[str, Any]: "Accept": "application/vnd.github.v3+json", "Authorization": f"token {github_token}", } - json_response: Dict[str, Any] = download_json( + json_response: dict[str, Any] = download_json( url=f"{pytorch_github_api}/issues/{pr_number}", headers=headers, ) @@ -128,7 +128,7 @@ def get_pr_info(pr_number: int) -> Dict[str, Any]: return json_response -def get_labels(pr_number: int) -> Set[str]: +def get_labels(pr_number: int) -> set[str]: """ Dynamically get the latest list of labels from the pull request """ @@ -138,14 +138,14 @@ def get_labels(pr_number: int) -> Set[str]: } -def filter_labels(labels: Set[str], label_regex: Any) -> Set[str]: +def filter_labels(labels: set[str], label_regex: Any) -> set[str]: """ Return the list of matching labels """ return {l for l in labels if re.match(label_regex, l)} -def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, List[Any]]: +def filter(test_matrix: dict[str, list[Any]], labels: set[str]) -> dict[str, list[Any]]: """ Select the list of test config to run from the test matrix. The logic works as follows: @@ -157,7 +157,7 @@ def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, Lis If the PR has none of the test-config label, all tests are run as usual. """ - filtered_test_matrix: Dict[str, List[Any]] = {"include": []} + filtered_test_matrix: dict[str, list[Any]] = {"include": []} for entry in test_matrix.get("include", []): config_name = entry.get("config", "") @@ -185,8 +185,8 @@ def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, Lis def filter_selected_test_configs( - test_matrix: Dict[str, List[Any]], selected_test_configs: Set[str] -) -> Dict[str, List[Any]]: + test_matrix: dict[str, list[Any]], selected_test_configs: set[str] +) -> dict[str, list[Any]]: """ Keep only the selected configs if the list if not empty. Otherwise, keep all test configs. This filter is used when the workflow is dispatched manually. @@ -194,7 +194,7 @@ def filter_selected_test_configs( if not selected_test_configs: return test_matrix - filtered_test_matrix: Dict[str, List[Any]] = {"include": []} + filtered_test_matrix: dict[str, list[Any]] = {"include": []} for entry in test_matrix.get("include", []): config_name = entry.get("config", "") if not config_name: @@ -207,12 +207,12 @@ def filter_selected_test_configs( def set_periodic_modes( - test_matrix: Dict[str, List[Any]], job_name: Optional[str] -) -> Dict[str, List[Any]]: + test_matrix: dict[str, list[Any]], job_name: Optional[str] +) -> dict[str, list[Any]]: """ Apply all periodic modes when running under a schedule """ - scheduled_test_matrix: Dict[str, List[Any]] = { + scheduled_test_matrix: dict[str, list[Any]] = { "include": [], } @@ -229,8 +229,8 @@ def set_periodic_modes( def mark_unstable_jobs( - workflow: str, job_name: str, test_matrix: Dict[str, List[Any]] -) -> Dict[str, List[Any]]: + workflow: str, job_name: str, test_matrix: dict[str, list[Any]] +) -> dict[str, list[Any]]: """ Check the list of unstable jobs and mark them accordingly. Note that if a job is unstable, all its dependents will also be marked accordingly @@ -245,8 +245,8 @@ def mark_unstable_jobs( def remove_disabled_jobs( - workflow: str, job_name: str, test_matrix: Dict[str, List[Any]] -) -> Dict[str, List[Any]]: + workflow: str, job_name: str, test_matrix: dict[str, list[Any]] +) -> dict[str, list[Any]]: """ Check the list of disabled jobs, remove the current job and all its dependents if it exists in the list @@ -261,15 +261,15 @@ def remove_disabled_jobs( def _filter_jobs( - test_matrix: Dict[str, List[Any]], + test_matrix: dict[str, list[Any]], issue_type: IssueType, target_cfg: Optional[str] = None, -) -> Dict[str, List[Any]]: +) -> dict[str, list[Any]]: """ An utility function used to actually apply the job filter """ # The result will be stored here - filtered_test_matrix: Dict[str, List[Any]] = {"include": []} + filtered_test_matrix: dict[str, list[Any]] = {"include": []} # This is an issue to disable a CI job if issue_type == IssueType.DISABLED: @@ -302,10 +302,10 @@ def _filter_jobs( def process_jobs( workflow: str, job_name: str, - test_matrix: Dict[str, List[Any]], + test_matrix: dict[str, list[Any]], issue_type: IssueType, url: str, -) -> Dict[str, List[Any]]: +) -> dict[str, list[Any]]: """ Both disabled and unstable jobs are in the following format: @@ -441,7 +441,7 @@ def process_jobs( return test_matrix -def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any: +def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any: for _ in range(num_retries): try: req = Request(url=url, headers=headers) @@ -462,7 +462,7 @@ def set_output(name: str, val: Any) -> None: print(f"::set-output name={name}::{val}") -def parse_reenabled_issues(s: Optional[str]) -> List[str]: +def parse_reenabled_issues(s: Optional[str]) -> list[str]: # NB: When the PR body is empty, GitHub API returns a None value, which is # passed into this function if not s: @@ -477,7 +477,7 @@ def parse_reenabled_issues(s: Optional[str]) -> List[str]: return issue_numbers -def get_reenabled_issues(pr_body: str = "") -> List[str]: +def get_reenabled_issues(pr_body: str = "") -> list[str]: default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}" try: commit_messages = subprocess.check_output( @@ -489,12 +489,12 @@ def get_reenabled_issues(pr_body: str = "") -> List[str]: return parse_reenabled_issues(pr_body) + parse_reenabled_issues(commit_messages) -def check_for_setting(labels: Set[str], body: str, setting: str) -> bool: +def check_for_setting(labels: set[str], body: str, setting: str) -> bool: return setting in labels or f"[{setting}]" in body def perform_misc_tasks( - labels: Set[str], test_matrix: Dict[str, List[Any]], job_name: str, pr_body: str + labels: set[str], test_matrix: dict[str, list[Any]], job_name: str, pr_body: str ) -> None: """ In addition to apply the filter logic, the script also does the following diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 23aec2d3314..c9280a746dc 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -12,7 +12,7 @@ architectures: """ import os -from typing import Dict, List, Optional, Tuple +from typing import Optional # NOTE: Also update the CUDA sources in tools/nightly.py when changing this list @@ -181,7 +181,7 @@ CXX11_ABI = "cxx11-abi" RELEASE = "release" DEBUG = "debug" -LIBTORCH_CONTAINER_IMAGES: Dict[Tuple[str, str], str] = { +LIBTORCH_CONTAINER_IMAGES: dict[tuple[str, str], str] = { **{ ( gpu_arch, @@ -223,16 +223,16 @@ def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str: }.get(gpu_arch_type, gpu_arch_version) -def list_without(in_list: List[str], without: List[str]) -> List[str]: +def list_without(in_list: list[str], without: list[str]) -> list[str]: return [item for item in in_list if item not in without] def generate_libtorch_matrix( os: str, abi_version: str, - arches: Optional[List[str]] = None, - libtorch_variants: Optional[List[str]] = None, -) -> List[Dict[str, str]]: + arches: Optional[list[str]] = None, + libtorch_variants: Optional[list[str]] = None, +) -> list[dict[str, str]]: if arches is None: arches = ["cpu"] if os == "linux": @@ -248,7 +248,7 @@ def generate_libtorch_matrix( "static-without-deps", ] - ret: List[Dict[str, str]] = [] + ret: list[dict[str, str]] = [] for arch_version in arches: for libtorch_variant in libtorch_variants: # one of the values in the following list must be exactly @@ -287,10 +287,10 @@ def generate_libtorch_matrix( def generate_wheels_matrix( os: str, - arches: Optional[List[str]] = None, - python_versions: Optional[List[str]] = None, + arches: Optional[list[str]] = None, + python_versions: Optional[list[str]] = None, use_split_build: bool = False, -) -> List[Dict[str, str]]: +) -> list[dict[str, str]]: package_type = "wheel" if os == "linux" or os == "linux-aarch64" or os == "linux-s390x": # NOTE: We only build manywheel packages for x86_64 and aarch64 and s390x linux @@ -315,7 +315,7 @@ def generate_wheels_matrix( # uses different build/test scripts arches = ["cpu-s390x"] - ret: List[Dict[str, str]] = [] + ret: list[dict[str, str]] = [] for python_version in python_versions: for arch_version in arches: gpu_arch_type = arch_type(arch_version) diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 8512b27f0c0..83169ad438a 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -2,9 +2,10 @@ import os import sys +from collections.abc import Iterable from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Dict, Iterable, List, Literal, Set +from typing import Literal from typing_extensions import TypedDict # Python 3.11+ import generate_binary_build_matrix # type: ignore[import] @@ -27,7 +28,7 @@ LABEL_CIFLOW_BINARIES_WHEEL = "ciflow/binaries_wheel" class CIFlowConfig: # For use to enable workflows to run on pytorch/pytorch-canary run_on_canary: bool = False - labels: Set[str] = field(default_factory=set) + labels: set[str] = field(default_factory=set) # Certain jobs might not want to be part of the ciflow/[all,trunk] workflow isolated_workflow: bool = False unstable: bool = False @@ -48,7 +49,7 @@ class Config(TypedDict): @dataclass class BinaryBuildWorkflow: os: str - build_configs: List[Dict[str, str]] + build_configs: list[dict[str, str]] package_type: str # Optional fields diff --git a/.github/scripts/get_workflow_job_id.py b/.github/scripts/get_workflow_job_id.py index 76ba52fbe37..cfbfe315bf6 100644 --- a/.github/scripts/get_workflow_job_id.py +++ b/.github/scripts/get_workflow_job_id.py @@ -11,11 +11,11 @@ import sys import time import urllib import urllib.parse -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional from urllib.request import Request, urlopen -def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]: +def parse_json_and_links(conn: Any) -> tuple[Any, dict[str, dict[str, str]]]: links = {} # Extract links which GH uses for pagination # see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link @@ -42,7 +42,7 @@ def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]: def fetch_url( url: str, *, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, reader: Callable[[Any], Any] = lambda x: x.read(), retries: Optional[int] = 3, backoff_timeout: float = 0.5, @@ -83,7 +83,7 @@ def parse_args() -> Any: return parser.parse_args() -def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]: +def fetch_jobs(url: str, headers: dict[str, str]) -> list[dict[str, str]]: response, links = fetch_url(url, headers=headers, reader=parse_json_and_links) jobs = response["jobs"] assert type(jobs) is list @@ -111,7 +111,7 @@ def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]: # running. -def find_job_id_name(args: Any) -> Tuple[str, str]: +def find_job_id_name(args: Any) -> tuple[str, str]: # From https://docs.github.com/en/actions/learn-github-actions/environment-variables PYTORCH_REPO = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch") PYTORCH_GITHUB_API = f"https://api.github.com/repos/{PYTORCH_REPO}" diff --git a/.github/scripts/github_utils.py b/.github/scripts/github_utils.py index ed41b50c942..cd19907189f 100644 --- a/.github/scripts/github_utils.py +++ b/.github/scripts/github_utils.py @@ -4,7 +4,7 @@ import json import os import warnings from dataclasses import dataclass -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Optional, Union from urllib.error import HTTPError from urllib.parse import quote from urllib.request import Request, urlopen @@ -27,11 +27,11 @@ class GitHubComment: def gh_fetch_url_and_headers( url: str, *, - headers: Optional[Dict[str, str]] = None, - data: Union[Optional[Dict[str, Any]], str] = None, + headers: Optional[dict[str, str]] = None, + data: Union[Optional[dict[str, Any]], str] = None, method: Optional[str] = None, reader: Callable[[Any], Any] = lambda x: x.read(), -) -> Tuple[Any, Any]: +) -> tuple[Any, Any]: if headers is None: headers = {} token = os.environ.get("GITHUB_TOKEN") @@ -70,8 +70,8 @@ def gh_fetch_url_and_headers( def gh_fetch_url( url: str, *, - headers: Optional[Dict[str, str]] = None, - data: Union[Optional[Dict[str, Any]], str] = None, + headers: Optional[dict[str, str]] = None, + data: Union[Optional[dict[str, Any]], str] = None, method: Optional[str] = None, reader: Callable[[Any], Any] = json.load, ) -> Any: @@ -82,25 +82,25 @@ def gh_fetch_url( def gh_fetch_json( url: str, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, + params: Optional[dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, method: Optional[str] = None, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: headers = {"Accept": "application/vnd.github.v3+json"} if params is not None and len(params) > 0: url += "?" + "&".join( f"{name}={quote(str(val))}" for name, val in params.items() ) return cast( - List[Dict[str, Any]], + list[dict[str, Any]], gh_fetch_url(url, headers=headers, data=data, reader=json.load, method=method), ) def _gh_fetch_json_any( url: str, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, + params: Optional[dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, ) -> Any: headers = {"Accept": "application/vnd.github.v3+json"} if params is not None and len(params) > 0: @@ -112,21 +112,21 @@ def _gh_fetch_json_any( def gh_fetch_json_list( url: str, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, -) -> List[Dict[str, Any]]: - return cast(List[Dict[str, Any]], _gh_fetch_json_any(url, params, data)) + params: Optional[dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, +) -> list[dict[str, Any]]: + return cast(list[dict[str, Any]], _gh_fetch_json_any(url, params, data)) def gh_fetch_json_dict( url: str, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data)) + params: Optional[dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + return cast(dict[str, Any], _gh_fetch_json_any(url, params, data)) -def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]: +def gh_graphql(query: str, **kwargs: Any) -> dict[str, Any]: rc = gh_fetch_url( "https://api.github.com/graphql", data={"query": query, "variables": kwargs}, @@ -136,12 +136,12 @@ def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]: raise RuntimeError( f"GraphQL query {query}, args {kwargs} failed: {rc['errors']}" ) - return cast(Dict[str, Any], rc) + return cast(dict[str, Any], rc) def _gh_post_comment( url: str, comment: str, dry_run: bool = False -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: if dry_run: print(comment) return [] @@ -150,7 +150,7 @@ def _gh_post_comment( def gh_post_pr_comment( org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: return _gh_post_comment( f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/{pr_num}/comments", comment, @@ -160,7 +160,7 @@ def gh_post_pr_comment( def gh_post_commit_comment( org: str, repo: str, sha: str, comment: str, dry_run: bool = False -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: return _gh_post_comment( f"{GITHUB_API_URL}/repos/{org}/{repo}/commits/{sha}/comments", comment, @@ -220,8 +220,8 @@ def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") -> def gh_query_issues_by_labels( - org: str, repo: str, labels: List[str], state: str = "open" -) -> List[Dict[str, Any]]: + org: str, repo: str, labels: list[str], state: str = "open" +) -> list[dict[str, Any]]: url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues" return gh_fetch_json( url, method="GET", params={"labels": ",".join(labels), "state": state} diff --git a/.github/scripts/gitutils.py b/.github/scripts/gitutils.py index 42f16366032..43ee063bd63 100644 --- a/.github/scripts/gitutils.py +++ b/.github/scripts/gitutils.py @@ -4,20 +4,10 @@ import os import re import tempfile from collections import defaultdict +from collections.abc import Iterator from datetime import datetime from functools import wraps -from typing import ( - Any, - Callable, - cast, - Dict, - Iterator, - List, - Optional, - Tuple, - TypeVar, - Union, -) +from typing import Any, Callable, cast, Optional, TypeVar, Union T = TypeVar("T") @@ -35,17 +25,17 @@ def get_git_repo_dir() -> str: return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parents[2])) -def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]: +def fuzzy_list_to_dict(items: list[tuple[str, str]]) -> dict[str, list[str]]: """ Converts list to dict preserving elements with duplicate keys """ - rc: Dict[str, List[str]] = defaultdict(list) + rc: dict[str, list[str]] = defaultdict(list) for key, val in items: rc[key].append(val) return dict(rc) -def _check_output(items: List[str], encoding: str = "utf-8") -> str: +def _check_output(items: list[str], encoding: str = "utf-8") -> str: from subprocess import CalledProcessError, check_output, STDOUT try: @@ -95,7 +85,7 @@ class GitCommit: return item in self.body or item in self.title -def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit: +def parse_fuller_format(lines: Union[str, list[str]]) -> GitCommit: """ Expect commit message generated using `--format=fuller --date=unix` format, i.e.: commit @@ -142,13 +132,13 @@ class GitRepo: print(f"+ git -C {self.repo_dir} {' '.join(args)}") return _check_output(["git", "-C", self.repo_dir] + list(args)) - def revlist(self, revision_range: str) -> List[str]: + def revlist(self, revision_range: str) -> list[str]: rc = self._run_git("rev-list", revision_range, "--", ".").strip() return rc.split("\n") if len(rc) > 0 else [] def branches_containing_ref( self, ref: str, *, include_remote: bool = True - ) -> List[str]: + ) -> list[str]: rc = ( self._run_git("branch", "--remote", "--contains", ref) if include_remote @@ -189,7 +179,7 @@ class GitRepo: def get_merge_base(self, from_ref: str, to_ref: str) -> str: return self._run_git("merge-base", from_ref, to_ref).strip() - def patch_id(self, ref: Union[str, List[str]]) -> List[Tuple[str, str]]: + def patch_id(self, ref: Union[str, list[str]]) -> list[tuple[str, str]]: is_list = isinstance(ref, list) if is_list: if len(ref) == 0: @@ -198,9 +188,9 @@ class GitRepo: rc = _check_output( ["sh", "-c", f"git -C {self.repo_dir} show {ref}|git patch-id --stable"] ).strip() - return [cast(Tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")] + return [cast(tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")] - def commits_resolving_gh_pr(self, pr_num: int) -> List[str]: + def commits_resolving_gh_pr(self, pr_num: int) -> list[str]: owner, name = self.gh_owner_and_name() msg = f"Pull Request resolved: https://github.com/{owner}/{name}/pull/{pr_num}" rc = self._run_git("log", "--format=%H", "--grep", msg).strip() @@ -219,7 +209,7 @@ class GitRepo: def compute_branch_diffs( self, from_branch: str, to_branch: str - ) -> Tuple[List[str], List[str]]: + ) -> tuple[list[str], list[str]]: """ Returns list of commmits that are missing in each other branch since their merge base Might be slow if merge base is between two branches is pretty far off @@ -311,14 +301,14 @@ class GitRepo: def remote_url(self) -> str: return self._run_git("remote", "get-url", self.remote) - def gh_owner_and_name(self) -> Tuple[str, str]: + def gh_owner_and_name(self) -> tuple[str, str]: url = os.getenv("GIT_REMOTE_URL", None) if url is None: url = self.remote_url() rc = RE_GITHUB_URL_MATCH.match(url) if rc is None: raise RuntimeError(f"Unexpected url format {url}") - return cast(Tuple[str, str], rc.groups()) + return cast(tuple[str, str], rc.groups()) def commit_message(self, ref: str) -> str: return self._run_git("log", "-1", "--format=%B", ref) @@ -366,7 +356,7 @@ class PeekableIterator(Iterator[str]): return rc -def patterns_to_regex(allowed_patterns: List[str]) -> Any: +def patterns_to_regex(allowed_patterns: list[str]) -> Any: """ pattern is glob-like, i.e. the only special sequences it has are: - ? - matches single character @@ -437,7 +427,7 @@ def retries_decorator( ) -> Callable[[Callable[..., T]], Callable[..., T]]: def decorator(f: Callable[..., T]) -> Callable[..., T]: @wraps(f) - def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> T: + def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> T: for idx in range(num_retries): try: return f(*args, **kwargs) diff --git a/.github/scripts/label_utils.py b/.github/scripts/label_utils.py index e4f2fa9e21a..8da0c49ba92 100644 --- a/.github/scripts/label_utils.py +++ b/.github/scripts/label_utils.py @@ -2,7 +2,7 @@ import json from functools import lru_cache -from typing import Any, List, Tuple, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING, Union from github_utils import gh_fetch_url_and_headers, GitHubComment @@ -28,14 +28,14 @@ https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for """ -def request_for_labels(url: str) -> Tuple[Any, Any]: +def request_for_labels(url: str) -> tuple[Any, Any]: headers = {"Accept": "application/vnd.github.v3+json"} return gh_fetch_url_and_headers( url, headers=headers, reader=lambda x: x.read().decode("utf-8") ) -def update_labels(labels: List[str], info: str) -> None: +def update_labels(labels: list[str], info: str) -> None: labels_json = json.loads(info) labels.extend([x["name"] for x in labels_json]) @@ -56,10 +56,10 @@ def get_last_page_num_from_header(header: Any) -> int: @lru_cache -def gh_get_labels(org: str, repo: str) -> List[str]: +def gh_get_labels(org: str, repo: str) -> list[str]: prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100" header, info = request_for_labels(prefix + "&page=1") - labels: List[str] = [] + labels: list[str] = [] update_labels(labels, info) last_page = get_last_page_num_from_header(header) @@ -74,7 +74,7 @@ def gh_get_labels(org: str, repo: str) -> List[str]: def gh_add_labels( - org: str, repo: str, pr_num: int, labels: Union[str, List[str]], dry_run: bool + org: str, repo: str, pr_num: int, labels: Union[str, list[str]], dry_run: bool ) -> None: if dry_run: print(f"Dryrun: Adding labels {labels} to PR {pr_num}") @@ -97,7 +97,7 @@ def gh_remove_label( ) -def get_release_notes_labels(org: str, repo: str) -> List[str]: +def get_release_notes_labels(org: str, repo: str) -> list[str]: return [ label for label in gh_get_labels(org, repo) diff --git a/.github/scripts/pytest_caching_utils.py b/.github/scripts/pytest_caching_utils.py index e4adfc8699a..0141bfd8da6 100644 --- a/.github/scripts/pytest_caching_utils.py +++ b/.github/scripts/pytest_caching_utils.py @@ -1,7 +1,7 @@ import hashlib import os from pathlib import Path -from typing import Dict, NamedTuple +from typing import NamedTuple from file_io_utils import ( copy_file, @@ -219,8 +219,8 @@ def _merge_lastfailed_files(source_pytest_cache: Path, dest_pytest_cache: Path) def _merged_lastfailed_content( - from_lastfailed: Dict[str, bool], to_lastfailed: Dict[str, bool] -) -> Dict[str, bool]: + from_lastfailed: dict[str, bool], to_lastfailed: dict[str, bool] +) -> dict[str, bool]: """ The lastfailed files are dictionaries where the key is the test identifier. Each entry's value appears to always be `true`, but let's not count on that. diff --git a/.github/scripts/runner_determinator.py b/.github/scripts/runner_determinator.py index 96ea30fd1f2..e6846e42475 100644 --- a/.github/scripts/runner_determinator.py +++ b/.github/scripts/runner_determinator.py @@ -61,9 +61,10 @@ import random import re import sys from argparse import ArgumentParser -from functools import lru_cache +from collections.abc import Iterable +from functools import cache from logging import LogRecord -from typing import Any, Dict, FrozenSet, Iterable, List, NamedTuple, Set, Tuple +from typing import Any, NamedTuple from urllib.request import Request, urlopen import yaml @@ -105,7 +106,7 @@ class Settings(NamedTuple): Settings for the experiments that can be opted into. """ - experiments: Dict[str, Experiment] = {} + experiments: dict[str, Experiment] = {} class ColorFormatter(logging.Formatter): @@ -150,7 +151,7 @@ def set_github_output(key: str, value: str) -> None: f.write(f"{key}={value}\n") -def _str_comma_separated_to_set(value: str) -> FrozenSet[str]: +def _str_comma_separated_to_set(value: str) -> frozenset[str]: return frozenset( filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(","))) ) @@ -208,12 +209,12 @@ def parse_args() -> Any: return parser.parse_args() -def get_gh_client(github_token: str) -> Github: +def get_gh_client(github_token: str) -> Github: # type: ignore[no-any-unimported] auth = Auth.Token(github_token) return Github(auth=auth) -def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: +def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: # type: ignore[no-any-unimported] repo = gh.get_repo(repo) return repo.get_issue(number=issue_num) @@ -242,7 +243,7 @@ def get_potential_pr_author( raise Exception( # noqa: TRY002 f"issue with pull request {pr_number} from repo {repository}" ) from e - return pull.user.login + return pull.user.login # type: ignore[no-any-return] # In all other cases, return the original input username return username @@ -263,7 +264,7 @@ def load_yaml(yaml_text: str) -> Any: raise -def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]: +def extract_settings_user_opt_in_from_text(rollout_state: str) -> tuple[str, str]: """ Extracts the text with settings, if any, and the opted in users from the rollout state. @@ -279,7 +280,7 @@ def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str return "", rollout_state -class UserOptins(Dict[str, List[str]]): +class UserOptins(dict[str, list[str]]): """ Dictionary of users with a list of features they have opted into """ @@ -420,7 +421,7 @@ def get_runner_prefix( rollout_state: str, workflow_requestors: Iterable[str], branch: str, - eligible_experiments: FrozenSet[str] = frozenset(), + eligible_experiments: frozenset[str] = frozenset(), is_canary: bool = False, ) -> str: settings = parse_settings(rollout_state) @@ -519,7 +520,7 @@ def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) - return str(issue.get_comments()[0].body.strip("\n\t ")) -def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any: +def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any: for _ in range(num_retries): try: req = Request(url=url, headers=headers) @@ -532,8 +533,8 @@ def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> An return {} -@lru_cache(maxsize=None) -def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str, Any]: +@cache +def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> dict[str, Any]: """ Dynamically get PR information """ @@ -542,7 +543,7 @@ def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str "Accept": "application/vnd.github.v3+json", "Authorization": f"token {github_token}", } - json_response: Dict[str, Any] = download_json( + json_response: dict[str, Any] = download_json( url=f"{github_api}/issues/{pr_number}", headers=headers, ) @@ -554,7 +555,7 @@ def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str return json_response -def get_labels(github_repo: str, github_token: str, pr_number: int) -> Set[str]: +def get_labels(github_repo: str, github_token: str, pr_number: int) -> set[str]: """ Dynamically get the latest list of labels from the pull request """ diff --git a/.github/scripts/tag_docker_images_for_release.py b/.github/scripts/tag_docker_images_for_release.py index 19311769416..b2bf474575f 100644 --- a/.github/scripts/tag_docker_images_for_release.py +++ b/.github/scripts/tag_docker_images_for_release.py @@ -1,6 +1,5 @@ import argparse import subprocess -from typing import Dict import generate_binary_build_matrix @@ -10,7 +9,7 @@ def tag_image( default_tag: str, release_version: str, dry_run: str, - tagged_images: Dict[str, bool], + tagged_images: dict[str, bool], ) -> None: if image in tagged_images: return @@ -41,7 +40,7 @@ def main() -> None: ) options = parser.parse_args() - tagged_images: Dict[str, bool] = {} + tagged_images: dict[str, bool] = {} platform_images = [ generate_binary_build_matrix.WHEEL_CONTAINER_IMAGES, generate_binary_build_matrix.LIBTORCH_CONTAINER_IMAGES, diff --git a/.github/scripts/test_check_labels.py b/.github/scripts/test_check_labels.py index 1c921f2eafa..15b9d806b30 100644 --- a/.github/scripts/test_check_labels.py +++ b/.github/scripts/test_check_labels.py @@ -1,6 +1,6 @@ """test_check_labels.py""" -from typing import Any, List +from typing import Any from unittest import main, mock, TestCase from check_labels import ( @@ -31,7 +31,7 @@ def mock_delete_all_label_err_comments(pr: "GitHubPR") -> None: pass -def mock_get_comments() -> List[GitHubComment]: +def mock_get_comments() -> list[GitHubComment]: return [ # Case 1 - a non label err comment GitHubComment( diff --git a/.github/scripts/test_filter_test_configs.py b/.github/scripts/test_filter_test_configs.py index 421da22f7e4..2bc30fdc1e2 100755 --- a/.github/scripts/test_filter_test_configs.py +++ b/.github/scripts/test_filter_test_configs.py @@ -3,7 +3,7 @@ import json import os import tempfile -from typing import Any, Dict, List +from typing import Any from unittest import main, mock, TestCase import yaml @@ -362,7 +362,7 @@ class TestConfigFilter(TestCase): self.assertEqual(case["expected"], json.dumps(filtered_test_matrix)) def test_set_periodic_modes(self) -> None: - testcases: List[Dict[str, str]] = [ + testcases: list[dict[str, str]] = [ { "job_name": "a CI job", "test_matrix": "{include: []}", @@ -702,7 +702,7 @@ class TestConfigFilter(TestCase): ) mocked_subprocess.return_value = b"" - testcases: List[Dict[str, Any]] = [ + testcases: list[dict[str, Any]] = [ { "labels": {}, "test_matrix": '{include: [{config: "default"}]}', diff --git a/.github/scripts/test_trymerge.py b/.github/scripts/test_trymerge.py index 3bbf701cb5f..af41345088d 100755 --- a/.github/scripts/test_trymerge.py +++ b/.github/scripts/test_trymerge.py @@ -12,7 +12,7 @@ import json import os import warnings from hashlib import sha256 -from typing import Any, List, Optional +from typing import Any, Optional from unittest import main, mock, skip, TestCase from urllib.error import HTTPError @@ -170,7 +170,7 @@ def mock_gh_get_info() -> Any: } -def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeRule]: +def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> list[MergeRule]: return [ MergeRule( name="mock with nonexistent check", @@ -182,7 +182,7 @@ def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeR ] -def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]: +def mocked_read_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]: return [ MergeRule( name="super", @@ -211,7 +211,7 @@ def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule def mocked_read_merge_rules_approvers( repo: Any, org: str, project: str -) -> List[MergeRule]: +) -> list[MergeRule]: return [ MergeRule( name="Core Reviewers", @@ -234,11 +234,11 @@ def mocked_read_merge_rules_approvers( ] -def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> List[MergeRule]: +def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> list[MergeRule]: raise RuntimeError("testing") -def xla_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]: +def xla_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]: return [ MergeRule( name=" OSS CI / pytorchbot / XLA", @@ -260,7 +260,7 @@ class DummyGitRepo(GitRepo): def __init__(self) -> None: super().__init__(get_git_repo_dir(), get_git_remote_name()) - def commits_resolving_gh_pr(self, pr_num: int) -> List[str]: + def commits_resolving_gh_pr(self, pr_num: int) -> list[str]: return ["FakeCommitSha"] def commit_message(self, ref: str) -> str: diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 21af4ca195b..349db71308b 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -17,21 +17,12 @@ import re import time import urllib.parse from collections import defaultdict +from collections.abc import Iterable from dataclasses import dataclass -from functools import lru_cache +from functools import cache from pathlib import Path -from typing import ( - Any, - Callable, - cast, - Dict, - Iterable, - List, - NamedTuple, - Optional, - Pattern, - Tuple, -) +from re import Pattern +from typing import Any, Callable, cast, NamedTuple, Optional from warnings import warn import yaml @@ -78,7 +69,7 @@ class JobCheckState(NamedTuple): summary: Optional[str] -JobNameToStateDict = Dict[str, JobCheckState] +JobNameToStateDict = dict[str, JobCheckState] class WorkflowCheckState: @@ -468,10 +459,10 @@ def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any: return rc["data"]["repository"]["pullRequest"] -@lru_cache(maxsize=None) -def gh_get_team_members(org: str, name: str) -> List[str]: - rc: List[str] = [] - team_members: Dict[str, Any] = { +@cache +def gh_get_team_members(org: str, name: str) -> list[str]: + rc: list[str] = [] + team_members: dict[str, Any] = { "pageInfo": {"hasNextPage": "true", "endCursor": None} } while bool(team_members["pageInfo"]["hasNextPage"]): @@ -503,14 +494,14 @@ def is_passing_status(status: Optional[str]) -> bool: def add_workflow_conclusions( checksuites: Any, - get_next_checkruns_page: Callable[[List[Dict[str, Dict[str, Any]]], int, Any], Any], + get_next_checkruns_page: Callable[[list[dict[str, dict[str, Any]]], int, Any], Any], get_next_checksuites: Callable[[Any], Any], ) -> JobNameToStateDict: # graphql seems to favor the most recent workflow run, so in theory we # shouldn't need to account for reruns, but do it just in case # workflow -> job -> job info - workflows: Dict[str, WorkflowCheckState] = {} + workflows: dict[str, WorkflowCheckState] = {} # for the jobs that don't have a workflow no_workflow_obj: WorkflowCheckState = WorkflowCheckState("", "", 0, None) @@ -633,8 +624,8 @@ def _revlist_to_prs( pr: "GitHubPR", rev_list: Iterable[str], should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None, -) -> List[Tuple["GitHubPR", str]]: - rc: List[Tuple[GitHubPR, str]] = [] +) -> list[tuple["GitHubPR", str]]: + rc: list[tuple[GitHubPR, str]] = [] for idx, rev in enumerate(rev_list): msg = repo.commit_message(rev) m = RE_PULL_REQUEST_RESOLVED.search(msg) @@ -656,7 +647,7 @@ def _revlist_to_prs( def get_ghstack_prs( repo: GitRepo, pr: "GitHubPR", open_only: bool = True -) -> List[Tuple["GitHubPR", str]]: +) -> list[tuple["GitHubPR", str]]: """ Get the PRs in the stack that are below this PR (inclusive). Throws error if any of the open PRs are out of sync. @:param open_only: Only return open PRs @@ -701,14 +692,14 @@ class GitHubPR: self.project = project self.pr_num = pr_num self.info = gh_get_pr_info(org, project, pr_num) - self.changed_files: Optional[List[str]] = None - self.labels: Optional[List[str]] = None + self.changed_files: Optional[list[str]] = None + self.labels: Optional[list[str]] = None self.conclusions: Optional[JobNameToStateDict] = None - self.comments: Optional[List[GitHubComment]] = None - self._authors: Optional[List[Tuple[str, str]]] = None - self._reviews: Optional[List[Tuple[str, str]]] = None + self.comments: Optional[list[GitHubComment]] = None + self._authors: Optional[list[tuple[str, str]]] = None + self._reviews: Optional[list[tuple[str, str]]] = None self.merge_base: Optional[str] = None - self.submodules: Optional[List[str]] = None + self.submodules: Optional[list[str]] = None def is_closed(self) -> bool: return bool(self.info["closed"]) @@ -763,7 +754,7 @@ class GitHubPR: return self.merge_base - def get_changed_files(self) -> List[str]: + def get_changed_files(self) -> list[str]: if self.changed_files is None: info = self.info unique_changed_files = set() @@ -786,14 +777,14 @@ class GitHubPR: raise RuntimeError("Changed file count mismatch") return self.changed_files - def get_submodules(self) -> List[str]: + def get_submodules(self) -> list[str]: if self.submodules is None: rc = gh_graphql(GH_GET_REPO_SUBMODULES, name=self.project, owner=self.org) info = rc["data"]["repository"]["submodules"] self.submodules = [s["path"] for s in info["nodes"]] return self.submodules - def get_changed_submodules(self) -> List[str]: + def get_changed_submodules(self) -> list[str]: submodules = self.get_submodules() return [f for f in self.get_changed_files() if f in submodules] @@ -809,7 +800,7 @@ class GitHubPR: and all("submodule" not in label for label in self.get_labels()) ) - def _get_reviews(self) -> List[Tuple[str, str]]: + def _get_reviews(self) -> list[tuple[str, str]]: if self._reviews is None: self._reviews = [] info = self.info @@ -834,7 +825,7 @@ class GitHubPR: reviews[author] = state return list(reviews.items()) - def get_approved_by(self) -> List[str]: + def get_approved_by(self) -> list[str]: return [login for (login, state) in self._get_reviews() if state == "APPROVED"] def get_commit_count(self) -> int: @@ -843,12 +834,12 @@ class GitHubPR: def get_pr_creator_login(self) -> str: return cast(str, self.info["author"]["login"]) - def _fetch_authors(self) -> List[Tuple[str, str]]: + def _fetch_authors(self) -> list[tuple[str, str]]: if self._authors is not None: return self._authors - authors: List[Tuple[str, str]] = [] + authors: list[tuple[str, str]] = [] - def add_authors(info: Dict[str, Any]) -> None: + def add_authors(info: dict[str, Any]) -> None: for node in info["commits_with_authors"]["nodes"]: for author_node in node["commit"]["authors"]["nodes"]: user_node = author_node["user"] @@ -881,7 +872,7 @@ class GitHubPR: def get_committer_author(self, num: int = 0) -> str: return self._fetch_authors()[num][1] - def get_labels(self) -> List[str]: + def get_labels(self) -> list[str]: if self.labels is not None: return self.labels labels = ( @@ -899,7 +890,7 @@ class GitHubPR: orig_last_commit = self.last_commit() def get_pr_next_check_runs( - edges: List[Dict[str, Dict[str, Any]]], edge_idx: int, checkruns: Any + edges: list[dict[str, dict[str, Any]]], edge_idx: int, checkruns: Any ) -> Any: rc = gh_graphql( GH_GET_PR_NEXT_CHECK_RUNS, @@ -951,7 +942,7 @@ class GitHubPR: return self.conclusions - def get_authors(self) -> Dict[str, str]: + def get_authors(self) -> dict[str, str]: rc = {} for idx in range(len(self._fetch_authors())): rc[self.get_committer_login(idx)] = self.get_committer_author(idx) @@ -995,7 +986,7 @@ class GitHubPR: url=node["url"], ) - def get_comments(self) -> List[GitHubComment]: + def get_comments(self) -> list[GitHubComment]: if self.comments is not None: return self.comments self.comments = [] @@ -1069,7 +1060,7 @@ class GitHubPR: skip_mandatory_checks: bool, comment_id: Optional[int] = None, skip_all_rule_checks: bool = False, - ) -> List["GitHubPR"]: + ) -> list["GitHubPR"]: assert self.is_ghstack_pr() ghstack_prs = get_ghstack_prs( repo, self, open_only=False @@ -1099,7 +1090,7 @@ class GitHubPR: def gen_commit_message( self, filter_ghstack: bool = False, - ghstack_deps: Optional[List["GitHubPR"]] = None, + ghstack_deps: Optional[list["GitHubPR"]] = None, ) -> str: """Fetches title and body from PR description adds reviewed by, pull request resolved and optionally @@ -1151,7 +1142,7 @@ class GitHubPR: skip_mandatory_checks: bool = False, dry_run: bool = False, comment_id: Optional[int] = None, - ignore_current_checks: Optional[List[str]] = None, + ignore_current_checks: Optional[list[str]] = None, ) -> None: # Raises exception if matching rule is not found ( @@ -1223,7 +1214,7 @@ class GitHubPR: comment_id: Optional[int] = None, branch: Optional[str] = None, skip_all_rule_checks: bool = False, - ) -> List["GitHubPR"]: + ) -> list["GitHubPR"]: """ :param skip_all_rule_checks: If true, skips all rule checks, useful for dry-running merge locally """ @@ -1263,14 +1254,14 @@ class PostCommentError(Exception): @dataclass class MergeRule: name: str - patterns: List[str] - approved_by: List[str] - mandatory_checks_name: Optional[List[str]] + patterns: list[str] + approved_by: list[str] + mandatory_checks_name: Optional[list[str]] ignore_flaky_failures: bool = True def gen_new_issue_link( - org: str, project: str, labels: List[str], template: str = "bug-report.yml" + org: str, project: str, labels: list[str], template: str = "bug-report.yml" ) -> str: labels_str = ",".join(labels) return ( @@ -1282,7 +1273,7 @@ def gen_new_issue_link( def read_merge_rules( repo: Optional[GitRepo], org: str, project: str -) -> List[MergeRule]: +) -> list[MergeRule]: """Returns the list of all merge rules for the repo or project. NB: this function is used in Meta-internal workflows, see the comment @@ -1312,12 +1303,12 @@ def find_matching_merge_rule( repo: Optional[GitRepo] = None, skip_mandatory_checks: bool = False, skip_internal_checks: bool = False, - ignore_current_checks: Optional[List[str]] = None, -) -> Tuple[ + ignore_current_checks: Optional[list[str]] = None, +) -> tuple[ MergeRule, - List[Tuple[str, Optional[str], Optional[int]]], - List[Tuple[str, Optional[str], Optional[int]]], - Dict[str, List[Any]], + list[tuple[str, Optional[str], Optional[int]]], + list[tuple[str, Optional[str], Optional[int]]], + dict[str, list[Any]], ]: """ Returns merge rule matching to this pr together with the list of associated pending @@ -1504,13 +1495,13 @@ def find_matching_merge_rule( raise MergeRuleFailedError(reject_reason, rule) -def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str: +def checks_to_str(checks: list[tuple[str, Optional[str]]]) -> str: return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks) def checks_to_markdown_bullets( - checks: List[Tuple[str, Optional[str], Optional[int]]], -) -> List[str]: + checks: list[tuple[str, Optional[str], Optional[int]]], +) -> list[str]: return [ f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5] ] @@ -1518,7 +1509,7 @@ def checks_to_markdown_bullets( def manually_close_merged_pr( pr: GitHubPR, - additional_merged_prs: List[GitHubPR], + additional_merged_prs: list[GitHubPR], merge_commit_sha: str, dry_run: bool, ) -> None: @@ -1551,12 +1542,12 @@ def save_merge_record( owner: str, project: str, author: str, - pending_checks: List[Tuple[str, Optional[str], Optional[int]]], - failed_checks: List[Tuple[str, Optional[str], Optional[int]]], - ignore_current_checks: List[Tuple[str, Optional[str], Optional[int]]], - broken_trunk_checks: List[Tuple[str, Optional[str], Optional[int]]], - flaky_checks: List[Tuple[str, Optional[str], Optional[int]]], - unstable_checks: List[Tuple[str, Optional[str], Optional[int]]], + pending_checks: list[tuple[str, Optional[str], Optional[int]]], + failed_checks: list[tuple[str, Optional[str], Optional[int]]], + ignore_current_checks: list[tuple[str, Optional[str], Optional[int]]], + broken_trunk_checks: list[tuple[str, Optional[str], Optional[int]]], + flaky_checks: list[tuple[str, Optional[str], Optional[int]]], + unstable_checks: list[tuple[str, Optional[str], Optional[int]]], last_commit_sha: str, merge_base_sha: str, merge_commit_sha: str = "", @@ -1714,9 +1705,9 @@ def is_invalid_cancel( def get_classifications( pr_num: int, project: str, - checks: Dict[str, JobCheckState], - ignore_current_checks: Optional[List[str]], -) -> Dict[str, JobCheckState]: + checks: dict[str, JobCheckState], + ignore_current_checks: Optional[list[str]], +) -> dict[str, JobCheckState]: # Get the failure classification from Dr.CI, which is the source of truth # going forward. It's preferable to try calling Dr.CI API directly first # to get the latest results as well as update Dr.CI PR comment @@ -1825,7 +1816,7 @@ def get_classifications( def filter_checks_with_lambda( checks: JobNameToStateDict, status_filter: Callable[[Optional[str]], bool] -) -> List[JobCheckState]: +) -> list[JobCheckState]: return [check for check in checks.values() if status_filter(check.status)] @@ -1841,7 +1832,7 @@ def get_pr_commit_sha(repo: GitRepo, pr: GitHubPR) -> str: def validate_revert( repo: GitRepo, pr: GitHubPR, *, comment_id: Optional[int] = None -) -> Tuple[str, str]: +) -> tuple[str, str]: comment = ( pr.get_last_comment() if comment_id is None @@ -1871,7 +1862,7 @@ def validate_revert( def get_ghstack_dependent_prs( repo: GitRepo, pr: GitHubPR, only_closed: bool = True -) -> List[Tuple[str, GitHubPR]]: +) -> list[tuple[str, GitHubPR]]: """ Get the PRs in the stack that are above this PR (inclusive). Throws error if stack have branched or original branches are gone @@ -1897,7 +1888,7 @@ def get_ghstack_dependent_prs( # Remove commits original PR depends on if skip_len > 0: rev_list = rev_list[:-skip_len] - rc: List[Tuple[str, GitHubPR]] = [] + rc: list[tuple[str, GitHubPR]] = [] for pr_, sha in _revlist_to_prs(repo, pr, rev_list): if not pr_.is_closed(): if not only_closed: @@ -1910,7 +1901,7 @@ def get_ghstack_dependent_prs( def do_revert_prs( repo: GitRepo, - shas_and_prs: List[Tuple[str, GitHubPR]], + shas_and_prs: list[tuple[str, GitHubPR]], *, author_login: str, extra_msg: str = "", @@ -2001,7 +1992,7 @@ def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None: if skip_mandatory_checks: return response = cast( - Dict[str, Any], + dict[str, Any], gh_fetch_json_list( "https://api.github.com/search/issues", # Having two label: queries is an AND operation @@ -2019,29 +2010,29 @@ def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None: return -def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool: +def has_label(labels: list[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool: return len(list(filter(pattern.match, labels))) > 0 def categorize_checks( check_runs: JobNameToStateDict, - required_checks: List[str], + required_checks: list[str], ok_failed_checks_threshold: Optional[int] = None, -) -> Tuple[ - List[Tuple[str, Optional[str], Optional[int]]], - List[Tuple[str, Optional[str], Optional[int]]], - Dict[str, List[Any]], +) -> tuple[ + list[tuple[str, Optional[str], Optional[int]]], + list[tuple[str, Optional[str], Optional[int]]], + dict[str, list[Any]], ]: """ Categories all jobs into the list of pending and failing jobs. All known flaky failures and broken trunk are ignored by defaults when ok_failed_checks_threshold is not set (unlimited) """ - pending_checks: List[Tuple[str, Optional[str], Optional[int]]] = [] - failed_checks: List[Tuple[str, Optional[str], Optional[int]]] = [] + pending_checks: list[tuple[str, Optional[str], Optional[int]]] = [] + failed_checks: list[tuple[str, Optional[str], Optional[int]]] = [] # failed_checks_categorization is used to keep track of all ignorable failures when saving the merge record on s3 - failed_checks_categorization: Dict[str, List[Any]] = defaultdict(list) + failed_checks_categorization: dict[str, list[Any]] = defaultdict(list) # If required_checks is not set or empty, consider all names are relevant relevant_checknames = [ diff --git a/.github/scripts/trymerge_explainer.py b/.github/scripts/trymerge_explainer.py index 22797909714..0527701291f 100644 --- a/.github/scripts/trymerge_explainer.py +++ b/.github/scripts/trymerge_explainer.py @@ -1,6 +1,7 @@ import os import re -from typing import List, Optional, Pattern, Tuple +from re import Pattern +from typing import Optional BOT_COMMANDS_WIKI = "https://github.com/pytorch/pytorch/wiki/Bot-commands" @@ -13,13 +14,13 @@ CONTACT_US = f"Questions? Feedback? Please reach out to the [PyTorch DevX Team]( ALTERNATIVES = f"Learn more about merging in the [wiki]({BOT_COMMANDS_WIKI})." -def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool: +def has_label(labels: list[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool: return len(list(filter(pattern.match, labels))) > 0 class TryMergeExplainer: force: bool - labels: List[str] + labels: list[str] pr_num: int org: str project: str @@ -31,7 +32,7 @@ class TryMergeExplainer: def __init__( self, force: bool, - labels: List[str], + labels: list[str], pr_num: int, org: str, project: str, @@ -47,7 +48,7 @@ class TryMergeExplainer: def _get_flag_msg( self, ignore_current_checks: Optional[ - List[Tuple[str, Optional[str], Optional[int]]] + list[tuple[str, Optional[str], Optional[int]]] ] = None, ) -> str: if self.force: @@ -68,7 +69,7 @@ class TryMergeExplainer: def get_merge_message( self, ignore_current_checks: Optional[ - List[Tuple[str, Optional[str], Optional[int]]] + list[tuple[str, Optional[str], Optional[int]]] ] = None, ) -> str: title = "### Merge started" diff --git a/.github/scripts/tryrebase.py b/.github/scripts/tryrebase.py index efc243279ba..0f6d74e8346 100755 --- a/.github/scripts/tryrebase.py +++ b/.github/scripts/tryrebase.py @@ -5,7 +5,8 @@ import os import re import subprocess import sys -from typing import Any, Generator +from collections.abc import Generator +from typing import Any from github_utils import gh_post_pr_comment as gh_post_comment from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo diff --git a/.github/workflows/_runner-determinator.yml b/.github/workflows/_runner-determinator.yml index 36f5a06da5d..47cd278bb8a 100644 --- a/.github/workflows/_runner-determinator.yml +++ b/.github/workflows/_runner-determinator.yml @@ -129,9 +129,10 @@ jobs: import re import sys from argparse import ArgumentParser - from functools import lru_cache + from collections.abc import Iterable + from functools import cache from logging import LogRecord - from typing import Any, Dict, FrozenSet, Iterable, List, NamedTuple, Set, Tuple + from typing import Any, NamedTuple from urllib.request import Request, urlopen import yaml @@ -173,7 +174,7 @@ jobs: Settings for the experiments that can be opted into. """ - experiments: Dict[str, Experiment] = {} + experiments: dict[str, Experiment] = {} class ColorFormatter(logging.Formatter): @@ -218,7 +219,7 @@ jobs: f.write(f"{key}={value}\n") - def _str_comma_separated_to_set(value: str) -> FrozenSet[str]: + def _str_comma_separated_to_set(value: str) -> frozenset[str]: return frozenset( filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(","))) ) @@ -276,12 +277,12 @@ jobs: return parser.parse_args() - def get_gh_client(github_token: str) -> Github: + def get_gh_client(github_token: str) -> Github: # type: ignore[no-any-unimported] auth = Auth.Token(github_token) return Github(auth=auth) - def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: + def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: # type: ignore[no-any-unimported] repo = gh.get_repo(repo) return repo.get_issue(number=issue_num) @@ -310,7 +311,7 @@ jobs: raise Exception( # noqa: TRY002 f"issue with pull request {pr_number} from repo {repository}" ) from e - return pull.user.login + return pull.user.login # type: ignore[no-any-return] # In all other cases, return the original input username return username @@ -331,7 +332,7 @@ jobs: raise - def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]: + def extract_settings_user_opt_in_from_text(rollout_state: str) -> tuple[str, str]: """ Extracts the text with settings, if any, and the opted in users from the rollout state. @@ -347,7 +348,7 @@ jobs: return "", rollout_state - class UserOptins(Dict[str, List[str]]): + class UserOptins(dict[str, list[str]]): """ Dictionary of users with a list of features they have opted into """ @@ -488,7 +489,7 @@ jobs: rollout_state: str, workflow_requestors: Iterable[str], branch: str, - eligible_experiments: FrozenSet[str] = frozenset(), + eligible_experiments: frozenset[str] = frozenset(), is_canary: bool = False, ) -> str: settings = parse_settings(rollout_state) @@ -587,7 +588,7 @@ jobs: return str(issue.get_comments()[0].body.strip("\n\t ")) - def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any: + def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any: for _ in range(num_retries): try: req = Request(url=url, headers=headers) @@ -600,8 +601,8 @@ jobs: return {} - @lru_cache(maxsize=None) - def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str, Any]: + @cache + def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> dict[str, Any]: """ Dynamically get PR information """ @@ -610,7 +611,7 @@ jobs: "Accept": "application/vnd.github.v3+json", "Authorization": f"token {github_token}", } - json_response: Dict[str, Any] = download_json( + json_response: dict[str, Any] = download_json( url=f"{github_api}/issues/{pr_number}", headers=headers, ) @@ -622,7 +623,7 @@ jobs: return json_response - def get_labels(github_repo: str, github_token: str, pr_number: int) -> Set[str]: + def get_labels(github_repo: str, github_token: str, pr_number: int) -> set[str]: """ Dynamically get the latest list of labels from the pull request """