mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[AOTI] Add an option to specify custom op C shim (#153851)"
This reverts commit365ac49840. Reverted https://github.com/pytorch/pytorch/pull/153851 on behalf of https://github.com/malfet due to Looks like it broke fuzzer test, but I could be wrong, seec4d1ff02f8/1([comment](https://github.com/pytorch/pytorch/pull/153851#issuecomment-2894619773))
This commit is contained in:
parent
c4d1ff02f8
commit
3102ae6798
|
|
@ -1,8 +1,5 @@
|
||||||
#include <torch/csrc/api/include/torch/types.h> // @manual=fbcode//caffe2:libtorch
|
#include <torch/csrc/api/include/torch/types.h> // @manual=fbcode//caffe2:libtorch
|
||||||
|
|
||||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
|
||||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
@ -313,40 +310,8 @@ void fn_out_variant_without_return_meta(
|
||||||
Tensor& out) {
|
Tensor& out) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor fn_square_impl(const Tensor& tensor) {
|
|
||||||
return tensor * tensor;
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor fn_square_meta(const Tensor& tensor) {
|
|
||||||
return at::empty_like(tensor);
|
|
||||||
}
|
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
AOTI_TORCH_EXPORT AOTITorchError
|
|
||||||
aoti_torch_cpu_fn_square(
|
|
||||||
AtenTensorHandle input,
|
|
||||||
AtenTensorHandle* ret) {
|
|
||||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
|
||||||
auto tmp_result = at::fn_square_impl(
|
|
||||||
torch::aot_inductor::resolve_tensor_dispatch_flags(input));
|
|
||||||
*ret = torch::aot_inductor::new_tensor_handle(std::move(tmp_result));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
AOTI_TORCH_EXPORT AOTITorchError
|
|
||||||
aoti_torch_cuda_fn_square(
|
|
||||||
AtenTensorHandle input,
|
|
||||||
AtenTensorHandle* ret) {
|
|
||||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
|
||||||
auto tmp_result = at::fn_square_impl(
|
|
||||||
torch::aot_inductor::resolve_tensor_dispatch_flags(input));
|
|
||||||
*ret = torch::aot_inductor::new_tensor_handle(std::move(tmp_result));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_LIBRARY(aoti_custom_ops, m) {
|
TORCH_LIBRARY(aoti_custom_ops, m) {
|
||||||
m.def("custom_add(Tensor t1, Tensor t2) -> Tensor");
|
m.def("custom_add(Tensor t1, Tensor t2) -> Tensor");
|
||||||
m.def(
|
m.def(
|
||||||
|
|
@ -389,7 +354,6 @@ TORCH_LIBRARY(aoti_custom_ops, m) {
|
||||||
"fn_with_input_mutation(Tensor(a!) t0, Tensor t1, Tensor(b!) t2) -> (Tensor, Tensor)");
|
"fn_with_input_mutation(Tensor(a!) t0, Tensor t1, Tensor(b!) t2) -> (Tensor, Tensor)");
|
||||||
|
|
||||||
m.def("fn_out_variant_without_return(Tensor x, Tensor(a!) out) -> ()");
|
m.def("fn_out_variant_without_return(Tensor x, Tensor(a!) out) -> ()");
|
||||||
m.def("fn_square(Tensor x) -> Tensor");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
|
TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
|
||||||
|
|
@ -401,7 +365,6 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
|
||||||
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_impl);
|
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_impl);
|
||||||
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_impl);
|
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_impl);
|
||||||
m.impl("fn_out_variant_without_return", at::fn_out_variant_without_return_impl);
|
m.impl("fn_out_variant_without_return", at::fn_out_variant_without_return_impl);
|
||||||
m.impl("fn_square", at::fn_square_impl);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
|
TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
|
||||||
|
|
@ -412,5 +375,4 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
|
||||||
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_meta);
|
m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_meta);
|
||||||
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_meta);
|
m.impl("fn_with_input_mutation", at::fn_with_input_mutation_meta);
|
||||||
m.impl("fn_out_variant_without_return", at::fn_out_variant_without_return_meta);
|
m.impl("fn_out_variant_without_return", at::fn_out_variant_without_return_meta);
|
||||||
m.impl("fn_square", at::fn_square_meta);
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,6 @@ from torch.testing._internal.common_utils import (
|
||||||
IS_MACOS,
|
IS_MACOS,
|
||||||
IS_SANDCASTLE,
|
IS_SANDCASTLE,
|
||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
skipIfRocm,
|
|
||||||
skipIfXpu,
|
|
||||||
)
|
)
|
||||||
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
||||||
from torch.testing._internal.triton_utils import HAS_CUDA
|
from torch.testing._internal.triton_utils import HAS_CUDA
|
||||||
|
|
@ -358,37 +356,6 @@ class AOTInductorTestsTemplate:
|
||||||
self.assertEqual(len(inps), 0)
|
self.assertEqual(len(inps), 0)
|
||||||
self.assertTrue(sentinel_seen)
|
self.assertTrue(sentinel_seen)
|
||||||
|
|
||||||
@skipIfXpu
|
|
||||||
@skipIfRocm
|
|
||||||
def test_custom_op_square(self) -> None:
|
|
||||||
class Model(torch.nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
return torch.ops.aoti_custom_ops.fn_square(x)
|
|
||||||
|
|
||||||
m = Model().to(device=self.device)
|
|
||||||
args = (torch.randn(2, 3, device=self.device),)
|
|
||||||
with config.patch(
|
|
||||||
"aot_inductor.custom_ops_to_c_shims",
|
|
||||||
{
|
|
||||||
torch.ops.aoti_custom_ops.fn_square.default: [
|
|
||||||
"""
|
|
||||||
AOTITorchError
|
|
||||||
aoti_torch_cpu_fn_square(
|
|
||||||
AtenTensorHandle input,
|
|
||||||
AtenTensorHandle* ret)""",
|
|
||||||
"""
|
|
||||||
AOTITorchError
|
|
||||||
aoti_torch_cuda_fn_square(
|
|
||||||
AtenTensorHandle input,
|
|
||||||
AtenTensorHandle* ret)""",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
), config.patch(
|
|
||||||
"aot_inductor.custom_op_libs",
|
|
||||||
["aoti_custom_ops"],
|
|
||||||
):
|
|
||||||
self.check_model(m, args)
|
|
||||||
|
|
||||||
|
|
||||||
class AOTInductorLoggingTest(LoggingTestCase):
|
class AOTInductorLoggingTest(LoggingTestCase):
|
||||||
@make_logging_test(dynamic=logging.DEBUG)
|
@make_logging_test(dynamic=logging.DEBUG)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from itertools import chain, count
|
from itertools import count
|
||||||
from typing import Callable, Optional, Protocol, TYPE_CHECKING, Union
|
from typing import Callable, Optional, Protocol, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
@ -237,22 +237,6 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
if V.graph.is_const_graph:
|
if V.graph.is_const_graph:
|
||||||
# We do not write prefix for constant graph, it will be written by main module.
|
# We do not write prefix for constant graph, it will be written by main module.
|
||||||
return
|
return
|
||||||
if config.aot_inductor.custom_ops_to_c_shims:
|
|
||||||
# custom_ops_to_c_shims contains declaration of custom ops with C shim.
|
|
||||||
# TODO: this could be auto-generated from a passed-in custom op schema
|
|
||||||
custom_c_shims = list(
|
|
||||||
chain(*config.aot_inductor.custom_ops_to_c_shims.values())
|
|
||||||
)
|
|
||||||
declarations = "\n".join(
|
|
||||||
[f"export {textwrap.dedent(shim)};" for shim in custom_c_shims]
|
|
||||||
)
|
|
||||||
self.prefix.splice(
|
|
||||||
f"""
|
|
||||||
extern "C" {{
|
|
||||||
{declarations}
|
|
||||||
}}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
if V.graph.aot_mode:
|
if V.graph.aot_mode:
|
||||||
self.prefix.writeline("namespace torch::aot_inductor {")
|
self.prefix.writeline("namespace torch::aot_inductor {")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1309,11 +1309,6 @@ class aot_inductor:
|
||||||
# Embed generated .cubin files into the .so
|
# Embed generated .cubin files into the .so
|
||||||
embed_cubin: bool = False
|
embed_cubin: bool = False
|
||||||
|
|
||||||
# Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict
|
|
||||||
custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {}
|
|
||||||
# custom op libs that have implemented C shim wrappers
|
|
||||||
custom_op_libs: list[str] = []
|
|
||||||
|
|
||||||
|
|
||||||
class cuda:
|
class cuda:
|
||||||
"""Settings for cuda backend, today this consists of cutlass"""
|
"""Settings for cuda backend, today this consists of cutlass"""
|
||||||
|
|
|
||||||
|
|
@ -1324,8 +1324,6 @@ def get_cpp_torch_device_options(
|
||||||
# Only add link args, when compile_only is false.
|
# Only add link args, when compile_only is false.
|
||||||
passthrough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"]
|
passthrough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"]
|
||||||
|
|
||||||
libraries += config.aot_inductor.custom_op_libs
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
definitions,
|
definitions,
|
||||||
include_dirs,
|
include_dirs,
|
||||||
|
|
|
||||||
|
|
@ -6997,10 +6997,7 @@ class FallbackKernel(ExternKernelAlloc):
|
||||||
assert isinstance(kernel, torch._ops.OpOverload)
|
assert isinstance(kernel, torch._ops.OpOverload)
|
||||||
elif V.graph.cpp_wrapper:
|
elif V.graph.cpp_wrapper:
|
||||||
# For non-aten OpOverload, i.e. custom ops
|
# For non-aten OpOverload, i.e. custom ops
|
||||||
# If the op is in custom_ops_to_c_shims, generate direct function call
|
self.use_runtime_dispatch = True
|
||||||
self.use_runtime_dispatch = (
|
|
||||||
kernel not in config.aot_inductor.custom_ops_to_c_shims
|
|
||||||
)
|
|
||||||
|
|
||||||
def do_runtime_dispatch() -> None:
|
def do_runtime_dispatch() -> None:
|
||||||
args = None
|
args = None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user