import pathlib import argparse import os import yaml from typing import List, Dict, Union, Tuple, Sequence from tools.codegen.gen import FileManager, get_grouped_native_functions, parse_native_yaml from tools.codegen.model import (ExternalBackendFunction, ExternalBackendFunctionsGroup, NativeFunction, NativeFunctionsGroup, OperatorName, ExternalBackendMetadata, assert_never) from tools.codegen.selective_build.selector import SelectiveBuilder from tools.codegen.utils import Target, concatMap import tools.codegen.dest as dest try: # use faster C loader if available from yaml import CSafeLoader as Loader except ImportError: from yaml import SafeLoader as Loader # type: ignore def parse_backend_yaml( backend_yaml_path: str, grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]] ) -> Tuple[str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]: with open(backend_yaml_path, 'r') as f: yaml_values = yaml.load(f, Loader=Loader) assert isinstance(yaml_values, dict) valid_keys = ['backend', 'cpp_namespace', 'supported', 'autograd'] backend = yaml_values.pop('backend', None) assert backend is not None, 'You must provide a value for "backend"' cpp_namespace = yaml_values.pop('cpp_namespace', None) assert cpp_namespace is not None, 'You must provide a value for "cpp_namespace"' supported = yaml_values.pop('supported', []) if supported is None: supported = [] # Allow an empty list of supported ops assert isinstance(supported, list), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})' supported_autograd = yaml_values.pop('autograd', []) assert isinstance(supported, list), f'expected "autograd" to be a list, but got: {supported_autograd}' assert len(yaml_values.keys()) == 0, \ f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \ Only the following keys are supported: {", ".join(valid_keys)}' metadata: Dict[OperatorName, ExternalBackendMetadata] = {} for op in supported: op_name = OperatorName.parse(op) m = ExternalBackendMetadata(op_name, backend, is_autograd=False) metadata[m.operator] = m for op in supported_autograd: op_name = OperatorName.parse(op) m = ExternalBackendMetadata(op_name, backend, is_autograd=True) metadata[m.operator] = m native_functions_map: Dict[OperatorName, NativeFunction] = { f.func.name: f for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions) } def native_to_external( g: Union[NativeFunction, NativeFunctionsGroup] ) -> Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]: if isinstance(g, NativeFunction): f = g m = metadata.get(f.func.name, None) return ExternalBackendFunction(f, m) elif isinstance(g, NativeFunctionsGroup): return ExternalBackendFunctionsGroup.from_function_group(g, metadata) else: assert_never(g) for op_name in metadata.keys(): assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}" return cpp_namespace, [native_to_external(g) for g in grouped_native_functions] def main() -> None: parser = argparse.ArgumentParser(description='Generate backend stub files') parser.add_argument( '-s', '--source_yaml', help='path to source yaml file containing operator external definitions') parser.add_argument( '-o', '--output_dir', help='output directory') parser.add_argument( '--dry_run', type=bool, default=False, help='output directory') options = parser.parse_args() run(options.source_yaml, options.output_dir, options.dry_run) def run(source_yaml: str, output_dir: str, dry_run: bool) -> None: # Assumes that this file lives at PYTORCH_ROOT/tools/codegen/gen_backend_stubs.py pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute() template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates") def make_file_manager(install_dir: str) -> FileManager: return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=dry_run) fm = make_file_manager(output_dir) native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml') grouped_native_functions = get_grouped_native_functions(native_yaml_path) cpp_namespace, external_backend_functions = parse_backend_yaml(source_yaml, grouped_native_functions) native_functions = parse_native_yaml(native_yaml_path) selector = SelectiveBuilder.get_nop_selector() generated_comment = 'Autogenerated file by gen_backend_stubs.py. Do not edit directly!' fm.write('aten_xla_type.h', lambda: { 'generated_comment': generated_comment, 'cpp_namespace': cpp_namespace, 'dispatch_xla_declarations': list(concatMap(dest.compute_native_function_declaration, external_backend_functions)), }) fm.write('aten_xla_type_default.h', lambda: { 'generated_comment': generated_comment, 'cpp_namespace': cpp_namespace, 'dispatch_aten_fallback_declarations': list(concatMap( dest.GenExternalAtenFallback(Target.NAMESPACED_DECLARATION), external_backend_functions )), }) fm.write('aten_xla_type_default.cpp', lambda: { 'generated_comment': generated_comment, 'cpp_namespace': cpp_namespace, # TODO: after cpu fallbacks are moved to a boxed kernel, # merge registrations / definitions into RegisterDispatchKey 'dispatch_aten_fallback_definitions': list(concatMap( dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION), external_backend_functions )), 'dispatch_registrations': list(concatMap( dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if not e.is_autograd_kernel] )), 'dispatch_autograd_registrations': list(concatMap( dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if e.is_autograd_kernel] )), }) if __name__ == '__main__': main()