mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[torchgen] Add CI job to cover custom ops registration for Executorch (#91291)
As titled. To register a custom op into Executorch, we need: * `custom_ops.yaml`, defines the operator schema and the corresponding native function. * `custom_ops.cpp`, defines the kernel. * `RegisterDispatchKeyCustomOps.cpp`, a template to register operator into PyTorch. Added a new test for custom ops. The custom op `custom::add_3.out` takes 3 tensors and add them together. The test makes sure it is registered correctly and then verifies the outcome is correct. Differential Revision: [D42204263](https://our.internmc.facebook.com/intern/diff/D42204263/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/91291 Approved by: https://github.com/ezyang
This commit is contained in:
parent
66b324cf06
commit
7568484d54
|
|
@ -14,11 +14,14 @@ set(GEN_COMMAND
|
|||
--aten_yaml_path=${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml
|
||||
--use_aten_lib
|
||||
--op_selection_yaml_path=${TEST_ROOT}/selected_operators.yaml
|
||||
--custom_ops_yaml_path=${TEST_ROOT}/custom_ops.yaml
|
||||
)
|
||||
set(GEN_COMMAND_sources
|
||||
${OUTPUT_DIRECTORY}/RegisterCodegenUnboxedKernelsEverything.cpp
|
||||
${OUTPUT_DIRECTORY}/RegisterCPUCustomOps.cpp
|
||||
${OUTPUT_DIRECTORY}/Functions.h
|
||||
${OUTPUT_DIRECTORY}/NativeFunctions.h
|
||||
${OUTPUT_DIRECTORY}/CustomOpsNativeFunctions.h
|
||||
)
|
||||
message(STATUS "Generating sources for unboxing kernels ${GEN_COMMAND}")
|
||||
add_custom_command(
|
||||
|
|
@ -32,6 +35,7 @@ add_custom_command(
|
|||
${TEST_ROOT}/templates/Functions.h
|
||||
${TEST_ROOT}/templates/NativeFunctions.h
|
||||
${TEST_ROOT}/templates/RegisterCodegenUnboxedKernels.cpp
|
||||
${TEST_ROOT}/templates/RegisterDispatchKeyCustomOps.cpp
|
||||
WORKING_DIRECTORY ${TORCH_ROOT}
|
||||
)
|
||||
add_custom_target(unbox_target DEPENDS ${GEN_COMMAND_sources})
|
||||
|
|
@ -39,6 +43,7 @@ add_custom_target(unbox_target DEPENDS ${GEN_COMMAND_sources})
|
|||
add_library(unbox_lib STATIC
|
||||
${GEN_COMMAND_sources}
|
||||
${TEST_ROOT}/operator_registry.cpp
|
||||
${TEST_ROOT}/custom_ops.cpp
|
||||
)
|
||||
target_include_directories(unbox_lib PUBLIC ${TEST_ROOT} ${ATen_CPU_INCLUDE})
|
||||
target_link_libraries(unbox_lib PUBLIC torch_cpu)
|
||||
|
|
|
|||
10
test/edge/custom_ops.cpp
Normal file
10
test/edge/custom_ops.cpp
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
#include <ATen/Tensor.h>
|
||||
|
||||
namespace custom {
|
||||
namespace native {
|
||||
at::Tensor& add_3_out(const at::Tensor& a, const at::Tensor& b, const at::Tensor& c, at::Tensor& out) {
|
||||
out = a.add(b).add(c);
|
||||
return out;
|
||||
}
|
||||
}
|
||||
}
|
||||
3
test/edge/custom_ops.yaml
Normal file
3
test/edge/custom_ops.yaml
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
- func: custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: custom::add_3_out
|
||||
|
|
@ -448,3 +448,9 @@ operators:
|
|||
include_all_overloads: false
|
||||
is_root_operator: true
|
||||
is_used_for_training: true
|
||||
custom::add_3.out:
|
||||
debug_info:
|
||||
- functions.yaml
|
||||
include_all_overloads: false
|
||||
is_root_operator: true
|
||||
is_used_for_training: true
|
||||
|
|
|
|||
27
test/edge/templates/RegisterDispatchKeyCustomOps.cpp
Normal file
27
test/edge/templates/RegisterDispatchKeyCustomOps.cpp
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
// clang-format off
|
||||
// Generated code for registering custom operators into the dispatcher.
|
||||
|
||||
#include <torch/library.h>
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
$ops_headers
|
||||
|
||||
namespace torch {
|
||||
namespace executor {
|
||||
namespace function {
|
||||
|
||||
|
||||
${dispatch_anonymous_definitions}
|
||||
|
||||
// All out variants ops
|
||||
${static_init_dispatch_registrations}
|
||||
|
||||
namespace ${dispatch_namespace}
|
||||
{
|
||||
${dispatch_namespaced_definitions}
|
||||
|
||||
} // namespace ${dispatch_namespace}
|
||||
|
||||
} // namespace function
|
||||
} // namespace executor
|
||||
} // namespace torch
|
||||
10
test/edge/templates/RegisterSchema.cpp
Normal file
10
test/edge/templates/RegisterSchema.cpp
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
// ${generated_comment}
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at {
|
||||
TORCH_LIBRARY_FRAGMENT(aten, m) {
|
||||
${aten_schema_registrations};
|
||||
}
|
||||
$schema_registrations
|
||||
} // namespace at
|
||||
|
|
@ -23,6 +23,27 @@ TEST(OperatorRegistrationTest, Add) {
|
|||
expected = at::fill(expected, 2);
|
||||
ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
|
||||
|
||||
}
|
||||
|
||||
// custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!)
|
||||
TEST(OperatorRegistrationTest, CustomAdd3) {
|
||||
EValue values[4];
|
||||
values[0] = EValue(at::ones({2, 3}));
|
||||
values[1] = EValue(at::ones({2, 3}));
|
||||
values[2] = EValue(at::ones({2, 3}));
|
||||
values[3] = EValue(at::zeros({2, 3}));
|
||||
ASSERT_TRUE(hasOpsFn("custom::add_3.out"));
|
||||
auto op = getOpsFn("custom::add_3.out");
|
||||
|
||||
EValue* kernel_values[4];
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
kernel_values[i] = &values[i];
|
||||
}
|
||||
op(kernel_values);
|
||||
at::Tensor expected = at::ones({2, 3});
|
||||
expected = at::fill(expected, 3);
|
||||
ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
|
||||
|
||||
}
|
||||
} // namespace executor
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class TestComputeNativeFunctionStub(expecttest.TestCase):
|
|||
def test_function_schema_generates_correct_kernel_tensor_out(self) -> None:
|
||||
obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"}
|
||||
expected = """
|
||||
at::Tensor & wrapper_out_foo_out(const at::Tensor & self, at::Tensor & out) {
|
||||
at::Tensor & wrapper_CPU_out_foo_out(const at::Tensor & self, at::Tensor & out) {
|
||||
return out;
|
||||
}
|
||||
"""
|
||||
|
|
@ -49,7 +49,7 @@ at::Tensor & wrapper_out_foo_out(const at::Tensor & self, at::Tensor & out) {
|
|||
def test_function_schema_generates_correct_kernel_no_out(self) -> None:
|
||||
obj = {"func": "custom::foo.Tensor(Tensor self) -> Tensor"}
|
||||
expected = """
|
||||
at::Tensor wrapper_Tensor_foo(const at::Tensor & self) {
|
||||
at::Tensor wrapper_CPU_Tensor_foo(const at::Tensor & self) {
|
||||
return self;
|
||||
}
|
||||
"""
|
||||
|
|
@ -58,7 +58,7 @@ at::Tensor wrapper_Tensor_foo(const at::Tensor & self) {
|
|||
def test_function_schema_generates_correct_kernel_no_return(self) -> None:
|
||||
obj = {"func": "custom::foo(Tensor self, *, Tensor(a!)[] out) -> ()"}
|
||||
expected = f"""
|
||||
void wrapper__foo_out(const at::Tensor & self, at::TensorList out) {{
|
||||
void wrapper_CPU__foo_out(const at::Tensor & self, at::TensorList out) {{
|
||||
{SPACES}
|
||||
}}
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class ComputeNativeFunctionStub:
|
|||
return None
|
||||
|
||||
sig = DispatcherSignature.from_schema(
|
||||
f.func, prefix=f"wrapper_{f.func.name.overload_name}_", symint=False
|
||||
f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
|
||||
)
|
||||
assert sig is not None
|
||||
if len(f.func.returns) == 0:
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ from torchgen.utils import (
|
|||
|
||||
|
||||
def static_dispatch(
|
||||
sig: ExecutorchCppSignature,
|
||||
sig: Union[CppSignature, ExecutorchCppSignature],
|
||||
f: NativeFunction,
|
||||
backend_indices: List[BackendIndex],
|
||||
) -> str:
|
||||
|
|
@ -99,12 +99,16 @@ class ComputeFunction:
|
|||
return None
|
||||
if Variant.function not in f.variants:
|
||||
return None
|
||||
|
||||
if self.use_aten_lib:
|
||||
comma = ", "
|
||||
sig = CppSignatureGroup.from_native_function(
|
||||
sig: Union[CppSignature, ExecutorchCppSignature] = (
|
||||
CppSignatureGroup.from_native_function(
|
||||
f, method=False, fallback_binding=f.manual_cpp_binding
|
||||
).most_faithful_signature()
|
||||
if self.use_aten_lib
|
||||
else ExecutorchCppSignature.from_native_function(f)
|
||||
)
|
||||
if self.use_aten_lib and f.namespace == "aten":
|
||||
comma = ", "
|
||||
|
||||
return f"""
|
||||
// {f.namespace}::{f.func}
|
||||
TORCH_API inline {sig.decl()} {{
|
||||
|
|
@ -114,7 +118,7 @@ TORCH_API inline {sig.decl()} {{
|
|||
|
||||
else:
|
||||
return static_dispatch(
|
||||
ExecutorchCppSignature.from_native_function(f),
|
||||
sig,
|
||||
f,
|
||||
backend_indices=self.static_dispatch_backend_indices,
|
||||
)
|
||||
|
|
@ -280,9 +284,12 @@ def gen_headers(
|
|||
cpu_fm.write(
|
||||
"Functions.h",
|
||||
lambda: {
|
||||
"static_dispatch_extra_headers": "#include <ATen/Functions.h>"
|
||||
"static_dispatch_extra_headers": [
|
||||
'#include "CustomOpsNativeFunctions.h"',
|
||||
"#include <ATen/Functions.h>",
|
||||
]
|
||||
if use_aten_lib
|
||||
else '#include "NativeFunctions.h"',
|
||||
else ['#include "NativeFunctions.h"'],
|
||||
"Functions_declarations": gen_functions_declarations(
|
||||
native_functions=native_functions,
|
||||
static_dispatch_idx=static_dispatch_idx,
|
||||
|
|
@ -314,7 +321,6 @@ def gen_custom_ops(
|
|||
cpu_fm: FileManager,
|
||||
rocm: bool,
|
||||
) -> None:
|
||||
|
||||
dispatch_key = DispatchKey.CPU
|
||||
backend_index = backend_indices[dispatch_key]
|
||||
(
|
||||
|
|
@ -326,11 +332,22 @@ def gen_custom_ops(
|
|||
backend_index=backend_index,
|
||||
rocm=rocm,
|
||||
)
|
||||
cpu_fm.write_with_template(
|
||||
"CustomOpsNativeFunctions.h",
|
||||
"NativeFunctions.h",
|
||||
lambda: {
|
||||
"nativeFunctions_declarations": get_native_function_declarations(
|
||||
grouped_native_functions=native_functions,
|
||||
backend_indices=backend_indices,
|
||||
native_function_decl_gen=dest.compute_native_function_declaration,
|
||||
),
|
||||
},
|
||||
)
|
||||
cpu_fm.write_with_template(
|
||||
f"Register{dispatch_key}CustomOps.cpp",
|
||||
"RegisterDispatchKeyCustomOps.cpp",
|
||||
lambda: {
|
||||
"ops_headers": '#include "NativeFunctions.h"',
|
||||
"ops_headers": '#include "CustomOpsNativeFunctions.h"',
|
||||
"DispatchKey": dispatch_key,
|
||||
"dispatch_namespace": dispatch_key.lower(),
|
||||
"dispatch_namespaced_definitions": "",
|
||||
|
|
@ -482,35 +499,35 @@ def parse_yaml_files(
|
|||
native_yaml_path = os.path.join(tmpdirname, "functions.yaml")
|
||||
with open(native_yaml_path, "w"):
|
||||
pass
|
||||
|
||||
# If custom_ops_yaml_path exists, combine both files.
|
||||
if custom_ops_yaml_path and os.path.exists(custom_ops_yaml_path):
|
||||
combined_yaml_path = os.path.join(tmpdirname, "combined.yaml")
|
||||
with open(combined_yaml_path, "w") as tmp:
|
||||
with open(native_yaml_path, "r") as native:
|
||||
for line in native:
|
||||
tmp.write(line)
|
||||
with open(custom_ops_yaml_path, "r") as custom:
|
||||
for line in custom:
|
||||
tmp.write(line)
|
||||
custom_ops_parsed_yaml = parse_native_yaml(
|
||||
custom_ops_yaml_path, tags_yaml_path, None, skip_native_fns_gen=True
|
||||
)
|
||||
else:
|
||||
# No custom_ops; just parse native_yaml_path.
|
||||
custom_ops_parsed_yaml = None
|
||||
combined_yaml_path = native_yaml_path
|
||||
# Translate native_yaml_path to the same format of native_functions.yaml
|
||||
translated_yaml_path = os.path.join(tmpdirname, "translated.yaml")
|
||||
with open(translated_yaml_path, "w") as translated:
|
||||
translate_native_yaml(
|
||||
tags_yaml_path,
|
||||
aten_yaml_path,
|
||||
combined_yaml_path,
|
||||
native_yaml_path,
|
||||
use_aten_lib,
|
||||
translated,
|
||||
)
|
||||
# If custom_ops_yaml_path doesn't exist, point to an empty file.
|
||||
if not custom_ops_yaml_path or not os.path.exists(custom_ops_yaml_path):
|
||||
custom_ops_yaml_path = os.path.join(tmpdirname, "custom_ops.yaml")
|
||||
with open(custom_ops_yaml_path, "w"):
|
||||
pass
|
||||
combined_yaml_path = os.path.join(tmpdirname, "combined.yaml")
|
||||
with open(combined_yaml_path, "w") as tmp, open(
|
||||
translated_yaml_path, "r"
|
||||
) as native, open(custom_ops_yaml_path, "r") as custom:
|
||||
for line in native.readlines():
|
||||
tmp.write(line)
|
||||
for line in custom.readlines():
|
||||
tmp.write(line)
|
||||
custom_ops_parsed_yaml = parse_native_yaml(
|
||||
custom_ops_yaml_path, tags_yaml_path, None, skip_native_fns_gen=True
|
||||
)
|
||||
|
||||
parsed_yaml = parse_native_yaml(
|
||||
translated_yaml_path,
|
||||
combined_yaml_path,
|
||||
tags_yaml_path,
|
||||
None,
|
||||
skip_native_fns_gen=(not gen_native_fns),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user