mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
PEP585: .github (#145707)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145707 Approved by: https://github.com/huydhn
This commit is contained in:
parent
bfaf76bfc6
commit
60f98262f1
10
.github/scripts/cherry_pick.py
vendored
10
.github/scripts/cherry_pick.py
vendored
|
|
@ -3,7 +3,7 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, cast, Dict, List, Optional
|
||||
from typing import Any, cast, Optional
|
||||
from urllib.error import HTTPError
|
||||
|
||||
from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels
|
||||
|
|
@ -67,7 +67,7 @@ def get_release_version(onto_branch: str) -> Optional[str]:
|
|||
|
||||
def get_tracker_issues(
|
||||
org: str, project: str, onto_branch: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Find the tracker issue from the repo. The tracker issue needs to have the title
|
||||
like [VERSION] Release Tracker following the convention on PyTorch
|
||||
|
|
@ -117,7 +117,7 @@ def cherry_pick(
|
|||
continue
|
||||
|
||||
res = cast(
|
||||
Dict[str, Any],
|
||||
dict[str, Any],
|
||||
post_tracker_issue_comment(
|
||||
org,
|
||||
project,
|
||||
|
|
@ -220,7 +220,7 @@ def submit_pr(
|
|||
|
||||
def post_pr_comment(
|
||||
org: str, project: str, pr_num: int, msg: str, dry_run: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Post a comment on the PR itself to point to the cherry picking PR when success
|
||||
or print the error when failure
|
||||
|
|
@ -255,7 +255,7 @@ def post_tracker_issue_comment(
|
|||
classification: str,
|
||||
fixes: str,
|
||||
dry_run: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Post a comment on the tracker issue (if any) to record the cherry pick
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import re
|
|||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from gitutils import retries_decorator
|
||||
|
|
@ -76,7 +76,7 @@ DISABLED_TESTS_JSON = (
|
|||
|
||||
|
||||
@retries_decorator()
|
||||
def query_db(query: str, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
def query_db(query: str, params: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
return query_clickhouse(query, params)
|
||||
|
||||
|
||||
|
|
@ -97,7 +97,7 @@ def download_log_worker(temp_dir: str, id: int, name: str) -> None:
|
|||
f.write(data)
|
||||
|
||||
|
||||
def printer(item: Tuple[str, Tuple[int, str, List[Any]]], extra: str) -> None:
|
||||
def printer(item: tuple[str, tuple[int, str, list[Any]]], extra: str) -> None:
|
||||
test, (_, link, _) = item
|
||||
print(f"{link:<55} {test:<120} {extra}")
|
||||
|
||||
|
|
@ -120,8 +120,8 @@ def close_issue(num: int) -> None:
|
|||
|
||||
|
||||
def check_if_exists(
|
||||
item: Tuple[str, Tuple[int, str, List[str]]], all_logs: List[str]
|
||||
) -> Tuple[bool, str]:
|
||||
item: tuple[str, tuple[int, str, list[str]]], all_logs: list[str]
|
||||
) -> tuple[bool, str]:
|
||||
test, (_, link, _) = item
|
||||
# Test names should look like `test_a (module.path.classname)`
|
||||
reg = re.match(r"(\S+) \((\S*)\)", test)
|
||||
|
|
|
|||
14
.github/scripts/collect_ciflow_labels.py
vendored
14
.github/scripts/collect_ciflow_labels.py
vendored
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, cast, Dict, List, Set
|
||||
from typing import Any, cast
|
||||
|
||||
import yaml
|
||||
|
||||
|
|
@ -10,9 +10,9 @@ import yaml
|
|||
GITHUB_DIR = Path(__file__).parent.parent
|
||||
|
||||
|
||||
def get_workflows_push_tags() -> Set[str]:
|
||||
def get_workflows_push_tags() -> set[str]:
|
||||
"Extract all known push tags from workflows"
|
||||
rc: Set[str] = set()
|
||||
rc: set[str] = set()
|
||||
for fname in (GITHUB_DIR / "workflows").glob("*.yml"):
|
||||
with fname.open("r") as f:
|
||||
wf_yml = yaml.safe_load(f)
|
||||
|
|
@ -25,19 +25,19 @@ def get_workflows_push_tags() -> Set[str]:
|
|||
return rc
|
||||
|
||||
|
||||
def filter_ciflow_tags(tags: Set[str]) -> List[str]:
|
||||
def filter_ciflow_tags(tags: set[str]) -> list[str]:
|
||||
"Return sorted list of ciflow tags"
|
||||
return sorted(
|
||||
tag[:-2] for tag in tags if tag.startswith("ciflow/") and tag.endswith("/*")
|
||||
)
|
||||
|
||||
|
||||
def read_probot_config() -> Dict[str, Any]:
|
||||
def read_probot_config() -> dict[str, Any]:
|
||||
with (GITHUB_DIR / "pytorch-probot.yml").open("r") as f:
|
||||
return cast(Dict[str, Any], yaml.safe_load(f))
|
||||
return cast(dict[str, Any], yaml.safe_load(f))
|
||||
|
||||
|
||||
def update_probot_config(labels: Set[str]) -> None:
|
||||
def update_probot_config(labels: set[str]) -> None:
|
||||
orig = read_probot_config()
|
||||
orig["ciflow_push_tags"] = filter_ciflow_tags(labels)
|
||||
with (GITHUB_DIR / "pytorch-probot.yml").open("w") as f:
|
||||
|
|
|
|||
28
.github/scripts/delete_old_branches.py
vendored
28
.github/scripts/delete_old_branches.py
vendored
|
|
@ -4,7 +4,7 @@ import re
|
|||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Set
|
||||
from typing import Any, Callable
|
||||
|
||||
from github_utils import gh_fetch_json_dict, gh_graphql
|
||||
from gitutils import GitRepo
|
||||
|
|
@ -112,7 +112,7 @@ def convert_gh_timestamp(date: str) -> float:
|
|||
return datetime.strptime(date, "%Y-%m-%dT%H:%M:%SZ").timestamp()
|
||||
|
||||
|
||||
def get_branches(repo: GitRepo) -> Dict[str, Any]:
|
||||
def get_branches(repo: GitRepo) -> dict[str, Any]:
|
||||
# Query locally for branches, group by branch base name (e.g. gh/blah/base -> gh/blah), and get the most recent branch
|
||||
git_response = repo._run_git(
|
||||
"for-each-ref",
|
||||
|
|
@ -120,7 +120,7 @@ def get_branches(repo: GitRepo) -> Dict[str, Any]:
|
|||
"--format=%(refname) %(committerdate:iso-strict)",
|
||||
"refs/remotes/origin",
|
||||
)
|
||||
branches_by_base_name: Dict[str, Any] = {}
|
||||
branches_by_base_name: dict[str, Any] = {}
|
||||
for line in git_response.splitlines():
|
||||
branch, date = line.split(" ")
|
||||
re_branch = re.match(r"refs/remotes/origin/(.*)", branch)
|
||||
|
|
@ -140,14 +140,14 @@ def get_branches(repo: GitRepo) -> Dict[str, Any]:
|
|||
|
||||
def paginate_graphql(
|
||||
query: str,
|
||||
kwargs: Dict[str, Any],
|
||||
termination_func: Callable[[List[Dict[str, Any]]], bool],
|
||||
get_data: Callable[[Dict[str, Any]], List[Dict[str, Any]]],
|
||||
get_page_info: Callable[[Dict[str, Any]], Dict[str, Any]],
|
||||
) -> List[Any]:
|
||||
kwargs: dict[str, Any],
|
||||
termination_func: Callable[[list[dict[str, Any]]], bool],
|
||||
get_data: Callable[[dict[str, Any]], list[dict[str, Any]]],
|
||||
get_page_info: Callable[[dict[str, Any]], dict[str, Any]],
|
||||
) -> list[Any]:
|
||||
hasNextPage = True
|
||||
endCursor = None
|
||||
data: List[Dict[str, Any]] = []
|
||||
data: list[dict[str, Any]] = []
|
||||
while hasNextPage:
|
||||
ESTIMATED_TOKENS[0] += 1
|
||||
res = gh_graphql(query, cursor=endCursor, **kwargs)
|
||||
|
|
@ -159,11 +159,11 @@ def paginate_graphql(
|
|||
return data
|
||||
|
||||
|
||||
def get_recent_prs() -> Dict[str, Any]:
|
||||
def get_recent_prs() -> dict[str, Any]:
|
||||
now = datetime.now().timestamp()
|
||||
|
||||
# Grab all PRs updated in last CLOSED_PR_RETENTION days
|
||||
pr_infos: List[Dict[str, Any]] = paginate_graphql(
|
||||
pr_infos: list[dict[str, Any]] = paginate_graphql(
|
||||
GRAPHQL_ALL_PRS_BY_UPDATED_AT,
|
||||
{"owner": "pytorch", "repo": "pytorch"},
|
||||
lambda data: (
|
||||
|
|
@ -190,7 +190,7 @@ def get_recent_prs() -> Dict[str, Any]:
|
|||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_open_prs() -> List[Dict[str, Any]]:
|
||||
def get_open_prs() -> list[dict[str, Any]]:
|
||||
return paginate_graphql(
|
||||
GRAPHQL_OPEN_PRS,
|
||||
{"owner": "pytorch", "repo": "pytorch"},
|
||||
|
|
@ -200,8 +200,8 @@ def get_open_prs() -> List[Dict[str, Any]]:
|
|||
)
|
||||
|
||||
|
||||
def get_branches_with_magic_label_or_open_pr() -> Set[str]:
|
||||
pr_infos: List[Dict[str, Any]] = paginate_graphql(
|
||||
def get_branches_with_magic_label_or_open_pr() -> set[str]:
|
||||
pr_infos: list[dict[str, Any]] = paginate_graphql(
|
||||
GRAPHQL_NO_DELETE_BRANCH_LABEL,
|
||||
{"owner": "pytorch", "repo": "pytorch"},
|
||||
lambda data: False,
|
||||
|
|
|
|||
4
.github/scripts/file_io_utils.py
vendored
4
.github/scripts/file_io_utils.py
vendored
|
|
@ -2,7 +2,7 @@ import json
|
|||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
import boto3 # type: ignore[import]
|
||||
|
||||
|
|
@ -77,7 +77,7 @@ def upload_file_to_s3(file_name: Path, bucket: str, key: str) -> None:
|
|||
|
||||
def download_s3_objects_with_prefix(
|
||||
bucket_name: str, prefix: str, download_folder: Path
|
||||
) -> List[Path]:
|
||||
) -> list[Path]:
|
||||
s3 = boto3.resource("s3")
|
||||
bucket = s3.Bucket(bucket_name)
|
||||
|
||||
|
|
|
|||
60
.github/scripts/filter_test_configs.py
vendored
60
.github/scripts/filter_test_configs.py
vendored
|
|
@ -8,9 +8,9 @@ import subprocess
|
|||
import sys
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from functools import cache
|
||||
from logging import info
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from typing import Any, Callable, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import yaml
|
||||
|
|
@ -32,7 +32,7 @@ def is_cuda_or_rocm_job(job_name: Optional[str]) -> bool:
|
|||
|
||||
# Supported modes when running periodically. Only applying the mode when
|
||||
# its lambda condition returns true
|
||||
SUPPORTED_PERIODICAL_MODES: Dict[str, Callable[[Optional[str]], bool]] = {
|
||||
SUPPORTED_PERIODICAL_MODES: dict[str, Callable[[Optional[str]], bool]] = {
|
||||
# Memory leak check is only needed for CUDA and ROCm jobs which utilize GPU memory
|
||||
"mem_leak_check": is_cuda_or_rocm_job,
|
||||
"rerun_disabled_tests": lambda job_name: True,
|
||||
|
|
@ -102,8 +102,8 @@ def parse_args() -> Any:
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_pr_info(pr_number: int) -> Dict[str, Any]:
|
||||
@cache
|
||||
def get_pr_info(pr_number: int) -> dict[str, Any]:
|
||||
"""
|
||||
Dynamically get PR information
|
||||
"""
|
||||
|
|
@ -116,7 +116,7 @@ def get_pr_info(pr_number: int) -> Dict[str, Any]:
|
|||
"Accept": "application/vnd.github.v3+json",
|
||||
"Authorization": f"token {github_token}",
|
||||
}
|
||||
json_response: Dict[str, Any] = download_json(
|
||||
json_response: dict[str, Any] = download_json(
|
||||
url=f"{pytorch_github_api}/issues/{pr_number}",
|
||||
headers=headers,
|
||||
)
|
||||
|
|
@ -128,7 +128,7 @@ def get_pr_info(pr_number: int) -> Dict[str, Any]:
|
|||
return json_response
|
||||
|
||||
|
||||
def get_labels(pr_number: int) -> Set[str]:
|
||||
def get_labels(pr_number: int) -> set[str]:
|
||||
"""
|
||||
Dynamically get the latest list of labels from the pull request
|
||||
"""
|
||||
|
|
@ -138,14 +138,14 @@ def get_labels(pr_number: int) -> Set[str]:
|
|||
}
|
||||
|
||||
|
||||
def filter_labels(labels: Set[str], label_regex: Any) -> Set[str]:
|
||||
def filter_labels(labels: set[str], label_regex: Any) -> set[str]:
|
||||
"""
|
||||
Return the list of matching labels
|
||||
"""
|
||||
return {l for l in labels if re.match(label_regex, l)}
|
||||
|
||||
|
||||
def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, List[Any]]:
|
||||
def filter(test_matrix: dict[str, list[Any]], labels: set[str]) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Select the list of test config to run from the test matrix. The logic works
|
||||
as follows:
|
||||
|
|
@ -157,7 +157,7 @@ def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, Lis
|
|||
|
||||
If the PR has none of the test-config label, all tests are run as usual.
|
||||
"""
|
||||
filtered_test_matrix: Dict[str, List[Any]] = {"include": []}
|
||||
filtered_test_matrix: dict[str, list[Any]] = {"include": []}
|
||||
|
||||
for entry in test_matrix.get("include", []):
|
||||
config_name = entry.get("config", "")
|
||||
|
|
@ -185,8 +185,8 @@ def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, Lis
|
|||
|
||||
|
||||
def filter_selected_test_configs(
|
||||
test_matrix: Dict[str, List[Any]], selected_test_configs: Set[str]
|
||||
) -> Dict[str, List[Any]]:
|
||||
test_matrix: dict[str, list[Any]], selected_test_configs: set[str]
|
||||
) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Keep only the selected configs if the list if not empty. Otherwise, keep all test configs.
|
||||
This filter is used when the workflow is dispatched manually.
|
||||
|
|
@ -194,7 +194,7 @@ def filter_selected_test_configs(
|
|||
if not selected_test_configs:
|
||||
return test_matrix
|
||||
|
||||
filtered_test_matrix: Dict[str, List[Any]] = {"include": []}
|
||||
filtered_test_matrix: dict[str, list[Any]] = {"include": []}
|
||||
for entry in test_matrix.get("include", []):
|
||||
config_name = entry.get("config", "")
|
||||
if not config_name:
|
||||
|
|
@ -207,12 +207,12 @@ def filter_selected_test_configs(
|
|||
|
||||
|
||||
def set_periodic_modes(
|
||||
test_matrix: Dict[str, List[Any]], job_name: Optional[str]
|
||||
) -> Dict[str, List[Any]]:
|
||||
test_matrix: dict[str, list[Any]], job_name: Optional[str]
|
||||
) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Apply all periodic modes when running under a schedule
|
||||
"""
|
||||
scheduled_test_matrix: Dict[str, List[Any]] = {
|
||||
scheduled_test_matrix: dict[str, list[Any]] = {
|
||||
"include": [],
|
||||
}
|
||||
|
||||
|
|
@ -229,8 +229,8 @@ def set_periodic_modes(
|
|||
|
||||
|
||||
def mark_unstable_jobs(
|
||||
workflow: str, job_name: str, test_matrix: Dict[str, List[Any]]
|
||||
) -> Dict[str, List[Any]]:
|
||||
workflow: str, job_name: str, test_matrix: dict[str, list[Any]]
|
||||
) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Check the list of unstable jobs and mark them accordingly. Note that if a job
|
||||
is unstable, all its dependents will also be marked accordingly
|
||||
|
|
@ -245,8 +245,8 @@ def mark_unstable_jobs(
|
|||
|
||||
|
||||
def remove_disabled_jobs(
|
||||
workflow: str, job_name: str, test_matrix: Dict[str, List[Any]]
|
||||
) -> Dict[str, List[Any]]:
|
||||
workflow: str, job_name: str, test_matrix: dict[str, list[Any]]
|
||||
) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Check the list of disabled jobs, remove the current job and all its dependents
|
||||
if it exists in the list
|
||||
|
|
@ -261,15 +261,15 @@ def remove_disabled_jobs(
|
|||
|
||||
|
||||
def _filter_jobs(
|
||||
test_matrix: Dict[str, List[Any]],
|
||||
test_matrix: dict[str, list[Any]],
|
||||
issue_type: IssueType,
|
||||
target_cfg: Optional[str] = None,
|
||||
) -> Dict[str, List[Any]]:
|
||||
) -> dict[str, list[Any]]:
|
||||
"""
|
||||
An utility function used to actually apply the job filter
|
||||
"""
|
||||
# The result will be stored here
|
||||
filtered_test_matrix: Dict[str, List[Any]] = {"include": []}
|
||||
filtered_test_matrix: dict[str, list[Any]] = {"include": []}
|
||||
|
||||
# This is an issue to disable a CI job
|
||||
if issue_type == IssueType.DISABLED:
|
||||
|
|
@ -302,10 +302,10 @@ def _filter_jobs(
|
|||
def process_jobs(
|
||||
workflow: str,
|
||||
job_name: str,
|
||||
test_matrix: Dict[str, List[Any]],
|
||||
test_matrix: dict[str, list[Any]],
|
||||
issue_type: IssueType,
|
||||
url: str,
|
||||
) -> Dict[str, List[Any]]:
|
||||
) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Both disabled and unstable jobs are in the following format:
|
||||
|
||||
|
|
@ -441,7 +441,7 @@ def process_jobs(
|
|||
return test_matrix
|
||||
|
||||
|
||||
def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any:
|
||||
def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any:
|
||||
for _ in range(num_retries):
|
||||
try:
|
||||
req = Request(url=url, headers=headers)
|
||||
|
|
@ -462,7 +462,7 @@ def set_output(name: str, val: Any) -> None:
|
|||
print(f"::set-output name={name}::{val}")
|
||||
|
||||
|
||||
def parse_reenabled_issues(s: Optional[str]) -> List[str]:
|
||||
def parse_reenabled_issues(s: Optional[str]) -> list[str]:
|
||||
# NB: When the PR body is empty, GitHub API returns a None value, which is
|
||||
# passed into this function
|
||||
if not s:
|
||||
|
|
@ -477,7 +477,7 @@ def parse_reenabled_issues(s: Optional[str]) -> List[str]:
|
|||
return issue_numbers
|
||||
|
||||
|
||||
def get_reenabled_issues(pr_body: str = "") -> List[str]:
|
||||
def get_reenabled_issues(pr_body: str = "") -> list[str]:
|
||||
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
|
||||
try:
|
||||
commit_messages = subprocess.check_output(
|
||||
|
|
@ -489,12 +489,12 @@ def get_reenabled_issues(pr_body: str = "") -> List[str]:
|
|||
return parse_reenabled_issues(pr_body) + parse_reenabled_issues(commit_messages)
|
||||
|
||||
|
||||
def check_for_setting(labels: Set[str], body: str, setting: str) -> bool:
|
||||
def check_for_setting(labels: set[str], body: str, setting: str) -> bool:
|
||||
return setting in labels or f"[{setting}]" in body
|
||||
|
||||
|
||||
def perform_misc_tasks(
|
||||
labels: Set[str], test_matrix: Dict[str, List[Any]], job_name: str, pr_body: str
|
||||
labels: set[str], test_matrix: dict[str, list[Any]], job_name: str, pr_body: str
|
||||
) -> None:
|
||||
"""
|
||||
In addition to apply the filter logic, the script also does the following
|
||||
|
|
|
|||
22
.github/scripts/generate_binary_build_matrix.py
vendored
22
.github/scripts/generate_binary_build_matrix.py
vendored
|
|
@ -12,7 +12,7 @@ architectures:
|
|||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# NOTE: Also update the CUDA sources in tools/nightly.py when changing this list
|
||||
|
|
@ -181,7 +181,7 @@ CXX11_ABI = "cxx11-abi"
|
|||
RELEASE = "release"
|
||||
DEBUG = "debug"
|
||||
|
||||
LIBTORCH_CONTAINER_IMAGES: Dict[Tuple[str, str], str] = {
|
||||
LIBTORCH_CONTAINER_IMAGES: dict[tuple[str, str], str] = {
|
||||
**{
|
||||
(
|
||||
gpu_arch,
|
||||
|
|
@ -223,16 +223,16 @@ def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
|
|||
}.get(gpu_arch_type, gpu_arch_version)
|
||||
|
||||
|
||||
def list_without(in_list: List[str], without: List[str]) -> List[str]:
|
||||
def list_without(in_list: list[str], without: list[str]) -> list[str]:
|
||||
return [item for item in in_list if item not in without]
|
||||
|
||||
|
||||
def generate_libtorch_matrix(
|
||||
os: str,
|
||||
abi_version: str,
|
||||
arches: Optional[List[str]] = None,
|
||||
libtorch_variants: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
arches: Optional[list[str]] = None,
|
||||
libtorch_variants: Optional[list[str]] = None,
|
||||
) -> list[dict[str, str]]:
|
||||
if arches is None:
|
||||
arches = ["cpu"]
|
||||
if os == "linux":
|
||||
|
|
@ -248,7 +248,7 @@ def generate_libtorch_matrix(
|
|||
"static-without-deps",
|
||||
]
|
||||
|
||||
ret: List[Dict[str, str]] = []
|
||||
ret: list[dict[str, str]] = []
|
||||
for arch_version in arches:
|
||||
for libtorch_variant in libtorch_variants:
|
||||
# one of the values in the following list must be exactly
|
||||
|
|
@ -287,10 +287,10 @@ def generate_libtorch_matrix(
|
|||
|
||||
def generate_wheels_matrix(
|
||||
os: str,
|
||||
arches: Optional[List[str]] = None,
|
||||
python_versions: Optional[List[str]] = None,
|
||||
arches: Optional[list[str]] = None,
|
||||
python_versions: Optional[list[str]] = None,
|
||||
use_split_build: bool = False,
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
package_type = "wheel"
|
||||
if os == "linux" or os == "linux-aarch64" or os == "linux-s390x":
|
||||
# NOTE: We only build manywheel packages for x86_64 and aarch64 and s390x linux
|
||||
|
|
@ -315,7 +315,7 @@ def generate_wheels_matrix(
|
|||
# uses different build/test scripts
|
||||
arches = ["cpu-s390x"]
|
||||
|
||||
ret: List[Dict[str, str]] = []
|
||||
ret: list[dict[str, str]] = []
|
||||
for python_version in python_versions:
|
||||
for arch_version in arches:
|
||||
gpu_arch_type = arch_type(arch_version)
|
||||
|
|
|
|||
7
.github/scripts/generate_ci_workflows.py
vendored
7
.github/scripts/generate_ci_workflows.py
vendored
|
|
@ -2,9 +2,10 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Literal, Set
|
||||
from typing import Literal
|
||||
from typing_extensions import TypedDict # Python 3.11+
|
||||
|
||||
import generate_binary_build_matrix # type: ignore[import]
|
||||
|
|
@ -27,7 +28,7 @@ LABEL_CIFLOW_BINARIES_WHEEL = "ciflow/binaries_wheel"
|
|||
class CIFlowConfig:
|
||||
# For use to enable workflows to run on pytorch/pytorch-canary
|
||||
run_on_canary: bool = False
|
||||
labels: Set[str] = field(default_factory=set)
|
||||
labels: set[str] = field(default_factory=set)
|
||||
# Certain jobs might not want to be part of the ciflow/[all,trunk] workflow
|
||||
isolated_workflow: bool = False
|
||||
unstable: bool = False
|
||||
|
|
@ -48,7 +49,7 @@ class Config(TypedDict):
|
|||
@dataclass
|
||||
class BinaryBuildWorkflow:
|
||||
os: str
|
||||
build_configs: List[Dict[str, str]]
|
||||
build_configs: list[dict[str, str]]
|
||||
package_type: str
|
||||
|
||||
# Optional fields
|
||||
|
|
|
|||
10
.github/scripts/get_workflow_job_id.py
vendored
10
.github/scripts/get_workflow_job_id.py
vendored
|
|
@ -11,11 +11,11 @@ import sys
|
|||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
|
||||
def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]:
|
||||
def parse_json_and_links(conn: Any) -> tuple[Any, dict[str, dict[str, str]]]:
|
||||
links = {}
|
||||
# Extract links which GH uses for pagination
|
||||
# see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link
|
||||
|
|
@ -42,7 +42,7 @@ def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]:
|
|||
def fetch_url(
|
||||
url: str,
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
reader: Callable[[Any], Any] = lambda x: x.read(),
|
||||
retries: Optional[int] = 3,
|
||||
backoff_timeout: float = 0.5,
|
||||
|
|
@ -83,7 +83,7 @@ def parse_args() -> Any:
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]:
|
||||
def fetch_jobs(url: str, headers: dict[str, str]) -> list[dict[str, str]]:
|
||||
response, links = fetch_url(url, headers=headers, reader=parse_json_and_links)
|
||||
jobs = response["jobs"]
|
||||
assert type(jobs) is list
|
||||
|
|
@ -111,7 +111,7 @@ def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]:
|
|||
# running.
|
||||
|
||||
|
||||
def find_job_id_name(args: Any) -> Tuple[str, str]:
|
||||
def find_job_id_name(args: Any) -> tuple[str, str]:
|
||||
# From https://docs.github.com/en/actions/learn-github-actions/environment-variables
|
||||
PYTORCH_REPO = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch")
|
||||
PYTORCH_GITHUB_API = f"https://api.github.com/repos/{PYTORCH_REPO}"
|
||||
|
|
|
|||
54
.github/scripts/github_utils.py
vendored
54
.github/scripts/github_utils.py
vendored
|
|
@ -4,7 +4,7 @@ import json
|
|||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
from urllib.error import HTTPError
|
||||
from urllib.parse import quote
|
||||
from urllib.request import Request, urlopen
|
||||
|
|
@ -27,11 +27,11 @@ class GitHubComment:
|
|||
def gh_fetch_url_and_headers(
|
||||
url: str,
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
data: Union[Optional[Dict[str, Any]], str] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
data: Union[Optional[dict[str, Any]], str] = None,
|
||||
method: Optional[str] = None,
|
||||
reader: Callable[[Any], Any] = lambda x: x.read(),
|
||||
) -> Tuple[Any, Any]:
|
||||
) -> tuple[Any, Any]:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
token = os.environ.get("GITHUB_TOKEN")
|
||||
|
|
@ -70,8 +70,8 @@ def gh_fetch_url_and_headers(
|
|||
def gh_fetch_url(
|
||||
url: str,
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
data: Union[Optional[Dict[str, Any]], str] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
data: Union[Optional[dict[str, Any]], str] = None,
|
||||
method: Optional[str] = None,
|
||||
reader: Callable[[Any], Any] = json.load,
|
||||
) -> Any:
|
||||
|
|
@ -82,25 +82,25 @@ def gh_fetch_url(
|
|||
|
||||
def gh_fetch_json(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
method: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||
if params is not None and len(params) > 0:
|
||||
url += "?" + "&".join(
|
||||
f"{name}={quote(str(val))}" for name, val in params.items()
|
||||
)
|
||||
return cast(
|
||||
List[Dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
gh_fetch_url(url, headers=headers, data=data, reader=json.load, method=method),
|
||||
)
|
||||
|
||||
|
||||
def _gh_fetch_json_any(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||
if params is not None and len(params) > 0:
|
||||
|
|
@ -112,21 +112,21 @@ def _gh_fetch_json_any(
|
|||
|
||||
def gh_fetch_json_list(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return cast(List[Dict[str, Any]], _gh_fetch_json_any(url, params, data))
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
return cast(list[dict[str, Any]], _gh_fetch_json_any(url, params, data))
|
||||
|
||||
|
||||
def gh_fetch_json_dict(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], _gh_fetch_json_any(url, params, data))
|
||||
|
||||
|
||||
def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def gh_graphql(query: str, **kwargs: Any) -> dict[str, Any]:
|
||||
rc = gh_fetch_url(
|
||||
"https://api.github.com/graphql",
|
||||
data={"query": query, "variables": kwargs},
|
||||
|
|
@ -136,12 +136,12 @@ def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]:
|
|||
raise RuntimeError(
|
||||
f"GraphQL query {query}, args {kwargs} failed: {rc['errors']}"
|
||||
)
|
||||
return cast(Dict[str, Any], rc)
|
||||
return cast(dict[str, Any], rc)
|
||||
|
||||
|
||||
def _gh_post_comment(
|
||||
url: str, comment: str, dry_run: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
if dry_run:
|
||||
print(comment)
|
||||
return []
|
||||
|
|
@ -150,7 +150,7 @@ def _gh_post_comment(
|
|||
|
||||
def gh_post_pr_comment(
|
||||
org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
return _gh_post_comment(
|
||||
f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/{pr_num}/comments",
|
||||
comment,
|
||||
|
|
@ -160,7 +160,7 @@ def gh_post_pr_comment(
|
|||
|
||||
def gh_post_commit_comment(
|
||||
org: str, repo: str, sha: str, comment: str, dry_run: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
return _gh_post_comment(
|
||||
f"{GITHUB_API_URL}/repos/{org}/{repo}/commits/{sha}/comments",
|
||||
comment,
|
||||
|
|
@ -220,8 +220,8 @@ def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") ->
|
|||
|
||||
|
||||
def gh_query_issues_by_labels(
|
||||
org: str, repo: str, labels: List[str], state: str = "open"
|
||||
) -> List[Dict[str, Any]]:
|
||||
org: str, repo: str, labels: list[str], state: str = "open"
|
||||
) -> list[dict[str, Any]]:
|
||||
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues"
|
||||
return gh_fetch_json(
|
||||
url, method="GET", params={"labels": ",".join(labels), "state": state}
|
||||
|
|
|
|||
42
.github/scripts/gitutils.py
vendored
42
.github/scripts/gitutils.py
vendored
|
|
@ -4,20 +4,10 @@ import os
|
|||
import re
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
|
@ -35,17 +25,17 @@ def get_git_repo_dir() -> str:
|
|||
return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
|
||||
def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:
|
||||
def fuzzy_list_to_dict(items: list[tuple[str, str]]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Converts list to dict preserving elements with duplicate keys
|
||||
"""
|
||||
rc: Dict[str, List[str]] = defaultdict(list)
|
||||
rc: dict[str, list[str]] = defaultdict(list)
|
||||
for key, val in items:
|
||||
rc[key].append(val)
|
||||
return dict(rc)
|
||||
|
||||
|
||||
def _check_output(items: List[str], encoding: str = "utf-8") -> str:
|
||||
def _check_output(items: list[str], encoding: str = "utf-8") -> str:
|
||||
from subprocess import CalledProcessError, check_output, STDOUT
|
||||
|
||||
try:
|
||||
|
|
@ -95,7 +85,7 @@ class GitCommit:
|
|||
return item in self.body or item in self.title
|
||||
|
||||
|
||||
def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit:
|
||||
def parse_fuller_format(lines: Union[str, list[str]]) -> GitCommit:
|
||||
"""
|
||||
Expect commit message generated using `--format=fuller --date=unix` format, i.e.:
|
||||
commit <sha1>
|
||||
|
|
@ -142,13 +132,13 @@ class GitRepo:
|
|||
print(f"+ git -C {self.repo_dir} {' '.join(args)}")
|
||||
return _check_output(["git", "-C", self.repo_dir] + list(args))
|
||||
|
||||
def revlist(self, revision_range: str) -> List[str]:
|
||||
def revlist(self, revision_range: str) -> list[str]:
|
||||
rc = self._run_git("rev-list", revision_range, "--", ".").strip()
|
||||
return rc.split("\n") if len(rc) > 0 else []
|
||||
|
||||
def branches_containing_ref(
|
||||
self, ref: str, *, include_remote: bool = True
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
rc = (
|
||||
self._run_git("branch", "--remote", "--contains", ref)
|
||||
if include_remote
|
||||
|
|
@ -189,7 +179,7 @@ class GitRepo:
|
|||
def get_merge_base(self, from_ref: str, to_ref: str) -> str:
|
||||
return self._run_git("merge-base", from_ref, to_ref).strip()
|
||||
|
||||
def patch_id(self, ref: Union[str, List[str]]) -> List[Tuple[str, str]]:
|
||||
def patch_id(self, ref: Union[str, list[str]]) -> list[tuple[str, str]]:
|
||||
is_list = isinstance(ref, list)
|
||||
if is_list:
|
||||
if len(ref) == 0:
|
||||
|
|
@ -198,9 +188,9 @@ class GitRepo:
|
|||
rc = _check_output(
|
||||
["sh", "-c", f"git -C {self.repo_dir} show {ref}|git patch-id --stable"]
|
||||
).strip()
|
||||
return [cast(Tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")]
|
||||
return [cast(tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")]
|
||||
|
||||
def commits_resolving_gh_pr(self, pr_num: int) -> List[str]:
|
||||
def commits_resolving_gh_pr(self, pr_num: int) -> list[str]:
|
||||
owner, name = self.gh_owner_and_name()
|
||||
msg = f"Pull Request resolved: https://github.com/{owner}/{name}/pull/{pr_num}"
|
||||
rc = self._run_git("log", "--format=%H", "--grep", msg).strip()
|
||||
|
|
@ -219,7 +209,7 @@ class GitRepo:
|
|||
|
||||
def compute_branch_diffs(
|
||||
self, from_branch: str, to_branch: str
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Returns list of commmits that are missing in each other branch since their merge base
|
||||
Might be slow if merge base is between two branches is pretty far off
|
||||
|
|
@ -311,14 +301,14 @@ class GitRepo:
|
|||
def remote_url(self) -> str:
|
||||
return self._run_git("remote", "get-url", self.remote)
|
||||
|
||||
def gh_owner_and_name(self) -> Tuple[str, str]:
|
||||
def gh_owner_and_name(self) -> tuple[str, str]:
|
||||
url = os.getenv("GIT_REMOTE_URL", None)
|
||||
if url is None:
|
||||
url = self.remote_url()
|
||||
rc = RE_GITHUB_URL_MATCH.match(url)
|
||||
if rc is None:
|
||||
raise RuntimeError(f"Unexpected url format {url}")
|
||||
return cast(Tuple[str, str], rc.groups())
|
||||
return cast(tuple[str, str], rc.groups())
|
||||
|
||||
def commit_message(self, ref: str) -> str:
|
||||
return self._run_git("log", "-1", "--format=%B", ref)
|
||||
|
|
@ -366,7 +356,7 @@ class PeekableIterator(Iterator[str]):
|
|||
return rc
|
||||
|
||||
|
||||
def patterns_to_regex(allowed_patterns: List[str]) -> Any:
|
||||
def patterns_to_regex(allowed_patterns: list[str]) -> Any:
|
||||
"""
|
||||
pattern is glob-like, i.e. the only special sequences it has are:
|
||||
- ? - matches single character
|
||||
|
|
@ -437,7 +427,7 @@ def retries_decorator(
|
|||
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
def decorator(f: Callable[..., T]) -> Callable[..., T]:
|
||||
@wraps(f)
|
||||
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> T:
|
||||
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> T:
|
||||
for idx in range(num_retries):
|
||||
try:
|
||||
return f(*args, **kwargs)
|
||||
|
|
|
|||
14
.github/scripts/label_utils.py
vendored
14
.github/scripts/label_utils.py
vendored
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, TYPE_CHECKING, Union
|
||||
|
||||
from github_utils import gh_fetch_url_and_headers, GitHubComment
|
||||
|
||||
|
|
@ -28,14 +28,14 @@ https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for
|
|||
"""
|
||||
|
||||
|
||||
def request_for_labels(url: str) -> Tuple[Any, Any]:
|
||||
def request_for_labels(url: str) -> tuple[Any, Any]:
|
||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||
return gh_fetch_url_and_headers(
|
||||
url, headers=headers, reader=lambda x: x.read().decode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
def update_labels(labels: List[str], info: str) -> None:
|
||||
def update_labels(labels: list[str], info: str) -> None:
|
||||
labels_json = json.loads(info)
|
||||
labels.extend([x["name"] for x in labels_json])
|
||||
|
||||
|
|
@ -56,10 +56,10 @@ def get_last_page_num_from_header(header: Any) -> int:
|
|||
|
||||
|
||||
@lru_cache
|
||||
def gh_get_labels(org: str, repo: str) -> List[str]:
|
||||
def gh_get_labels(org: str, repo: str) -> list[str]:
|
||||
prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100"
|
||||
header, info = request_for_labels(prefix + "&page=1")
|
||||
labels: List[str] = []
|
||||
labels: list[str] = []
|
||||
update_labels(labels, info)
|
||||
|
||||
last_page = get_last_page_num_from_header(header)
|
||||
|
|
@ -74,7 +74,7 @@ def gh_get_labels(org: str, repo: str) -> List[str]:
|
|||
|
||||
|
||||
def gh_add_labels(
|
||||
org: str, repo: str, pr_num: int, labels: Union[str, List[str]], dry_run: bool
|
||||
org: str, repo: str, pr_num: int, labels: Union[str, list[str]], dry_run: bool
|
||||
) -> None:
|
||||
if dry_run:
|
||||
print(f"Dryrun: Adding labels {labels} to PR {pr_num}")
|
||||
|
|
@ -97,7 +97,7 @@ def gh_remove_label(
|
|||
)
|
||||
|
||||
|
||||
def get_release_notes_labels(org: str, repo: str) -> List[str]:
|
||||
def get_release_notes_labels(org: str, repo: str) -> list[str]:
|
||||
return [
|
||||
label
|
||||
for label in gh_get_labels(org, repo)
|
||||
|
|
|
|||
6
.github/scripts/pytest_caching_utils.py
vendored
6
.github/scripts/pytest_caching_utils.py
vendored
|
|
@ -1,7 +1,7 @@
|
|||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, NamedTuple
|
||||
from typing import NamedTuple
|
||||
|
||||
from file_io_utils import (
|
||||
copy_file,
|
||||
|
|
@ -219,8 +219,8 @@ def _merge_lastfailed_files(source_pytest_cache: Path, dest_pytest_cache: Path)
|
|||
|
||||
|
||||
def _merged_lastfailed_content(
|
||||
from_lastfailed: Dict[str, bool], to_lastfailed: Dict[str, bool]
|
||||
) -> Dict[str, bool]:
|
||||
from_lastfailed: dict[str, bool], to_lastfailed: dict[str, bool]
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
The lastfailed files are dictionaries where the key is the test identifier.
|
||||
Each entry's value appears to always be `true`, but let's not count on that.
|
||||
|
|
|
|||
31
.github/scripts/runner_determinator.py
vendored
31
.github/scripts/runner_determinator.py
vendored
|
|
@ -61,9 +61,10 @@ import random
|
|||
import re
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from functools import lru_cache
|
||||
from collections.abc import Iterable
|
||||
from functools import cache
|
||||
from logging import LogRecord
|
||||
from typing import Any, Dict, FrozenSet, Iterable, List, NamedTuple, Set, Tuple
|
||||
from typing import Any, NamedTuple
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import yaml
|
||||
|
|
@ -105,7 +106,7 @@ class Settings(NamedTuple):
|
|||
Settings for the experiments that can be opted into.
|
||||
"""
|
||||
|
||||
experiments: Dict[str, Experiment] = {}
|
||||
experiments: dict[str, Experiment] = {}
|
||||
|
||||
|
||||
class ColorFormatter(logging.Formatter):
|
||||
|
|
@ -150,7 +151,7 @@ def set_github_output(key: str, value: str) -> None:
|
|||
f.write(f"{key}={value}\n")
|
||||
|
||||
|
||||
def _str_comma_separated_to_set(value: str) -> FrozenSet[str]:
|
||||
def _str_comma_separated_to_set(value: str) -> frozenset[str]:
|
||||
return frozenset(
|
||||
filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(",")))
|
||||
)
|
||||
|
|
@ -208,12 +209,12 @@ def parse_args() -> Any:
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_gh_client(github_token: str) -> Github:
|
||||
def get_gh_client(github_token: str) -> Github: # type: ignore[no-any-unimported]
|
||||
auth = Auth.Token(github_token)
|
||||
return Github(auth=auth)
|
||||
|
||||
|
||||
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue:
|
||||
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: # type: ignore[no-any-unimported]
|
||||
repo = gh.get_repo(repo)
|
||||
return repo.get_issue(number=issue_num)
|
||||
|
||||
|
|
@ -242,7 +243,7 @@ def get_potential_pr_author(
|
|||
raise Exception( # noqa: TRY002
|
||||
f"issue with pull request {pr_number} from repo {repository}"
|
||||
) from e
|
||||
return pull.user.login
|
||||
return pull.user.login # type: ignore[no-any-return]
|
||||
# In all other cases, return the original input username
|
||||
return username
|
||||
|
||||
|
|
@ -263,7 +264,7 @@ def load_yaml(yaml_text: str) -> Any:
|
|||
raise
|
||||
|
||||
|
||||
def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]:
|
||||
def extract_settings_user_opt_in_from_text(rollout_state: str) -> tuple[str, str]:
|
||||
"""
|
||||
Extracts the text with settings, if any, and the opted in users from the rollout state.
|
||||
|
||||
|
|
@ -279,7 +280,7 @@ def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str
|
|||
return "", rollout_state
|
||||
|
||||
|
||||
class UserOptins(Dict[str, List[str]]):
|
||||
class UserOptins(dict[str, list[str]]):
|
||||
"""
|
||||
Dictionary of users with a list of features they have opted into
|
||||
"""
|
||||
|
|
@ -420,7 +421,7 @@ def get_runner_prefix(
|
|||
rollout_state: str,
|
||||
workflow_requestors: Iterable[str],
|
||||
branch: str,
|
||||
eligible_experiments: FrozenSet[str] = frozenset(),
|
||||
eligible_experiments: frozenset[str] = frozenset(),
|
||||
is_canary: bool = False,
|
||||
) -> str:
|
||||
settings = parse_settings(rollout_state)
|
||||
|
|
@ -519,7 +520,7 @@ def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -
|
|||
return str(issue.get_comments()[0].body.strip("\n\t "))
|
||||
|
||||
|
||||
def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any:
|
||||
def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any:
|
||||
for _ in range(num_retries):
|
||||
try:
|
||||
req = Request(url=url, headers=headers)
|
||||
|
|
@ -532,8 +533,8 @@ def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> An
|
|||
return {}
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str, Any]:
|
||||
@cache
|
||||
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> dict[str, Any]:
|
||||
"""
|
||||
Dynamically get PR information
|
||||
"""
|
||||
|
|
@ -542,7 +543,7 @@ def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str
|
|||
"Accept": "application/vnd.github.v3+json",
|
||||
"Authorization": f"token {github_token}",
|
||||
}
|
||||
json_response: Dict[str, Any] = download_json(
|
||||
json_response: dict[str, Any] = download_json(
|
||||
url=f"{github_api}/issues/{pr_number}",
|
||||
headers=headers,
|
||||
)
|
||||
|
|
@ -554,7 +555,7 @@ def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str
|
|||
return json_response
|
||||
|
||||
|
||||
def get_labels(github_repo: str, github_token: str, pr_number: int) -> Set[str]:
|
||||
def get_labels(github_repo: str, github_token: str, pr_number: int) -> set[str]:
|
||||
"""
|
||||
Dynamically get the latest list of labels from the pull request
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import argparse
|
||||
import subprocess
|
||||
from typing import Dict
|
||||
|
||||
import generate_binary_build_matrix
|
||||
|
||||
|
|
@ -10,7 +9,7 @@ def tag_image(
|
|||
default_tag: str,
|
||||
release_version: str,
|
||||
dry_run: str,
|
||||
tagged_images: Dict[str, bool],
|
||||
tagged_images: dict[str, bool],
|
||||
) -> None:
|
||||
if image in tagged_images:
|
||||
return
|
||||
|
|
@ -41,7 +40,7 @@ def main() -> None:
|
|||
)
|
||||
|
||||
options = parser.parse_args()
|
||||
tagged_images: Dict[str, bool] = {}
|
||||
tagged_images: dict[str, bool] = {}
|
||||
platform_images = [
|
||||
generate_binary_build_matrix.WHEEL_CONTAINER_IMAGES,
|
||||
generate_binary_build_matrix.LIBTORCH_CONTAINER_IMAGES,
|
||||
|
|
|
|||
4
.github/scripts/test_check_labels.py
vendored
4
.github/scripts/test_check_labels.py
vendored
|
|
@ -1,6 +1,6 @@
|
|||
"""test_check_labels.py"""
|
||||
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
from unittest import main, mock, TestCase
|
||||
|
||||
from check_labels import (
|
||||
|
|
@ -31,7 +31,7 @@ def mock_delete_all_label_err_comments(pr: "GitHubPR") -> None:
|
|||
pass
|
||||
|
||||
|
||||
def mock_get_comments() -> List[GitHubComment]:
|
||||
def mock_get_comments() -> list[GitHubComment]:
|
||||
return [
|
||||
# Case 1 - a non label err comment
|
||||
GitHubComment(
|
||||
|
|
|
|||
6
.github/scripts/test_filter_test_configs.py
vendored
6
.github/scripts/test_filter_test_configs.py
vendored
|
|
@ -3,7 +3,7 @@
|
|||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
from unittest import main, mock, TestCase
|
||||
|
||||
import yaml
|
||||
|
|
@ -362,7 +362,7 @@ class TestConfigFilter(TestCase):
|
|||
self.assertEqual(case["expected"], json.dumps(filtered_test_matrix))
|
||||
|
||||
def test_set_periodic_modes(self) -> None:
|
||||
testcases: List[Dict[str, str]] = [
|
||||
testcases: list[dict[str, str]] = [
|
||||
{
|
||||
"job_name": "a CI job",
|
||||
"test_matrix": "{include: []}",
|
||||
|
|
@ -702,7 +702,7 @@ class TestConfigFilter(TestCase):
|
|||
)
|
||||
|
||||
mocked_subprocess.return_value = b""
|
||||
testcases: List[Dict[str, Any]] = [
|
||||
testcases: list[dict[str, Any]] = [
|
||||
{
|
||||
"labels": {},
|
||||
"test_matrix": '{include: [{config: "default"}]}',
|
||||
|
|
|
|||
14
.github/scripts/test_trymerge.py
vendored
14
.github/scripts/test_trymerge.py
vendored
|
|
@ -12,7 +12,7 @@ import json
|
|||
import os
|
||||
import warnings
|
||||
from hashlib import sha256
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
from unittest import main, mock, skip, TestCase
|
||||
from urllib.error import HTTPError
|
||||
|
||||
|
|
@ -170,7 +170,7 @@ def mock_gh_get_info() -> Any:
|
|||
}
|
||||
|
||||
|
||||
def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeRule]:
|
||||
def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> list[MergeRule]:
|
||||
return [
|
||||
MergeRule(
|
||||
name="mock with nonexistent check",
|
||||
|
|
@ -182,7 +182,7 @@ def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeR
|
|||
]
|
||||
|
||||
|
||||
def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
|
||||
def mocked_read_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]:
|
||||
return [
|
||||
MergeRule(
|
||||
name="super",
|
||||
|
|
@ -211,7 +211,7 @@ def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule
|
|||
|
||||
def mocked_read_merge_rules_approvers(
|
||||
repo: Any, org: str, project: str
|
||||
) -> List[MergeRule]:
|
||||
) -> list[MergeRule]:
|
||||
return [
|
||||
MergeRule(
|
||||
name="Core Reviewers",
|
||||
|
|
@ -234,11 +234,11 @@ def mocked_read_merge_rules_approvers(
|
|||
]
|
||||
|
||||
|
||||
def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> List[MergeRule]:
|
||||
def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> list[MergeRule]:
|
||||
raise RuntimeError("testing")
|
||||
|
||||
|
||||
def xla_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
|
||||
def xla_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]:
|
||||
return [
|
||||
MergeRule(
|
||||
name=" OSS CI / pytorchbot / XLA",
|
||||
|
|
@ -260,7 +260,7 @@ class DummyGitRepo(GitRepo):
|
|||
def __init__(self) -> None:
|
||||
super().__init__(get_git_repo_dir(), get_git_remote_name())
|
||||
|
||||
def commits_resolving_gh_pr(self, pr_num: int) -> List[str]:
|
||||
def commits_resolving_gh_pr(self, pr_num: int) -> list[str]:
|
||||
return ["FakeCommitSha"]
|
||||
|
||||
def commit_message(self, ref: str) -> str:
|
||||
|
|
|
|||
157
.github/scripts/trymerge.py
vendored
157
.github/scripts/trymerge.py
vendored
|
|
@ -17,21 +17,12 @@ import re
|
|||
import time
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Pattern,
|
||||
Tuple,
|
||||
)
|
||||
from re import Pattern
|
||||
from typing import Any, Callable, cast, NamedTuple, Optional
|
||||
from warnings import warn
|
||||
|
||||
import yaml
|
||||
|
|
@ -78,7 +69,7 @@ class JobCheckState(NamedTuple):
|
|||
summary: Optional[str]
|
||||
|
||||
|
||||
JobNameToStateDict = Dict[str, JobCheckState]
|
||||
JobNameToStateDict = dict[str, JobCheckState]
|
||||
|
||||
|
||||
class WorkflowCheckState:
|
||||
|
|
@ -468,10 +459,10 @@ def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any:
|
|||
return rc["data"]["repository"]["pullRequest"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def gh_get_team_members(org: str, name: str) -> List[str]:
|
||||
rc: List[str] = []
|
||||
team_members: Dict[str, Any] = {
|
||||
@cache
|
||||
def gh_get_team_members(org: str, name: str) -> list[str]:
|
||||
rc: list[str] = []
|
||||
team_members: dict[str, Any] = {
|
||||
"pageInfo": {"hasNextPage": "true", "endCursor": None}
|
||||
}
|
||||
while bool(team_members["pageInfo"]["hasNextPage"]):
|
||||
|
|
@ -503,14 +494,14 @@ def is_passing_status(status: Optional[str]) -> bool:
|
|||
|
||||
def add_workflow_conclusions(
|
||||
checksuites: Any,
|
||||
get_next_checkruns_page: Callable[[List[Dict[str, Dict[str, Any]]], int, Any], Any],
|
||||
get_next_checkruns_page: Callable[[list[dict[str, dict[str, Any]]], int, Any], Any],
|
||||
get_next_checksuites: Callable[[Any], Any],
|
||||
) -> JobNameToStateDict:
|
||||
# graphql seems to favor the most recent workflow run, so in theory we
|
||||
# shouldn't need to account for reruns, but do it just in case
|
||||
|
||||
# workflow -> job -> job info
|
||||
workflows: Dict[str, WorkflowCheckState] = {}
|
||||
workflows: dict[str, WorkflowCheckState] = {}
|
||||
|
||||
# for the jobs that don't have a workflow
|
||||
no_workflow_obj: WorkflowCheckState = WorkflowCheckState("", "", 0, None)
|
||||
|
|
@ -633,8 +624,8 @@ def _revlist_to_prs(
|
|||
pr: "GitHubPR",
|
||||
rev_list: Iterable[str],
|
||||
should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None,
|
||||
) -> List[Tuple["GitHubPR", str]]:
|
||||
rc: List[Tuple[GitHubPR, str]] = []
|
||||
) -> list[tuple["GitHubPR", str]]:
|
||||
rc: list[tuple[GitHubPR, str]] = []
|
||||
for idx, rev in enumerate(rev_list):
|
||||
msg = repo.commit_message(rev)
|
||||
m = RE_PULL_REQUEST_RESOLVED.search(msg)
|
||||
|
|
@ -656,7 +647,7 @@ def _revlist_to_prs(
|
|||
|
||||
def get_ghstack_prs(
|
||||
repo: GitRepo, pr: "GitHubPR", open_only: bool = True
|
||||
) -> List[Tuple["GitHubPR", str]]:
|
||||
) -> list[tuple["GitHubPR", str]]:
|
||||
"""
|
||||
Get the PRs in the stack that are below this PR (inclusive). Throws error if any of the open PRs are out of sync.
|
||||
@:param open_only: Only return open PRs
|
||||
|
|
@ -701,14 +692,14 @@ class GitHubPR:
|
|||
self.project = project
|
||||
self.pr_num = pr_num
|
||||
self.info = gh_get_pr_info(org, project, pr_num)
|
||||
self.changed_files: Optional[List[str]] = None
|
||||
self.labels: Optional[List[str]] = None
|
||||
self.changed_files: Optional[list[str]] = None
|
||||
self.labels: Optional[list[str]] = None
|
||||
self.conclusions: Optional[JobNameToStateDict] = None
|
||||
self.comments: Optional[List[GitHubComment]] = None
|
||||
self._authors: Optional[List[Tuple[str, str]]] = None
|
||||
self._reviews: Optional[List[Tuple[str, str]]] = None
|
||||
self.comments: Optional[list[GitHubComment]] = None
|
||||
self._authors: Optional[list[tuple[str, str]]] = None
|
||||
self._reviews: Optional[list[tuple[str, str]]] = None
|
||||
self.merge_base: Optional[str] = None
|
||||
self.submodules: Optional[List[str]] = None
|
||||
self.submodules: Optional[list[str]] = None
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return bool(self.info["closed"])
|
||||
|
|
@ -763,7 +754,7 @@ class GitHubPR:
|
|||
|
||||
return self.merge_base
|
||||
|
||||
def get_changed_files(self) -> List[str]:
|
||||
def get_changed_files(self) -> list[str]:
|
||||
if self.changed_files is None:
|
||||
info = self.info
|
||||
unique_changed_files = set()
|
||||
|
|
@ -786,14 +777,14 @@ class GitHubPR:
|
|||
raise RuntimeError("Changed file count mismatch")
|
||||
return self.changed_files
|
||||
|
||||
def get_submodules(self) -> List[str]:
|
||||
def get_submodules(self) -> list[str]:
|
||||
if self.submodules is None:
|
||||
rc = gh_graphql(GH_GET_REPO_SUBMODULES, name=self.project, owner=self.org)
|
||||
info = rc["data"]["repository"]["submodules"]
|
||||
self.submodules = [s["path"] for s in info["nodes"]]
|
||||
return self.submodules
|
||||
|
||||
def get_changed_submodules(self) -> List[str]:
|
||||
def get_changed_submodules(self) -> list[str]:
|
||||
submodules = self.get_submodules()
|
||||
return [f for f in self.get_changed_files() if f in submodules]
|
||||
|
||||
|
|
@ -809,7 +800,7 @@ class GitHubPR:
|
|||
and all("submodule" not in label for label in self.get_labels())
|
||||
)
|
||||
|
||||
def _get_reviews(self) -> List[Tuple[str, str]]:
|
||||
def _get_reviews(self) -> list[tuple[str, str]]:
|
||||
if self._reviews is None:
|
||||
self._reviews = []
|
||||
info = self.info
|
||||
|
|
@ -834,7 +825,7 @@ class GitHubPR:
|
|||
reviews[author] = state
|
||||
return list(reviews.items())
|
||||
|
||||
def get_approved_by(self) -> List[str]:
|
||||
def get_approved_by(self) -> list[str]:
|
||||
return [login for (login, state) in self._get_reviews() if state == "APPROVED"]
|
||||
|
||||
def get_commit_count(self) -> int:
|
||||
|
|
@ -843,12 +834,12 @@ class GitHubPR:
|
|||
def get_pr_creator_login(self) -> str:
|
||||
return cast(str, self.info["author"]["login"])
|
||||
|
||||
def _fetch_authors(self) -> List[Tuple[str, str]]:
|
||||
def _fetch_authors(self) -> list[tuple[str, str]]:
|
||||
if self._authors is not None:
|
||||
return self._authors
|
||||
authors: List[Tuple[str, str]] = []
|
||||
authors: list[tuple[str, str]] = []
|
||||
|
||||
def add_authors(info: Dict[str, Any]) -> None:
|
||||
def add_authors(info: dict[str, Any]) -> None:
|
||||
for node in info["commits_with_authors"]["nodes"]:
|
||||
for author_node in node["commit"]["authors"]["nodes"]:
|
||||
user_node = author_node["user"]
|
||||
|
|
@ -881,7 +872,7 @@ class GitHubPR:
|
|||
def get_committer_author(self, num: int = 0) -> str:
|
||||
return self._fetch_authors()[num][1]
|
||||
|
||||
def get_labels(self) -> List[str]:
|
||||
def get_labels(self) -> list[str]:
|
||||
if self.labels is not None:
|
||||
return self.labels
|
||||
labels = (
|
||||
|
|
@ -899,7 +890,7 @@ class GitHubPR:
|
|||
orig_last_commit = self.last_commit()
|
||||
|
||||
def get_pr_next_check_runs(
|
||||
edges: List[Dict[str, Dict[str, Any]]], edge_idx: int, checkruns: Any
|
||||
edges: list[dict[str, dict[str, Any]]], edge_idx: int, checkruns: Any
|
||||
) -> Any:
|
||||
rc = gh_graphql(
|
||||
GH_GET_PR_NEXT_CHECK_RUNS,
|
||||
|
|
@ -951,7 +942,7 @@ class GitHubPR:
|
|||
|
||||
return self.conclusions
|
||||
|
||||
def get_authors(self) -> Dict[str, str]:
|
||||
def get_authors(self) -> dict[str, str]:
|
||||
rc = {}
|
||||
for idx in range(len(self._fetch_authors())):
|
||||
rc[self.get_committer_login(idx)] = self.get_committer_author(idx)
|
||||
|
|
@ -995,7 +986,7 @@ class GitHubPR:
|
|||
url=node["url"],
|
||||
)
|
||||
|
||||
def get_comments(self) -> List[GitHubComment]:
|
||||
def get_comments(self) -> list[GitHubComment]:
|
||||
if self.comments is not None:
|
||||
return self.comments
|
||||
self.comments = []
|
||||
|
|
@ -1069,7 +1060,7 @@ class GitHubPR:
|
|||
skip_mandatory_checks: bool,
|
||||
comment_id: Optional[int] = None,
|
||||
skip_all_rule_checks: bool = False,
|
||||
) -> List["GitHubPR"]:
|
||||
) -> list["GitHubPR"]:
|
||||
assert self.is_ghstack_pr()
|
||||
ghstack_prs = get_ghstack_prs(
|
||||
repo, self, open_only=False
|
||||
|
|
@ -1099,7 +1090,7 @@ class GitHubPR:
|
|||
def gen_commit_message(
|
||||
self,
|
||||
filter_ghstack: bool = False,
|
||||
ghstack_deps: Optional[List["GitHubPR"]] = None,
|
||||
ghstack_deps: Optional[list["GitHubPR"]] = None,
|
||||
) -> str:
|
||||
"""Fetches title and body from PR description
|
||||
adds reviewed by, pull request resolved and optionally
|
||||
|
|
@ -1151,7 +1142,7 @@ class GitHubPR:
|
|||
skip_mandatory_checks: bool = False,
|
||||
dry_run: bool = False,
|
||||
comment_id: Optional[int] = None,
|
||||
ignore_current_checks: Optional[List[str]] = None,
|
||||
ignore_current_checks: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
# Raises exception if matching rule is not found
|
||||
(
|
||||
|
|
@ -1223,7 +1214,7 @@ class GitHubPR:
|
|||
comment_id: Optional[int] = None,
|
||||
branch: Optional[str] = None,
|
||||
skip_all_rule_checks: bool = False,
|
||||
) -> List["GitHubPR"]:
|
||||
) -> list["GitHubPR"]:
|
||||
"""
|
||||
:param skip_all_rule_checks: If true, skips all rule checks, useful for dry-running merge locally
|
||||
"""
|
||||
|
|
@ -1263,14 +1254,14 @@ class PostCommentError(Exception):
|
|||
@dataclass
|
||||
class MergeRule:
|
||||
name: str
|
||||
patterns: List[str]
|
||||
approved_by: List[str]
|
||||
mandatory_checks_name: Optional[List[str]]
|
||||
patterns: list[str]
|
||||
approved_by: list[str]
|
||||
mandatory_checks_name: Optional[list[str]]
|
||||
ignore_flaky_failures: bool = True
|
||||
|
||||
|
||||
def gen_new_issue_link(
|
||||
org: str, project: str, labels: List[str], template: str = "bug-report.yml"
|
||||
org: str, project: str, labels: list[str], template: str = "bug-report.yml"
|
||||
) -> str:
|
||||
labels_str = ",".join(labels)
|
||||
return (
|
||||
|
|
@ -1282,7 +1273,7 @@ def gen_new_issue_link(
|
|||
|
||||
def read_merge_rules(
|
||||
repo: Optional[GitRepo], org: str, project: str
|
||||
) -> List[MergeRule]:
|
||||
) -> list[MergeRule]:
|
||||
"""Returns the list of all merge rules for the repo or project.
|
||||
|
||||
NB: this function is used in Meta-internal workflows, see the comment
|
||||
|
|
@ -1312,12 +1303,12 @@ def find_matching_merge_rule(
|
|||
repo: Optional[GitRepo] = None,
|
||||
skip_mandatory_checks: bool = False,
|
||||
skip_internal_checks: bool = False,
|
||||
ignore_current_checks: Optional[List[str]] = None,
|
||||
) -> Tuple[
|
||||
ignore_current_checks: Optional[list[str]] = None,
|
||||
) -> tuple[
|
||||
MergeRule,
|
||||
List[Tuple[str, Optional[str], Optional[int]]],
|
||||
List[Tuple[str, Optional[str], Optional[int]]],
|
||||
Dict[str, List[Any]],
|
||||
list[tuple[str, Optional[str], Optional[int]]],
|
||||
list[tuple[str, Optional[str], Optional[int]]],
|
||||
dict[str, list[Any]],
|
||||
]:
|
||||
"""
|
||||
Returns merge rule matching to this pr together with the list of associated pending
|
||||
|
|
@ -1504,13 +1495,13 @@ def find_matching_merge_rule(
|
|||
raise MergeRuleFailedError(reject_reason, rule)
|
||||
|
||||
|
||||
def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
|
||||
def checks_to_str(checks: list[tuple[str, Optional[str]]]) -> str:
|
||||
return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)
|
||||
|
||||
|
||||
def checks_to_markdown_bullets(
|
||||
checks: List[Tuple[str, Optional[str], Optional[int]]],
|
||||
) -> List[str]:
|
||||
checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||
) -> list[str]:
|
||||
return [
|
||||
f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]
|
||||
]
|
||||
|
|
@ -1518,7 +1509,7 @@ def checks_to_markdown_bullets(
|
|||
|
||||
def manually_close_merged_pr(
|
||||
pr: GitHubPR,
|
||||
additional_merged_prs: List[GitHubPR],
|
||||
additional_merged_prs: list[GitHubPR],
|
||||
merge_commit_sha: str,
|
||||
dry_run: bool,
|
||||
) -> None:
|
||||
|
|
@ -1551,12 +1542,12 @@ def save_merge_record(
|
|||
owner: str,
|
||||
project: str,
|
||||
author: str,
|
||||
pending_checks: List[Tuple[str, Optional[str], Optional[int]]],
|
||||
failed_checks: List[Tuple[str, Optional[str], Optional[int]]],
|
||||
ignore_current_checks: List[Tuple[str, Optional[str], Optional[int]]],
|
||||
broken_trunk_checks: List[Tuple[str, Optional[str], Optional[int]]],
|
||||
flaky_checks: List[Tuple[str, Optional[str], Optional[int]]],
|
||||
unstable_checks: List[Tuple[str, Optional[str], Optional[int]]],
|
||||
pending_checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||
failed_checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||
ignore_current_checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||
broken_trunk_checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||
flaky_checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||
unstable_checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||
last_commit_sha: str,
|
||||
merge_base_sha: str,
|
||||
merge_commit_sha: str = "",
|
||||
|
|
@ -1714,9 +1705,9 @@ def is_invalid_cancel(
|
|||
def get_classifications(
|
||||
pr_num: int,
|
||||
project: str,
|
||||
checks: Dict[str, JobCheckState],
|
||||
ignore_current_checks: Optional[List[str]],
|
||||
) -> Dict[str, JobCheckState]:
|
||||
checks: dict[str, JobCheckState],
|
||||
ignore_current_checks: Optional[list[str]],
|
||||
) -> dict[str, JobCheckState]:
|
||||
# Get the failure classification from Dr.CI, which is the source of truth
|
||||
# going forward. It's preferable to try calling Dr.CI API directly first
|
||||
# to get the latest results as well as update Dr.CI PR comment
|
||||
|
|
@ -1825,7 +1816,7 @@ def get_classifications(
|
|||
|
||||
def filter_checks_with_lambda(
|
||||
checks: JobNameToStateDict, status_filter: Callable[[Optional[str]], bool]
|
||||
) -> List[JobCheckState]:
|
||||
) -> list[JobCheckState]:
|
||||
return [check for check in checks.values() if status_filter(check.status)]
|
||||
|
||||
|
||||
|
|
@ -1841,7 +1832,7 @@ def get_pr_commit_sha(repo: GitRepo, pr: GitHubPR) -> str:
|
|||
|
||||
def validate_revert(
|
||||
repo: GitRepo, pr: GitHubPR, *, comment_id: Optional[int] = None
|
||||
) -> Tuple[str, str]:
|
||||
) -> tuple[str, str]:
|
||||
comment = (
|
||||
pr.get_last_comment()
|
||||
if comment_id is None
|
||||
|
|
@ -1871,7 +1862,7 @@ def validate_revert(
|
|||
|
||||
def get_ghstack_dependent_prs(
|
||||
repo: GitRepo, pr: GitHubPR, only_closed: bool = True
|
||||
) -> List[Tuple[str, GitHubPR]]:
|
||||
) -> list[tuple[str, GitHubPR]]:
|
||||
"""
|
||||
Get the PRs in the stack that are above this PR (inclusive).
|
||||
Throws error if stack have branched or original branches are gone
|
||||
|
|
@ -1897,7 +1888,7 @@ def get_ghstack_dependent_prs(
|
|||
# Remove commits original PR depends on
|
||||
if skip_len > 0:
|
||||
rev_list = rev_list[:-skip_len]
|
||||
rc: List[Tuple[str, GitHubPR]] = []
|
||||
rc: list[tuple[str, GitHubPR]] = []
|
||||
for pr_, sha in _revlist_to_prs(repo, pr, rev_list):
|
||||
if not pr_.is_closed():
|
||||
if not only_closed:
|
||||
|
|
@ -1910,7 +1901,7 @@ def get_ghstack_dependent_prs(
|
|||
|
||||
def do_revert_prs(
|
||||
repo: GitRepo,
|
||||
shas_and_prs: List[Tuple[str, GitHubPR]],
|
||||
shas_and_prs: list[tuple[str, GitHubPR]],
|
||||
*,
|
||||
author_login: str,
|
||||
extra_msg: str = "",
|
||||
|
|
@ -2001,7 +1992,7 @@ def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None:
|
|||
if skip_mandatory_checks:
|
||||
return
|
||||
response = cast(
|
||||
Dict[str, Any],
|
||||
dict[str, Any],
|
||||
gh_fetch_json_list(
|
||||
"https://api.github.com/search/issues",
|
||||
# Having two label: queries is an AND operation
|
||||
|
|
@ -2019,29 +2010,29 @@ def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None:
|
|||
return
|
||||
|
||||
|
||||
def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
|
||||
def has_label(labels: list[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
|
||||
return len(list(filter(pattern.match, labels))) > 0
|
||||
|
||||
|
||||
def categorize_checks(
|
||||
check_runs: JobNameToStateDict,
|
||||
required_checks: List[str],
|
||||
required_checks: list[str],
|
||||
ok_failed_checks_threshold: Optional[int] = None,
|
||||
) -> Tuple[
|
||||
List[Tuple[str, Optional[str], Optional[int]]],
|
||||
List[Tuple[str, Optional[str], Optional[int]]],
|
||||
Dict[str, List[Any]],
|
||||
) -> tuple[
|
||||
list[tuple[str, Optional[str], Optional[int]]],
|
||||
list[tuple[str, Optional[str], Optional[int]]],
|
||||
dict[str, list[Any]],
|
||||
]:
|
||||
"""
|
||||
Categories all jobs into the list of pending and failing jobs. All known flaky
|
||||
failures and broken trunk are ignored by defaults when ok_failed_checks_threshold
|
||||
is not set (unlimited)
|
||||
"""
|
||||
pending_checks: List[Tuple[str, Optional[str], Optional[int]]] = []
|
||||
failed_checks: List[Tuple[str, Optional[str], Optional[int]]] = []
|
||||
pending_checks: list[tuple[str, Optional[str], Optional[int]]] = []
|
||||
failed_checks: list[tuple[str, Optional[str], Optional[int]]] = []
|
||||
|
||||
# failed_checks_categorization is used to keep track of all ignorable failures when saving the merge record on s3
|
||||
failed_checks_categorization: Dict[str, List[Any]] = defaultdict(list)
|
||||
failed_checks_categorization: dict[str, list[Any]] = defaultdict(list)
|
||||
|
||||
# If required_checks is not set or empty, consider all names are relevant
|
||||
relevant_checknames = [
|
||||
|
|
|
|||
13
.github/scripts/trymerge_explainer.py
vendored
13
.github/scripts/trymerge_explainer.py
vendored
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import re
|
||||
from typing import List, Optional, Pattern, Tuple
|
||||
from re import Pattern
|
||||
from typing import Optional
|
||||
|
||||
|
||||
BOT_COMMANDS_WIKI = "https://github.com/pytorch/pytorch/wiki/Bot-commands"
|
||||
|
|
@ -13,13 +14,13 @@ CONTACT_US = f"Questions? Feedback? Please reach out to the [PyTorch DevX Team](
|
|||
ALTERNATIVES = f"Learn more about merging in the [wiki]({BOT_COMMANDS_WIKI})."
|
||||
|
||||
|
||||
def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
|
||||
def has_label(labels: list[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
|
||||
return len(list(filter(pattern.match, labels))) > 0
|
||||
|
||||
|
||||
class TryMergeExplainer:
|
||||
force: bool
|
||||
labels: List[str]
|
||||
labels: list[str]
|
||||
pr_num: int
|
||||
org: str
|
||||
project: str
|
||||
|
|
@ -31,7 +32,7 @@ class TryMergeExplainer:
|
|||
def __init__(
|
||||
self,
|
||||
force: bool,
|
||||
labels: List[str],
|
||||
labels: list[str],
|
||||
pr_num: int,
|
||||
org: str,
|
||||
project: str,
|
||||
|
|
@ -47,7 +48,7 @@ class TryMergeExplainer:
|
|||
def _get_flag_msg(
|
||||
self,
|
||||
ignore_current_checks: Optional[
|
||||
List[Tuple[str, Optional[str], Optional[int]]]
|
||||
list[tuple[str, Optional[str], Optional[int]]]
|
||||
] = None,
|
||||
) -> str:
|
||||
if self.force:
|
||||
|
|
@ -68,7 +69,7 @@ class TryMergeExplainer:
|
|||
def get_merge_message(
|
||||
self,
|
||||
ignore_current_checks: Optional[
|
||||
List[Tuple[str, Optional[str], Optional[int]]]
|
||||
list[tuple[str, Optional[str], Optional[int]]]
|
||||
] = None,
|
||||
) -> str:
|
||||
title = "### Merge started"
|
||||
|
|
|
|||
3
.github/scripts/tryrebase.py
vendored
3
.github/scripts/tryrebase.py
vendored
|
|
@ -5,7 +5,8 @@ import os
|
|||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any, Generator
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from github_utils import gh_post_pr_comment as gh_post_comment
|
||||
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
||||
|
|
|
|||
31
.github/workflows/_runner-determinator.yml
vendored
31
.github/workflows/_runner-determinator.yml
vendored
|
|
@ -129,9 +129,10 @@ jobs:
|
|||
import re
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from functools import lru_cache
|
||||
from collections.abc import Iterable
|
||||
from functools import cache
|
||||
from logging import LogRecord
|
||||
from typing import Any, Dict, FrozenSet, Iterable, List, NamedTuple, Set, Tuple
|
||||
from typing import Any, NamedTuple
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import yaml
|
||||
|
|
@ -173,7 +174,7 @@ jobs:
|
|||
Settings for the experiments that can be opted into.
|
||||
"""
|
||||
|
||||
experiments: Dict[str, Experiment] = {}
|
||||
experiments: dict[str, Experiment] = {}
|
||||
|
||||
|
||||
class ColorFormatter(logging.Formatter):
|
||||
|
|
@ -218,7 +219,7 @@ jobs:
|
|||
f.write(f"{key}={value}\n")
|
||||
|
||||
|
||||
def _str_comma_separated_to_set(value: str) -> FrozenSet[str]:
|
||||
def _str_comma_separated_to_set(value: str) -> frozenset[str]:
|
||||
return frozenset(
|
||||
filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(",")))
|
||||
)
|
||||
|
|
@ -276,12 +277,12 @@ jobs:
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_gh_client(github_token: str) -> Github:
|
||||
def get_gh_client(github_token: str) -> Github: # type: ignore[no-any-unimported]
|
||||
auth = Auth.Token(github_token)
|
||||
return Github(auth=auth)
|
||||
|
||||
|
||||
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue:
|
||||
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: # type: ignore[no-any-unimported]
|
||||
repo = gh.get_repo(repo)
|
||||
return repo.get_issue(number=issue_num)
|
||||
|
||||
|
|
@ -310,7 +311,7 @@ jobs:
|
|||
raise Exception( # noqa: TRY002
|
||||
f"issue with pull request {pr_number} from repo {repository}"
|
||||
) from e
|
||||
return pull.user.login
|
||||
return pull.user.login # type: ignore[no-any-return]
|
||||
# In all other cases, return the original input username
|
||||
return username
|
||||
|
||||
|
|
@ -331,7 +332,7 @@ jobs:
|
|||
raise
|
||||
|
||||
|
||||
def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]:
|
||||
def extract_settings_user_opt_in_from_text(rollout_state: str) -> tuple[str, str]:
|
||||
"""
|
||||
Extracts the text with settings, if any, and the opted in users from the rollout state.
|
||||
|
||||
|
|
@ -347,7 +348,7 @@ jobs:
|
|||
return "", rollout_state
|
||||
|
||||
|
||||
class UserOptins(Dict[str, List[str]]):
|
||||
class UserOptins(dict[str, list[str]]):
|
||||
"""
|
||||
Dictionary of users with a list of features they have opted into
|
||||
"""
|
||||
|
|
@ -488,7 +489,7 @@ jobs:
|
|||
rollout_state: str,
|
||||
workflow_requestors: Iterable[str],
|
||||
branch: str,
|
||||
eligible_experiments: FrozenSet[str] = frozenset(),
|
||||
eligible_experiments: frozenset[str] = frozenset(),
|
||||
is_canary: bool = False,
|
||||
) -> str:
|
||||
settings = parse_settings(rollout_state)
|
||||
|
|
@ -587,7 +588,7 @@ jobs:
|
|||
return str(issue.get_comments()[0].body.strip("\n\t "))
|
||||
|
||||
|
||||
def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any:
|
||||
def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any:
|
||||
for _ in range(num_retries):
|
||||
try:
|
||||
req = Request(url=url, headers=headers)
|
||||
|
|
@ -600,8 +601,8 @@ jobs:
|
|||
return {}
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str, Any]:
|
||||
@cache
|
||||
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> dict[str, Any]:
|
||||
"""
|
||||
Dynamically get PR information
|
||||
"""
|
||||
|
|
@ -610,7 +611,7 @@ jobs:
|
|||
"Accept": "application/vnd.github.v3+json",
|
||||
"Authorization": f"token {github_token}",
|
||||
}
|
||||
json_response: Dict[str, Any] = download_json(
|
||||
json_response: dict[str, Any] = download_json(
|
||||
url=f"{github_api}/issues/{pr_number}",
|
||||
headers=headers,
|
||||
)
|
||||
|
|
@ -622,7 +623,7 @@ jobs:
|
|||
return json_response
|
||||
|
||||
|
||||
def get_labels(github_repo: str, github_token: str, pr_number: int) -> Set[str]:
|
||||
def get_labels(github_repo: str, github_token: str, pr_number: int) -> set[str]:
|
||||
"""
|
||||
Dynamically get the latest list of labels from the pull request
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user