[torchgen] Add CI job to cover custom ops registration for Executorch (#91291)

As titled. To register a custom op into Executorch, we need:

* `custom_ops.yaml`, defines the operator schema and the corresponding native function.
* `custom_ops.cpp`, defines the kernel.
* `RegisterDispatchKeyCustomOps.cpp`, a template to register operator into PyTorch.

Added a new test for custom ops. The custom op `custom::add_3.out` takes 3 tensors and add them together. The test makes sure it is registered correctly and then verifies the outcome is correct.

Differential Revision: [D42204263](https://our.internmc.facebook.com/intern/diff/D42204263/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91291
Approved by: https://github.com/ezyang
This commit is contained in:
Larry Liu 2023-01-13 23:10:59 +00:00 committed by PyTorch MergeBot
parent 66b324cf06
commit 7568484d54
10 changed files with 133 additions and 34 deletions

View File

@ -14,11 +14,14 @@ set(GEN_COMMAND
--aten_yaml_path=${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml --aten_yaml_path=${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml
--use_aten_lib --use_aten_lib
--op_selection_yaml_path=${TEST_ROOT}/selected_operators.yaml --op_selection_yaml_path=${TEST_ROOT}/selected_operators.yaml
--custom_ops_yaml_path=${TEST_ROOT}/custom_ops.yaml
) )
set(GEN_COMMAND_sources set(GEN_COMMAND_sources
${OUTPUT_DIRECTORY}/RegisterCodegenUnboxedKernelsEverything.cpp ${OUTPUT_DIRECTORY}/RegisterCodegenUnboxedKernelsEverything.cpp
${OUTPUT_DIRECTORY}/RegisterCPUCustomOps.cpp
${OUTPUT_DIRECTORY}/Functions.h ${OUTPUT_DIRECTORY}/Functions.h
${OUTPUT_DIRECTORY}/NativeFunctions.h ${OUTPUT_DIRECTORY}/NativeFunctions.h
${OUTPUT_DIRECTORY}/CustomOpsNativeFunctions.h
) )
message(STATUS "Generating sources for unboxing kernels ${GEN_COMMAND}") message(STATUS "Generating sources for unboxing kernels ${GEN_COMMAND}")
add_custom_command( add_custom_command(
@ -32,6 +35,7 @@ add_custom_command(
${TEST_ROOT}/templates/Functions.h ${TEST_ROOT}/templates/Functions.h
${TEST_ROOT}/templates/NativeFunctions.h ${TEST_ROOT}/templates/NativeFunctions.h
${TEST_ROOT}/templates/RegisterCodegenUnboxedKernels.cpp ${TEST_ROOT}/templates/RegisterCodegenUnboxedKernels.cpp
${TEST_ROOT}/templates/RegisterDispatchKeyCustomOps.cpp
WORKING_DIRECTORY ${TORCH_ROOT} WORKING_DIRECTORY ${TORCH_ROOT}
) )
add_custom_target(unbox_target DEPENDS ${GEN_COMMAND_sources}) add_custom_target(unbox_target DEPENDS ${GEN_COMMAND_sources})
@ -39,6 +43,7 @@ add_custom_target(unbox_target DEPENDS ${GEN_COMMAND_sources})
add_library(unbox_lib STATIC add_library(unbox_lib STATIC
${GEN_COMMAND_sources} ${GEN_COMMAND_sources}
${TEST_ROOT}/operator_registry.cpp ${TEST_ROOT}/operator_registry.cpp
${TEST_ROOT}/custom_ops.cpp
) )
target_include_directories(unbox_lib PUBLIC ${TEST_ROOT} ${ATen_CPU_INCLUDE}) target_include_directories(unbox_lib PUBLIC ${TEST_ROOT} ${ATen_CPU_INCLUDE})
target_link_libraries(unbox_lib PUBLIC torch_cpu) target_link_libraries(unbox_lib PUBLIC torch_cpu)

10
test/edge/custom_ops.cpp Normal file
View File

@ -0,0 +1,10 @@
#include <ATen/Tensor.h>
namespace custom {
namespace native {
at::Tensor& add_3_out(const at::Tensor& a, const at::Tensor& b, const at::Tensor& c, at::Tensor& out) {
out = a.add(b).add(c);
return out;
}
}
}

View File

@ -0,0 +1,3 @@
- func: custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: custom::add_3_out

View File

@ -448,3 +448,9 @@ operators:
include_all_overloads: false include_all_overloads: false
is_root_operator: true is_root_operator: true
is_used_for_training: true is_used_for_training: true
custom::add_3.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true

View File

@ -0,0 +1,27 @@
// clang-format off
// Generated code for registering custom operators into the dispatcher.
#include <torch/library.h>
#include <ATen/Tensor.h>
$ops_headers
namespace torch {
namespace executor {
namespace function {
${dispatch_anonymous_definitions}
// All out variants ops
${static_init_dispatch_registrations}
namespace ${dispatch_namespace}
{
${dispatch_namespaced_definitions}
} // namespace ${dispatch_namespace}
} // namespace function
} // namespace executor
} // namespace torch

View File

@ -0,0 +1,10 @@
// ${generated_comment}
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <torch/library.h>
namespace at {
TORCH_LIBRARY_FRAGMENT(aten, m) {
${aten_schema_registrations};
}
$schema_registrations
} // namespace at

View File

@ -23,6 +23,27 @@ TEST(OperatorRegistrationTest, Add) {
expected = at::fill(expected, 2); expected = at::fill(expected, 2);
ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor())); ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
}
// custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!)
TEST(OperatorRegistrationTest, CustomAdd3) {
EValue values[4];
values[0] = EValue(at::ones({2, 3}));
values[1] = EValue(at::ones({2, 3}));
values[2] = EValue(at::ones({2, 3}));
values[3] = EValue(at::zeros({2, 3}));
ASSERT_TRUE(hasOpsFn("custom::add_3.out"));
auto op = getOpsFn("custom::add_3.out");
EValue* kernel_values[4];
for (size_t i = 0; i < 4; i++) {
kernel_values[i] = &values[i];
}
op(kernel_values);
at::Tensor expected = at::ones({2, 3});
expected = at::fill(expected, 3);
ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
} }
} // namespace executor } // namespace executor
} // namespace torch } // namespace torch

