pytorch/docs/source/scripts/build_opsets.py
Xuehai Pan 35ea5c6b22 [3/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torchgen (#127124)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127124
Approved by: https://github.com/Skylion007
ghstack dependencies: #127122, #127123
2024-05-25 19:20:03 +00:00

76 lines
2.0 KiB
Python

import os
from collections import OrderedDict
from pathlib import Path
import torch
import torch._prims as prims
from torchgen.gen import parse_native_yaml
ROOT = Path(__file__).absolute().parent.parent.parent.parent
NATIVE_FUNCTION_YAML_PATH = ROOT / Path("aten/src/ATen/native/native_functions.yaml")
TAGS_YAML_PATH = ROOT / Path("aten/src/ATen/native/tags.yaml")
BUILD_DIR = "build/ir"
ATEN_OPS_CSV_FILE = "aten_ops.csv"
PRIMS_OPS_CSV_FILE = "prims_ops.csv"
def get_aten():
parsed_yaml = parse_native_yaml(NATIVE_FUNCTION_YAML_PATH, TAGS_YAML_PATH)
native_functions = parsed_yaml.native_functions
aten_ops = OrderedDict()
for function in native_functions:
if "core" in function.tags:
op_name = str(function.func.name)
aten_ops[op_name] = function
op_schema_pairs = []
for key, op in sorted(aten_ops.items()):
op_name = f"aten.{key}"
schema = str(op.func).replace("*", r"\*")
op_schema_pairs.append((op_name, schema))
return op_schema_pairs
def get_prims():
op_schema_pairs = []
for op_name in prims.__all__:
op_overload = getattr(prims, op_name, None)
if not isinstance(op_overload, torch._ops.OpOverload):
continue
op_overloadpacket = op_overload.overloadpacket
op_name = str(op_overload).replace(".default", "")
schema = op_overloadpacket.schema.replace("*", r"\*")
op_schema_pairs.append((op_name, schema))
return op_schema_pairs
def main():
aten_ops_list = get_aten()
prims_ops_list = get_prims()
os.makedirs(BUILD_DIR, exist_ok=True)
with open(os.path.join(BUILD_DIR, ATEN_OPS_CSV_FILE), "w") as f:
f.write("Operator,Schema\n")
for name, schema in aten_ops_list:
f.write(f'"``{name}``","{schema}"\n')
with open(os.path.join(BUILD_DIR, PRIMS_OPS_CSV_FILE), "w") as f:
f.write("Operator,Schema\n")
for name, schema in prims_ops_list:
f.write(f'"``{name}``","{schema}"\n')
if __name__ == "__main__":
main()