[WIP] generate in-place/out wrappers for external kernels

ghstack-source-id: 859e07e54a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56835
This commit is contained in:
Brian Hirsh 2021-04-26 15:32:08 -07:00
parent 10c4bf510e
commit 2fefb36870
10 changed files with 256 additions and 95 deletions

View File

@ -34,6 +34,7 @@
#include <torch/library.h>
$extra_cuda_headers
$legacy_th_headers
$external_backend_headers
namespace at {

View File

@ -108,10 +108,6 @@ ${dispatch_aten_fallback_definitions}
TORCH_LIBRARY_IMPL(aten, XLA, m) {
${dispatch_registrations}
}
TORCH_LIBRARY_IMPL(aten, AutogradXLA, m) {
${dispatch_autograd_registrations}
}
} // namespace torch_xla

View File

@ -24,8 +24,13 @@ from typing import Sequence, List, Union
# arguments.
#
def name(func: FunctionSchema) -> str:
return cpp.name(func)
def name(func: FunctionSchema, *, append_overload_name: bool = False) -> str:
name = cpp.name(func)
if append_overload_name and func.name.overload_name != '':
# This isn't an important characteristic of the dispatcher API, and could be removed if we want to further unify
# The dispatcher and C++ API's. There happen to be a few places in the codegen where we need to guarantee name uniqueness.
name = f'{name}_{func.name.overload_name}'
return name
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
# This is a faux amis. If it makes sense in the future to add

View File

@ -406,11 +406,15 @@ class DispatcherSignature:
# The schema this signature is derived from
func: FunctionSchema
prefix: str
append_overload_name: bool
def arguments(self) -> List[Binding]:
return dispatcher.arguments(self.func)
def name(self) -> str:
return dispatcher.name(self.func)
return self.prefix + dispatcher.name(self.func, append_overload_name=self.append_overload_name)
def decl(self, name: Optional[str] = None) -> str:
args_str = ', '.join(a.decl() for a in self.arguments())
@ -440,8 +444,8 @@ class DispatcherSignature:
return f'{self.returns_type().cpp_type()} ({dispatcher_args_types_str})'
@staticmethod
def from_schema(func: FunctionSchema) -> 'DispatcherSignature':
return DispatcherSignature(func)
def from_schema(func: FunctionSchema, *, prefix: str = '', append_overload_name: bool = False) -> 'DispatcherSignature':
return DispatcherSignature(func, prefix, append_overload_name)
@dataclass(frozen=True)
class NativeSignature:

View File

@ -16,6 +16,7 @@ F = TypeVar(
ExternalBackendFunction,
ExternalBackendFunctionsGroup,
Union[NativeFunction, NativeFunctionsGroup],
Union[NativeFunction, ExternalBackendFunction],
Union[ExternalBackendFunctionsGroup, ExternalBackendFunction],
Union[NativeFunction, NativeFunctionsGroup, ExternalBackendFunction, ExternalBackendFunctionsGroup]
)

View File

