[torchgen] Add CI job to cover custom ops registration for Executorch (#91291)

As titled. To register a custom op into Executorch, we need:

* `custom_ops.yaml`, defines the operator schema and the corresponding native function.
* `custom_ops.cpp`, defines the kernel.
* `RegisterDispatchKeyCustomOps.cpp`, a template to register operator into PyTorch.

Added a new test for custom ops. The custom op `custom::add_3.out` takes 3 tensors and add them together. The test makes sure it is registered correctly and then verifies the outcome is correct.

Differential Revision: [D42204263](https://our.internmc.facebook.com/intern/diff/D42204263/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91291
Approved by: https://github.com/ezyang
This commit is contained in:
Larry Liu 2023-01-13 23:10:59 +00:00 committed by PyTorch MergeBot
parent 66b324cf06
commit 7568484d54
10 changed files with 133 additions and 34 deletions

View File

@ -14,11 +14,14 @@ set(GEN_COMMAND
--aten_yaml_path=${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml
--use_aten_lib
--op_selection_yaml_path=${TEST_ROOT}/selected_operators.yaml
--custom_ops_yaml_path=${TEST_ROOT}/custom_ops.yaml
)
set(GEN_COMMAND_sources
${OUTPUT_DIRECTORY}/RegisterCodegenUnboxedKernelsEverything.cpp
${OUTPUT_DIRECTORY}/RegisterCPUCustomOps.cpp
${OUTPUT_DIRECTORY}/Functions.h
${OUTPUT_DIRECTORY}/NativeFunctions.h
${OUTPUT_DIRECTORY}/CustomOpsNativeFunctions.h
)
message(STATUS "Generating sources for unboxing kernels ${GEN_COMMAND}")
add_custom_command(
@ -32,6 +35,7 @@ add_custom_command(
${TEST_ROOT}/templates/Functions.h
${TEST_ROOT}/templates/NativeFunctions.h
${TEST_ROOT}/templates/RegisterCodegenUnboxedKernels.cpp
${TEST_ROOT}/templates/RegisterDispatchKeyCustomOps.cpp
WORKING_DIRECTORY ${TORCH_ROOT}
)
add_custom_target(unbox_target DEPENDS ${GEN_COMMAND_sources})
@ -39,6 +43,7 @@ add_custom_target(unbox_target DEPENDS ${GEN_COMMAND_sources})
add_library(unbox_lib STATIC
${GEN_COMMAND_sources}
${TEST_ROOT}/operator_registry.cpp
${TEST_ROOT}/custom_ops.cpp
)
target_include_directories(unbox_lib PUBLIC ${TEST_ROOT} ${ATen_CPU_INCLUDE})
target_link_libraries(unbox_lib PUBLIC torch_cpu)

10
test/edge/custom_ops.cpp Normal file
View File

@ -0,0 +1,10 @@
#include <ATen/Tensor.h>
namespace custom {
namespace native {
at::Tensor& add_3_out(const at::Tensor& a, const at::Tensor& b, const at::Tensor& c, at::Tensor& out) {
out = a.add(b).add(c);
return out;
}
}
}

View File

@ -0,0 +1,3 @@
- func: custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: custom::add_3_out

View File

@ -448,3 +448,9 @@ operators:
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
custom::add_3.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true

View File

@ -0,0 +1,27 @@
// clang-format off
// Generated code for registering custom operators into the dispatcher.
#include <torch/library.h>
#include <ATen/Tensor.h>
$ops_headers
namespace torch {
namespace executor {
namespace function {
${dispatch_anonymous_definitions}
// All out variants ops
${static_init_dispatch_registrations}
namespace ${dispatch_namespace}
{
${dispatch_namespaced_definitions}
} // namespace ${dispatch_namespace}
} // namespace function
} // namespace executor
} // namespace torch

View File

@ -0,0 +1,10 @@
// ${generated_comment}
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <torch/library.h>
namespace at {
TORCH_LIBRARY_FRAGMENT(aten, m) {
${aten_schema_registrations};
}
$schema_registrations
} // namespace at

View File

@ -23,6 +23,27 @@ TEST(OperatorRegistrationTest, Add) {
expected = at::fill(expected, 2);
ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
}
// custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!)
TEST(OperatorRegistrationTest, CustomAdd3) {
EValue values[4];
values[0] = EValue(at::ones({2, 3}));
values[1] = EValue(at::ones({2, 3}));
values[2] = EValue(at::ones({2, 3}));
values[3] = EValue(at::zeros({2, 3}));
ASSERT_TRUE(hasOpsFn("custom::add_3.out"));
auto op = getOpsFn("custom::add_3.out");
EValue* kernel_values[4];
for (size_t i = 0; i < 4; i++) {
kernel_values[i] = &values[i];
}
op(kernel_values);
at::Tensor expected = at::ones({2, 3});
expected = at::fill(expected, 3);
ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
}
} // namespace executor
} // namespace torch

View File

@ -40,7 +40,7 @@ class TestComputeNativeFunctionStub(expecttest.TestCase):
def test_function_schema_generates_correct_kernel_tensor_out(self) -> None:
obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"}
expected = """
at::Tensor & wrapper_out_foo_out(const at::Tensor & self, at::Tensor & out) {
at::Tensor & wrapper_CPU_out_foo_out(const at::Tensor & self, at::Tensor & out) {
return out;
}
"""
@ -49,7 +49,7 @@ at::Tensor & wrapper_out_foo_out(const at::Tensor & self, at::Tensor & out) {
def test_function_schema_generates_correct_kernel_no_out(self) -> None:
obj = {"func": "custom::foo.Tensor(Tensor self) -> Tensor"}
expected = """
at::Tensor wrapper_Tensor_foo(const at::Tensor & self) {
at::Tensor wrapper_CPU_Tensor_foo(const at::Tensor & self) {
return self;
}
"""
@ -58,7 +58,7 @@ at::Tensor wrapper_Tensor_foo(const at::Tensor & self) {
def test_function_schema_generates_correct_kernel_no_return(self) -> None:
obj = {"func": "custom::foo(Tensor self, *, Tensor(a!)[] out) -> ()"}
expected = f"""
void wrapper__foo_out(const at::Tensor & self, at::TensorList out) {{
void wrapper_CPU__foo_out(const at::Tensor & self, at::TensorList out) {{
{SPACES}
}}
"""

View File

@ -23,7 +23,7 @@ class ComputeNativeFunctionStub:
return None
sig = DispatcherSignature.from_schema(
f.func, prefix=f"wrapper_{f.func.name.overload_name}_", symint=False
f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
)
assert sig is not None
if len(f.func.returns) == 0:

View File

@ -46,7 +46,7 @@ from torchgen.utils import (
def static_dispatch(
sig: ExecutorchCppSignature,
sig: Union[CppSignature, ExecutorchCppSignature],
f: NativeFunction,
backend_indices: List[BackendIndex],
) -> str:
@ -99,12 +99,16 @@ class ComputeFunction:
return None
if Variant.function not in f.variants:
return None
if self.use_aten_lib:
comma = ", "
sig = CppSignatureGroup.from_native_function(
sig: Union[CppSignature, ExecutorchCppSignature] = (
CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding
).most_faithful_signature()
if self.use_aten_lib
else ExecutorchCppSignature.from_native_function(f)
)
if self.use_aten_lib and f.namespace == "aten":
comma = ", "
return f"""
// {f.namespace}::{f.func}
TORCH_API inline {sig.decl()} {{
@ -114,7 +118,7 @@ TORCH_API inline {sig.decl()} {{
else:
return static_dispatch(
ExecutorchCppSignature.from_native_function(f),
sig,
f,
backend_indices=self.static_dispatch_backend_indices,
)
@ -280,9 +284,12 @@ def gen_headers(
cpu_fm.write(
"Functions.h",
lambda: {
"static_dispatch_extra_headers": "#include <ATen/Functions.h>"
"static_dispatch_extra_headers": [
'#include "CustomOpsNativeFunctions.h"',
"#include <ATen/Functions.h>",
]
if use_aten_lib
else '#include "NativeFunctions.h"',
else ['#include "NativeFunctions.h"'],
"Functions_declarations": gen_functions_declarations(
native_functions=native_functions,
static_dispatch_idx=static_dispatch_idx,
@ -314,7 +321,6 @@ def gen_custom_ops(
cpu_fm: FileManager,
rocm: bool,
) -> None:
dispatch_key = DispatchKey.CPU
backend_index = backend_indices[dispatch_key]
(
@ -326,11 +332,22 @@ def gen_custom_ops(
backend_index=backend_index,
rocm=rocm,
)
cpu_fm.write_with_template(
"CustomOpsNativeFunctions.h",
"NativeFunctions.h",
lambda: {
"nativeFunctions_declarations": get_native_function_declarations(
grouped_native_functions=native_functions,
backend_indices=backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
),
},
)
cpu_fm.write_with_template(
f"Register{dispatch_key}CustomOps.cpp",
"RegisterDispatchKeyCustomOps.cpp",
lambda: {
"ops_headers": '#include "NativeFunctions.h"',
"ops_headers": '#include "CustomOpsNativeFunctions.h"',
"DispatchKey": dispatch_key,
"dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_definitions": "",
@ -482,35 +499,35 @@ def parse_yaml_files(
native_yaml_path = os.path.join(tmpdirname, "functions.yaml")
with open(native_yaml_path, "w"):
pass
# If custom_ops_yaml_path exists, combine both files.
if custom_ops_yaml_path and os.path.exists(custom_ops_yaml_path):
combined_yaml_path = os.path.join(tmpdirname, "combined.yaml")
with open(combined_yaml_path, "w") as tmp:
with open(native_yaml_path, "r") as native:
for line in native:
tmp.write(line)
with open(custom_ops_yaml_path, "r") as custom:
for line in custom:
tmp.write(line)
custom_ops_parsed_yaml = parse_native_yaml(
custom_ops_yaml_path, tags_yaml_path, None, skip_native_fns_gen=True
)
else:
# No custom_ops; just parse native_yaml_path.
custom_ops_parsed_yaml = None
combined_yaml_path = native_yaml_path
# Translate native_yaml_path to the same format of native_functions.yaml
translated_yaml_path = os.path.join(tmpdirname, "translated.yaml")
with open(translated_yaml_path, "w") as translated:
translate_native_yaml(
tags_yaml_path,
aten_yaml_path,
combined_yaml_path,
native_yaml_path,
use_aten_lib,
translated,
)
# If custom_ops_yaml_path doesn't exist, point to an empty file.
if not custom_ops_yaml_path or not os.path.exists(custom_ops_yaml_path):
custom_ops_yaml_path = os.path.join(tmpdirname, "custom_ops.yaml")
with open(custom_ops_yaml_path, "w"):
pass
combined_yaml_path = os.path.join(tmpdirname, "combined.yaml")
with open(combined_yaml_path, "w") as tmp, open(
translated_yaml_path, "r"
) as native, open(custom_ops_yaml_path, "r") as custom:
for line in native.readlines():
tmp.write(line)
for line in custom.readlines():
tmp.write(line)
custom_ops_parsed_yaml = parse_native_yaml(
custom_ops_yaml_path, tags_yaml_path, None, skip_native_fns_gen=True
)
parsed_yaml = parse_native_yaml(
translated_yaml_path,
combined_yaml_path,
tags_yaml_path,
None,
skip_native_fns_gen=(not gen_native_fns),