mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
As titled. To register a custom op into Executorch, we need: * `custom_ops.yaml`, defines the operator schema and the corresponding native function. * `custom_ops.cpp`, defines the kernel. * `RegisterDispatchKeyCustomOps.cpp`, a template to register operator into PyTorch. Added a new test for custom ops. The custom op `custom::add_3.out` takes 3 tensors and add them together. The test makes sure it is registered correctly and then verifies the outcome is correct. Differential Revision: [D42204263](https://our.internmc.facebook.com/intern/diff/D42204263/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/91291 Approved by: https://github.com/ezyang
127 lines
4.5 KiB
Python
127 lines
4.5 KiB
Python
from collections import defaultdict
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Sequence, Tuple
|
|
|
|
from torchgen import dest
|
|
|
|
# disable import sorting to avoid circular dependency.
|
|
from torchgen.api.types import DispatcherSignature # isort:skip
|
|
from torchgen.context import method_with_native_function
|
|
from torchgen.model import BackendIndex, DispatchKey, NativeFunction, Variant
|
|
from torchgen.selective_build.selector import SelectiveBuilder
|
|
from torchgen.utils import concatMap, Target
|
|
|
|
|
|
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
|
|
# model authoring side.
|
|
@dataclass(frozen=True)
|
|
class ComputeNativeFunctionStub:
|
|
@method_with_native_function
|
|
def __call__(self, f: NativeFunction) -> Optional[str]:
|
|
if Variant.function not in f.variants:
|
|
return None
|
|
|
|
sig = DispatcherSignature.from_schema(
|
|
f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
|
|
)
|
|
assert sig is not None
|
|
if len(f.func.returns) == 0:
|
|
ret_name = ""
|
|
elif len(f.func.returns) == 1:
|
|
if f.func.arguments.out:
|
|
ret_name = f.func.arguments.out[0].name
|
|
else:
|
|
ret_name = next(
|
|
(
|
|
a.name
|
|
for a in f.func.arguments.flat_non_out
|
|
if a.type == f.func.returns[0].type
|
|
),
|
|
"",
|
|
)
|
|
if not ret_name:
|
|
raise Exception(f"Can't handle this return type {f.func}")
|
|
else:
|
|
assert len(f.func.arguments.out) == len(f.func.returns), (
|
|
"Out variant number of returns need to match the number of out arguments."
|
|
f" Got outs {str(f.func.arguments.out)} but returns {str(f.func.returns)}"
|
|
)
|
|
# returns a tuple of out arguments
|
|
tensor_type = "at::Tensor &"
|
|
comma = ", "
|
|
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
|
|
{comma.join([r.name for r in f.func.arguments.out])}
|
|
)"""
|
|
ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
|
|
return f"""
|
|
{sig.defn()} {{
|
|
{ret_str}
|
|
}}
|
|
"""
|
|
|
|
|
|
def gen_custom_ops_registration(
|
|
*,
|
|
native_functions: Sequence[NativeFunction],
|
|
selector: SelectiveBuilder,
|
|
backend_index: BackendIndex,
|
|
rocm: bool,
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Generate custom ops registration code for dest.RegisterDispatchKey.
|
|
|
|
:param native_functions: a sequence of `NativeFunction`
|
|
:param selector: for selective build.
|
|
:param backend_index: kernels for all the ops.
|
|
:param rocm: bool for dest.RegisterDispatchKey.
|
|
:return: generated C++ code to register custom operators into PyTorch
|
|
"""
|
|
dispatch_key = DispatchKey.CPU
|
|
|
|
static_init_dispatch_registrations = ""
|
|
ns_grouped_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
|
|
for native_function in native_functions:
|
|
ns_grouped_native_functions[native_function.namespace].append(native_function)
|
|
|
|
for namespace, functions in ns_grouped_native_functions.items():
|
|
if len(functions) == 0:
|
|
continue
|
|
dispatch_registrations_body = "\n".join(
|
|
list(
|
|
concatMap(
|
|
dest.RegisterDispatchKey(
|
|
backend_index,
|
|
Target.REGISTRATION,
|
|
selector,
|
|
rocm=rocm,
|
|
symint=False,
|
|
class_method_name=None,
|
|
skip_dispatcher_op_registration=False,
|
|
),
|
|
functions,
|
|
)
|
|
)
|
|
)
|
|
static_init_dispatch_registrations += f"""
|
|
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
|
|
{dispatch_registrations_body}
|
|
}};"""
|
|
anonymous_definition = "\n".join(
|
|
list(
|
|
concatMap(
|
|
dest.RegisterDispatchKey(
|
|
backend_index,
|
|
Target.ANONYMOUS_DEFINITION,
|
|
selector,
|
|
rocm=rocm,
|
|
symint=False,
|
|
class_method_name=None,
|
|
skip_dispatcher_op_registration=False,
|
|
),
|
|
native_functions,
|
|
)
|
|
)
|
|
)
|
|
return anonymous_definition, static_init_dispatch_registrations
|