Revert "[AOTI] Add an option to specify custom op C shim (#153851)"

This reverts commit 365ac49840.

Reverted https://github.com/pytorch/pytorch/pull/153851 on behalf of https://github.com/malfet due to Looks like it broke fuzzer test, but I could be wrong, see c4d1ff02f8/1 ([comment](https://github.com/pytorch/pytorch/pull/153851#issuecomment-2894619773))
This commit is contained in:
PyTorch MergeBot 2025-05-20 14:23:50 +00:00
parent c4d1ff02f8
commit 3102ae6798
6 changed files with 2 additions and 99 deletions

View File

@ -1,8 +1,5 @@
#include <torch/csrc/api/include/torch/types.h> // @manual=fbcode//caffe2:libtorch
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <cstdint>
#include <iostream>
#include <string>
@ -313,40 +310,8 @@ void fn_out_variant_without_return_meta(
Tensor& out) {
}
Tensor fn_square_impl(const Tensor& tensor) {
return tensor * tensor;
}
Tensor fn_square_meta(const Tensor& tensor) {
return at::empty_like(tensor);
}
} // namespace at
extern "C" {
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_cpu_fn_square(
AtenTensorHandle input,
AtenTensorHandle* ret) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto tmp_result = at::fn_square_impl(
torch::aot_inductor::resolve_tensor_dispatch_flags(input));
*ret = torch::aot_inductor::new_tensor_handle(std::move(tmp_result));
});
}
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_cuda_fn_square(
AtenTensorHandle input,
AtenTensorHandle* ret) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto tmp_result = at::fn_square_impl(
torch::aot_inductor::resolve_tensor_dispatch_flags(input));
*ret = torch::aot_inductor::new_tensor_handle(std::move(tmp_result));
});
}
}
TORCH_LIBRARY(aoti_custom_ops, m) {
m.def("custom_add(Tensor t1, Tensor t2) -> Tensor");
m.def(
@ -389,7 +354,6 @@ TORCH_LIBRARY(aoti_custom_ops, m) {
"fn_with_input_mutation(Tensor(a!) t0, Tensor t1, Tensor(b!) t2) -> (Tensor, Tensor)");
m.def("fn_out_variant_without_return(Tensor x, Tensor(a!) out) -> ()");
m.def("fn_square(Tensor x) -> Tensor");
}
TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
@ -401,7 +365,6 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_impl);
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_impl);
m.impl("fn_out_variant_without_return", at::fn_out_variant_without_return_impl);
m.impl("fn_square", at::fn_square_impl);
}
TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
@ -412,5 +375,4 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_meta);
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_meta);
m.impl("fn_out_variant_without_return", at::fn_out_variant_without_return_meta);
m.impl("fn_square", at::fn_square_meta);
}

View File

@ -20,8 +20,6 @@ from torch.testing._internal.common_utils import (
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
skipIfRocm,
skipIfXpu,
)
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
from torch.testing._internal.triton_utils import HAS_CUDA
@ -358,37 +356,6 @@ class AOTInductorTestsTemplate:
self.assertEqual(len(inps), 0)
self.assertTrue(sentinel_seen)
@skipIfXpu
@skipIfRocm
def test_custom_op_square(self) -> None:
class Model(torch.nn.Module):
def forward(self, x):
return torch.ops.aoti_custom_ops.fn_square(x)
m = Model().to(device=self.device)
args = (torch.randn(2, 3, device=self.device),)
with config.patch(
"aot_inductor.custom_ops_to_c_shims",
{
torch.ops.aoti_custom_ops.fn_square.default: [
"""
AOTITorchError
aoti_torch_cpu_fn_square(
AtenTensorHandle input,
AtenTensorHandle* ret)""",
"""
AOTITorchError
aoti_torch_cuda_fn_square(
AtenTensorHandle input,
AtenTensorHandle* ret)""",
],
},
), config.patch(
"aot_inductor.custom_op_libs",
["aoti_custom_ops"],
):
self.check_model(m, args)
class AOTInductorLoggingTest(LoggingTestCase):
@make_logging_test(dynamic=logging.DEBUG)

View File

@ -6,7 +6,7 @@ import math
import os
import sys
import textwrap
from itertools import chain, count
from itertools import count
from typing import Callable, Optional, Protocol, TYPE_CHECKING, Union
import sympy
@ -237,22 +237,6 @@ class CppWrapperCpu(PythonWrapperCodegen):
if V.graph.is_const_graph:
# We do not write prefix for constant graph, it will be written by main module.
return
if config.aot_inductor.custom_ops_to_c_shims:
# custom_ops_to_c_shims contains declaration of custom ops with C shim.
# TODO: this could be auto-generated from a passed-in custom op schema
custom_c_shims = list(
chain(*config.aot_inductor.custom_ops_to_c_shims.values())
)
declarations = "\n".join(
[f"export {textwrap.dedent(shim)};" for shim in custom_c_shims]
)
self.prefix.splice(
f"""
extern "C" {{
{declarations}
}}
"""
)
if V.graph.aot_mode:
self.prefix.writeline("namespace torch::aot_inductor {")

View File

@ -1309,11 +1309,6 @@ class aot_inductor:
# Embed generated .cubin files into the .so
embed_cubin: bool = False
# Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict
custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {}
# custom op libs that have implemented C shim wrappers
custom_op_libs: list[str] = []
class cuda:
"""Settings for cuda backend, today this consists of cutlass"""

View File

@ -1324,8 +1324,6 @@ def get_cpp_torch_device_options(
# Only add link args, when compile_only is false.
passthrough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"]
libraries += config.aot_inductor.custom_op_libs
return (
definitions,
include_dirs,

View File

@ -6997,10 +6997,7 @@ class FallbackKernel(ExternKernelAlloc):
assert isinstance(kernel, torch._ops.OpOverload)
elif V.graph.cpp_wrapper:
# For non-aten OpOverload, i.e. custom ops
# If the op is in custom_ops_to_c_shims, generate direct function call
self.use_runtime_dispatch = (
kernel not in config.aot_inductor.custom_ops_to_c_shims
)
self.use_runtime_dispatch = True
def do_runtime_dispatch() -> None:
args = None