mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ET] Add RuntimeContext to ET Aten mode (#96084)
Summary:
In ATen mode, we add the RuntimeContext arg, so we have something like
```
TORCH_API inline at::Tensor & gelu_outf(torch::executor::RuntimeContext & context, const at::Tensor & self, c10::string_view approximate, at::Tensor & out) {
return at::gelu_outf(self, approximate, out);
}
```
and user can use `<namespace like aten>::gelu_outf` and we will automatically dispatch the registered function in aten kernel using `at::gelu_outf` (dispatched by ATen/Functions.h header)
In optimized kernel tests, we can now automatically handle between aten kernel and optimized kernel.
The implication is that the test must depend on the correctness of codegen; an error in codegen can break the kernel tests.
Test Plan: CI
Differential Revision: D43777848
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96084
Approved by: https://github.com/larryliu0820
This commit is contained in:
parent
c88aa336aa
commit
93ff71ec37
|
|
@ -288,3 +288,20 @@ def define_tools_targets(
|
|||
":autograd",
|
||||
],
|
||||
)
|
||||
|
||||
python_test(
|
||||
name = "test_torchgen_executorch",
|
||||
srcs = [
|
||||
"test/test_executorch_custom_ops.py",
|
||||
"test/test_executorch_gen.py",
|
||||
"test/test_executorch_signatures.py",
|
||||
"test/test_executorch_types.py",
|
||||
"test/test_executorch_unboxing.py",
|
||||
],
|
||||
contacts = contacts,
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
torchgen_deps,
|
||||
"fbsource//third-party/pypi/expecttest:expecttest",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -203,3 +203,27 @@ TORCH_API inline bool op_2(torch::executor::RuntimeContext & context) {
|
|||
"""
|
||||
in declarations
|
||||
)
|
||||
|
||||
def test_aten_lib_has_context_arg(self) -> None:
|
||||
declarations = gen_functions_declarations(
|
||||
native_functions=[
|
||||
self.custom_1_native_function,
|
||||
],
|
||||
static_dispatch_idx=self.static_dispatch_idx,
|
||||
selector=SelectiveBuilder.get_nop_selector(),
|
||||
use_aten_lib=True,
|
||||
)
|
||||
print(declarations)
|
||||
self.assertTrue(
|
||||
"""
|
||||
namespace custom_1 {
|
||||
|
||||
// custom_1::op_1() -> bool
|
||||
TORCH_API inline bool op_1(torch::executor::RuntimeContext & context) {
|
||||
return at::op_1();
|
||||
}
|
||||
|
||||
} // namespace custom_1
|
||||
"""
|
||||
in declarations
|
||||
)
|
||||
|
|
|
|||
|
|
@ -48,6 +48,22 @@ from torchgen.utils import (
|
|||
)
|
||||
|
||||
|
||||
def _sig_decl_wrapper(sig: Union[CppSignature, ExecutorchCppSignature]) -> str:
|
||||
"""
|
||||
A wrapper function to basically get `sig.decl(include_context=True)`.
|
||||
For ATen kernel, the codegen has no idea about ET contextArg, so we
|
||||
use this wrapper to add it.
|
||||
"""
|
||||
if isinstance(sig, ExecutorchCppSignature):
|
||||
return sig.decl()
|
||||
|
||||
returns_type = aten_cpp.returns_type(sig.func.returns).cpp_type()
|
||||
cpp_args = [a.decl() for a in sig.arguments()]
|
||||
cpp_args_str = ", ".join([contextArg.decl()] + cpp_args)
|
||||
sig_decl = f"{returns_type} {sig.name()}({cpp_args_str})"
|
||||
return sig_decl
|
||||
|
||||
|
||||
def static_dispatch(
|
||||
sig: Union[CppSignature, ExecutorchCppSignature],
|
||||
f: NativeFunction,
|
||||
|
|
@ -80,7 +96,7 @@ ET_ASSERT_UNREACHABLE_MSG("The number of native function(s) binding to {f.func.n
|
|||
"""
|
||||
return f"""
|
||||
// {f.namespace}::{f.func}
|
||||
TORCH_API inline {sig.decl()} {{
|
||||
TORCH_API inline {_sig_decl_wrapper(sig)} {{
|
||||
{static_block}
|
||||
}}
|
||||
"""
|
||||
|
|
@ -116,10 +132,10 @@ class ComputeFunction:
|
|||
|
||||
return f"""
|
||||
// {f.namespace}::{f.func}
|
||||
TORCH_API inline {sig.decl()} {{
|
||||
TORCH_API inline {_sig_decl_wrapper(sig)} {{
|
||||
return at::{sig.name()}({comma.join(e.name for e in sig.arguments())});
|
||||
}}
|
||||
"""
|
||||
"""
|
||||
|
||||
else:
|
||||
return static_dispatch(
|
||||
|
|
@ -188,11 +204,10 @@ class ComputeCodegenUnboxedKernels:
|
|||
Operator(
|
||||
"{f.namespace}::{f.func.name}",
|
||||
[]({contextArg.defn()}, EValue** stack) {{
|
||||
{"(void)context;" if self.use_aten_lib else ""}
|
||||
{code_connector.join(code_list)}
|
||||
|
||||
EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}");
|
||||
{ret_prefix}torch::executor::{f.namespace}::{sig.name()}({"" if self.use_aten_lib else "context, "}{args_str});
|
||||
{ret_prefix}torch::executor::{f.namespace}::{sig.name()}({"context, "}{args_str});
|
||||
|
||||
{return_assignment}
|
||||
}}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user