mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105428 Approved by: https://github.com/albanD, https://github.com/soulitzer, https://github.com/malfet
257 lines
7.5 KiB
Python
257 lines
7.5 KiB
Python
#!/usr/bin/env python3
|
|
|
|
""" Generates PyTorch ONNX Export Diagnostic rules for C++, Python and documentations.
|
|
The rules are defined in torch/onnx/_internal/diagnostics/rules.yaml.
|
|
|
|
Usage:
|
|
|
|
python -m tools.onnx.gen_diagnostics \
|
|
torch/onnx/_internal/diagnostics/rules.yaml \
|
|
torch/onnx/_internal/diagnostics \
|
|
torch/csrc/onnx/diagnostics/generated \
|
|
torch/docs/source
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import string
|
|
import subprocess
|
|
import textwrap
|
|
from typing import Any, Mapping, Sequence
|
|
|
|
import yaml
|
|
|
|
from torchgen import utils as torchgen_utils
|
|
from torchgen.yaml_utils import YamlLoader
|
|
|
|
_RULES_GENERATED_COMMENT = """\
|
|
GENERATED CODE - DO NOT EDIT DIRECTLY
|
|
This file is generated by gen_diagnostics.py.
|
|
See tools/onnx/gen_diagnostics.py for more information.
|
|
|
|
Diagnostic rules for PyTorch ONNX export.
|
|
"""
|
|
|
|
_PY_RULE_CLASS_COMMENT = """\
|
|
GENERATED CODE - DO NOT EDIT DIRECTLY
|
|
The purpose of generating a class for each rule is to override the `format_message`
|
|
method to provide more details in the signature about the format arguments.
|
|
"""
|
|
|
|
_PY_RULE_CLASS_TEMPLATE = """\
|
|
class _{pascal_case_name}(infra.Rule):
|
|
\"\"\"{short_description}\"\"\"
|
|
def format_message( # type: ignore[override]
|
|
self,
|
|
{message_arguments}
|
|
) -> str:
|
|
\"\"\"Returns the formatted default message of this Rule.
|
|
|
|
Message template: {message_template}
|
|
\"\"\"
|
|
return self.message_default_template.format({message_arguments_assigned})
|
|
|
|
def format( # type: ignore[override]
|
|
self,
|
|
level: infra.Level,
|
|
{message_arguments}
|
|
) -> Tuple[infra.Rule, infra.Level, str]:
|
|
\"\"\"Returns a tuple of (Rule, Level, message) for this Rule.
|
|
|
|
Message template: {message_template}
|
|
\"\"\"
|
|
return self, level, self.format_message({message_arguments_assigned})
|
|
|
|
"""
|
|
|
|
_PY_RULE_COLLECTION_FIELD_TEMPLATE = """\
|
|
{snake_case_name}: _{pascal_case_name} = dataclasses.field(
|
|
default=_{pascal_case_name}.from_sarif(**{sarif_dict}),
|
|
init=False,
|
|
)
|
|
\"\"\"{short_description}\"\"\"
|
|
"""
|
|
|
|
_CPP_RULE_TEMPLATE = """\
|
|
/**
|
|
* @brief {short_description}
|
|
*/
|
|
{name},
|
|
"""
|
|
|
|
_RuleType = Mapping[str, Any]
|
|
|
|
|
|
def _kebab_case_to_snake_case(name: str) -> str:
|
|
return name.replace("-", "_")
|
|
|
|
|
|
def _kebab_case_to_pascal_case(name: str) -> str:
|
|
return "".join(word.capitalize() for word in name.split("-"))
|
|
|
|
|
|
def _format_rule_for_python_class(rule: _RuleType) -> str:
|
|
pascal_case_name = _kebab_case_to_pascal_case(rule["name"])
|
|
short_description = rule["short_description"]["text"]
|
|
message_template = rule["message_strings"]["default"]["text"]
|
|
field_names = [
|
|
field_name
|
|
for _, field_name, _, _ in string.Formatter().parse(message_template)
|
|
if field_name is not None
|
|
]
|
|
for field_name in field_names:
|
|
assert isinstance(
|
|
field_name, str
|
|
), f"Unexpected field type {type(field_name)} from {field_name}. "
|
|
"Field name must be string.\nFull message template: {message_template}"
|
|
assert (
|
|
not field_name.isnumeric()
|
|
), f"Unexpected numeric field name {field_name}. "
|
|
"Only keyword name formatting is supported.\nFull message template: {message_template}"
|
|
message_arguments = ", ".join(field_names)
|
|
message_arguments_assigned = ", ".join(
|
|
[f"{field_name}={field_name}" for field_name in field_names]
|
|
)
|
|
return _PY_RULE_CLASS_TEMPLATE.format(
|
|
pascal_case_name=pascal_case_name,
|
|
short_description=short_description,
|
|
message_template=repr(message_template),
|
|
message_arguments=message_arguments,
|
|
message_arguments_assigned=message_arguments_assigned,
|
|
)
|
|
|
|
|
|
def _format_rule_for_python_field(rule: _RuleType) -> str:
|
|
snake_case_name = _kebab_case_to_snake_case(rule["name"])
|
|
pascal_case_name = _kebab_case_to_pascal_case(rule["name"])
|
|
short_description = rule["short_description"]["text"]
|
|
|
|
return _PY_RULE_COLLECTION_FIELD_TEMPLATE.format(
|
|
snake_case_name=snake_case_name,
|
|
pascal_case_name=pascal_case_name,
|
|
sarif_dict=rule,
|
|
short_description=short_description,
|
|
)
|
|
|
|
|
|
def _format_rule_for_cpp(rule: _RuleType) -> str:
|
|
name = f"k{_kebab_case_to_pascal_case(rule['name'])}"
|
|
short_description = rule["short_description"]["text"]
|
|
return _CPP_RULE_TEMPLATE.format(name=name, short_description=short_description)
|
|
|
|
|
|
def gen_diagnostics_python(
|
|
rules: Sequence[_RuleType], out_py_dir: str, template_dir: str
|
|
) -> None:
|
|
rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules]
|
|
rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules]
|
|
|
|
fm = torchgen_utils.FileManager(
|
|
install_dir=out_py_dir, template_dir=template_dir, dry_run=False
|
|
)
|
|
fm.write_with_template(
|
|
"_rules.py",
|
|
"rules.py.in",
|
|
lambda: {
|
|
"generated_comment": _RULES_GENERATED_COMMENT,
|
|
"generated_rule_class_comment": _PY_RULE_CLASS_COMMENT,
|
|
"rule_classes": "\n".join(rule_class_lines),
|
|
"rules": textwrap.indent("\n".join(rule_field_lines), " " * 4),
|
|
},
|
|
)
|
|
_lint_file(os.path.join(out_py_dir, "_rules.py"))
|
|
|
|
|
|
def gen_diagnostics_cpp(
|
|
rules: Sequence[_RuleType], out_cpp_dir: str, template_dir: str
|
|
) -> None:
|
|
rule_lines = [_format_rule_for_cpp(rule) for rule in rules]
|
|
rule_names = [f'"{_kebab_case_to_snake_case(rule["name"])}",' for rule in rules]
|
|
|
|
fm = torchgen_utils.FileManager(
|
|
install_dir=out_cpp_dir, template_dir=template_dir, dry_run=False
|
|
)
|
|
fm.write_with_template(
|
|
"rules.h",
|
|
"rules.h.in",
|
|
lambda: {
|
|
"generated_comment": textwrap.indent(
|
|
_RULES_GENERATED_COMMENT,
|
|
" * ",
|
|
predicate=lambda x: True, # Don't ignore empty line
|
|
),
|
|
"rules": textwrap.indent("\n".join(rule_lines), " " * 2),
|
|
"py_rule_names": textwrap.indent("\n".join(rule_names), " " * 4),
|
|
},
|
|
)
|
|
_lint_file(os.path.join(out_cpp_dir, "rules.h"))
|
|
|
|
|
|
def gen_diagnostics_docs(
|
|
rules: Sequence[_RuleType], out_docs_dir: str, template_dir: str
|
|
) -> None:
|
|
# TODO: Add doc generation in a follow-up PR.
|
|
pass
|
|
|
|
|
|
def _lint_file(file_path: str) -> None:
|
|
p = subprocess.Popen(["lintrunner", "-a", file_path])
|
|
p.wait()
|
|
|
|
|
|
def gen_diagnostics(
|
|
rules_path: str,
|
|
out_py_dir: str,
|
|
out_cpp_dir: str,
|
|
out_docs_dir: str,
|
|
) -> None:
|
|
with open(rules_path) as f:
|
|
rules = yaml.load(f, Loader=YamlLoader)
|
|
|
|
template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
|
|
|
|
gen_diagnostics_python(
|
|
rules,
|
|
out_py_dir,
|
|
template_dir,
|
|
)
|
|
|
|
gen_diagnostics_cpp(
|
|
rules,
|
|
out_cpp_dir,
|
|
template_dir,
|
|
)
|
|
|
|
gen_diagnostics_docs(rules, out_docs_dir, template_dir)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Generate ONNX diagnostics files")
|
|
parser.add_argument("rules_path", metavar="RULES", help="path to rules.yaml")
|
|
parser.add_argument(
|
|
"out_py_dir",
|
|
metavar="OUT_PY",
|
|
help="path to output directory for Python",
|
|
)
|
|
parser.add_argument(
|
|
"out_cpp_dir",
|
|
metavar="OUT_CPP",
|
|
help="path to output directory for C++",
|
|
)
|
|
parser.add_argument(
|
|
"out_docs_dir",
|
|
metavar="OUT_DOCS",
|
|
help="path to output directory for docs",
|
|
)
|
|
args = parser.parse_args()
|
|
gen_diagnostics(
|
|
args.rules_path,
|
|
args.out_py_dir,
|
|
args.out_cpp_dir,
|
|
args.out_docs_dir,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|