[ONNX] Update the script for version updates (#83283)

This PR updates the `tools/onnx/update_default_opset_version.py` script to ensure files are edited correctly to prepare for the opset 17 support in torch.onnx.

- (clean up) Move script to `main()`
- Add an `--skip_build` option to avoid building pytorch if we want to rerun the process due to errors after compilation is done
- Update to edit the correct files now that the onnx files were refactored
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83283
Approved by: https://github.com/thiagocrepaldi, https://github.com/AllenTiTaiWang, https://github.com/abock
This commit is contained in:
Justin Chu 2022-08-16 19:58:48 +00:00 committed by PyTorch MergeBot
parent d52d2bd5a9
commit cd68f08992

View File

@ -9,6 +9,7 @@ Usage:
Run with no arguments. Run with no arguments.
""" """
import argparse
import datetime import datetime
import os import os
import pathlib import pathlib
@ -16,39 +17,53 @@ import re
import subprocess import subprocess
import sys import sys
from subprocess import DEVNULL from subprocess import DEVNULL
from typing import Any
pytorch_dir = pathlib.Path(__file__).parent.parent.parent.resolve()
onnx_dir = pytorch_dir / "third_party" / "onnx"
os.chdir(onnx_dir)
date = datetime.datetime.now() - datetime.timedelta(days=18 * 30) def read_sub_write(path: str, prefix_pat: str, new_default: int) -> None:
onnx_commit = subprocess.check_output( with open(path, encoding="utf-8") as f:
("git", "log", f"--until={date}", "--max-count=1", "--format=%H"), encoding="utf-8" content_str = f.read()
).strip() content_str = re.sub(prefix_pat, r"\g<1>{}".format(new_default), content_str)
onnx_tags = subprocess.check_output( with open(path, "w", encoding="utf-8") as f:
f.write(content_str)
print("modified", path)
def main(args: Any) -> None:
pytorch_dir = pathlib.Path(__file__).parent.parent.parent.resolve()
onnx_dir = pytorch_dir / "third_party" / "onnx"
os.chdir(onnx_dir)
date = datetime.datetime.now() - datetime.timedelta(days=18 * 30)
onnx_commit = subprocess.check_output(
("git", "log", f"--until={date}", "--max-count=1", "--format=%H"),
encoding="utf-8",
).strip()
onnx_tags = subprocess.check_output(
("git", "tag", "--list", f"--contains={onnx_commit}"), encoding="utf-8" ("git", "tag", "--list", f"--contains={onnx_commit}"), encoding="utf-8"
) )
tag_tups = [] tag_tups = []
semver_pat = re.compile(r"v(\d+)\.(\d+)\.(\d+)") semver_pat = re.compile(r"v(\d+)\.(\d+)\.(\d+)")
for tag in onnx_tags.splitlines(): for tag in onnx_tags.splitlines():
match = semver_pat.match(tag) match = semver_pat.match(tag)
if match: if match:
tag_tups.append(tuple(int(x) for x in match.groups())) tag_tups.append(tuple(int(x) for x in match.groups()))
version_str = "{}.{}.{}".format(*min(tag_tups)) # Take the release 18 months ago
version_str = "{}.{}.{}".format(*min(tag_tups))
print("Using ONNX release", version_str) print("Using ONNX release", version_str)
head_commit = subprocess.check_output( head_commit = subprocess.check_output(
("git", "log", "--max-count=1", "--format=%H", "HEAD"), encoding="utf-8" ("git", "log", "--max-count=1", "--format=%H", "HEAD"), encoding="utf-8"
).strip() ).strip()
new_default = None new_default = None
subprocess.check_call( subprocess.check_call(
("git", "checkout", f"v{version_str}"), stdout=DEVNULL, stderr=DEVNULL ("git", "checkout", f"v{version_str}"), stdout=DEVNULL, stderr=DEVNULL
) )
try: try:
from onnx import helper # type: ignore[import] from onnx import helper # type: ignore[import]
for version in helper.VERSION_TABLE: for version in helper.VERSION_TABLE:
@ -60,35 +75,38 @@ try:
sys.exit( sys.exit(
f"failed to find version {version_str} in onnx.helper.VERSION_TABLE at commit {onnx_commit}" f"failed to find version {version_str} in onnx.helper.VERSION_TABLE at commit {onnx_commit}"
) )
finally: finally:
subprocess.check_call( subprocess.check_call(
("git", "checkout", head_commit), stdout=DEVNULL, stderr=DEVNULL ("git", "checkout", head_commit), stdout=DEVNULL, stderr=DEVNULL
) )
os.chdir(pytorch_dir) os.chdir(pytorch_dir)
read_sub_write(
def read_sub_write(path: str, prefix_pat: str) -> None:
with open(path, encoding="utf-8") as f:
content_str = f.read()
content_str = re.sub(prefix_pat, r"\g<1>{}".format(new_default), content_str)
with open(path, "w", encoding="utf-8") as f:
f.write(content_str)
print("modified", path)
read_sub_write(
os.path.join("torch", "onnx", "_constants.py"), os.path.join("torch", "onnx", "_constants.py"),
r"(onnx_default_opset = )\d+", r"(onnx_default_opset = )\d+",
) new_default,
read_sub_write( )
os.path.join("torch", "onnx", "__init__.py"), r"(opset_version \(int, default )\d+" read_sub_write(
) os.path.join("torch", "onnx", "utils.py"),
r"(opset_version \(int, default )\d+",
new_default,
)
print("Updating operator .expect files") if not args.skip_build:
subprocess.check_call(("python", "setup.py", "develop"), stdout=DEVNULL, stderr=DEVNULL) print("Building PyTorch...")
subprocess.check_call( subprocess.check_call(
("python", "setup.py", "develop"),
)
print("Updating operator .expect files")
subprocess.check_call(
("python", os.path.join("test", "onnx", "test_operators.py"), "--accept"), ("python", os.path.join("test", "onnx", "test_operators.py"), "--accept"),
stdout=DEVNULL, )
stderr=DEVNULL,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--skip_build", action="store_true", help="Skip building pytorch"
)
main(parser.parse_args())