[WIP] generate in-place/out wrappers for external kernels

ghstack-source-id: 859e07e54a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56835
This commit is contained in:
Brian Hirsh 2021-04-26 15:32:08 -07:00
parent 10c4bf510e
commit 2fefb36870
10 changed files with 256 additions and 95 deletions

View File

@ -34,6 +34,7 @@
#include <torch/library.h> #include <torch/library.h>
$extra_cuda_headers $extra_cuda_headers
$legacy_th_headers $legacy_th_headers
$external_backend_headers
namespace at { namespace at {

View File

@ -108,10 +108,6 @@ ${dispatch_aten_fallback_definitions}
TORCH_LIBRARY_IMPL(aten, XLA, m) { TORCH_LIBRARY_IMPL(aten, XLA, m) {
${dispatch_registrations} ${dispatch_registrations}
}
TORCH_LIBRARY_IMPL(aten, AutogradXLA, m) {
${dispatch_autograd_registrations}
} }
} // namespace torch_xla } // namespace torch_xla

View File

@ -24,8 +24,13 @@ from typing import Sequence, List, Union
# arguments. # arguments.
# #
def name(func: FunctionSchema) -> str: def name(func: FunctionSchema, *, append_overload_name: bool = False) -> str:
return cpp.name(func) name = cpp.name(func)
if append_overload_name and func.name.overload_name != '':
# This isn't an important characteristic of the dispatcher API, and could be removed if we want to further unify
# The dispatcher and C++ API's. There happen to be a few places in the codegen where we need to guarantee name uniqueness.
name = f'{name}_{func.name.overload_name}'
return name
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
# This is a faux amis. If it makes sense in the future to add # This is a faux amis. If it makes sense in the future to add

View File

@ -406,11 +406,15 @@ class DispatcherSignature:
# The schema this signature is derived from # The schema this signature is derived from
func: FunctionSchema func: FunctionSchema
prefix: str
append_overload_name: bool
def arguments(self) -> List[Binding]: def arguments(self) -> List[Binding]:
return dispatcher.arguments(self.func) return dispatcher.arguments(self.func)
def name(self) -> str: def name(self) -> str:
return dispatcher.name(self.func) return self.prefix + dispatcher.name(self.func, append_overload_name=self.append_overload_name)
def decl(self, name: Optional[str] = None) -> str: def decl(self, name: Optional[str] = None) -> str:
args_str = ', '.join(a.decl() for a in self.arguments()) args_str = ', '.join(a.decl() for a in self.arguments())
@ -440,8 +444,8 @@ class DispatcherSignature:
return f'{self.returns_type().cpp_type()} ({dispatcher_args_types_str})' return f'{self.returns_type().cpp_type()} ({dispatcher_args_types_str})'
@staticmethod @staticmethod
def from_schema(func: FunctionSchema) -> 'DispatcherSignature': def from_schema(func: FunctionSchema, *, prefix: str = '', append_overload_name: bool = False) -> 'DispatcherSignature':
return DispatcherSignature(func) return DispatcherSignature(func, prefix, append_overload_name)
@dataclass(frozen=True) @dataclass(frozen=True)
class NativeSignature: class NativeSignature:

View File

@ -16,6 +16,7 @@ F = TypeVar(
ExternalBackendFunction, ExternalBackendFunction,
ExternalBackendFunctionsGroup, ExternalBackendFunctionsGroup,
Union[NativeFunction, NativeFunctionsGroup], Union[NativeFunction, NativeFunctionsGroup],
Union[NativeFunction, ExternalBackendFunction],
Union[ExternalBackendFunctionsGroup, ExternalBackendFunction], Union[ExternalBackendFunctionsGroup, ExternalBackendFunction],
Union[NativeFunction, NativeFunctionsGroup, ExternalBackendFunction, ExternalBackendFunctionsGroup] Union[NativeFunction, NativeFunctionsGroup, ExternalBackendFunction, ExternalBackendFunctionsGroup]
) )

View File

