Add scaffolding for StableIValue FC/BC (no PoC) (#164332)

1. Add `extension_build_version` and `is_internal` to `FromImpl`/`ToImpl` (this will be useful for future if we need to break the BC of any type) #163832 has the PoC of how we would actually use this system
2. Add `aoti_torch_library_impl_v2` that takes in an additional `extension_build_version` argument, updates callsite in `torch/csrc/stable/library.h` to always pass `TORCH_ABI_VERSION` for this argument
3. Add `extension_build_version` to `from_ivalue` and `to_ivalue` and update all callsites
4. Add a private `_from` and `_to` that pass `is_internal=True` to `FromImpl`/`ToImpl`, making it easier to reason about what is being called from libtorch-land / extension-land

**Note: This PR does not include a linter that tells the user to update from/to if changing the ABI of a type in headeronly, which I intend to do in https://github.com/pytorch/pytorch/pull/163998**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164332
Approved by: https://github.com/janeyx99
ghstack dependencies: #164356, #166373, #163683
This commit is contained in:
Mikayla Gawarecki 2025-10-28 13:12:02 -07:00 committed by PyTorch MergeBot
parent 8f51556daa
commit eae701cad0
3 changed files with 188 additions and 54 deletions

View File

@ -12,33 +12,41 @@
static StableIValue from_ivalue( static StableIValue from_ivalue(
const c10::TypePtr& type, const c10::TypePtr& type,
const c10::IValue& ivalue) { const c10::IValue& ivalue,
uint64_t extension_build_version) {
switch (type->kind()) { switch (type->kind()) {
case c10::TypeKind::TensorType: { case c10::TypeKind::TensorType: {
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle( AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
std::move(const_cast<at::Tensor&>(ivalue.toTensor()))); std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
return torch::stable::detail::from(ath); return torch::stable::detail::_from(ath, extension_build_version);
} }
case c10::TypeKind::IntType: { case c10::TypeKind::IntType: {
return torch::stable::detail::from(ivalue.toInt()); return torch::stable::detail::_from(
ivalue.toInt(), extension_build_version);
} }
case c10::TypeKind::FloatType: { case c10::TypeKind::FloatType: {
return torch::stable::detail::from(ivalue.toDouble()); return torch::stable::detail::_from(
ivalue.toDouble(), extension_build_version);
} }
case c10::TypeKind::BoolType: { case c10::TypeKind::BoolType: {
return torch::stable::detail::from(ivalue.toBool()); return torch::stable::detail::_from(
ivalue.toBool(), extension_build_version);
} }
case c10::TypeKind::ScalarTypeType: { case c10::TypeKind::ScalarTypeType: {
return torch::stable::detail::from(ivalue.toScalarType()); return torch::stable::detail::_from(
ivalue.toScalarType(), extension_build_version);
} }
case c10::TypeKind::DeviceObjType: { case c10::TypeKind::DeviceObjType: {
return torch::stable::detail::from(ivalue.toDevice()); return torch::stable::detail::_from(
ivalue.toDevice(), extension_build_version);
} }
case c10::TypeKind::LayoutType: { case c10::TypeKind::LayoutType: {
return torch::stable::detail::from(ivalue.toLayout()); return torch::stable::detail::_from(
ivalue.toLayout(), extension_build_version);
} }
case c10::TypeKind::MemoryFormatType: { case c10::TypeKind::MemoryFormatType: {
return torch::stable::detail::from(ivalue.toMemoryFormat()); return torch::stable::detail::_from(
ivalue.toMemoryFormat(), extension_build_version);
} }
case c10::TypeKind::OptionalType: { case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType(); auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
@ -56,10 +64,12 @@ static StableIValue from_ivalue(
// be kept in sync with torch::stable::detail::from<std::optional<T>> // be kept in sync with torch::stable::detail::from<std::optional<T>>
// function in torch/csrc/stable/stableivalue_conversions.h // function in torch/csrc/stable/stableivalue_conversions.h
if (ivalue.isNone()) { if (ivalue.isNone()) {
return torch::stable::detail::from(std::nullopt); return torch::stable::detail::_from(
std::nullopt, extension_build_version);
} }
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue)); StableIValue* sivp = new StableIValue(
return torch::stable::detail::from(sivp); from_ivalue(inner_type, ivalue, extension_build_version));
return torch::stable::detail::_from(sivp, extension_build_version);
} }
default: { default: {
TORCH_CHECK( TORCH_CHECK(
@ -72,36 +82,43 @@ static StableIValue from_ivalue(
static c10::IValue to_ivalue( static c10::IValue to_ivalue(
const c10::TypePtr& type, const c10::TypePtr& type,
const StableIValue stable_ivalue) { const StableIValue stable_ivalue,
uint64_t extension_build_version) {
switch (type->kind()) { switch (type->kind()) {
case c10::TypeKind::TensorType: { case c10::TypeKind::TensorType: {
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle( auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
torch::stable::detail::to<AtenTensorHandle>(stable_ivalue)); torch::stable::detail::_to<AtenTensorHandle>(
stable_ivalue, extension_build_version));
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer( return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
ret_raiiath.get()))); ret_raiiath.get())));
} }
case c10::TypeKind::IntType: { case c10::TypeKind::IntType: {
return c10::IValue(torch::stable::detail::to<int64_t>(stable_ivalue)); return c10::IValue(torch::stable::detail::_to<int64_t>(
stable_ivalue, extension_build_version));
} }
case c10::TypeKind::FloatType: { case c10::TypeKind::FloatType: {
return c10::IValue(torch::stable::detail::to<double>(stable_ivalue)); return c10::IValue(torch::stable::detail::_to<double>(
stable_ivalue, extension_build_version));
} }
case c10::TypeKind::BoolType: { case c10::TypeKind::BoolType: {
return c10::IValue(torch::stable::detail::to<bool>(stable_ivalue)); return c10::IValue(torch::stable::detail::_to<bool>(
stable_ivalue, extension_build_version));
} }
case c10::TypeKind::ScalarTypeType: { case c10::TypeKind::ScalarTypeType: {
return c10::IValue( return c10::IValue(torch::stable::detail::_to<c10::ScalarType>(
torch::stable::detail::to<c10::ScalarType>(stable_ivalue)); stable_ivalue, extension_build_version));
} }
case c10::TypeKind::DeviceObjType: { case c10::TypeKind::DeviceObjType: {
return c10::IValue(torch::stable::detail::to<c10::Device>(stable_ivalue)); return c10::IValue(torch::stable::detail::_to<c10::Device>(
stable_ivalue, extension_build_version));
} }
case c10::TypeKind::LayoutType: { case c10::TypeKind::LayoutType: {
return c10::IValue(torch::stable::detail::to<c10::Layout>(stable_ivalue)); return c10::IValue(torch::stable::detail::_to<c10::Layout>(
stable_ivalue, extension_build_version));
} }
case c10::TypeKind::MemoryFormatType: { case c10::TypeKind::MemoryFormatType: {
return c10::IValue( return c10::IValue(torch::stable::detail::_to<c10::MemoryFormat>(
torch::stable::detail::to<c10::MemoryFormat>(stable_ivalue)); stable_ivalue, extension_build_version));
} }
case c10::TypeKind::OptionalType: { case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType(); auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
@ -116,13 +133,15 @@ static c10::IValue to_ivalue(
// //
// BUT we do NOT have that type inner_type::t readily available, so we // BUT we do NOT have that type inner_type::t readily available, so we
// will manually unwrap and recursively call. This implementation MUST // will manually unwrap and recursively call. This implementation MUST
// be kept in sync with the torch::stable::detail::to<T> function in // be kept in sync with the torch::stable::detail::_to<T> function in
// torch/csrc/stable/stableivalue_conversions.h // torch/csrc/stable/library.h
if (stable_ivalue == torch::stable::detail::from(std::nullopt)) { if (stable_ivalue ==
torch::stable::detail::_from(std::nullopt, extension_build_version)) {
return c10::IValue(); return c10::IValue();
} }
auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue); auto sivp = torch::stable::detail::_to<StableIValue*>(
auto ival = to_ivalue(inner_type, *sivp); stable_ivalue, extension_build_version);
auto ival = to_ivalue(inner_type, *sivp, extension_build_version);
delete sivp; delete sivp;
return ival; return ival;
} }
@ -137,8 +156,10 @@ static c10::IValue to_ivalue(
class StableIValueBoxedKernel : public c10::OperatorKernel { class StableIValueBoxedKernel : public c10::OperatorKernel {
public: public:
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t)) StableIValueBoxedKernel(
: fn_(fn) {} void (*fn)(StableIValue*, uint64_t, uint64_t),
uint64_t extension_build_version)
: fn_(fn), extension_build_version_(extension_build_version) {}
void operator()( void operator()(
const c10::OperatorHandle& op, const c10::OperatorHandle& op,
@ -154,7 +175,8 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
for (const auto idx : c10::irange(num_arguments)) { for (const auto idx : c10::irange(num_arguments)) {
const auto ministack_idx = num_arguments - idx - 1; const auto ministack_idx = num_arguments - idx - 1;
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type(); const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack)); ministack[ministack_idx] = from_ivalue(
arg_type, torch::jit::pop(stack), extension_build_version_);
} }
// boxed function is going to take a stack of StableIValues, cast them to // boxed function is going to take a stack of StableIValues, cast them to
@ -165,12 +187,14 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
// IValue from StableIValue // IValue from StableIValue
for (size_t idx = 0; idx < num_returns; idx++) { for (size_t idx = 0; idx < num_returns; idx++) {
const c10::TypePtr& ret_type = schema.returns()[idx].type(); const c10::TypePtr& ret_type = schema.returns()[idx].type();
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx])); torch::jit::push(
stack, to_ivalue(ret_type, ministack[idx], extension_build_version_));
} }
} }
private: private:
void (*fn_)(StableIValue*, uint64_t, uint64_t); void (*fn_)(StableIValue*, uint64_t, uint64_t);
uint64_t extension_build_version_;
}; };
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl( AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
@ -181,7 +205,23 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
reinterpret_cast<torch::Library*>(self)->impl( reinterpret_cast<torch::Library*>(self)->impl(
name, name,
torch::CppFunction::makeFromBoxedFunctor( torch::CppFunction::makeFromBoxedFunctor(
std::make_unique<StableIValueBoxedKernel>(fn))); std::make_unique<StableIValueBoxedKernel>(fn, TORCH_ABI_VERSION)));
});
}
// Version-aware variant of aoti_torch_library_impl that takes an
// extension_build_version parameter for backward compatibility
AOTI_TORCH_EXPORT AOTITorchError torch_library_impl(
TorchLibraryHandle self,
const char* name,
void (*fn)(StableIValue*, uint64_t, uint64_t),
uint64_t extension_build_version) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
reinterpret_cast<torch::Library*>(self)->impl(
name,
torch::CppFunction::makeFromBoxedFunctor(
std::make_unique<StableIValueBoxedKernel>(
fn, extension_build_version)));
}); });
} }
@ -204,7 +244,8 @@ AOTITorchError aoti_torch_call_dispatcher(
for (const auto idx : c10::irange(num_arguments)) { for (const auto idx : c10::irange(num_arguments)) {
auto stable_ivalue = stack[idx]; auto stable_ivalue = stack[idx];
auto arg_type = schema.arguments()[idx].type(); auto arg_type = schema.arguments()[idx].type();
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue)); torch::jit::push(
ivalue_stack, to_ivalue(arg_type, stable_ivalue, TORCH_ABI_VERSION));
} }
op.callBoxed(ivalue_stack); op.callBoxed(ivalue_stack);
@ -214,7 +255,8 @@ AOTITorchError aoti_torch_call_dispatcher(
for (const auto idx : c10::irange(num_returns)) { for (const auto idx : c10::irange(num_returns)) {
const auto stack_idx = num_returns - idx - 1; const auto stack_idx = num_returns - idx - 1;
const c10::TypePtr& ret_type = schema.returns()[idx].type(); const c10::TypePtr& ret_type = schema.returns()[idx].type();
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack)); stack[stack_idx] = from_ivalue(
ret_type, torch::jit::pop(ivalue_stack), TORCH_ABI_VERSION);
} }
}); });
} }
@ -355,7 +397,9 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
for (const auto idx : c10::irange(num_arguments)) { for (const auto idx : c10::irange(num_arguments)) {
auto stable_ivalue = stack[idx]; auto stable_ivalue = stack[idx];
auto arg_type = schema.arguments()[idx].type(); auto arg_type = schema.arguments()[idx].type();
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue)); torch::jit::push(
ivalue_stack,
to_ivalue(arg_type, stable_ivalue, extension_build_version));
} }
} }
@ -366,7 +410,8 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
for (const auto idx : c10::irange(num_returns)) { for (const auto idx : c10::irange(num_returns)) {
const auto stack_idx = num_returns - idx - 1; const auto stack_idx = num_returns - idx - 1;
const c10::TypePtr& ret_type = schema.returns()[idx].type(); const c10::TypePtr& ret_type = schema.returns()[idx].type();
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack)); stack[stack_idx] = from_ivalue(
ret_type, torch::jit::pop(ivalue_stack), extension_build_version);
} }
}); });
} }