@ -6,7 +6,7 @@ import re
from tools.codegen.context import method_with_native_function
from tools.codegen.utils import Target, mapMaybe
from tools.codegen.model import (Argument, ExternalBackendFunction,
ExternalBackendFunctionsGroup, SchemaKind,
ExternalBackendFunctionsGroup,
assert_never, Return, is_generic_dispatch_key,
ListType, OptionalType, BaseType, BaseTy, Variant)
from tools.codegen.api.types import DispatcherSignature, CppSignatureGroup
@ -102,7 +102,7 @@ def xla_tensor_creation_api(
# do not have full aten coverage.
# For operators not implemented by the external backend, our codegen
# will register these fallbacks instead.
# - Why do we generate fallback for ALL aten ops, including ops that
# - Why do we generate fallback for ALL (non-composite) aten ops, including ops that
# external backends have already implemented?
# Many external backend kernels only work with specific input shapes,
# and are written to call into a cpu fallback when given inputs
@ -117,41 +117,6 @@ class GenExternalAtenFallback:
@method_with_native_function
def __call__(self, g: Union[ExternalBackendFunctionsGroup, ExternalBackendFunction]) -> List[str]:
def gen_out_wrapper(g: ExternalBackendFunctionsGroup) -> Optional[str]:
dispatcher_sig = DispatcherSignature.from_schema(g.out.native_function.func)
name = dispatcher_sig.name()
dispatcher_order_args = dispatcher.jit_arguments(g.out.native_function.func)
tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)]
print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors])
func_name = f'AtenXlaTypeDefault::{name}'
functional_result_name = f'{name}_tmp'
return_names = cpp.return_names(g.out.native_function)
if len(return_names) > 1:
updates = '\n '.join(
f'at::_copy_from_and_resize(std::get<{i}>({functional_result_name}), {ret_name});'
for i, ret_name in enumerate(return_names))
returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})'
else:
ret_name = return_names[0]
updates = f'at::_copy_from_and_resize({functional_result_name}, {ret_name});'
returns = ret_name
functional_sig = DispatcherSignature.from_schema(g.functional.native_function.func)
return f"""\
{dispatcher_sig.defn(name=func_name)} {{
XLA_FN_TRACK(3);
TF_VLOG(3) << "XLA {name} :"{print_args_str};
auto {functional_result_name} = AtenXlaType::{functional_sig.name()}({", ".join(a.name for a in functional_sig.arguments())});
{updates}
return {returns};
}}
"""
def gen_unstructured_external(f: ExternalBackendFunction) -> Optional[str]:
if not requires_backend_wrapper(f):
return None
@ -174,8 +139,7 @@ class GenExternalAtenFallback:
return device_like[0].name
raise AssertionError("Need a tensor-like or device argument in order to determine the output device")
# XLA appears to have used the dispatcher convention to write their kernel signatures,
# probably because they based their signatures off of our RegistrationDeclarations.h
# See Note [External Backends Follow Dispatcher convention]
dispatcher_sig = DispatcherSignature.from_schema(f.native_function.func)
name = dispatcher_sig.name()
args = dispatcher_sig.arguments()
@ -184,25 +148,19 @@ class GenExternalAtenFallback:
return f" static {dispatcher_sig.decl()};"
elif self.target is Target.REGISTRATION:
if f.metadata is not None:
# xla has their own kernel: register it
namespace = 'AtenXlaType'
else:
# xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel).
namespace = 'AtenXlaTypeDefault'
payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})"
# This codegen is only responsible for registering CPU fallback kernels
# We also skip registrations if there is a functional backend kernel,
# because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py).
if f.metadata is not None or (isinstance(g, ExternalBackendFunctionsGroup) and g.functional.metadata is not None):
return ''
payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})"
return f' m.impl("{f.native_function.func.name}", {payload});\n'
if self.target is not Target.NAMESPACED_DEFINITION:
assert_never(self.target)
# Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators.
# TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list
if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \
and isinstance(g, ExternalBackendFunctionsGroup):
return gen_out_wrapper(g)
# Everything below here is where we generate the CPU fallback.
# See Note [External Backends Follow Dispatcher convention]
dispatcher_order_args = dispatcher.jit_arguments(f.native_function.func)
# Map each argument to it's intermediate variable name in the fallback

View File

