Add shim.h C API to call dispatcher on our own aten ops (#148832)

This PR still needs testing through some cpp extension

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148832
Approved by: https://github.com/albanD, https://github.com/atalman
ghstack dependencies: #148124
This commit is contained in:
Jane Xu 2025-03-11 10:53:24 -07:00 committed by PyTorch MergeBot
parent cf19efd3d9
commit e6ef0620cc
6 changed files with 163 additions and 0 deletions

View File

@ -125,3 +125,25 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("identity", &boxed_identity);
}
RAIIATH my_abs(RAIIATH t) {
const auto num_args = 1;
StableIValue stack[num_args];
stack[0] = from(t.release());
aoti_torch_call_dispatcher("aten::abs", "", stack);
return RAIIATH(to<AtenTensorHandle>(stack[0]));
}
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
RAIIATH t(to<AtenTensorHandle>(stack[0]));
RAIIATH raiiath_res = my_abs(std::move(t));
stack[0] = from(raiiath_res.release());
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_abs(Tensor t) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_abs", &boxed_my_abs);
}

View File

@ -36,3 +36,16 @@ def identity(t) -> Tensor:
a Tensor, the same as input.
"""
return torch.ops.libtorch_agnostic.identity.default(t)
def my_abs(t) -> Tensor:
"""
Returns abs on the input tensor, outputs a new Tensor
Args:
t: any Tensor
Returns:
a Tensor
"""
return torch.ops.libtorch_agnostic.my_abs.default(t)

View File

@ -52,6 +52,23 @@ class TestLibtorchAgnostic(TestCase):
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
def test_my_abs(self, device):
t = torch.rand(32, 16, device=device)
cpu_t = libtorch_agnostic.ops.my_abs(t)
self.assertEqual(cpu_t, torch.abs(t))
def _make_cuda_tensors(prior_mem):
cuda_t = libtorch_agnostic.ops.my_abs(t)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
self.assertEqual(cuda_t, torch.abs(t))
if t.is_cuda:
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
@onlyCUDA
def test_z_delete_torch_lib(self, device):
# Why the z + CUDA? THIS TEST MUST BE RUN LAST

View File

@ -270,6 +270,23 @@ class TestCppExtensionAOT(common.TestCase):
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
# (3) test calling our dispatcher on ones_like
t = torch.rand(32, 16, device=device)
cpu_t = libtorch_agnostic.ops.my_abs(t)
self.assertEqual(cpu_t, torch.abs(t))
def _make_cuda_tensors(prior_mem):
cuda_t = libtorch_agnostic.ops.my_abs(t)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
self.assertEqual(cuda_t, torch.abs(t))
if t.is_cuda:
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestPybindTypeCasters(common.TestCase):

View File

@ -677,6 +677,15 @@ aoti_torch_library_def(TorchLibraryHandle self, const char* schema);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_delete_library_object(TorchLibraryHandle tlh);
// calls the op overload defined by a given opName, overloadName, and a
// stack of StableIValues. This call will populate any return values of the
// op into the stack in their StableIValue form, with ret0 at index 0, ret1
// at index 1, and so on.
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_call_dispatcher(
const char* opName,
const char* overloadName,
StableIValue* stack);
#ifdef USE_CUDA
struct CUDAGuardOpaque;

View File

@ -1429,3 +1429,88 @@ aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ delete reinterpret_cast<torch::Library*>(tlh); });
}
static c10::IValue to_ivalue(
const c10::TypePtr& arg_type,
const StableIValue stable_ivalue) {
switch (arg_type->kind()) {
case c10::TypeKind::TensorType: {
// stable_ivalue must be an ATH
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
to<AtenTensorHandle>(stable_ivalue));
at::Tensor arg = *torch::aot_inductor::tensor_handle_to_tensor_pointer(
ret_raiiath.get());
return (c10::IValue(arg));
}
case c10::TypeKind::IntType: {
return c10::IValue(to<int64_t>(stable_ivalue));
}
case c10::TypeKind::FloatType: {
return c10::IValue(to<double>(stable_ivalue));
}
case c10::TypeKind::BoolType: {
return c10::IValue(to<bool>(stable_ivalue));
}
case c10::TypeKind::ScalarTypeType: {
return c10::IValue(to<c10::ScalarType>(stable_ivalue));
}
case c10::TypeKind::LayoutType: {
return c10::IValue(to<c10::Layout>(stable_ivalue));
}
case c10::TypeKind::MemoryFormatType: {
return c10::IValue(to<c10::MemoryFormat>(stable_ivalue));
}
default: {
TORCH_CHECK(false, "Not yet supported argument type: ", arg_type->str());
}
}
}
AOTITorchError aoti_torch_call_dispatcher(
const char* opName,
const char* overloadName,
StableIValue* stack) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
static auto op =
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
torch::jit::Stack ivalue_stack;
// we will only need max(num_args, num_returns)
ivalue_stack.reserve(std::max(num_arguments, num_returns));
// convert StableIValue stack to c10::IValue stack
for (const auto idx : c10::irange(num_arguments)) {
auto stable_ivalue = stack[idx];
auto arg_type = schema.arguments()[idx].type();
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
}
op.callBoxed(ivalue_stack);
// there should then be num_returns IValues on the stack, which
// we will convert to StableIValue and repopulate user input stack
for (const auto idx : c10::irange(num_returns)) {
const c10::IValue& ret = torch::jit::pop(ivalue_stack);
const auto stack_idx = num_returns - idx - 1;
if (ret.isInt()) {
stack[stack_idx] = from(ret.toInt());
} else if (ret.isDouble()) {
stack[stack_idx] = from(ret.toDouble());
} else if (ret.isBool()) {
stack[stack_idx] = from(ret.toBool());
} else if (ret.isNone()) {
stack[stack_idx] = from(nullptr);
} else if (ret.isTensor()) {
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
std::move(const_cast<at::Tensor&>(ret.toTensor())));
stack[stack_idx] = from(ath);
} else {
TORCH_CHECK(false, "Other types of IValue returns not yet handled!");
}
}
});
}