mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pass install channel when building nightly images
Pass `TRITON_VERSION` argument to install triton for nightly images
Fix `generate_pytorch_version.py` to work with unannotated tags and avoid failures like the following:
```
% git checkout nightly
% ./.github/scripts/generate_pytorch_version.py
fatal: No annotated tags can describe '93f15b1b54ca5fb4a7ca9c21a813b4b86ebaeafa'.
However, there were unannotated tags: try --tags.
Traceback (most recent call last):
File "/Users/nshulga/git/pytorch/pytorch-release/./.github/scripts/generate_pytorch_version.py", line 120, in <module>
main()
File "/Users/nshulga/git/pytorch/pytorch-release/./.github/scripts/generate_pytorch_version.py", line 115, in main
print(version_obj.get_release_version())
File "/Users/nshulga/git/pytorch/pytorch-release/./.github/scripts/generate_pytorch_version.py", line 75, in get_release_version
if not get_tag():
File "/Users/nshulga/git/pytorch/pytorch-release/./.github/scripts/generate_pytorch_version.py", line 37, in get_tag
dirty_tag = subprocess.check_output(
File "/Users/nshulga/miniforge3/lib/python3.9/subprocess.py", line 424, in check_output
return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
File "/Users/nshulga/miniforge3/lib/python3.9/subprocess.py", line 528, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['git', 'describe']' returned non-zero exit status 128.
```
After the change nightly is reported as(due to autolabelling issue,
should be fixed by ttps://github.com/pytorch/test-infra/pull/1047 ):
```
% ./.github/scripts/generate_pytorch_version.py
ciflow/inductor/26921+cpu
```
Even for tagged release commits version generation was wrong:
```
% git checkout release/1.13
% ./.github/scripts/generate_pytorch_version.py
ciflow/periodic/79617-4848-g7c98e70d44+cpu
```
After the fix, it is as expected:
```
% ./.github/scripts/generate_pytorch_version.py
1.13.0+cpu
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88741
Approved by: https://github.com/dagitses, https://github.com/msaroufim
116 lines
3.5 KiB
Python
Executable File
116 lines
3.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import os
|
|
import subprocess
|
|
import re
|
|
|
|
from datetime import datetime
|
|
from distutils.util import strtobool
|
|
from pathlib import Path
|
|
|
|
LEADING_V_PATTERN = re.compile("^v")
|
|
TRAILING_RC_PATTERN = re.compile("-rc[0-9]*$")
|
|
LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
|
|
|
|
class NoGitTagException(Exception):
|
|
pass
|
|
|
|
def get_pytorch_root() -> Path:
|
|
return Path(subprocess.check_output(
|
|
['git', 'rev-parse', '--show-toplevel']
|
|
).decode('ascii').strip())
|
|
|
|
def get_tag() -> str:
|
|
root = get_pytorch_root()
|
|
try:
|
|
dirty_tag = subprocess.check_output(
|
|
['git', 'describe', '--tags', '--exact'],
|
|
cwd=root
|
|
).decode('ascii').strip()
|
|
except subprocess.CalledProcessError:
|
|
return ""
|
|
# Strip leading v that we typically do when we tag branches
|
|
# ie: v1.7.1 -> 1.7.1
|
|
tag = re.sub(LEADING_V_PATTERN, "", dirty_tag)
|
|
# Strip trailing rc pattern
|
|
# ie: 1.7.1-rc1 -> 1.7.1
|
|
tag = re.sub(TRAILING_RC_PATTERN, "", tag)
|
|
# Ignore ciflow tags
|
|
if tag.startswith("ciflow/"):
|
|
return ""
|
|
return tag
|
|
|
|
def get_base_version() -> str:
|
|
root = get_pytorch_root()
|
|
dirty_version = open(root / 'version.txt', 'r').read().strip()
|
|
# Strips trailing a0 from version.txt, not too sure why it's there in the
|
|
# first place
|
|
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
|
|
|
|
class PytorchVersion:
|
|
def __init__(
|
|
self,
|
|
gpu_arch_type: str,
|
|
gpu_arch_version: str,
|
|
no_build_suffix: bool,
|
|
) -> None:
|
|
self.gpu_arch_type = gpu_arch_type
|
|
self.gpu_arch_version = gpu_arch_version
|
|
self.no_build_suffix = no_build_suffix
|
|
|
|
def get_post_build_suffix(self) -> str:
|
|
if self.no_build_suffix:
|
|
return ""
|
|
if self.gpu_arch_type == "cuda":
|
|
return f"+cu{self.gpu_arch_version.replace('.', '')}"
|
|
return f"+{self.gpu_arch_type}{self.gpu_arch_version}"
|
|
|
|
def get_release_version(self) -> str:
|
|
if not get_tag():
|
|
raise NoGitTagException(
|
|
"Not on a git tag, are you sure you want a release version?"
|
|
)
|
|
return f"{get_tag()}{self.get_post_build_suffix()}"
|
|
|
|
def get_nightly_version(self) -> str:
|
|
date_str = datetime.today().strftime('%Y%m%d')
|
|
build_suffix = self.get_post_build_suffix()
|
|
return f"{get_base_version()}.dev{date_str}{build_suffix}"
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate pytorch version for binary builds"
|
|
)
|
|
parser.add_argument(
|
|
"--no-build-suffix",
|
|
action="store_true",
|
|
help="Whether or not to add a build suffix typically (+cpu)",
|
|
default=strtobool(os.environ.get("NO_BUILD_SUFFIX", "False"))
|
|
)
|
|
parser.add_argument(
|
|
"--gpu-arch-type",
|
|
type=str,
|
|
help="GPU arch you are building for, typically (cpu, cuda, rocm)",
|
|
default=os.environ.get("GPU_ARCH_TYPE", "cpu")
|
|
)
|
|
parser.add_argument(
|
|
"--gpu-arch-version",
|
|
type=str,
|
|
help="GPU arch version, typically (10.2, 4.0), leave blank for CPU",
|
|
default=os.environ.get("GPU_ARCH_VERSION", "")
|
|
)
|
|
args = parser.parse_args()
|
|
version_obj = PytorchVersion(
|
|
args.gpu_arch_type,
|
|
args.gpu_arch_version,
|
|
args.no_build_suffix
|
|
)
|
|
try:
|
|
print(version_obj.get_release_version())
|
|
except NoGitTagException:
|
|
print(version_obj.get_nightly_version())
|
|
|
|
if __name__ == "__main__":
|
|
main()
|