mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
PEP585: .github (#145707)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145707 Approved by: https://github.com/huydhn
This commit is contained in:
parent
bfaf76bfc6
commit
60f98262f1
10
.github/scripts/cherry_pick.py
vendored
10
.github/scripts/cherry_pick.py
vendored
|
|
@ -3,7 +3,7 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Any, cast, Dict, List, Optional
|
from typing import Any, cast, Optional
|
||||||
from urllib.error import HTTPError
|
from urllib.error import HTTPError
|
||||||
|
|
||||||
from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels
|
from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels
|
||||||
|
|
@ -67,7 +67,7 @@ def get_release_version(onto_branch: str) -> Optional[str]:
|
||||||
|
|
||||||
def get_tracker_issues(
|
def get_tracker_issues(
|
||||||
org: str, project: str, onto_branch: str
|
org: str, project: str, onto_branch: str
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Find the tracker issue from the repo. The tracker issue needs to have the title
|
Find the tracker issue from the repo. The tracker issue needs to have the title
|
||||||
like [VERSION] Release Tracker following the convention on PyTorch
|
like [VERSION] Release Tracker following the convention on PyTorch
|
||||||
|
|
@ -117,7 +117,7 @@ def cherry_pick(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
res = cast(
|
res = cast(
|
||||||
Dict[str, Any],
|
dict[str, Any],
|
||||||
post_tracker_issue_comment(
|
post_tracker_issue_comment(
|
||||||
org,
|
org,
|
||||||
project,
|
project,
|
||||||
|
|
@ -220,7 +220,7 @@ def submit_pr(
|
||||||
|
|
||||||
def post_pr_comment(
|
def post_pr_comment(
|
||||||
org: str, project: str, pr_num: int, msg: str, dry_run: bool = False
|
org: str, project: str, pr_num: int, msg: str, dry_run: bool = False
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Post a comment on the PR itself to point to the cherry picking PR when success
|
Post a comment on the PR itself to point to the cherry picking PR when success
|
||||||
or print the error when failure
|
or print the error when failure
|
||||||
|
|
@ -255,7 +255,7 @@ def post_tracker_issue_comment(
|
||||||
classification: str,
|
classification: str,
|
||||||
fixes: str,
|
fixes: str,
|
||||||
dry_run: bool = False,
|
dry_run: bool = False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Post a comment on the tracker issue (if any) to record the cherry pick
|
Post a comment on the tracker issue (if any) to record the cherry pick
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import re
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from gitutils import retries_decorator
|
from gitutils import retries_decorator
|
||||||
|
|
@ -76,7 +76,7 @@ DISABLED_TESTS_JSON = (
|
||||||
|
|
||||||
|
|
||||||
@retries_decorator()
|
@retries_decorator()
|
||||||
def query_db(query: str, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
def query_db(query: str, params: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
return query_clickhouse(query, params)
|
return query_clickhouse(query, params)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -97,7 +97,7 @@ def download_log_worker(temp_dir: str, id: int, name: str) -> None:
|
||||||
f.write(data)
|
f.write(data)
|
||||||
|
|
||||||
|
|
||||||
def printer(item: Tuple[str, Tuple[int, str, List[Any]]], extra: str) -> None:
|
def printer(item: tuple[str, tuple[int, str, list[Any]]], extra: str) -> None:
|
||||||
test, (_, link, _) = item
|
test, (_, link, _) = item
|
||||||
print(f"{link:<55} {test:<120} {extra}")
|
print(f"{link:<55} {test:<120} {extra}")
|
||||||
|
|
||||||
|
|
@ -120,8 +120,8 @@ def close_issue(num: int) -> None:
|
||||||
|
|
||||||
|
|
||||||
def check_if_exists(
|
def check_if_exists(
|
||||||
item: Tuple[str, Tuple[int, str, List[str]]], all_logs: List[str]
|
item: tuple[str, tuple[int, str, list[str]]], all_logs: list[str]
|
||||||
) -> Tuple[bool, str]:
|
) -> tuple[bool, str]:
|
||||||
test, (_, link, _) = item
|
test, (_, link, _) = item
|
||||||
# Test names should look like `test_a (module.path.classname)`
|
# Test names should look like `test_a (module.path.classname)`
|
||||||
reg = re.match(r"(\S+) \((\S*)\)", test)
|
reg = re.match(r"(\S+) \((\S*)\)", test)
|
||||||
|
|
|
||||||
14
.github/scripts/collect_ciflow_labels.py
vendored
14
.github/scripts/collect_ciflow_labels.py
vendored
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, cast, Dict, List, Set
|
from typing import Any, cast
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
@ -10,9 +10,9 @@ import yaml
|
||||||
GITHUB_DIR = Path(__file__).parent.parent
|
GITHUB_DIR = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
|
||||||
def get_workflows_push_tags() -> Set[str]:
|
def get_workflows_push_tags() -> set[str]:
|
||||||
"Extract all known push tags from workflows"
|
"Extract all known push tags from workflows"
|
||||||
rc: Set[str] = set()
|
rc: set[str] = set()
|
||||||
for fname in (GITHUB_DIR / "workflows").glob("*.yml"):
|
for fname in (GITHUB_DIR / "workflows").glob("*.yml"):
|
||||||
with fname.open("r") as f:
|
with fname.open("r") as f:
|
||||||
wf_yml = yaml.safe_load(f)
|
wf_yml = yaml.safe_load(f)
|
||||||
|
|
@ -25,19 +25,19 @@ def get_workflows_push_tags() -> Set[str]:
|
||||||
return rc
|
return rc
|
||||||
|
|
||||||
|
|
||||||
def filter_ciflow_tags(tags: Set[str]) -> List[str]:
|
def filter_ciflow_tags(tags: set[str]) -> list[str]:
|
||||||
"Return sorted list of ciflow tags"
|
"Return sorted list of ciflow tags"
|
||||||
return sorted(
|
return sorted(
|
||||||
tag[:-2] for tag in tags if tag.startswith("ciflow/") and tag.endswith("/*")
|
tag[:-2] for tag in tags if tag.startswith("ciflow/") and tag.endswith("/*")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def read_probot_config() -> Dict[str, Any]:
|
def read_probot_config() -> dict[str, Any]:
|
||||||
with (GITHUB_DIR / "pytorch-probot.yml").open("r") as f:
|
with (GITHUB_DIR / "pytorch-probot.yml").open("r") as f:
|
||||||
return cast(Dict[str, Any], yaml.safe_load(f))
|
return cast(dict[str, Any], yaml.safe_load(f))
|
||||||
|
|
||||||
|
|
||||||
def update_probot_config(labels: Set[str]) -> None:
|
def update_probot_config(labels: set[str]) -> None:
|
||||||
orig = read_probot_config()
|
orig = read_probot_config()
|
||||||
orig["ciflow_push_tags"] = filter_ciflow_tags(labels)
|
orig["ciflow_push_tags"] = filter_ciflow_tags(labels)
|
||||||
with (GITHUB_DIR / "pytorch-probot.yml").open("w") as f:
|
with (GITHUB_DIR / "pytorch-probot.yml").open("w") as f:
|
||||||
|
|
|
||||||
28
.github/scripts/delete_old_branches.py
vendored
28
.github/scripts/delete_old_branches.py
vendored
|
|
@ -4,7 +4,7 @@ import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Set
|
from typing import Any, Callable
|
||||||
|
|
||||||
from github_utils import gh_fetch_json_dict, gh_graphql
|
from github_utils import gh_fetch_json_dict, gh_graphql
|
||||||
from gitutils import GitRepo
|
from gitutils import GitRepo
|
||||||
|
|
@ -112,7 +112,7 @@ def convert_gh_timestamp(date: str) -> float:
|
||||||
return datetime.strptime(date, "%Y-%m-%dT%H:%M:%SZ").timestamp()
|
return datetime.strptime(date, "%Y-%m-%dT%H:%M:%SZ").timestamp()
|
||||||
|
|
||||||
|
|
||||||
def get_branches(repo: GitRepo) -> Dict[str, Any]:
|
def get_branches(repo: GitRepo) -> dict[str, Any]:
|
||||||
# Query locally for branches, group by branch base name (e.g. gh/blah/base -> gh/blah), and get the most recent branch
|
# Query locally for branches, group by branch base name (e.g. gh/blah/base -> gh/blah), and get the most recent branch
|
||||||
git_response = repo._run_git(
|
git_response = repo._run_git(
|
||||||
"for-each-ref",
|
"for-each-ref",
|
||||||
|
|
@ -120,7 +120,7 @@ def get_branches(repo: GitRepo) -> Dict[str, Any]:
|
||||||
"--format=%(refname) %(committerdate:iso-strict)",
|
"--format=%(refname) %(committerdate:iso-strict)",
|
||||||
"refs/remotes/origin",
|
"refs/remotes/origin",
|
||||||
)
|
)
|
||||||
branches_by_base_name: Dict[str, Any] = {}
|
branches_by_base_name: dict[str, Any] = {}
|
||||||
for line in git_response.splitlines():
|
for line in git_response.splitlines():
|
||||||
branch, date = line.split(" ")
|
branch, date = line.split(" ")
|
||||||
re_branch = re.match(r"refs/remotes/origin/(.*)", branch)
|
re_branch = re.match(r"refs/remotes/origin/(.*)", branch)
|
||||||
|
|
@ -140,14 +140,14 @@ def get_branches(repo: GitRepo) -> Dict[str, Any]:
|
||||||
|
|
||||||
def paginate_graphql(
|
def paginate_graphql(
|
||||||
query: str,
|
query: str,
|
||||||
kwargs: Dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
termination_func: Callable[[List[Dict[str, Any]]], bool],
|
termination_func: Callable[[list[dict[str, Any]]], bool],
|
||||||
get_data: Callable[[Dict[str, Any]], List[Dict[str, Any]]],
|
get_data: Callable[[dict[str, Any]], list[dict[str, Any]]],
|
||||||
get_page_info: Callable[[Dict[str, Any]], Dict[str, Any]],
|
get_page_info: Callable[[dict[str, Any]], dict[str, Any]],
|
||||||
) -> List[Any]:
|
) -> list[Any]:
|
||||||
hasNextPage = True
|
hasNextPage = True
|
||||||
endCursor = None
|
endCursor = None
|
||||||
data: List[Dict[str, Any]] = []
|
data: list[dict[str, Any]] = []
|
||||||
while hasNextPage:
|
while hasNextPage:
|
||||||
ESTIMATED_TOKENS[0] += 1
|
ESTIMATED_TOKENS[0] += 1
|
||||||
res = gh_graphql(query, cursor=endCursor, **kwargs)
|
res = gh_graphql(query, cursor=endCursor, **kwargs)
|
||||||
|
|
@ -159,11 +159,11 @@ def paginate_graphql(
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def get_recent_prs() -> Dict[str, Any]:
|
def get_recent_prs() -> dict[str, Any]:
|
||||||
now = datetime.now().timestamp()
|
now = datetime.now().timestamp()
|
||||||
|
|
||||||
# Grab all PRs updated in last CLOSED_PR_RETENTION days
|
# Grab all PRs updated in last CLOSED_PR_RETENTION days
|
||||||
pr_infos: List[Dict[str, Any]] = paginate_graphql(
|
pr_infos: list[dict[str, Any]] = paginate_graphql(
|
||||||
GRAPHQL_ALL_PRS_BY_UPDATED_AT,
|
GRAPHQL_ALL_PRS_BY_UPDATED_AT,
|
||||||
{"owner": "pytorch", "repo": "pytorch"},
|
{"owner": "pytorch", "repo": "pytorch"},
|
||||||
lambda data: (
|
lambda data: (
|
||||||
|
|
@ -190,7 +190,7 @@ def get_recent_prs() -> Dict[str, Any]:
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def get_open_prs() -> List[Dict[str, Any]]:
|
def get_open_prs() -> list[dict[str, Any]]:
|
||||||
return paginate_graphql(
|
return paginate_graphql(
|
||||||
GRAPHQL_OPEN_PRS,
|
GRAPHQL_OPEN_PRS,
|
||||||
{"owner": "pytorch", "repo": "pytorch"},
|
{"owner": "pytorch", "repo": "pytorch"},
|
||||||
|
|
@ -200,8 +200,8 @@ def get_open_prs() -> List[Dict[str, Any]]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_branches_with_magic_label_or_open_pr() -> Set[str]:
|
def get_branches_with_magic_label_or_open_pr() -> set[str]:
|
||||||
pr_infos: List[Dict[str, Any]] = paginate_graphql(
|
pr_infos: list[dict[str, Any]] = paginate_graphql(
|
||||||
GRAPHQL_NO_DELETE_BRANCH_LABEL,
|
GRAPHQL_NO_DELETE_BRANCH_LABEL,
|
||||||
{"owner": "pytorch", "repo": "pytorch"},
|
{"owner": "pytorch", "repo": "pytorch"},
|
||||||
lambda data: False,
|
lambda data: False,
|
||||||
|
|
|
||||||
4
.github/scripts/file_io_utils.py
vendored
4
.github/scripts/file_io_utils.py
vendored
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List
|
from typing import Any
|
||||||
|
|
||||||
import boto3 # type: ignore[import]
|
import boto3 # type: ignore[import]
|
||||||
|
|
||||||
|
|
@ -77,7 +77,7 @@ def upload_file_to_s3(file_name: Path, bucket: str, key: str) -> None:
|
||||||
|
|
||||||
def download_s3_objects_with_prefix(
|
def download_s3_objects_with_prefix(
|
||||||
bucket_name: str, prefix: str, download_folder: Path
|
bucket_name: str, prefix: str, download_folder: Path
|
||||||
) -> List[Path]:
|
) -> list[Path]:
|
||||||
s3 = boto3.resource("s3")
|
s3 = boto3.resource("s3")
|
||||||
bucket = s3.Bucket(bucket_name)
|
bucket = s3.Bucket(bucket_name)
|
||||||
|
|
||||||
|
|
|
||||||
60
.github/scripts/filter_test_configs.py
vendored
60
.github/scripts/filter_test_configs.py
vendored
|
|
@ -8,9 +8,9 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import lru_cache
|
from functools import cache
|
||||||
from logging import info
|
from logging import info
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set
|
from typing import Any, Callable, Optional
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -32,7 +32,7 @@ def is_cuda_or_rocm_job(job_name: Optional[str]) -> bool:
|
||||||
|
|
||||||
# Supported modes when running periodically. Only applying the mode when
|
# Supported modes when running periodically. Only applying the mode when
|
||||||
# its lambda condition returns true
|
# its lambda condition returns true
|
||||||
SUPPORTED_PERIODICAL_MODES: Dict[str, Callable[[Optional[str]], bool]] = {
|
SUPPORTED_PERIODICAL_MODES: dict[str, Callable[[Optional[str]], bool]] = {
|
||||||
# Memory leak check is only needed for CUDA and ROCm jobs which utilize GPU memory
|
# Memory leak check is only needed for CUDA and ROCm jobs which utilize GPU memory
|
||||||
"mem_leak_check": is_cuda_or_rocm_job,
|
"mem_leak_check": is_cuda_or_rocm_job,
|
||||||
"rerun_disabled_tests": lambda job_name: True,
|
"rerun_disabled_tests": lambda job_name: True,
|
||||||
|
|
@ -102,8 +102,8 @@ def parse_args() -> Any:
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@cache
|
||||||
def get_pr_info(pr_number: int) -> Dict[str, Any]:
|
def get_pr_info(pr_number: int) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Dynamically get PR information
|
Dynamically get PR information
|
||||||
"""
|
"""
|
||||||
|
|
@ -116,7 +116,7 @@ def get_pr_info(pr_number: int) -> Dict[str, Any]:
|
||||||
"Accept": "application/vnd.github.v3+json",
|
"Accept": "application/vnd.github.v3+json",
|
||||||
"Authorization": f"token {github_token}",
|
"Authorization": f"token {github_token}",
|
||||||
}
|
}
|
||||||
json_response: Dict[str, Any] = download_json(
|
json_response: dict[str, Any] = download_json(
|
||||||
url=f"{pytorch_github_api}/issues/{pr_number}",
|
url=f"{pytorch_github_api}/issues/{pr_number}",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
@ -128,7 +128,7 @@ def get_pr_info(pr_number: int) -> Dict[str, Any]:
|
||||||
return json_response
|
return json_response
|
||||||
|
|
||||||
|
|
||||||
def get_labels(pr_number: int) -> Set[str]:
|
def get_labels(pr_number: int) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Dynamically get the latest list of labels from the pull request
|
Dynamically get the latest list of labels from the pull request
|
||||||
"""
|
"""
|
||||||
|
|
@ -138,14 +138,14 @@ def get_labels(pr_number: int) -> Set[str]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def filter_labels(labels: Set[str], label_regex: Any) -> Set[str]:
|
def filter_labels(labels: set[str], label_regex: Any) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Return the list of matching labels
|
Return the list of matching labels
|
||||||
"""
|
"""
|
||||||
return {l for l in labels if re.match(label_regex, l)}
|
return {l for l in labels if re.match(label_regex, l)}
|
||||||
|
|
||||||
|
|
||||||
def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, List[Any]]:
|
def filter(test_matrix: dict[str, list[Any]], labels: set[str]) -> dict[str, list[Any]]:
|
||||||
"""
|
"""
|
||||||
Select the list of test config to run from the test matrix. The logic works
|
Select the list of test config to run from the test matrix. The logic works
|
||||||
as follows:
|
as follows:
|
||||||
|
|
@ -157,7 +157,7 @@ def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, Lis
|
||||||
|
|
||||||
If the PR has none of the test-config label, all tests are run as usual.
|
If the PR has none of the test-config label, all tests are run as usual.
|
||||||
"""
|
"""
|
||||||
filtered_test_matrix: Dict[str, List[Any]] = {"include": []}
|
filtered_test_matrix: dict[str, list[Any]] = {"include": []}
|
||||||
|
|
||||||
for entry in test_matrix.get("include", []):
|
for entry in test_matrix.get("include", []):
|
||||||
config_name = entry.get("config", "")
|
config_name = entry.get("config", "")
|
||||||
|
|
@ -185,8 +185,8 @@ def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, Lis
|
||||||
|
|
||||||
|
|
||||||
def filter_selected_test_configs(
|
def filter_selected_test_configs(
|
||||||
test_matrix: Dict[str, List[Any]], selected_test_configs: Set[str]
|
test_matrix: dict[str, list[Any]], selected_test_configs: set[str]
|
||||||
) -> Dict[str, List[Any]]:
|
) -> dict[str, list[Any]]:
|
||||||
"""
|
"""
|
||||||
Keep only the selected configs if the list if not empty. Otherwise, keep all test configs.
|
Keep only the selected configs if the list if not empty. Otherwise, keep all test configs.
|
||||||
This filter is used when the workflow is dispatched manually.
|
This filter is used when the workflow is dispatched manually.
|
||||||
|
|
@ -194,7 +194,7 @@ def filter_selected_test_configs(
|
||||||
if not selected_test_configs:
|
if not selected_test_configs:
|
||||||
return test_matrix
|
return test_matrix
|
||||||
|
|
||||||
filtered_test_matrix: Dict[str, List[Any]] = {"include": []}
|
filtered_test_matrix: dict[str, list[Any]] = {"include": []}
|
||||||
for entry in test_matrix.get("include", []):
|
for entry in test_matrix.get("include", []):
|
||||||
config_name = entry.get("config", "")
|
config_name = entry.get("config", "")
|
||||||
if not config_name:
|
if not config_name:
|
||||||
|
|
@ -207,12 +207,12 @@ def filter_selected_test_configs(
|
||||||
|
|
||||||
|
|
||||||
def set_periodic_modes(
|
def set_periodic_modes(
|
||||||
test_matrix: Dict[str, List[Any]], job_name: Optional[str]
|
test_matrix: dict[str, list[Any]], job_name: Optional[str]
|
||||||
) -> Dict[str, List[Any]]:
|
) -> dict[str, list[Any]]:
|
||||||
"""
|
"""
|
||||||
Apply all periodic modes when running under a schedule
|
Apply all periodic modes when running under a schedule
|
||||||
"""
|
"""
|
||||||
scheduled_test_matrix: Dict[str, List[Any]] = {
|
scheduled_test_matrix: dict[str, list[Any]] = {
|
||||||
"include": [],
|
"include": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -229,8 +229,8 @@ def set_periodic_modes(
|
||||||
|
|
||||||
|
|
||||||
def mark_unstable_jobs(
|
def mark_unstable_jobs(
|
||||||
workflow: str, job_name: str, test_matrix: Dict[str, List[Any]]
|
workflow: str, job_name: str, test_matrix: dict[str, list[Any]]
|
||||||
) -> Dict[str, List[Any]]:
|
) -> dict[str, list[Any]]:
|
||||||
"""
|
"""
|
||||||
Check the list of unstable jobs and mark them accordingly. Note that if a job
|
Check the list of unstable jobs and mark them accordingly. Note that if a job
|
||||||
is unstable, all its dependents will also be marked accordingly
|
is unstable, all its dependents will also be marked accordingly
|
||||||
|
|
@ -245,8 +245,8 @@ def mark_unstable_jobs(
|
||||||
|
|
||||||
|
|
||||||
def remove_disabled_jobs(
|
def remove_disabled_jobs(
|
||||||
workflow: str, job_name: str, test_matrix: Dict[str, List[Any]]
|
workflow: str, job_name: str, test_matrix: dict[str, list[Any]]
|
||||||
) -> Dict[str, List[Any]]:
|
) -> dict[str, list[Any]]:
|
||||||
"""
|
"""
|
||||||
Check the list of disabled jobs, remove the current job and all its dependents
|
Check the list of disabled jobs, remove the current job and all its dependents
|
||||||
if it exists in the list
|
if it exists in the list
|
||||||
|
|
@ -261,15 +261,15 @@ def remove_disabled_jobs(
|
||||||
|
|
||||||
|
|
||||||
def _filter_jobs(
|
def _filter_jobs(
|
||||||
test_matrix: Dict[str, List[Any]],
|
test_matrix: dict[str, list[Any]],
|
||||||
issue_type: IssueType,
|
issue_type: IssueType,
|
||||||
target_cfg: Optional[str] = None,
|
target_cfg: Optional[str] = None,
|
||||||
) -> Dict[str, List[Any]]:
|
) -> dict[str, list[Any]]:
|
||||||
"""
|
"""
|
||||||
An utility function used to actually apply the job filter
|
An utility function used to actually apply the job filter
|
||||||
"""
|
"""
|
||||||
# The result will be stored here
|
# The result will be stored here
|
||||||
filtered_test_matrix: Dict[str, List[Any]] = {"include": []}
|
filtered_test_matrix: dict[str, list[Any]] = {"include": []}
|
||||||
|
|
||||||
# This is an issue to disable a CI job
|
# This is an issue to disable a CI job
|
||||||
if issue_type == IssueType.DISABLED:
|
if issue_type == IssueType.DISABLED:
|
||||||
|
|
@ -302,10 +302,10 @@ def _filter_jobs(
|
||||||
def process_jobs(
|
def process_jobs(
|
||||||
workflow: str,
|
workflow: str,
|
||||||
job_name: str,
|
job_name: str,
|
||||||
test_matrix: Dict[str, List[Any]],
|
test_matrix: dict[str, list[Any]],
|
||||||
issue_type: IssueType,
|
issue_type: IssueType,
|
||||||
url: str,
|
url: str,
|
||||||
) -> Dict[str, List[Any]]:
|
) -> dict[str, list[Any]]:
|
||||||
"""
|
"""
|
||||||
Both disabled and unstable jobs are in the following format:
|
Both disabled and unstable jobs are in the following format:
|
||||||
|
|
||||||
|
|
@ -441,7 +441,7 @@ def process_jobs(
|
||||||
return test_matrix
|
return test_matrix
|
||||||
|
|
||||||
|
|
||||||
def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any:
|
def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any:
|
||||||
for _ in range(num_retries):
|
for _ in range(num_retries):
|
||||||
try:
|
try:
|
||||||
req = Request(url=url, headers=headers)
|
req = Request(url=url, headers=headers)
|
||||||
|
|
@ -462,7 +462,7 @@ def set_output(name: str, val: Any) -> None:
|
||||||
print(f"::set-output name={name}::{val}")
|
print(f"::set-output name={name}::{val}")
|
||||||
|
|
||||||
|
|
||||||
def parse_reenabled_issues(s: Optional[str]) -> List[str]:
|
def parse_reenabled_issues(s: Optional[str]) -> list[str]:
|
||||||
# NB: When the PR body is empty, GitHub API returns a None value, which is
|
# NB: When the PR body is empty, GitHub API returns a None value, which is
|
||||||
# passed into this function
|
# passed into this function
|
||||||
if not s:
|
if not s:
|
||||||
|
|
@ -477,7 +477,7 @@ def parse_reenabled_issues(s: Optional[str]) -> List[str]:
|
||||||
return issue_numbers
|
return issue_numbers
|
||||||
|
|
||||||
|
|
||||||
def get_reenabled_issues(pr_body: str = "") -> List[str]:
|
def get_reenabled_issues(pr_body: str = "") -> list[str]:
|
||||||
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
|
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
|
||||||
try:
|
try:
|
||||||
commit_messages = subprocess.check_output(
|
commit_messages = subprocess.check_output(
|
||||||
|
|
@ -489,12 +489,12 @@ def get_reenabled_issues(pr_body: str = "") -> List[str]:
|
||||||
return parse_reenabled_issues(pr_body) + parse_reenabled_issues(commit_messages)
|
return parse_reenabled_issues(pr_body) + parse_reenabled_issues(commit_messages)
|
||||||
|
|
||||||
|
|
||||||
def check_for_setting(labels: Set[str], body: str, setting: str) -> bool:
|
def check_for_setting(labels: set[str], body: str, setting: str) -> bool:
|
||||||
return setting in labels or f"[{setting}]" in body
|
return setting in labels or f"[{setting}]" in body
|
||||||
|
|
||||||
|
|
||||||
def perform_misc_tasks(
|
def perform_misc_tasks(
|
||||||
labels: Set[str], test_matrix: Dict[str, List[Any]], job_name: str, pr_body: str
|
labels: set[str], test_matrix: dict[str, list[Any]], job_name: str, pr_body: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
In addition to apply the filter logic, the script also does the following
|
In addition to apply the filter logic, the script also does the following
|
||||||
|
|
|
||||||
22
.github/scripts/generate_binary_build_matrix.py
vendored
22
.github/scripts/generate_binary_build_matrix.py
vendored
|
|
@ -12,7 +12,7 @@ architectures:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
# NOTE: Also update the CUDA sources in tools/nightly.py when changing this list
|
# NOTE: Also update the CUDA sources in tools/nightly.py when changing this list
|
||||||
|
|
@ -181,7 +181,7 @@ CXX11_ABI = "cxx11-abi"
|
||||||
RELEASE = "release"
|
RELEASE = "release"
|
||||||
DEBUG = "debug"
|
DEBUG = "debug"
|
||||||
|
|
||||||
LIBTORCH_CONTAINER_IMAGES: Dict[Tuple[str, str], str] = {
|
LIBTORCH_CONTAINER_IMAGES: dict[tuple[str, str], str] = {
|
||||||
**{
|
**{
|
||||||
(
|
(
|
||||||
gpu_arch,
|
gpu_arch,
|
||||||
|
|
@ -223,16 +223,16 @@ def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
|
||||||
}.get(gpu_arch_type, gpu_arch_version)
|
}.get(gpu_arch_type, gpu_arch_version)
|
||||||
|
|
||||||
|
|
||||||
def list_without(in_list: List[str], without: List[str]) -> List[str]:
|
def list_without(in_list: list[str], without: list[str]) -> list[str]:
|
||||||
return [item for item in in_list if item not in without]
|
return [item for item in in_list if item not in without]
|
||||||
|
|
||||||
|
|
||||||
def generate_libtorch_matrix(
|
def generate_libtorch_matrix(
|
||||||
os: str,
|
os: str,
|
||||||
abi_version: str,
|
abi_version: str,
|
||||||
arches: Optional[List[str]] = None,
|
arches: Optional[list[str]] = None,
|
||||||
libtorch_variants: Optional[List[str]] = None,
|
libtorch_variants: Optional[list[str]] = None,
|
||||||
) -> List[Dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
if arches is None:
|
if arches is None:
|
||||||
arches = ["cpu"]
|
arches = ["cpu"]
|
||||||
if os == "linux":
|
if os == "linux":
|
||||||
|
|
@ -248,7 +248,7 @@ def generate_libtorch_matrix(
|
||||||
"static-without-deps",
|
"static-without-deps",
|
||||||
]
|
]
|
||||||
|
|
||||||
ret: List[Dict[str, str]] = []
|
ret: list[dict[str, str]] = []
|
||||||
for arch_version in arches:
|
for arch_version in arches:
|
||||||
for libtorch_variant in libtorch_variants:
|
for libtorch_variant in libtorch_variants:
|
||||||
# one of the values in the following list must be exactly
|
# one of the values in the following list must be exactly
|
||||||
|
|
@ -287,10 +287,10 @@ def generate_libtorch_matrix(
|
||||||
|
|
||||||
def generate_wheels_matrix(
|
def generate_wheels_matrix(
|
||||||
os: str,
|
os: str,
|
||||||
arches: Optional[List[str]] = None,
|
arches: Optional[list[str]] = None,
|
||||||
python_versions: Optional[List[str]] = None,
|
python_versions: Optional[list[str]] = None,
|
||||||
use_split_build: bool = False,
|
use_split_build: bool = False,
|
||||||
) -> List[Dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
package_type = "wheel"
|
package_type = "wheel"
|
||||||
if os == "linux" or os == "linux-aarch64" or os == "linux-s390x":
|
if os == "linux" or os == "linux-aarch64" or os == "linux-s390x":
|
||||||
# NOTE: We only build manywheel packages for x86_64 and aarch64 and s390x linux
|
# NOTE: We only build manywheel packages for x86_64 and aarch64 and s390x linux
|
||||||
|
|
@ -315,7 +315,7 @@ def generate_wheels_matrix(
|
||||||
# uses different build/test scripts
|
# uses different build/test scripts
|
||||||
arches = ["cpu-s390x"]
|
arches = ["cpu-s390x"]
|
||||||
|
|
||||||
ret: List[Dict[str, str]] = []
|
ret: list[dict[str, str]] = []
|
||||||
for python_version in python_versions:
|
for python_version in python_versions:
|
||||||
for arch_version in arches:
|
for arch_version in arches:
|
||||||
gpu_arch_type = arch_type(arch_version)
|
gpu_arch_type = arch_type(arch_version)
|
||||||
|
|
|
||||||
7
.github/scripts/generate_ci_workflows.py
vendored
7
.github/scripts/generate_ci_workflows.py
vendored
|
|
@ -2,9 +2,10 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, Literal, Set
|
from typing import Literal
|
||||||
from typing_extensions import TypedDict # Python 3.11+
|
from typing_extensions import TypedDict # Python 3.11+
|
||||||
|
|
||||||
import generate_binary_build_matrix # type: ignore[import]
|
import generate_binary_build_matrix # type: ignore[import]
|
||||||
|
|
@ -27,7 +28,7 @@ LABEL_CIFLOW_BINARIES_WHEEL = "ciflow/binaries_wheel"
|
||||||
class CIFlowConfig:
|
class CIFlowConfig:
|
||||||
# For use to enable workflows to run on pytorch/pytorch-canary
|
# For use to enable workflows to run on pytorch/pytorch-canary
|
||||||
run_on_canary: bool = False
|
run_on_canary: bool = False
|
||||||
labels: Set[str] = field(default_factory=set)
|
labels: set[str] = field(default_factory=set)
|
||||||
# Certain jobs might not want to be part of the ciflow/[all,trunk] workflow
|
# Certain jobs might not want to be part of the ciflow/[all,trunk] workflow
|
||||||
isolated_workflow: bool = False
|
isolated_workflow: bool = False
|
||||||
unstable: bool = False
|
unstable: bool = False
|
||||||
|
|
@ -48,7 +49,7 @@ class Config(TypedDict):
|
||||||
@dataclass
|
@dataclass
|
||||||
class BinaryBuildWorkflow:
|
class BinaryBuildWorkflow:
|
||||||
os: str
|
os: str
|
||||||
build_configs: List[Dict[str, str]]
|
build_configs: list[dict[str, str]]
|
||||||
package_type: str
|
package_type: str
|
||||||
|
|
||||||
# Optional fields
|
# Optional fields
|
||||||
|
|
|
||||||
10
.github/scripts/get_workflow_job_id.py
vendored
10
.github/scripts/get_workflow_job_id.py
vendored
|
|
@ -11,11 +11,11 @@ import sys
|
||||||
import time
|
import time
|
||||||
import urllib
|
import urllib
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Optional
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
|
||||||
def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]:
|
def parse_json_and_links(conn: Any) -> tuple[Any, dict[str, dict[str, str]]]:
|
||||||
links = {}
|
links = {}
|
||||||
# Extract links which GH uses for pagination
|
# Extract links which GH uses for pagination
|
||||||
# see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link
|
# see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link
|
||||||
|
|
@ -42,7 +42,7 @@ def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]:
|
||||||
def fetch_url(
|
def fetch_url(
|
||||||
url: str,
|
url: str,
|
||||||
*,
|
*,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
reader: Callable[[Any], Any] = lambda x: x.read(),
|
reader: Callable[[Any], Any] = lambda x: x.read(),
|
||||||
retries: Optional[int] = 3,
|
retries: Optional[int] = 3,
|
||||||
backoff_timeout: float = 0.5,
|
backoff_timeout: float = 0.5,
|
||||||
|
|
@ -83,7 +83,7 @@ def parse_args() -> Any:
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]:
|
def fetch_jobs(url: str, headers: dict[str, str]) -> list[dict[str, str]]:
|
||||||
response, links = fetch_url(url, headers=headers, reader=parse_json_and_links)
|
response, links = fetch_url(url, headers=headers, reader=parse_json_and_links)
|
||||||
jobs = response["jobs"]
|
jobs = response["jobs"]
|
||||||
assert type(jobs) is list
|
assert type(jobs) is list
|
||||||
|
|
@ -111,7 +111,7 @@ def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]:
|
||||||
# running.
|
# running.
|
||||||
|
|
||||||
|
|
||||||
def find_job_id_name(args: Any) -> Tuple[str, str]:
|
def find_job_id_name(args: Any) -> tuple[str, str]:
|
||||||
# From https://docs.github.com/en/actions/learn-github-actions/environment-variables
|
# From https://docs.github.com/en/actions/learn-github-actions/environment-variables
|
||||||
PYTORCH_REPO = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch")
|
PYTORCH_REPO = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch")
|
||||||
PYTORCH_GITHUB_API = f"https://api.github.com/repos/{PYTORCH_REPO}"
|
PYTORCH_GITHUB_API = f"https://api.github.com/repos/{PYTORCH_REPO}"
|
||||||
|
|
|
||||||
54
.github/scripts/github_utils.py
vendored
54
.github/scripts/github_utils.py
vendored
|
|
@ -4,7 +4,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, cast, Optional, Union
|
||||||
from urllib.error import HTTPError
|
from urllib.error import HTTPError
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|
@ -27,11 +27,11 @@ class GitHubComment:
|
||||||
def gh_fetch_url_and_headers(
|
def gh_fetch_url_and_headers(
|
||||||
url: str,
|
url: str,
|
||||||
*,
|
*,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
data: Union[Optional[Dict[str, Any]], str] = None,
|
data: Union[Optional[dict[str, Any]], str] = None,
|
||||||
method: Optional[str] = None,
|
method: Optional[str] = None,
|
||||||
reader: Callable[[Any], Any] = lambda x: x.read(),
|
reader: Callable[[Any], Any] = lambda x: x.read(),
|
||||||
) -> Tuple[Any, Any]:
|
) -> tuple[Any, Any]:
|
||||||
if headers is None:
|
if headers is None:
|
||||||
headers = {}
|
headers = {}
|
||||||
token = os.environ.get("GITHUB_TOKEN")
|
token = os.environ.get("GITHUB_TOKEN")
|
||||||
|
|
@ -70,8 +70,8 @@ def gh_fetch_url_and_headers(
|
||||||
def gh_fetch_url(
|
def gh_fetch_url(
|
||||||
url: str,
|
url: str,
|
||||||
*,
|
*,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
data: Union[Optional[Dict[str, Any]], str] = None,
|
data: Union[Optional[dict[str, Any]], str] = None,
|
||||||
method: Optional[str] = None,
|
method: Optional[str] = None,
|
||||||
reader: Callable[[Any], Any] = json.load,
|
reader: Callable[[Any], Any] = json.load,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
|
@ -82,25 +82,25 @@ def gh_fetch_url(
|
||||||
|
|
||||||
def gh_fetch_json(
|
def gh_fetch_json(
|
||||||
url: str,
|
url: str,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[dict[str, Any]] = None,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: Optional[dict[str, Any]] = None,
|
||||||
method: Optional[str] = None,
|
method: Optional[str] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||||
if params is not None and len(params) > 0:
|
if params is not None and len(params) > 0:
|
||||||
url += "?" + "&".join(
|
url += "?" + "&".join(
|
||||||
f"{name}={quote(str(val))}" for name, val in params.items()
|
f"{name}={quote(str(val))}" for name, val in params.items()
|
||||||
)
|
)
|
||||||
return cast(
|
return cast(
|
||||||
List[Dict[str, Any]],
|
list[dict[str, Any]],
|
||||||
gh_fetch_url(url, headers=headers, data=data, reader=json.load, method=method),
|
gh_fetch_url(url, headers=headers, data=data, reader=json.load, method=method),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _gh_fetch_json_any(
|
def _gh_fetch_json_any(
|
||||||
url: str,
|
url: str,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[dict[str, Any]] = None,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: Optional[dict[str, Any]] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||||
if params is not None and len(params) > 0:
|
if params is not None and len(params) > 0:
|
||||||
|
|
@ -112,21 +112,21 @@ def _gh_fetch_json_any(
|
||||||
|
|
||||||
def gh_fetch_json_list(
|
def gh_fetch_json_list(
|
||||||
url: str,
|
url: str,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[dict[str, Any]] = None,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: Optional[dict[str, Any]] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
return cast(List[Dict[str, Any]], _gh_fetch_json_any(url, params, data))
|
return cast(list[dict[str, Any]], _gh_fetch_json_any(url, params, data))
|
||||||
|
|
||||||
|
|
||||||
def gh_fetch_json_dict(
|
def gh_fetch_json_dict(
|
||||||
url: str,
|
url: str,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[dict[str, Any]] = None,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: Optional[dict[str, Any]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
|
return cast(dict[str, Any], _gh_fetch_json_any(url, params, data))
|
||||||
|
|
||||||
|
|
||||||
def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]:
|
def gh_graphql(query: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
rc = gh_fetch_url(
|
rc = gh_fetch_url(
|
||||||
"https://api.github.com/graphql",
|
"https://api.github.com/graphql",
|
||||||
data={"query": query, "variables": kwargs},
|
data={"query": query, "variables": kwargs},
|
||||||
|
|
@ -136,12 +136,12 @@ def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"GraphQL query {query}, args {kwargs} failed: {rc['errors']}"
|
f"GraphQL query {query}, args {kwargs} failed: {rc['errors']}"
|
||||||
)
|
)
|
||||||
return cast(Dict[str, Any], rc)
|
return cast(dict[str, Any], rc)
|
||||||
|
|
||||||
|
|
||||||
def _gh_post_comment(
|
def _gh_post_comment(
|
||||||
url: str, comment: str, dry_run: bool = False
|
url: str, comment: str, dry_run: bool = False
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
if dry_run:
|
if dry_run:
|
||||||
print(comment)
|
print(comment)
|
||||||
return []
|
return []
|
||||||
|
|
@ -150,7 +150,7 @@ def _gh_post_comment(
|
||||||
|
|
||||||
def gh_post_pr_comment(
|
def gh_post_pr_comment(
|
||||||
org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False
|
org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
return _gh_post_comment(
|
return _gh_post_comment(
|
||||||
f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/{pr_num}/comments",
|
f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/{pr_num}/comments",
|
||||||
comment,
|
comment,
|
||||||
|
|
@ -160,7 +160,7 @@ def gh_post_pr_comment(
|
||||||
|
|
||||||
def gh_post_commit_comment(
|
def gh_post_commit_comment(
|
||||||
org: str, repo: str, sha: str, comment: str, dry_run: bool = False
|
org: str, repo: str, sha: str, comment: str, dry_run: bool = False
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
return _gh_post_comment(
|
return _gh_post_comment(
|
||||||
f"{GITHUB_API_URL}/repos/{org}/{repo}/commits/{sha}/comments",
|
f"{GITHUB_API_URL}/repos/{org}/{repo}/commits/{sha}/comments",
|
||||||
comment,
|
comment,
|
||||||
|
|
@ -220,8 +220,8 @@ def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") ->
|
||||||
|
|
||||||
|
|
||||||
def gh_query_issues_by_labels(
|
def gh_query_issues_by_labels(
|
||||||
org: str, repo: str, labels: List[str], state: str = "open"
|
org: str, repo: str, labels: list[str], state: str = "open"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues"
|
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues"
|
||||||
return gh_fetch_json(
|
return gh_fetch_json(
|
||||||
url, method="GET", params={"labels": ",".join(labels), "state": state}
|
url, method="GET", params={"labels": ",".join(labels), "state": state}
|
||||||
|
|
|
||||||
42
.github/scripts/gitutils.py
vendored
42
.github/scripts/gitutils.py
vendored
|
|
@ -4,20 +4,10 @@ import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import (
|
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
cast,
|
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
@ -35,17 +25,17 @@ def get_git_repo_dir() -> str:
|
||||||
return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parents[2]))
|
return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parents[2]))
|
||||||
|
|
||||||
|
|
||||||
def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:
|
def fuzzy_list_to_dict(items: list[tuple[str, str]]) -> dict[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
Converts list to dict preserving elements with duplicate keys
|
Converts list to dict preserving elements with duplicate keys
|
||||||
"""
|
"""
|
||||||
rc: Dict[str, List[str]] = defaultdict(list)
|
rc: dict[str, list[str]] = defaultdict(list)
|
||||||
for key, val in items:
|
for key, val in items:
|
||||||
rc[key].append(val)
|
rc[key].append(val)
|
||||||
return dict(rc)
|
return dict(rc)
|
||||||
|
|
||||||
|
|
||||||
def _check_output(items: List[str], encoding: str = "utf-8") -> str:
|
def _check_output(items: list[str], encoding: str = "utf-8") -> str:
|
||||||
from subprocess import CalledProcessError, check_output, STDOUT
|
from subprocess import CalledProcessError, check_output, STDOUT
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -95,7 +85,7 @@ class GitCommit:
|
||||||
return item in self.body or item in self.title
|
return item in self.body or item in self.title
|
||||||
|
|
||||||
|
|
||||||
def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit:
|
def parse_fuller_format(lines: Union[str, list[str]]) -> GitCommit:
|
||||||
"""
|
"""
|
||||||
Expect commit message generated using `--format=fuller --date=unix` format, i.e.:
|
Expect commit message generated using `--format=fuller --date=unix` format, i.e.:
|
||||||
commit <sha1>
|
commit <sha1>
|
||||||
|
|
@ -142,13 +132,13 @@ class GitRepo:
|
||||||
print(f"+ git -C {self.repo_dir} {' '.join(args)}")
|
print(f"+ git -C {self.repo_dir} {' '.join(args)}")
|
||||||
return _check_output(["git", "-C", self.repo_dir] + list(args))
|
return _check_output(["git", "-C", self.repo_dir] + list(args))
|
||||||
|
|
||||||
def revlist(self, revision_range: str) -> List[str]:
|
def revlist(self, revision_range: str) -> list[str]:
|
||||||
rc = self._run_git("rev-list", revision_range, "--", ".").strip()
|
rc = self._run_git("rev-list", revision_range, "--", ".").strip()
|
||||||
return rc.split("\n") if len(rc) > 0 else []
|
return rc.split("\n") if len(rc) > 0 else []
|
||||||
|
|
||||||
def branches_containing_ref(
|
def branches_containing_ref(
|
||||||
self, ref: str, *, include_remote: bool = True
|
self, ref: str, *, include_remote: bool = True
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
rc = (
|
rc = (
|
||||||
self._run_git("branch", "--remote", "--contains", ref)
|
self._run_git("branch", "--remote", "--contains", ref)
|
||||||
if include_remote
|
if include_remote
|
||||||
|
|
@ -189,7 +179,7 @@ class GitRepo:
|
||||||
def get_merge_base(self, from_ref: str, to_ref: str) -> str:
|
def get_merge_base(self, from_ref: str, to_ref: str) -> str:
|
||||||
return self._run_git("merge-base", from_ref, to_ref).strip()
|
return self._run_git("merge-base", from_ref, to_ref).strip()
|
||||||
|
|
||||||
def patch_id(self, ref: Union[str, List[str]]) -> List[Tuple[str, str]]:
|
def patch_id(self, ref: Union[str, list[str]]) -> list[tuple[str, str]]:
|
||||||
is_list = isinstance(ref, list)
|
is_list = isinstance(ref, list)
|
||||||
if is_list:
|
if is_list:
|
||||||
if len(ref) == 0:
|
if len(ref) == 0:
|
||||||
|
|
@ -198,9 +188,9 @@ class GitRepo:
|
||||||
rc = _check_output(
|
rc = _check_output(
|
||||||
["sh", "-c", f"git -C {self.repo_dir} show {ref}|git patch-id --stable"]
|
["sh", "-c", f"git -C {self.repo_dir} show {ref}|git patch-id --stable"]
|
||||||
).strip()
|
).strip()
|
||||||
return [cast(Tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")]
|
return [cast(tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")]
|
||||||
|
|
||||||
def commits_resolving_gh_pr(self, pr_num: int) -> List[str]:
|
def commits_resolving_gh_pr(self, pr_num: int) -> list[str]:
|
||||||
owner, name = self.gh_owner_and_name()
|
owner, name = self.gh_owner_and_name()
|
||||||
msg = f"Pull Request resolved: https://github.com/{owner}/{name}/pull/{pr_num}"
|
msg = f"Pull Request resolved: https://github.com/{owner}/{name}/pull/{pr_num}"
|
||||||
rc = self._run_git("log", "--format=%H", "--grep", msg).strip()
|
rc = self._run_git("log", "--format=%H", "--grep", msg).strip()
|
||||||
|
|
@ -219,7 +209,7 @@ class GitRepo:
|
||||||
|
|
||||||
def compute_branch_diffs(
|
def compute_branch_diffs(
|
||||||
self, from_branch: str, to_branch: str
|
self, from_branch: str, to_branch: str
|
||||||
) -> Tuple[List[str], List[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
"""
|
"""
|
||||||
Returns list of commmits that are missing in each other branch since their merge base
|
Returns list of commmits that are missing in each other branch since their merge base
|
||||||
Might be slow if merge base is between two branches is pretty far off
|
Might be slow if merge base is between two branches is pretty far off
|
||||||
|
|
@ -311,14 +301,14 @@ class GitRepo:
|
||||||
def remote_url(self) -> str:
|
def remote_url(self) -> str:
|
||||||
return self._run_git("remote", "get-url", self.remote)
|
return self._run_git("remote", "get-url", self.remote)
|
||||||
|
|
||||||
def gh_owner_and_name(self) -> Tuple[str, str]:
|
def gh_owner_and_name(self) -> tuple[str, str]:
|
||||||
url = os.getenv("GIT_REMOTE_URL", None)
|
url = os.getenv("GIT_REMOTE_URL", None)
|
||||||
if url is None:
|
if url is None:
|
||||||
url = self.remote_url()
|
url = self.remote_url()
|
||||||
rc = RE_GITHUB_URL_MATCH.match(url)
|
rc = RE_GITHUB_URL_MATCH.match(url)
|
||||||
if rc is None:
|
if rc is None:
|
||||||
raise RuntimeError(f"Unexpected url format {url}")
|
raise RuntimeError(f"Unexpected url format {url}")
|
||||||
return cast(Tuple[str, str], rc.groups())
|
return cast(tuple[str, str], rc.groups())
|
||||||
|
|
||||||
def commit_message(self, ref: str) -> str:
|
def commit_message(self, ref: str) -> str:
|
||||||
return self._run_git("log", "-1", "--format=%B", ref)
|
return self._run_git("log", "-1", "--format=%B", ref)
|
||||||
|
|
@ -366,7 +356,7 @@ class PeekableIterator(Iterator[str]):
|
||||||
return rc
|
return rc
|
||||||
|
|
||||||
|
|
||||||
def patterns_to_regex(allowed_patterns: List[str]) -> Any:
|
def patterns_to_regex(allowed_patterns: list[str]) -> Any:
|
||||||
"""
|
"""
|
||||||
pattern is glob-like, i.e. the only special sequences it has are:
|
pattern is glob-like, i.e. the only special sequences it has are:
|
||||||
- ? - matches single character
|
- ? - matches single character
|
||||||
|
|
@ -437,7 +427,7 @@ def retries_decorator(
|
||||||
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||||
def decorator(f: Callable[..., T]) -> Callable[..., T]:
|
def decorator(f: Callable[..., T]) -> Callable[..., T]:
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> T:
|
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> T:
|
||||||
for idx in range(num_retries):
|
for idx in range(num_retries):
|
||||||
try:
|
try:
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
|
||||||
14
.github/scripts/label_utils.py
vendored
14
.github/scripts/label_utils.py
vendored
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, List, Tuple, TYPE_CHECKING, Union
|
from typing import Any, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from github_utils import gh_fetch_url_and_headers, GitHubComment
|
from github_utils import gh_fetch_url_and_headers, GitHubComment
|
||||||
|
|
||||||
|
|
@ -28,14 +28,14 @@ https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def request_for_labels(url: str) -> Tuple[Any, Any]:
|
def request_for_labels(url: str) -> tuple[Any, Any]:
|
||||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||||
return gh_fetch_url_and_headers(
|
return gh_fetch_url_and_headers(
|
||||||
url, headers=headers, reader=lambda x: x.read().decode("utf-8")
|
url, headers=headers, reader=lambda x: x.read().decode("utf-8")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def update_labels(labels: List[str], info: str) -> None:
|
def update_labels(labels: list[str], info: str) -> None:
|
||||||
labels_json = json.loads(info)
|
labels_json = json.loads(info)
|
||||||
labels.extend([x["name"] for x in labels_json])
|
labels.extend([x["name"] for x in labels_json])
|
||||||
|
|
||||||
|
|
@ -56,10 +56,10 @@ def get_last_page_num_from_header(header: Any) -> int:
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def gh_get_labels(org: str, repo: str) -> List[str]:
|
def gh_get_labels(org: str, repo: str) -> list[str]:
|
||||||
prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100"
|
prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100"
|
||||||
header, info = request_for_labels(prefix + "&page=1")
|
header, info = request_for_labels(prefix + "&page=1")
|
||||||
labels: List[str] = []
|
labels: list[str] = []
|
||||||
update_labels(labels, info)
|
update_labels(labels, info)
|
||||||
|
|
||||||
last_page = get_last_page_num_from_header(header)
|
last_page = get_last_page_num_from_header(header)
|
||||||
|
|
@ -74,7 +74,7 @@ def gh_get_labels(org: str, repo: str) -> List[str]:
|
||||||
|
|
||||||
|
|
||||||
def gh_add_labels(
|
def gh_add_labels(
|
||||||
org: str, repo: str, pr_num: int, labels: Union[str, List[str]], dry_run: bool
|
org: str, repo: str, pr_num: int, labels: Union[str, list[str]], dry_run: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
if dry_run:
|
if dry_run:
|
||||||
print(f"Dryrun: Adding labels {labels} to PR {pr_num}")
|
print(f"Dryrun: Adding labels {labels} to PR {pr_num}")
|
||||||
|
|
@ -97,7 +97,7 @@ def gh_remove_label(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_release_notes_labels(org: str, repo: str) -> List[str]:
|
def get_release_notes_labels(org: str, repo: str) -> list[str]:
|
||||||
return [
|
return [
|
||||||
label
|
label
|
||||||
for label in gh_get_labels(org, repo)
|
for label in gh_get_labels(org, repo)
|
||||||
|
|
|
||||||
6
.github/scripts/pytest_caching_utils.py
vendored
6
.github/scripts/pytest_caching_utils.py
vendored
|
|
@ -1,7 +1,7 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
from file_io_utils import (
|
from file_io_utils import (
|
||||||
copy_file,
|
copy_file,
|
||||||
|
|
@ -219,8 +219,8 @@ def _merge_lastfailed_files(source_pytest_cache: Path, dest_pytest_cache: Path)
|
||||||
|
|
||||||
|
|
||||||
def _merged_lastfailed_content(
|
def _merged_lastfailed_content(
|
||||||
from_lastfailed: Dict[str, bool], to_lastfailed: Dict[str, bool]
|
from_lastfailed: dict[str, bool], to_lastfailed: dict[str, bool]
|
||||||
) -> Dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""
|
"""
|
||||||
The lastfailed files are dictionaries where the key is the test identifier.
|
The lastfailed files are dictionaries where the key is the test identifier.
|
||||||
Each entry's value appears to always be `true`, but let's not count on that.
|
Each entry's value appears to always be `true`, but let's not count on that.
|
||||||
|
|
|
||||||
31
.github/scripts/runner_determinator.py
vendored
31
.github/scripts/runner_determinator.py
vendored
|
|
@ -61,9 +61,10 @@ import random
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from functools import lru_cache
|
from collections.abc import Iterable
|
||||||
|
from functools import cache
|
||||||
from logging import LogRecord
|
from logging import LogRecord
|
||||||
from typing import Any, Dict, FrozenSet, Iterable, List, NamedTuple, Set, Tuple
|
from typing import Any, NamedTuple
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -105,7 +106,7 @@ class Settings(NamedTuple):
|
||||||
Settings for the experiments that can be opted into.
|
Settings for the experiments that can be opted into.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
experiments: Dict[str, Experiment] = {}
|
experiments: dict[str, Experiment] = {}
|
||||||
|
|
||||||
|
|
||||||
class ColorFormatter(logging.Formatter):
|
class ColorFormatter(logging.Formatter):
|
||||||
|
|
@ -150,7 +151,7 @@ def set_github_output(key: str, value: str) -> None:
|
||||||
f.write(f"{key}={value}\n")
|
f.write(f"{key}={value}\n")
|
||||||
|
|
||||||
|
|
||||||
def _str_comma_separated_to_set(value: str) -> FrozenSet[str]:
|
def _str_comma_separated_to_set(value: str) -> frozenset[str]:
|
||||||
return frozenset(
|
return frozenset(
|
||||||
filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(",")))
|
filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(",")))
|
||||||
)
|
)
|
||||||
|
|
@ -208,12 +209,12 @@ def parse_args() -> Any:
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def get_gh_client(github_token: str) -> Github:
|
def get_gh_client(github_token: str) -> Github: # type: ignore[no-any-unimported]
|
||||||
auth = Auth.Token(github_token)
|
auth = Auth.Token(github_token)
|
||||||
return Github(auth=auth)
|
return Github(auth=auth)
|
||||||
|
|
||||||
|
|
||||||
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue:
|
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: # type: ignore[no-any-unimported]
|
||||||
repo = gh.get_repo(repo)
|
repo = gh.get_repo(repo)
|
||||||
return repo.get_issue(number=issue_num)
|
return repo.get_issue(number=issue_num)
|
||||||
|
|
||||||
|
|
@ -242,7 +243,7 @@ def get_potential_pr_author(
|
||||||
raise Exception( # noqa: TRY002
|
raise Exception( # noqa: TRY002
|
||||||
f"issue with pull request {pr_number} from repo {repository}"
|
f"issue with pull request {pr_number} from repo {repository}"
|
||||||
) from e
|
) from e
|
||||||
return pull.user.login
|
return pull.user.login # type: ignore[no-any-return]
|
||||||
# In all other cases, return the original input username
|
# In all other cases, return the original input username
|
||||||
return username
|
return username
|
||||||
|
|
||||||
|
|
@ -263,7 +264,7 @@ def load_yaml(yaml_text: str) -> Any:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]:
|
def extract_settings_user_opt_in_from_text(rollout_state: str) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Extracts the text with settings, if any, and the opted in users from the rollout state.
|
Extracts the text with settings, if any, and the opted in users from the rollout state.
|
||||||
|
|
||||||
|
|
@ -279,7 +280,7 @@ def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str
|
||||||
return "", rollout_state
|
return "", rollout_state
|
||||||
|
|
||||||
|
|
||||||
class UserOptins(Dict[str, List[str]]):
|
class UserOptins(dict[str, list[str]]):
|
||||||
"""
|
"""
|
||||||
Dictionary of users with a list of features they have opted into
|
Dictionary of users with a list of features they have opted into
|
||||||
"""
|
"""
|
||||||
|
|
@ -420,7 +421,7 @@ def get_runner_prefix(
|
||||||
rollout_state: str,
|
rollout_state: str,
|
||||||
workflow_requestors: Iterable[str],
|
workflow_requestors: Iterable[str],
|
||||||
branch: str,
|
branch: str,
|
||||||
eligible_experiments: FrozenSet[str] = frozenset(),
|
eligible_experiments: frozenset[str] = frozenset(),
|
||||||
is_canary: bool = False,
|
is_canary: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
settings = parse_settings(rollout_state)
|
settings = parse_settings(rollout_state)
|
||||||
|
|
@ -519,7 +520,7 @@ def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -
|
||||||
return str(issue.get_comments()[0].body.strip("\n\t "))
|
return str(issue.get_comments()[0].body.strip("\n\t "))
|
||||||
|
|
||||||
|
|
||||||
def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any:
|
def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any:
|
||||||
for _ in range(num_retries):
|
for _ in range(num_retries):
|
||||||
try:
|
try:
|
||||||
req = Request(url=url, headers=headers)
|
req = Request(url=url, headers=headers)
|
||||||
|
|
@ -532,8 +533,8 @@ def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> An
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@cache
|
||||||
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str, Any]:
|
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Dynamically get PR information
|
Dynamically get PR information
|
||||||
"""
|
"""
|
||||||
|
|
@ -542,7 +543,7 @@ def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str
|
||||||
"Accept": "application/vnd.github.v3+json",
|
"Accept": "application/vnd.github.v3+json",
|
||||||
"Authorization": f"token {github_token}",
|
"Authorization": f"token {github_token}",
|
||||||
}
|
}
|
||||||
json_response: Dict[str, Any] = download_json(
|
json_response: dict[str, Any] = download_json(
|
||||||
url=f"{github_api}/issues/{pr_number}",
|
url=f"{github_api}/issues/{pr_number}",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
@ -554,7 +555,7 @@ def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str
|
||||||
return json_response
|
return json_response
|
||||||
|
|
||||||
|
|
||||||
def get_labels(github_repo: str, github_token: str, pr_number: int) -> Set[str]:
|
def get_labels(github_repo: str, github_token: str, pr_number: int) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Dynamically get the latest list of labels from the pull request
|
Dynamically get the latest list of labels from the pull request
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import generate_binary_build_matrix
|
import generate_binary_build_matrix
|
||||||
|
|
||||||
|
|
@ -10,7 +9,7 @@ def tag_image(
|
||||||
default_tag: str,
|
default_tag: str,
|
||||||
release_version: str,
|
release_version: str,
|
||||||
dry_run: str,
|
dry_run: str,
|
||||||
tagged_images: Dict[str, bool],
|
tagged_images: dict[str, bool],
|
||||||
) -> None:
|
) -> None:
|
||||||
if image in tagged_images:
|
if image in tagged_images:
|
||||||
return
|
return
|
||||||
|
|
@ -41,7 +40,7 @@ def main() -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
options = parser.parse_args()
|
options = parser.parse_args()
|
||||||
tagged_images: Dict[str, bool] = {}
|
tagged_images: dict[str, bool] = {}
|
||||||
platform_images = [
|
platform_images = [
|
||||||
generate_binary_build_matrix.WHEEL_CONTAINER_IMAGES,
|
generate_binary_build_matrix.WHEEL_CONTAINER_IMAGES,
|
||||||
generate_binary_build_matrix.LIBTORCH_CONTAINER_IMAGES,
|
generate_binary_build_matrix.LIBTORCH_CONTAINER_IMAGES,
|
||||||
|
|
|
||||||
4
.github/scripts/test_check_labels.py
vendored
4
.github/scripts/test_check_labels.py
vendored
|
|
@ -1,6 +1,6 @@
|
||||||
"""test_check_labels.py"""
|
"""test_check_labels.py"""
|
||||||
|
|
||||||
from typing import Any, List
|
from typing import Any
|
||||||
from unittest import main, mock, TestCase
|
from unittest import main, mock, TestCase
|
||||||
|
|
||||||
from check_labels import (
|
from check_labels import (
|
||||||
|
|
@ -31,7 +31,7 @@ def mock_delete_all_label_err_comments(pr: "GitHubPR") -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def mock_get_comments() -> List[GitHubComment]:
|
def mock_get_comments() -> list[GitHubComment]:
|
||||||
return [
|
return [
|
||||||
# Case 1 - a non label err comment
|
# Case 1 - a non label err comment
|
||||||
GitHubComment(
|
GitHubComment(
|
||||||
|
|
|
||||||
6
.github/scripts/test_filter_test_configs.py
vendored
6
.github/scripts/test_filter_test_configs.py
vendored
|
|
@ -3,7 +3,7 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
from unittest import main, mock, TestCase
|
from unittest import main, mock, TestCase
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -362,7 +362,7 @@ class TestConfigFilter(TestCase):
|
||||||
self.assertEqual(case["expected"], json.dumps(filtered_test_matrix))
|
self.assertEqual(case["expected"], json.dumps(filtered_test_matrix))
|
||||||
|
|
||||||
def test_set_periodic_modes(self) -> None:
|
def test_set_periodic_modes(self) -> None:
|
||||||
testcases: List[Dict[str, str]] = [
|
testcases: list[dict[str, str]] = [
|
||||||
{
|
{
|
||||||
"job_name": "a CI job",
|
"job_name": "a CI job",
|
||||||
"test_matrix": "{include: []}",
|
"test_matrix": "{include: []}",
|
||||||
|
|
@ -702,7 +702,7 @@ class TestConfigFilter(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
mocked_subprocess.return_value = b""
|
mocked_subprocess.return_value = b""
|
||||||
testcases: List[Dict[str, Any]] = [
|
testcases: list[dict[str, Any]] = [
|
||||||
{
|
{
|
||||||
"labels": {},
|
"labels": {},
|
||||||
"test_matrix": '{include: [{config: "default"}]}',
|
"test_matrix": '{include: [{config: "default"}]}',
|
||||||
|
|
|
||||||
14
.github/scripts/test_trymerge.py
vendored
14
.github/scripts/test_trymerge.py
vendored
|
|
@ -12,7 +12,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Optional
|
||||||
from unittest import main, mock, skip, TestCase
|
from unittest import main, mock, skip, TestCase
|
||||||
from urllib.error import HTTPError
|
from urllib.error import HTTPError
|
||||||
|
|
||||||
|
|
@ -170,7 +170,7 @@ def mock_gh_get_info() -> Any:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeRule]:
|
def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> list[MergeRule]:
|
||||||
return [
|
return [
|
||||||
MergeRule(
|
MergeRule(
|
||||||
name="mock with nonexistent check",
|
name="mock with nonexistent check",
|
||||||
|
|
@ -182,7 +182,7 @@ def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeR
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
|
def mocked_read_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]:
|
||||||
return [
|
return [
|
||||||
MergeRule(
|
MergeRule(
|
||||||
name="super",
|
name="super",
|
||||||
|
|
@ -211,7 +211,7 @@ def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule
|
||||||
|
|
||||||
def mocked_read_merge_rules_approvers(
|
def mocked_read_merge_rules_approvers(
|
||||||
repo: Any, org: str, project: str
|
repo: Any, org: str, project: str
|
||||||
) -> List[MergeRule]:
|
) -> list[MergeRule]:
|
||||||
return [
|
return [
|
||||||
MergeRule(
|
MergeRule(
|
||||||
name="Core Reviewers",
|
name="Core Reviewers",
|
||||||
|
|
@ -234,11 +234,11 @@ def mocked_read_merge_rules_approvers(
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> List[MergeRule]:
|
def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> list[MergeRule]:
|
||||||
raise RuntimeError("testing")
|
raise RuntimeError("testing")
|
||||||
|
|
||||||
|
|
||||||
def xla_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
|
def xla_merge_rules(repo: Any, org: str, project: str) -> list[MergeRule]:
|
||||||
return [
|
return [
|
||||||
MergeRule(
|
MergeRule(
|
||||||
name=" OSS CI / pytorchbot / XLA",
|
name=" OSS CI / pytorchbot / XLA",
|
||||||
|
|
@ -260,7 +260,7 @@ class DummyGitRepo(GitRepo):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(get_git_repo_dir(), get_git_remote_name())
|
super().__init__(get_git_repo_dir(), get_git_remote_name())
|
||||||
|
|
||||||
def commits_resolving_gh_pr(self, pr_num: int) -> List[str]:
|
def commits_resolving_gh_pr(self, pr_num: int) -> list[str]:
|
||||||
return ["FakeCommitSha"]
|
return ["FakeCommitSha"]
|
||||||
|
|
||||||
def commit_message(self, ref: str) -> str:
|
def commit_message(self, ref: str) -> str:
|
||||||
|
|
|
||||||
157
.github/scripts/trymerge.py
vendored
157
.github/scripts/trymerge.py
vendored
|
|
@ -17,21 +17,12 @@ import re
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from re import Pattern
|
||||||
Any,
|
from typing import Any, Callable, cast, NamedTuple, Optional
|
||||||
Callable,
|
|
||||||
cast,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Pattern,
|
|
||||||
Tuple,
|
|
||||||
)
|
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -78,7 +69,7 @@ class JobCheckState(NamedTuple):
|
||||||
summary: Optional[str]
|
summary: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
JobNameToStateDict = Dict[str, JobCheckState]
|
JobNameToStateDict = dict[str, JobCheckState]
|
||||||
|
|
||||||
|
|
||||||
class WorkflowCheckState:
|
class WorkflowCheckState:
|
||||||
|
|
@ -468,10 +459,10 @@ def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any:
|
||||||
return rc["data"]["repository"]["pullRequest"]
|
return rc["data"]["repository"]["pullRequest"]
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@cache
|
||||||
def gh_get_team_members(org: str, name: str) -> List[str]:
|
def gh_get_team_members(org: str, name: str) -> list[str]:
|
||||||
rc: List[str] = []
|
rc: list[str] = []
|
||||||
team_members: Dict[str, Any] = {
|
team_members: dict[str, Any] = {
|
||||||
"pageInfo": {"hasNextPage": "true", "endCursor": None}
|
"pageInfo": {"hasNextPage": "true", "endCursor": None}
|
||||||
}
|
}
|
||||||
while bool(team_members["pageInfo"]["hasNextPage"]):
|
while bool(team_members["pageInfo"]["hasNextPage"]):
|
||||||
|
|
@ -503,14 +494,14 @@ def is_passing_status(status: Optional[str]) -> bool:
|
||||||
|
|
||||||
def add_workflow_conclusions(
|
def add_workflow_conclusions(
|
||||||
checksuites: Any,
|
checksuites: Any,
|
||||||
get_next_checkruns_page: Callable[[List[Dict[str, Dict[str, Any]]], int, Any], Any],
|
get_next_checkruns_page: Callable[[list[dict[str, dict[str, Any]]], int, Any], Any],
|
||||||
get_next_checksuites: Callable[[Any], Any],
|
get_next_checksuites: Callable[[Any], Any],
|
||||||
) -> JobNameToStateDict:
|
) -> JobNameToStateDict:
|
||||||
# graphql seems to favor the most recent workflow run, so in theory we
|
# graphql seems to favor the most recent workflow run, so in theory we
|
||||||
# shouldn't need to account for reruns, but do it just in case
|
# shouldn't need to account for reruns, but do it just in case
|
||||||
|
|
||||||
# workflow -> job -> job info
|
# workflow -> job -> job info
|
||||||
workflows: Dict[str, WorkflowCheckState] = {}
|
workflows: dict[str, WorkflowCheckState] = {}
|
||||||
|
|
||||||
# for the jobs that don't have a workflow
|
# for the jobs that don't have a workflow
|
||||||
no_workflow_obj: WorkflowCheckState = WorkflowCheckState("", "", 0, None)
|
no_workflow_obj: WorkflowCheckState = WorkflowCheckState("", "", 0, None)
|
||||||
|
|
@ -633,8 +624,8 @@ def _revlist_to_prs(
|
||||||
pr: "GitHubPR",
|
pr: "GitHubPR",
|
||||||
rev_list: Iterable[str],
|
rev_list: Iterable[str],
|
||||||
should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None,
|
should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None,
|
||||||
) -> List[Tuple["GitHubPR", str]]:
|
) -> list[tuple["GitHubPR", str]]:
|
||||||
rc: List[Tuple[GitHubPR, str]] = []
|
rc: list[tuple[GitHubPR, str]] = []
|
||||||
for idx, rev in enumerate(rev_list):
|
for idx, rev in enumerate(rev_list):
|
||||||
msg = repo.commit_message(rev)
|
msg = repo.commit_message(rev)
|
||||||
m = RE_PULL_REQUEST_RESOLVED.search(msg)
|
m = RE_PULL_REQUEST_RESOLVED.search(msg)
|
||||||
|
|
@ -656,7 +647,7 @@ def _revlist_to_prs(
|
||||||
|
|
||||||
def get_ghstack_prs(
|
def get_ghstack_prs(
|
||||||
repo: GitRepo, pr: "GitHubPR", open_only: bool = True
|
repo: GitRepo, pr: "GitHubPR", open_only: bool = True
|
||||||
) -> List[Tuple["GitHubPR", str]]:
|
) -> list[tuple["GitHubPR", str]]:
|
||||||
"""
|
"""
|
||||||
Get the PRs in the stack that are below this PR (inclusive). Throws error if any of the open PRs are out of sync.
|
Get the PRs in the stack that are below this PR (inclusive). Throws error if any of the open PRs are out of sync.
|
||||||
@:param open_only: Only return open PRs
|
@:param open_only: Only return open PRs
|
||||||
|
|
@ -701,14 +692,14 @@ class GitHubPR:
|
||||||
self.project = project
|
self.project = project
|
||||||
self.pr_num = pr_num
|
self.pr_num = pr_num
|
||||||
self.info = gh_get_pr_info(org, project, pr_num)
|
self.info = gh_get_pr_info(org, project, pr_num)
|
||||||
self.changed_files: Optional[List[str]] = None
|
self.changed_files: Optional[list[str]] = None
|
||||||
self.labels: Optional[List[str]] = None
|
self.labels: Optional[list[str]] = None
|
||||||
self.conclusions: Optional[JobNameToStateDict] = None
|
self.conclusions: Optional[JobNameToStateDict] = None
|
||||||
self.comments: Optional[List[GitHubComment]] = None
|
self.comments: Optional[list[GitHubComment]] = None
|
||||||
self._authors: Optional[List[Tuple[str, str]]] = None
|
self._authors: Optional[list[tuple[str, str]]] = None
|
||||||
self._reviews: Optional[List[Tuple[str, str]]] = None
|
self._reviews: Optional[list[tuple[str, str]]] = None
|
||||||
self.merge_base: Optional[str] = None
|
self.merge_base: Optional[str] = None
|
||||||
self.submodules: Optional[List[str]] = None
|
self.submodules: Optional[list[str]] = None
|
||||||
|
|
||||||
def is_closed(self) -> bool:
|
def is_closed(self) -> bool:
|
||||||
return bool(self.info["closed"])
|
return bool(self.info["closed"])
|
||||||
|
|
@ -763,7 +754,7 @@ class GitHubPR:
|
||||||
|
|
||||||
return self.merge_base
|
return self.merge_base
|
||||||
|
|
||||||
def get_changed_files(self) -> List[str]:
|
def get_changed_files(self) -> list[str]:
|
||||||
if self.changed_files is None:
|
if self.changed_files is None:
|
||||||
info = self.info
|
info = self.info
|
||||||
unique_changed_files = set()
|
unique_changed_files = set()
|
||||||
|
|
@ -786,14 +777,14 @@ class GitHubPR:
|
||||||
raise RuntimeError("Changed file count mismatch")
|
raise RuntimeError("Changed file count mismatch")
|
||||||
return self.changed_files
|
return self.changed_files
|
||||||
|
|
||||||
def get_submodules(self) -> List[str]:
|
def get_submodules(self) -> list[str]:
|
||||||
if self.submodules is None:
|
if self.submodules is None:
|
||||||
rc = gh_graphql(GH_GET_REPO_SUBMODULES, name=self.project, owner=self.org)
|
rc = gh_graphql(GH_GET_REPO_SUBMODULES, name=self.project, owner=self.org)
|
||||||
info = rc["data"]["repository"]["submodules"]
|
info = rc["data"]["repository"]["submodules"]
|
||||||
self.submodules = [s["path"] for s in info["nodes"]]
|
self.submodules = [s["path"] for s in info["nodes"]]
|
||||||
return self.submodules
|
return self.submodules
|
||||||
|
|
||||||
def get_changed_submodules(self) -> List[str]:
|
def get_changed_submodules(self) -> list[str]:
|
||||||
submodules = self.get_submodules()
|
submodules = self.get_submodules()
|
||||||
return [f for f in self.get_changed_files() if f in submodules]
|
return [f for f in self.get_changed_files() if f in submodules]
|
||||||
|
|
||||||
|
|
@ -809,7 +800,7 @@ class GitHubPR:
|
||||||
and all("submodule" not in label for label in self.get_labels())
|
and all("submodule" not in label for label in self.get_labels())
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_reviews(self) -> List[Tuple[str, str]]:
|
def _get_reviews(self) -> list[tuple[str, str]]:
|
||||||
if self._reviews is None:
|
if self._reviews is None:
|
||||||
self._reviews = []
|
self._reviews = []
|
||||||
info = self.info
|
info = self.info
|
||||||
|
|
@ -834,7 +825,7 @@ class GitHubPR:
|
||||||
reviews[author] = state
|
reviews[author] = state
|
||||||
return list(reviews.items())
|
return list(reviews.items())
|
||||||
|
|
||||||
def get_approved_by(self) -> List[str]:
|
def get_approved_by(self) -> list[str]:
|
||||||
return [login for (login, state) in self._get_reviews() if state == "APPROVED"]
|
return [login for (login, state) in self._get_reviews() if state == "APPROVED"]
|
||||||
|
|
||||||
def get_commit_count(self) -> int:
|
def get_commit_count(self) -> int:
|
||||||
|
|
@ -843,12 +834,12 @@ class GitHubPR:
|
||||||
def get_pr_creator_login(self) -> str:
|
def get_pr_creator_login(self) -> str:
|
||||||
return cast(str, self.info["author"]["login"])
|
return cast(str, self.info["author"]["login"])
|
||||||
|
|
||||||
def _fetch_authors(self) -> List[Tuple[str, str]]:
|
def _fetch_authors(self) -> list[tuple[str, str]]:
|
||||||
if self._authors is not None:
|
if self._authors is not None:
|
||||||
return self._authors
|
return self._authors
|
||||||
authors: List[Tuple[str, str]] = []
|
authors: list[tuple[str, str]] = []
|
||||||
|
|
||||||
def add_authors(info: Dict[str, Any]) -> None:
|
def add_authors(info: dict[str, Any]) -> None:
|
||||||
for node in info["commits_with_authors"]["nodes"]:
|
for node in info["commits_with_authors"]["nodes"]:
|
||||||
for author_node in node["commit"]["authors"]["nodes"]:
|
for author_node in node["commit"]["authors"]["nodes"]:
|
||||||
user_node = author_node["user"]
|
user_node = author_node["user"]
|
||||||
|
|
@ -881,7 +872,7 @@ class GitHubPR:
|
||||||
def get_committer_author(self, num: int = 0) -> str:
|
def get_committer_author(self, num: int = 0) -> str:
|
||||||
return self._fetch_authors()[num][1]
|
return self._fetch_authors()[num][1]
|
||||||
|
|
||||||
def get_labels(self) -> List[str]:
|
def get_labels(self) -> list[str]:
|
||||||
if self.labels is not None:
|
if self.labels is not None:
|
||||||
return self.labels
|
return self.labels
|
||||||
labels = (
|
labels = (
|
||||||
|
|
@ -899,7 +890,7 @@ class GitHubPR:
|
||||||
orig_last_commit = self.last_commit()
|
orig_last_commit = self.last_commit()
|
||||||
|
|
||||||
def get_pr_next_check_runs(
|
def get_pr_next_check_runs(
|
||||||
edges: List[Dict[str, Dict[str, Any]]], edge_idx: int, checkruns: Any
|
edges: list[dict[str, dict[str, Any]]], edge_idx: int, checkruns: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
rc = gh_graphql(
|
rc = gh_graphql(
|
||||||
GH_GET_PR_NEXT_CHECK_RUNS,
|
GH_GET_PR_NEXT_CHECK_RUNS,
|
||||||
|
|
@ -951,7 +942,7 @@ class GitHubPR:
|
||||||
|
|
||||||
return self.conclusions
|
return self.conclusions
|
||||||
|
|
||||||
def get_authors(self) -> Dict[str, str]:
|
def get_authors(self) -> dict[str, str]:
|
||||||
rc = {}
|
rc = {}
|
||||||
for idx in range(len(self._fetch_authors())):
|
for idx in range(len(self._fetch_authors())):
|
||||||
rc[self.get_committer_login(idx)] = self.get_committer_author(idx)
|
rc[self.get_committer_login(idx)] = self.get_committer_author(idx)
|
||||||
|
|
@ -995,7 +986,7 @@ class GitHubPR:
|
||||||
url=node["url"],
|
url=node["url"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_comments(self) -> List[GitHubComment]:
|
def get_comments(self) -> list[GitHubComment]:
|
||||||
if self.comments is not None:
|
if self.comments is not None:
|
||||||
return self.comments
|
return self.comments
|
||||||
self.comments = []
|
self.comments = []
|
||||||
|
|
@ -1069,7 +1060,7 @@ class GitHubPR:
|
||||||
skip_mandatory_checks: bool,
|
skip_mandatory_checks: bool,
|
||||||
comment_id: Optional[int] = None,
|
comment_id: Optional[int] = None,
|
||||||
skip_all_rule_checks: bool = False,
|
skip_all_rule_checks: bool = False,
|
||||||
) -> List["GitHubPR"]:
|
) -> list["GitHubPR"]:
|
||||||
assert self.is_ghstack_pr()
|
assert self.is_ghstack_pr()
|
||||||
ghstack_prs = get_ghstack_prs(
|
ghstack_prs = get_ghstack_prs(
|
||||||
repo, self, open_only=False
|
repo, self, open_only=False
|
||||||
|
|
@ -1099,7 +1090,7 @@ class GitHubPR:
|
||||||
def gen_commit_message(
|
def gen_commit_message(
|
||||||
self,
|
self,
|
||||||
filter_ghstack: bool = False,
|
filter_ghstack: bool = False,
|
||||||
ghstack_deps: Optional[List["GitHubPR"]] = None,
|
ghstack_deps: Optional[list["GitHubPR"]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Fetches title and body from PR description
|
"""Fetches title and body from PR description
|
||||||
adds reviewed by, pull request resolved and optionally
|
adds reviewed by, pull request resolved and optionally
|
||||||
|
|
@ -1151,7 +1142,7 @@ class GitHubPR:
|
||||||
skip_mandatory_checks: bool = False,
|
skip_mandatory_checks: bool = False,
|
||||||
dry_run: bool = False,
|
dry_run: bool = False,
|
||||||
comment_id: Optional[int] = None,
|
comment_id: Optional[int] = None,
|
||||||
ignore_current_checks: Optional[List[str]] = None,
|
ignore_current_checks: Optional[list[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Raises exception if matching rule is not found
|
# Raises exception if matching rule is not found
|
||||||
(
|
(
|
||||||
|
|
@ -1223,7 +1214,7 @@ class GitHubPR:
|
||||||
comment_id: Optional[int] = None,
|
comment_id: Optional[int] = None,
|
||||||
branch: Optional[str] = None,
|
branch: Optional[str] = None,
|
||||||
skip_all_rule_checks: bool = False,
|
skip_all_rule_checks: bool = False,
|
||||||
) -> List["GitHubPR"]:
|
) -> list["GitHubPR"]:
|
||||||
"""
|
"""
|
||||||
:param skip_all_rule_checks: If true, skips all rule checks, useful for dry-running merge locally
|
:param skip_all_rule_checks: If true, skips all rule checks, useful for dry-running merge locally
|
||||||
"""
|
"""
|
||||||
|
|
@ -1263,14 +1254,14 @@ class PostCommentError(Exception):
|
||||||
@dataclass
|
@dataclass
|
||||||
class MergeRule:
|
class MergeRule:
|
||||||
name: str
|
name: str
|
||||||
patterns: List[str]
|
patterns: list[str]
|
||||||
approved_by: List[str]
|
approved_by: list[str]
|
||||||
mandatory_checks_name: Optional[List[str]]
|
mandatory_checks_name: Optional[list[str]]
|
||||||
ignore_flaky_failures: bool = True
|
ignore_flaky_failures: bool = True
|
||||||
|
|
||||||
|
|
||||||
def gen_new_issue_link(
|
def gen_new_issue_link(
|
||||||
org: str, project: str, labels: List[str], template: str = "bug-report.yml"
|
org: str, project: str, labels: list[str], template: str = "bug-report.yml"
|
||||||
) -> str:
|
) -> str:
|
||||||
labels_str = ",".join(labels)
|
labels_str = ",".join(labels)
|
||||||
return (
|
return (
|
||||||
|
|
@ -1282,7 +1273,7 @@ def gen_new_issue_link(
|
||||||
|
|
||||||
def read_merge_rules(
|
def read_merge_rules(
|
||||||
repo: Optional[GitRepo], org: str, project: str
|
repo: Optional[GitRepo], org: str, project: str
|
||||||
) -> List[MergeRule]:
|
) -> list[MergeRule]:
|
||||||
"""Returns the list of all merge rules for the repo or project.
|
"""Returns the list of all merge rules for the repo or project.
|
||||||
|
|
||||||
NB: this function is used in Meta-internal workflows, see the comment
|
NB: this function is used in Meta-internal workflows, see the comment
|
||||||
|
|
@ -1312,12 +1303,12 @@ def find_matching_merge_rule(
|
||||||
repo: Optional[GitRepo] = None,
|
repo: Optional[GitRepo] = None,
|
||||||
skip_mandatory_checks: bool = False,
|
skip_mandatory_checks: bool = False,
|
||||||
skip_internal_checks: bool = False,
|
skip_internal_checks: bool = False,
|
||||||
ignore_current_checks: Optional[List[str]] = None,
|
ignore_current_checks: Optional[list[str]] = None,
|
||||||
) -> Tuple[
|
) -> tuple[
|
||||||
MergeRule,
|
MergeRule,
|
||||||
List[Tuple[str, Optional[str], Optional[int]]],
|
list[tuple[str, Optional[str], Optional[int]]],
|
||||||
List[Tuple[str, Optional[str], Optional[int]]],
|
list[tuple[str, Optional[str], Optional[int]]],
|
||||||
Dict[str, List[Any]],
|
dict[str, list[Any]],
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
Returns merge rule matching to this pr together with the list of associated pending
|
Returns merge rule matching to this pr together with the list of associated pending
|
||||||
|
|
@ -1504,13 +1495,13 @@ def find_matching_merge_rule(
|
||||||
raise MergeRuleFailedError(reject_reason, rule)
|
raise MergeRuleFailedError(reject_reason, rule)
|
||||||
|
|
||||||
|
|
||||||
def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
|
def checks_to_str(checks: list[tuple[str, Optional[str]]]) -> str:
|
||||||
return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)
|
return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)
|
||||||
|
|
||||||
|
|
||||||
def checks_to_markdown_bullets(
|
def checks_to_markdown_bullets(
|
||||||
checks: List[Tuple[str, Optional[str], Optional[int]]],
|
checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
return [
|
return [
|
||||||
f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]
|
f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]
|
||||||
]
|
]
|
||||||
|
|
@ -1518,7 +1509,7 @@ def checks_to_markdown_bullets(
|
||||||
|
|
||||||
def manually_close_merged_pr(
|
def manually_close_merged_pr(
|
||||||
pr: GitHubPR,
|
pr: GitHubPR,
|
||||||
additional_merged_prs: List[GitHubPR],
|
additional_merged_prs: list[GitHubPR],
|
||||||
merge_commit_sha: str,
|
merge_commit_sha: str,
|
||||||
dry_run: bool,
|
dry_run: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -1551,12 +1542,12 @@ def save_merge_record(
|
||||||
owner: str,
|
owner: str,
|
||||||
project: str,
|
project: str,
|
||||||
author: str,
|
author: str,
|
||||||
pending_checks: List[Tuple[str, Optional[str], Optional[int]]],
|
pending_checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||||
failed_checks: List[Tuple[str, Optional[str], Optional[int]]],
|
failed_checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||||
ignore_current_checks: List[Tuple[str, Optional[str], Optional[int]]],
|
ignore_current_checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||||
broken_trunk_checks: List[Tuple[str, Optional[str], Optional[int]]],
|
broken_trunk_checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||||
flaky_checks: List[Tuple[str, Optional[str], Optional[int]]],
|
flaky_checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||||
unstable_checks: List[Tuple[str, Optional[str], Optional[int]]],
|
unstable_checks: list[tuple[str, Optional[str], Optional[int]]],
|
||||||
last_commit_sha: str,
|
last_commit_sha: str,
|
||||||
merge_base_sha: str,
|
merge_base_sha: str,
|
||||||
merge_commit_sha: str = "",
|
merge_commit_sha: str = "",
|
||||||
|
|
@ -1714,9 +1705,9 @@ def is_invalid_cancel(
|
||||||
def get_classifications(
|
def get_classifications(
|
||||||
pr_num: int,
|
pr_num: int,
|
||||||
project: str,
|
project: str,
|
||||||
checks: Dict[str, JobCheckState],
|
checks: dict[str, JobCheckState],
|
||||||
ignore_current_checks: Optional[List[str]],
|
ignore_current_checks: Optional[list[str]],
|
||||||
) -> Dict[str, JobCheckState]:
|
) -> dict[str, JobCheckState]:
|
||||||
# Get the failure classification from Dr.CI, which is the source of truth
|
# Get the failure classification from Dr.CI, which is the source of truth
|
||||||
# going forward. It's preferable to try calling Dr.CI API directly first
|
# going forward. It's preferable to try calling Dr.CI API directly first
|
||||||
# to get the latest results as well as update Dr.CI PR comment
|
# to get the latest results as well as update Dr.CI PR comment
|
||||||
|
|
@ -1825,7 +1816,7 @@ def get_classifications(
|
||||||
|
|
||||||
def filter_checks_with_lambda(
|
def filter_checks_with_lambda(
|
||||||
checks: JobNameToStateDict, status_filter: Callable[[Optional[str]], bool]
|
checks: JobNameToStateDict, status_filter: Callable[[Optional[str]], bool]
|
||||||
) -> List[JobCheckState]:
|
) -> list[JobCheckState]:
|
||||||
return [check for check in checks.values() if status_filter(check.status)]
|
return [check for check in checks.values() if status_filter(check.status)]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1841,7 +1832,7 @@ def get_pr_commit_sha(repo: GitRepo, pr: GitHubPR) -> str:
|
||||||
|
|
||||||
def validate_revert(
|
def validate_revert(
|
||||||
repo: GitRepo, pr: GitHubPR, *, comment_id: Optional[int] = None
|
repo: GitRepo, pr: GitHubPR, *, comment_id: Optional[int] = None
|
||||||
) -> Tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
comment = (
|
comment = (
|
||||||
pr.get_last_comment()
|
pr.get_last_comment()
|
||||||
if comment_id is None
|
if comment_id is None
|
||||||
|
|
@ -1871,7 +1862,7 @@ def validate_revert(
|
||||||
|
|
||||||
def get_ghstack_dependent_prs(
|
def get_ghstack_dependent_prs(
|
||||||
repo: GitRepo, pr: GitHubPR, only_closed: bool = True
|
repo: GitRepo, pr: GitHubPR, only_closed: bool = True
|
||||||
) -> List[Tuple[str, GitHubPR]]:
|
) -> list[tuple[str, GitHubPR]]:
|
||||||
"""
|
"""
|
||||||
Get the PRs in the stack that are above this PR (inclusive).
|
Get the PRs in the stack that are above this PR (inclusive).
|
||||||
Throws error if stack have branched or original branches are gone
|
Throws error if stack have branched or original branches are gone
|
||||||
|
|
@ -1897,7 +1888,7 @@ def get_ghstack_dependent_prs(
|
||||||
# Remove commits original PR depends on
|
# Remove commits original PR depends on
|
||||||
if skip_len > 0:
|
if skip_len > 0:
|
||||||
rev_list = rev_list[:-skip_len]
|
rev_list = rev_list[:-skip_len]
|
||||||
rc: List[Tuple[str, GitHubPR]] = []
|
rc: list[tuple[str, GitHubPR]] = []
|
||||||
for pr_, sha in _revlist_to_prs(repo, pr, rev_list):
|
for pr_, sha in _revlist_to_prs(repo, pr, rev_list):
|
||||||
if not pr_.is_closed():
|
if not pr_.is_closed():
|
||||||
if not only_closed:
|
if not only_closed:
|
||||||
|
|
@ -1910,7 +1901,7 @@ def get_ghstack_dependent_prs(
|
||||||
|
|
||||||
def do_revert_prs(
|
def do_revert_prs(
|
||||||
repo: GitRepo,
|
repo: GitRepo,
|
||||||
shas_and_prs: List[Tuple[str, GitHubPR]],
|
shas_and_prs: list[tuple[str, GitHubPR]],
|
||||||
*,
|
*,
|
||||||
author_login: str,
|
author_login: str,
|
||||||
extra_msg: str = "",
|
extra_msg: str = "",
|
||||||
|
|
@ -2001,7 +1992,7 @@ def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None:
|
||||||
if skip_mandatory_checks:
|
if skip_mandatory_checks:
|
||||||
return
|
return
|
||||||
response = cast(
|
response = cast(
|
||||||
Dict[str, Any],
|
dict[str, Any],
|
||||||
gh_fetch_json_list(
|
gh_fetch_json_list(
|
||||||
"https://api.github.com/search/issues",
|
"https://api.github.com/search/issues",
|
||||||
# Having two label: queries is an AND operation
|
# Having two label: queries is an AND operation
|
||||||
|
|
@ -2019,29 +2010,29 @@ def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
|
def has_label(labels: list[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
|
||||||
return len(list(filter(pattern.match, labels))) > 0
|
return len(list(filter(pattern.match, labels))) > 0
|
||||||
|
|
||||||
|
|
||||||
def categorize_checks(
|
def categorize_checks(
|
||||||
check_runs: JobNameToStateDict,
|
check_runs: JobNameToStateDict,
|
||||||
required_checks: List[str],
|
required_checks: list[str],
|
||||||
ok_failed_checks_threshold: Optional[int] = None,
|
ok_failed_checks_threshold: Optional[int] = None,
|
||||||
) -> Tuple[
|
) -> tuple[
|
||||||
List[Tuple[str, Optional[str], Optional[int]]],
|
list[tuple[str, Optional[str], Optional[int]]],
|
||||||
List[Tuple[str, Optional[str], Optional[int]]],
|
list[tuple[str, Optional[str], Optional[int]]],
|
||||||
Dict[str, List[Any]],
|
dict[str, list[Any]],
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
Categories all jobs into the list of pending and failing jobs. All known flaky
|
Categories all jobs into the list of pending and failing jobs. All known flaky
|
||||||
failures and broken trunk are ignored by defaults when ok_failed_checks_threshold
|
failures and broken trunk are ignored by defaults when ok_failed_checks_threshold
|
||||||
is not set (unlimited)
|
is not set (unlimited)
|
||||||
"""
|
"""
|
||||||
pending_checks: List[Tuple[str, Optional[str], Optional[int]]] = []
|
pending_checks: list[tuple[str, Optional[str], Optional[int]]] = []
|
||||||
failed_checks: List[Tuple[str, Optional[str], Optional[int]]] = []
|
failed_checks: list[tuple[str, Optional[str], Optional[int]]] = []
|
||||||
|
|
||||||
# failed_checks_categorization is used to keep track of all ignorable failures when saving the merge record on s3
|
# failed_checks_categorization is used to keep track of all ignorable failures when saving the merge record on s3
|
||||||
failed_checks_categorization: Dict[str, List[Any]] = defaultdict(list)
|
failed_checks_categorization: dict[str, list[Any]] = defaultdict(list)
|
||||||
|
|
||||||
# If required_checks is not set or empty, consider all names are relevant
|
# If required_checks is not set or empty, consider all names are relevant
|
||||||
relevant_checknames = [
|
relevant_checknames = [
|
||||||
|
|
|
||||||
13
.github/scripts/trymerge_explainer.py
vendored
13
.github/scripts/trymerge_explainer.py
vendored
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Pattern, Tuple
|
from re import Pattern
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
BOT_COMMANDS_WIKI = "https://github.com/pytorch/pytorch/wiki/Bot-commands"
|
BOT_COMMANDS_WIKI = "https://github.com/pytorch/pytorch/wiki/Bot-commands"
|
||||||
|
|
@ -13,13 +14,13 @@ CONTACT_US = f"Questions? Feedback? Please reach out to the [PyTorch DevX Team](
|
||||||
ALTERNATIVES = f"Learn more about merging in the [wiki]({BOT_COMMANDS_WIKI})."
|
ALTERNATIVES = f"Learn more about merging in the [wiki]({BOT_COMMANDS_WIKI})."
|
||||||
|
|
||||||
|
|
||||||
def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
|
def has_label(labels: list[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
|
||||||
return len(list(filter(pattern.match, labels))) > 0
|
return len(list(filter(pattern.match, labels))) > 0
|
||||||
|
|
||||||
|
|
||||||
class TryMergeExplainer:
|
class TryMergeExplainer:
|
||||||
force: bool
|
force: bool
|
||||||
labels: List[str]
|
labels: list[str]
|
||||||
pr_num: int
|
pr_num: int
|
||||||
org: str
|
org: str
|
||||||
project: str
|
project: str
|
||||||
|
|
@ -31,7 +32,7 @@ class TryMergeExplainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
force: bool,
|
force: bool,
|
||||||
labels: List[str],
|
labels: list[str],
|
||||||
pr_num: int,
|
pr_num: int,
|
||||||
org: str,
|
org: str,
|
||||||
project: str,
|
project: str,
|
||||||
|
|
@ -47,7 +48,7 @@ class TryMergeExplainer:
|
||||||
def _get_flag_msg(
|
def _get_flag_msg(
|
||||||
self,
|
self,
|
||||||
ignore_current_checks: Optional[
|
ignore_current_checks: Optional[
|
||||||
List[Tuple[str, Optional[str], Optional[int]]]
|
list[tuple[str, Optional[str], Optional[int]]]
|
||||||
] = None,
|
] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if self.force:
|
if self.force:
|
||||||
|
|
@ -68,7 +69,7 @@ class TryMergeExplainer:
|
||||||
def get_merge_message(
|
def get_merge_message(
|
||||||
self,
|
self,
|
||||||
ignore_current_checks: Optional[
|
ignore_current_checks: Optional[
|
||||||
List[Tuple[str, Optional[str], Optional[int]]]
|
list[tuple[str, Optional[str], Optional[int]]]
|
||||||
] = None,
|
] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
title = "### Merge started"
|
title = "### Merge started"
|
||||||
|
|
|
||||||
3
.github/scripts/tryrebase.py
vendored
3
.github/scripts/tryrebase.py
vendored
|
|
@ -5,7 +5,8 @@ import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Generator
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from github_utils import gh_post_pr_comment as gh_post_comment
|
from github_utils import gh_post_pr_comment as gh_post_comment
|
||||||
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
||||||
|
|
|
||||||
31
.github/workflows/_runner-determinator.yml
vendored
31
.github/workflows/_runner-determinator.yml
vendored
|
|
@ -129,9 +129,10 @@ jobs:
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from functools import lru_cache
|
from collections.abc import Iterable
|
||||||
|
from functools import cache
|
||||||
from logging import LogRecord
|
from logging import LogRecord
|
||||||
from typing import Any, Dict, FrozenSet, Iterable, List, NamedTuple, Set, Tuple
|
from typing import Any, NamedTuple
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -173,7 +174,7 @@ jobs:
|
||||||
Settings for the experiments that can be opted into.
|
Settings for the experiments that can be opted into.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
experiments: Dict[str, Experiment] = {}
|
experiments: dict[str, Experiment] = {}
|
||||||
|
|
||||||
|
|
||||||
class ColorFormatter(logging.Formatter):
|
class ColorFormatter(logging.Formatter):
|
||||||
|
|
@ -218,7 +219,7 @@ jobs:
|
||||||
f.write(f"{key}={value}\n")
|
f.write(f"{key}={value}\n")
|
||||||
|
|
||||||
|
|
||||||
def _str_comma_separated_to_set(value: str) -> FrozenSet[str]:
|
def _str_comma_separated_to_set(value: str) -> frozenset[str]:
|
||||||
return frozenset(
|
return frozenset(
|
||||||
filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(",")))
|
filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(",")))
|
||||||
)
|
)
|
||||||
|
|
@ -276,12 +277,12 @@ jobs:
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def get_gh_client(github_token: str) -> Github:
|
def get_gh_client(github_token: str) -> Github: # type: ignore[no-any-unimported]
|
||||||
auth = Auth.Token(github_token)
|
auth = Auth.Token(github_token)
|
||||||
return Github(auth=auth)
|
return Github(auth=auth)
|
||||||
|
|
||||||
|
|
||||||
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue:
|
def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: # type: ignore[no-any-unimported]
|
||||||
repo = gh.get_repo(repo)
|
repo = gh.get_repo(repo)
|
||||||
return repo.get_issue(number=issue_num)
|
return repo.get_issue(number=issue_num)
|
||||||
|
|
||||||
|
|
@ -310,7 +311,7 @@ jobs:
|
||||||
raise Exception( # noqa: TRY002
|
raise Exception( # noqa: TRY002
|
||||||
f"issue with pull request {pr_number} from repo {repository}"
|
f"issue with pull request {pr_number} from repo {repository}"
|
||||||
) from e
|
) from e
|
||||||
return pull.user.login
|
return pull.user.login # type: ignore[no-any-return]
|
||||||
# In all other cases, return the original input username
|
# In all other cases, return the original input username
|
||||||
return username
|
return username
|
||||||
|
|
||||||
|
|
@ -331,7 +332,7 @@ jobs:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]:
|
def extract_settings_user_opt_in_from_text(rollout_state: str) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Extracts the text with settings, if any, and the opted in users from the rollout state.
|
Extracts the text with settings, if any, and the opted in users from the rollout state.
|
||||||
|
|
||||||
|
|
@ -347,7 +348,7 @@ jobs:
|
||||||
return "", rollout_state
|
return "", rollout_state
|
||||||
|
|
||||||
|
|
||||||
class UserOptins(Dict[str, List[str]]):
|
class UserOptins(dict[str, list[str]]):
|
||||||
"""
|
"""
|
||||||
Dictionary of users with a list of features they have opted into
|
Dictionary of users with a list of features they have opted into
|
||||||
"""
|
"""
|
||||||
|
|
@ -488,7 +489,7 @@ jobs:
|
||||||
rollout_state: str,
|
rollout_state: str,
|
||||||
workflow_requestors: Iterable[str],
|
workflow_requestors: Iterable[str],
|
||||||
branch: str,
|
branch: str,
|
||||||
eligible_experiments: FrozenSet[str] = frozenset(),
|
eligible_experiments: frozenset[str] = frozenset(),
|
||||||
is_canary: bool = False,
|
is_canary: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
settings = parse_settings(rollout_state)
|
settings = parse_settings(rollout_state)
|
||||||
|
|
@ -587,7 +588,7 @@ jobs:
|
||||||
return str(issue.get_comments()[0].body.strip("\n\t "))
|
return str(issue.get_comments()[0].body.strip("\n\t "))
|
||||||
|
|
||||||
|
|
||||||
def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any:
|
def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any:
|
||||||
for _ in range(num_retries):
|
for _ in range(num_retries):
|
||||||
try:
|
try:
|
||||||
req = Request(url=url, headers=headers)
|
req = Request(url=url, headers=headers)
|
||||||
|
|
@ -600,8 +601,8 @@ jobs:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@cache
|
||||||
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> Dict[str, Any]:
|
def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Dynamically get PR information
|
Dynamically get PR information
|
||||||
"""
|
"""
|
||||||
|
|
@ -610,7 +611,7 @@ jobs:
|
||||||
"Accept": "application/vnd.github.v3+json",
|
"Accept": "application/vnd.github.v3+json",
|
||||||
"Authorization": f"token {github_token}",
|
"Authorization": f"token {github_token}",
|
||||||
}
|
}
|
||||||
json_response: Dict[str, Any] = download_json(
|
json_response: dict[str, Any] = download_json(
|
||||||
url=f"{github_api}/issues/{pr_number}",
|
url=f"{github_api}/issues/{pr_number}",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
@ -622,7 +623,7 @@ jobs:
|
||||||
return json_response
|
return json_response
|
||||||
|
|
||||||
|
|
||||||
def get_labels(github_repo: str, github_token: str, pr_number: int) -> Set[str]:
|
def get_labels(github_repo: str, github_token: str, pr_number: int) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Dynamically get the latest list of labels from the pull request
|
Dynamically get the latest list of labels from the pull request
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user