[ET] Add RuntimeContext to ET Aten mode (#96084)

Summary:
In ATen mode, we add the RuntimeContext arg, so we have something like
```
TORCH_API inline at::Tensor & gelu_outf(torch::executor::RuntimeContext & context, const at::Tensor & self, c10::string_view approximate, at::Tensor & out) {
    return at::gelu_outf(self, approximate, out);
}
```
and user can use `<namespace like aten>::gelu_outf` and we will automatically dispatch the registered function in aten kernel using `at::gelu_outf` (dispatched by ATen/Functions.h header)

In optimized kernel tests, we can now automatically handle between aten kernel and optimized kernel.

The implication is that the test must depend on the correctness of codegen; an error in codegen can break the kernel tests.

Test Plan: CI

Differential Revision: D43777848

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96084
Approved by: https://github.com/larryliu0820
This commit is contained in:
Hansong Zhang 2023-03-08 02:51:47 +00:00 committed by PyTorch MergeBot
parent c88aa336aa
commit 93ff71ec37
3 changed files with 61 additions and 5 deletions

View File

@ -288,3 +288,20 @@ def define_tools_targets(
":autograd",
],
)
python_test(
name = "test_torchgen_executorch",
srcs = [
"test/test_executorch_custom_ops.py",
"test/test_executorch_gen.py",
"test/test_executorch_signatures.py",
"test/test_executorch_types.py",
"test/test_executorch_unboxing.py",
],
contacts = contacts,
visibility = ["PUBLIC"],
deps = [
torchgen_deps,
"fbsource//third-party/pypi/expecttest:expecttest",
],
)

View File

@ -203,3 +203,27 @@ TORCH_API inline bool op_2(torch::executor::RuntimeContext & context) {
"""
in declarations
)
def test_aten_lib_has_context_arg(self) -> None:
declarations = gen_functions_declarations(
native_functions=[
self.custom_1_native_function,
],
static_dispatch_idx=self.static_dispatch_idx,
selector=SelectiveBuilder.get_nop_selector(),
use_aten_lib=True,
)
print(declarations)
self.assertTrue(
"""
namespace custom_1 {
// custom_1::op_1() -> bool
TORCH_API inline bool op_1(torch::executor::RuntimeContext & context) {
return at::op_1();
}
} // namespace custom_1
"""
in declarations
)

View File

@ -48,6 +48,22 @@ from torchgen.utils import (
)
def _sig_decl_wrapper(sig: Union[CppSignature, ExecutorchCppSignature]) -> str:
"""
A wrapper function to basically get `sig.decl(include_context=True)`.
For ATen kernel, the codegen has no idea about ET contextArg, so we
use this wrapper to add it.
"""
if isinstance(sig, ExecutorchCppSignature):
return sig.decl()
returns_type = aten_cpp.returns_type(sig.func.returns).cpp_type()
cpp_args = [a.decl() for a in sig.arguments()]
cpp_args_str = ", ".join([contextArg.decl()] + cpp_args)
sig_decl = f"{returns_type} {sig.name()}({cpp_args_str})"
return sig_decl
def static_dispatch(
sig: Union[CppSignature, ExecutorchCppSignature],
f: NativeFunction,
@ -80,7 +96,7 @@ ET_ASSERT_UNREACHABLE_MSG("The number of native function(s) binding to {f.func.n
"""
return f"""
// {f.namespace}::{f.func}
TORCH_API inline {sig.decl()} {{
TORCH_API inline {_sig_decl_wrapper(sig)} {{
{static_block}
}}
"""
@ -116,10 +132,10 @@ class ComputeFunction:
return f"""
// {f.namespace}::{f.func}
TORCH_API inline {sig.decl()} {{
TORCH_API inline {_sig_decl_wrapper(sig)} {{
return at::{sig.name()}({comma.join(e.name for e in sig.arguments())});
}}
"""
"""
else:
return static_dispatch(
@ -188,11 +204,10 @@ class ComputeCodegenUnboxedKernels:
Operator(
"{f.namespace}::{f.func.name}",
[]({contextArg.defn()}, EValue** stack) {{
{"(void)context;" if self.use_aten_lib else ""}
{code_connector.join(code_list)}
EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}");
{ret_prefix}torch::executor::{f.namespace}::{sig.name()}({"" if self.use_aten_lib else "context, "}{args_str});
{ret_prefix}torch::executor::{f.namespace}::{sig.name()}({"context, "}{args_str});
{return_assignment}
}}