mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTI] Add an option to specify custom op C shim (#153851)
Summary: Add an option to tell AOTInductor codegen to generate C shim functions for certain custom ops instead of relying on ProxyExecutor. The lib that defines custom ops need to implement corresponding C shim functions. Differential Revision: [D75014177](https://our.internmc.facebook.com/intern/diff/D75014177) Pull Request resolved: https://github.com/pytorch/pytorch/pull/153851 Approved by: https://github.com/hl475
This commit is contained in:
parent
89ebd29fdc
commit
365ac49840
|
|
@ -1,5 +1,8 @@
|
|||
#include <torch/csrc/api/include/torch/types.h> // @manual=fbcode//caffe2:libtorch
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
|
@ -310,8 +313,40 @@ void fn_out_variant_without_return_meta(
|
|||
Tensor& out) {
|
||||
}
|
||||
|
||||
Tensor fn_square_impl(const Tensor& tensor) {
|
||||
return tensor * tensor;
|
||||
}
|
||||
|
||||
Tensor fn_square_meta(const Tensor& tensor) {
|
||||
return at::empty_like(tensor);
|
||||
}
|
||||
} // namespace at
|
||||
|
||||
|
||||
extern "C" {
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_cpu_fn_square(
|
||||
AtenTensorHandle input,
|
||||
AtenTensorHandle* ret) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto tmp_result = at::fn_square_impl(
|
||||
torch::aot_inductor::resolve_tensor_dispatch_flags(input));
|
||||
*ret = torch::aot_inductor::new_tensor_handle(std::move(tmp_result));
|
||||
});
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_cuda_fn_square(
|
||||
AtenTensorHandle input,
|
||||
AtenTensorHandle* ret) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto tmp_result = at::fn_square_impl(
|
||||
torch::aot_inductor::resolve_tensor_dispatch_flags(input));
|
||||
*ret = torch::aot_inductor::new_tensor_handle(std::move(tmp_result));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(aoti_custom_ops, m) {
|
||||
m.def("custom_add(Tensor t1, Tensor t2) -> Tensor");
|
||||
m.def(
|
||||
|
|
@ -354,6 +389,7 @@ TORCH_LIBRARY(aoti_custom_ops, m) {
|
|||
"fn_with_input_mutation(Tensor(a!) t0, Tensor t1, Tensor(b!) t2) -> (Tensor, Tensor)");
|
||||
|
||||
m.def("fn_out_variant_without_return(Tensor x, Tensor(a!) out) -> ()");
|
||||
m.def("fn_square(Tensor x) -> Tensor");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
|
||||
|
|
@ -365,6 +401,7 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
|
|||
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_impl);
|
||||
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_impl);
|
||||
m.impl("fn_out_variant_without_return", at::fn_out_variant_without_return_impl);
|
||||
m.impl("fn_square", at::fn_square_impl);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
|
||||
|
|
@ -375,4 +412,5 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
|
|||
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_meta);
|
||||
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_meta);
|
||||
m.impl("fn_out_variant_without_return", at::fn_out_variant_without_return_meta);
|
||||
m.impl("fn_square", at::fn_square_meta);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ from torch.testing._internal.common_utils import (
|
|||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
skipIfRocm,
|
||||
skipIfXpu,
|
||||
)
|
||||
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
||||
from torch.testing._internal.triton_utils import HAS_CUDA
|
||||
|
|
@ -356,6 +358,37 @@ class AOTInductorTestsTemplate:
|
|||
self.assertEqual(len(inps), 0)
|
||||
self.assertTrue(sentinel_seen)
|
||||
|
||||
@skipIfXpu
|
||||
@skipIfRocm
|
||||
def test_custom_op_square(self) -> None:
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.ops.aoti_custom_ops.fn_square(x)
|
||||
|
||||
m = Model().to(device=self.device)
|
||||
args = (torch.randn(2, 3, device=self.device),)
|
||||
with config.patch(
|
||||
"aot_inductor.custom_ops_to_c_shims",
|
||||
{
|
||||
torch.ops.aoti_custom_ops.fn_square.default: [
|
||||
"""
|
||||
AOTITorchError
|
||||
aoti_torch_cpu_fn_square(
|
||||
AtenTensorHandle input,
|
||||
AtenTensorHandle* ret)""",
|
||||
"""
|
||||
AOTITorchError
|
||||
aoti_torch_cuda_fn_square(
|
||||
AtenTensorHandle input,
|
||||
AtenTensorHandle* ret)""",
|
||||
],
|
||||
},
|
||||
), config.patch(
|
||||
"aot_inductor.custom_op_libs",
|
||||
["aoti_custom_ops"],
|
||||
):
|
||||
self.check_model(m, args)
|
||||
|
||||
|
||||
class AOTInductorLoggingTest(LoggingTestCase):
|
||||
@make_logging_test(dynamic=logging.DEBUG)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import math
|
|||
import os
|
||||
import sys
|
||||
import textwrap
|
||||
from itertools import count
|
||||
from itertools import chain, count
|
||||
from typing import Callable, Optional, Protocol, TYPE_CHECKING, Union
|
||||
|
||||
import sympy
|
||||
|
|
@ -237,6 +237,22 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
if V.graph.is_const_graph:
|
||||
# We do not write prefix for constant graph, it will be written by main module.
|
||||
return
|
||||
if config.aot_inductor.custom_ops_to_c_shims:
|
||||
# custom_ops_to_c_shims contains declaration of custom ops with C shim.
|
||||
# TODO: this could be auto-generated from a passed-in custom op schema
|
||||
custom_c_shims = list(
|
||||
chain(*config.aot_inductor.custom_ops_to_c_shims.values())
|
||||
)
|
||||
declarations = "\n".join(
|
||||
[f"export {textwrap.dedent(shim)};" for shim in custom_c_shims]
|
||||
)
|
||||
self.prefix.splice(
|
||||
f"""
|
||||
extern "C" {{
|
||||
{declarations}
|
||||
}}
|
||||
"""
|
||||
)
|
||||
if V.graph.aot_mode:
|
||||
self.prefix.writeline("namespace torch::aot_inductor {")
|
||||
|
||||
|
|
|
|||
|
|
@ -1309,6 +1309,11 @@ class aot_inductor:
|
|||
# Embed generated .cubin files into the .so
|
||||
embed_cubin: bool = False
|
||||
|
||||
# Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict
|
||||
custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {}
|
||||
# custom op libs that have implemented C shim wrappers
|
||||
custom_op_libs: list[str] = []
|
||||
|
||||
|
||||
class cuda:
|
||||
"""Settings for cuda backend, today this consists of cutlass"""
|
||||
|
|
|
|||
|
|
@ -1324,6 +1324,8 @@ def get_cpp_torch_device_options(
|
|||
# Only add link args, when compile_only is false.
|
||||
passthrough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"]
|
||||
|
||||
libraries += config.aot_inductor.custom_op_libs
|
||||
|
||||
return (
|
||||
definitions,
|
||||
include_dirs,
|
||||
|
|
|
|||
|
|
@ -6997,7 +6997,10 @@ class FallbackKernel(ExternKernelAlloc):
|
|||
assert isinstance(kernel, torch._ops.OpOverload)
|
||||
elif V.graph.cpp_wrapper:
|
||||
# For non-aten OpOverload, i.e. custom ops
|
||||
self.use_runtime_dispatch = True
|
||||
# If the op is in custom_ops_to_c_shims, generate direct function call
|
||||
self.use_runtime_dispatch = (
|
||||
kernel not in config.aot_inductor.custom_ops_to_c_shims
|
||||
)
|
||||
|
||||
def do_runtime_dispatch() -> None:
|
||||
args = None
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user