View File

@ -4,12 +4,14 @@
// code for better UX. // code for better UX.
#include <torch/csrc/inductor/aoti_torch/c/shim.h> #include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/headeronly/macros/Macros.h> #include <torch/headeronly/macros/Macros.h>
// Technically, this file doesn't use anything from stableivalue_conversions.h, // Technically, this file doesn't use anything from stableivalue_conversions.h,
// but we need to include it here as the contents of stableivalue_conversions.h // but we need to include it here as the contents of stableivalue_conversions.h
// used to live here and so we need to expose them for backwards compatibility. // used to live here and so we need to expose them for backwards compatibility.
#include <torch/csrc/stable/stableivalue_conversions.h> #include <torch/csrc/stable/stableivalue_conversions.h>
#include <torch/csrc/stable/version.h>
HIDDEN_NAMESPACE_BEGIN(torch, stable, detail) HIDDEN_NAMESPACE_BEGIN(torch, stable, detail)
@ -81,7 +83,11 @@ class StableLibrary final {
StableLibrary& impl( StableLibrary& impl(
const char* name, const char* name,
void (*fn)(StableIValue*, uint64_t, uint64_t)) { void (*fn)(StableIValue*, uint64_t, uint64_t)) {
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
torch_library_impl(lib_, name, fn, TORCH_ABI_VERSION);
#else
aoti_torch_library_impl(lib_, name, fn); aoti_torch_library_impl(lib_, name, fn);
#endif
return *this; return *this;
} }

View File

@ -24,12 +24,17 @@ T to(StableIValue val);
// ============================================================================= // =============================================================================
// ============================================================================= // =============================================================================
// FROM CONVERSIONS (T -> StableIValue) // FROM CONVERSIONS (T -> StableIValue)
// ============================================================================= // ======================================================================
// Specialization for general copyable types (catch-all) => StableIValue // Specialization for general copyable types (catch-all) => StableIValue
template <typename T> template <typename T>
struct FromImpl { struct FromImpl {
static StableIValue call(T val) { static StableIValue call(
T val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
static_assert( static_assert(
sizeof(T) <= sizeof(StableIValue), sizeof(T) <= sizeof(StableIValue),
"StableLibrary stack does not support parameter types larger than 64 bits."); "StableLibrary stack does not support parameter types larger than 64 bits.");
@ -68,7 +73,12 @@ struct FromImpl {
using torch::headeronly::ScalarType; using torch::headeronly::ScalarType;
template <> template <>
struct FromImpl<ScalarType> { struct FromImpl<ScalarType> {
static StableIValue call(ScalarType val) { static StableIValue call(
ScalarType val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
switch (val) { switch (val) {
case ScalarType::Byte: case ScalarType::Byte:
return from(aoti_torch_dtype_uint8()); return from(aoti_torch_dtype_uint8());
@ -121,7 +131,12 @@ struct FromImpl<ScalarType> {
// Specialization for std::nullopt_t => StableIValue // Specialization for std::nullopt_t => StableIValue
template <> template <>
struct FromImpl<std::nullopt_t> { struct FromImpl<std::nullopt_t> {
static StableIValue call(std::nullopt_t val) { static StableIValue call(
std::nullopt_t val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
return from(nullptr); return from(nullptr);
} }
}; };
@ -157,11 +172,15 @@ struct FromImpl<std::nullopt_t> {
// std::optional<T> or a std::nullopt. // std::optional<T> or a std::nullopt.
template <typename T> template <typename T>
struct FromImpl<std::optional<T>> { struct FromImpl<std::optional<T>> {
static StableIValue call(const std::optional<T>& val) { static StableIValue call(
const std::optional<T>& val,
uint64_t extension_build_version,
bool is_internal) {
if (!val.has_value()) { if (!val.has_value()) {
return from(std::nullopt); return from(std::nullopt);
} }
return from(new StableIValue(from(val.value()))); return from(new StableIValue(detail::FromImpl<T>::call(
val.value(), extension_build_version, is_internal)));
} }
}; };
@ -169,7 +188,12 @@ struct FromImpl<std::optional<T>> {
// Returns a new owning reference of the underlying Tensor. // Returns a new owning reference of the underlying Tensor.
template <> template <>
struct FromImpl<torch::stable::Tensor> { struct FromImpl<torch::stable::Tensor> {
static StableIValue call(const torch::stable::Tensor& val) { static StableIValue call(
const torch::stable::Tensor& val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
AtenTensorHandle new_ath; AtenTensorHandle new_ath;
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath)); TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
return from(new_ath); return from(new_ath);
@ -183,7 +207,12 @@ struct FromImpl<torch::stable::Tensor> {
// Specialization for StableIValue => general copyable types (catch-all) // Specialization for StableIValue => general copyable types (catch-all)
template <typename T> template <typename T>
struct ToImpl { struct ToImpl {
static T call(StableIValue val) { static T call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
static_assert(std::is_trivially_copyable_v<T>); static_assert(std::is_trivially_copyable_v<T>);
// T may not have a default constructor. (For example, it might be // T may not have a default constructor. (For example, it might be
// c10::Device.) However, std::memcpy implicitly creates a T at the // c10::Device.) However, std::memcpy implicitly creates a T at the
@ -218,7 +247,12 @@ struct ToImpl {
// Specialization for StableIValue => torch::headeronly::ScalarType // Specialization for StableIValue => torch::headeronly::ScalarType
template <> template <>
struct ToImpl<ScalarType> { struct ToImpl<ScalarType> {
static ScalarType call(StableIValue val) { static ScalarType call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
int32_t shim_scalartype = to<int32_t>(val); int32_t shim_scalartype = to<int32_t>(val);
if (shim_scalartype == aoti_torch_dtype_uint8()) { if (shim_scalartype == aoti_torch_dtype_uint8()) {
return ScalarType::Byte; return ScalarType::Byte;
@ -273,7 +307,12 @@ struct ToImpl<ScalarType> {
// Specialization for StableIValue => std::nullopt_t // Specialization for StableIValue => std::nullopt_t
template <> template <>
struct ToImpl<std::nullopt_t> { struct ToImpl<std::nullopt_t> {
static std::nullopt_t call(StableIValue val) { static std::nullopt_t call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
// val should be equivalent to from(nullptr) // val should be equivalent to from(nullptr)
return std::nullopt; return std::nullopt;
} }
@ -284,14 +323,18 @@ struct ToImpl<std::nullopt_t> {
// from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension // from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
template <typename T> template <typename T>
struct ToImpl<std::optional<T>> { struct ToImpl<std::optional<T>> {
static std::optional<T> call(StableIValue val) { static std::optional<T> call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
auto sivp = to<StableIValue*>(val); auto sivp = to<StableIValue*>(val);
// sivp is either nullptr or a pointer to a StableIValue // sivp is either nullptr or a pointer to a StableIValue
if (sivp == nullptr) { if (sivp == nullptr) {
return {}; return {};
} }
auto inner_val = to<T>(*sivp); auto inner_val =
detail::ToImpl<T>::call(*sivp, extension_build_version, is_internal);
// free the memory associated with StableIValue* sivp // free the memory associated with StableIValue* sivp
delete sivp; delete sivp;
@ -305,7 +348,12 @@ struct ToImpl<std::optional<T>> {
// underlying AtenTensorHandle. // underlying AtenTensorHandle.
template <> template <>
struct ToImpl<torch::stable::Tensor> { struct ToImpl<torch::stable::Tensor> {
static torch::stable::Tensor call(StableIValue val) { static torch::stable::Tensor call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
return torch::stable::Tensor(to<AtenTensorHandle>(val)); return torch::stable::Tensor(to<AtenTensorHandle>(val));
} }
}; };
@ -315,25 +363,60 @@ struct ToImpl<torch::stable::Tensor> {
// ============================================================================= // =============================================================================
// Expose the partially templated class functions through single functions // Expose the partially templated class functions through single functions
// The non-private versions will be used by the extension or headers that
// the extension includes.
template <typename T> template <typename T>
inline StableIValue from(T val) { inline StableIValue from(T val) {
return detail::FromImpl<T>::call(val); return detail::FromImpl<T>::call(
val, aoti_torch_abi_version(), /*is_internal=*/false);
} }
template <typename T> template <typename T>
inline StableIValue from(const std::optional<T>& val) { inline StableIValue from(const std::optional<T>& val) {
return detail::FromImpl<std::optional<T>>::call(val); return detail::FromImpl<std::optional<T>>::call(
val, aoti_torch_abi_version(), /*is_internal=*/false);
} }
// The below overload is used! See https://godbolt.org/z/859cshxrW // The below overload is used! See https://godbolt.org/z/859cshxrW
// We are suppressing the warning for versions clang12- and gcc11- // We are suppressing the warning for versions clang12- and gcc11-
[[maybe_unused]] inline StableIValue from(const torch::stable::Tensor& val) { [[maybe_unused]] inline StableIValue from(const torch::stable::Tensor& val) {
return detail::FromImpl<torch::stable::Tensor>::call(val); return detail::FromImpl<torch::stable::Tensor>::call(
val, aoti_torch_abi_version(), /*is_internal=*/false);
} }
template <typename T> template <typename T>
inline T to(StableIValue val) { inline T to(StableIValue val) {
return detail::ToImpl<T>::call(val); return detail::ToImpl<T>::call(
val, aoti_torch_abi_version(), /*is_internal=*/false);
}
// Internal conversion functions used by from_ivalue and to_ivalue.
// These are used in libtorch
template <typename T>
inline StableIValue _from(T val, uint64_t extension_build_version) {
return detail::FromImpl<T>::call(
val, extension_build_version, /*is_internal=*/true);
}
template <typename T>
inline StableIValue _from(
const std::optional<T>& val,
uint64_t extension_build_version) {
return detail::FromImpl<std::optional<T>>::call(
val, extension_build_version, /*is_internal=*/true);
}
[[maybe_unused]] inline StableIValue _from(
const torch::stable::Tensor& val,
uint64_t extension_build_version) {
return detail::FromImpl<torch::stable::Tensor>::call(
val, extension_build_version, /*is_internal=*/true);
}
template <typename T>
inline T _to(StableIValue val, uint64_t extension_build_version) {
return detail::ToImpl<T>::call(
val, extension_build_version, /*is_internal=*/true);
} }
HIDDEN_NAMESPACE_END(torch, stable, detail) HIDDEN_NAMESPACE_END(torch, stable, detail)