pytorch/tools/setup_helpers/gen_version_header.py
Huy Do 347b036350 Apply ufmt linter to all py files under tools (#81285)
With ufmt in place https://github.com/pytorch/pytorch/pull/81157, we can now use it to gradually format all files. I'm breaking this down into multiple smaller batches to avoid too many merge conflicts later on.

This batch (as copied from the current BLACK linter config):
* `tools/**/*.py`

Upcoming batchs:
* `torchgen/**/*.py`
* `torch/package/**/*.py`
* `torch/onnx/**/*.py`
* `torch/_refs/**/*.py`
* `torch/_prims/**/*.py`
* `torch/_meta_registrations.py`
* `torch/_decomp/**/*.py`
* `test/onnx/**/*.py`

Once they are all formatted, BLACK linter will be removed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81285
Approved by: https://github.com/suo
2022-07-13 07:59:22 +00:00

90 lines
2.6 KiB
Python

# Ideally, there would be a way in Bazel to parse version.txt
# and use the version numbers from there as substitutions for
# an expand_template action. Since there isn't, this silly script exists.
import argparse
import os
from typing import cast, Dict, Tuple
Version = Tuple[int, int, int]
def parse_version(version: str) -> Version:
"""
Parses a version string into (major, minor, patch) version numbers.
Args:
version: Full version number string, possibly including revision / commit hash.
Returns:
An int 3-tuple of (major, minor, patch) version numbers.
"""
# Extract version number part (i.e. toss any revision / hash parts).
version_number_str = version
for i in range(len(version)):
c = version[i]
if not (c.isdigit() or c == "."):
version_number_str = version[:i]
break
return cast(Version, tuple([int(n) for n in version_number_str.split(".")]))
def apply_replacements(replacements: Dict[str, str], text: str) -> str:
"""
Applies the given replacements within the text.
Args:
replacements (dict): Mapping of str -> str replacements.
text (str): Text in which to make replacements.
Returns:
Text with replacements applied, if any.
"""
for (before, after) in replacements.items():
text = text.replace(before, after)
return text
def main(args: argparse.Namespace) -> None:
with open(args.version_path) as f:
version = f.read().strip()
(major, minor, patch) = parse_version(version)
replacements = {
"@TORCH_VERSION_MAJOR@": str(major),
"@TORCH_VERSION_MINOR@": str(minor),
"@TORCH_VERSION_PATCH@": str(patch),
}
# Create the output dir if it doesn't exist.
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
with open(args.template_path) as input:
with open(args.output_path, "w") as output:
for line in input.readlines():
output.write(apply_replacements(replacements, line))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate version.h from version.h.in template",
)
parser.add_argument(
"--template-path",
required=True,
help="Path to the template (i.e. version.h.in)",
)
parser.add_argument(
"--version-path",
required=True,
help="Path to the file specifying the version",
)
parser.add_argument(
"--output-path",
required=True,
help="Output path for expanded template (i.e. version.h)",
)
args = parser.parse_args()
main(args)