diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/libtorch_agnostic_kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/libtorch_agnostic_kernel.cpp index cba0366dceb..85800c115b3 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/libtorch_agnostic_kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/libtorch_agnostic_kernel.cpp @@ -125,3 +125,25 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) { STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { m.impl("identity", &boxed_identity); } + +RAIIATH my_abs(RAIIATH t) { + const auto num_args = 1; + StableIValue stack[num_args]; + stack[0] = from(t.release()); + aoti_torch_call_dispatcher("aten::abs", "", stack); + return RAIIATH(to(stack[0])); +} + +void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + RAIIATH t(to(stack[0])); + RAIIATH raiiath_res = my_abs(std::move(t)); + stack[0] = from(raiiath_res.release()); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("my_abs(Tensor t) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("my_abs", &boxed_my_abs); +} diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 2a76d0f4b17..6eed13920ee 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -36,3 +36,16 @@ def identity(t) -> Tensor: a Tensor, the same as input. """ return torch.ops.libtorch_agnostic.identity.default(t) + + +def my_abs(t) -> Tensor: + """ + Returns abs on the input tensor, outputs a new Tensor + + Args: + t: any Tensor + + Returns: + a Tensor + """ + return torch.ops.libtorch_agnostic.my_abs.default(t) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index a33099175f1..bc4f96d4ed0 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -52,6 +52,23 @@ class TestLibtorchAgnostic(TestCase): curr_mem = torch.cuda.memory_allocated(device) self.assertEqual(curr_mem, init_mem) + def test_my_abs(self, device): + t = torch.rand(32, 16, device=device) + cpu_t = libtorch_agnostic.ops.my_abs(t) + self.assertEqual(cpu_t, torch.abs(t)) + + def _make_cuda_tensors(prior_mem): + cuda_t = libtorch_agnostic.ops.my_abs(t) + self.assertGreater(torch.cuda.memory_allocated(device), prior_mem) + self.assertEqual(cuda_t, torch.abs(t)) + + if t.is_cuda: + init_mem = torch.cuda.memory_allocated(device) + for _ in range(3): + _make_cuda_tensors(init_mem) + curr_mem = torch.cuda.memory_allocated(device) + self.assertEqual(curr_mem, init_mem) + @onlyCUDA def test_z_delete_torch_lib(self, device): # Why the z + CUDA? THIS TEST MUST BE RUN LAST diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index 57fc0ecb46e..b017c61bb69 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -270,6 +270,23 @@ class TestCppExtensionAOT(common.TestCase): curr_mem = torch.cuda.memory_allocated(device) self.assertEqual(curr_mem, init_mem) + # (3) test calling our dispatcher on ones_like + t = torch.rand(32, 16, device=device) + cpu_t = libtorch_agnostic.ops.my_abs(t) + self.assertEqual(cpu_t, torch.abs(t)) + + def _make_cuda_tensors(prior_mem): + cuda_t = libtorch_agnostic.ops.my_abs(t) + self.assertGreater(torch.cuda.memory_allocated(device), prior_mem) + self.assertEqual(cuda_t, torch.abs(t)) + + if t.is_cuda: + init_mem = torch.cuda.memory_allocated(device) + for _ in range(3): + _make_cuda_tensors(init_mem) + curr_mem = torch.cuda.memory_allocated(device) + self.assertEqual(curr_mem, init_mem) + @torch.testing._internal.common_utils.markDynamoStrictTest class TestPybindTypeCasters(common.TestCase): diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 1069e0bc728..f56f6eca744 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -677,6 +677,15 @@ aoti_torch_library_def(TorchLibraryHandle self, const char* schema); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_delete_library_object(TorchLibraryHandle tlh); +// calls the op overload defined by a given opName, overloadName, and a +// stack of StableIValues. This call will populate any return values of the +// op into the stack in their StableIValue form, with ret0 at index 0, ret1 +// at index 1, and so on. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_call_dispatcher( + const char* opName, + const char* overloadName, + StableIValue* stack); + #ifdef USE_CUDA struct CUDAGuardOpaque; diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 924d734a101..d48b321b496 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1429,3 +1429,88 @@ aoti_torch_delete_library_object(TorchLibraryHandle tlh) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( { delete reinterpret_cast(tlh); }); } + +static c10::IValue to_ivalue( + const c10::TypePtr& arg_type, + const StableIValue stable_ivalue) { + switch (arg_type->kind()) { + case c10::TypeKind::TensorType: { + // stable_ivalue must be an ATH + auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle( + to(stable_ivalue)); + at::Tensor arg = *torch::aot_inductor::tensor_handle_to_tensor_pointer( + ret_raiiath.get()); + return (c10::IValue(arg)); + } + case c10::TypeKind::IntType: { + return c10::IValue(to(stable_ivalue)); + } + case c10::TypeKind::FloatType: { + return c10::IValue(to(stable_ivalue)); + } + case c10::TypeKind::BoolType: { + return c10::IValue(to(stable_ivalue)); + } + case c10::TypeKind::ScalarTypeType: { + return c10::IValue(to(stable_ivalue)); + } + case c10::TypeKind::LayoutType: { + return c10::IValue(to(stable_ivalue)); + } + case c10::TypeKind::MemoryFormatType: { + return c10::IValue(to(stable_ivalue)); + } + default: { + TORCH_CHECK(false, "Not yet supported argument type: ", arg_type->str()); + } + } +} + +AOTITorchError aoti_torch_call_dispatcher( + const char* opName, + const char* overloadName, + StableIValue* stack) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + static auto op = + c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName); + + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + const auto num_arguments = schema.arguments().size(); + + torch::jit::Stack ivalue_stack; + // we will only need max(num_args, num_returns) + ivalue_stack.reserve(std::max(num_arguments, num_returns)); + + // convert StableIValue stack to c10::IValue stack + for (const auto idx : c10::irange(num_arguments)) { + auto stable_ivalue = stack[idx]; + auto arg_type = schema.arguments()[idx].type(); + torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue)); + } + + op.callBoxed(ivalue_stack); + + // there should then be num_returns IValues on the stack, which + // we will convert to StableIValue and repopulate user input stack + for (const auto idx : c10::irange(num_returns)) { + const c10::IValue& ret = torch::jit::pop(ivalue_stack); + const auto stack_idx = num_returns - idx - 1; + if (ret.isInt()) { + stack[stack_idx] = from(ret.toInt()); + } else if (ret.isDouble()) { + stack[stack_idx] = from(ret.toDouble()); + } else if (ret.isBool()) { + stack[stack_idx] = from(ret.toBool()); + } else if (ret.isNone()) { + stack[stack_idx] = from(nullptr); + } else if (ret.isTensor()) { + AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle( + std::move(const_cast(ret.toTensor()))); + stack[stack_idx] = from(ath); + } else { + TORCH_CHECK(false, "Other types of IValue returns not yet handled!"); + } + } + }); +}