mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add trymerge workflow (#71488)
Summary:
This one, will react to `repo_dispatch` event sent by PyTorch Probot
when `pytorchbot merge this` command is issued
At the moment, workflow will only attempt to merge PRs which has not
been created from forked repo and that match rules defined in
`.github/merge_rules.json`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71488
Reviewed By: bigfootjon
Differential Revision: D33665142
Pulled By: malfet
fbshipit-source-id: e22daa1892523e62d7b7a941960636a6514cb7d7
(cherry picked from commit 92059bab07)
This commit is contained in:
parent
f45e217c01
commit
61713acb07
66
.github/scripts/gitutils.py
vendored
66
.github/scripts/gitutils.py
vendored
|
|
@ -211,22 +211,54 @@ class GitRepo:
|
||||||
self._run_git("commit", "--amend", "-m", msg)
|
self._run_git("commit", "--amend", "-m", msg)
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> Any:
|
class PeekableIterator:
|
||||||
from argparse import ArgumentParser
|
def __init__(self, val: str) -> None:
|
||||||
parser = ArgumentParser("Merge PR/branch into default branch")
|
self._val = val
|
||||||
parser.add_argument("--sync-branch", default="sync")
|
self._idx = -1
|
||||||
parser.add_argument("--default-branch", type=str, default="main")
|
|
||||||
parser.add_argument("--dry-run", action="store_true")
|
def peek(self) -> Optional[str]:
|
||||||
return parser.parse_args()
|
if self._idx + 1 >= len(self._val):
|
||||||
|
return None
|
||||||
|
return self._val[self._idx + 1]
|
||||||
|
|
||||||
|
def __iter__(self) -> Any:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> str:
|
||||||
|
rc = self.peek()
|
||||||
|
if rc is None:
|
||||||
|
raise StopIteration
|
||||||
|
self._idx += 1
|
||||||
|
return rc
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def patterns_to_regex(allowed_patterns: List[str]) -> Any:
|
||||||
args = parse_args()
|
"""
|
||||||
repo = GitRepo(get_git_repo_dir())
|
pattern is glob-like, i.e. the only special sequences it has are:
|
||||||
repo.cherry_pick_commits(args.sync_branch, args.default_branch)
|
- ? - matches single character
|
||||||
if not args.dry_run:
|
- * - matches any non-folder separator characters
|
||||||
repo.push(args.default_branch)
|
- ** - matches any characters
|
||||||
|
Assuming that patterns are free of braces and backslashes
|
||||||
|
the only character that needs to be escaped are dot and plus
|
||||||
if __name__ == '__main__':
|
"""
|
||||||
main()
|
rc = "("
|
||||||
|
for idx, pattern in enumerate(allowed_patterns):
|
||||||
|
if idx > 0:
|
||||||
|
rc += "|"
|
||||||
|
pattern_ = PeekableIterator(pattern)
|
||||||
|
assert not any(c in pattern for c in "{}()[]\\")
|
||||||
|
for c in pattern_:
|
||||||
|
if c == ".":
|
||||||
|
rc += "\\."
|
||||||
|
elif c == "+":
|
||||||
|
rc += "\\+"
|
||||||
|
elif c == "*":
|
||||||
|
if pattern_.peek() == "*":
|
||||||
|
next(pattern_)
|
||||||
|
rc += ".+"
|
||||||
|
else:
|
||||||
|
rc += "[^/]+"
|
||||||
|
else:
|
||||||
|
rc += c
|
||||||
|
rc += ")"
|
||||||
|
return re.compile(rc)
|
||||||
|
|
|
||||||
340
.github/scripts/trymerge.py
vendored
Executable file
340
.github/scripts/trymerge.py
vendored
Executable file
|
|
@ -0,0 +1,340 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from urllib.request import urlopen, Request
|
||||||
|
from urllib.error import HTTPError
|
||||||
|
from typing import cast, Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
from gitutils import get_git_remote_name, get_git_repo_dir, patterns_to_regex, GitRepo
|
||||||
|
|
||||||
|
|
||||||
|
GH_GET_PR_INFO_QUERY = """
|
||||||
|
query ($owner: String!, $name: String!, $number: Int!) {
|
||||||
|
repository(owner: $owner, name: $name) {
|
||||||
|
pullRequest(number: $number) {
|
||||||
|
closed
|
||||||
|
isCrossRepository
|
||||||
|
author {
|
||||||
|
login
|
||||||
|
}
|
||||||
|
title
|
||||||
|
body
|
||||||
|
headRefName
|
||||||
|
headRepository {
|
||||||
|
nameWithOwner
|
||||||
|
}
|
||||||
|
baseRefName
|
||||||
|
baseRepository {
|
||||||
|
nameWithOwner
|
||||||
|
defaultBranchRef {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
commits(first: 100) {
|
||||||
|
nodes {
|
||||||
|
commit {
|
||||||
|
author {
|
||||||
|
user {
|
||||||
|
login
|
||||||
|
}
|
||||||
|
email
|
||||||
|
name
|
||||||
|
}
|
||||||
|
oid
|
||||||
|
checkSuites(filterBy: {appId: 12274}, first: 1) {
|
||||||
|
nodes {
|
||||||
|
app {
|
||||||
|
databaseId
|
||||||
|
}
|
||||||
|
conclusion
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
totalCount
|
||||||
|
}
|
||||||
|
changedFiles,
|
||||||
|
files(last: 100) {
|
||||||
|
nodes {
|
||||||
|
path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
latestReviews(last: 100) {
|
||||||
|
nodes {
|
||||||
|
author {
|
||||||
|
login
|
||||||
|
},
|
||||||
|
state
|
||||||
|
},
|
||||||
|
totalCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
RE_GHSTACK_HEAD_REF = re.compile(r"^(gh/[^/]+/[0-9]+/)head$")
|
||||||
|
RE_GHSTACK_SOURCE_ID = re.compile(r'^ghstack-source-id: (.+)\n?', re.MULTILINE)
|
||||||
|
RE_PULL_REQUEST_RESOLVED = re.compile(
|
||||||
|
r'Pull Request resolved: '
|
||||||
|
r'https://github.com/(?P<owner>[^/]+)/(?P<repo>[^/]+)/pull/(?P<number>[0-9]+)',
|
||||||
|
re.MULTILINE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_url(url: str, *,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
data: Optional[Dict[str, Any]] = None,
|
||||||
|
method: Optional[str] = None,
|
||||||
|
reader: Callable[[Any], Any] = lambda x: x.read()) -> Any:
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
token = os.environ.get("GITHUB_TOKEN")
|
||||||
|
if token is not None and url.startswith('https://api.github.com/'):
|
||||||
|
headers['Authorization'] = f'token {token}'
|
||||||
|
data_ = json.dumps(data).encode() if data is not None else None
|
||||||
|
try:
|
||||||
|
with urlopen(Request(url, headers=headers, data=data_, method=method)) as conn:
|
||||||
|
return reader(conn)
|
||||||
|
except HTTPError as err:
|
||||||
|
if err.code == 403 and all(key in err.headers for key in ['X-RateLimit-Limit', 'X-RateLimit-Used']):
|
||||||
|
print(f"Rate limit exceeded: {err.headers['X-RateLimit-Used']}/{err.headers['X-RateLimit-Limit']}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_json(url: str,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
data: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||||
|
headers = {'Accept': 'application/vnd.github.v3+json'}
|
||||||
|
if params is not None and len(params) > 0:
|
||||||
|
url += '?' + '&'.join(f"{name}={val}" for name, val in params.items())
|
||||||
|
return cast(List[Dict[str, Any]], _fetch_url(url, headers=headers, data=data, reader=json.load))
|
||||||
|
|
||||||
|
|
||||||
|
def gh_post_comment(org: str, project: str, pr_num: int, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
|
||||||
|
if dry_run:
|
||||||
|
print(comment)
|
||||||
|
return []
|
||||||
|
return fetch_json(f'https://api.github.com/repos/{org}/{project}/issues/{pr_num}/comments',
|
||||||
|
data={"body": comment})
|
||||||
|
|
||||||
|
|
||||||
|
def gh_graphql(query: str, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
rc = _fetch_url("https://api.github.com/graphql", data={"query": query, "variables": kwargs}, reader=json.load)
|
||||||
|
if "errors" in rc:
|
||||||
|
raise RuntimeError(f"GraphQL query {query} failed: {rc['errors']}")
|
||||||
|
return cast(Dict[str, Any], rc)
|
||||||
|
|
||||||
|
|
||||||
|
def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any:
|
||||||
|
rc = gh_graphql(GH_GET_PR_INFO_QUERY, name=proj, owner=org, number=pr_no)
|
||||||
|
return rc["data"]["repository"]["pullRequest"]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> Any:
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
parser = ArgumentParser("Merge PR into default branch")
|
||||||
|
parser.add_argument("--dry-run", action="store_true")
|
||||||
|
parser.add_argument("pr_num", type=int)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubPR:
|
||||||
|
def __init__(self, org: str, project: str, pr_num: int) -> None:
|
||||||
|
assert isinstance(pr_num, int)
|
||||||
|
self.org = org
|
||||||
|
self.project = project
|
||||||
|
self.pr_num = pr_num
|
||||||
|
self.info = gh_get_pr_info(org, project, pr_num)
|
||||||
|
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
return bool(self.info["closed"])
|
||||||
|
|
||||||
|
def is_cross_repo(self) -> bool:
|
||||||
|
return bool(self.info["isCrossRepository"])
|
||||||
|
|
||||||
|
def base_ref(self) -> str:
|
||||||
|
return cast(str, self.info["baseRefName"])
|
||||||
|
|
||||||
|
def default_branch(self) -> str:
|
||||||
|
return cast(str, self.info["baseRepository"]["defaultBranchRef"]["name"])
|
||||||
|
|
||||||
|
def head_ref(self) -> str:
|
||||||
|
return cast(str, self.info["headRefName"])
|
||||||
|
|
||||||
|
def is_ghstack_pr(self) -> bool:
|
||||||
|
return RE_GHSTACK_HEAD_REF.match(self.head_ref()) is not None
|
||||||
|
|
||||||
|
def get_changed_files_count(self) -> int:
|
||||||
|
return int(self.info["changedFiles"])
|
||||||
|
|
||||||
|
def get_changed_files(self) -> List[str]:
|
||||||
|
rc = [x["path"] for x in self.info["files"]["nodes"]]
|
||||||
|
if len(rc) != self.get_changed_files_count():
|
||||||
|
raise RuntimeError("Changed file count mismatch")
|
||||||
|
return rc
|
||||||
|
|
||||||
|
def _get_reviewers(self) -> List[Tuple[str, str]]:
|
||||||
|
reviews_count = int(self.info["latestReviews"]["totalCount"])
|
||||||
|
if len(self.info["latestReviews"]["nodes"]) != reviews_count:
|
||||||
|
raise RuntimeError("Can't fetch all PR reviews")
|
||||||
|
return [(x["author"]["login"], x["state"]) for x in self.info["latestReviews"]["nodes"]]
|
||||||
|
|
||||||
|
def get_approved_by(self) -> List[str]:
|
||||||
|
return [login for (login, state) in self._get_reviewers() if state == "APPROVED"]
|
||||||
|
|
||||||
|
def get_commit_count(self) -> int:
|
||||||
|
return int(self.info["commits"]["totalCount"])
|
||||||
|
|
||||||
|
def get_pr_creator_login(self) -> str:
|
||||||
|
return cast(str, self.info["author"]["login"])
|
||||||
|
|
||||||
|
def get_committer_login(self, num: int = 0) -> str:
|
||||||
|
return cast(str, self.info["commits"]["nodes"][num]["commit"]["author"]["user"]["login"])
|
||||||
|
|
||||||
|
def get_committer_author(self, num: int = 0) -> str:
|
||||||
|
node = self.info["commits"]["nodes"][num]["commit"]["author"]
|
||||||
|
return f"{node['name']} <{node['email']}>"
|
||||||
|
|
||||||
|
def get_check_suite_conclusions(self) -> Dict[int, str]:
|
||||||
|
last_commit = self.info["commits"]["nodes"][-1]["commit"]
|
||||||
|
rc = {}
|
||||||
|
for node in last_commit["checkSuites"]["nodes"]:
|
||||||
|
rc[int(node["app"]["databaseId"])] = node["conclusion"]
|
||||||
|
return rc
|
||||||
|
|
||||||
|
def get_authors(self) -> Dict[str, str]:
|
||||||
|
rc = {}
|
||||||
|
for idx in range(self.get_commit_count()):
|
||||||
|
rc[self.get_committer_login(idx)] = self.get_committer_author(idx)
|
||||||
|
|
||||||
|
return rc
|
||||||
|
|
||||||
|
def get_author(self) -> str:
|
||||||
|
authors = self.get_authors()
|
||||||
|
if len(authors) == 1:
|
||||||
|
return next(iter(authors.values()))
|
||||||
|
return self.get_authors()[self.get_pr_creator_login()]
|
||||||
|
|
||||||
|
def get_title(self) -> str:
|
||||||
|
return cast(str, self.info["title"])
|
||||||
|
|
||||||
|
def get_body(self) -> str:
|
||||||
|
return cast(str, self.info["body"])
|
||||||
|
|
||||||
|
def get_pr_url(self) -> str:
|
||||||
|
return f"https://github.com/{self.org}/{self.project}/pull/{self.pr_num}"
|
||||||
|
|
||||||
|
def merge_ghstack_into(self, repo: GitRepo) -> None:
|
||||||
|
assert self.is_ghstack_pr()
|
||||||
|
# For ghstack, cherry-pick commits based from origin
|
||||||
|
orig_ref = f"{repo.remote}/{re.sub(r'/head$', '/orig', self.head_ref())}"
|
||||||
|
rev_list = repo.revlist(f"{self.default_branch()}..{orig_ref}")
|
||||||
|
for idx, rev in enumerate(reversed(rev_list)):
|
||||||
|
msg = repo.commit_message(rev)
|
||||||
|
m = RE_PULL_REQUEST_RESOLVED.search(msg)
|
||||||
|
if m is None:
|
||||||
|
raise RuntimeError(f"Could not find PR-resolved string in {msg} of ghstacked PR {self.pr_num}")
|
||||||
|
if self.org != m.group('owner') or self.project != m.group('repo'):
|
||||||
|
raise RuntimeError(f"PR {m.group('number')} resolved to wrong owner/repo pair")
|
||||||
|
pr_num = int(m.group('number'))
|
||||||
|
if pr_num != self.pr_num:
|
||||||
|
pr = GitHubPR(self.org, self.project, pr_num)
|
||||||
|
if pr.is_closed():
|
||||||
|
print(f"Skipping {idx+1} of {len(rev_list)} PR (#{pr_num}) as its already been merged")
|
||||||
|
continue
|
||||||
|
check_if_should_be_merged(pr, repo)
|
||||||
|
repo.cherry_pick(rev)
|
||||||
|
repo.amend_commit_message(re.sub(RE_GHSTACK_SOURCE_ID, "", msg))
|
||||||
|
|
||||||
|
def merge_into(self, repo: GitRepo, dry_run: bool = False) -> None:
|
||||||
|
check_if_should_be_merged(self, repo)
|
||||||
|
if repo.current_branch() != self.default_branch():
|
||||||
|
repo.checkout(self.default_branch())
|
||||||
|
if not self.is_ghstack_pr():
|
||||||
|
msg = self.get_title() + "\n" + self.get_body()
|
||||||
|
msg += f"\nPull Request resolved: {self.get_pr_url()}\n"
|
||||||
|
repo._run_git("merge", "--squash", f"{repo.remote}/{self.head_ref()}")
|
||||||
|
repo._run_git("commit", f"--author=\"{self.get_author()}\"", "-m", msg)
|
||||||
|
else:
|
||||||
|
self.merge_ghstack_into(repo)
|
||||||
|
|
||||||
|
if not dry_run:
|
||||||
|
repo.push(self.default_branch())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MergeRule:
|
||||||
|
name: str
|
||||||
|
patterns: List[str]
|
||||||
|
approved_by: List[str]
|
||||||
|
mandatory_app_id: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
def read_merge_rules(repo: GitRepo) -> List[MergeRule]:
|
||||||
|
from pathlib import Path
|
||||||
|
rules_path = Path(repo.repo_dir) / ".github" / "merge_rules.json"
|
||||||
|
if not rules_path.exists():
|
||||||
|
print(f"{rules_path} does not exist, returning empty rules")
|
||||||
|
return []
|
||||||
|
with open(rules_path) as fp:
|
||||||
|
rc = json.load(fp, object_hook=lambda x: MergeRule(**x))
|
||||||
|
return cast(List[MergeRule], rc)
|
||||||
|
|
||||||
|
|
||||||
|
def check_if_should_be_merged(pr: GitHubPR, repo: GitRepo) -> None:
|
||||||
|
changed_files = pr.get_changed_files()
|
||||||
|
approved_by = set(pr.get_approved_by())
|
||||||
|
rules = read_merge_rules(repo)
|
||||||
|
for rule in rules:
|
||||||
|
rule_name = rule.name
|
||||||
|
rule_approvers_set = set(rule.approved_by)
|
||||||
|
patterns_re = patterns_to_regex(rule.patterns)
|
||||||
|
approvers_intersection = approved_by.intersection(rule_approvers_set)
|
||||||
|
# If rule requires approvers but they aren't the ones that reviewed PR
|
||||||
|
if len(approvers_intersection) == 0 and len(rule_approvers_set) > 0:
|
||||||
|
print(f"Skipping rule {rule_name} due to no approvers overlap")
|
||||||
|
continue
|
||||||
|
if rule.mandatory_app_id is not None:
|
||||||
|
cs_conslusions = pr.get_check_suite_conclusions()
|
||||||
|
mandatory_app_id = rule.mandatory_app_id
|
||||||
|
if mandatory_app_id not in cs_conslusions or cs_conslusions[mandatory_app_id] != "SUCCESS":
|
||||||
|
print(f"Skipping rule {rule_name} as mandatory app {mandatory_app_id} is not in {cs_conslusions}")
|
||||||
|
continue
|
||||||
|
non_matching_files = []
|
||||||
|
for fname in changed_files:
|
||||||
|
if not patterns_re.match(fname):
|
||||||
|
non_matching_files.append(fname)
|
||||||
|
if len(non_matching_files) > 0:
|
||||||
|
print(f"Skipping rule {rule_name} due to non-matching files: {non_matching_files}")
|
||||||
|
continue
|
||||||
|
print(f"Matched rule {rule_name} for {pr.pr_num}")
|
||||||
|
return
|
||||||
|
raise RuntimeError(f"PR {pr.pr_num} does not match merge rules")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
import sys
|
||||||
|
args = parse_args()
|
||||||
|
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
|
||||||
|
org, project = repo.gh_owner_and_name()
|
||||||
|
|
||||||
|
pr = GitHubPR(org, project, args.pr_num)
|
||||||
|
if pr.is_closed():
|
||||||
|
print(gh_post_comment(org, project, args.pr_num, f"Can't merge closed PR #{args.pr_num}", dry_run=args.dry_run))
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
if pr.is_cross_repo():
|
||||||
|
print(gh_post_comment(org, project, args.pr_num, "Cross-repo merges are not supported at the moment", dry_run=args.dry_run))
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
pr.merge_into(repo, dry_run=args.dry_run)
|
||||||
|
except Exception as e:
|
||||||
|
gh_post_comment(org, project, args.pr_num, f"Merge failed due to {e}", dry_run=args.dry_run)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
31
.github/workflows/trymerge.yml
vendored
Normal file
31
.github/workflows/trymerge.yml
vendored
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
name: Validate and merge PR
|
||||||
|
|
||||||
|
on:
|
||||||
|
repository_dispatch:
|
||||||
|
types: [try-merge]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
do_merge:
|
||||||
|
runs-on: ubuntu-20.04
|
||||||
|
steps:
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: 3.8
|
||||||
|
architecture: x64
|
||||||
|
- name: Checkout repo
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
token: ${{ secrets.MERGEBOT_TOKEN }}
|
||||||
|
|
||||||
|
- name: Setup committer id
|
||||||
|
run: |
|
||||||
|
git config --global user.email "pytorchmergebot@users.noreply.github.com"
|
||||||
|
git config --global user.name "PyTorch MergeBot"
|
||||||
|
- name: Merge PR
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
PR_NUM: ${{ github.event.client_payload.pr_num }}
|
||||||
|
run: |
|
||||||
|
python3 .github/scripts/trymerge.py "${PR_NUM}"
|
||||||
Loading…
Reference in New Issue
Block a user