Add scaffolding for StableIValue FC/BC (no PoC) (#164332)

1. Add `extension_build_version` and `is_internal` to `FromImpl`/`ToImpl` (this will be useful for future if we need to break the BC of any type) #163832 has the PoC of how we would actually use this system
2. Add `aoti_torch_library_impl_v2` that takes in an additional `extension_build_version` argument, updates callsite in `torch/csrc/stable/library.h` to always pass `TORCH_ABI_VERSION` for this argument
3. Add `extension_build_version` to `from_ivalue` and `to_ivalue` and update all callsites
4. Add a private `_from` and `_to` that pass `is_internal=True` to `FromImpl`/`ToImpl`, making it easier to reason about what is being called from libtorch-land / extension-land

**Note: This PR does not include a linter that tells the user to update from/to if changing the ABI of a type in headeronly, which I intend to do in https://github.com/pytorch/pytorch/pull/163998**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164332
Approved by: https://github.com/janeyx99
ghstack dependencies: #164356, #166373, #163683
This commit is contained in:
Mikayla Gawarecki 2025-10-28 13:12:02 -07:00 committed by PyTorch MergeBot
parent 8f51556daa
commit eae701cad0
3 changed files with 188 additions and 54 deletions

View File

@ -12,33 +12,41 @@
static StableIValue from_ivalue(
const c10::TypePtr& type,
const c10::IValue& ivalue) {
const c10::IValue& ivalue,
uint64_t extension_build_version) {
switch (type->kind()) {
case c10::TypeKind::TensorType: {
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
return torch::stable::detail::from(ath);
return torch::stable::detail::_from(ath, extension_build_version);
}
case c10::TypeKind::IntType: {
return torch::stable::detail::from(ivalue.toInt());
return torch::stable::detail::_from(
ivalue.toInt(), extension_build_version);
}
case c10::TypeKind::FloatType: {
return torch::stable::detail::from(ivalue.toDouble());
return torch::stable::detail::_from(
ivalue.toDouble(), extension_build_version);
}
case c10::TypeKind::BoolType: {
return torch::stable::detail::from(ivalue.toBool());
return torch::stable::detail::_from(
ivalue.toBool(), extension_build_version);
}
case c10::TypeKind::ScalarTypeType: {
return torch::stable::detail::from(ivalue.toScalarType());
return torch::stable::detail::_from(
ivalue.toScalarType(), extension_build_version);
}
case c10::TypeKind::DeviceObjType: {
return torch::stable::detail::from(ivalue.toDevice());
return torch::stable::detail::_from(
ivalue.toDevice(), extension_build_version);
}
case c10::TypeKind::LayoutType: {
return torch::stable::detail::from(ivalue.toLayout());
return torch::stable::detail::_from(
ivalue.toLayout(), extension_build_version);
}
case c10::TypeKind::MemoryFormatType: {
return torch::stable::detail::from(ivalue.toMemoryFormat());
return torch::stable::detail::_from(
ivalue.toMemoryFormat(), extension_build_version);
}
case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
@ -56,10 +64,12 @@ static StableIValue from_ivalue(
// be kept in sync with torch::stable::detail::from<std::optional<T>>
// function in torch/csrc/stable/stableivalue_conversions.h
if (ivalue.isNone()) {
return torch::stable::detail::from(std::nullopt);
return torch::stable::detail::_from(
std::nullopt, extension_build_version);
}
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue));
return torch::stable::detail::from(sivp);
StableIValue* sivp = new StableIValue(
from_ivalue(inner_type, ivalue, extension_build_version));
return torch::stable::detail::_from(sivp, extension_build_version);
}
default: {
TORCH_CHECK(
@ -72,36 +82,43 @@ static StableIValue from_ivalue(
static c10::IValue to_ivalue(
const c10::TypePtr& type,
const StableIValue stable_ivalue) {
const StableIValue stable_ivalue,
uint64_t extension_build_version) {
switch (type->kind()) {
case c10::TypeKind::TensorType: {
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
torch::stable::detail::to<AtenTensorHandle>(stable_ivalue));
torch::stable::detail::_to<AtenTensorHandle>(
stable_ivalue, extension_build_version));
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
ret_raiiath.get())));
}
case c10::TypeKind::IntType: {
return c10::IValue(torch::stable::detail::to<int64_t>(stable_ivalue));
return c10::IValue(torch::stable::detail::_to<int64_t>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::FloatType: {
return c10::IValue(torch::stable::detail::to<double>(stable_ivalue));
return c10::IValue(torch::stable::detail::_to<double>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::BoolType: {
return c10::IValue(torch::stable::detail::to<bool>(stable_ivalue));
return c10::IValue(torch::stable::detail::_to<bool>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::ScalarTypeType: {
return c10::IValue(
torch::stable::detail::to<c10::ScalarType>(stable_ivalue));
return c10::IValue(torch::stable::detail::_to<c10::ScalarType>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::DeviceObjType: {
return c10::IValue(torch::stable::detail::to<c10::Device>(stable_ivalue));
return c10::IValue(torch::stable::detail::_to<c10::Device>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::LayoutType: {
return c10::IValue(torch::stable::detail::to<c10::Layout>(stable_ivalue));
return c10::IValue(torch::stable::detail::_to<c10::Layout>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::MemoryFormatType: {
return c10::IValue(
torch::stable::detail::to<c10::MemoryFormat>(stable_ivalue));
return c10::IValue(torch::stable::detail::_to<c10::MemoryFormat>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
@ -116,13 +133,15 @@ static c10::IValue to_ivalue(
//
// BUT we do NOT have that type inner_type::t readily available, so we
// will manually unwrap and recursively call. This implementation MUST
// be kept in sync with the torch::stable::detail::to<T> function in
// torch/csrc/stable/stableivalue_conversions.h
if (stable_ivalue == torch::stable::detail::from(std::nullopt)) {
// be kept in sync with the torch::stable::detail::_to<T> function in
// torch/csrc/stable/library.h
if (stable_ivalue ==
torch::stable::detail::_from(std::nullopt, extension_build_version)) {
return c10::IValue();
}
auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue);
auto ival = to_ivalue(inner_type, *sivp);
auto sivp = torch::stable::detail::_to<StableIValue*>(
stable_ivalue, extension_build_version);
auto ival = to_ivalue(inner_type, *sivp, extension_build_version);
delete sivp;
return ival;
}
@ -137,8 +156,10 @@ static c10::IValue to_ivalue(
class StableIValueBoxedKernel : public c10::OperatorKernel {
public:
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
: fn_(fn) {}
StableIValueBoxedKernel(
void (*fn)(StableIValue*, uint64_t, uint64_t),
uint64_t extension_build_version)
: fn_(fn), extension_build_version_(extension_build_version) {}
void operator()(
const c10::OperatorHandle& op,
@ -154,7 +175,8 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
for (const auto idx : c10::irange(num_arguments)) {
const auto ministack_idx = num_arguments - idx - 1;
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack));
ministack[ministack_idx] = from_ivalue(
arg_type, torch::jit::pop(stack), extension_build_version_);
}
// boxed function is going to take a stack of StableIValues, cast them to
@ -165,12 +187,14 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
// IValue from StableIValue
for (size_t idx = 0; idx < num_returns; idx++) {
const c10::TypePtr& ret_type = schema.returns()[idx].type();
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx]));
torch::jit::push(
stack, to_ivalue(ret_type, ministack[idx], extension_build_version_));
}
}
private:
void (*fn_)(StableIValue*, uint64_t, uint64_t);
uint64_t extension_build_version_;
};
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
@ -181,7 +205,23 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
reinterpret_cast<torch::Library*>(self)->impl(
name,
torch::CppFunction::makeFromBoxedFunctor(
std::make_unique<StableIValueBoxedKernel>(fn)));
std::make_unique<StableIValueBoxedKernel>(fn, TORCH_ABI_VERSION)));
});
}
// Version-aware variant of aoti_torch_library_impl that takes an
// extension_build_version parameter for backward compatibility
AOTI_TORCH_EXPORT AOTITorchError torch_library_impl(
TorchLibraryHandle self,
const char* name,
void (*fn)(StableIValue*, uint64_t, uint64_t),
uint64_t extension_build_version) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
reinterpret_cast<torch::Library*>(self)->impl(
name,
torch::CppFunction::makeFromBoxedFunctor(
std::make_unique<StableIValueBoxedKernel>(
fn, extension_build_version)));
});
}
@ -204,7 +244,8 @@ AOTITorchError aoti_torch_call_dispatcher(
for (const auto idx : c10::irange(num_arguments)) {
auto stable_ivalue = stack[idx];
auto arg_type = schema.arguments()[idx].type();
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
torch::jit::push(
ivalue_stack, to_ivalue(arg_type, stable_ivalue, TORCH_ABI_VERSION));
}
op.callBoxed(ivalue_stack);
@ -214,7 +255,8 @@ AOTITorchError aoti_torch_call_dispatcher(
for (const auto idx : c10::irange(num_returns)) {
const auto stack_idx = num_returns - idx - 1;
const c10::TypePtr& ret_type = schema.returns()[idx].type();
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
stack[stack_idx] = from_ivalue(
ret_type, torch::jit::pop(ivalue_stack), TORCH_ABI_VERSION);
}
});
}
@ -355,7 +397,9 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
for (const auto idx : c10::irange(num_arguments)) {
auto stable_ivalue = stack[idx];
auto arg_type = schema.arguments()[idx].type();
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
torch::jit::push(
ivalue_stack,
to_ivalue(arg_type, stable_ivalue, extension_build_version));
}
}
@ -366,7 +410,8 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
for (const auto idx : c10::irange(num_returns)) {
const auto stack_idx = num_returns - idx - 1;
const c10::TypePtr& ret_type = schema.returns()[idx].type();
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
stack[stack_idx] = from_ivalue(
ret_type, torch::jit::pop(ivalue_stack), extension_build_version);
}
});
}

