diff --git a/test/edge/CMakeLists.txt b/test/edge/CMakeLists.txt index b38214ac41a..fa1e5720215 100644 --- a/test/edge/CMakeLists.txt +++ b/test/edge/CMakeLists.txt @@ -14,11 +14,14 @@ set(GEN_COMMAND --aten_yaml_path=${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml --use_aten_lib --op_selection_yaml_path=${TEST_ROOT}/selected_operators.yaml + --custom_ops_yaml_path=${TEST_ROOT}/custom_ops.yaml ) set(GEN_COMMAND_sources ${OUTPUT_DIRECTORY}/RegisterCodegenUnboxedKernelsEverything.cpp + ${OUTPUT_DIRECTORY}/RegisterCPUCustomOps.cpp ${OUTPUT_DIRECTORY}/Functions.h ${OUTPUT_DIRECTORY}/NativeFunctions.h + ${OUTPUT_DIRECTORY}/CustomOpsNativeFunctions.h ) message(STATUS "Generating sources for unboxing kernels ${GEN_COMMAND}") add_custom_command( @@ -32,6 +35,7 @@ add_custom_command( ${TEST_ROOT}/templates/Functions.h ${TEST_ROOT}/templates/NativeFunctions.h ${TEST_ROOT}/templates/RegisterCodegenUnboxedKernels.cpp + ${TEST_ROOT}/templates/RegisterDispatchKeyCustomOps.cpp WORKING_DIRECTORY ${TORCH_ROOT} ) add_custom_target(unbox_target DEPENDS ${GEN_COMMAND_sources}) @@ -39,6 +43,7 @@ add_custom_target(unbox_target DEPENDS ${GEN_COMMAND_sources}) add_library(unbox_lib STATIC ${GEN_COMMAND_sources} ${TEST_ROOT}/operator_registry.cpp + ${TEST_ROOT}/custom_ops.cpp ) target_include_directories(unbox_lib PUBLIC ${TEST_ROOT} ${ATen_CPU_INCLUDE}) target_link_libraries(unbox_lib PUBLIC torch_cpu) diff --git a/test/edge/custom_ops.cpp b/test/edge/custom_ops.cpp new file mode 100644 index 00000000000..cce09841127 --- /dev/null +++ b/test/edge/custom_ops.cpp @@ -0,0 +1,10 @@ +#include + +namespace custom { +namespace native { +at::Tensor& add_3_out(const at::Tensor& a, const at::Tensor& b, const at::Tensor& c, at::Tensor& out) { + out = a.add(b).add(c); + return out; +} +} +} diff --git a/test/edge/custom_ops.yaml b/test/edge/custom_ops.yaml new file mode 100644 index 00000000000..b85fd12bd32 --- /dev/null +++ b/test/edge/custom_ops.yaml @@ -0,0 +1,3 @@ +- func: custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU: custom::add_3_out diff --git a/test/edge/selected_operators.yaml b/test/edge/selected_operators.yaml index d7833576640..b51ab66b8e5 100644 --- a/test/edge/selected_operators.yaml +++ b/test/edge/selected_operators.yaml @@ -448,3 +448,9 @@ operators: include_all_overloads: false is_root_operator: true is_used_for_training: true + custom::add_3.out: + debug_info: + - functions.yaml + include_all_overloads: false + is_root_operator: true + is_used_for_training: true diff --git a/test/edge/templates/RegisterDispatchKeyCustomOps.cpp b/test/edge/templates/RegisterDispatchKeyCustomOps.cpp new file mode 100644 index 00000000000..14c3d085f93 --- /dev/null +++ b/test/edge/templates/RegisterDispatchKeyCustomOps.cpp @@ -0,0 +1,27 @@ +// clang-format off +// Generated code for registering custom operators into the dispatcher. + +#include +#include + +$ops_headers + +namespace torch { +namespace executor { +namespace function { + + +${dispatch_anonymous_definitions} + +// All out variants ops +${static_init_dispatch_registrations} + +namespace ${dispatch_namespace} +{ + ${dispatch_namespaced_definitions} + +} // namespace ${dispatch_namespace} + +} // namespace function +} // namespace executor +} // namespace torch diff --git a/test/edge/templates/RegisterSchema.cpp b/test/edge/templates/RegisterSchema.cpp new file mode 100644 index 00000000000..f2ba92a4305 --- /dev/null +++ b/test/edge/templates/RegisterSchema.cpp @@ -0,0 +1,10 @@ +// ${generated_comment} +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +namespace at { +TORCH_LIBRARY_FRAGMENT(aten, m) { + ${aten_schema_registrations}; +} +$schema_registrations +} // namespace at diff --git a/test/edge/test_operator_registration.cpp b/test/edge/test_operator_registration.cpp index 9ab06898297..89aed23df28 100644 --- a/test/edge/test_operator_registration.cpp +++ b/test/edge/test_operator_registration.cpp @@ -23,6 +23,27 @@ TEST(OperatorRegistrationTest, Add) { expected = at::fill(expected, 2); ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor())); +} + +// custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!) +TEST(OperatorRegistrationTest, CustomAdd3) { + EValue values[4]; + values[0] = EValue(at::ones({2, 3})); + values[1] = EValue(at::ones({2, 3})); + values[2] = EValue(at::ones({2, 3})); + values[3] = EValue(at::zeros({2, 3})); + ASSERT_TRUE(hasOpsFn("custom::add_3.out")); + auto op = getOpsFn("custom::add_3.out"); + + EValue* kernel_values[4]; + for (size_t i = 0; i < 4; i++) { + kernel_values[i] = &values[i]; + } + op(kernel_values); + at::Tensor expected = at::ones({2, 3}); + expected = at::fill(expected, 3); + ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor())); + } } // namespace executor } // namespace torch diff --git a/tools/test/test_executorch_custom_ops.py b/tools/test/test_executorch_custom_ops.py index fedf13cb3ba..5ca261362aa 100644 --- a/tools/test/test_executorch_custom_ops.py +++ b/tools/test/test_executorch_custom_ops.py @@ -40,7 +40,7 @@ class TestComputeNativeFunctionStub(expecttest.TestCase): def test_function_schema_generates_correct_kernel_tensor_out(self) -> None: obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"} expected = """ -at::Tensor & wrapper_out_foo_out(const at::Tensor & self, at::Tensor & out) { +at::Tensor & wrapper_CPU_out_foo_out(const at::Tensor & self, at::Tensor & out) { return out; } """ @@ -49,7 +49,7 @@ at::Tensor & wrapper_out_foo_out(const at::Tensor & self, at::Tensor & out) { def test_function_schema_generates_correct_kernel_no_out(self) -> None: obj = {"func": "custom::foo.Tensor(Tensor self) -> Tensor"} expected = """ -at::Tensor wrapper_Tensor_foo(const at::Tensor & self) { +at::Tensor wrapper_CPU_Tensor_foo(const at::Tensor & self) { return self; } """ @@ -58,7 +58,7 @@ at::Tensor wrapper_Tensor_foo(const at::Tensor & self) { def test_function_schema_generates_correct_kernel_no_return(self) -> None: obj = {"func": "custom::foo(Tensor self, *, Tensor(a!)[] out) -> ()"} expected = f""" -void wrapper__foo_out(const at::Tensor & self, at::TensorList out) {{ +void wrapper_CPU__foo_out(const at::Tensor & self, at::TensorList out) {{ {SPACES} }} """ diff --git a/torchgen/executorch/api/custom_ops.py b/torchgen/executorch/api/custom_ops.py index a1bae4feca4..872158cd3da 100644 --- a/torchgen/executorch/api/custom_ops.py +++ b/torchgen/executorch/api/custom_ops.py @@ -23,7 +23,7 @@ class ComputeNativeFunctionStub: return None sig = DispatcherSignature.from_schema( - f.func, prefix=f"wrapper_{f.func.name.overload_name}_", symint=False + f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False ) assert sig is not None if len(f.func.returns) == 0: diff --git a/torchgen/gen_executorch.py b/torchgen/gen_executorch.py index 03769ee7a5f..47a7fb89ee5 100644 --- a/torchgen/gen_executorch.py +++ b/torchgen/gen_executorch.py @@ -46,7 +46,7 @@ from torchgen.utils import ( def static_dispatch( - sig: ExecutorchCppSignature, + sig: Union[CppSignature, ExecutorchCppSignature], f: NativeFunction, backend_indices: List[BackendIndex], ) -> str: @@ -99,12 +99,16 @@ class ComputeFunction: return None if Variant.function not in f.variants: return None - - if self.use_aten_lib: - comma = ", " - sig = CppSignatureGroup.from_native_function( + sig: Union[CppSignature, ExecutorchCppSignature] = ( + CppSignatureGroup.from_native_function( f, method=False, fallback_binding=f.manual_cpp_binding ).most_faithful_signature() + if self.use_aten_lib + else ExecutorchCppSignature.from_native_function(f) + ) + if self.use_aten_lib and f.namespace == "aten": + comma = ", " + return f""" // {f.namespace}::{f.func} TORCH_API inline {sig.decl()} {{ @@ -114,7 +118,7 @@ TORCH_API inline {sig.decl()} {{ else: return static_dispatch( - ExecutorchCppSignature.from_native_function(f), + sig, f, backend_indices=self.static_dispatch_backend_indices, ) @@ -280,9 +284,12 @@ def gen_headers( cpu_fm.write( "Functions.h", lambda: { - "static_dispatch_extra_headers": "#include " + "static_dispatch_extra_headers": [ + '#include "CustomOpsNativeFunctions.h"', + "#include ", + ] if use_aten_lib - else '#include "NativeFunctions.h"', + else ['#include "NativeFunctions.h"'], "Functions_declarations": gen_functions_declarations( native_functions=native_functions, static_dispatch_idx=static_dispatch_idx, @@ -314,7 +321,6 @@ def gen_custom_ops( cpu_fm: FileManager, rocm: bool, ) -> None: - dispatch_key = DispatchKey.CPU backend_index = backend_indices[dispatch_key] ( @@ -326,11 +332,22 @@ def gen_custom_ops( backend_index=backend_index, rocm=rocm, ) + cpu_fm.write_with_template( + "CustomOpsNativeFunctions.h", + "NativeFunctions.h", + lambda: { + "nativeFunctions_declarations": get_native_function_declarations( + grouped_native_functions=native_functions, + backend_indices=backend_indices, + native_function_decl_gen=dest.compute_native_function_declaration, + ), + }, + ) cpu_fm.write_with_template( f"Register{dispatch_key}CustomOps.cpp", "RegisterDispatchKeyCustomOps.cpp", lambda: { - "ops_headers": '#include "NativeFunctions.h"', + "ops_headers": '#include "CustomOpsNativeFunctions.h"', "DispatchKey": dispatch_key, "dispatch_namespace": dispatch_key.lower(), "dispatch_namespaced_definitions": "", @@ -482,35 +499,35 @@ def parse_yaml_files( native_yaml_path = os.path.join(tmpdirname, "functions.yaml") with open(native_yaml_path, "w"): pass - - # If custom_ops_yaml_path exists, combine both files. - if custom_ops_yaml_path and os.path.exists(custom_ops_yaml_path): - combined_yaml_path = os.path.join(tmpdirname, "combined.yaml") - with open(combined_yaml_path, "w") as tmp: - with open(native_yaml_path, "r") as native: - for line in native: - tmp.write(line) - with open(custom_ops_yaml_path, "r") as custom: - for line in custom: - tmp.write(line) - custom_ops_parsed_yaml = parse_native_yaml( - custom_ops_yaml_path, tags_yaml_path, None, skip_native_fns_gen=True - ) - else: - # No custom_ops; just parse native_yaml_path. - custom_ops_parsed_yaml = None - combined_yaml_path = native_yaml_path + # Translate native_yaml_path to the same format of native_functions.yaml translated_yaml_path = os.path.join(tmpdirname, "translated.yaml") with open(translated_yaml_path, "w") as translated: translate_native_yaml( tags_yaml_path, aten_yaml_path, - combined_yaml_path, + native_yaml_path, use_aten_lib, translated, ) + # If custom_ops_yaml_path doesn't exist, point to an empty file. + if not custom_ops_yaml_path or not os.path.exists(custom_ops_yaml_path): + custom_ops_yaml_path = os.path.join(tmpdirname, "custom_ops.yaml") + with open(custom_ops_yaml_path, "w"): + pass + combined_yaml_path = os.path.join(tmpdirname, "combined.yaml") + with open(combined_yaml_path, "w") as tmp, open( + translated_yaml_path, "r" + ) as native, open(custom_ops_yaml_path, "r") as custom: + for line in native.readlines(): + tmp.write(line) + for line in custom.readlines(): + tmp.write(line) + custom_ops_parsed_yaml = parse_native_yaml( + custom_ops_yaml_path, tags_yaml_path, None, skip_native_fns_gen=True + ) + parsed_yaml = parse_native_yaml( - translated_yaml_path, + combined_yaml_path, tags_yaml_path, None, skip_native_fns_gen=(not gen_native_fns),