View File

@ -40,7 +40,7 @@ class TestComputeNativeFunctionStub(expecttest.TestCase):
def test_function_schema_generates_correct_kernel_tensor_out(self) -> None: def test_function_schema_generates_correct_kernel_tensor_out(self) -> None:
obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"} obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"}
expected = """ expected = """
at::Tensor & wrapper_out_foo_out(const at::Tensor & self, at::Tensor & out) { at::Tensor & wrapper_CPU_out_foo_out(const at::Tensor & self, at::Tensor & out) {
return out; return out;
} }
""" """
@ -49,7 +49,7 @@ at::Tensor & wrapper_out_foo_out(const at::Tensor & self, at::Tensor & out) {
def test_function_schema_generates_correct_kernel_no_out(self) -> None: def test_function_schema_generates_correct_kernel_no_out(self) -> None:
obj = {"func": "custom::foo.Tensor(Tensor self) -> Tensor"} obj = {"func": "custom::foo.Tensor(Tensor self) -> Tensor"}
expected = """ expected = """
at::Tensor wrapper_Tensor_foo(const at::Tensor & self) { at::Tensor wrapper_CPU_Tensor_foo(const at::Tensor & self) {
return self; return self;
} }
""" """
@ -58,7 +58,7 @@ at::Tensor wrapper_Tensor_foo(const at::Tensor & self) {
def test_function_schema_generates_correct_kernel_no_return(self) -> None: def test_function_schema_generates_correct_kernel_no_return(self) -> None:
obj = {"func": "custom::foo(Tensor self, *, Tensor(a!)[] out) -> ()"} obj = {"func": "custom::foo(Tensor self, *, Tensor(a!)[] out) -> ()"}
expected = f""" expected = f"""
void wrapper__foo_out(const at::Tensor & self, at::TensorList out) {{ void wrapper_CPU__foo_out(const at::Tensor & self, at::TensorList out) {{
{SPACES} {SPACES}
}} }}
""" """

View File

@ -23,7 +23,7 @@ class ComputeNativeFunctionStub:
return None return None
sig = DispatcherSignature.from_schema( sig = DispatcherSignature.from_schema(
f.func, prefix=f"wrapper_{f.func.name.overload_name}_", symint=False f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
) )
assert sig is not None assert sig is not None
if len(f.func.returns) == 0: if len(f.func.returns) == 0:

View File

@ -46,7 +46,7 @@ from torchgen.utils import (
def static_dispatch( def static_dispatch(
sig: ExecutorchCppSignature, sig: Union[CppSignature, ExecutorchCppSignature],
f: NativeFunction, f: NativeFunction,
backend_indices: List[BackendIndex], backend_indices: List[BackendIndex],
) -> str: ) -> str:
@ -99,12 +99,16 @@ class ComputeFunction:
return None return None
if Variant.function not in f.variants: if Variant.function not in f.variants:
return None return None
sig: Union[CppSignature, ExecutorchCppSignature] = (
if self.use_aten_lib: CppSignatureGroup.from_native_function(
comma = ", "
sig = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding f, method=False, fallback_binding=f.manual_cpp_binding
).most_faithful_signature() ).most_faithful_signature()
if self.use_aten_lib
else ExecutorchCppSignature.from_native_function(f)
)
if self.use_aten_lib and f.namespace == "aten":
comma = ", "
return f""" return f"""
// {f.namespace}::{f.func} // {f.namespace}::{f.func}
TORCH_API inline {sig.decl()} {{ TORCH_API inline {sig.decl()} {{
@ -114,7 +118,7 @@ TORCH_API inline {sig.decl()} {{
else: else:
return static_dispatch( return static_dispatch(
ExecutorchCppSignature.from_native_function(f), sig,
f, f,
backend_indices=self.static_dispatch_backend_indices, backend_indices=self.static_dispatch_backend_indices,
) )
@ -280,9 +284,12 @@ def gen_headers(
cpu_fm.write( cpu_fm.write(
"Functions.h", "Functions.h",
lambda: { lambda: {
"static_dispatch_extra_headers": "#include <ATen/Functions.h>" "static_dispatch_extra_headers": [
'#include "CustomOpsNativeFunctions.h"',
"#include <ATen/Functions.h>",
]
if use_aten_lib if use_aten_lib
else '#include "NativeFunctions.h"', else ['#include "NativeFunctions.h"'],
"Functions_declarations": gen_functions_declarations( "Functions_declarations": gen_functions_declarations(
native_functions=native_functions, native_functions=native_functions,
static_dispatch_idx=static_dispatch_idx, static_dispatch_idx=static_dispatch_idx,
@ -314,7 +321,6 @@ def gen_custom_ops(
cpu_fm: FileManager, cpu_fm: FileManager,
rocm: bool, rocm: bool,
) -> None: ) -> None:
dispatch_key = DispatchKey.CPU dispatch_key = DispatchKey.CPU
backend_index = backend_indices[dispatch_key] backend_index = backend_indices[dispatch_key]
( (
@ -326,11 +332,22 @@ def gen_custom_ops(
backend_index=backend_index, backend_index=backend_index,
rocm=rocm, rocm=rocm,
) )
cpu_fm.write_with_template(
"CustomOpsNativeFunctions.h",
"NativeFunctions.h",
lambda: {
"nativeFunctions_declarations": get_native_function_declarations(
grouped_native_functions=native_functions,
backend_indices=backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
),
},
)
cpu_fm.write_with_template( cpu_fm.write_with_template(
f"Register{dispatch_key}CustomOps.cpp", f"Register{dispatch_key}CustomOps.cpp",
"RegisterDispatchKeyCustomOps.cpp", "RegisterDispatchKeyCustomOps.cpp",
lambda: { lambda: {
"ops_headers": '#include "NativeFunctions.h"', "ops_headers": '#include "CustomOpsNativeFunctions.h"',
"DispatchKey": dispatch_key, "DispatchKey": dispatch_key,
"dispatch_namespace": dispatch_key.lower(), "dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_definitions": "", "dispatch_namespaced_definitions": "",
@ -482,35 +499,35 @@ def parse_yaml_files(
native_yaml_path = os.path.join(tmpdirname, "functions.yaml") native_yaml_path = os.path.join(tmpdirname, "functions.yaml")
with open(native_yaml_path, "w"): with open(native_yaml_path, "w"):
pass pass
# Translate native_yaml_path to the same format of native_functions.yaml
# If custom_ops_yaml_path exists, combine both files.
if custom_ops_yaml_path and os.path.exists(custom_ops_yaml_path):
combined_yaml_path = os.path.join(tmpdirname, "combined.yaml")
with open(combined_yaml_path, "w") as tmp:
with open(native_yaml_path, "r") as native:
for line in native:
tmp.write(line)
with open(custom_ops_yaml_path, "r") as custom:
for line in custom:
tmp.write(line)
custom_ops_parsed_yaml = parse_native_yaml(
custom_ops_yaml_path, tags_yaml_path, None, skip_native_fns_gen=True
)
else:
# No custom_ops; just parse native_yaml_path.
custom_ops_parsed_yaml = None
combined_yaml_path = native_yaml_path
translated_yaml_path = os.path.join(tmpdirname, "translated.yaml") translated_yaml_path = os.path.join(tmpdirname, "translated.yaml")
with open(translated_yaml_path, "w") as translated: with open(translated_yaml_path, "w") as translated:
translate_native_yaml( translate_native_yaml(
tags_yaml_path, tags_yaml_path,
aten_yaml_path, aten_yaml_path,
combined_yaml_path, native_yaml_path,
use_aten_lib, use_aten_lib,
translated, translated,
) )
# If custom_ops_yaml_path doesn't exist, point to an empty file.
if not custom_ops_yaml_path or not os.path.exists(custom_ops_yaml_path):
custom_ops_yaml_path = os.path.join(tmpdirname, "custom_ops.yaml")
with open(custom_ops_yaml_path, "w"):
pass
combined_yaml_path = os.path.join(tmpdirname, "combined.yaml")
with open(combined_yaml_path, "w") as tmp, open(
translated_yaml_path, "r"
) as native, open(custom_ops_yaml_path, "r") as custom:
for line in native.readlines():
tmp.write(line)
for line in custom.readlines():
tmp.write(line)
custom_ops_parsed_yaml = parse_native_yaml(
custom_ops_yaml_path, tags_yaml_path, None, skip_native_fns_gen=True
)
parsed_yaml = parse_native_yaml( parsed_yaml = parse_native_yaml(
translated_yaml_path, combined_yaml_path,
tags_yaml_path, tags_yaml_path,
None, None,
skip_native_fns_gen=(not gen_native_fns), skip_native_fns_gen=(not gen_native_fns),