mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add shim.h C API to call dispatcher on our own aten ops (#148832)
This PR still needs testing through some cpp extension Pull Request resolved: https://github.com/pytorch/pytorch/pull/148832 Approved by: https://github.com/albanD, https://github.com/atalman ghstack dependencies: #148124
This commit is contained in:
parent
cf19efd3d9
commit
e6ef0620cc
|
|
@ -125,3 +125,25 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
|
|||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
m.impl("identity", &boxed_identity);
|
||||
}
|
||||
|
||||
RAIIATH my_abs(RAIIATH t) {
|
||||
const auto num_args = 1;
|
||||
StableIValue stack[num_args];
|
||||
stack[0] = from(t.release());
|
||||
aoti_torch_call_dispatcher("aten::abs", "", stack);
|
||||
return RAIIATH(to<AtenTensorHandle>(stack[0]));
|
||||
}
|
||||
|
||||
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
RAIIATH t(to<AtenTensorHandle>(stack[0]));
|
||||
RAIIATH raiiath_res = my_abs(std::move(t));
|
||||
stack[0] = from(raiiath_res.release());
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_abs(Tensor t) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_abs", &boxed_my_abs);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,3 +36,16 @@ def identity(t) -> Tensor:
|
|||
a Tensor, the same as input.
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.identity.default(t)
|
||||
|
||||
|
||||
def my_abs(t) -> Tensor:
|
||||
"""
|
||||
Returns abs on the input tensor, outputs a new Tensor
|
||||
|
||||
Args:
|
||||
t: any Tensor
|
||||
|
||||
Returns:
|
||||
a Tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_abs.default(t)
|
||||
|
|
|
|||
|
|
@ -52,6 +52,23 @@ class TestLibtorchAgnostic(TestCase):
|
|||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
def test_my_abs(self, device):
|
||||
t = torch.rand(32, 16, device=device)
|
||||
cpu_t = libtorch_agnostic.ops.my_abs(t)
|
||||
self.assertEqual(cpu_t, torch.abs(t))
|
||||
|
||||
def _make_cuda_tensors(prior_mem):
|
||||
cuda_t = libtorch_agnostic.ops.my_abs(t)
|
||||
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
|
||||
self.assertEqual(cuda_t, torch.abs(t))
|
||||
|
||||
if t.is_cuda:
|
||||
init_mem = torch.cuda.memory_allocated(device)
|
||||
for _ in range(3):
|
||||
_make_cuda_tensors(init_mem)
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
@onlyCUDA
|
||||
def test_z_delete_torch_lib(self, device):
|
||||
# Why the z + CUDA? THIS TEST MUST BE RUN LAST
|
||||
|
|
|
|||
|
|
@ -270,6 +270,23 @@ class TestCppExtensionAOT(common.TestCase):
|
|||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
# (3) test calling our dispatcher on ones_like
|
||||
t = torch.rand(32, 16, device=device)
|
||||
cpu_t = libtorch_agnostic.ops.my_abs(t)
|
||||
self.assertEqual(cpu_t, torch.abs(t))
|
||||
|
||||
def _make_cuda_tensors(prior_mem):
|
||||
cuda_t = libtorch_agnostic.ops.my_abs(t)
|
||||
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
|
||||
self.assertEqual(cuda_t, torch.abs(t))
|
||||
|
||||
if t.is_cuda:
|
||||
init_mem = torch.cuda.memory_allocated(device)
|
||||
for _ in range(3):
|
||||
_make_cuda_tensors(init_mem)
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestPybindTypeCasters(common.TestCase):
|
||||
|
|
|
|||
|
|
@ -677,6 +677,15 @@ aoti_torch_library_def(TorchLibraryHandle self, const char* schema);
|
|||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_delete_library_object(TorchLibraryHandle tlh);
|
||||
|
||||
// calls the op overload defined by a given opName, overloadName, and a
|
||||
// stack of StableIValues. This call will populate any return values of the
|
||||
// op into the stack in their StableIValue form, with ret0 at index 0, ret1
|
||||
// at index 1, and so on.
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_call_dispatcher(
|
||||
const char* opName,
|
||||
const char* overloadName,
|
||||
StableIValue* stack);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
struct CUDAGuardOpaque;
|
||||
|
|
|
|||
|
|
@ -1429,3 +1429,88 @@ aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
|
|||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ delete reinterpret_cast<torch::Library*>(tlh); });
|
||||
}
|
||||
|
||||
static c10::IValue to_ivalue(
|
||||
const c10::TypePtr& arg_type,
|
||||
const StableIValue stable_ivalue) {
|
||||
switch (arg_type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
// stable_ivalue must be an ATH
|
||||
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
|
||||
to<AtenTensorHandle>(stable_ivalue));
|
||||
at::Tensor arg = *torch::aot_inductor::tensor_handle_to_tensor_pointer(
|
||||
ret_raiiath.get());
|
||||
return (c10::IValue(arg));
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return c10::IValue(to<int64_t>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return c10::IValue(to<double>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return c10::IValue(to<bool>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return c10::IValue(to<c10::ScalarType>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return c10::IValue(to<c10::Layout>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return c10::IValue(to<c10::MemoryFormat>(stable_ivalue));
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(false, "Not yet supported argument type: ", arg_type->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_call_dispatcher(
|
||||
const char* opName,
|
||||
const char* overloadName,
|
||||
StableIValue* stack) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
static auto op =
|
||||
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
|
||||
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
torch::jit::Stack ivalue_stack;
|
||||
// we will only need max(num_args, num_returns)
|
||||
ivalue_stack.reserve(std::max(num_arguments, num_returns));
|
||||
|
||||
// convert StableIValue stack to c10::IValue stack
|
||||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
auto stable_ivalue = stack[idx];
|
||||
auto arg_type = schema.arguments()[idx].type();
|
||||
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
|
||||
}
|
||||
|
||||
op.callBoxed(ivalue_stack);
|
||||
|
||||
// there should then be num_returns IValues on the stack, which
|
||||
// we will convert to StableIValue and repopulate user input stack
|
||||
for (const auto idx : c10::irange(num_returns)) {
|
||||
const c10::IValue& ret = torch::jit::pop(ivalue_stack);
|
||||
const auto stack_idx = num_returns - idx - 1;
|
||||
if (ret.isInt()) {
|
||||
stack[stack_idx] = from(ret.toInt());
|
||||
} else if (ret.isDouble()) {
|
||||
stack[stack_idx] = from(ret.toDouble());
|
||||
} else if (ret.isBool()) {
|
||||
stack[stack_idx] = from(ret.toBool());
|
||||
} else if (ret.isNone()) {
|
||||
stack[stack_idx] = from(nullptr);
|
||||
} else if (ret.isTensor()) {
|
||||
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
|
||||
std::move(const_cast<at::Tensor&>(ret.toTensor())));
|
||||
stack[stack_idx] = from(ath);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Other types of IValue returns not yet handled!");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user