@ -6,7 +6,7 @@ import re
from tools.codegen.context import method_with_native_function from tools.codegen.context import method_with_native_function
from tools.codegen.utils import Target, mapMaybe from tools.codegen.utils import Target, mapMaybe
from tools.codegen.model import (Argument, ExternalBackendFunction, from tools.codegen.model import (Argument, ExternalBackendFunction,
ExternalBackendFunctionsGroup, SchemaKind, ExternalBackendFunctionsGroup,
assert_never, Return, is_generic_dispatch_key, assert_never, Return, is_generic_dispatch_key,
ListType, OptionalType, BaseType, BaseTy, Variant) ListType, OptionalType, BaseType, BaseTy, Variant)
from tools.codegen.api.types import DispatcherSignature, CppSignatureGroup from tools.codegen.api.types import DispatcherSignature, CppSignatureGroup
@ -102,7 +102,7 @@ def xla_tensor_creation_api(
# do not have full aten coverage. # do not have full aten coverage.
# For operators not implemented by the external backend, our codegen # For operators not implemented by the external backend, our codegen
# will register these fallbacks instead. # will register these fallbacks instead.
# - Why do we generate fallback for ALL aten ops, including ops that # - Why do we generate fallback for ALL (non-composite) aten ops, including ops that
# external backends have already implemented? # external backends have already implemented?
# Many external backend kernels only work with specific input shapes, # Many external backend kernels only work with specific input shapes,
# and are written to call into a cpu fallback when given inputs # and are written to call into a cpu fallback when given inputs
@ -117,41 +117,6 @@ class GenExternalAtenFallback:
@method_with_native_function @method_with_native_function
def __call__(self, g: Union[ExternalBackendFunctionsGroup, ExternalBackendFunction]) -> List[str]: def __call__(self, g: Union[ExternalBackendFunctionsGroup, ExternalBackendFunction]) -> List[str]:
def gen_out_wrapper(g: ExternalBackendFunctionsGroup) -> Optional[str]:
dispatcher_sig = DispatcherSignature.from_schema(g.out.native_function.func)
name = dispatcher_sig.name()
dispatcher_order_args = dispatcher.jit_arguments(g.out.native_function.func)
tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)]
print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors])
func_name = f'AtenXlaTypeDefault::{name}'
functional_result_name = f'{name}_tmp'
return_names = cpp.return_names(g.out.native_function)
if len(return_names) > 1:
updates = '\n '.join(
f'at::_copy_from_and_resize(std::get<{i}>({functional_result_name}), {ret_name});'
for i, ret_name in enumerate(return_names))
returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})'
else:
ret_name = return_names[0]
updates = f'at::_copy_from_and_resize({functional_result_name}, {ret_name});'
returns = ret_name
functional_sig = DispatcherSignature.from_schema(g.functional.native_function.func)
return f"""\
{dispatcher_sig.defn(name=func_name)} {{
XLA_FN_TRACK(3);
TF_VLOG(3) << "XLA {name} :"{print_args_str};
auto {functional_result_name} = AtenXlaType::{functional_sig.name()}({", ".join(a.name for a in functional_sig.arguments())});
{updates}
return {returns};
}}
"""
def gen_unstructured_external(f: ExternalBackendFunction) -> Optional[str]: def gen_unstructured_external(f: ExternalBackendFunction) -> Optional[str]:
if not requires_backend_wrapper(f): if not requires_backend_wrapper(f):
return None return None
@ -174,8 +139,7 @@ class GenExternalAtenFallback:
return device_like[0].name return device_like[0].name
raise AssertionError("Need a tensor-like or device argument in order to determine the output device") raise AssertionError("Need a tensor-like or device argument in order to determine the output device")
# XLA appears to have used the dispatcher convention to write their kernel signatures, # See Note [External Backends Follow Dispatcher convention]
# probably because they based their signatures off of our RegistrationDeclarations.h
dispatcher_sig = DispatcherSignature.from_schema(f.native_function.func) dispatcher_sig = DispatcherSignature.from_schema(f.native_function.func)
name = dispatcher_sig.name() name = dispatcher_sig.name()
args = dispatcher_sig.arguments() args = dispatcher_sig.arguments()
@ -184,25 +148,19 @@ class GenExternalAtenFallback:
return f" static {dispatcher_sig.decl()};" return f" static {dispatcher_sig.decl()};"
elif self.target is Target.REGISTRATION: elif self.target is Target.REGISTRATION:
if f.metadata is not None: # This codegen is only responsible for registering CPU fallback kernels
# xla has their own kernel: register it # We also skip registrations if there is a functional backend kernel,
namespace = 'AtenXlaType' # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py).
else: if f.metadata is not None or (isinstance(g, ExternalBackendFunctionsGroup) and g.functional.metadata is not None):
# xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel). return ''
namespace = 'AtenXlaTypeDefault' payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})"
payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})"
return f' m.impl("{f.native_function.func.name}", {payload});\n' return f' m.impl("{f.native_function.func.name}", {payload});\n'
if self.target is not Target.NAMESPACED_DEFINITION: if self.target is not Target.NAMESPACED_DEFINITION:
assert_never(self.target) assert_never(self.target)
# Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators.
# TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list
if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \
and isinstance(g, ExternalBackendFunctionsGroup):
return gen_out_wrapper(g)
# Everything below here is where we generate the CPU fallback. # Everything below here is where we generate the CPU fallback.
# See Note [External Backends Follow Dispatcher convention]
dispatcher_order_args = dispatcher.jit_arguments(f.native_function.func) dispatcher_order_args = dispatcher.jit_arguments(f.native_function.func)
# Map each argument to it's intermediate variable name in the fallback # Map each argument to it's intermediate variable name in the fallback

