Move static from_ivalue/to_ivalue to new shim_common.cpp (#166373)

Move `from_ivalue` and `to_ivalue` and their dependents `StableIValueBoxedKernel`, `aoti_torch_library_impl` `aoti_torch_call_dispatcher` into new (non-aoti shim_common.cpp)

This is in prep for the above PRs where I add v2s (`torch_call_dispatcher` and `torch_library_impl`) that are versioning aware

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166373
Approved by: https://github.com/janeyx99
ghstack dependencies: #164356
This commit is contained in:
Mikayla Gawarecki 2025-10-28 13:12:01 -07:00 committed by PyTorch MergeBot
parent fefb546b91
commit c0bbda37e8
3 changed files with 219 additions and 209 deletions

View File

@ -482,6 +482,7 @@ inductor_core_resources = [
"torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp", "torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp",
"torch/csrc/inductor/inductor_ops.cpp", "torch/csrc/inductor/inductor_ops.cpp",
"torch/csrc/jit/serialization/pickle.cpp", "torch/csrc/jit/serialization/pickle.cpp",
"torch/csrc/shim_common.cpp",
] ]
libtorch_core_sources = sorted( libtorch_core_sources = sorted(

View File

@ -1406,169 +1406,6 @@ AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) {
}); });
} }
static StableIValue from_ivalue(
const c10::TypePtr& type,
const c10::IValue& ivalue) {
switch (type->kind()) {
case c10::TypeKind::TensorType: {
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
return torch::stable::detail::from(ath);
}
case c10::TypeKind::IntType: {
return torch::stable::detail::from(ivalue.toInt());
}
case c10::TypeKind::FloatType: {
return torch::stable::detail::from(ivalue.toDouble());
}
case c10::TypeKind::BoolType: {
return torch::stable::detail::from(ivalue.toBool());
}
case c10::TypeKind::ScalarTypeType: {
return torch::stable::detail::from(ivalue.toScalarType());
}
case c10::TypeKind::DeviceObjType: {
return torch::stable::detail::from(ivalue.toDevice());
}
case c10::TypeKind::LayoutType: {
return torch::stable::detail::from(ivalue.toLayout());
}
case c10::TypeKind::MemoryFormatType: {
return torch::stable::detail::from(ivalue.toMemoryFormat());
}
case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
// ideally, if we had the C++ type corresponding to inner_type, which we
// will denote as inner_type::t (does not actually exist), we would be
// able to follow the patterned semantic of every other case here in one
// line:
//
// return
// torch::stable::detail::from<std::optional<inner_type::t>>(ivalue.toInnerTypeT()));
//
// BUT we do NOT have that type inner_type::t readily available, so we
// will manually unwrap and recursively call. This implementation MUST
// be kept in sync with torch::stable::detail::from<std::optional<T>>
// function in torch/csrc/stable/stableivalue_conversions.h
if (ivalue.isNone()) {
return torch::stable::detail::from(std::nullopt);
}
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue));
return torch::stable::detail::from(sivp);
}
default: {
TORCH_CHECK(
false,
"Not yet supported conversion from IValue to StableIValue for schema type: ",
type->str());
}
}
}
static c10::IValue to_ivalue(
const c10::TypePtr& type,
const StableIValue stable_ivalue) {
switch (type->kind()) {
case c10::TypeKind::TensorType: {
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
torch::stable::detail::to<AtenTensorHandle>(stable_ivalue));
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
ret_raiiath.get())));
}
case c10::TypeKind::IntType: {
return c10::IValue(torch::stable::detail::to<int64_t>(stable_ivalue));
}
case c10::TypeKind::FloatType: {
return c10::IValue(torch::stable::detail::to<double>(stable_ivalue));
}
case c10::TypeKind::BoolType: {
return c10::IValue(torch::stable::detail::to<bool>(stable_ivalue));
}
case c10::TypeKind::ScalarTypeType: {
return c10::IValue(
torch::stable::detail::to<c10::ScalarType>(stable_ivalue));
}
case c10::TypeKind::DeviceObjType: {
return c10::IValue(torch::stable::detail::to<c10::Device>(stable_ivalue));
}
case c10::TypeKind::LayoutType: {
return c10::IValue(torch::stable::detail::to<c10::Layout>(stable_ivalue));
}
case c10::TypeKind::MemoryFormatType: {
return c10::IValue(
torch::stable::detail::to<c10::MemoryFormat>(stable_ivalue));
}
case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
// ideally, if we had the C++ type corresponding to inner_type, which we
// will denote as inner_type::t (does not actually exist), we would be
// able to follow the patterned semantic of every other case here in one
// line:
//
// return
// c10::IValue(torch::stable::detail::to<std::optional<inner_type::t>>(stable_ivalue));
//
// BUT we do NOT have that type inner_type::t readily available, so we
// will manually unwrap and recursively call. This implementation MUST
// be kept in sync with the torch::stable::detail::to<T> function in
// torch/csrc/stable/stableivalue_conversions.h
if (stable_ivalue == torch::stable::detail::from(std::nullopt)) {
return c10::IValue();
}
auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue);
auto ival = to_ivalue(inner_type, *sivp);
delete sivp;
return ival;
}
default: {
TORCH_CHECK(
false,
"Not yet supported conversion from StableIValue to IValue for schema type: ",
type->str());
}
}
}
class StableIValueBoxedKernel : public c10::OperatorKernel {
public:
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
: fn_(fn) {}
void operator()(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
auto ministack =
std::make_unique<StableIValue[]>(std::max(num_arguments, num_returns));
for (const auto idx : c10::irange(num_arguments)) {
const auto ministack_idx = num_arguments - idx - 1;
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack));
}
// boxed function is going to take a stack of StableIValues, cast them to
// our schema values, and run the function and modify the StableIValue stack
fn_(ministack.get(), num_arguments, num_returns);
// read the output from the end of the stack and wrap that back into
// IValue from StableIValue
for (size_t idx = 0; idx < num_returns; idx++) {
const c10::TypePtr& ret_type = schema.returns()[idx].type();
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx]));
}
}
private:
void (*fn_)(StableIValue*, uint64_t, uint64_t);
};
AOTITorchError aoti_torch_library_init_impl( AOTITorchError aoti_torch_library_init_impl(
const char* ns, const char* ns,
const char* k, const char* k,
@ -1618,18 +1455,6 @@ AOTITorchError aoti_torch_library_init_fragment(
}); });
} }
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
TorchLibraryHandle self,
const char* name,
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
reinterpret_cast<torch::Library*>(self)->impl(
name,
torch::CppFunction::makeFromBoxedFunctor(
std::make_unique<StableIValueBoxedKernel>(fn)));
});
}
AOTI_TORCH_EXPORT AOTITorchError AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_library_def(TorchLibraryHandle self, const char* name) { aoti_torch_library_def(TorchLibraryHandle self, const char* name) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
@ -1642,40 +1467,6 @@ aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
{ delete reinterpret_cast<torch::Library*>(tlh); }); { delete reinterpret_cast<torch::Library*>(tlh); });
} }
AOTITorchError aoti_torch_call_dispatcher(
const char* opName,
const char* overloadName,
StableIValue* stack) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
const auto op =
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
torch::jit::Stack ivalue_stack;
// we will only need max(num_args, num_returns)
ivalue_stack.reserve(std::max(num_arguments, num_returns));
// convert StableIValue stack to c10::IValue stack
for (const auto idx : c10::irange(num_arguments)) {
auto stable_ivalue = stack[idx];
auto arg_type = schema.arguments()[idx].type();
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
}
op.callBoxed(ivalue_stack);
// there should then be num_returns IValues on the stack, which
// we will convert to StableIValue and repopulate user input stack
for (const auto idx : c10::irange(num_returns)) {
const auto stack_idx = num_returns - idx - 1;
const c10::TypePtr& ret_type = schema.returns()[idx].type();
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
}
});
}
AOTITorchError aoti_torch_create_device_guard( AOTITorchError aoti_torch_create_device_guard(
int32_t device_index, int32_t device_index,
DeviceGuardHandle* ret_guard // returns new reference DeviceGuardHandle* ret_guard // returns new reference

218
torch/csrc/shim_common.cpp Normal file
View File

@ -0,0 +1,218 @@
#include <c10/core/DispatchKey.h>
#include <c10/util/Exception.h>
#include <torch/csrc/inductor/aoti_runtime/utils.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <torch/csrc/stable/library.h>
#include <torch/library.h>
static StableIValue from_ivalue(
const c10::TypePtr& type,
const c10::IValue& ivalue) {
switch (type->kind()) {
case c10::TypeKind::TensorType: {
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
return torch::stable::detail::from(ath);
}
case c10::TypeKind::IntType: {
return torch::stable::detail::from(ivalue.toInt());
}
case c10::TypeKind::FloatType: {
return torch::stable::detail::from(ivalue.toDouble());
}
case c10::TypeKind::BoolType: {
return torch::stable::detail::from(ivalue.toBool());
}
case c10::TypeKind::ScalarTypeType: {
return torch::stable::detail::from(ivalue.toScalarType());
}
case c10::TypeKind::DeviceObjType: {
return torch::stable::detail::from(ivalue.toDevice());
}
case c10::TypeKind::LayoutType: {
return torch::stable::detail::from(ivalue.toLayout());
}
case c10::TypeKind::MemoryFormatType: {
return torch::stable::detail::from(ivalue.toMemoryFormat());
}
case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
// ideally, if we had the C++ type corresponding to inner_type, which we
// will denote as inner_type::t (does not actually exist), we would be
// able to follow the patterned semantic of every other case here in one
// line:
//
// return
// torch::stable::detail::from<std::optional<inner_type::t>>(ivalue.toInnerTypeT()));
//
// BUT we do NOT have that type inner_type::t readily available, so we
// will manually unwrap and recursively call. This implementation MUST
// be kept in sync with torch::stable::detail::from<std::optional<T>>
// function in torch/csrc/stable/stableivalue_conversions.h
if (ivalue.isNone()) {
return torch::stable::detail::from(std::nullopt);
}
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue));
return torch::stable::detail::from(sivp);
}
default: {
TORCH_CHECK(
false,
"Not yet supported conversion from IValue to StableIValue for schema type: ",
type->str());
}
}
}
static c10::IValue to_ivalue(
const c10::TypePtr& type,
const StableIValue stable_ivalue) {
switch (type->kind()) {
case c10::TypeKind::TensorType: {
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
torch::stable::detail::to<AtenTensorHandle>(stable_ivalue));
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
ret_raiiath.get())));
}
case c10::TypeKind::IntType: {
return c10::IValue(torch::stable::detail::to<int64_t>(stable_ivalue));
}
case c10::TypeKind::FloatType: {
return c10::IValue(torch::stable::detail::to<double>(stable_ivalue));
}
case c10::TypeKind::BoolType: {
return c10::IValue(torch::stable::detail::to<bool>(stable_ivalue));
}
case c10::TypeKind::ScalarTypeType: {
return c10::IValue(
torch::stable::detail::to<c10::ScalarType>(stable_ivalue));
}
case c10::TypeKind::DeviceObjType: {
return c10::IValue(torch::stable::detail::to<c10::Device>(stable_ivalue));
}
case c10::TypeKind::LayoutType: {
return c10::IValue(torch::stable::detail::to<c10::Layout>(stable_ivalue));
}
case c10::TypeKind::MemoryFormatType: {
return c10::IValue(
torch::stable::detail::to<c10::MemoryFormat>(stable_ivalue));
}
case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
// ideally, if we had the C++ type corresponding to inner_type, which we
// will denote as inner_type::t (does not actually exist), we would be
// able to follow the patterned semantic of every other case here in one
// line:
//
// return
// c10::IValue(torch::stable::detail::to<std::optional<inner_type::t>>(stable_ivalue));
//
// BUT we do NOT have that type inner_type::t readily available, so we
// will manually unwrap and recursively call. This implementation MUST
// be kept in sync with the torch::stable::detail::to<T> function in
// torch/csrc/stable/stableivalue_conversions.h
if (stable_ivalue == torch::stable::detail::from(std::nullopt)) {
return c10::IValue();
}
auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue);
auto ival = to_ivalue(inner_type, *sivp);
delete sivp;
return ival;
}
default: {
TORCH_CHECK(
false,
"Not yet supported conversion from StableIValue to IValue for schema type: ",
type->str());
}
}
}
class StableIValueBoxedKernel : public c10::OperatorKernel {
public:
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
: fn_(fn) {}
void operator()(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
auto ministack =
std::make_unique<StableIValue[]>(std::max(num_arguments, num_returns));
for (const auto idx : c10::irange(num_arguments)) {
const auto ministack_idx = num_arguments - idx - 1;
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack));
}
// boxed function is going to take a stack of StableIValues, cast them to
// our schema values, and run the function and modify the StableIValue stack
fn_(ministack.get(), num_arguments, num_returns);
// read the output from the end of the stack and wrap that back into
// IValue from StableIValue
for (size_t idx = 0; idx < num_returns; idx++) {
const c10::TypePtr& ret_type = schema.returns()[idx].type();
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx]));
}
}
private:
void (*fn_)(StableIValue*, uint64_t, uint64_t);
};
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
TorchLibraryHandle self,
const char* name,
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
reinterpret_cast<torch::Library*>(self)->impl(
name,
torch::CppFunction::makeFromBoxedFunctor(
std::make_unique<StableIValueBoxedKernel>(fn)));
});
}
AOTITorchError aoti_torch_call_dispatcher(
const char* opName,
const char* overloadName,
StableIValue* stack) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
const auto op =
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
torch::jit::Stack ivalue_stack;
// we will only need max(num_args, num_returns)
ivalue_stack.reserve(std::max(num_arguments, num_returns));
// convert StableIValue stack to c10::IValue stack
for (const auto idx : c10::irange(num_arguments)) {
auto stable_ivalue = stack[idx];
auto arg_type = schema.arguments()[idx].type();
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
}
op.callBoxed(ivalue_stack);
// there should then be num_returns IValues on the stack, which
// we will convert to StableIValue and repopulate user input stack
for (const auto idx : c10::irange(num_returns)) {
const auto stack_idx = num_returns - idx - 1;
const c10::TypePtr& ret_type = schema.returns()[idx].type();
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
}
});
}