@ -4,12 +4,13 @@ from typing_extensions import Literal
from dataclasses import dataclass
import textwrap
from tools.codegen.context import method_with_native_function
from tools.codegen.context import method_with_native_function, with_native_function
from tools.codegen.utils import Target, mapMaybe
from tools.codegen.model import (DispatchKey, NativeFunction,
NativeFunctionsGroup, SchemaKind,
ExternalBackendFunctionsGroup, ExternalBackendFunction,
TensorOptionsArguments, assert_never,
is_cuda_dispatch_key,
is_cuda_dispatch_key, BaseType, BaseTy,
is_structured_dispatch_key)
from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
CppSignature, CppSignatureGroup,
@ -17,6 +18,9 @@ from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
NativeSignature, tensorT, NamedCType)
import tools.codegen.api.meta as meta
import tools.codegen.api.structured as structured
import tools.codegen.api.dispatcher as dispatcher
import tools.codegen.api.cpp as cpp
from tools.codegen.dest.gen_external_aten_fallbacks import requires_backend_wrapper
from tools.codegen.api.translate import translate
from tools.codegen.selective_build.selector import SelectiveBuilder
@ -56,14 +60,35 @@ class RegisterDispatchKey:
# Whether or not we are actually code-genning for ROCm
rocm: bool
# The namespace that the kernels are written in. This is just `at::native` for in-tree kernels.
cpp_namespace: str
@method_with_native_function
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
def __call__(self, f: Union[
NativeFunctionsGroup,
NativeFunction,
ExternalBackendFunctionsGroup,
ExternalBackendFunction]
) -> List[str]:
if isinstance(f, NativeFunctionsGroup):
if f.structured:
return self.gen_structured(f)
else:
return list(mapMaybe(self.gen_unstructured, f.functions()))
elif isinstance(f, NativeFunction):
elif isinstance(f, ExternalBackendFunctionsGroup):
if f.structured:
raise AssertionError("structured kernels not implemented yet for external backends")
elif f.primary == f.functional:
# For external backends that specify that they'd like to primarily implement functional kernels (namely XLA),
# we can generate anonymous wrappers for the out and in-place kernels for them, even for un-structured operators.
# Note that we can't go the other way around (generate the functional using the out), since we don't know
# how to create the output tensor without a meta function.
return self.gen_out_inplace_wrappers(f)
else:
# For backends that implement out kernels, they need to port their ops to structured
# if they want generated functional/inplace kernels.
return list(mapMaybe(self.gen_unstructured, f.functions()))
elif isinstance(f, NativeFunction) or isinstance(f, ExternalBackendFunction):
r = self.gen_unstructured(f)
return [] if r is None else [r]
else:
@ -88,12 +113,38 @@ class RegisterDispatchKey:
self.target,
self.selector,
self.rocm,
self.cpp_namespace,
g
)
return list(mapMaybe(structured_gen.gen_one, g.functions()))
@method_with_native_function
def gen_unstructured(self, f: NativeFunction) -> Optional[str]:
def gen_unstructured(
self,
native_or_external: Union[NativeFunction, ExternalBackendFunction],
*,
# Only applies to ExternalBackendFunction objects.
# True for inplace/out functions that don't have kernels, but do have corresponding functional kernels.
is_generated_wrapper: bool = False,
) -> Optional[str]:
sig: Union[NativeSignature, DispatcherSignature]
if isinstance(native_or_external, ExternalBackendFunction):
if not requires_backend_wrapper(native_or_external) and not is_generated_wrapper:
return None
# If the backend doesn't have a kernel, we don't generate an anonymous wrapper or a dispatcher registration for it
# (fallbacks to CPU are generated elsewhere).
if native_or_external.metadata is None and not is_generated_wrapper:
# TODO: Right now, we generate "fast-path" `at::xla::{op}` kernels for all non-composite ops,
# including those with CPU fallbacks. That logic will have to change if CPU fallbacks become a boxed kernel.
if self.target is not Target.NAMESPACED_DEFINITION and self.target is not Target.NAMESPACED_DECLARATION:
return None
f = native_or_external.native_function
sig = self.external_backend_wrapper_sig(native_or_external)
elif isinstance(native_or_external, NativeFunction):
f = native_or_external
sig = NativeSignature(f.func, prefix='wrapper_')
else:
assert_never(f)
inplace_meta = False
if self.dispatch_key not in f.dispatch:
if (self.dispatch_key == DispatchKey.Meta and
@ -112,8 +163,6 @@ class RegisterDispatchKey:
if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f):
return None
sig = NativeSignature(f.func, prefix='wrapper_')
name = sig.name()
returns_type = sig.returns_type().cpp_type()
args = sig.arguments()
@ -129,9 +178,19 @@ class RegisterDispatchKey:
return result
elif self.target is Target.NAMESPACED_DEFINITION:
def generate_defn(cpp_sig: CppSignature) -> str:
# This is needed in order for namespaced definitions to call into CPU fallbacks,
# which live in a different namespace.
# TODO: this logic will change if we implement a boxed CPU fallback.
if isinstance(native_or_external, ExternalBackendFunction) \
and native_or_external.metadata is None \
and not is_generated_wrapper:
# See Note [External Backends Follow Dispatcher convention]
kernel_name = f'{self.cpp_namespace}::AtenXlaTypeDefault::{dispatcher.name(f.func)}'
else:
kernel_name = sig.name()
return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
return {kernel_name}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""
result = generate_defn(cpp_sig_group.signature)
@ -152,7 +211,11 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
}}
"""
impl_name = f"at::native::{f.dispatch[self.dispatch_key]}"
if isinstance(native_or_external, ExternalBackendFunction):
# TODO: remove this difference and merge the two cases when we remove xla-specific logic
impl_name = f"{self.cpp_namespace}::AtenXlaType::{f.dispatch[self.dispatch_key]}"
else:
impl_name = f"{self.cpp_namespace}::{f.dispatch[self.dispatch_key]}"
args_exprs_str = ', '.join(a.name for a in args)
@ -196,12 +259,74 @@ namespace {{
if f.manual_kernel_registration:
return None
else:
dispatcher_sig = DispatcherSignature.from_schema(f.func)
payload = f"TORCH_FN({name})"
return f'm.impl("{f.func.name}",\n{payload});\n'
else:
assert_never(self.target)
def external_backend_wrapper_sig(self, f: ExternalBackendFunction) -> DispatcherSignature:
# See Note [External Backends Follow Dispatcher convention]
return DispatcherSignature.from_schema(f.native_function.func, prefix='wrapper_', append_overload_name=True)
def gen_out_inplace_wrappers(self, g: ExternalBackendFunctionsGroup) -> List[str]:
def gets_generated_out_inplace_wrapper(f: ExternalBackendFunction) -> bool:
return f.native_function.func.kind() is not SchemaKind.functional \
and f.metadata is None and g.functional.metadata is not None
@with_native_function
def gen_wrapper(f: ExternalBackendFunction) -> Optional[str]:
# Only anonymous definitions get "special treatment" for out/inplace wrappers.
# All other functionality can be directly pulled from gen_unstructured.
if self.target is not Target.ANONYMOUS_DEFINITION:
is_generated_wrapper = gets_generated_out_inplace_wrapper(f)
return self.gen_unstructured(f, is_generated_wrapper=is_generated_wrapper)
if f.native_function.func.kind() is SchemaKind.functional:
# Wrappers are generated for out/inplace kernels, using the functional kernel,
# so the functional kernel itself is generated normally.
return self.gen_unstructured(f)
if f.metadata is not None:
# If the backend provided their own out/inplace kernel, use it.
return self.gen_unstructured(f)
if not gets_generated_out_inplace_wrapper(f):
return None
# Special out/inplace wrapper logic starts here.
dispatcher_sig = self.external_backend_wrapper_sig(f)
name = dispatcher_sig.name()
# See Note [External Backends Follow Dispatcher convention]
dispatcher_order_args = dispatcher.jit_arguments(f.native_function.func)
tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)]
print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors])
functional_result_name = f'{name}_tmp'
return_names = cpp.return_names(f.native_function)
if len(return_names) > 1:
updates = '\n '.join(
f'at::_copy_from_and_resize(std::get<{i}>({functional_result_name}), {ret_name});'
for i, ret_name in enumerate(return_names))
returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})'
else:
ret_name = return_names[0]
updates = f'at::_copy_from_and_resize({functional_result_name}, {ret_name});'
returns = ret_name
functional_sig = self.external_backend_wrapper_sig(g.functional)
return f"""\
{dispatcher_sig.defn()} {{
XLA_FN_TRACK(3);
TF_VLOG(3) << "XLA {name} :"{print_args_str};
auto {functional_result_name} = {functional_sig.name()}({", ".join(a.name for a in functional_sig.arguments())});
{updates}
return {returns};
}}
"""
return list(mapMaybe(gen_wrapper, g.functions(functional_first=True)))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#

View File

@ -916,20 +916,22 @@ def main() -> None:
'#include <ATen/LegacyTHFunctionsCPU.h>' if dispatch_key == DispatchKey.CPU else
'#include <ATen/LegacyTHFunctionsCUDA.h>' if dispatch_key == DispatchKey.CUDA else
'',
'external_backend_headers': '',
'DispatchKey': dispatch_key,
'dispatch_namespace': dispatch_key.lower(),
'dispatch_namespaced_definitions': list(concatMap(
dest.RegisterDispatchKey(
dispatch_key, Target.NAMESPACED_DEFINITION, selector, rocm=options.rocm),
dispatch_key, Target.NAMESPACED_DEFINITION, selector, rocm=options.rocm, cpp_namespace='at::native'),
grouped_native_functions
)),
'dispatch_anonymous_definitions': list(concatMap(
dest.RegisterDispatchKey(
dispatch_key, Target.ANONYMOUS_DEFINITION, selector, rocm=options.rocm),
dispatch_key, Target.ANONYMOUS_DEFINITION, selector, rocm=options.rocm, cpp_namespace='at::native'),
grouped_native_functions
)),
'dispatch_registrations': list(concatMap(
dest.RegisterDispatchKey(dispatch_key, Target.REGISTRATION, selector, rocm=options.rocm),
dest.RegisterDispatchKey(
dispatch_key, Target.REGISTRATION, selector, rocm=options.rocm, cpp_namespace='at::native'),
grouped_native_functions
)),
})
@ -939,7 +941,7 @@ def main() -> None:
'dispatch_namespace': dispatch_key.lower(),
'dispatch_namespaced_declarations': list(concatMap(
dest.RegisterDispatchKey(
dispatch_key, Target.NAMESPACED_DECLARATION, selector, rocm=options.rocm),
dispatch_key, Target.NAMESPACED_DECLARATION, selector, rocm=options.rocm, cpp_namespace='at::native'),
grouped_native_functions
)),
})

View File

@ -2,7 +2,7 @@ import pathlib
import argparse
import os
import yaml
from typing import List, Dict, Union, Tuple, Sequence
from typing import List, Dict, Union, Tuple, Sequence, Optional
from tools.codegen.gen import FileManager, get_grouped_native_functions, parse_native_yaml
from tools.codegen.model import (ExternalBackendFunction, ExternalBackendFunctionsGroup,
NativeFunction, NativeFunctionsGroup, OperatorName,
@ -17,13 +17,14 @@ try:
# use faster C loader if available
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore
from yaml import SafeLoader as Loader # type: ignore[misc]
# Parses the backend's yaml file and returns (BackendDispatchKey, AutogradDispatchKey, cpp_namespace, backend_kernel_list)
def parse_backend_yaml(
backend_yaml_path: str,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]]
) -> Tuple[str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]:
) -> Tuple[DispatchKey, Optional[DispatchKey], str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]:
with open(backend_yaml_path, 'r') as f:
yaml_values = yaml.load(f, Loader=Loader)
assert isinstance(yaml_values, dict)
@ -71,7 +72,8 @@ Only the following keys are supported: {", ".join(valid_keys)}'
for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions)
}
def kernel_name(func: FunctionSchema) -> str:
def external_kernel_name(func: FunctionSchema) -> str:
# Note [External Backends Follow Dispatcher convention]
# For external backends, we enforce that their names and signatures match the dispatcher convention
return dispatcher.name(func)
@ -83,17 +85,17 @@ Only the following keys are supported: {", ".join(valid_keys)}'
m = metadata.get(f.func.name, None)
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
if m is not None and m.is_autograd else DispatchKey.parse(backend)
kernel = kernel_name(f.func)
kernel = external_kernel_name(f.func)
return ExternalBackendFunction(NativeFunction.with_dispatch_entry(f, dispatch_key, kernel), dispatch_key, m)
elif isinstance(g, NativeFunctionsGroup):
out_meta = metadata.get(g.out.func.name, None)
kernel = kernel_name(g.out.func)
kernel = external_kernel_name(g.out.func)
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
if out_meta is not None and out_meta.is_autograd else DispatchKey.parse(backend)
out = ExternalBackendFunction(NativeFunction.with_dispatch_entry(g.out, dispatch_key, kernel), dispatch_key, out_meta)
functional_meta = metadata.get(g.functional.func.name, None)
kernel = kernel_name(g.functional.func)
kernel = external_kernel_name(g.functional.func)
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
if functional_meta is not None and functional_meta.is_autograd else DispatchKey.parse(backend)
functional = ExternalBackendFunction(
@ -109,7 +111,7 @@ autograd key. They can not be mix and matched. If this is something you need, fe
inplace = None
if g.inplace:
inplace_meta = metadata.get(g.inplace.func.name, None)
kernel = kernel_name(g.inplace.func)
kernel = external_kernel_name(g.inplace.func)
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
if inplace_meta is not None and inplace_meta.is_autograd else DispatchKey.parse(backend)
inplace = ExternalBackendFunction(
@ -128,7 +130,7 @@ autograd key. They can not be mix and matched. If this is something you need, fe
assert_never(g)
for op_name in metadata.keys():
assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}"
return cpp_namespace, [native_to_external(g) for g in grouped_native_functions]
return backend_key, autograd_key, cpp_namespace, [native_to_external(g) for g in grouped_native_functions]
def main() -> None:
parser = argparse.ArgumentParser(description='Generate backend stub files')
@ -157,7 +159,7 @@ def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
grouped_native_functions = get_grouped_native_functions(native_yaml_path)
cpp_namespace, external_backend_functions = parse_backend_yaml(source_yaml, grouped_native_functions)
backend_key, autograd_key, cpp_namespace, external_backend_functions = parse_backend_yaml(source_yaml, grouped_native_functions)
native_functions = parse_native_yaml(native_yaml_path)
@ -171,6 +173,68 @@ def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
'dispatch_xla_declarations': list(concatMap(dest.compute_native_function_declaration, external_backend_functions)),
})
external_backend_functions_no_autograd = [f for f in external_backend_functions if not f.is_autograd_kernel]
external_backend_functions_autograd = [f for f in external_backend_functions if f.is_autograd_kernel]
external_backend_headers = '''\
#include <tensorflow/compiler/xla/xla_client/debug_macros.h>
#include <tensorflow/compiler/xla/xla_client/metrics.h>
#include <tensorflow/compiler/xla/xla_client/tf_logging.h>
#include <torch_xla/csrc/function_call_tracker.h>
#include <torch_xla/csrc/aten_xla_type.h>
#include <torch_xla/csrc/aten_xla_type_default.h>'''
fm.write_with_template(f'Register{backend_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
'extra_cuda_headers': '',
'legacy_th_headers': '',
'external_backend_headers': external_backend_headers,
'DispatchKey': backend_key,
'dispatch_namespace': backend_key.lower(),
'dispatch_namespaced_definitions': list(concatMap(
dest.RegisterDispatchKey(
backend_key, Target.NAMESPACED_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
external_backend_functions_no_autograd
)),
'dispatch_anonymous_definitions': list(concatMap(
dest.RegisterDispatchKey(
backend_key, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
external_backend_functions_no_autograd
)),
'dispatch_registrations': list(concatMap(
dest.RegisterDispatchKey(backend_key, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace),
external_backend_functions_no_autograd
)),
})
# If they have at least one autograd entry in their yaml file
if autograd_key is not None:
assert len(external_backend_functions_autograd) > 0
autograd_dispatchkey: DispatchKey = autograd_key # make mypy happy
fm.write_with_template(f'Register{autograd_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
'extra_cuda_headers': '',
'legacy_th_headers': '',
'external_backend_headers': external_backend_headers,
'DispatchKey': autograd_dispatchkey,
'dispatch_namespace': autograd_dispatchkey.lower(),
'dispatch_namespaced_definitions': list(concatMap(
dest.RegisterDispatchKey(
autograd_dispatchkey, Target.NAMESPACED_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
external_backend_functions_autograd
)),
'dispatch_anonymous_definitions': list(concatMap(
dest.RegisterDispatchKey(
autograd_dispatchkey, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
external_backend_functions_autograd
)),
'dispatch_registrations': list(concatMap(
dest.RegisterDispatchKey(
autograd_dispatchkey, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace),
external_backend_functions_autograd
)),
})
fm.write('aten_xla_type_default.h', lambda: {
'generated_comment': generated_comment,
'cpp_namespace': cpp_namespace,
@ -188,10 +252,7 @@ def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION), external_backend_functions
)),
'dispatch_registrations': list(concatMap(
dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if not e.is_autograd_kernel]
)),
'dispatch_autograd_registrations': list(concatMap(
dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if e.is_autograd_kernel]
dest.GenExternalAtenFallback(Target.REGISTRATION), external_backend_functions
)),
})

View File

@ -1452,11 +1452,19 @@ class ExternalBackendFunctionsGroup:
f"variant, {str(f.native_function.func.name)} will be generated for you " \
"and doesn't need to live in the yaml."
def functions(self) -> Iterator[ExternalBackendFunction]:
yield self.out
yield self.functional
if self.inplace is not None:
yield self.inplace
def functions(self, *, functional_first: bool = False) -> Iterator[ExternalBackendFunction]:
if not functional_first:
yield self.out
yield self.functional
if self.inplace is not None:
yield self.inplace
else:
# When we generate out/inplace wrappers for unstructured kernels, we need the functional kernel
# to be defined before we can generate the other two wrappers.
yield self.functional
yield self.out
if self.inplace is not None:
yield self.inplace
# Helper functions for parsing argument lists (both inputs and returns)