View File

@ -4,12 +4,13 @@ from typing_extensions import Literal
from dataclasses import dataclass from dataclasses import dataclass
import textwrap import textwrap
from tools.codegen.context import method_with_native_function from tools.codegen.context import method_with_native_function, with_native_function
from tools.codegen.utils import Target, mapMaybe from tools.codegen.utils import Target, mapMaybe
from tools.codegen.model import (DispatchKey, NativeFunction, from tools.codegen.model import (DispatchKey, NativeFunction,
NativeFunctionsGroup, SchemaKind, NativeFunctionsGroup, SchemaKind,
ExternalBackendFunctionsGroup, ExternalBackendFunction,
TensorOptionsArguments, assert_never, TensorOptionsArguments, assert_never,
is_cuda_dispatch_key, is_cuda_dispatch_key, BaseType, BaseTy,
is_structured_dispatch_key) is_structured_dispatch_key)
from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType, from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
CppSignature, CppSignatureGroup, CppSignature, CppSignatureGroup,
@ -17,6 +18,9 @@ from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
NativeSignature, tensorT, NamedCType) NativeSignature, tensorT, NamedCType)
import tools.codegen.api.meta as meta import tools.codegen.api.meta as meta
import tools.codegen.api.structured as structured import tools.codegen.api.structured as structured
import tools.codegen.api.dispatcher as dispatcher
import tools.codegen.api.cpp as cpp
from tools.codegen.dest.gen_external_aten_fallbacks import requires_backend_wrapper
from tools.codegen.api.translate import translate from tools.codegen.api.translate import translate
from tools.codegen.selective_build.selector import SelectiveBuilder from tools.codegen.selective_build.selector import SelectiveBuilder
@ -56,14 +60,35 @@ class RegisterDispatchKey:
# Whether or not we are actually code-genning for ROCm # Whether or not we are actually code-genning for ROCm
rocm: bool rocm: bool
# The namespace that the kernels are written in. This is just `at::native` for in-tree kernels.
cpp_namespace: str
@method_with_native_function @method_with_native_function
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: def __call__(self, f: Union[
NativeFunctionsGroup,
NativeFunction,
ExternalBackendFunctionsGroup,
ExternalBackendFunction]
) -> List[str]:
if isinstance(f, NativeFunctionsGroup): if isinstance(f, NativeFunctionsGroup):
if f.structured: if f.structured:
return self.gen_structured(f) return self.gen_structured(f)
else: else:
return list(mapMaybe(self.gen_unstructured, f.functions())) return list(mapMaybe(self.gen_unstructured, f.functions()))
elif isinstance(f, NativeFunction): elif isinstance(f, ExternalBackendFunctionsGroup):
if f.structured:
raise AssertionError("structured kernels not implemented yet for external backends")
elif f.primary == f.functional:
# For external backends that specify that they'd like to primarily implement functional kernels (namely XLA),
# we can generate anonymous wrappers for the out and in-place kernels for them, even for un-structured operators.
# Note that we can't go the other way around (generate the functional using the out), since we don't know
# how to create the output tensor without a meta function.
return self.gen_out_inplace_wrappers(f)
else:
# For backends that implement out kernels, they need to port their ops to structured
# if they want generated functional/inplace kernels.
return list(mapMaybe(self.gen_unstructured, f.functions()))
elif isinstance(f, NativeFunction) or isinstance(f, ExternalBackendFunction):
r = self.gen_unstructured(f) r = self.gen_unstructured(f)
return [] if r is None else [r] return [] if r is None else [r]
else: else:
@ -88,12 +113,38 @@ class RegisterDispatchKey:
self.target, self.target,
self.selector, self.selector,
self.rocm, self.rocm,
self.cpp_namespace,
g g
) )
return list(mapMaybe(structured_gen.gen_one, g.functions())) return list(mapMaybe(structured_gen.gen_one, g.functions()))
@method_with_native_function def gen_unstructured(
def gen_unstructured(self, f: NativeFunction) -> Optional[str]: self,
native_or_external: Union[NativeFunction, ExternalBackendFunction],
*,
# Only applies to ExternalBackendFunction objects.
# True for inplace/out functions that don't have kernels, but do have corresponding functional kernels.
is_generated_wrapper: bool = False,
) -> Optional[str]:
sig: Union[NativeSignature, DispatcherSignature]
if isinstance(native_or_external, ExternalBackendFunction):
if not requires_backend_wrapper(native_or_external) and not is_generated_wrapper:
return None
# If the backend doesn't have a kernel, we don't generate an anonymous wrapper or a dispatcher registration for it
# (fallbacks to CPU are generated elsewhere).
if native_or_external.metadata is None and not is_generated_wrapper:
# TODO: Right now, we generate "fast-path" `at::xla::{op}` kernels for all non-composite ops,
# including those with CPU fallbacks. That logic will have to change if CPU fallbacks become a boxed kernel.
if self.target is not Target.NAMESPACED_DEFINITION and self.target is not Target.NAMESPACED_DECLARATION:
return None
f = native_or_external.native_function
sig = self.external_backend_wrapper_sig(native_or_external)
elif isinstance(native_or_external, NativeFunction):
f = native_or_external
sig = NativeSignature(f.func, prefix='wrapper_')
else:
assert_never(f)
inplace_meta = False inplace_meta = False
if self.dispatch_key not in f.dispatch: if self.dispatch_key not in f.dispatch:
if (self.dispatch_key == DispatchKey.Meta and if (self.dispatch_key == DispatchKey.Meta and
@ -112,8 +163,6 @@ class RegisterDispatchKey:
if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f): if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f):
return None return None
sig = NativeSignature(f.func, prefix='wrapper_')
name = sig.name() name = sig.name()
returns_type = sig.returns_type().cpp_type() returns_type = sig.returns_type().cpp_type()
args = sig.arguments() args = sig.arguments()
@ -129,9 +178,19 @@ class RegisterDispatchKey:
return result return result
elif self.target is Target.NAMESPACED_DEFINITION: elif self.target is Target.NAMESPACED_DEFINITION:
def generate_defn(cpp_sig: CppSignature) -> str: def generate_defn(cpp_sig: CppSignature) -> str:
# This is needed in order for namespaced definitions to call into CPU fallbacks,
# which live in a different namespace.
# TODO: this logic will change if we implement a boxed CPU fallback.
if isinstance(native_or_external, ExternalBackendFunction) \
and native_or_external.metadata is None \
and not is_generated_wrapper:
# See Note [External Backends Follow Dispatcher convention]
kernel_name = f'{self.cpp_namespace}::AtenXlaTypeDefault::{dispatcher.name(f.func)}'
else:
kernel_name = sig.name()
return f""" return f"""
{cpp_sig.defn()} {{ {cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); return {kernel_name}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}} }}
""" """
result = generate_defn(cpp_sig_group.signature) result = generate_defn(cpp_sig_group.signature)
@ -152,7 +211,11 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
}} }}
""" """
impl_name = f"at::native::{f.dispatch[self.dispatch_key]}" if isinstance(native_or_external, ExternalBackendFunction):
# TODO: remove this difference and merge the two cases when we remove xla-specific logic
impl_name = f"{self.cpp_namespace}::AtenXlaType::{f.dispatch[self.dispatch_key]}"
else:
impl_name = f"{self.cpp_namespace}::{f.dispatch[self.dispatch_key]}"
args_exprs_str = ', '.join(a.name for a in args) args_exprs_str = ', '.join(a.name for a in args)
@ -196,12 +259,74 @@ namespace {{
if f.manual_kernel_registration: if f.manual_kernel_registration:
return None return None
else: else:
dispatcher_sig = DispatcherSignature.from_schema(f.func)
payload = f"TORCH_FN({name})" payload = f"TORCH_FN({name})"
return f'm.impl("{f.func.name}",\n{payload});\n' return f'm.impl("{f.func.name}",\n{payload});\n'
else: else:
assert_never(self.target) assert_never(self.target)
def external_backend_wrapper_sig(self, f: ExternalBackendFunction) -> DispatcherSignature:
# See Note [External Backends Follow Dispatcher convention]
return DispatcherSignature.from_schema(f.native_function.func, prefix='wrapper_', append_overload_name=True)
def gen_out_inplace_wrappers(self, g: ExternalBackendFunctionsGroup) -> List[str]:
def gets_generated_out_inplace_wrapper(f: ExternalBackendFunction) -> bool:
return f.native_function.func.kind() is not SchemaKind.functional \
and f.metadata is None and g.functional.metadata is not None
@with_native_function
def gen_wrapper(f: ExternalBackendFunction) -> Optional[str]:
# Only anonymous definitions get "special treatment" for out/inplace wrappers.
# All other functionality can be directly pulled from gen_unstructured.
if self.target is not Target.ANONYMOUS_DEFINITION:
is_generated_wrapper = gets_generated_out_inplace_wrapper(f)
return self.gen_unstructured(f, is_generated_wrapper=is_generated_wrapper)
if f.native_function.func.kind() is SchemaKind.functional:
# Wrappers are generated for out/inplace kernels, using the functional kernel,
# so the functional kernel itself is generated normally.
return self.gen_unstructured(f)
if f.metadata is not None:
# If the backend provided their own out/inplace kernel, use it.
return self.gen_unstructured(f)
if not gets_generated_out_inplace_wrapper(f):
return None
# Special out/inplace wrapper logic starts here.
dispatcher_sig = self.external_backend_wrapper_sig(f)
name = dispatcher_sig.name()
# See Note [External Backends Follow Dispatcher convention]
dispatcher_order_args = dispatcher.jit_arguments(f.native_function.func)
tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)]
print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors])
functional_result_name = f'{name}_tmp'
return_names = cpp.return_names(f.native_function)
if len(return_names) > 1:
updates = '\n '.join(
f'at::_copy_from_and_resize(std::get<{i}>({functional_result_name}), {ret_name});'
for i, ret_name in enumerate(return_names))
returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})'
else:
ret_name = return_names[0]
updates = f'at::_copy_from_and_resize({functional_result_name}, {ret_name});'
returns = ret_name
functional_sig = self.external_backend_wrapper_sig(g.functional)
return f"""\
{dispatcher_sig.defn()} {{
XLA_FN_TRACK(3);
TF_VLOG(3) << "XLA {name} :"{print_args_str};
auto {functional_result_name} = {functional_sig.name()}({", ".join(a.name for a in functional_sig.arguments())});
{updates}
return {returns};
}}
"""
return list(mapMaybe(gen_wrapper, g.functions(functional_first=True)))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# #