View File

@ -4,12 +4,14 @@
// code for better UX.
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/headeronly/macros/Macros.h>
// Technically, this file doesn't use anything from stableivalue_conversions.h,
// but we need to include it here as the contents of stableivalue_conversions.h
// used to live here and so we need to expose them for backwards compatibility.
#include <torch/csrc/stable/stableivalue_conversions.h>
#include <torch/csrc/stable/version.h>
HIDDEN_NAMESPACE_BEGIN(torch, stable, detail)
@ -81,7 +83,11 @@ class StableLibrary final {
StableLibrary& impl(
const char* name,
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
torch_library_impl(lib_, name, fn, TORCH_ABI_VERSION);
#else
aoti_torch_library_impl(lib_, name, fn);
#endif
return *this;
}

View File

@ -24,12 +24,17 @@ T to(StableIValue val);
// =============================================================================
// =============================================================================
// FROM CONVERSIONS (T -> StableIValue)
// =============================================================================
// ======================================================================
// Specialization for general copyable types (catch-all) => StableIValue
template <typename T>
struct FromImpl {
static StableIValue call(T val) {
static StableIValue call(
T val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
static_assert(
sizeof(T) <= sizeof(StableIValue),
"StableLibrary stack does not support parameter types larger than 64 bits.");
@ -68,7 +73,12 @@ struct FromImpl {
using torch::headeronly::ScalarType;
template <>
struct FromImpl<ScalarType> {
static StableIValue call(ScalarType val) {
static StableIValue call(
ScalarType val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
switch (val) {
case ScalarType::Byte:
return from(aoti_torch_dtype_uint8());
@ -121,7 +131,12 @@ struct FromImpl<ScalarType> {
// Specialization for std::nullopt_t => StableIValue
template <>
struct FromImpl<std::nullopt_t> {
static StableIValue call(std::nullopt_t val) {
static StableIValue call(
std::nullopt_t val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
return from(nullptr);
}
};
@ -157,11 +172,15 @@ struct FromImpl<std::nullopt_t> {
// std::optional<T> or a std::nullopt.
template <typename T>
struct FromImpl<std::optional<T>> {
static StableIValue call(const std::optional<T>& val) {
static StableIValue call(
const std::optional<T>& val,
uint64_t extension_build_version,
bool is_internal) {
if (!val.has_value()) {
return from(std::nullopt);
}
return from(new StableIValue(from(val.value())));
return from(new StableIValue(detail::FromImpl<T>::call(
val.value(), extension_build_version, is_internal)));
}
};
@ -169,7 +188,12 @@ struct FromImpl<std::optional<T>> {
// Returns a new owning reference of the underlying Tensor.
template <>
struct FromImpl<torch::stable::Tensor> {
static StableIValue call(const torch::stable::Tensor& val) {
static StableIValue call(
const torch::stable::Tensor& val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
AtenTensorHandle new_ath;
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
return from(new_ath);
@ -183,7 +207,12 @@ struct FromImpl<torch::stable::Tensor> {
// Specialization for StableIValue => general copyable types (catch-all)
template <typename T>
struct ToImpl {
static T call(StableIValue val) {
static T call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
static_assert(std::is_trivially_copyable_v<T>);
// T may not have a default constructor. (For example, it might be
// c10::Device.) However, std::memcpy implicitly creates a T at the
@ -218,7 +247,12 @@ struct ToImpl {
// Specialization for StableIValue => torch::headeronly::ScalarType
template <>
struct ToImpl<ScalarType> {
static ScalarType call(StableIValue val) {
static ScalarType call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
int32_t shim_scalartype = to<int32_t>(val);
if (shim_scalartype == aoti_torch_dtype_uint8()) {
return ScalarType::Byte;
@ -273,7 +307,12 @@ struct ToImpl<ScalarType> {
// Specialization for StableIValue => std::nullopt_t
template <>
struct ToImpl<std::nullopt_t> {
static std::nullopt_t call(StableIValue val) {
static std::nullopt_t call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
// val should be equivalent to from(nullptr)
return std::nullopt;
}
@ -284,14 +323,18 @@ struct ToImpl<std::nullopt_t> {
// from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
template <typename T>
struct ToImpl<std::optional<T>> {
static std::optional<T> call(StableIValue val) {
static std::optional<T> call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
auto sivp = to<StableIValue*>(val);
// sivp is either nullptr or a pointer to a StableIValue
if (sivp == nullptr) {
return {};
}
auto inner_val = to<T>(*sivp);
auto inner_val =
detail::ToImpl<T>::call(*sivp, extension_build_version, is_internal);
// free the memory associated with StableIValue* sivp
delete sivp;
@ -305,7 +348,12 @@ struct ToImpl<std::optional<T>> {
// underlying AtenTensorHandle.
template <>
struct ToImpl<torch::stable::Tensor> {
static torch::stable::Tensor call(StableIValue val) {
static torch::stable::Tensor call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
return torch::stable::Tensor(to<AtenTensorHandle>(val));
}
};
@ -315,25 +363,60 @@ struct ToImpl<torch::stable::Tensor> {
// =============================================================================
// Expose the partially templated class functions through single functions
// The non-private versions will be used by the extension or headers that
// the extension includes.
template <typename T>
inline StableIValue from(T val) {
return detail::FromImpl<T>::call(val);
return detail::FromImpl<T>::call(
val, aoti_torch_abi_version(), /*is_internal=*/false);
}
template <typename T>
inline StableIValue from(const std::optional<T>& val) {
return detail::FromImpl<std::optional<T>>::call(val);
return detail::FromImpl<std::optional<T>>::call(
val, aoti_torch_abi_version(), /*is_internal=*/false);
}
// The below overload is used! See https://godbolt.org/z/859cshxrW
// We are suppressing the warning for versions clang12- and gcc11-
[[maybe_unused]] inline StableIValue from(const torch::stable::Tensor& val) {
return detail::FromImpl<torch::stable::Tensor>::call(val);
return detail::FromImpl<torch::stable::Tensor>::call(
val, aoti_torch_abi_version(), /*is_internal=*/false);
}
template <typename T>
inline T to(StableIValue val) {
return detail::ToImpl<T>::call(val);
return detail::ToImpl<T>::call(
val, aoti_torch_abi_version(), /*is_internal=*/false);
}
// Internal conversion functions used by from_ivalue and to_ivalue.
// These are used in libtorch
template <typename T>
inline StableIValue _from(T val, uint64_t extension_build_version) {
return detail::FromImpl<T>::call(
val, extension_build_version, /*is_internal=*/true);
}
template <typename T>
inline StableIValue _from(
const std::optional<T>& val,
uint64_t extension_build_version) {
return detail::FromImpl<std::optional<T>>::call(
val, extension_build_version, /*is_internal=*/true);
}
[[maybe_unused]] inline StableIValue _from(
const torch::stable::Tensor& val,
uint64_t extension_build_version) {
return detail::FromImpl<torch::stable::Tensor>::call(
val, extension_build_version, /*is_internal=*/true);
}
template <typename T>
inline T _to(StableIValue val, uint64_t extension_build_version) {
return detail::ToImpl<T>::call(
val, extension_build_version, /*is_internal=*/true);
}
HIDDEN_NAMESPACE_END(torch, stable, detail)