pytorch/functorch/codegen/gen.py

54 lines
1.8 KiB
Python

import os
import argparse
from torchgen.gen import FileManager, parse_native_yaml
from torchgen.gen import get_torchgen_root
from gen_vmap_plumbing import gen_all_vmap_plumbing
"""
INSTRUCTIONS
Step 1: You must have a PyTorch installation (in develop mode, i.e.
installed with python setup.py develop) in your current environment.
This script relies on the `tools` module from the PyTorch develop installation.
Step 2: Run this script.
python codegen/gen.py
"""
def main() -> None:
parser = argparse.ArgumentParser(description='functorch codegen')
parser.add_argument(
'-s',
'--source-path',
help='path to source directory for ATen',
default=None)
parser.add_argument(
'-d', '--install_dir', help='output directory',
default='functorch/csrc')
options = parser.parse_args()
generate_code(options.install_dir, options.source_path)
def generate_code(install_dir='functorch/csrc', source_path=None):
if source_path is None:
# infer the source path via torchgen
source_path = os.path.join(get_torchgen_root(), "packaged/ATen")
native_yaml_path = os.path.join(source_path, 'native/native_functions.yaml')
tags_path = os.path.join(source_path, 'native/tags.yaml')
parsed_yaml = parse_native_yaml(native_yaml_path, tags_path)
native_functions, _ = parsed_yaml.native_functions, parsed_yaml.backend_indices
template_dir = os.path.join(source_path, "templates")
def make_file_manager(install_dir: str) -> FileManager:
return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=False)
cpu_fm = make_file_manager(install_dir)
cpu_fm.write('VmapGeneratedPlumbing.h', lambda: gen_all_vmap_plumbing(native_functions))
if __name__ == '__main__':
main()