View File

@ -916,20 +916,22 @@ def main() -> None:
'#include <ATen/LegacyTHFunctionsCPU.h>' if dispatch_key == DispatchKey.CPU else '#include <ATen/LegacyTHFunctionsCPU.h>' if dispatch_key == DispatchKey.CPU else
'#include <ATen/LegacyTHFunctionsCUDA.h>' if dispatch_key == DispatchKey.CUDA else '#include <ATen/LegacyTHFunctionsCUDA.h>' if dispatch_key == DispatchKey.CUDA else
'', '',
'external_backend_headers': '',
'DispatchKey': dispatch_key, 'DispatchKey': dispatch_key,
'dispatch_namespace': dispatch_key.lower(), 'dispatch_namespace': dispatch_key.lower(),
'dispatch_namespaced_definitions': list(concatMap( 'dispatch_namespaced_definitions': list(concatMap(
dest.RegisterDispatchKey( dest.RegisterDispatchKey(
dispatch_key, Target.NAMESPACED_DEFINITION, selector, rocm=options.rocm), dispatch_key, Target.NAMESPACED_DEFINITION, selector, rocm=options.rocm, cpp_namespace='at::native'),
grouped_native_functions grouped_native_functions
)), )),
'dispatch_anonymous_definitions': list(concatMap( 'dispatch_anonymous_definitions': list(concatMap(
dest.RegisterDispatchKey( dest.RegisterDispatchKey(
dispatch_key, Target.ANONYMOUS_DEFINITION, selector, rocm=options.rocm), dispatch_key, Target.ANONYMOUS_DEFINITION, selector, rocm=options.rocm, cpp_namespace='at::native'),
grouped_native_functions grouped_native_functions
)), )),
'dispatch_registrations': list(concatMap( 'dispatch_registrations': list(concatMap(
dest.RegisterDispatchKey(dispatch_key, Target.REGISTRATION, selector, rocm=options.rocm), dest.RegisterDispatchKey(
dispatch_key, Target.REGISTRATION, selector, rocm=options.rocm, cpp_namespace='at::native'),
grouped_native_functions grouped_native_functions
)), )),
}) })
@ -939,7 +941,7 @@ def main() -> None:
'dispatch_namespace': dispatch_key.lower(), 'dispatch_namespace': dispatch_key.lower(),
'dispatch_namespaced_declarations': list(concatMap( 'dispatch_namespaced_declarations': list(concatMap(
dest.RegisterDispatchKey( dest.RegisterDispatchKey(
dispatch_key, Target.NAMESPACED_DECLARATION, selector, rocm=options.rocm), dispatch_key, Target.NAMESPACED_DECLARATION, selector, rocm=options.rocm, cpp_namespace='at::native'),
grouped_native_functions grouped_native_functions
)), )),
}) })

