mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add trymerge workflow (#71488)
Summary:
This one, will react to `repo_dispatch` event sent by PyTorch Probot
when `pytorchbot merge this` command is issued
At the moment, workflow will only attempt to merge PRs which has not
been created from forked repo and that match rules defined in
`.github/merge_rules.json`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71488
Reviewed By: bigfootjon
Differential Revision: D33665142
Pulled By: malfet
fbshipit-source-id: e22daa1892523e62d7b7a941960636a6514cb7d7
(cherry picked from commit 92059bab07)
This commit is contained in:
parent
f45e217c01
commit
61713acb07
66
.github/scripts/gitutils.py
vendored
66
.github/scripts/gitutils.py
vendored
|
|
@ -211,22 +211,54 @@ class GitRepo:
|
|||
self._run_git("commit", "--amend", "-m", msg)
|
||||
|
||||
|
||||
def parse_args() -> Any:
|
||||
from argparse import ArgumentParser
|
||||
parser = ArgumentParser("Merge PR/branch into default branch")
|
||||
parser.add_argument("--sync-branch", default="sync")
|
||||
parser.add_argument("--default-branch", type=str, default="main")
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
return parser.parse_args()
|
||||
class PeekableIterator:
|
||||
def __init__(self, val: str) -> None:
|
||||
self._val = val
|
||||
self._idx = -1
|
||||
|
||||
def peek(self) -> Optional[str]:
|
||||
if self._idx + 1 >= len(self._val):
|
||||
return None
|
||||
return self._val[self._idx + 1]
|
||||
|
||||
def __iter__(self) -> Any:
|
||||
return self
|
||||
|
||||
def __next__(self) -> str:
|
||||
rc = self.peek()
|
||||
if rc is None:
|
||||
raise StopIteration
|
||||
self._idx += 1
|
||||
return rc
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
repo = GitRepo(get_git_repo_dir())
|
||||
repo.cherry_pick_commits(args.sync_branch, args.default_branch)
|
||||
if not args.dry_run:
|
||||
repo.push(args.default_branch)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
def patterns_to_regex(allowed_patterns: List[str]) -> Any:
|
||||
"""
|
||||
pattern is glob-like, i.e. the only special sequences it has are:
|
||||
- ? - matches single character
|
||||
- * - matches any non-folder separator characters
|
||||
- ** - matches any characters
|
||||
Assuming that patterns are free of braces and backslashes
|
||||
the only character that needs to be escaped are dot and plus
|
||||
"""
|
||||
rc = "("
|
||||
for idx, pattern in enumerate(allowed_patterns):
|
||||
if idx > 0:
|
||||
rc += "|"
|
||||
pattern_ = PeekableIterator(pattern)
|
||||
assert not any(c in pattern for c in "{}()[]\\")
|
||||
for c in pattern_:
|
||||
if c == ".":
|
||||
rc += "\\."
|
||||
elif c == "+":
|
||||
rc += "\\+"
|
||||
elif c == "*":
|
||||
if pattern_.peek() == "*":
|
||||
next(pattern_)
|
||||
rc += ".+"
|
||||
else:
|
||||
rc += "[^/]+"
|
||||
else:
|
||||
rc += c
|
||||
rc += ")"
|
||||
return re.compile(rc)
|
||||
|
|
|
|||
340
.github/scripts/trymerge.py
vendored
Executable file
340
.github/scripts/trymerge.py
vendored
Executable file
|
|
@ -0,0 +1,340 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from urllib.request import urlopen, Request
|
||||
from urllib.error import HTTPError
|
||||
from typing import cast, Any, Callable, Dict, List, Optional, Tuple
|
||||
from gitutils import get_git_remote_name, get_git_repo_dir, patterns_to_regex, GitRepo
|
||||
|
||||
|
||||
GH_GET_PR_INFO_QUERY = """
|
||||
query ($owner: String!, $name: String!, $number: Int!) {
|
||||
repository(owner: $owner, name: $name) {
|
||||
pullRequest(number: $number) {
|
||||
closed
|
||||
isCrossRepository
|
||||
author {
|
||||
login
|
||||
}
|
||||
title
|
||||
body
|
||||
headRefName
|
||||
headRepository {
|
||||
nameWithOwner
|
||||
}
|
||||
baseRefName
|
||||
baseRepository {
|
||||
nameWithOwner
|
||||
defaultBranchRef {
|
||||
name
|
||||
}
|
||||
}
|
||||
commits(first: 100) {
|
||||
nodes {
|
||||
commit {
|
||||
author {
|
||||
user {
|
||||
login
|
||||
}
|
||||
email
|
||||
name
|
||||
}
|
||||
oid
|
||||
checkSuites(filterBy: {appId: 12274}, first: 1) {
|
||||
nodes {
|
||||
app {
|
||||
databaseId
|
||||
}
|
||||
conclusion
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
totalCount
|
||||
}
|
||||
changedFiles,
|
||||
files(last: 100) {
|
||||
nodes {
|
||||
path
|
||||
}
|
||||
}
|
||||
latestReviews(last: 100) {
|
||||
nodes {
|
||||
author {
|
||||
login
|
||||
},
|
||||
state
|
||||
},
|
||||
totalCount
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
RE_GHSTACK_HEAD_REF = re.compile(r"^(gh/[^/]+/[0-9]+/)head$")
|
||||
RE_GHSTACK_SOURCE_ID = re.compile(r'^ghstack-source-id: (.+)\n?', re.MULTILINE)
|
||||
RE_PULL_REQUEST_RESOLVED = re.compile(
|
||||
r'Pull Request resolved: '
|
||||
r'https://github.com/(?P<owner>[^/]+)/(?P<repo>[^/]+)/pull/(?P<number>[0-9]+)',
|
||||
re.MULTILINE
|
||||
)
|
||||
|
||||
|
||||
def _fetch_url(url: str, *,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
method: Optional[str] = None,
|
||||
reader: Callable[[Any], Any] = lambda x: x.read()) -> Any:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
token = os.environ.get("GITHUB_TOKEN")
|
||||
if token is not None and url.startswith('https://api.github.com/'):
|
||||
headers['Authorization'] = f'token {token}'
|
||||
data_ = json.dumps(data).encode() if data is not None else None
|
||||
try:
|
||||
with urlopen(Request(url, headers=headers, data=data_, method=method)) as conn:
|
||||
return reader(conn)
|
||||
except HTTPError as err:
|
||||
if err.code == 403 and all(key in err.headers for key in ['X-RateLimit-Limit', 'X-RateLimit-Used']):
|
||||
print(f"Rate limit exceeded: {err.headers['X-RateLimit-Used']}/{err.headers['X-RateLimit-Limit']}")
|
||||
raise
|
||||
|
||||
|
||||
def fetch_json(url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||
headers = {'Accept': 'application/vnd.github.v3+json'}
|
||||
if params is not None and len(params) > 0:
|
||||
url += '?' + '&'.join(f"{name}={val}" for name, val in params.items())
|
||||
return cast(List[Dict[str, Any]], _fetch_url(url, headers=headers, data=data, reader=json.load))
|
||||
|
||||
|
||||
def gh_post_comment(org: str, project: str, pr_num: int, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
|
||||
if dry_run:
|
||||
print(comment)
|
||||
return []
|
||||
return fetch_json(f'https://api.github.com/repos/{org}/{project}/issues/{pr_num}/comments',
|
||||
data={"body": comment})
|
||||
|
||||
|
||||
def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
rc = _fetch_url("https://api.github.com/graphql", data={"query": query, "variables": kwargs}, reader=json.load)
|
||||
if "errors" in rc:
|
||||
raise RuntimeError(f"GraphQL query {query} failed: {rc['errors']}")
|
||||
return cast(Dict[str, Any], rc)
|
||||
|
||||
|
||||
def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any:
|
||||
rc = gh_graphql(GH_GET_PR_INFO_QUERY, name=proj, owner=org, number=pr_no)
|
||||
return rc["data"]["repository"]["pullRequest"]
|
||||
|
||||
|
||||
def parse_args() -> Any:
|
||||
from argparse import ArgumentParser
|
||||
parser = ArgumentParser("Merge PR into default branch")
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
parser.add_argument("pr_num", type=int)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class GitHubPR:
|
||||
def __init__(self, org: str, project: str, pr_num: int) -> None:
|
||||
assert isinstance(pr_num, int)
|
||||
self.org = org
|
||||
self.project = project
|
||||
self.pr_num = pr_num
|
||||
self.info = gh_get_pr_info(org, project, pr_num)
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return bool(self.info["closed"])
|
||||
|
||||
def is_cross_repo(self) -> bool:
|
||||
return bool(self.info["isCrossRepository"])
|
||||
|
||||
def base_ref(self) -> str:
|
||||
return cast(str, self.info["baseRefName"])
|
||||
|
||||
def default_branch(self) -> str:
|
||||
return cast(str, self.info["baseRepository"]["defaultBranchRef"]["name"])
|
||||
|
||||
def head_ref(self) -> str:
|
||||
return cast(str, self.info["headRefName"])
|
||||
|
||||
def is_ghstack_pr(self) -> bool:
|
||||
return RE_GHSTACK_HEAD_REF.match(self.head_ref()) is not None
|
||||
|
||||
def get_changed_files_count(self) -> int:
|
||||
return int(self.info["changedFiles"])
|
||||
|
||||
def get_changed_files(self) -> List[str]:
|
||||
rc = [x["path"] for x in self.info["files"]["nodes"]]
|
||||
if len(rc) != self.get_changed_files_count():
|
||||
raise RuntimeError("Changed file count mismatch")
|
||||
return rc
|
||||
|
||||
def _get_reviewers(self) -> List[Tuple[str, str]]:
|
||||
reviews_count = int(self.info["latestReviews"]["totalCount"])
|
||||
if len(self.info["latestReviews"]["nodes"]) != reviews_count:
|
||||
raise RuntimeError("Can't fetch all PR reviews")
|
||||
return [(x["author"]["login"], x["state"]) for x in self.info["latestReviews"]["nodes"]]
|
||||
|
||||
def get_approved_by(self) -> List[str]:
|
||||
return [login for (login, state) in self._get_reviewers() if state == "APPROVED"]
|
||||
|
||||
def get_commit_count(self) -> int:
|
||||
return int(self.info["commits"]["totalCount"])
|
||||
|
||||
def get_pr_creator_login(self) -> str:
|
||||
return cast(str, self.info["author"]["login"])
|
||||
|
||||
def get_committer_login(self, num: int = 0) -> str:
|
||||
return cast(str, self.info["commits"]["nodes"][num]["commit"]["author"]["user"]["login"])
|
||||
|
||||
def get_committer_author(self, num: int = 0) -> str:
|
||||
node = self.info["commits"]["nodes"][num]["commit"]["author"]
|
||||
return f"{node['name']} <{node['email']}>"
|
||||
|
||||
def get_check_suite_conclusions(self) -> Dict[int, str]:
|
||||
last_commit = self.info["commits"]["nodes"][-1]["commit"]
|
||||
rc = {}
|
||||
for node in last_commit["checkSuites"]["nodes"]:
|
||||
rc[int(node["app"]["databaseId"])] = node["conclusion"]
|
||||
return rc
|
||||
|
||||
def get_authors(self) -> Dict[str, str]:
|
||||
rc = {}
|
||||
for idx in range(self.get_commit_count()):
|
||||
rc[self.get_committer_login(idx)] = self.get_committer_author(idx)
|
||||
|
||||
return rc
|
||||
|
||||
def get_author(self) -> str:
|
||||
authors = self.get_authors()
|
||||
if len(authors) == 1:
|
||||
return next(iter(authors.values()))
|
||||
return self.get_authors()[self.get_pr_creator_login()]
|
||||
|
||||
def get_title(self) -> str:
|
||||
return cast(str, self.info["title"])
|
||||
|
||||
def get_body(self) -> str:
|
||||
return cast(str, self.info["body"])
|
||||
|
||||
def get_pr_url(self) -> str:
|
||||
return f"https://github.com/{self.org}/{self.project}/pull/{self.pr_num}"
|
||||
|
||||
def merge_ghstack_into(self, repo: GitRepo) -> None:
|
||||
assert self.is_ghstack_pr()
|
||||
# For ghstack, cherry-pick commits based from origin
|
||||
orig_ref = f"{repo.remote}/{re.sub(r'/head$', '/orig', self.head_ref())}"
|
||||
rev_list = repo.revlist(f"{self.default_branch()}..{orig_ref}")
|
||||
for idx, rev in enumerate(reversed(rev_list)):
|
||||
msg = repo.commit_message(rev)
|
||||
m = RE_PULL_REQUEST_RESOLVED.search(msg)
|
||||
if m is None:
|
||||
raise RuntimeError(f"Could not find PR-resolved string in {msg} of ghstacked PR {self.pr_num}")
|
||||
if self.org != m.group('owner') or self.project != m.group('repo'):
|
||||
raise RuntimeError(f"PR {m.group('number')} resolved to wrong owner/repo pair")
|
||||
pr_num = int(m.group('number'))
|
||||
if pr_num != self.pr_num:
|
||||
pr = GitHubPR(self.org, self.project, pr_num)
|
||||
if pr.is_closed():
|
||||
print(f"Skipping {idx+1} of {len(rev_list)} PR (#{pr_num}) as its already been merged")
|
||||
continue
|
||||
check_if_should_be_merged(pr, repo)
|
||||
repo.cherry_pick(rev)
|
||||
repo.amend_commit_message(re.sub(RE_GHSTACK_SOURCE_ID, "", msg))
|
||||
|
||||
def merge_into(self, repo: GitRepo, dry_run: bool = False) -> None:
|
||||
check_if_should_be_merged(self, repo)
|
||||
if repo.current_branch() != self.default_branch():
|
||||
repo.checkout(self.default_branch())
|
||||
if not self.is_ghstack_pr():
|
||||
msg = self.get_title() + "\n" + self.get_body()
|
||||
msg += f"\nPull Request resolved: {self.get_pr_url()}\n"
|
||||
repo._run_git("merge", "--squash", f"{repo.remote}/{self.head_ref()}")
|
||||
repo._run_git("commit", f"--author=\"{self.get_author()}\"", "-m", msg)
|
||||
else:
|
||||
self.merge_ghstack_into(repo)
|
||||
|
||||
if not dry_run:
|
||||
repo.push(self.default_branch())
|
||||
|
||||
|
||||
@dataclass
|
||||
class MergeRule:
|
||||
name: str
|
||||
patterns: List[str]
|
||||
approved_by: List[str]
|
||||
mandatory_app_id: Optional[int]
|
||||
|
||||
|
||||
def read_merge_rules(repo: GitRepo) -> List[MergeRule]:
|
||||
from pathlib import Path
|
||||
rules_path = Path(repo.repo_dir) / ".github" / "merge_rules.json"
|
||||
if not rules_path.exists():
|
||||
print(f"{rules_path} does not exist, returning empty rules")
|
||||
return []
|
||||
with open(rules_path) as fp:
|
||||
rc = json.load(fp, object_hook=lambda x: MergeRule(**x))
|
||||
return cast(List[MergeRule], rc)
|
||||
|
||||
|
||||
def check_if_should_be_merged(pr: GitHubPR, repo: GitRepo) -> None:
|
||||
changed_files = pr.get_changed_files()
|
||||
approved_by = set(pr.get_approved_by())
|
||||
rules = read_merge_rules(repo)
|
||||
for rule in rules:
|
||||
rule_name = rule.name
|
||||
rule_approvers_set = set(rule.approved_by)
|
||||
patterns_re = patterns_to_regex(rule.patterns)
|
||||
approvers_intersection = approved_by.intersection(rule_approvers_set)
|
||||
# If rule requires approvers but they aren't the ones that reviewed PR
|
||||
if len(approvers_intersection) == 0 and len(rule_approvers_set) > 0:
|
||||
print(f"Skipping rule {rule_name} due to no approvers overlap")
|
||||
continue
|
||||
if rule.mandatory_app_id is not None:
|
||||
cs_conslusions = pr.get_check_suite_conclusions()
|
||||
mandatory_app_id = rule.mandatory_app_id
|
||||
if mandatory_app_id not in cs_conslusions or cs_conslusions[mandatory_app_id] != "SUCCESS":
|
||||
print(f"Skipping rule {rule_name} as mandatory app {mandatory_app_id} is not in {cs_conslusions}")
|
||||
continue
|
||||
non_matching_files = []
|
||||
for fname in changed_files:
|
||||
if not patterns_re.match(fname):
|
||||
non_matching_files.append(fname)
|
||||
if len(non_matching_files) > 0:
|
||||
print(f"Skipping rule {rule_name} due to non-matching files: {non_matching_files}")
|
||||
continue
|
||||
print(f"Matched rule {rule_name} for {pr.pr_num}")
|
||||
return
|
||||
raise RuntimeError(f"PR {pr.pr_num} does not match merge rules")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
import sys
|
||||
args = parse_args()
|
||||
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
|
||||
org, project = repo.gh_owner_and_name()
|
||||
|
||||
pr = GitHubPR(org, project, args.pr_num)
|
||||
if pr.is_closed():
|
||||
print(gh_post_comment(org, project, args.pr_num, f"Can't merge closed PR #{args.pr_num}", dry_run=args.dry_run))
|
||||
sys.exit(-1)
|
||||
|
||||
if pr.is_cross_repo():
|
||||
print(gh_post_comment(org, project, args.pr_num, "Cross-repo merges are not supported at the moment", dry_run=args.dry_run))
|
||||
sys.exit(-1)
|
||||
|
||||
try:
|
||||
pr.merge_into(repo, dry_run=args.dry_run)
|
||||
except Exception as e:
|
||||
gh_post_comment(org, project, args.pr_num, f"Merge failed due to {e}", dry_run=args.dry_run)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
31
.github/workflows/trymerge.yml
vendored
Normal file
31
.github/workflows/trymerge.yml
vendored
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
name: Validate and merge PR
|
||||
|
||||
on:
|
||||
repository_dispatch:
|
||||
types: [try-merge]
|
||||
|
||||
jobs:
|
||||
do_merge:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.8
|
||||
architecture: x64
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.MERGEBOT_TOKEN }}
|
||||
|
||||
- name: Setup committer id
|
||||
run: |
|
||||
git config --global user.email "pytorchmergebot@users.noreply.github.com"
|
||||
git config --global user.name "PyTorch MergeBot"
|
||||
- name: Merge PR
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUM: ${{ github.event.client_payload.pr_num }}
|
||||
run: |
|
||||
python3 .github/scripts/trymerge.py "${PR_NUM}"
|
||||
Loading…
Reference in New Issue
Block a user