mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Update the script for version updates (#83283)
This PR updates the `tools/onnx/update_default_opset_version.py` script to ensure files are edited correctly to prepare for the opset 17 support in torch.onnx. - (clean up) Move script to `main()` - Add an `--skip_build` option to avoid building pytorch if we want to rerun the process due to errors after compilation is done - Update to edit the correct files now that the onnx files were refactored Pull Request resolved: https://github.com/pytorch/pytorch/pull/83283 Approved by: https://github.com/thiagocrepaldi, https://github.com/AllenTiTaiWang, https://github.com/abock
This commit is contained in:
parent
d52d2bd5a9
commit
cd68f08992
|
|
@ -9,6 +9,7 @@ Usage:
|
|||
Run with no arguments.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pathlib
|
||||
|
|
@ -16,14 +17,27 @@ import re
|
|||
import subprocess
|
||||
import sys
|
||||
from subprocess import DEVNULL
|
||||
from typing import Any
|
||||
|
||||
|
||||
def read_sub_write(path: str, prefix_pat: str, new_default: int) -> None:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
content_str = f.read()
|
||||
content_str = re.sub(prefix_pat, r"\g<1>{}".format(new_default), content_str)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content_str)
|
||||
print("modified", path)
|
||||
|
||||
|
||||
def main(args: Any) -> None:
|
||||
pytorch_dir = pathlib.Path(__file__).parent.parent.parent.resolve()
|
||||
onnx_dir = pytorch_dir / "third_party" / "onnx"
|
||||
os.chdir(onnx_dir)
|
||||
|
||||
date = datetime.datetime.now() - datetime.timedelta(days=18 * 30)
|
||||
onnx_commit = subprocess.check_output(
|
||||
("git", "log", f"--until={date}", "--max-count=1", "--format=%H"), encoding="utf-8"
|
||||
("git", "log", f"--until={date}", "--max-count=1", "--format=%H"),
|
||||
encoding="utf-8",
|
||||
).strip()
|
||||
onnx_tags = subprocess.check_output(
|
||||
("git", "tag", "--list", f"--contains={onnx_commit}"), encoding="utf-8"
|
||||
|
|
@ -35,6 +49,7 @@ for tag in onnx_tags.splitlines():
|
|||
if match:
|
||||
tag_tups.append(tuple(int(x) for x in match.groups()))
|
||||
|
||||
# Take the release 18 months ago
|
||||
version_str = "{}.{}.{}".format(*min(tag_tups))
|
||||
|
||||
print("Using ONNX release", version_str)
|
||||
|
|
@ -67,28 +82,31 @@ finally:
|
|||
|
||||
os.chdir(pytorch_dir)
|
||||
|
||||
|
||||
def read_sub_write(path: str, prefix_pat: str) -> None:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
content_str = f.read()
|
||||
content_str = re.sub(prefix_pat, r"\g<1>{}".format(new_default), content_str)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content_str)
|
||||
print("modified", path)
|
||||
|
||||
|
||||
read_sub_write(
|
||||
os.path.join("torch", "onnx", "_constants.py"),
|
||||
r"(onnx_default_opset = )\d+",
|
||||
new_default,
|
||||
)
|
||||
read_sub_write(
|
||||
os.path.join("torch", "onnx", "__init__.py"), r"(opset_version \(int, default )\d+"
|
||||
os.path.join("torch", "onnx", "utils.py"),
|
||||
r"(opset_version \(int, default )\d+",
|
||||
new_default,
|
||||
)
|
||||
|
||||
if not args.skip_build:
|
||||
print("Building PyTorch...")
|
||||
subprocess.check_call(
|
||||
("python", "setup.py", "develop"),
|
||||
)
|
||||
print("Updating operator .expect files")
|
||||
subprocess.check_call(("python", "setup.py", "develop"), stdout=DEVNULL, stderr=DEVNULL)
|
||||
subprocess.check_call(
|
||||
("python", os.path.join("test", "onnx", "test_operators.py"), "--accept"),
|
||||
stdout=DEVNULL,
|
||||
stderr=DEVNULL,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--skip_build", action="store_true", help="Skip building pytorch"
|
||||
)
|
||||
main(parser.parse_args())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user