mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add scaffolding for StableIValue FC/BC (no PoC) (#164332)
1. Add `extension_build_version` and `is_internal` to `FromImpl`/`ToImpl` (this will be useful for future if we need to break the BC of any type) #163832 has the PoC of how we would actually use this system 2. Add `aoti_torch_library_impl_v2` that takes in an additional `extension_build_version` argument, updates callsite in `torch/csrc/stable/library.h` to always pass `TORCH_ABI_VERSION` for this argument 3. Add `extension_build_version` to `from_ivalue` and `to_ivalue` and update all callsites 4. Add a private `_from` and `_to` that pass `is_internal=True` to `FromImpl`/`ToImpl`, making it easier to reason about what is being called from libtorch-land / extension-land **Note: This PR does not include a linter that tells the user to update from/to if changing the ABI of a type in headeronly, which I intend to do in https://github.com/pytorch/pytorch/pull/163998** Pull Request resolved: https://github.com/pytorch/pytorch/pull/164332 Approved by: https://github.com/janeyx99 ghstack dependencies: #164356, #166373, #163683
This commit is contained in:
parent
8f51556daa
commit
eae701cad0
|
|
@ -12,33 +12,41 @@
|
||||||
|
|
||||||
static StableIValue from_ivalue(
|
static StableIValue from_ivalue(
|
||||||
const c10::TypePtr& type,
|
const c10::TypePtr& type,
|
||||||
const c10::IValue& ivalue) {
|
const c10::IValue& ivalue,
|
||||||
|
uint64_t extension_build_version) {
|
||||||
switch (type->kind()) {
|
switch (type->kind()) {
|
||||||
case c10::TypeKind::TensorType: {
|
case c10::TypeKind::TensorType: {
|
||||||
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
|
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
|
||||||
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
|
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
|
||||||
return torch::stable::detail::from(ath);
|
return torch::stable::detail::_from(ath, extension_build_version);
|
||||||
}
|
}
|
||||||
case c10::TypeKind::IntType: {
|
case c10::TypeKind::IntType: {
|
||||||
return torch::stable::detail::from(ivalue.toInt());
|
return torch::stable::detail::_from(
|
||||||
|
ivalue.toInt(), extension_build_version);
|
||||||
}
|
}
|
||||||
case c10::TypeKind::FloatType: {
|
case c10::TypeKind::FloatType: {
|
||||||
return torch::stable::detail::from(ivalue.toDouble());
|
return torch::stable::detail::_from(
|
||||||
|
ivalue.toDouble(), extension_build_version);
|
||||||
}
|
}
|
||||||
case c10::TypeKind::BoolType: {
|
case c10::TypeKind::BoolType: {
|
||||||
return torch::stable::detail::from(ivalue.toBool());
|
return torch::stable::detail::_from(
|
||||||
|
ivalue.toBool(), extension_build_version);
|
||||||
}
|
}
|
||||||
case c10::TypeKind::ScalarTypeType: {
|
case c10::TypeKind::ScalarTypeType: {
|
||||||
return torch::stable::detail::from(ivalue.toScalarType());
|
return torch::stable::detail::_from(
|
||||||
|
ivalue.toScalarType(), extension_build_version);
|
||||||
}
|
}
|
||||||
case c10::TypeKind::DeviceObjType: {
|
case c10::TypeKind::DeviceObjType: {
|
||||||
return torch::stable::detail::from(ivalue.toDevice());
|
return torch::stable::detail::_from(
|
||||||
|
ivalue.toDevice(), extension_build_version);
|
||||||
}
|
}
|
||||||
case c10::TypeKind::LayoutType: {
|
case c10::TypeKind::LayoutType: {
|
||||||
return torch::stable::detail::from(ivalue.toLayout());
|
return torch::stable::detail::_from(
|
||||||
|
ivalue.toLayout(), extension_build_version);
|
||||||
}
|
}
|
||||||
case c10::TypeKind::MemoryFormatType: {
|
case c10::TypeKind::MemoryFormatType: {
|
||||||
return torch::stable::detail::from(ivalue.toMemoryFormat());
|
return torch::stable::detail::_from(
|
||||||
|
ivalue.toMemoryFormat(), extension_build_version);
|
||||||
}
|
}
|
||||||
case c10::TypeKind::OptionalType: {
|
case c10::TypeKind::OptionalType: {
|
||||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||||
|
|
@ -56,10 +64,12 @@ static StableIValue from_ivalue(
|
||||||
// be kept in sync with torch::stable::detail::from<std::optional<T>>
|
// be kept in sync with torch::stable::detail::from<std::optional<T>>
|
||||||
// function in torch/csrc/stable/stableivalue_conversions.h
|
// function in torch/csrc/stable/stableivalue_conversions.h
|
||||||
if (ivalue.isNone()) {
|
if (ivalue.isNone()) {
|
||||||
return torch::stable::detail::from(std::nullopt);
|
return torch::stable::detail::_from(
|
||||||
|
std::nullopt, extension_build_version);
|
||||||
}
|
}
|
||||||
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue));
|
StableIValue* sivp = new StableIValue(
|
||||||
return torch::stable::detail::from(sivp);
|
from_ivalue(inner_type, ivalue, extension_build_version));
|
||||||
|
return torch::stable::detail::_from(sivp, extension_build_version);
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
|
|
@ -72,36 +82,43 @@ static StableIValue from_ivalue(
|
||||||
|
|
||||||
static c10::IValue to_ivalue(
|
static c10::IValue to_ivalue(
|
||||||
const c10::TypePtr& type,
|
const c10::TypePtr& type,
|
||||||
const StableIValue stable_ivalue) {
|
const StableIValue stable_ivalue,
|
||||||
|
uint64_t extension_build_version) {
|
||||||
switch (type->kind()) {
|
switch (type->kind()) {
|
||||||
case c10::TypeKind::TensorType: {
|
case c10::TypeKind::TensorType: {
|
||||||
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
|
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
|
||||||
torch::stable::detail::to<AtenTensorHandle>(stable_ivalue));
|
torch::stable::detail::_to<AtenTensorHandle>(
|
||||||
|
stable_ivalue, extension_build_version));
|
||||||
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
|
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
|
||||||
ret_raiiath.get())));
|
ret_raiiath.get())));
|
||||||
}
|
}
|
||||||
case c10::TypeKind::IntType: {
|
case c10::TypeKind::IntType: {
|
||||||
return c10::IValue(torch::stable::detail::to<int64_t>(stable_ivalue));
|
return c10::IValue(torch::stable::detail::_to<int64_t>(
|
||||||
|
stable_ivalue, extension_build_version));
|
||||||
}
|
}
|
||||||
case c10::TypeKind::FloatType: {
|
case c10::TypeKind::FloatType: {
|
||||||
return c10::IValue(torch::stable::detail::to<double>(stable_ivalue));
|
return c10::IValue(torch::stable::detail::_to<double>(
|
||||||
|
stable_ivalue, extension_build_version));
|
||||||
}
|
}
|
||||||
case c10::TypeKind::BoolType: {
|
case c10::TypeKind::BoolType: {
|
||||||
return c10::IValue(torch::stable::detail::to<bool>(stable_ivalue));
|
return c10::IValue(torch::stable::detail::_to<bool>(
|
||||||
|
stable_ivalue, extension_build_version));
|
||||||
}
|
}
|
||||||
case c10::TypeKind::ScalarTypeType: {
|
case c10::TypeKind::ScalarTypeType: {
|
||||||
return c10::IValue(
|
return c10::IValue(torch::stable::detail::_to<c10::ScalarType>(
|
||||||
torch::stable::detail::to<c10::ScalarType>(stable_ivalue));
|
stable_ivalue, extension_build_version));
|
||||||
}
|
}
|
||||||
case c10::TypeKind::DeviceObjType: {
|
case c10::TypeKind::DeviceObjType: {
|
||||||
return c10::IValue(torch::stable::detail::to<c10::Device>(stable_ivalue));
|
return c10::IValue(torch::stable::detail::_to<c10::Device>(
|
||||||
|
stable_ivalue, extension_build_version));
|
||||||
}
|
}
|
||||||
case c10::TypeKind::LayoutType: {
|
case c10::TypeKind::LayoutType: {
|
||||||
return c10::IValue(torch::stable::detail::to<c10::Layout>(stable_ivalue));
|
return c10::IValue(torch::stable::detail::_to<c10::Layout>(
|
||||||
|
stable_ivalue, extension_build_version));
|
||||||
}
|
}
|
||||||
case c10::TypeKind::MemoryFormatType: {
|
case c10::TypeKind::MemoryFormatType: {
|
||||||
return c10::IValue(
|
return c10::IValue(torch::stable::detail::_to<c10::MemoryFormat>(
|
||||||
torch::stable::detail::to<c10::MemoryFormat>(stable_ivalue));
|
stable_ivalue, extension_build_version));
|
||||||
}
|
}
|
||||||
case c10::TypeKind::OptionalType: {
|
case c10::TypeKind::OptionalType: {
|
||||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||||
|
|
@ -116,13 +133,15 @@ static c10::IValue to_ivalue(
|
||||||
//
|
//
|
||||||
// BUT we do NOT have that type inner_type::t readily available, so we
|
// BUT we do NOT have that type inner_type::t readily available, so we
|
||||||
// will manually unwrap and recursively call. This implementation MUST
|
// will manually unwrap and recursively call. This implementation MUST
|
||||||
// be kept in sync with the torch::stable::detail::to<T> function in
|
// be kept in sync with the torch::stable::detail::_to<T> function in
|
||||||
// torch/csrc/stable/stableivalue_conversions.h
|
// torch/csrc/stable/library.h
|
||||||
if (stable_ivalue == torch::stable::detail::from(std::nullopt)) {
|
if (stable_ivalue ==
|
||||||
|
torch::stable::detail::_from(std::nullopt, extension_build_version)) {
|
||||||
return c10::IValue();
|
return c10::IValue();
|
||||||
}
|
}
|
||||||
auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue);
|
auto sivp = torch::stable::detail::_to<StableIValue*>(
|
||||||
auto ival = to_ivalue(inner_type, *sivp);
|
stable_ivalue, extension_build_version);
|
||||||
|
auto ival = to_ivalue(inner_type, *sivp, extension_build_version);
|
||||||
delete sivp;
|
delete sivp;
|
||||||
return ival;
|
return ival;
|
||||||
}
|
}
|
||||||
|
|
@ -137,8 +156,10 @@ static c10::IValue to_ivalue(
|
||||||
|
|
||||||
class StableIValueBoxedKernel : public c10::OperatorKernel {
|
class StableIValueBoxedKernel : public c10::OperatorKernel {
|
||||||
public:
|
public:
|
||||||
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
|
StableIValueBoxedKernel(
|
||||||
: fn_(fn) {}
|
void (*fn)(StableIValue*, uint64_t, uint64_t),
|
||||||
|
uint64_t extension_build_version)
|
||||||
|
: fn_(fn), extension_build_version_(extension_build_version) {}
|
||||||
|
|
||||||
void operator()(
|
void operator()(
|
||||||
const c10::OperatorHandle& op,
|
const c10::OperatorHandle& op,
|
||||||
|
|
@ -154,7 +175,8 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
|
||||||
for (const auto idx : c10::irange(num_arguments)) {
|
for (const auto idx : c10::irange(num_arguments)) {
|
||||||
const auto ministack_idx = num_arguments - idx - 1;
|
const auto ministack_idx = num_arguments - idx - 1;
|
||||||
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
|
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
|
||||||
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack));
|
ministack[ministack_idx] = from_ivalue(
|
||||||
|
arg_type, torch::jit::pop(stack), extension_build_version_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// boxed function is going to take a stack of StableIValues, cast them to
|
// boxed function is going to take a stack of StableIValues, cast them to
|
||||||
|
|
@ -165,12 +187,14 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
|
||||||
// IValue from StableIValue
|
// IValue from StableIValue
|
||||||
for (size_t idx = 0; idx < num_returns; idx++) {
|
for (size_t idx = 0; idx < num_returns; idx++) {
|
||||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||||
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx]));
|
torch::jit::push(
|
||||||
|
stack, to_ivalue(ret_type, ministack[idx], extension_build_version_));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void (*fn_)(StableIValue*, uint64_t, uint64_t);
|
void (*fn_)(StableIValue*, uint64_t, uint64_t);
|
||||||
|
uint64_t extension_build_version_;
|
||||||
};
|
};
|
||||||
|
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
|
||||||
|
|
@ -181,7 +205,23 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
|
||||||
reinterpret_cast<torch::Library*>(self)->impl(
|
reinterpret_cast<torch::Library*>(self)->impl(
|
||||||
name,
|
name,
|
||||||
torch::CppFunction::makeFromBoxedFunctor(
|
torch::CppFunction::makeFromBoxedFunctor(
|
||||||
std::make_unique<StableIValueBoxedKernel>(fn)));
|
std::make_unique<StableIValueBoxedKernel>(fn, TORCH_ABI_VERSION)));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Version-aware variant of aoti_torch_library_impl that takes an
|
||||||
|
// extension_build_version parameter for backward compatibility
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError torch_library_impl(
|
||||||
|
TorchLibraryHandle self,
|
||||||
|
const char* name,
|
||||||
|
void (*fn)(StableIValue*, uint64_t, uint64_t),
|
||||||
|
uint64_t extension_build_version) {
|
||||||
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||||
|
reinterpret_cast<torch::Library*>(self)->impl(
|
||||||
|
name,
|
||||||
|
torch::CppFunction::makeFromBoxedFunctor(
|
||||||
|
std::make_unique<StableIValueBoxedKernel>(
|
||||||
|
fn, extension_build_version)));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -204,7 +244,8 @@ AOTITorchError aoti_torch_call_dispatcher(
|
||||||
for (const auto idx : c10::irange(num_arguments)) {
|
for (const auto idx : c10::irange(num_arguments)) {
|
||||||
auto stable_ivalue = stack[idx];
|
auto stable_ivalue = stack[idx];
|
||||||
auto arg_type = schema.arguments()[idx].type();
|
auto arg_type = schema.arguments()[idx].type();
|
||||||
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
|
torch::jit::push(
|
||||||
|
ivalue_stack, to_ivalue(arg_type, stable_ivalue, TORCH_ABI_VERSION));
|
||||||
}
|
}
|
||||||
|
|
||||||
op.callBoxed(ivalue_stack);
|
op.callBoxed(ivalue_stack);
|
||||||
|
|
@ -214,7 +255,8 @@ AOTITorchError aoti_torch_call_dispatcher(
|
||||||
for (const auto idx : c10::irange(num_returns)) {
|
for (const auto idx : c10::irange(num_returns)) {
|
||||||
const auto stack_idx = num_returns - idx - 1;
|
const auto stack_idx = num_returns - idx - 1;
|
||||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||||
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
|
stack[stack_idx] = from_ivalue(
|
||||||
|
ret_type, torch::jit::pop(ivalue_stack), TORCH_ABI_VERSION);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -355,7 +397,9 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
|
||||||
for (const auto idx : c10::irange(num_arguments)) {
|
for (const auto idx : c10::irange(num_arguments)) {
|
||||||
auto stable_ivalue = stack[idx];
|
auto stable_ivalue = stack[idx];
|
||||||
auto arg_type = schema.arguments()[idx].type();
|
auto arg_type = schema.arguments()[idx].type();
|
||||||
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
|
torch::jit::push(
|
||||||
|
ivalue_stack,
|
||||||
|
to_ivalue(arg_type, stable_ivalue, extension_build_version));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -366,7 +410,8 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
|
||||||
for (const auto idx : c10::irange(num_returns)) {
|
for (const auto idx : c10::irange(num_returns)) {
|
||||||
const auto stack_idx = num_returns - idx - 1;
|
const auto stack_idx = num_returns - idx - 1;
|
||||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||||
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
|
stack[stack_idx] = from_ivalue(
|
||||||
|
ret_type, torch::jit::pop(ivalue_stack), extension_build_version);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,14 @@
|
||||||
// code for better UX.
|
// code for better UX.
|
||||||
|
|
||||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||||
|
#include <torch/csrc/stable/c/shim.h>
|
||||||
#include <torch/headeronly/macros/Macros.h>
|
#include <torch/headeronly/macros/Macros.h>
|
||||||
|
|
||||||
// Technically, this file doesn't use anything from stableivalue_conversions.h,
|
// Technically, this file doesn't use anything from stableivalue_conversions.h,
|
||||||
// but we need to include it here as the contents of stableivalue_conversions.h
|
// but we need to include it here as the contents of stableivalue_conversions.h
|
||||||
// used to live here and so we need to expose them for backwards compatibility.
|
// used to live here and so we need to expose them for backwards compatibility.
|
||||||
#include <torch/csrc/stable/stableivalue_conversions.h>
|
#include <torch/csrc/stable/stableivalue_conversions.h>
|
||||||
|
#include <torch/csrc/stable/version.h>
|
||||||
|
|
||||||
HIDDEN_NAMESPACE_BEGIN(torch, stable, detail)
|
HIDDEN_NAMESPACE_BEGIN(torch, stable, detail)
|
||||||
|
|
||||||
|
|
@ -81,7 +83,11 @@ class StableLibrary final {
|
||||||
StableLibrary& impl(
|
StableLibrary& impl(
|
||||||
const char* name,
|
const char* name,
|
||||||
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
|
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
|
||||||
|
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||||
|
torch_library_impl(lib_, name, fn, TORCH_ABI_VERSION);
|
||||||
|
#else
|
||||||
aoti_torch_library_impl(lib_, name, fn);
|
aoti_torch_library_impl(lib_, name, fn);
|
||||||
|
#endif
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,12 +24,17 @@ T to(StableIValue val);
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// FROM CONVERSIONS (T -> StableIValue)
|
// FROM CONVERSIONS (T -> StableIValue)
|
||||||
// =============================================================================
|
// ======================================================================
|
||||||
|
|
||||||
// Specialization for general copyable types (catch-all) => StableIValue
|
// Specialization for general copyable types (catch-all) => StableIValue
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct FromImpl {
|
struct FromImpl {
|
||||||
static StableIValue call(T val) {
|
static StableIValue call(
|
||||||
|
T val,
|
||||||
|
uint64_t extension_build_version,
|
||||||
|
bool is_internal) {
|
||||||
|
(void)extension_build_version; // Unused parameter
|
||||||
|
(void)is_internal; // Unused parameter
|
||||||
static_assert(
|
static_assert(
|
||||||
sizeof(T) <= sizeof(StableIValue),
|
sizeof(T) <= sizeof(StableIValue),
|
||||||
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
||||||
|
|
@ -68,7 +73,12 @@ struct FromImpl {
|
||||||
using torch::headeronly::ScalarType;
|
using torch::headeronly::ScalarType;
|
||||||
template <>
|
template <>
|
||||||
struct FromImpl<ScalarType> {
|
struct FromImpl<ScalarType> {
|
||||||
static StableIValue call(ScalarType val) {
|
static StableIValue call(
|
||||||
|
ScalarType val,
|
||||||
|
uint64_t extension_build_version,
|
||||||
|
bool is_internal) {
|
||||||
|
(void)extension_build_version; // Unused parameter
|
||||||
|
(void)is_internal; // Unused parameter
|
||||||
switch (val) {
|
switch (val) {
|
||||||
case ScalarType::Byte:
|
case ScalarType::Byte:
|
||||||
return from(aoti_torch_dtype_uint8());
|
return from(aoti_torch_dtype_uint8());
|
||||||
|
|
@ -121,7 +131,12 @@ struct FromImpl<ScalarType> {
|
||||||
// Specialization for std::nullopt_t => StableIValue
|
// Specialization for std::nullopt_t => StableIValue
|
||||||
template <>
|
template <>
|
||||||
struct FromImpl<std::nullopt_t> {
|
struct FromImpl<std::nullopt_t> {
|
||||||
static StableIValue call(std::nullopt_t val) {
|
static StableIValue call(
|
||||||
|
std::nullopt_t val,
|
||||||
|
uint64_t extension_build_version,
|
||||||
|
bool is_internal) {
|
||||||
|
(void)extension_build_version; // Unused parameter
|
||||||
|
(void)is_internal; // Unused parameter
|
||||||
return from(nullptr);
|
return from(nullptr);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -157,11 +172,15 @@ struct FromImpl<std::nullopt_t> {
|
||||||
// std::optional<T> or a std::nullopt.
|
// std::optional<T> or a std::nullopt.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct FromImpl<std::optional<T>> {
|
struct FromImpl<std::optional<T>> {
|
||||||
static StableIValue call(const std::optional<T>& val) {
|
static StableIValue call(
|
||||||
|
const std::optional<T>& val,
|
||||||
|
uint64_t extension_build_version,
|
||||||
|
bool is_internal) {
|
||||||
if (!val.has_value()) {
|
if (!val.has_value()) {
|
||||||
return from(std::nullopt);
|
return from(std::nullopt);
|
||||||
}
|
}
|
||||||
return from(new StableIValue(from(val.value())));
|
return from(new StableIValue(detail::FromImpl<T>::call(
|
||||||
|
val.value(), extension_build_version, is_internal)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -169,7 +188,12 @@ struct FromImpl<std::optional<T>> {
|
||||||
// Returns a new owning reference of the underlying Tensor.
|
// Returns a new owning reference of the underlying Tensor.
|
||||||
template <>
|
template <>
|
||||||
struct FromImpl<torch::stable::Tensor> {
|
struct FromImpl<torch::stable::Tensor> {
|
||||||
static StableIValue call(const torch::stable::Tensor& val) {
|
static StableIValue call(
|
||||||
|
const torch::stable::Tensor& val,
|
||||||
|
uint64_t extension_build_version,
|
||||||
|
bool is_internal) {
|
||||||
|
(void)extension_build_version; // Unused parameter
|
||||||
|
(void)is_internal; // Unused parameter
|
||||||
AtenTensorHandle new_ath;
|
AtenTensorHandle new_ath;
|
||||||
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
|
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
|
||||||
return from(new_ath);
|
return from(new_ath);
|
||||||
|
|
@ -183,7 +207,12 @@ struct FromImpl<torch::stable::Tensor> {
|
||||||
// Specialization for StableIValue => general copyable types (catch-all)
|
// Specialization for StableIValue => general copyable types (catch-all)
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ToImpl {
|
struct ToImpl {
|
||||||
static T call(StableIValue val) {
|
static T call(
|
||||||
|
StableIValue val,
|
||||||
|
uint64_t extension_build_version,
|
||||||
|
bool is_internal) {
|
||||||
|
(void)extension_build_version; // Unused parameter
|
||||||
|
(void)is_internal; // Unused parameter
|
||||||
static_assert(std::is_trivially_copyable_v<T>);
|
static_assert(std::is_trivially_copyable_v<T>);
|
||||||
// T may not have a default constructor. (For example, it might be
|
// T may not have a default constructor. (For example, it might be
|
||||||
// c10::Device.) However, std::memcpy implicitly creates a T at the
|
// c10::Device.) However, std::memcpy implicitly creates a T at the
|
||||||
|
|
@ -218,7 +247,12 @@ struct ToImpl {
|
||||||
// Specialization for StableIValue => torch::headeronly::ScalarType
|
// Specialization for StableIValue => torch::headeronly::ScalarType
|
||||||
template <>
|
template <>
|
||||||
struct ToImpl<ScalarType> {
|
struct ToImpl<ScalarType> {
|
||||||
static ScalarType call(StableIValue val) {
|
static ScalarType call(
|
||||||
|
StableIValue val,
|
||||||
|
uint64_t extension_build_version,
|
||||||
|
bool is_internal) {
|
||||||
|
(void)extension_build_version; // Unused parameter
|
||||||
|
(void)is_internal; // Unused parameter
|
||||||
int32_t shim_scalartype = to<int32_t>(val);
|
int32_t shim_scalartype = to<int32_t>(val);
|
||||||
if (shim_scalartype == aoti_torch_dtype_uint8()) {
|
if (shim_scalartype == aoti_torch_dtype_uint8()) {
|
||||||
return ScalarType::Byte;
|
return ScalarType::Byte;
|
||||||
|
|
@ -273,7 +307,12 @@ struct ToImpl<ScalarType> {
|
||||||
// Specialization for StableIValue => std::nullopt_t
|
// Specialization for StableIValue => std::nullopt_t
|
||||||
template <>
|
template <>
|
||||||
struct ToImpl<std::nullopt_t> {
|
struct ToImpl<std::nullopt_t> {
|
||||||
static std::nullopt_t call(StableIValue val) {
|
static std::nullopt_t call(
|
||||||
|
StableIValue val,
|
||||||
|
uint64_t extension_build_version,
|
||||||
|
bool is_internal) {
|
||||||
|
(void)extension_build_version; // Unused parameter
|
||||||
|
(void)is_internal; // Unused parameter
|
||||||
// val should be equivalent to from(nullptr)
|
// val should be equivalent to from(nullptr)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
@ -284,14 +323,18 @@ struct ToImpl<std::nullopt_t> {
|
||||||
// from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
|
// from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ToImpl<std::optional<T>> {
|
struct ToImpl<std::optional<T>> {
|
||||||
static std::optional<T> call(StableIValue val) {
|
static std::optional<T> call(
|
||||||
|
StableIValue val,
|
||||||
|
uint64_t extension_build_version,
|
||||||
|
bool is_internal) {
|
||||||
auto sivp = to<StableIValue*>(val);
|
auto sivp = to<StableIValue*>(val);
|
||||||
|
|
||||||
// sivp is either nullptr or a pointer to a StableIValue
|
// sivp is either nullptr or a pointer to a StableIValue
|
||||||
if (sivp == nullptr) {
|
if (sivp == nullptr) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
auto inner_val = to<T>(*sivp);
|
auto inner_val =
|
||||||
|
detail::ToImpl<T>::call(*sivp, extension_build_version, is_internal);
|
||||||
|
|
||||||
// free the memory associated with StableIValue* sivp
|
// free the memory associated with StableIValue* sivp
|
||||||
delete sivp;
|
delete sivp;
|
||||||
|
|
@ -305,7 +348,12 @@ struct ToImpl<std::optional<T>> {
|
||||||
// underlying AtenTensorHandle.
|
// underlying AtenTensorHandle.
|
||||||
template <>
|
template <>
|
||||||
struct ToImpl<torch::stable::Tensor> {
|
struct ToImpl<torch::stable::Tensor> {
|
||||||
static torch::stable::Tensor call(StableIValue val) {
|
static torch::stable::Tensor call(
|
||||||
|
StableIValue val,
|
||||||
|
uint64_t extension_build_version,
|
||||||
|
bool is_internal) {
|
||||||
|
(void)extension_build_version; // Unused parameter
|
||||||
|
(void)is_internal; // Unused parameter
|
||||||
return torch::stable::Tensor(to<AtenTensorHandle>(val));
|
return torch::stable::Tensor(to<AtenTensorHandle>(val));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -315,25 +363,60 @@ struct ToImpl<torch::stable::Tensor> {
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
// Expose the partially templated class functions through single functions
|
// Expose the partially templated class functions through single functions
|
||||||
|
// The non-private versions will be used by the extension or headers that
|
||||||
|
// the extension includes.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline StableIValue from(T val) {
|
inline StableIValue from(T val) {
|
||||||
return detail::FromImpl<T>::call(val);
|
return detail::FromImpl<T>::call(
|
||||||
|
val, aoti_torch_abi_version(), /*is_internal=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline StableIValue from(const std::optional<T>& val) {
|
inline StableIValue from(const std::optional<T>& val) {
|
||||||
return detail::FromImpl<std::optional<T>>::call(val);
|
return detail::FromImpl<std::optional<T>>::call(
|
||||||
|
val, aoti_torch_abi_version(), /*is_internal=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The below overload is used! See https://godbolt.org/z/859cshxrW
|
// The below overload is used! See https://godbolt.org/z/859cshxrW
|
||||||
// We are suppressing the warning for versions clang12- and gcc11-
|
// We are suppressing the warning for versions clang12- and gcc11-
|
||||||
[[maybe_unused]] inline StableIValue from(const torch::stable::Tensor& val) {
|
[[maybe_unused]] inline StableIValue from(const torch::stable::Tensor& val) {
|
||||||
return detail::FromImpl<torch::stable::Tensor>::call(val);
|
return detail::FromImpl<torch::stable::Tensor>::call(
|
||||||
|
val, aoti_torch_abi_version(), /*is_internal=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T to(StableIValue val) {
|
inline T to(StableIValue val) {
|
||||||
return detail::ToImpl<T>::call(val);
|
return detail::ToImpl<T>::call(
|
||||||
|
val, aoti_torch_abi_version(), /*is_internal=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Internal conversion functions used by from_ivalue and to_ivalue.
|
||||||
|
// These are used in libtorch
|
||||||
|
template <typename T>
|
||||||
|
inline StableIValue _from(T val, uint64_t extension_build_version) {
|
||||||
|
return detail::FromImpl<T>::call(
|
||||||
|
val, extension_build_version, /*is_internal=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline StableIValue _from(
|
||||||
|
const std::optional<T>& val,
|
||||||
|
uint64_t extension_build_version) {
|
||||||
|
return detail::FromImpl<std::optional<T>>::call(
|
||||||
|
val, extension_build_version, /*is_internal=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
|
[[maybe_unused]] inline StableIValue _from(
|
||||||
|
const torch::stable::Tensor& val,
|
||||||
|
uint64_t extension_build_version) {
|
||||||
|
return detail::FromImpl<torch::stable::Tensor>::call(
|
||||||
|
val, extension_build_version, /*is_internal=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T _to(StableIValue val, uint64_t extension_build_version) {
|
||||||
|
return detail::ToImpl<T>::call(
|
||||||
|
val, extension_build_version, /*is_internal=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
HIDDEN_NAMESPACE_END(torch, stable, detail)
|
HIDDEN_NAMESPACE_END(torch, stable, detail)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user