mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add scaffolding for StableIValue FC/BC (no PoC) (#164332)
1. Add `extension_build_version` and `is_internal` to `FromImpl`/`ToImpl` (this will be useful for future if we need to break the BC of any type) #163832 has the PoC of how we would actually use this system 2. Add `aoti_torch_library_impl_v2` that takes in an additional `extension_build_version` argument, updates callsite in `torch/csrc/stable/library.h` to always pass `TORCH_ABI_VERSION` for this argument 3. Add `extension_build_version` to `from_ivalue` and `to_ivalue` and update all callsites 4. Add a private `_from` and `_to` that pass `is_internal=True` to `FromImpl`/`ToImpl`, making it easier to reason about what is being called from libtorch-land / extension-land **Note: This PR does not include a linter that tells the user to update from/to if changing the ABI of a type in headeronly, which I intend to do in https://github.com/pytorch/pytorch/pull/163998** Pull Request resolved: https://github.com/pytorch/pytorch/pull/164332 Approved by: https://github.com/janeyx99 ghstack dependencies: #164356, #166373, #163683
This commit is contained in:
parent
8f51556daa
commit
eae701cad0
|
|
@ -12,33 +12,41 @@
|
|||
|
||||
static StableIValue from_ivalue(
|
||||
const c10::TypePtr& type,
|
||||
const c10::IValue& ivalue) {
|
||||
const c10::IValue& ivalue,
|
||||
uint64_t extension_build_version) {
|
||||
switch (type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
|
||||
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
|
||||
return torch::stable::detail::from(ath);
|
||||
return torch::stable::detail::_from(ath, extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return torch::stable::detail::from(ivalue.toInt());
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toInt(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return torch::stable::detail::from(ivalue.toDouble());
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toDouble(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return torch::stable::detail::from(ivalue.toBool());
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toBool(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return torch::stable::detail::from(ivalue.toScalarType());
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toScalarType(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::DeviceObjType: {
|
||||
return torch::stable::detail::from(ivalue.toDevice());
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toDevice(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return torch::stable::detail::from(ivalue.toLayout());
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toLayout(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return torch::stable::detail::from(ivalue.toMemoryFormat());
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toMemoryFormat(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::OptionalType: {
|
||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||
|
|
@ -56,10 +64,12 @@ static StableIValue from_ivalue(
|
|||
// be kept in sync with torch::stable::detail::from<std::optional<T>>
|
||||
// function in torch/csrc/stable/stableivalue_conversions.h
|
||||
if (ivalue.isNone()) {
|
||||
return torch::stable::detail::from(std::nullopt);
|
||||
return torch::stable::detail::_from(
|
||||
std::nullopt, extension_build_version);
|
||||
}
|
||||
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue));
|
||||
return torch::stable::detail::from(sivp);
|
||||
StableIValue* sivp = new StableIValue(
|
||||
from_ivalue(inner_type, ivalue, extension_build_version));
|
||||
return torch::stable::detail::_from(sivp, extension_build_version);
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
|
|
@ -72,36 +82,43 @@ static StableIValue from_ivalue(
|
|||
|
||||
static c10::IValue to_ivalue(
|
||||
const c10::TypePtr& type,
|
||||
const StableIValue stable_ivalue) {
|
||||
const StableIValue stable_ivalue,
|
||||
uint64_t extension_build_version) {
|
||||
switch (type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
|
||||
torch::stable::detail::to<AtenTensorHandle>(stable_ivalue));
|
||||
torch::stable::detail::_to<AtenTensorHandle>(
|
||||
stable_ivalue, extension_build_version));
|
||||
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
|
||||
ret_raiiath.get())));
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return c10::IValue(torch::stable::detail::to<int64_t>(stable_ivalue));
|
||||
return c10::IValue(torch::stable::detail::_to<int64_t>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return c10::IValue(torch::stable::detail::to<double>(stable_ivalue));
|
||||
return c10::IValue(torch::stable::detail::_to<double>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return c10::IValue(torch::stable::detail::to<bool>(stable_ivalue));
|
||||
return c10::IValue(torch::stable::detail::_to<bool>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return c10::IValue(
|
||||
torch::stable::detail::to<c10::ScalarType>(stable_ivalue));
|
||||
return c10::IValue(torch::stable::detail::_to<c10::ScalarType>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::DeviceObjType: {
|
||||
return c10::IValue(torch::stable::detail::to<c10::Device>(stable_ivalue));
|
||||
return c10::IValue(torch::stable::detail::_to<c10::Device>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return c10::IValue(torch::stable::detail::to<c10::Layout>(stable_ivalue));
|
||||
return c10::IValue(torch::stable::detail::_to<c10::Layout>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return c10::IValue(
|
||||
torch::stable::detail::to<c10::MemoryFormat>(stable_ivalue));
|
||||
return c10::IValue(torch::stable::detail::_to<c10::MemoryFormat>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::OptionalType: {
|
||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||
|
|
@ -116,13 +133,15 @@ static c10::IValue to_ivalue(
|
|||
//
|
||||
// BUT we do NOT have that type inner_type::t readily available, so we
|
||||
// will manually unwrap and recursively call. This implementation MUST
|
||||
// be kept in sync with the torch::stable::detail::to<T> function in
|
||||
// torch/csrc/stable/stableivalue_conversions.h
|
||||
if (stable_ivalue == torch::stable::detail::from(std::nullopt)) {
|
||||
// be kept in sync with the torch::stable::detail::_to<T> function in
|
||||
// torch/csrc/stable/library.h
|
||||
if (stable_ivalue ==
|
||||
torch::stable::detail::_from(std::nullopt, extension_build_version)) {
|
||||
return c10::IValue();
|
||||
}
|
||||
auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue);
|
||||
auto ival = to_ivalue(inner_type, *sivp);
|
||||
auto sivp = torch::stable::detail::_to<StableIValue*>(
|
||||
stable_ivalue, extension_build_version);
|
||||
auto ival = to_ivalue(inner_type, *sivp, extension_build_version);
|
||||
delete sivp;
|
||||
return ival;
|
||||
}
|
||||
|
|
@ -137,8 +156,10 @@ static c10::IValue to_ivalue(
|
|||
|
||||
class StableIValueBoxedKernel : public c10::OperatorKernel {
|
||||
public:
|
||||
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
|
||||
: fn_(fn) {}
|
||||
StableIValueBoxedKernel(
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t),
|
||||
uint64_t extension_build_version)
|
||||
: fn_(fn), extension_build_version_(extension_build_version) {}
|
||||
|
||||
void operator()(
|
||||
const c10::OperatorHandle& op,
|
||||
|
|
@ -154,7 +175,8 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
|
|||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
const auto ministack_idx = num_arguments - idx - 1;
|
||||
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
|
||||
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack));
|
||||
ministack[ministack_idx] = from_ivalue(
|
||||
arg_type, torch::jit::pop(stack), extension_build_version_);
|
||||
}
|
||||
|
||||
// boxed function is going to take a stack of StableIValues, cast them to
|
||||
|
|
@ -165,12 +187,14 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
|
|||
// IValue from StableIValue
|
||||
for (size_t idx = 0; idx < num_returns; idx++) {
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx]));
|
||||
torch::jit::push(
|
||||
stack, to_ivalue(ret_type, ministack[idx], extension_build_version_));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void (*fn_)(StableIValue*, uint64_t, uint64_t);
|
||||
uint64_t extension_build_version_;
|
||||
};
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
|
||||
|
|
@ -181,7 +205,23 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
|
|||
reinterpret_cast<torch::Library*>(self)->impl(
|
||||
name,
|
||||
torch::CppFunction::makeFromBoxedFunctor(
|
||||
std::make_unique<StableIValueBoxedKernel>(fn)));
|
||||
std::make_unique<StableIValueBoxedKernel>(fn, TORCH_ABI_VERSION)));
|
||||
});
|
||||
}
|
||||
|
||||
// Version-aware variant of aoti_torch_library_impl that takes an
|
||||
// extension_build_version parameter for backward compatibility
|
||||
AOTI_TORCH_EXPORT AOTITorchError torch_library_impl(
|
||||
TorchLibraryHandle self,
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t),
|
||||
uint64_t extension_build_version) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
reinterpret_cast<torch::Library*>(self)->impl(
|
||||
name,
|
||||
torch::CppFunction::makeFromBoxedFunctor(
|
||||
std::make_unique<StableIValueBoxedKernel>(
|
||||
fn, extension_build_version)));
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -204,7 +244,8 @@ AOTITorchError aoti_torch_call_dispatcher(
|
|||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
auto stable_ivalue = stack[idx];
|
||||
auto arg_type = schema.arguments()[idx].type();
|
||||
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
|
||||
torch::jit::push(
|
||||
ivalue_stack, to_ivalue(arg_type, stable_ivalue, TORCH_ABI_VERSION));
|
||||
}
|
||||
|
||||
op.callBoxed(ivalue_stack);
|
||||
|
|
@ -214,7 +255,8 @@ AOTITorchError aoti_torch_call_dispatcher(
|
|||
for (const auto idx : c10::irange(num_returns)) {
|
||||
const auto stack_idx = num_returns - idx - 1;
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
|
||||
stack[stack_idx] = from_ivalue(
|
||||
ret_type, torch::jit::pop(ivalue_stack), TORCH_ABI_VERSION);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -355,7 +397,9 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
|
|||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
auto stable_ivalue = stack[idx];
|
||||
auto arg_type = schema.arguments()[idx].type();
|
||||
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
|
||||
torch::jit::push(
|
||||
ivalue_stack,
|
||||
to_ivalue(arg_type, stable_ivalue, extension_build_version));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -366,7 +410,8 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
|
|||
for (const auto idx : c10::irange(num_returns)) {
|
||||
const auto stack_idx = num_returns - idx - 1;
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
|
||||
stack[stack_idx] = from_ivalue(
|
||||
ret_type, torch::jit::pop(ivalue_stack), extension_build_version);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@
|
|||
// code for better UX.
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
|
||||
// Technically, this file doesn't use anything from stableivalue_conversions.h,
|
||||
// but we need to include it here as the contents of stableivalue_conversions.h
|
||||
// used to live here and so we need to expose them for backwards compatibility.
|
||||
#include <torch/csrc/stable/stableivalue_conversions.h>
|
||||
#include <torch/csrc/stable/version.h>
|
||||
|
||||
HIDDEN_NAMESPACE_BEGIN(torch, stable, detail)
|
||||
|
||||
|
|
@ -81,7 +83,11 @@ class StableLibrary final {
|
|||
StableLibrary& impl(
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
torch_library_impl(lib_, name, fn, TORCH_ABI_VERSION);
|
||||
#else
|
||||
aoti_torch_library_impl(lib_, name, fn);
|
||||
#endif
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,12 +24,17 @@ T to(StableIValue val);
|
|||
// =============================================================================
|
||||
// =============================================================================
|
||||
// FROM CONVERSIONS (T -> StableIValue)
|
||||
// =============================================================================
|
||||
// ======================================================================
|
||||
|
||||
// Specialization for general copyable types (catch-all) => StableIValue
|
||||
template <typename T>
|
||||
struct FromImpl {
|
||||
static StableIValue call(T val) {
|
||||
static StableIValue call(
|
||||
T val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
static_assert(
|
||||
sizeof(T) <= sizeof(StableIValue),
|
||||
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
||||
|
|
@ -68,7 +73,12 @@ struct FromImpl {
|
|||
using torch::headeronly::ScalarType;
|
||||
template <>
|
||||
struct FromImpl<ScalarType> {
|
||||
static StableIValue call(ScalarType val) {
|
||||
static StableIValue call(
|
||||
ScalarType val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
switch (val) {
|
||||
case ScalarType::Byte:
|
||||
return from(aoti_torch_dtype_uint8());
|
||||
|
|
@ -121,7 +131,12 @@ struct FromImpl<ScalarType> {
|
|||
// Specialization for std::nullopt_t => StableIValue
|
||||
template <>
|
||||
struct FromImpl<std::nullopt_t> {
|
||||
static StableIValue call(std::nullopt_t val) {
|
||||
static StableIValue call(
|
||||
std::nullopt_t val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
return from(nullptr);
|
||||
}
|
||||
};
|
||||
|
|
@ -157,11 +172,15 @@ struct FromImpl<std::nullopt_t> {
|
|||
// std::optional<T> or a std::nullopt.
|
||||
template <typename T>
|
||||
struct FromImpl<std::optional<T>> {
|
||||
static StableIValue call(const std::optional<T>& val) {
|
||||
static StableIValue call(
|
||||
const std::optional<T>& val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
if (!val.has_value()) {
|
||||
return from(std::nullopt);
|
||||
}
|
||||
return from(new StableIValue(from(val.value())));
|
||||
return from(new StableIValue(detail::FromImpl<T>::call(
|
||||
val.value(), extension_build_version, is_internal)));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -169,7 +188,12 @@ struct FromImpl<std::optional<T>> {
|
|||
// Returns a new owning reference of the underlying Tensor.
|
||||
template <>
|
||||
struct FromImpl<torch::stable::Tensor> {
|
||||
static StableIValue call(const torch::stable::Tensor& val) {
|
||||
static StableIValue call(
|
||||
const torch::stable::Tensor& val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
AtenTensorHandle new_ath;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
|
||||
return from(new_ath);
|
||||
|
|
@ -183,7 +207,12 @@ struct FromImpl<torch::stable::Tensor> {
|
|||
// Specialization for StableIValue => general copyable types (catch-all)
|
||||
template <typename T>
|
||||
struct ToImpl {
|
||||
static T call(StableIValue val) {
|
||||
static T call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
static_assert(std::is_trivially_copyable_v<T>);
|
||||
// T may not have a default constructor. (For example, it might be
|
||||
// c10::Device.) However, std::memcpy implicitly creates a T at the
|
||||
|
|
@ -218,7 +247,12 @@ struct ToImpl {
|
|||
// Specialization for StableIValue => torch::headeronly::ScalarType
|
||||
template <>
|
||||
struct ToImpl<ScalarType> {
|
||||
static ScalarType call(StableIValue val) {
|
||||
static ScalarType call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
int32_t shim_scalartype = to<int32_t>(val);
|
||||
if (shim_scalartype == aoti_torch_dtype_uint8()) {
|
||||
return ScalarType::Byte;
|
||||
|
|
@ -273,7 +307,12 @@ struct ToImpl<ScalarType> {
|
|||
// Specialization for StableIValue => std::nullopt_t
|
||||
template <>
|
||||
struct ToImpl<std::nullopt_t> {
|
||||
static std::nullopt_t call(StableIValue val) {
|
||||
static std::nullopt_t call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
// val should be equivalent to from(nullptr)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
|
@ -284,14 +323,18 @@ struct ToImpl<std::nullopt_t> {
|
|||
// from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
|
||||
template <typename T>
|
||||
struct ToImpl<std::optional<T>> {
|
||||
static std::optional<T> call(StableIValue val) {
|
||||
static std::optional<T> call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
auto sivp = to<StableIValue*>(val);
|
||||
|
||||
// sivp is either nullptr or a pointer to a StableIValue
|
||||
if (sivp == nullptr) {
|
||||
return {};
|
||||
}
|
||||
auto inner_val = to<T>(*sivp);
|
||||
auto inner_val =
|
||||
detail::ToImpl<T>::call(*sivp, extension_build_version, is_internal);
|
||||
|
||||
// free the memory associated with StableIValue* sivp
|
||||
delete sivp;
|
||||
|
|
@ -305,7 +348,12 @@ struct ToImpl<std::optional<T>> {
|
|||
// underlying AtenTensorHandle.
|
||||
template <>
|
||||
struct ToImpl<torch::stable::Tensor> {
|
||||
static torch::stable::Tensor call(StableIValue val) {
|
||||
static torch::stable::Tensor call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
return torch::stable::Tensor(to<AtenTensorHandle>(val));
|
||||
}
|
||||
};
|
||||
|
|
@ -315,25 +363,60 @@ struct ToImpl<torch::stable::Tensor> {
|
|||
// =============================================================================
|
||||
|
||||
// Expose the partially templated class functions through single functions
|
||||
// The non-private versions will be used by the extension or headers that
|
||||
// the extension includes.
|
||||
template <typename T>
|
||||
inline StableIValue from(T val) {
|
||||
return detail::FromImpl<T>::call(val);
|
||||
return detail::FromImpl<T>::call(
|
||||
val, aoti_torch_abi_version(), /*is_internal=*/false);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline StableIValue from(const std::optional<T>& val) {
|
||||
return detail::FromImpl<std::optional<T>>::call(val);
|
||||
return detail::FromImpl<std::optional<T>>::call(
|
||||
val, aoti_torch_abi_version(), /*is_internal=*/false);
|
||||
}
|
||||
|
||||
// The below overload is used! See https://godbolt.org/z/859cshxrW
|
||||
// We are suppressing the warning for versions clang12- and gcc11-
|
||||
[[maybe_unused]] inline StableIValue from(const torch::stable::Tensor& val) {
|
||||
return detail::FromImpl<torch::stable::Tensor>::call(val);
|
||||
return detail::FromImpl<torch::stable::Tensor>::call(
|
||||
val, aoti_torch_abi_version(), /*is_internal=*/false);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T to(StableIValue val) {
|
||||
return detail::ToImpl<T>::call(val);
|
||||
return detail::ToImpl<T>::call(
|
||||
val, aoti_torch_abi_version(), /*is_internal=*/false);
|
||||
}
|
||||
|
||||
// Internal conversion functions used by from_ivalue and to_ivalue.
|
||||
// These are used in libtorch
|
||||
template <typename T>
|
||||
inline StableIValue _from(T val, uint64_t extension_build_version) {
|
||||
return detail::FromImpl<T>::call(
|
||||
val, extension_build_version, /*is_internal=*/true);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline StableIValue _from(
|
||||
const std::optional<T>& val,
|
||||
uint64_t extension_build_version) {
|
||||
return detail::FromImpl<std::optional<T>>::call(
|
||||
val, extension_build_version, /*is_internal=*/true);
|
||||
}
|
||||
|
||||
[[maybe_unused]] inline StableIValue _from(
|
||||
const torch::stable::Tensor& val,
|
||||
uint64_t extension_build_version) {
|
||||
return detail::FromImpl<torch::stable::Tensor>::call(
|
||||
val, extension_build_version, /*is_internal=*/true);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T _to(StableIValue val, uint64_t extension_build_version) {
|
||||
return detail::ToImpl<T>::call(
|
||||
val, extension_build_version, /*is_internal=*/true);
|
||||
}
|
||||
|
||||
HIDDEN_NAMESPACE_END(torch, stable, detail)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user