mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[WIP] generate in-place/out wrappers for external kernels
ghstack-source-id: 859e07e54a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56835
This commit is contained in:
parent
10c4bf510e
commit
2fefb36870
|
|
@ -34,6 +34,7 @@
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
$extra_cuda_headers
|
$extra_cuda_headers
|
||||||
$legacy_th_headers
|
$legacy_th_headers
|
||||||
|
$external_backend_headers
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -108,10 +108,6 @@ ${dispatch_aten_fallback_definitions}
|
||||||
TORCH_LIBRARY_IMPL(aten, XLA, m) {
|
TORCH_LIBRARY_IMPL(aten, XLA, m) {
|
||||||
${dispatch_registrations}
|
${dispatch_registrations}
|
||||||
|
|
||||||
}
|
|
||||||
TORCH_LIBRARY_IMPL(aten, AutogradXLA, m) {
|
|
||||||
${dispatch_autograd_registrations}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace torch_xla
|
} // namespace torch_xla
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,13 @@ from typing import Sequence, List, Union
|
||||||
# arguments.
|
# arguments.
|
||||||
#
|
#
|
||||||
|
|
||||||
def name(func: FunctionSchema) -> str:
|
def name(func: FunctionSchema, *, append_overload_name: bool = False) -> str:
|
||||||
return cpp.name(func)
|
name = cpp.name(func)
|
||||||
|
if append_overload_name and func.name.overload_name != '':
|
||||||
|
# This isn't an important characteristic of the dispatcher API, and could be removed if we want to further unify
|
||||||
|
# The dispatcher and C++ API's. There happen to be a few places in the codegen where we need to guarantee name uniqueness.
|
||||||
|
name = f'{name}_{func.name.overload_name}'
|
||||||
|
return name
|
||||||
|
|
||||||
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
|
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
|
||||||
# This is a faux amis. If it makes sense in the future to add
|
# This is a faux amis. If it makes sense in the future to add
|
||||||
|
|
|
||||||
|
|
@ -406,11 +406,15 @@ class DispatcherSignature:
|
||||||
# The schema this signature is derived from
|
# The schema this signature is derived from
|
||||||
func: FunctionSchema
|
func: FunctionSchema
|
||||||
|
|
||||||
|
prefix: str
|
||||||
|
|
||||||
|
append_overload_name: bool
|
||||||
|
|
||||||
def arguments(self) -> List[Binding]:
|
def arguments(self) -> List[Binding]:
|
||||||
return dispatcher.arguments(self.func)
|
return dispatcher.arguments(self.func)
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return dispatcher.name(self.func)
|
return self.prefix + dispatcher.name(self.func, append_overload_name=self.append_overload_name)
|
||||||
|
|
||||||
def decl(self, name: Optional[str] = None) -> str:
|
def decl(self, name: Optional[str] = None) -> str:
|
||||||
args_str = ', '.join(a.decl() for a in self.arguments())
|
args_str = ', '.join(a.decl() for a in self.arguments())
|
||||||
|
|
@ -440,8 +444,8 @@ class DispatcherSignature:
|
||||||
return f'{self.returns_type().cpp_type()} ({dispatcher_args_types_str})'
|
return f'{self.returns_type().cpp_type()} ({dispatcher_args_types_str})'
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_schema(func: FunctionSchema) -> 'DispatcherSignature':
|
def from_schema(func: FunctionSchema, *, prefix: str = '', append_overload_name: bool = False) -> 'DispatcherSignature':
|
||||||
return DispatcherSignature(func)
|
return DispatcherSignature(func, prefix, append_overload_name)
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class NativeSignature:
|
class NativeSignature:
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ F = TypeVar(
|
||||||
ExternalBackendFunction,
|
ExternalBackendFunction,
|
||||||
ExternalBackendFunctionsGroup,
|
ExternalBackendFunctionsGroup,
|
||||||
Union[NativeFunction, NativeFunctionsGroup],
|
Union[NativeFunction, NativeFunctionsGroup],
|
||||||
|
Union[NativeFunction, ExternalBackendFunction],
|
||||||
Union[ExternalBackendFunctionsGroup, ExternalBackendFunction],
|
Union[ExternalBackendFunctionsGroup, ExternalBackendFunction],
|
||||||
Union[NativeFunction, NativeFunctionsGroup, ExternalBackendFunction, ExternalBackendFunctionsGroup]
|
Union[NativeFunction, NativeFunctionsGroup, ExternalBackendFunction, ExternalBackendFunctionsGroup]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import re
|
||||||
from tools.codegen.context import method_with_native_function
|
from tools.codegen.context import method_with_native_function
|
||||||
from tools.codegen.utils import Target, mapMaybe
|
from tools.codegen.utils import Target, mapMaybe
|
||||||
from tools.codegen.model import (Argument, ExternalBackendFunction,
|
from tools.codegen.model import (Argument, ExternalBackendFunction,
|
||||||
ExternalBackendFunctionsGroup, SchemaKind,
|
ExternalBackendFunctionsGroup,
|
||||||
assert_never, Return, is_generic_dispatch_key,
|
assert_never, Return, is_generic_dispatch_key,
|
||||||
ListType, OptionalType, BaseType, BaseTy, Variant)
|
ListType, OptionalType, BaseType, BaseTy, Variant)
|
||||||
from tools.codegen.api.types import DispatcherSignature, CppSignatureGroup
|
from tools.codegen.api.types import DispatcherSignature, CppSignatureGroup
|
||||||
|
|
@ -102,7 +102,7 @@ def xla_tensor_creation_api(
|
||||||
# do not have full aten coverage.
|
# do not have full aten coverage.
|
||||||
# For operators not implemented by the external backend, our codegen
|
# For operators not implemented by the external backend, our codegen
|
||||||
# will register these fallbacks instead.
|
# will register these fallbacks instead.
|
||||||
# - Why do we generate fallback for ALL aten ops, including ops that
|
# - Why do we generate fallback for ALL (non-composite) aten ops, including ops that
|
||||||
# external backends have already implemented?
|
# external backends have already implemented?
|
||||||
# Many external backend kernels only work with specific input shapes,
|
# Many external backend kernels only work with specific input shapes,
|
||||||
# and are written to call into a cpu fallback when given inputs
|
# and are written to call into a cpu fallback when given inputs
|
||||||
|
|
@ -117,41 +117,6 @@ class GenExternalAtenFallback:
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, g: Union[ExternalBackendFunctionsGroup, ExternalBackendFunction]) -> List[str]:
|
def __call__(self, g: Union[ExternalBackendFunctionsGroup, ExternalBackendFunction]) -> List[str]:
|
||||||
|
|
||||||
def gen_out_wrapper(g: ExternalBackendFunctionsGroup) -> Optional[str]:
|
|
||||||
dispatcher_sig = DispatcherSignature.from_schema(g.out.native_function.func)
|
|
||||||
name = dispatcher_sig.name()
|
|
||||||
|
|
||||||
dispatcher_order_args = dispatcher.jit_arguments(g.out.native_function.func)
|
|
||||||
tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)]
|
|
||||||
print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors])
|
|
||||||
|
|
||||||
func_name = f'AtenXlaTypeDefault::{name}'
|
|
||||||
functional_result_name = f'{name}_tmp'
|
|
||||||
return_names = cpp.return_names(g.out.native_function)
|
|
||||||
if len(return_names) > 1:
|
|
||||||
updates = '\n '.join(
|
|
||||||
f'at::_copy_from_and_resize(std::get<{i}>({functional_result_name}), {ret_name});'
|
|
||||||
for i, ret_name in enumerate(return_names))
|
|
||||||
returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})'
|
|
||||||
else:
|
|
||||||
ret_name = return_names[0]
|
|
||||||
updates = f'at::_copy_from_and_resize({functional_result_name}, {ret_name});'
|
|
||||||
returns = ret_name
|
|
||||||
|
|
||||||
functional_sig = DispatcherSignature.from_schema(g.functional.native_function.func)
|
|
||||||
|
|
||||||
return f"""\
|
|
||||||
{dispatcher_sig.defn(name=func_name)} {{
|
|
||||||
XLA_FN_TRACK(3);
|
|
||||||
TF_VLOG(3) << "XLA {name} :"{print_args_str};
|
|
||||||
auto {functional_result_name} = AtenXlaType::{functional_sig.name()}({", ".join(a.name for a in functional_sig.arguments())});
|
|
||||||
{updates}
|
|
||||||
return {returns};
|
|
||||||
}}
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def gen_unstructured_external(f: ExternalBackendFunction) -> Optional[str]:
|
def gen_unstructured_external(f: ExternalBackendFunction) -> Optional[str]:
|
||||||
if not requires_backend_wrapper(f):
|
if not requires_backend_wrapper(f):
|
||||||
return None
|
return None
|
||||||
|
|
@ -174,8 +139,7 @@ class GenExternalAtenFallback:
|
||||||
return device_like[0].name
|
return device_like[0].name
|
||||||
raise AssertionError("Need a tensor-like or device argument in order to determine the output device")
|
raise AssertionError("Need a tensor-like or device argument in order to determine the output device")
|
||||||
|
|
||||||
# XLA appears to have used the dispatcher convention to write their kernel signatures,
|
# See Note [External Backends Follow Dispatcher convention]
|
||||||
# probably because they based their signatures off of our RegistrationDeclarations.h
|
|
||||||
dispatcher_sig = DispatcherSignature.from_schema(f.native_function.func)
|
dispatcher_sig = DispatcherSignature.from_schema(f.native_function.func)
|
||||||
name = dispatcher_sig.name()
|
name = dispatcher_sig.name()
|
||||||
args = dispatcher_sig.arguments()
|
args = dispatcher_sig.arguments()
|
||||||
|
|
@ -184,25 +148,19 @@ class GenExternalAtenFallback:
|
||||||
return f" static {dispatcher_sig.decl()};"
|
return f" static {dispatcher_sig.decl()};"
|
||||||
|
|
||||||
elif self.target is Target.REGISTRATION:
|
elif self.target is Target.REGISTRATION:
|
||||||
if f.metadata is not None:
|
# This codegen is only responsible for registering CPU fallback kernels
|
||||||
# xla has their own kernel: register it
|
# We also skip registrations if there is a functional backend kernel,
|
||||||
namespace = 'AtenXlaType'
|
# because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py).
|
||||||
else:
|
if f.metadata is not None or (isinstance(g, ExternalBackendFunctionsGroup) and g.functional.metadata is not None):
|
||||||
# xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel).
|
return ''
|
||||||
namespace = 'AtenXlaTypeDefault'
|
payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})"
|
||||||
payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})"
|
|
||||||
return f' m.impl("{f.native_function.func.name}", {payload});\n'
|
return f' m.impl("{f.native_function.func.name}", {payload});\n'
|
||||||
|
|
||||||
if self.target is not Target.NAMESPACED_DEFINITION:
|
if self.target is not Target.NAMESPACED_DEFINITION:
|
||||||
assert_never(self.target)
|
assert_never(self.target)
|
||||||
|
|
||||||
# Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators.
|
|
||||||
# TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list
|
|
||||||
if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \
|
|
||||||
and isinstance(g, ExternalBackendFunctionsGroup):
|
|
||||||
return gen_out_wrapper(g)
|
|
||||||
|
|
||||||
# Everything below here is where we generate the CPU fallback.
|
# Everything below here is where we generate the CPU fallback.
|
||||||
|
# See Note [External Backends Follow Dispatcher convention]
|
||||||
dispatcher_order_args = dispatcher.jit_arguments(f.native_function.func)
|
dispatcher_order_args = dispatcher.jit_arguments(f.native_function.func)
|
||||||
|
|
||||||
# Map each argument to it's intermediate variable name in the fallback
|
# Map each argument to it's intermediate variable name in the fallback
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,13 @@ from typing_extensions import Literal
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from tools.codegen.context import method_with_native_function
|
from tools.codegen.context import method_with_native_function, with_native_function
|
||||||
from tools.codegen.utils import Target, mapMaybe
|
from tools.codegen.utils import Target, mapMaybe
|
||||||
from tools.codegen.model import (DispatchKey, NativeFunction,
|
from tools.codegen.model import (DispatchKey, NativeFunction,
|
||||||
NativeFunctionsGroup, SchemaKind,
|
NativeFunctionsGroup, SchemaKind,
|
||||||
|
ExternalBackendFunctionsGroup, ExternalBackendFunction,
|
||||||
TensorOptionsArguments, assert_never,
|
TensorOptionsArguments, assert_never,
|
||||||
is_cuda_dispatch_key,
|
is_cuda_dispatch_key, BaseType, BaseTy,
|
||||||
is_structured_dispatch_key)
|
is_structured_dispatch_key)
|
||||||
from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
|
from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
|
||||||
CppSignature, CppSignatureGroup,
|
CppSignature, CppSignatureGroup,
|
||||||
|
|
@ -17,6 +18,9 @@ from tools.codegen.api.types import (BaseCType, Binding, ConstRefCType,
|
||||||
NativeSignature, tensorT, NamedCType)
|
NativeSignature, tensorT, NamedCType)
|
||||||
import tools.codegen.api.meta as meta
|
import tools.codegen.api.meta as meta
|
||||||
import tools.codegen.api.structured as structured
|
import tools.codegen.api.structured as structured
|
||||||
|
import tools.codegen.api.dispatcher as dispatcher
|
||||||
|
import tools.codegen.api.cpp as cpp
|
||||||
|
from tools.codegen.dest.gen_external_aten_fallbacks import requires_backend_wrapper
|
||||||
from tools.codegen.api.translate import translate
|
from tools.codegen.api.translate import translate
|
||||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||||
|
|
||||||
|
|
@ -56,14 +60,35 @@ class RegisterDispatchKey:
|
||||||
# Whether or not we are actually code-genning for ROCm
|
# Whether or not we are actually code-genning for ROCm
|
||||||
rocm: bool
|
rocm: bool
|
||||||
|
|
||||||
|
# The namespace that the kernels are written in. This is just `at::native` for in-tree kernels.
|
||||||
|
cpp_namespace: str
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
|
def __call__(self, f: Union[
|
||||||
|
NativeFunctionsGroup,
|
||||||
|
NativeFunction,
|
||||||
|
ExternalBackendFunctionsGroup,
|
||||||
|
ExternalBackendFunction]
|
||||||
|
) -> List[str]:
|
||||||
if isinstance(f, NativeFunctionsGroup):
|
if isinstance(f, NativeFunctionsGroup):
|
||||||
if f.structured:
|
if f.structured:
|
||||||
return self.gen_structured(f)
|
return self.gen_structured(f)
|
||||||
else:
|
else:
|
||||||
return list(mapMaybe(self.gen_unstructured, f.functions()))
|
return list(mapMaybe(self.gen_unstructured, f.functions()))
|
||||||
elif isinstance(f, NativeFunction):
|
elif isinstance(f, ExternalBackendFunctionsGroup):
|
||||||
|
if f.structured:
|
||||||
|
raise AssertionError("structured kernels not implemented yet for external backends")
|
||||||
|
elif f.primary == f.functional:
|
||||||
|
# For external backends that specify that they'd like to primarily implement functional kernels (namely XLA),
|
||||||
|
# we can generate anonymous wrappers for the out and in-place kernels for them, even for un-structured operators.
|
||||||
|
# Note that we can't go the other way around (generate the functional using the out), since we don't know
|
||||||
|
# how to create the output tensor without a meta function.
|
||||||
|
return self.gen_out_inplace_wrappers(f)
|
||||||
|
else:
|
||||||
|
# For backends that implement out kernels, they need to port their ops to structured
|
||||||
|
# if they want generated functional/inplace kernels.
|
||||||
|
return list(mapMaybe(self.gen_unstructured, f.functions()))
|
||||||
|
elif isinstance(f, NativeFunction) or isinstance(f, ExternalBackendFunction):
|
||||||
r = self.gen_unstructured(f)
|
r = self.gen_unstructured(f)
|
||||||
return [] if r is None else [r]
|
return [] if r is None else [r]
|
||||||
else:
|
else:
|
||||||
|
|
@ -88,12 +113,38 @@ class RegisterDispatchKey:
|
||||||
self.target,
|
self.target,
|
||||||
self.selector,
|
self.selector,
|
||||||
self.rocm,
|
self.rocm,
|
||||||
|
self.cpp_namespace,
|
||||||
g
|
g
|
||||||
)
|
)
|
||||||
return list(mapMaybe(structured_gen.gen_one, g.functions()))
|
return list(mapMaybe(structured_gen.gen_one, g.functions()))
|
||||||
|
|
||||||
@method_with_native_function
|
def gen_unstructured(
|
||||||
def gen_unstructured(self, f: NativeFunction) -> Optional[str]:
|
self,
|
||||||
|
native_or_external: Union[NativeFunction, ExternalBackendFunction],
|
||||||
|
*,
|
||||||
|
# Only applies to ExternalBackendFunction objects.
|
||||||
|
# True for inplace/out functions that don't have kernels, but do have corresponding functional kernels.
|
||||||
|
is_generated_wrapper: bool = False,
|
||||||
|
) -> Optional[str]:
|
||||||
|
sig: Union[NativeSignature, DispatcherSignature]
|
||||||
|
if isinstance(native_or_external, ExternalBackendFunction):
|
||||||
|
if not requires_backend_wrapper(native_or_external) and not is_generated_wrapper:
|
||||||
|
return None
|
||||||
|
# If the backend doesn't have a kernel, we don't generate an anonymous wrapper or a dispatcher registration for it
|
||||||
|
# (fallbacks to CPU are generated elsewhere).
|
||||||
|
if native_or_external.metadata is None and not is_generated_wrapper:
|
||||||
|
# TODO: Right now, we generate "fast-path" `at::xla::{op}` kernels for all non-composite ops,
|
||||||
|
# including those with CPU fallbacks. That logic will have to change if CPU fallbacks become a boxed kernel.
|
||||||
|
if self.target is not Target.NAMESPACED_DEFINITION and self.target is not Target.NAMESPACED_DECLARATION:
|
||||||
|
return None
|
||||||
|
f = native_or_external.native_function
|
||||||
|
sig = self.external_backend_wrapper_sig(native_or_external)
|
||||||
|
elif isinstance(native_or_external, NativeFunction):
|
||||||
|
f = native_or_external
|
||||||
|
sig = NativeSignature(f.func, prefix='wrapper_')
|
||||||
|
else:
|
||||||
|
assert_never(f)
|
||||||
|
|
||||||
inplace_meta = False
|
inplace_meta = False
|
||||||
if self.dispatch_key not in f.dispatch:
|
if self.dispatch_key not in f.dispatch:
|
||||||
if (self.dispatch_key == DispatchKey.Meta and
|
if (self.dispatch_key == DispatchKey.Meta and
|
||||||
|
|
@ -112,8 +163,6 @@ class RegisterDispatchKey:
|
||||||
if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f):
|
if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
sig = NativeSignature(f.func, prefix='wrapper_')
|
|
||||||
|
|
||||||
name = sig.name()
|
name = sig.name()
|
||||||
returns_type = sig.returns_type().cpp_type()
|
returns_type = sig.returns_type().cpp_type()
|
||||||
args = sig.arguments()
|
args = sig.arguments()
|
||||||
|
|
@ -129,9 +178,19 @@ class RegisterDispatchKey:
|
||||||
return result
|
return result
|
||||||
elif self.target is Target.NAMESPACED_DEFINITION:
|
elif self.target is Target.NAMESPACED_DEFINITION:
|
||||||
def generate_defn(cpp_sig: CppSignature) -> str:
|
def generate_defn(cpp_sig: CppSignature) -> str:
|
||||||
|
# This is needed in order for namespaced definitions to call into CPU fallbacks,
|
||||||
|
# which live in a different namespace.
|
||||||
|
# TODO: this logic will change if we implement a boxed CPU fallback.
|
||||||
|
if isinstance(native_or_external, ExternalBackendFunction) \
|
||||||
|
and native_or_external.metadata is None \
|
||||||
|
and not is_generated_wrapper:
|
||||||
|
# See Note [External Backends Follow Dispatcher convention]
|
||||||
|
kernel_name = f'{self.cpp_namespace}::AtenXlaTypeDefault::{dispatcher.name(f.func)}'
|
||||||
|
else:
|
||||||
|
kernel_name = sig.name()
|
||||||
return f"""
|
return f"""
|
||||||
{cpp_sig.defn()} {{
|
{cpp_sig.defn()} {{
|
||||||
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
|
return {kernel_name}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
result = generate_defn(cpp_sig_group.signature)
|
result = generate_defn(cpp_sig_group.signature)
|
||||||
|
|
@ -152,7 +211,11 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
impl_name = f"at::native::{f.dispatch[self.dispatch_key]}"
|
if isinstance(native_or_external, ExternalBackendFunction):
|
||||||
|
# TODO: remove this difference and merge the two cases when we remove xla-specific logic
|
||||||
|
impl_name = f"{self.cpp_namespace}::AtenXlaType::{f.dispatch[self.dispatch_key]}"
|
||||||
|
else:
|
||||||
|
impl_name = f"{self.cpp_namespace}::{f.dispatch[self.dispatch_key]}"
|
||||||
|
|
||||||
args_exprs_str = ', '.join(a.name for a in args)
|
args_exprs_str = ', '.join(a.name for a in args)
|
||||||
|
|
||||||
|
|
@ -196,12 +259,74 @@ namespace {{
|
||||||
if f.manual_kernel_registration:
|
if f.manual_kernel_registration:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
dispatcher_sig = DispatcherSignature.from_schema(f.func)
|
|
||||||
payload = f"TORCH_FN({name})"
|
payload = f"TORCH_FN({name})"
|
||||||
return f'm.impl("{f.func.name}",\n{payload});\n'
|
return f'm.impl("{f.func.name}",\n{payload});\n'
|
||||||
else:
|
else:
|
||||||
assert_never(self.target)
|
assert_never(self.target)
|
||||||
|
|
||||||
|
def external_backend_wrapper_sig(self, f: ExternalBackendFunction) -> DispatcherSignature:
|
||||||
|
# See Note [External Backends Follow Dispatcher convention]
|
||||||
|
return DispatcherSignature.from_schema(f.native_function.func, prefix='wrapper_', append_overload_name=True)
|
||||||
|
|
||||||
|
def gen_out_inplace_wrappers(self, g: ExternalBackendFunctionsGroup) -> List[str]:
|
||||||
|
def gets_generated_out_inplace_wrapper(f: ExternalBackendFunction) -> bool:
|
||||||
|
return f.native_function.func.kind() is not SchemaKind.functional \
|
||||||
|
and f.metadata is None and g.functional.metadata is not None
|
||||||
|
|
||||||
|
@with_native_function
|
||||||
|
def gen_wrapper(f: ExternalBackendFunction) -> Optional[str]:
|
||||||
|
# Only anonymous definitions get "special treatment" for out/inplace wrappers.
|
||||||
|
# All other functionality can be directly pulled from gen_unstructured.
|
||||||
|
if self.target is not Target.ANONYMOUS_DEFINITION:
|
||||||
|
is_generated_wrapper = gets_generated_out_inplace_wrapper(f)
|
||||||
|
return self.gen_unstructured(f, is_generated_wrapper=is_generated_wrapper)
|
||||||
|
|
||||||
|
if f.native_function.func.kind() is SchemaKind.functional:
|
||||||
|
# Wrappers are generated for out/inplace kernels, using the functional kernel,
|
||||||
|
# so the functional kernel itself is generated normally.
|
||||||
|
return self.gen_unstructured(f)
|
||||||
|
if f.metadata is not None:
|
||||||
|
# If the backend provided their own out/inplace kernel, use it.
|
||||||
|
return self.gen_unstructured(f)
|
||||||
|
|
||||||
|
if not gets_generated_out_inplace_wrapper(f):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Special out/inplace wrapper logic starts here.
|
||||||
|
dispatcher_sig = self.external_backend_wrapper_sig(f)
|
||||||
|
name = dispatcher_sig.name()
|
||||||
|
|
||||||
|
# See Note [External Backends Follow Dispatcher convention]
|
||||||
|
dispatcher_order_args = dispatcher.jit_arguments(f.native_function.func)
|
||||||
|
tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)]
|
||||||
|
print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors])
|
||||||
|
|
||||||
|
functional_result_name = f'{name}_tmp'
|
||||||
|
return_names = cpp.return_names(f.native_function)
|
||||||
|
if len(return_names) > 1:
|
||||||
|
updates = '\n '.join(
|
||||||
|
f'at::_copy_from_and_resize(std::get<{i}>({functional_result_name}), {ret_name});'
|
||||||
|
for i, ret_name in enumerate(return_names))
|
||||||
|
returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})'
|
||||||
|
else:
|
||||||
|
ret_name = return_names[0]
|
||||||
|
updates = f'at::_copy_from_and_resize({functional_result_name}, {ret_name});'
|
||||||
|
returns = ret_name
|
||||||
|
|
||||||
|
functional_sig = self.external_backend_wrapper_sig(g.functional)
|
||||||
|
|
||||||
|
return f"""\
|
||||||
|
{dispatcher_sig.defn()} {{
|
||||||
|
XLA_FN_TRACK(3);
|
||||||
|
TF_VLOG(3) << "XLA {name} :"{print_args_str};
|
||||||
|
auto {functional_result_name} = {functional_sig.name()}({", ".join(a.name for a in functional_sig.arguments())});
|
||||||
|
{updates}
|
||||||
|
return {returns};
|
||||||
|
}}
|
||||||
|
|
||||||
|
"""
|
||||||
|
return list(mapMaybe(gen_wrapper, g.functions(functional_first=True)))
|
||||||
|
|
||||||
|
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||||
#
|
#
|
||||||
|
|
|
||||||
|
|
@ -916,20 +916,22 @@ def main() -> None:
|
||||||
'#include <ATen/LegacyTHFunctionsCPU.h>' if dispatch_key == DispatchKey.CPU else
|
'#include <ATen/LegacyTHFunctionsCPU.h>' if dispatch_key == DispatchKey.CPU else
|
||||||
'#include <ATen/LegacyTHFunctionsCUDA.h>' if dispatch_key == DispatchKey.CUDA else
|
'#include <ATen/LegacyTHFunctionsCUDA.h>' if dispatch_key == DispatchKey.CUDA else
|
||||||
'',
|
'',
|
||||||
|
'external_backend_headers': '',
|
||||||
'DispatchKey': dispatch_key,
|
'DispatchKey': dispatch_key,
|
||||||
'dispatch_namespace': dispatch_key.lower(),
|
'dispatch_namespace': dispatch_key.lower(),
|
||||||
'dispatch_namespaced_definitions': list(concatMap(
|
'dispatch_namespaced_definitions': list(concatMap(
|
||||||
dest.RegisterDispatchKey(
|
dest.RegisterDispatchKey(
|
||||||
dispatch_key, Target.NAMESPACED_DEFINITION, selector, rocm=options.rocm),
|
dispatch_key, Target.NAMESPACED_DEFINITION, selector, rocm=options.rocm, cpp_namespace='at::native'),
|
||||||
grouped_native_functions
|
grouped_native_functions
|
||||||
)),
|
)),
|
||||||
'dispatch_anonymous_definitions': list(concatMap(
|
'dispatch_anonymous_definitions': list(concatMap(
|
||||||
dest.RegisterDispatchKey(
|
dest.RegisterDispatchKey(
|
||||||
dispatch_key, Target.ANONYMOUS_DEFINITION, selector, rocm=options.rocm),
|
dispatch_key, Target.ANONYMOUS_DEFINITION, selector, rocm=options.rocm, cpp_namespace='at::native'),
|
||||||
grouped_native_functions
|
grouped_native_functions
|
||||||
)),
|
)),
|
||||||
'dispatch_registrations': list(concatMap(
|
'dispatch_registrations': list(concatMap(
|
||||||
dest.RegisterDispatchKey(dispatch_key, Target.REGISTRATION, selector, rocm=options.rocm),
|
dest.RegisterDispatchKey(
|
||||||
|
dispatch_key, Target.REGISTRATION, selector, rocm=options.rocm, cpp_namespace='at::native'),
|
||||||
grouped_native_functions
|
grouped_native_functions
|
||||||
)),
|
)),
|
||||||
})
|
})
|
||||||
|
|
@ -939,7 +941,7 @@ def main() -> None:
|
||||||
'dispatch_namespace': dispatch_key.lower(),
|
'dispatch_namespace': dispatch_key.lower(),
|
||||||
'dispatch_namespaced_declarations': list(concatMap(
|
'dispatch_namespaced_declarations': list(concatMap(
|
||||||
dest.RegisterDispatchKey(
|
dest.RegisterDispatchKey(
|
||||||
dispatch_key, Target.NAMESPACED_DECLARATION, selector, rocm=options.rocm),
|
dispatch_key, Target.NAMESPACED_DECLARATION, selector, rocm=options.rocm, cpp_namespace='at::native'),
|
||||||
grouped_native_functions
|
grouped_native_functions
|
||||||
)),
|
)),
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import pathlib
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
from typing import List, Dict, Union, Tuple, Sequence
|
from typing import List, Dict, Union, Tuple, Sequence, Optional
|
||||||
from tools.codegen.gen import FileManager, get_grouped_native_functions, parse_native_yaml
|
from tools.codegen.gen import FileManager, get_grouped_native_functions, parse_native_yaml
|
||||||
from tools.codegen.model import (ExternalBackendFunction, ExternalBackendFunctionsGroup,
|
from tools.codegen.model import (ExternalBackendFunction, ExternalBackendFunctionsGroup,
|
||||||
NativeFunction, NativeFunctionsGroup, OperatorName,
|
NativeFunction, NativeFunctionsGroup, OperatorName,
|
||||||
|
|
@ -17,13 +17,14 @@ try:
|
||||||
# use faster C loader if available
|
# use faster C loader if available
|
||||||
from yaml import CSafeLoader as Loader
|
from yaml import CSafeLoader as Loader
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from yaml import SafeLoader as Loader # type: ignore
|
from yaml import SafeLoader as Loader # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# Parses the backend's yaml file and returns (BackendDispatchKey, AutogradDispatchKey, cpp_namespace, backend_kernel_list)
|
||||||
def parse_backend_yaml(
|
def parse_backend_yaml(
|
||||||
backend_yaml_path: str,
|
backend_yaml_path: str,
|
||||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]]
|
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]]
|
||||||
) -> Tuple[str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]:
|
) -> Tuple[DispatchKey, Optional[DispatchKey], str, List[Union[ExternalBackendFunction, ExternalBackendFunctionsGroup]]]:
|
||||||
with open(backend_yaml_path, 'r') as f:
|
with open(backend_yaml_path, 'r') as f:
|
||||||
yaml_values = yaml.load(f, Loader=Loader)
|
yaml_values = yaml.load(f, Loader=Loader)
|
||||||
assert isinstance(yaml_values, dict)
|
assert isinstance(yaml_values, dict)
|
||||||
|
|
@ -71,7 +72,8 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
||||||
for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions)
|
for f in concatMap(lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), grouped_native_functions)
|
||||||
}
|
}
|
||||||
|
|
||||||
def kernel_name(func: FunctionSchema) -> str:
|
def external_kernel_name(func: FunctionSchema) -> str:
|
||||||
|
# Note [External Backends Follow Dispatcher convention]
|
||||||
# For external backends, we enforce that their names and signatures match the dispatcher convention
|
# For external backends, we enforce that their names and signatures match the dispatcher convention
|
||||||
return dispatcher.name(func)
|
return dispatcher.name(func)
|
||||||
|
|
||||||
|
|
@ -83,17 +85,17 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
||||||
m = metadata.get(f.func.name, None)
|
m = metadata.get(f.func.name, None)
|
||||||
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
|
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
|
||||||
if m is not None and m.is_autograd else DispatchKey.parse(backend)
|
if m is not None and m.is_autograd else DispatchKey.parse(backend)
|
||||||
kernel = kernel_name(f.func)
|
kernel = external_kernel_name(f.func)
|
||||||
return ExternalBackendFunction(NativeFunction.with_dispatch_entry(f, dispatch_key, kernel), dispatch_key, m)
|
return ExternalBackendFunction(NativeFunction.with_dispatch_entry(f, dispatch_key, kernel), dispatch_key, m)
|
||||||
elif isinstance(g, NativeFunctionsGroup):
|
elif isinstance(g, NativeFunctionsGroup):
|
||||||
out_meta = metadata.get(g.out.func.name, None)
|
out_meta = metadata.get(g.out.func.name, None)
|
||||||
kernel = kernel_name(g.out.func)
|
kernel = external_kernel_name(g.out.func)
|
||||||
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
|
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
|
||||||
if out_meta is not None and out_meta.is_autograd else DispatchKey.parse(backend)
|
if out_meta is not None and out_meta.is_autograd else DispatchKey.parse(backend)
|
||||||
out = ExternalBackendFunction(NativeFunction.with_dispatch_entry(g.out, dispatch_key, kernel), dispatch_key, out_meta)
|
out = ExternalBackendFunction(NativeFunction.with_dispatch_entry(g.out, dispatch_key, kernel), dispatch_key, out_meta)
|
||||||
|
|
||||||
functional_meta = metadata.get(g.functional.func.name, None)
|
functional_meta = metadata.get(g.functional.func.name, None)
|
||||||
kernel = kernel_name(g.functional.func)
|
kernel = external_kernel_name(g.functional.func)
|
||||||
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
|
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
|
||||||
if functional_meta is not None and functional_meta.is_autograd else DispatchKey.parse(backend)
|
if functional_meta is not None and functional_meta.is_autograd else DispatchKey.parse(backend)
|
||||||
functional = ExternalBackendFunction(
|
functional = ExternalBackendFunction(
|
||||||
|
|
@ -109,7 +111,7 @@ autograd key. They can not be mix and matched. If this is something you need, fe
|
||||||
inplace = None
|
inplace = None
|
||||||
if g.inplace:
|
if g.inplace:
|
||||||
inplace_meta = metadata.get(g.inplace.func.name, None)
|
inplace_meta = metadata.get(g.inplace.func.name, None)
|
||||||
kernel = kernel_name(g.inplace.func)
|
kernel = external_kernel_name(g.inplace.func)
|
||||||
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
|
dispatch_key = DispatchKey.parse(f'Autograd{backend}') \
|
||||||
if inplace_meta is not None and inplace_meta.is_autograd else DispatchKey.parse(backend)
|
if inplace_meta is not None and inplace_meta.is_autograd else DispatchKey.parse(backend)
|
||||||
inplace = ExternalBackendFunction(
|
inplace = ExternalBackendFunction(
|
||||||
|
|
@ -128,7 +130,7 @@ autograd key. They can not be mix and matched. If this is something you need, fe
|
||||||
assert_never(g)
|
assert_never(g)
|
||||||
for op_name in metadata.keys():
|
for op_name in metadata.keys():
|
||||||
assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}"
|
assert op_name in native_functions_map, f"Found an invalid operator name: {op_name}"
|
||||||
return cpp_namespace, [native_to_external(g) for g in grouped_native_functions]
|
return backend_key, autograd_key, cpp_namespace, [native_to_external(g) for g in grouped_native_functions]
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(description='Generate backend stub files')
|
parser = argparse.ArgumentParser(description='Generate backend stub files')
|
||||||
|
|
@ -157,7 +159,7 @@ def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
|
||||||
|
|
||||||
native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
|
native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
|
||||||
grouped_native_functions = get_grouped_native_functions(native_yaml_path)
|
grouped_native_functions = get_grouped_native_functions(native_yaml_path)
|
||||||
cpp_namespace, external_backend_functions = parse_backend_yaml(source_yaml, grouped_native_functions)
|
backend_key, autograd_key, cpp_namespace, external_backend_functions = parse_backend_yaml(source_yaml, grouped_native_functions)
|
||||||
|
|
||||||
native_functions = parse_native_yaml(native_yaml_path)
|
native_functions = parse_native_yaml(native_yaml_path)
|
||||||
|
|
||||||
|
|
@ -171,6 +173,68 @@ def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
|
||||||
'dispatch_xla_declarations': list(concatMap(dest.compute_native_function_declaration, external_backend_functions)),
|
'dispatch_xla_declarations': list(concatMap(dest.compute_native_function_declaration, external_backend_functions)),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
external_backend_functions_no_autograd = [f for f in external_backend_functions if not f.is_autograd_kernel]
|
||||||
|
external_backend_functions_autograd = [f for f in external_backend_functions if f.is_autograd_kernel]
|
||||||
|
|
||||||
|
external_backend_headers = '''\
|
||||||
|
#include <tensorflow/compiler/xla/xla_client/debug_macros.h>
|
||||||
|
#include <tensorflow/compiler/xla/xla_client/metrics.h>
|
||||||
|
#include <tensorflow/compiler/xla/xla_client/tf_logging.h>
|
||||||
|
#include <torch_xla/csrc/function_call_tracker.h>
|
||||||
|
#include <torch_xla/csrc/aten_xla_type.h>
|
||||||
|
#include <torch_xla/csrc/aten_xla_type_default.h>'''
|
||||||
|
|
||||||
|
fm.write_with_template(f'Register{backend_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
|
||||||
|
'extra_cuda_headers': '',
|
||||||
|
'legacy_th_headers': '',
|
||||||
|
'external_backend_headers': external_backend_headers,
|
||||||
|
'DispatchKey': backend_key,
|
||||||
|
'dispatch_namespace': backend_key.lower(),
|
||||||
|
'dispatch_namespaced_definitions': list(concatMap(
|
||||||
|
dest.RegisterDispatchKey(
|
||||||
|
backend_key, Target.NAMESPACED_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
|
||||||
|
external_backend_functions_no_autograd
|
||||||
|
)),
|
||||||
|
'dispatch_anonymous_definitions': list(concatMap(
|
||||||
|
dest.RegisterDispatchKey(
|
||||||
|
backend_key, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
|
||||||
|
external_backend_functions_no_autograd
|
||||||
|
)),
|
||||||
|
'dispatch_registrations': list(concatMap(
|
||||||
|
dest.RegisterDispatchKey(backend_key, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace),
|
||||||
|
external_backend_functions_no_autograd
|
||||||
|
)),
|
||||||
|
})
|
||||||
|
|
||||||
|
# If they have at least one autograd entry in their yaml file
|
||||||
|
if autograd_key is not None:
|
||||||
|
assert len(external_backend_functions_autograd) > 0
|
||||||
|
autograd_dispatchkey: DispatchKey = autograd_key # make mypy happy
|
||||||
|
fm.write_with_template(f'Register{autograd_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
|
||||||
|
'extra_cuda_headers': '',
|
||||||
|
'legacy_th_headers': '',
|
||||||
|
'external_backend_headers': external_backend_headers,
|
||||||
|
'DispatchKey': autograd_dispatchkey,
|
||||||
|
'dispatch_namespace': autograd_dispatchkey.lower(),
|
||||||
|
'dispatch_namespaced_definitions': list(concatMap(
|
||||||
|
dest.RegisterDispatchKey(
|
||||||
|
autograd_dispatchkey, Target.NAMESPACED_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
|
||||||
|
external_backend_functions_autograd
|
||||||
|
)),
|
||||||
|
'dispatch_anonymous_definitions': list(concatMap(
|
||||||
|
dest.RegisterDispatchKey(
|
||||||
|
autograd_dispatchkey, Target.ANONYMOUS_DEFINITION, selector, rocm=False, cpp_namespace=cpp_namespace),
|
||||||
|
external_backend_functions_autograd
|
||||||
|
)),
|
||||||
|
'dispatch_registrations': list(concatMap(
|
||||||
|
dest.RegisterDispatchKey(
|
||||||
|
autograd_dispatchkey, Target.REGISTRATION, selector, rocm=False, cpp_namespace=cpp_namespace),
|
||||||
|
external_backend_functions_autograd
|
||||||
|
)),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
fm.write('aten_xla_type_default.h', lambda: {
|
fm.write('aten_xla_type_default.h', lambda: {
|
||||||
'generated_comment': generated_comment,
|
'generated_comment': generated_comment,
|
||||||
'cpp_namespace': cpp_namespace,
|
'cpp_namespace': cpp_namespace,
|
||||||
|
|
@ -188,10 +252,7 @@ def run(source_yaml: str, output_dir: str, dry_run: bool) -> None:
|
||||||
dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION), external_backend_functions
|
dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION), external_backend_functions
|
||||||
)),
|
)),
|
||||||
'dispatch_registrations': list(concatMap(
|
'dispatch_registrations': list(concatMap(
|
||||||
dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if not e.is_autograd_kernel]
|
dest.GenExternalAtenFallback(Target.REGISTRATION), external_backend_functions
|
||||||
)),
|
|
||||||
'dispatch_autograd_registrations': list(concatMap(
|
|
||||||
dest.GenExternalAtenFallback(Target.REGISTRATION), [e for e in external_backend_functions if e.is_autograd_kernel]
|
|
||||||
)),
|
)),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1452,11 +1452,19 @@ class ExternalBackendFunctionsGroup:
|
||||||
f"variant, {str(f.native_function.func.name)} will be generated for you " \
|
f"variant, {str(f.native_function.func.name)} will be generated for you " \
|
||||||
"and doesn't need to live in the yaml."
|
"and doesn't need to live in the yaml."
|
||||||
|
|
||||||
def functions(self) -> Iterator[ExternalBackendFunction]:
|
def functions(self, *, functional_first: bool = False) -> Iterator[ExternalBackendFunction]:
|
||||||
yield self.out
|
if not functional_first:
|
||||||
yield self.functional
|
yield self.out
|
||||||
if self.inplace is not None:
|
yield self.functional
|
||||||
yield self.inplace
|
if self.inplace is not None:
|
||||||
|
yield self.inplace
|
||||||
|
else:
|
||||||
|
# When we generate out/inplace wrappers for unstructured kernels, we need the functional kernel
|
||||||
|
# to be defined before we can generate the other two wrappers.
|
||||||
|
yield self.functional
|
||||||
|
yield self.out
|
||||||
|
if self.inplace is not None:
|
||||||
|
yield self.inplace
|
||||||
|
|
||||||
|
|
||||||
# Helper functions for parsing argument lists (both inputs and returns)
|
# Helper functions for parsing argument lists (both inputs and returns)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user