View File

@ -2,7 +2,7 @@ import pathlib
import argparse import argparse
import os import os
import yaml import yaml
from typing import List, Dict, Union, Tuple, Sequence from typing import List, Dict, Union, Tuple, Sequence, Optional
from tools.codegen.gen import FileManager, get_grouped_native_functions, parse_native_yaml from tools.codegen.gen import FileManager, get_grouped_native_functions, parse_native_yaml
from tools.codegen.model import (ExternalBackendFunction, ExternalBackendFunctionsGroup, from tools.codegen.model import (ExternalBackendFunction, ExternalBackendFunctionsGroup,
NativeFunction, NativeFunctionsGroup, OperatorName, NativeFunction, NativeFunctionsGroup, OperatorName,
@ -17,13 +17,14 @@ try:
# use faster C loader if available # use faster C loader if available
from yaml import CSafeLoader as Loader from yaml import CSafeLoader as Loader
except ImportError: except ImportError:
from yaml import SafeLoader as Loader # type: ignore from yaml import SafeLoader as Loader # type: ignore[misc]
# Parses the backend's yaml file and returns (BackendDispatchKey, AutogradDispatchKey, cpp_namespace, backend_kernel_list)
def parse_backend_yaml( def parse_backend_yaml(
backend_yaml_path: str, backend_yaml_path: str,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]] grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]]
) -> Tuple[str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]: ) -> Tuple[DispatchKey, Optional[DispatchKey], str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]:
with open(backend_yaml_path, 'r') as f: with open(backend_yaml_path, 'r') as f:
yaml_values = yaml.load(f, Loader=Loader) yaml_values = yaml.load(f, Loader=Loader)
assert isinstance(yaml_values, dict) assert isinstance(yaml_values, dict)
@ -71,7 +72,8 @@ Only the following keys are supported: {", ".join(valid_keys)}'
for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions) for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions)
} }
def kernel_name(func: FunctionSchema) -> str: def external_kernel_name(func: FunctionSchema) -> str:
# Note [External Backends Follow Dispatcher convention]
# For external backends, we enforce that their names and signatures match the dispatcher convention # For external backends, we enforce that their names and signatures match the dispatcher convention
return dispatcher.name(func) return dispatcher.name(func)
@ -83,17 +85,17 @@ Only the following keys are supported: {", ".join(valid_keys)}'
m = metadata.get(f.func.name, None) m = metadata.get(f.func.name, None)
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \ dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
if m is not None and m.is_autograd else DispatchKey.parse(backend) if m is not None and m.is_autograd else DispatchKey.parse(backend)
kernel = kernel_name(f.func) kernel = external_kernel_name(f.func)
return ExternalBackendFunction(NativeFunction.with_dispatch_entry(f, dispatch_key, kernel), dispatch_key, m) return ExternalBackendFunction(NativeFunction.with_dispatch_entry(f, dispatch_key, kernel), dispatch_key, m)
elif isinstance(g, NativeFunctionsGroup): elif isinstance(g, NativeFunctionsGroup):
out_meta = metadata.get(g.out.func.name, None) out_meta = metadata.get(g.out.func.name, None)
kernel = kernel_name(g.out.func) kernel = external_kernel_name(g.out.func)
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \ dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
if out_meta is not None and out_meta.is_autograd else DispatchKey.parse(backend) if out_meta is not None and out_meta.is_autograd else DispatchKey.parse(backend)
out = ExternalBackendFunction(NativeFunction.with_dispatch_entry(g.out, dispatch_key, kernel), dispatch_key, out_meta) out = ExternalBackendFunction(NativeFunction.with_dispatch_entry(g.out, dispatch_key, kernel), dispatch_key, out_meta)
functional_meta = metadata.get(g.functional.func.name, None) functional_meta = metadata.get(g.functional.func.name, None)
kernel = kernel_name(g.functional.func) kernel = external_kernel_name(g.functional.func)
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \ dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
if functional_meta is not None and functional_meta.is_autograd else DispatchKey.parse(backend) if functional_meta is not None and functional_meta.is_autograd else DispatchKey.parse(backend)
functional = ExternalBackendFunction( functional = ExternalBackendFunction(
@ -109,7 +111,7 @@ autograd key. They can not be mix and matched. If this is something you need, fe
inplace = None inplace = None
if g.inplace: if g.inplace:
inplace_meta = metadata.get(g.inplace.func.name, None) inplace_meta = metadata.get(g.inplace.func.name, None)
kernel = kernel_name(g.inplace.func) kernel = external_kernel_name(g.inplace.func)
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \ dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
if inplace_meta is not None and inplace_meta.is_autograd else DispatchKey.parse(backend) if inplace_meta is not None and inplace_meta.is_autograd else DispatchKey.parse(backend)
inplace = ExternalBackendFunction( inplace = ExternalBackendFunction(
@ -128,7 +130,7 @@ autograd key. They can not be mix and matched. If this is something you need, fe
assert_never(g) assert_never(g)
for op_name in metadata.keys(): for op_name in metadata.keys():
assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}" assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}"
return cpp_namespace, [native_to_external(g) for g in grouped_native_functions] return backend_key, autograd_key, cpp_namespace, [native_to_external(g) for g in grouped_native_functions]
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser(description='Generate backend stub files') parser = argparse.ArgumentParser(description='Generate backend stub files')
@ -157,7 +159,7 @@ def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml') native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
grouped_native_functions = get_grouped_native_functions(native_yaml_path) grouped_native_functions = get_grouped_native_functions(native_yaml_path)
cpp_namespace, external_backend_functions = parse_backend_yaml(source_yaml, grouped_native_functions) backend_key, autograd_key, cpp_namespace, external_backend_functions = parse_backend_yaml(source_yaml, grouped_native_functions)
native_functions = parse_native_yaml(native_yaml_path) native_functions = parse_native_yaml(native_yaml_path)
@ -171,6 +173,68 @@ def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
'dispatch_xla_declarations': list(concatMap(dest.compute_native_function_declaration, external_backend_functions)), 'dispatch_xla_declarations': list(concatMap(dest.compute_native_function_declaration, external_backend_functions)),
}) })
external_backend_functions_no_autograd = [f for f in external_backend_functions if not f.is_autograd_kernel]
external_backend_functions_autograd = [f for f in external_backend_functions if f.is_autograd_kernel]
external_backend_headers = '''\
#include <tensorflow/compiler/xla/xla_client/debug_macros.h>
#include <tensorflow/compiler/xla/xla_client/metrics.h>
#include <tensorflow/compiler/xla/xla_client/tf_logging.h>
#include <torch_xla/csrc/function_call_tracker.h>
#include <torch_xla/csrc/aten_xla_type.h>
#include <torch_xla/csrc/aten_xla_type_default.h>'''
fm.write_with_template(f'Register{backend_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
'extra_cuda_headers': '',
'legacy_th_headers': '',
'external_backend_headers': external_backend_headers,
'DispatchKey': backend_key,
'dispatch_namespace': backend_key.lower(),
'dispatch_namespaced_definitions': list(concatMap(
dest.RegisterDispatchKey(
backend_key, Target.NAMESPACED_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
external_backend_functions_no_autograd
)),
'dispatch_anonymous_definitions': list(concatMap(
dest.RegisterDispatchKey(
backend_key, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
external_backend_functions_no_autograd
)),
'dispatch_registrations': list(concatMap(
dest.RegisterDispatchKey(backend_key, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace),
external_backend_functions_no_autograd
)),
})
# If they have at least one autograd entry in their yaml file
if autograd_key is not None:
assert len(external_backend_functions_autograd) > 0
autograd_dispatchkey: DispatchKey = autograd_key # make mypy happy
fm.write_with_template(f'Register{autograd_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
'extra_cuda_headers': '',
'legacy_th_headers': '',
'external_backend_headers': external_backend_headers,
'DispatchKey': autograd_dispatchkey,
'dispatch_namespace': autograd_dispatchkey.lower(),
'dispatch_namespaced_definitions': list(concatMap(
dest.RegisterDispatchKey(
autograd_dispatchkey, Target.NAMESPACED_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
external_backend_functions_autograd
)),
'dispatch_anonymous_definitions': list(concatMap(
dest.RegisterDispatchKey(
autograd_dispatchkey, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
external_backend_functions_autograd
)),
'dispatch_registrations': list(concatMap(
dest.RegisterDispatchKey(
autograd_dispatchkey, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace),
external_backend_functions_autograd
)),
})
fm.write('aten_xla_type_default.h', lambda: { fm.write('aten_xla_type_default.h', lambda: {
'generated_comment': generated_comment, 'generated_comment': generated_comment,
'cpp_namespace': cpp_namespace, 'cpp_namespace': cpp_namespace,
@ -188,10 +252,7 @@ def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION), external_backend_functions dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION), external_backend_functions
)), )),
'dispatch_registrations': list(concatMap( 'dispatch_registrations': list(concatMap(
dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if not e.is_autograd_kernel] dest.GenExternalAtenFallback(Target.REGISTRATION), external_backend_functions
)),
'dispatch_autograd_registrations': list(concatMap(
dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if e.is_autograd_kernel]
)), )),
}) })

View File

@ -1452,11 +1452,19 @@ class ExternalBackendFunctionsGroup:
f"variant, {str(f.native_function.func.name)} will be generated for you " \ f"variant, {str(f.native_function.func.name)} will be generated for you " \
"and doesn't need to live in the yaml." "and doesn't need to live in the yaml."
def functions(self) -> Iterator[ExternalBackendFunction]: def functions(self, *, functional_first: bool = False) -> Iterator[ExternalBackendFunction]:
yield self.out if not functional_first:
yield self.functional yield self.out
if self.inplace is not None: yield self.functional
yield self.inplace if self.inplace is not None:
yield self.inplace
else:
# When we generate out/inplace wrappers for unstructured kernels, we need the functional kernel
# to be defined before we can generate the other two wrappers.
yield self.functional
yield self.out
if self.inplace is not None:
yield self.inplace
# Helper functions for parsing argument lists (both inputs and returns) # Helper functions for parsing argument lists (both inputs and returns)