mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
ghstack-source-id: 18939fb645
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56597
144 lines
6.3 KiB
Python
144 lines
6.3 KiB
Python
import pathlib
|
|
import argparse
|
|
import os
|
|
import yaml
|
|
from typing import List, Dict, Union, Tuple, Sequence
|
|
from tools.codegen.gen import FileManager, get_grouped_native_functions, parse_native_yaml
|
|
from tools.codegen.model import (ExternalBackendFunction, ExternalBackendFunctionsGroup,
|
|
NativeFunction, NativeFunctionsGroup, OperatorName,
|
|
ExternalBackendMetadata, assert_never)
|
|
from tools.codegen.selective_build.selector import SelectiveBuilder
|
|
from tools.codegen.utils import Target, concatMap
|
|
import tools.codegen.dest as dest
|
|
|
|
try:
|
|
# use faster C loader if available
|
|
from yaml import CSafeLoader as Loader
|
|
except ImportError:
|
|
from yaml import SafeLoader as Loader # type: ignore
|
|
|
|
|
|
def parse_backend_yaml(
|
|
backend_yaml_path: str,
|
|
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]]
|
|
) -> Tuple[str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]:
|
|
with open(backend_yaml_path, 'r') as f:
|
|
yaml_values = yaml.load(f, Loader=Loader)
|
|
assert isinstance(yaml_values, dict)
|
|
|
|
valid_keys = ['backend', 'cpp_namespace', 'supported', 'autograd']
|
|
|
|
backend = yaml_values.pop('backend', None)
|
|
assert backend is not None, 'You must provide a value for "backend"'
|
|
cpp_namespace = yaml_values.pop('cpp_namespace', None)
|
|
assert cpp_namespace is not None, 'You must provide a value for "cpp_namespace"'
|
|
|
|
supported = yaml_values.pop('supported', [])
|
|
if supported is None:
|
|
supported = [] # Allow an empty list of supported ops
|
|
assert isinstance(supported, list), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'
|
|
supported_autograd = yaml_values.pop('autograd', [])
|
|
assert isinstance(supported, list), f'expected "autograd" to be a list, but got: {supported_autograd}'
|
|
|
|
assert len(yaml_values.keys()) == 0, \
|
|
f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \
|
|
Only the following keys are supported: {", ".join(valid_keys)}'
|
|
|
|
metadata: Dict[OperatorName, ExternalBackendMetadata] = {}
|
|
for op in supported:
|
|
op_name = OperatorName.parse(op)
|
|
m = ExternalBackendMetadata(op_name, backend, is_autograd=False)
|
|
metadata[m.operator] = m
|
|
for op in supported_autograd:
|
|
op_name = OperatorName.parse(op)
|
|
m = ExternalBackendMetadata(op_name, backend, is_autograd=True)
|
|
metadata[m.operator] = m
|
|
|
|
native_functions_map: Dict[OperatorName, NativeFunction] = {
|
|
f.func.name: f
|
|
for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions)
|
|
}
|
|
|
|
def native_to_external(
|
|
g: Union[NativeFunction, NativeFunctionsGroup]
|
|
) -> Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]:
|
|
if isinstance(g, NativeFunction):
|
|
f = g
|
|
m = metadata.get(f.func.name, None)
|
|
return ExternalBackendFunction(f, m)
|
|
elif isinstance(g, NativeFunctionsGroup):
|
|
return ExternalBackendFunctionsGroup.from_function_group(g, metadata)
|
|
else:
|
|
assert_never(g)
|
|
for op_name in metadata.keys():
|
|
assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}"
|
|
return cpp_namespace, [native_to_external(g) for g in grouped_native_functions]
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description='Generate backend stub files')
|
|
parser.add_argument(
|
|
'-s',
|
|
'--source_yaml',
|
|
help='path to source yaml file containing operator external definitions')
|
|
parser.add_argument(
|
|
'-o', '--output_dir', help='output directory')
|
|
parser.add_argument(
|
|
'--dry_run', type=bool, default=False, help='output directory')
|
|
options = parser.parse_args()
|
|
|
|
run(options.source_yaml, options.output_dir, options.dry_run)
|
|
|
|
def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
|
|
|
|
# Assumes that this file lives at PYTORCH_ROOT/tools/codegen/gen_backend_stubs.py
|
|
pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
|
|
template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")
|
|
|
|
def make_file_manager(install_dir: str) -> FileManager:
|
|
return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=dry_run)
|
|
|
|
fm = make_file_manager(output_dir)
|
|
|
|
native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
|
|
grouped_native_functions = get_grouped_native_functions(native_yaml_path)
|
|
cpp_namespace, external_backend_functions = parse_backend_yaml(source_yaml, grouped_native_functions)
|
|
|
|
native_functions = parse_native_yaml(native_yaml_path)
|
|
|
|
selector = SelectiveBuilder.get_nop_selector()
|
|
|
|
|
|
generated_comment = 'Autogenerated file by gen_backend_stubs.py. Do not edit directly!'
|
|
fm.write('aten_xla_type.h', lambda: {
|
|
'generated_comment': generated_comment,
|
|
'cpp_namespace': cpp_namespace,
|
|
'dispatch_xla_declarations': list(concatMap(dest.compute_native_function_declaration, external_backend_functions)),
|
|
})
|
|
|
|
fm.write('aten_xla_type_default.h', lambda: {
|
|
'generated_comment': generated_comment,
|
|
'cpp_namespace': cpp_namespace,
|
|
'dispatch_aten_fallback_declarations': list(concatMap(
|
|
dest.GenExternalAtenFallback(Target.NAMESPACED_DECLARATION), external_backend_functions
|
|
)),
|
|
})
|
|
|
|
fm.write('aten_xla_type_default.cpp', lambda: {
|
|
'generated_comment': generated_comment,
|
|
'cpp_namespace': cpp_namespace,
|
|
# TODO: after cpu fallbacks are moved to a boxed kernel,
|
|
# merge registrations / definitions into RegisterDispatchKey
|
|
'dispatch_aten_fallback_definitions': list(concatMap(
|
|
dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION), external_backend_functions
|
|
)),
|
|
'dispatch_registrations': list(concatMap(
|
|
dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if not e.is_autograd_kernel]
|
|
)),
|
|
'dispatch_autograd_registrations': list(concatMap(
|
|
dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if e.is_autograd_kernel]
|
|
)),
|
|
})
|
|
|
|
if __name__ == '__main__':
|
|
main()
|