mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[WIP] generate in-place/out wrappers for external kernels
ghstack-source-id: 859e07e54a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56835
This commit is contained in:
parent
10c4bf510e
commit
2fefb36870
|
|
@ -34,6 +34,7 @@
|
|||
#include <torch/library.h>
|
||||
$extra_cuda_headers
|
||||
$legacy_th_headers
|
||||
$external_backend_headers
|
||||
|
||||
namespace at {
|
||||
|
||||
|
|
|
|||
|
|
@ -108,10 +108,6 @@ ${dispatch_aten_fallback_definitions}
|
|||
TORCH_LIBRARY_IMPL(aten, XLA, m) {
|
||||
${dispatch_registrations}
|
||||
|
||||
}
|
||||
TORCH_LIBRARY_IMPL(aten, AutogradXLA, m) {
|
||||
${dispatch_autograd_registrations}
|
||||
|
||||
}
|
||||
|
||||
} // namespace torch_xla
|
||||
|
|
|
|||
|
|
@ -24,8 +24,13 @@ from typing import Sequence, List, Union
|
|||
# arguments.
|
||||
#
|
||||
|
||||
def name(func: FunctionSchema) -> str:
|
||||
return cpp.name(func)
|
||||
def name(func: FunctionSchema, *, append_overload_name: bool = False) -> str:
|
||||
name = cpp.name(func)
|
||||
if append_overload_name and func.name.overload_name != '':
|
||||
# This isn't an important characteristic of the dispatcher API, and could be removed if we want to further unify
|
||||
# The dispatcher and C++ API's. There happen to be a few places in the codegen where we need to guarantee name uniqueness.
|
||||
name = f'{name}_{func.name.overload_name}'
|
||||
return name
|
||||
|
||||
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
|
||||
# This is a faux amis. If it makes sense in the future to add
|
||||
|
|
|
|||
|
|
@ -406,11 +406,15 @@ class DispatcherSignature:
|
|||
# The schema this signature is derived from
|
||||
func: FunctionSchema
|
||||
|
||||
prefix: str
|
||||
|
||||
append_overload_name: bool
|
||||
|
||||
def arguments(self) -> List[Binding]:
|
||||
return dispatcher.arguments(self.func)
|
||||
|
||||
def name(self) -> str:
|
||||
return dispatcher.name(self.func)
|
||||
return self.prefix + dispatcher.name(self.func, append_overload_name=self.append_overload_name)
|
||||
|
||||
def decl(self, name: Optional[str] = None) -> str:
|
||||
args_str = ', '.join(a.decl() for a in self.arguments())
|
||||
|
|
@ -440,8 +444,8 @@ class DispatcherSignature:
|
|||
return f'{self.returns_type().cpp_type()} ({dispatcher_args_types_str})'
|
||||
|
||||
@staticmethod
|
||||
def from_schema(func: FunctionSchema) -> 'DispatcherSignature':
|
||||
return DispatcherSignature(func)
|
||||
def from_schema(func: FunctionSchema, *, prefix: str = '', append_overload_name: bool = False) -> 'DispatcherSignature':
|
||||
return DispatcherSignature(func, prefix, append_overload_name)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeSignature:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ F = TypeVar(
|
|||
ExternalBackendFunction,
|
||||
ExternalBackendFunctionsGroup,
|
||||
Union[NativeFunction, NativeFunctionsGroup],
|
||||
Union[NativeFunction, ExternalBackendFunction],
|
||||
Union[ExternalBackendFunctionsGroup, ExternalBackendFunction],
|
||||
Union[NativeFunction, NativeFunctionsGroup, ExternalBackendFunction, ExternalBackendFunctionsGroup]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import re
|
|||
from tools.codegen.context import method_with_native_function
|
||||
from tools.codegen.utils import Target, mapMaybe
|
||||
from tools.codegen.model import (Argument, ExternalBackendFunction,
|
||||
ExternalBackendFunctionsGroup, SchemaKind,
|
||||
ExternalBackendFunctionsGroup,
|
||||
assert_never, Return, is_generic_dispatch_key,
|
||||
ListType, OptionalType, BaseType, BaseTy, Variant)
|
||||
from tools.codegen.api.types import DispatcherSignature, CppSignatureGroup
|
||||
|
|
@ -102,7 +102,7 @@ def xla_tensor_creation_api(
|
|||
# do not have full aten coverage.
|
||||
# For operators not implemented by the external backend, our codegen
|
||||
# will register these fallbacks instead.
|
||||
# - Why do we generate fallback for ALL aten ops, including ops that
|
||||
# - Why do we generate fallback for ALL (non-composite) aten ops, including ops that
|
||||
# external backends have already implemented?
|
||||
# Many external backend kernels only work with specific input shapes,
|
||||
# and are written to call into a cpu fallback when given inputs
|
||||
|
|
@ -117,41 +117,6 @@ class GenExternalAtenFallback:
|
|||
|
||||
@method_with_native_function
|
||||
def __call__(self, g: Union[ExternalBackendFunctionsGroup, ExternalBackendFunction]) -> List[str]:
|
||||
|
||||
def gen_out_wrapper(g: ExternalBackendFunctionsGroup) -> Optional[str]:
|
||||
dispatcher_sig = DispatcherSignature.from_schema(g.out.native_function.func)
|
||||
name = dispatcher_sig.name()
|
||||
|
||||
dispatcher_order_args = dispatcher.jit_arguments(g.out.native_function.func)
|
||||
tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)]
|
||||
print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors])
|
||||
|
||||
func_name = f'AtenXlaTypeDefault::{name}'
|
||||
functional_result_name = f'{name}_tmp'
|
||||
return_names = cpp.return_names(g.out.native_function)
|
||||
if len(return_names) > 1:
|
||||
updates = '\n '.join(
|
||||
f'at::_copy_from_and_resize(std::get<{i}>({functional_result_name}), {ret_name});'
|
||||
for i, ret_name in enumerate(return_names))
|
||||
returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})'
|
||||
else:
|
||||
ret_name = return_names[0]
|
||||
updates = f'at::_copy_from_and_resize({functional_result_name}, {ret_name});'
|
||||
returns = ret_name
|
||||
|
||||
functional_sig = DispatcherSignature.from_schema(g.functional.native_function.func)
|
||||
|
||||
return f"""\
|
||||
{dispatcher_sig.defn(name=func_name)} {{
|
||||
XLA_FN_TRACK(3);
|
||||
TF_VLOG(3) << "XLA {name} :"{print_args_str};
|
||||
auto {functional_result_name} = AtenXlaType::{functional_sig.name()}({", ".join(a.name for a in functional_sig.arguments())});
|
||||
{updates}
|
||||
return {returns};
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
def gen_unstructured_external(f: ExternalBackendFunction) -> Optional[str]:
|
||||
if not requires_backend_wrapper(f):
|
||||
return None
|
||||
|
|
@ -174,8 +139,7 @@ class GenExternalAtenFallback:
|
|||
return device_like[0].name
|
||||
raise AssertionError("Need a tensor-like or device argument in order to determine the output device")
|
||||
|
||||
# XLA appears to have used the dispatcher convention to write their kernel signatures,
|
||||
# probably because they based their signatures off of our RegistrationDeclarations.h
|
||||
# See Note [External Backends Follow Dispatcher convention]
|
||||
dispatcher_sig = DispatcherSignature.from_schema(f.native_function.func)
|
||||
name = dispatcher_sig.name()
|
||||
args = dispatcher_sig.arguments()
|
||||
|
|
@ -184,25 +148,19 @@ class GenExternalAtenFallback:
|
|||
return f" static {dispatcher_sig.decl()};"
|
||||
|
||||
elif self.target is Target.REGISTRATION:
|
||||
if f.metadata is not None:
|
||||
# xla has their own kernel: register it
|
||||
namespace = 'AtenXlaType'
|
||||
else:
|
||||
# xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel).
|
||||
namespace = 'AtenXlaTypeDefault'
|
||||
payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})"
|
||||
# This codegen is only responsible for registering CPU fallback kernels
|
||||
# We also skip registrations if there is a functional backend kernel,
|
||||
# because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py).
|
||||
if f.metadata is not None or (isinstance(g, ExternalBackendFunctionsGroup) and g.functional.metadata is not None):
|
||||
return ''
|
||||
payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})"
|
||||
return f' m.impl("{f.native_function.func.name}", {payload});\n'
|
||||
|
||||
if self.target is not Target.NAMESPACED_DEFINITION:
|
||||
assert_never(self.target)
|
||||
|
||||
# Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators.
|
||||
# TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list
|
||||
if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \
|
||||
and isinstance(g, ExternalBackendFunctionsGroup):
|
||||
return gen_out_wrapper(g)
|
||||
|
||||
# Everything below here is where we generate the CPU fallback.
|
||||
# See Note [External Backends Follow Dispatcher convention]
|
||||
dispatcher_order_args = dispatcher.jit_arguments(f.native_function.func)
|
||||
|
||||
# Map each argument to it's intermediate variable name in the fallback
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ from typing_extensions import Literal
|
|||
from dataclasses import dataclass
|
||||
import textwrap
|
||||
|
||||
from tools.codegen.context import method_with_native_function
|
||||
from tools.codegen.context import method_with_native_function, with_native_function
|
||||
from tools.codegen.utils import Target, mapMaybe
|
||||
from tools.codegen.model import (DispatchKey, NativeFunction,
|
||||
NativeFunctionsGroup, SchemaKind,
|
||||
ExternalBackendFunctionsGroup, ExternalBackendFunction,
|
||||
TensorOptionsArguments, assert_never,
|
||||
is_cuda_dispatch_key,
|
||||
is_cuda_dispatch_key, BaseType, BaseTy,
|
||||
is_structured_dispatch_key)
|
||||
from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
|
||||
CppSignature, CppSignatureGroup,
|
||||
|
|
@ -17,6 +18,9 @@ from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
|
|||
NativeSignature, tensorT, NamedCType)
|
||||
import tools.codegen.api.meta as meta
|
||||
import tools.codegen.api.structured as structured
|
||||
import tools.codegen.api.dispatcher as dispatcher
|
||||
import tools.codegen.api.cpp as cpp
|
||||
from tools.codegen.dest.gen_external_aten_fallbacks import requires_backend_wrapper
|
||||
from tools.codegen.api.translate import translate
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
|
|
@ -56,14 +60,35 @@ class RegisterDispatchKey:
|
|||
# Whether or not we are actually code-genning for ROCm
|
||||
rocm: bool
|
||||
|
||||
# The namespace that the kernels are written in. This is just `at::native` for in-tree kernels.
|
||||
cpp_namespace: str
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
|
||||
def __call__(self, f: Union[
|
||||
NativeFunctionsGroup,
|
||||
NativeFunction,
|
||||
ExternalBackendFunctionsGroup,
|
||||
ExternalBackendFunction]
|
||||
) -> List[str]:
|
||||
if isinstance(f, NativeFunctionsGroup):
|
||||
if f.structured:
|
||||
return self.gen_structured(f)
|
||||
else:
|
||||
return list(mapMaybe(self.gen_unstructured, f.functions()))
|
||||
elif isinstance(f, NativeFunction):
|
||||
elif isinstance(f, ExternalBackendFunctionsGroup):
|
||||
if f.structured:
|
||||
raise AssertionError("structured kernels not implemented yet for external backends")
|
||||
elif f.primary == f.functional:
|
||||
# For external backends that specify that they'd like to primarily implement functional kernels (namely XLA),
|
||||
# we can generate anonymous wrappers for the out and in-place kernels for them, even for un-structured operators.
|
||||
# Note that we can't go the other way around (generate the functional using the out), since we don't know
|
||||
# how to create the output tensor without a meta function.
|
||||
return self.gen_out_inplace_wrappers(f)
|
||||
else:
|
||||
# For backends that implement out kernels, they need to port their ops to structured
|
||||
# if they want generated functional/inplace kernels.
|
||||
return list(mapMaybe(self.gen_unstructured, f.functions()))
|
||||
elif isinstance(f, NativeFunction) or isinstance(f, ExternalBackendFunction):
|
||||
r = self.gen_unstructured(f)
|
||||
return [] if r is None else [r]
|
||||
else:
|
||||
|
|
@ -88,12 +113,38 @@ class RegisterDispatchKey:
|
|||
self.target,
|
||||
self.selector,
|
||||
self.rocm,
|
||||
self.cpp_namespace,
|
||||
g
|
||||
)
|
||||
return list(mapMaybe(structured_gen.gen_one, g.functions()))
|
||||
|
||||
@method_with_native_function
|
||||
def gen_unstructured(self, f: NativeFunction) -> Optional[str]:
|
||||
def gen_unstructured(
|
||||
self,
|
||||
native_or_external: Union[NativeFunction, ExternalBackendFunction],
|
||||
*,
|
||||
# Only applies to ExternalBackendFunction objects.
|
||||
# True for inplace/out functions that don't have kernels, but do have corresponding functional kernels.
|
||||
is_generated_wrapper: bool = False,
|
||||
) -> Optional[str]:
|
||||
sig: Union[NativeSignature, DispatcherSignature]
|
||||
if isinstance(native_or_external, ExternalBackendFunction):
|
||||
if not requires_backend_wrapper(native_or_external) and not is_generated_wrapper:
|
||||
return None
|
||||
# If the backend doesn't have a kernel, we don't generate an anonymous wrapper or a dispatcher registration for it
|
||||
# (fallbacks to CPU are generated elsewhere).
|
||||
if native_or_external.metadata is None and not is_generated_wrapper:
|
||||
# TODO: Right now, we generate "fast-path" `at::xla::{op}` kernels for all non-composite ops,
|
||||
# including those with CPU fallbacks. That logic will have to change if CPU fallbacks become a boxed kernel.
|
||||
if self.target is not Target.NAMESPACED_DEFINITION and self.target is not Target.NAMESPACED_DECLARATION:
|
||||
return None
|
||||
f = native_or_external.native_function
|
||||
sig = self.external_backend_wrapper_sig(native_or_external)
|
||||
elif isinstance(native_or_external, NativeFunction):
|
||||
f = native_or_external
|
||||
sig = NativeSignature(f.func, prefix='wrapper_')
|
||||
else:
|
||||
assert_never(f)
|
||||
|
||||
inplace_meta = False
|
||||
if self.dispatch_key not in f.dispatch:
|
||||
if (self.dispatch_key == DispatchKey.Meta and
|
||||
|
|
@ -112,8 +163,6 @@ class RegisterDispatchKey:
|
|||
if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f):
|
||||
return None
|
||||
|
||||
sig = NativeSignature(f.func, prefix='wrapper_')
|
||||
|
||||
name = sig.name()
|
||||
returns_type = sig.returns_type().cpp_type()
|
||||
args = sig.arguments()
|
||||
|
|
@ -129,9 +178,19 @@ class RegisterDispatchKey:
|
|||
return result
|
||||
elif self.target is Target.NAMESPACED_DEFINITION:
|
||||
def generate_defn(cpp_sig: CppSignature) -> str:
|
||||
# This is needed in order for namespaced definitions to call into CPU fallbacks,
|
||||
# which live in a different namespace.
|
||||
# TODO: this logic will change if we implement a boxed CPU fallback.
|
||||
if isinstance(native_or_external, ExternalBackendFunction) \
|
||||
and native_or_external.metadata is None \
|
||||
and not is_generated_wrapper:
|
||||
# See Note [External Backends Follow Dispatcher convention]
|
||||
kernel_name = f'{self.cpp_namespace}::AtenXlaTypeDefault::{dispatcher.name(f.func)}'
|
||||
else:
|
||||
kernel_name = sig.name()
|
||||
return f"""
|
||||
{cpp_sig.defn()} {{
|
||||
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
|
||||
return {kernel_name}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
|
||||
}}
|
||||
"""
|
||||
result = generate_defn(cpp_sig_group.signature)
|
||||
|
|
@ -152,7 +211,11 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
|||
}}
|
||||
"""
|
||||
|
||||
impl_name = f"at::native::{f.dispatch[self.dispatch_key]}"
|
||||
if isinstance(native_or_external, ExternalBackendFunction):
|
||||
# TODO: remove this difference and merge the two cases when we remove xla-specific logic
|
||||
impl_name = f"{self.cpp_namespace}::AtenXlaType::{f.dispatch[self.dispatch_key]}"
|
||||
else:
|
||||
impl_name = f"{self.cpp_namespace}::{f.dispatch[self.dispatch_key]}"
|
||||
|
||||
args_exprs_str = ', '.join(a.name for a in args)
|
||||
|
||||
|
|
@ -196,12 +259,74 @@ namespace {{
|
|||
if f.manual_kernel_registration:
|
||||
return None
|
||||
else:
|
||||
dispatcher_sig = DispatcherSignature.from_schema(f.func)
|
||||
payload = f"TORCH_FN({name})"
|
||||
return f'm.impl("{f.func.name}",\n{payload});\n'
|
||||
else:
|
||||
assert_never(self.target)
|
||||
|
||||
def external_backend_wrapper_sig(self, f: ExternalBackendFunction) -> DispatcherSignature:
|
||||
# See Note [External Backends Follow Dispatcher convention]
|
||||
return DispatcherSignature.from_schema(f.native_function.func, prefix='wrapper_', append_overload_name=True)
|
||||
|
||||
def gen_out_inplace_wrappers(self, g: ExternalBackendFunctionsGroup) -> List[str]:
|
||||
def gets_generated_out_inplace_wrapper(f: ExternalBackendFunction) -> bool:
|
||||
return f.native_function.func.kind() is not SchemaKind.functional \
|
||||
and f.metadata is None and g.functional.metadata is not None
|
||||
|
||||
@with_native_function
|
||||
def gen_wrapper(f: ExternalBackendFunction) -> Optional[str]:
|
||||
# Only anonymous definitions get "special treatment" for out/inplace wrappers.
|
||||
# All other functionality can be directly pulled from gen_unstructured.
|
||||
if self.target is not Target.ANONYMOUS_DEFINITION:
|
||||
is_generated_wrapper = gets_generated_out_inplace_wrapper(f)
|
||||
return self.gen_unstructured(f, is_generated_wrapper=is_generated_wrapper)
|
||||
|
||||
if f.native_function.func.kind() is SchemaKind.functional:
|
||||
# Wrappers are generated for out/inplace kernels, using the functional kernel,
|
||||
# so the functional kernel itself is generated normally.
|
||||
return self.gen_unstructured(f)
|
||||
if f.metadata is not None:
|
||||
# If the backend provided their own out/inplace kernel, use it.
|
||||
return self.gen_unstructured(f)
|
||||
|
||||
if not gets_generated_out_inplace_wrapper(f):
|
||||
return None
|
||||
|
||||
# Special out/inplace wrapper logic starts here.
|
||||
dispatcher_sig = self.external_backend_wrapper_sig(f)
|
||||
name = dispatcher_sig.name()
|
||||
|
||||
# See Note [External Backends Follow Dispatcher convention]
|
||||
dispatcher_order_args = dispatcher.jit_arguments(f.native_function.func)
|
||||
tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)]
|
||||
print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors])
|
||||
|
||||
functional_result_name = f'{name}_tmp'
|
||||
return_names = cpp.return_names(f.native_function)
|
||||
if len(return_names) > 1:
|
||||
updates = '\n '.join(
|
||||
f'at::_copy_from_and_resize(std::get<{i}>({functional_result_name}), {ret_name});'
|
||||
for i, ret_name in enumerate(return_names))
|
||||
returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})'
|
||||
else:
|
||||
ret_name = return_names[0]
|
||||
updates = f'at::_copy_from_and_resize({functional_result_name}, {ret_name});'
|
||||
returns = ret_name
|
||||
|
||||
functional_sig = self.external_backend_wrapper_sig(g.functional)
|
||||
|
||||
return f"""\
|
||||
{dispatcher_sig.defn()} {{
|
||||
XLA_FN_TRACK(3);
|
||||
TF_VLOG(3) << "XLA {name} :"{print_args_str};
|
||||
auto {functional_result_name} = {functional_sig.name()}({", ".join(a.name for a in functional_sig.arguments())});
|
||||
{updates}
|
||||
return {returns};
|
||||
}}
|
||||
|
||||
"""
|
||||
return list(mapMaybe(gen_wrapper, g.functions(functional_first=True)))
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
#
|
||||
|
|
|
|||
|
|
@ -916,20 +916,22 @@ def main() -> None:
|
|||
'#include <ATen/LegacyTHFunctionsCPU.h>' if dispatch_key == DispatchKey.CPU else
|
||||
'#include <ATen/LegacyTHFunctionsCUDA.h>' if dispatch_key == DispatchKey.CUDA else
|
||||
'',
|
||||
'external_backend_headers': '',
|
||||
'DispatchKey': dispatch_key,
|
||||
'dispatch_namespace': dispatch_key.lower(),
|
||||
'dispatch_namespaced_definitions': list(concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
dispatch_key, Target.NAMESPACED_DEFINITION, selector, rocm=options.rocm),
|
||||
dispatch_key, Target.NAMESPACED_DEFINITION, selector, rocm=options.rocm, cpp_namespace='at::native'),
|
||||
grouped_native_functions
|
||||
)),
|
||||
'dispatch_anonymous_definitions': list(concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
dispatch_key, Target.ANONYMOUS_DEFINITION, selector, rocm=options.rocm),
|
||||
dispatch_key, Target.ANONYMOUS_DEFINITION, selector, rocm=options.rocm, cpp_namespace='at::native'),
|
||||
grouped_native_functions
|
||||
)),
|
||||
'dispatch_registrations': list(concatMap(
|
||||
dest.RegisterDispatchKey(dispatch_key, Target.REGISTRATION, selector, rocm=options.rocm),
|
||||
dest.RegisterDispatchKey(
|
||||
dispatch_key, Target.REGISTRATION, selector, rocm=options.rocm, cpp_namespace='at::native'),
|
||||
grouped_native_functions
|
||||
)),
|
||||
})
|
||||
|
|
@ -939,7 +941,7 @@ def main() -> None:
|
|||
'dispatch_namespace': dispatch_key.lower(),
|
||||
'dispatch_namespaced_declarations': list(concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
dispatch_key, Target.NAMESPACED_DECLARATION, selector, rocm=options.rocm),
|
||||
dispatch_key, Target.NAMESPACED_DECLARATION, selector, rocm=options.rocm, cpp_namespace='at::native'),
|
||||
grouped_native_functions
|
||||
)),
|
||||
})
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import pathlib
|
|||
import argparse
|
||||
import os
|
||||
import yaml
|
||||
from typing import List, Dict, Union, Tuple, Sequence
|
||||
from typing import List, Dict, Union, Tuple, Sequence, Optional
|
||||
from tools.codegen.gen import FileManager, get_grouped_native_functions, parse_native_yaml
|
||||
from tools.codegen.model import (ExternalBackendFunction, ExternalBackendFunctionsGroup,
|
||||
NativeFunction, NativeFunctionsGroup, OperatorName,
|
||||
|
|
@ -17,13 +17,14 @@ try:
|
|||
# use faster C loader if available
|
||||
from yaml import CSafeLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import SafeLoader as Loader # type: ignore
|
||||
from yaml import SafeLoader as Loader # type: ignore[misc]
|
||||
|
||||
|
||||
# Parses the backend's yaml file and returns (BackendDispatchKey, AutogradDispatchKey, cpp_namespace, backend_kernel_list)
|
||||
def parse_backend_yaml(
|
||||
backend_yaml_path: str,
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]]
|
||||
) -> Tuple[str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]:
|
||||
) -> Tuple[DispatchKey, Optional[DispatchKey], str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]:
|
||||
with open(backend_yaml_path, 'r') as f:
|
||||
yaml_values = yaml.load(f, Loader=Loader)
|
||||
assert isinstance(yaml_values, dict)
|
||||
|
|
@ -71,7 +72,8 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
|||
for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions)
|
||||
}
|
||||
|
||||
def kernel_name(func: FunctionSchema) -> str:
|
||||
def external_kernel_name(func: FunctionSchema) -> str:
|
||||
# Note [External Backends Follow Dispatcher convention]
|
||||
# For external backends, we enforce that their names and signatures match the dispatcher convention
|
||||
return dispatcher.name(func)
|
||||
|
||||
|
|
@ -83,17 +85,17 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
|||
m = metadata.get(f.func.name, None)
|
||||
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
|
||||
if m is not None and m.is_autograd else DispatchKey.parse(backend)
|
||||
kernel = kernel_name(f.func)
|
||||
kernel = external_kernel_name(f.func)
|
||||
return ExternalBackendFunction(NativeFunction.with_dispatch_entry(f, dispatch_key, kernel), dispatch_key, m)
|
||||
elif isinstance(g, NativeFunctionsGroup):
|
||||
out_meta = metadata.get(g.out.func.name, None)
|
||||
kernel = kernel_name(g.out.func)
|
||||
kernel = external_kernel_name(g.out.func)
|
||||
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
|
||||
if out_meta is not None and out_meta.is_autograd else DispatchKey.parse(backend)
|
||||
out = ExternalBackendFunction(NativeFunction.with_dispatch_entry(g.out, dispatch_key, kernel), dispatch_key, out_meta)
|
||||
|
||||
functional_meta = metadata.get(g.functional.func.name, None)
|
||||
kernel = kernel_name(g.functional.func)
|
||||
kernel = external_kernel_name(g.functional.func)
|
||||
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
|
||||
if functional_meta is not None and functional_meta.is_autograd else DispatchKey.parse(backend)
|
||||
functional = ExternalBackendFunction(
|
||||
|
|
@ -109,7 +111,7 @@ autograd key. They can not be mix and matched. If this is something you need, fe
|
|||
inplace = None
|
||||
if g.inplace:
|
||||
inplace_meta = metadata.get(g.inplace.func.name, None)
|
||||
kernel = kernel_name(g.inplace.func)
|
||||
kernel = external_kernel_name(g.inplace.func)
|
||||
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
|
||||
if inplace_meta is not None and inplace_meta.is_autograd else DispatchKey.parse(backend)
|
||||
inplace = ExternalBackendFunction(
|
||||
|
|
@ -128,7 +130,7 @@ autograd key. They can not be mix and matched. If this is something you need, fe
|
|||
assert_never(g)
|
||||
for op_name in metadata.keys():
|
||||
assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}"
|
||||
return cpp_namespace, [native_to_external(g) for g in grouped_native_functions]
|
||||
return backend_key, autograd_key, cpp_namespace, [native_to_external(g) for g in grouped_native_functions]
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description='Generate backend stub files')
|
||||
|
|
@ -157,7 +159,7 @@ def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
|
|||
|
||||
native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
|
||||
grouped_native_functions = get_grouped_native_functions(native_yaml_path)
|
||||
cpp_namespace, external_backend_functions = parse_backend_yaml(source_yaml, grouped_native_functions)
|
||||
backend_key, autograd_key, cpp_namespace, external_backend_functions = parse_backend_yaml(source_yaml, grouped_native_functions)
|
||||
|
||||
native_functions = parse_native_yaml(native_yaml_path)
|
||||
|
||||
|
|
@ -171,6 +173,68 @@ def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
|
|||
'dispatch_xla_declarations': list(concatMap(dest.compute_native_function_declaration, external_backend_functions)),
|
||||
})
|
||||
|
||||
|
||||
external_backend_functions_no_autograd = [f for f in external_backend_functions if not f.is_autograd_kernel]
|
||||
external_backend_functions_autograd = [f for f in external_backend_functions if f.is_autograd_kernel]
|
||||
|
||||
external_backend_headers = '''\
|
||||
#include <tensorflow/compiler/xla/xla_client/debug_macros.h>
|
||||
#include <tensorflow/compiler/xla/xla_client/metrics.h>
|
||||
#include <tensorflow/compiler/xla/xla_client/tf_logging.h>
|
||||
#include <torch_xla/csrc/function_call_tracker.h>
|
||||
#include <torch_xla/csrc/aten_xla_type.h>
|
||||
#include <torch_xla/csrc/aten_xla_type_default.h>'''
|
||||
|
||||
fm.write_with_template(f'Register{backend_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
|
||||
'extra_cuda_headers': '',
|
||||
'legacy_th_headers': '',
|
||||
'external_backend_headers': external_backend_headers,
|
||||
'DispatchKey': backend_key,
|
||||
'dispatch_namespace': backend_key.lower(),
|
||||
'dispatch_namespaced_definitions': list(concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
backend_key, Target.NAMESPACED_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
|
||||
external_backend_functions_no_autograd
|
||||
)),
|
||||
'dispatch_anonymous_definitions': list(concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
backend_key, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
|
||||
external_backend_functions_no_autograd
|
||||
)),
|
||||
'dispatch_registrations': list(concatMap(
|
||||
dest.RegisterDispatchKey(backend_key, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace),
|
||||
external_backend_functions_no_autograd
|
||||
)),
|
||||
})
|
||||
|
||||
# If they have at least one autograd entry in their yaml file
|
||||
if autograd_key is not None:
|
||||
assert len(external_backend_functions_autograd) > 0
|
||||
autograd_dispatchkey: DispatchKey = autograd_key # make mypy happy
|
||||
fm.write_with_template(f'Register{autograd_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
|
||||
'extra_cuda_headers': '',
|
||||
'legacy_th_headers': '',
|
||||
'external_backend_headers': external_backend_headers,
|
||||
'DispatchKey': autograd_dispatchkey,
|
||||
'dispatch_namespace': autograd_dispatchkey.lower(),
|
||||
'dispatch_namespaced_definitions': list(concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
autograd_dispatchkey, Target.NAMESPACED_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
|
||||
external_backend_functions_autograd
|
||||
)),
|
||||
'dispatch_anonymous_definitions': list(concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
autograd_dispatchkey, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
|
||||
external_backend_functions_autograd
|
||||
)),
|
||||
'dispatch_registrations': list(concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
autograd_dispatchkey, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace),
|
||||
external_backend_functions_autograd
|
||||
)),
|
||||
})
|
||||
|
||||
|
||||
fm.write('aten_xla_type_default.h', lambda: {
|
||||
'generated_comment': generated_comment,
|
||||
'cpp_namespace': cpp_namespace,
|
||||
|
|
@ -188,10 +252,7 @@ def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
|
|||
dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION), external_backend_functions
|
||||
)),
|
||||
'dispatch_registrations': list(concatMap(
|
||||
dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if not e.is_autograd_kernel]
|
||||
)),
|
||||
'dispatch_autograd_registrations': list(concatMap(
|
||||
dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if e.is_autograd_kernel]
|
||||
dest.GenExternalAtenFallback(Target.REGISTRATION), external_backend_functions
|
||||
)),
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -1452,11 +1452,19 @@ class ExternalBackendFunctionsGroup:
|
|||
f"variant, {str(f.native_function.func.name)} will be generated for you " \
|
||||
"and doesn't need to live in the yaml."
|
||||
|
||||
def functions(self) -> Iterator[ExternalBackendFunction]:
|
||||
yield self.out
|
||||
yield self.functional
|
||||
if self.inplace is not None:
|
||||
yield self.inplace
|
||||
def functions(self, *, functional_first: bool = False) -> Iterator[ExternalBackendFunction]:
|
||||
if not functional_first:
|
||||
yield self.out
|
||||
yield self.functional
|
||||
if self.inplace is not None:
|
||||
yield self.inplace
|
||||
else:
|
||||
# When we generate out/inplace wrappers for unstructured kernels, we need the functional kernel
|
||||
# to be defined before we can generate the other two wrappers.
|
||||
yield self.functional
|
||||
yield self.out
|
||||
if self.inplace is not None:
|
||||
yield self.inplace
|
||||
|
||||
|
||||
# Helper functions for parsing argument lists (both inputs and returns)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user