From eae701cad03feb87908107cb5246d9e13a9d426e Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Tue, 28 Oct 2025 13:12:02 -0700 Subject: [PATCH] Add scaffolding for StableIValue FC/BC (no PoC) (#164332) 1. Add `extension_build_version` and `is_internal` to `FromImpl`/`ToImpl` (this will be useful for future if we need to break the BC of any type) #163832 has the PoC of how we would actually use this system 2. Add `aoti_torch_library_impl_v2` that takes in an additional `extension_build_version` argument, updates callsite in `torch/csrc/stable/library.h` to always pass `TORCH_ABI_VERSION` for this argument 3. Add `extension_build_version` to `from_ivalue` and `to_ivalue` and update all callsites 4. Add a private `_from` and `_to` that pass `is_internal=True` to `FromImpl`/`ToImpl`, making it easier to reason about what is being called from libtorch-land / extension-land **Note: This PR does not include a linter that tells the user to update from/to if changing the ABI of a type in headeronly, which I intend to do in https://github.com/pytorch/pytorch/pull/163998** Pull Request resolved: https://github.com/pytorch/pytorch/pull/164332 Approved by: https://github.com/janeyx99 ghstack dependencies: #164356, #166373, #163683 --- torch/csrc/shim_common.cpp | 119 +++++++++++++------ torch/csrc/stable/library.h | 6 + torch/csrc/stable/stableivalue_conversions.h | 117 +++++++++++++++--- 3 files changed, 188 insertions(+), 54 deletions(-) diff --git a/torch/csrc/shim_common.cpp b/torch/csrc/shim_common.cpp index 6d0fe0e5d96..23effad1a36 100644 --- a/torch/csrc/shim_common.cpp +++ b/torch/csrc/shim_common.cpp @@ -12,33 +12,41 @@ static StableIValue from_ivalue( const c10::TypePtr& type, - const c10::IValue& ivalue) { + const c10::IValue& ivalue, + uint64_t extension_build_version) { switch (type->kind()) { case c10::TypeKind::TensorType: { AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle( std::move(const_cast(ivalue.toTensor()))); - return torch::stable::detail::from(ath); + return torch::stable::detail::_from(ath, extension_build_version); } case c10::TypeKind::IntType: { - return torch::stable::detail::from(ivalue.toInt()); + return torch::stable::detail::_from( + ivalue.toInt(), extension_build_version); } case c10::TypeKind::FloatType: { - return torch::stable::detail::from(ivalue.toDouble()); + return torch::stable::detail::_from( + ivalue.toDouble(), extension_build_version); } case c10::TypeKind::BoolType: { - return torch::stable::detail::from(ivalue.toBool()); + return torch::stable::detail::_from( + ivalue.toBool(), extension_build_version); } case c10::TypeKind::ScalarTypeType: { - return torch::stable::detail::from(ivalue.toScalarType()); + return torch::stable::detail::_from( + ivalue.toScalarType(), extension_build_version); } case c10::TypeKind::DeviceObjType: { - return torch::stable::detail::from(ivalue.toDevice()); + return torch::stable::detail::_from( + ivalue.toDevice(), extension_build_version); } case c10::TypeKind::LayoutType: { - return torch::stable::detail::from(ivalue.toLayout()); + return torch::stable::detail::_from( + ivalue.toLayout(), extension_build_version); } case c10::TypeKind::MemoryFormatType: { - return torch::stable::detail::from(ivalue.toMemoryFormat()); + return torch::stable::detail::_from( + ivalue.toMemoryFormat(), extension_build_version); } case c10::TypeKind::OptionalType: { auto inner_type = type->castRaw()->getElementType(); @@ -56,10 +64,12 @@ static StableIValue from_ivalue( // be kept in sync with torch::stable::detail::from> // function in torch/csrc/stable/stableivalue_conversions.h if (ivalue.isNone()) { - return torch::stable::detail::from(std::nullopt); + return torch::stable::detail::_from( + std::nullopt, extension_build_version); } - StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue)); - return torch::stable::detail::from(sivp); + StableIValue* sivp = new StableIValue( + from_ivalue(inner_type, ivalue, extension_build_version)); + return torch::stable::detail::_from(sivp, extension_build_version); } default: { TORCH_CHECK( @@ -72,36 +82,43 @@ static StableIValue from_ivalue( static c10::IValue to_ivalue( const c10::TypePtr& type, - const StableIValue stable_ivalue) { + const StableIValue stable_ivalue, + uint64_t extension_build_version) { switch (type->kind()) { case c10::TypeKind::TensorType: { auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle( - torch::stable::detail::to(stable_ivalue)); + torch::stable::detail::_to( + stable_ivalue, extension_build_version)); return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer( ret_raiiath.get()))); } case c10::TypeKind::IntType: { - return c10::IValue(torch::stable::detail::to(stable_ivalue)); + return c10::IValue(torch::stable::detail::_to( + stable_ivalue, extension_build_version)); } case c10::TypeKind::FloatType: { - return c10::IValue(torch::stable::detail::to(stable_ivalue)); + return c10::IValue(torch::stable::detail::_to( + stable_ivalue, extension_build_version)); } case c10::TypeKind::BoolType: { - return c10::IValue(torch::stable::detail::to(stable_ivalue)); + return c10::IValue(torch::stable::detail::_to( + stable_ivalue, extension_build_version)); } case c10::TypeKind::ScalarTypeType: { - return c10::IValue( - torch::stable::detail::to(stable_ivalue)); + return c10::IValue(torch::stable::detail::_to( + stable_ivalue, extension_build_version)); } case c10::TypeKind::DeviceObjType: { - return c10::IValue(torch::stable::detail::to(stable_ivalue)); + return c10::IValue(torch::stable::detail::_to( + stable_ivalue, extension_build_version)); } case c10::TypeKind::LayoutType: { - return c10::IValue(torch::stable::detail::to(stable_ivalue)); + return c10::IValue(torch::stable::detail::_to( + stable_ivalue, extension_build_version)); } case c10::TypeKind::MemoryFormatType: { - return c10::IValue( - torch::stable::detail::to(stable_ivalue)); + return c10::IValue(torch::stable::detail::_to( + stable_ivalue, extension_build_version)); } case c10::TypeKind::OptionalType: { auto inner_type = type->castRaw()->getElementType(); @@ -116,13 +133,15 @@ static c10::IValue to_ivalue( // // BUT we do NOT have that type inner_type::t readily available, so we // will manually unwrap and recursively call. This implementation MUST - // be kept in sync with the torch::stable::detail::to function in - // torch/csrc/stable/stableivalue_conversions.h - if (stable_ivalue == torch::stable::detail::from(std::nullopt)) { + // be kept in sync with the torch::stable::detail::_to function in + // torch/csrc/stable/library.h + if (stable_ivalue == + torch::stable::detail::_from(std::nullopt, extension_build_version)) { return c10::IValue(); } - auto sivp = torch::stable::detail::to(stable_ivalue); - auto ival = to_ivalue(inner_type, *sivp); + auto sivp = torch::stable::detail::_to( + stable_ivalue, extension_build_version); + auto ival = to_ivalue(inner_type, *sivp, extension_build_version); delete sivp; return ival; } @@ -137,8 +156,10 @@ static c10::IValue to_ivalue( class StableIValueBoxedKernel : public c10::OperatorKernel { public: - StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t)) - : fn_(fn) {} + StableIValueBoxedKernel( + void (*fn)(StableIValue*, uint64_t, uint64_t), + uint64_t extension_build_version) + : fn_(fn), extension_build_version_(extension_build_version) {} void operator()( const c10::OperatorHandle& op, @@ -154,7 +175,8 @@ class StableIValueBoxedKernel : public c10::OperatorKernel { for (const auto idx : c10::irange(num_arguments)) { const auto ministack_idx = num_arguments - idx - 1; const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type(); - ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack)); + ministack[ministack_idx] = from_ivalue( + arg_type, torch::jit::pop(stack), extension_build_version_); } // boxed function is going to take a stack of StableIValues, cast them to @@ -165,12 +187,14 @@ class StableIValueBoxedKernel : public c10::OperatorKernel { // IValue from StableIValue for (size_t idx = 0; idx < num_returns; idx++) { const c10::TypePtr& ret_type = schema.returns()[idx].type(); - torch::jit::push(stack, to_ivalue(ret_type, ministack[idx])); + torch::jit::push( + stack, to_ivalue(ret_type, ministack[idx], extension_build_version_)); } } private: void (*fn_)(StableIValue*, uint64_t, uint64_t); + uint64_t extension_build_version_; }; AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl( @@ -181,7 +205,23 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl( reinterpret_cast(self)->impl( name, torch::CppFunction::makeFromBoxedFunctor( - std::make_unique(fn))); + std::make_unique(fn, TORCH_ABI_VERSION))); + }); +} + +// Version-aware variant of aoti_torch_library_impl that takes an +// extension_build_version parameter for backward compatibility +AOTI_TORCH_EXPORT AOTITorchError torch_library_impl( + TorchLibraryHandle self, + const char* name, + void (*fn)(StableIValue*, uint64_t, uint64_t), + uint64_t extension_build_version) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + reinterpret_cast(self)->impl( + name, + torch::CppFunction::makeFromBoxedFunctor( + std::make_unique( + fn, extension_build_version))); }); } @@ -204,7 +244,8 @@ AOTITorchError aoti_torch_call_dispatcher( for (const auto idx : c10::irange(num_arguments)) { auto stable_ivalue = stack[idx]; auto arg_type = schema.arguments()[idx].type(); - torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue)); + torch::jit::push( + ivalue_stack, to_ivalue(arg_type, stable_ivalue, TORCH_ABI_VERSION)); } op.callBoxed(ivalue_stack); @@ -214,7 +255,8 @@ AOTITorchError aoti_torch_call_dispatcher( for (const auto idx : c10::irange(num_returns)) { const auto stack_idx = num_returns - idx - 1; const c10::TypePtr& ret_type = schema.returns()[idx].type(); - stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack)); + stack[stack_idx] = from_ivalue( + ret_type, torch::jit::pop(ivalue_stack), TORCH_ABI_VERSION); } }); } @@ -355,7 +397,9 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher( for (const auto idx : c10::irange(num_arguments)) { auto stable_ivalue = stack[idx]; auto arg_type = schema.arguments()[idx].type(); - torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue)); + torch::jit::push( + ivalue_stack, + to_ivalue(arg_type, stable_ivalue, extension_build_version)); } } @@ -366,7 +410,8 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher( for (const auto idx : c10::irange(num_returns)) { const auto stack_idx = num_returns - idx - 1; const c10::TypePtr& ret_type = schema.returns()[idx].type(); - stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack)); + stack[stack_idx] = from_ivalue( + ret_type, torch::jit::pop(ivalue_stack), extension_build_version); } }); } diff --git a/torch/csrc/stable/library.h b/torch/csrc/stable/library.h index a2cb0eda9e7..61bc6d7249f 100644 --- a/torch/csrc/stable/library.h +++ b/torch/csrc/stable/library.h @@ -4,12 +4,14 @@ // code for better UX. #include +#include #include // Technically, this file doesn't use anything from stableivalue_conversions.h, // but we need to include it here as the contents of stableivalue_conversions.h // used to live here and so we need to expose them for backwards compatibility. #include +#include HIDDEN_NAMESPACE_BEGIN(torch, stable, detail) @@ -81,7 +83,11 @@ class StableLibrary final { StableLibrary& impl( const char* name, void (*fn)(StableIValue*, uint64_t, uint64_t)) { +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + torch_library_impl(lib_, name, fn, TORCH_ABI_VERSION); +#else aoti_torch_library_impl(lib_, name, fn); +#endif return *this; } diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index 044d31d119c..f35ed50d99b 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -24,12 +24,17 @@ T to(StableIValue val); // ============================================================================= // ============================================================================= // FROM CONVERSIONS (T -> StableIValue) -// ============================================================================= +// ====================================================================== // Specialization for general copyable types (catch-all) => StableIValue template struct FromImpl { - static StableIValue call(T val) { + static StableIValue call( + T val, + uint64_t extension_build_version, + bool is_internal) { + (void)extension_build_version; // Unused parameter + (void)is_internal; // Unused parameter static_assert( sizeof(T) <= sizeof(StableIValue), "StableLibrary stack does not support parameter types larger than 64 bits."); @@ -68,7 +73,12 @@ struct FromImpl { using torch::headeronly::ScalarType; template <> struct FromImpl { - static StableIValue call(ScalarType val) { + static StableIValue call( + ScalarType val, + uint64_t extension_build_version, + bool is_internal) { + (void)extension_build_version; // Unused parameter + (void)is_internal; // Unused parameter switch (val) { case ScalarType::Byte: return from(aoti_torch_dtype_uint8()); @@ -121,7 +131,12 @@ struct FromImpl { // Specialization for std::nullopt_t => StableIValue template <> struct FromImpl { - static StableIValue call(std::nullopt_t val) { + static StableIValue call( + std::nullopt_t val, + uint64_t extension_build_version, + bool is_internal) { + (void)extension_build_version; // Unused parameter + (void)is_internal; // Unused parameter return from(nullptr); } }; @@ -157,11 +172,15 @@ struct FromImpl { // std::optional or a std::nullopt. template struct FromImpl> { - static StableIValue call(const std::optional& val) { + static StableIValue call( + const std::optional& val, + uint64_t extension_build_version, + bool is_internal) { if (!val.has_value()) { return from(std::nullopt); } - return from(new StableIValue(from(val.value()))); + return from(new StableIValue(detail::FromImpl::call( + val.value(), extension_build_version, is_internal))); } }; @@ -169,7 +188,12 @@ struct FromImpl> { // Returns a new owning reference of the underlying Tensor. template <> struct FromImpl { - static StableIValue call(const torch::stable::Tensor& val) { + static StableIValue call( + const torch::stable::Tensor& val, + uint64_t extension_build_version, + bool is_internal) { + (void)extension_build_version; // Unused parameter + (void)is_internal; // Unused parameter AtenTensorHandle new_ath; TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath)); return from(new_ath); @@ -183,7 +207,12 @@ struct FromImpl { // Specialization for StableIValue => general copyable types (catch-all) template struct ToImpl { - static T call(StableIValue val) { + static T call( + StableIValue val, + uint64_t extension_build_version, + bool is_internal) { + (void)extension_build_version; // Unused parameter + (void)is_internal; // Unused parameter static_assert(std::is_trivially_copyable_v); // T may not have a default constructor. (For example, it might be // c10::Device.) However, std::memcpy implicitly creates a T at the @@ -218,7 +247,12 @@ struct ToImpl { // Specialization for StableIValue => torch::headeronly::ScalarType template <> struct ToImpl { - static ScalarType call(StableIValue val) { + static ScalarType call( + StableIValue val, + uint64_t extension_build_version, + bool is_internal) { + (void)extension_build_version; // Unused parameter + (void)is_internal; // Unused parameter int32_t shim_scalartype = to(val); if (shim_scalartype == aoti_torch_dtype_uint8()) { return ScalarType::Byte; @@ -273,7 +307,12 @@ struct ToImpl { // Specialization for StableIValue => std::nullopt_t template <> struct ToImpl { - static std::nullopt_t call(StableIValue val) { + static std::nullopt_t call( + StableIValue val, + uint64_t extension_build_version, + bool is_internal) { + (void)extension_build_version; // Unused parameter + (void)is_internal; // Unused parameter // val should be equivalent to from(nullptr) return std::nullopt; } @@ -284,14 +323,18 @@ struct ToImpl { // from IValue --(from_ivalue)-> StableIValue --(to)-> T in custom extension template struct ToImpl> { - static std::optional call(StableIValue val) { + static std::optional call( + StableIValue val, + uint64_t extension_build_version, + bool is_internal) { auto sivp = to(val); // sivp is either nullptr or a pointer to a StableIValue if (sivp == nullptr) { return {}; } - auto inner_val = to(*sivp); + auto inner_val = + detail::ToImpl::call(*sivp, extension_build_version, is_internal); // free the memory associated with StableIValue* sivp delete sivp; @@ -305,7 +348,12 @@ struct ToImpl> { // underlying AtenTensorHandle. template <> struct ToImpl { - static torch::stable::Tensor call(StableIValue val) { + static torch::stable::Tensor call( + StableIValue val, + uint64_t extension_build_version, + bool is_internal) { + (void)extension_build_version; // Unused parameter + (void)is_internal; // Unused parameter return torch::stable::Tensor(to(val)); } }; @@ -315,25 +363,60 @@ struct ToImpl { // ============================================================================= // Expose the partially templated class functions through single functions +// The non-private versions will be used by the extension or headers that +// the extension includes. template inline StableIValue from(T val) { - return detail::FromImpl::call(val); + return detail::FromImpl::call( + val, aoti_torch_abi_version(), /*is_internal=*/false); } template inline StableIValue from(const std::optional& val) { - return detail::FromImpl>::call(val); + return detail::FromImpl>::call( + val, aoti_torch_abi_version(), /*is_internal=*/false); } // The below overload is used! See https://godbolt.org/z/859cshxrW // We are suppressing the warning for versions clang12- and gcc11- [[maybe_unused]] inline StableIValue from(const torch::stable::Tensor& val) { - return detail::FromImpl::call(val); + return detail::FromImpl::call( + val, aoti_torch_abi_version(), /*is_internal=*/false); } template inline T to(StableIValue val) { - return detail::ToImpl::call(val); + return detail::ToImpl::call( + val, aoti_torch_abi_version(), /*is_internal=*/false); +} + +// Internal conversion functions used by from_ivalue and to_ivalue. +// These are used in libtorch +template +inline StableIValue _from(T val, uint64_t extension_build_version) { + return detail::FromImpl::call( + val, extension_build_version, /*is_internal=*/true); +} + +template +inline StableIValue _from( + const std::optional& val, + uint64_t extension_build_version) { + return detail::FromImpl>::call( + val, extension_build_version, /*is_internal=*/true); +} + +[[maybe_unused]] inline StableIValue _from( + const torch::stable::Tensor& val, + uint64_t extension_build_version) { + return detail::FromImpl::call( + val, extension_build_version, /*is_internal=*/true); +} + +template +inline T _to(StableIValue val, uint64_t extension_build_version) { + return detail::ToImpl::call( + val, extension_build_version, /*is_internal=*/true); } HIDDEN_NAMESPACE_END(torch, stable, detail)