Add ScalarType -> shim conversion, add stable::Tensor.scalar_type (#160557)

TL;DR: Moving to ScalarType in user extensions and removing deprecated dtypes.

This change _modifies_ the from/to behavior between ScalarType and StableValue! Whereas before, user extensions could only in abstract pass around obfuscated dtypes appearing as int32_ts, now, users can confidently use torch::headeronly::ScalarType in their extensions for major scalar types. This PR enables ABI stability by adding a translation layer through the shim, so that even if the ScalarType enum values change in the future, user extensions need not fear.

Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t).

I then changed the test to test the scalar_type API.

This code change required some refactoring because of circular dependencies.

## BC Breaking note
This commit is (narrowly) BC-breaking for unpopular dtypes: `quint*`s, `qint*`s, `Bits*`, `dummy_uint*`s, `dummy_int*`s, `Float8_e8m0fnu`, and `Float4_e2m1fn_x2` in the narrow use case where an extension retrieves a Tensor dtype of the above and passes it into `aoti_torch_call_dispatcher`. As of now, I believe there are 0 users of this use case, so the benefits of this change significantly justify BC-breaking this API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160557
Approved by: https://github.com/mikaylagawarecki, https://github.com/malfet
This commit is contained in:
Jane Xu 2025-08-19 11:55:10 -07:00 committed by PyTorch MergeBot
parent 05e8fac4f3
commit 8f766d6839
9 changed files with 552 additions and 388 deletions

View File

@ -9,8 +9,9 @@ This note will eventually contain more details on how to use the APIs in torch/c
| type in custom extension | StableIValue representation | type in libtorch | Schema Type |
| -------- | ------- | ------- | ------- |
| std::optional\<S> | if there is a value, raw bitwise copy into leading bytes of uint64_t of pointer to a new StableIValue representing S. if there is no value, nullptr. | std::optional\<T> | Type? |
| RAIIATH | raw bitwise copy of underlying AtenTensorHandle into leading bytes of uint64_t | at::Tensor | Tensor |
| int32_t | raw bitwise copy into leading bytes of uint64_t | at::ScalarType | ScalarType |
| torch::stable::Tensor | raw bitwise copy of underlying AtenTensorHandle into leading bytes of uint64_t | at::Tensor | Tensor |
| RAIIATH (outdated) | raw bitwise copy of underlying AtenTensorHandle into leading bytes of uint64_t | at::Tensor | Tensor |
| torch::headeronly::ScalarType | raw bitwise copy of the translated underlying enum into leading bytes of uint64_t | torch::headeronly::ScalarType | ScalarType |
| int32_t | raw bitwise copy into leading bytes of uint64_t | at::Layout | Layout |
| int32_t | raw bitwise copy into leading bytes of uint64_t | at::MemoryFormat | MemoryFormat |
| bool | raw bitwise copy into leading bytes of uint64_t | bool | bool |

View File

@ -139,12 +139,10 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
const auto num_args = 6;
StableIValue stack[num_args];
int32_t t_dtype;
aoti_torch_get_dtype(t.get(), &t_dtype);
auto mf = aoti_torch_memory_format_contiguous_format();
stack[0] = from(t);
stack[1] = from(std::optional(t_dtype)); // dtype
stack[1] = from(std::optional(t.scalar_type())); // dtype
stack[2] = from(std::nullopt); // layout
stack[3] = from(std::optional(device)); // device
stack[4] = from(std::optional(false)); // pin_memory

View File

@ -1227,7 +1227,7 @@ class TestCppExtensionJIT(common.TestCase):
#include <torch/csrc/inductor/aoti_runtime/utils.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/stableivalue_conversions.h>
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;

View File

@ -4,229 +4,16 @@
// code for better UX.
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/tensor.h>
#include <optional>
// Technically, this file doesn't use anything from stableivalue_conversions.h,
// but we need to include it here as the contents of stableivalue_conversions.h
// used to live here and so we need to expose them for backwards compatibility.
#include <torch/csrc/stable/stableivalue_conversions.h>
// use anonymous namespace to avoid collisions between differing
// versions of this file that may be included by different sources
namespace {
// =============================================================================
// helpers for converting between StableIValue and T
// =============================================================================
// forward declare so that from/to() calls in detail work
template <typename T>
StableIValue from(T val);
template <typename T>
T to(StableIValue val);
namespace detail {
// =============================================================================
// FROM CONVERSIONS (T -> StableIValue)
// =============================================================================
// Specialization for general copyable types (catch-all) => StableIValue
template <typename T>
struct FromImpl {
static StableIValue call(T val) {
static_assert(
sizeof(T) <= sizeof(StableIValue),
"StableLibrary stack does not support parameter types larger than 64 bits.");
static_assert(std::is_trivially_copyable_v<T>);
// Initialization should be cheap enough; let's give people well-specified
// reproducible behavior.
StableIValue result = 0;
// NOTE [ -Wclass-memaccess ]: reinterpret_cast to suppress
// overzealous -Wclass-memaccess. (see
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=107361) We have a
// static_assert above that T is trivially copyable, which should be
// enough.
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
std::memcpy(&result, reinterpret_cast<const void*>(&val), sizeof(val));
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
// if value has size less than sizeof(StableIValue), then only lowest bytes
// have to be updated
std::memcpy(
reinterpret_cast<unsigned char*>(&result) + sizeof(StableIValue) -
sizeof(val),
reinterpret_cast<const void*>(&val),
sizeof(val));
#else
#error Unexpected or undefined __BYTE_ORDER__
#endif
return result;
}
};
// Specialization for std::nullopt_t => StableIValue
template <>
struct FromImpl<std::nullopt_t> {
static StableIValue call(std::nullopt_t val) {
return from(nullptr);
}
};
// Specialization for std::optional => StableIValue
// [Handling std::optional]
// When the schema is represented by an optional type, say int?, then we
// expect the custom extension representation to be a std::optional<int>
// (critically NOT int!). In order for all parameters to be stably parsed and
// handled by our dispatcher, we liaison custom extension parameters through
// boxed kernels, meaning that every value will make its way to be an IValue:
//
// custom extension value --(from)-> StableIValue --(to_ivalue)-> IValue
//
// When the custom extension value is a literal that can be trivially
// casted to StableIValue, e.g., an int, a float, a pointer, this route is
// ...trivial. The below specialization is for a case when the custom
// extension value would NOT fit within a StableIValue: a std::optional.
//
// If the std::optional has no value, it is treated as std::nullopt,
// whose StableIValue representation is from(nullptr). Otherwise, we:
// 1. unwrap the std::optional<T>
// 2. recursively convert its value of type T to a StableIValue
// 3. allocate heap space for said StableIValue
// 4. convert the resulting StableIValue* into a StableIValue
//
// note that this allocates heap memory! which we expect to be cleaned
// up in the to_ivalue() function defined in shim_common.cpp. We
// purposefully hide this implementation detail from the user so that
// all the user needs to know is:
//
// The schema requests an optional (T?) so I must call `from` on a
// std::optional<T> or a std::nullopt.
template <typename T>
struct FromImpl<std::optional<T>> {
static StableIValue call(const std::optional<T>& val) {
if (!val.has_value()) {
return from(std::nullopt);
}
StableIValue* heap_val = new StableIValue(from(val.value()));
return from(heap_val);
}
};
// Specialization for torch::stable::Tensor => StableIValue
// Returns a new owning reference of the underlying Tensor.
template <>
struct FromImpl<torch::stable::Tensor> {
static StableIValue call(const torch::stable::Tensor& val) {
AtenTensorHandle new_ath;
aoti_torch_new_tensor_handle(val.get(), &new_ath);
return from(new_ath);
}
};
// =============================================================================
// TO CONVERSIONS (StableIValue -> T)
// =============================================================================
// Specialization for StableIValue => general copyable types (catch-all)
template <typename T>
struct ToImpl {
static T call(StableIValue val) {
static_assert(std::is_trivially_copyable_v<T>);
// T may not have a default constructor. (For example, it might be
// c10::Device.) However, std::memcpy implicitly creates a T at the
// destination. So, we can use a union to work around this lack of
// default constructor.
union Result {
Result() {}
T t;
};
Result result;
// See NOTE[ -Wclass-memaccess ] above.
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
std::memcpy(reinterpret_cast<void*>(&result.t), &val, sizeof(result));
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
static_assert(
sizeof(T) <= sizeof(StableIValue),
"StableLibrary stack does not support parameter types larger than 64 bits.");
// if value has size less than sizeof(StableIValue), then only lowest bytes
// have to be updated
std::memcpy(
reinterpret_cast<void*>(&result.t),
reinterpret_cast<unsigned char*>(&val) + sizeof(StableIValue) -
sizeof(result),
sizeof(result));
#else
#error Unexpected or undefined __BYTE_ORDER__
#endif
return result.t;
}
};
// Specialization for StableIValue => std::nullopt_t
template <>
struct ToImpl<std::nullopt_t> {
static std::nullopt_t call(StableIValue val) {
// val should be equivalent to from(nullptr)
return std::nullopt;
}
};
// Specialization for StableIValue => std::optional, see [Handling
// std::optional] as the semantic is the same but in reverse direction as we go
// from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
template <typename T>
struct ToImpl<std::optional<T>> {
static std::optional<T> call(StableIValue val) {
auto sivp = to<StableIValue*>(val);
// sivp is either nullptr or a pointer to a StableIValue
if (sivp == nullptr) {
return {};
}
auto inner_val = to<T>(*sivp);
// free the memory associated with StableIValue* sivp
delete sivp;
return std::make_optional(inner_val);
}
};
// Specialization for StableIValue => torch::stable::Tensor
// The resulting stable::Tensor steals ownership of the input's
// underlying AtenTensorHandle.
template <>
struct ToImpl<torch::stable::Tensor> {
static torch::stable::Tensor call(StableIValue val) {
return torch::stable::Tensor(to<AtenTensorHandle>(val));
}
};
} // namespace detail
// Expose the partially templated class functions through single functions
template <typename T>
StableIValue from(T val) {
return detail::FromImpl<T>::call(val);
}
template <typename T>
StableIValue from(const std::optional<T>& val) {
return detail::FromImpl<std::optional<T>>::call(val);
}
// The below overload is used! See https://godbolt.org/z/859cshxrW
// We are suppressing the warning for versions clang12- and gcc11-
[[maybe_unused]] StableIValue from(const torch::stable::Tensor& val) {
return detail::FromImpl<torch::stable::Tensor>::call(val);
}
template <typename T>
T to(StableIValue val) {
return detail::ToImpl<T>::call(val);
}
// =============================================================================
// end to helpers for converting between StableIValue and T
// =============================================================================
class StableLibrary final {
private:
TorchLibraryHandle lib_;

View File

@ -1,6 +1,6 @@
#pragma once
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/stableivalue_conversions.h>
#include <array>
#include <cstdint>
#include <optional>

View File

@ -0,0 +1,345 @@
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/tensor_struct.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/shim_utils.h>
#include <optional>
// use anonymous namespace to avoid collisions between differing
// versions of this file that may be included by different sources
namespace {
// forward declare so that the from/to() implementations in the detail
// namespace of library.h where the real work is done can compile.
template <typename T>
StableIValue from(T val);
template <typename T>
T to(StableIValue val);
// =============================================================================
// helpers for converting between StableIValue and T
// =============================================================================
// note that the signatures for from and to are forward declared in
// stable/stableivalue_conversions.h but defined below to avoid circular
// dependencies where other headers (like tensor-inl.h) will need to/from.
namespace detail {
// =============================================================================
// FROM CONVERSIONS (T -> StableIValue)
// =============================================================================
// Specialization for general copyable types (catch-all) => StableIValue
template <typename T>
struct FromImpl {
static StableIValue call(T val) {
static_assert(
sizeof(T) <= sizeof(StableIValue),
"StableLibrary stack does not support parameter types larger than 64 bits.");
static_assert(std::is_trivially_copyable_v<T>);
// Initialization should be cheap enough; let's give people well-specified
// reproducible behavior.
StableIValue result = 0;
// NOTE [ -Wclass-memaccess ]: reinterpret_cast to suppress
// overzealous -Wclass-memaccess. (see
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=107361) We have a
// static_assert above that T is trivially copyable, which should be
// enough.
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
std::memcpy(&result, reinterpret_cast<const void*>(&val), sizeof(val));
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
// if value has size less than sizeof(StableIValue), then only lowest bytes
// have to be updated
std::memcpy(
reinterpret_cast<unsigned char*>(&result) + sizeof(StableIValue) -
sizeof(val),
reinterpret_cast<const void*>(&val),
sizeof(val));
#else
#error "Unexpected or undefined __BYTE_ORDER__"
#endif
return result;
}
};
// Specialization for torch::headeronly::ScalarType => StableIValue
// Note that we call into the shim to translate between the user's
// ScalarType and libtorch's ScalarType, which can be different!
// Also note that the list below is not comprehensive, as it does not
// include types that are no longer really used and should probably be
// deprecated (like qint8).
using torch::headeronly::ScalarType;
template <>
struct FromImpl<ScalarType> {
static StableIValue call(ScalarType val) {
switch (val) {
case ScalarType::Byte:
return from(aoti_torch_dtype_uint8());
case ScalarType::Char:
return from(aoti_torch_dtype_int8());
case ScalarType::Short:
return from(aoti_torch_dtype_int16());
case ScalarType::Int:
return from(aoti_torch_dtype_int32());
case ScalarType::Long:
return from(aoti_torch_dtype_int64());
case ScalarType::Half:
return from(aoti_torch_dtype_float16());
case ScalarType::Float:
return from(aoti_torch_dtype_float32());
case ScalarType::Double:
return from(aoti_torch_dtype_float64());
case ScalarType::ComplexHalf:
return from(aoti_torch_dtype_complex32());
case ScalarType::ComplexFloat:
return from(aoti_torch_dtype_complex64());
case ScalarType::ComplexDouble:
return from(aoti_torch_dtype_complex128());
case ScalarType::Bool:
return from(aoti_torch_dtype_bool());
case ScalarType::BFloat16:
return from(aoti_torch_dtype_bfloat16());
case ScalarType::Float8_e5m2:
return from(aoti_torch_dtype_float8_e5m2());
case ScalarType::Float8_e4m3fn:
return from(aoti_torch_dtype_float8_e4m3fn());
case ScalarType::Float8_e5m2fnuz:
return from(aoti_torch_dtype_float8_e5m2fnuz());
case ScalarType::Float8_e4m3fnuz:
return from(aoti_torch_dtype_float8_e4m3fnuz());
case ScalarType::UInt16:
return from(aoti_torch_dtype_uint16());
case ScalarType::UInt32:
return from(aoti_torch_dtype_uint32());
case ScalarType::UInt64:
return from(aoti_torch_dtype_uint64());
default:
throw std::runtime_error(
"Not yet supported ScalarType, please file an issue describing your use case.");
}
}
};
// Specialization for std::nullopt_t => StableIValue
template <>
struct FromImpl<std::nullopt_t> {
static StableIValue call(std::nullopt_t val) {
return from(nullptr);
}
};
// Specialization for std::optional => StableIValue
// [Handling std::optional]
// When the schema is represented by an optional type, say int?, then we
// expect the custom extension representation to be a std::optional<int>
// (critically NOT int!). In order for all parameters to be stably parsed and
// handled by our dispatcher, we liaison custom extension parameters through
// boxed kernels, meaning that every value will make its way to be an IValue:
//
// custom extension value --(from)-> StableIValue --(to_ivalue)-> IValue
//
// When the custom extension value is a literal that can be trivially
// casted to StableIValue, e.g., an int, a float, a pointer, this route is
// ...trivial. The below specialization is for a case when the custom
// extension value would NOT fit within a StableIValue: a std::optional.
//
// If the std::optional has no value, it is treated as std::nullopt,
// whose StableIValue representation is from(nullptr). Otherwise, we:
// 1. unwrap the std::optional<T>
// 2. recursively convert its value of type T to a StableIValue
// 3. allocate heap space for said StableIValue
// 4. convert the resulting StableIValue* into a StableIValue
//
// note that this allocates heap memory! which we expect to be cleaned
// up in the to_ivalue() function defined in shim_common.cpp. We
// purposefully hide this implementation detail from the user so that
// all the user needs to know is:
//
// The schema requests an optional (T?) so I must call `from` on a
// std::optional<T> or a std::nullopt.
template <typename T>
struct FromImpl<std::optional<T>> {
static StableIValue call(const std::optional<T>& val) {
if (!val.has_value()) {
return from(std::nullopt);
}
return from(new StableIValue(from(val.value())));
}
};
// Specialization for torch::stable::Tensor => StableIValue
// Returns a new owning reference of the underlying Tensor.
template <>
struct FromImpl<torch::stable::Tensor> {
static StableIValue call(const torch::stable::Tensor& val) {
AtenTensorHandle new_ath;
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
return from(new_ath);
}
};
// =============================================================================
// TO CONVERSIONS (StableIValue -> T)
// =============================================================================
// Specialization for StableIValue => general copyable types (catch-all)
template <typename T>
struct ToImpl {
static T call(StableIValue val) {
static_assert(std::is_trivially_copyable_v<T>);
// T may not have a default constructor. (For example, it might be
// c10::Device.) However, std::memcpy implicitly creates a T at the
// destination. So, we can use a union to work around this lack of
// default constructor.
union Result {
Result() {}
T t;
};
Result result;
// See NOTE[ -Wclass-memaccess ] above.
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
std::memcpy(reinterpret_cast<void*>(&result.t), &val, sizeof(result));
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
static_assert(
sizeof(T) <= sizeof(StableIValue),
"StableLibrary stack does not support parameter types larger than 64 bits.");
// if value has size less than sizeof(StableIValue), then only lowest bytes
// have to be updated
std::memcpy(
reinterpret_cast<void*>(&result.t),
reinterpret_cast<unsigned char*>(&val) + sizeof(StableIValue) -
sizeof(result),
sizeof(result));
#else
#error "Unexpected or undefined __BYTE_ORDER__"
#endif
return result.t;
}
};
// Specialization for StableIValue => torch::headeronly::ScalarType
template <>
struct ToImpl<ScalarType> {
static ScalarType call(StableIValue val) {
int32_t shim_scalartype = to<int32_t>(val);
if (shim_scalartype == aoti_torch_dtype_uint8()) {
return ScalarType::Byte;
} else if (shim_scalartype == aoti_torch_dtype_int8()) {
return ScalarType::Char;
} else if (shim_scalartype == aoti_torch_dtype_int16()) {
return ScalarType::Short;
} else if (shim_scalartype == aoti_torch_dtype_int32()) {
return ScalarType::Int;
} else if (shim_scalartype == aoti_torch_dtype_int64()) {
return ScalarType::Long;
} else if (shim_scalartype == aoti_torch_dtype_float16()) {
return ScalarType::Half;
} else if (shim_scalartype == aoti_torch_dtype_float32()) {
return ScalarType::Float;
} else if (shim_scalartype == aoti_torch_dtype_float64()) {
return ScalarType::Double;
} else if (shim_scalartype == aoti_torch_dtype_complex32()) {
return ScalarType::ComplexHalf;
} else if (shim_scalartype == aoti_torch_dtype_complex64()) {
return ScalarType::ComplexFloat;
} else if (shim_scalartype == aoti_torch_dtype_complex128()) {
return ScalarType::ComplexDouble;
} else if (shim_scalartype == aoti_torch_dtype_bool()) {
return ScalarType::Bool;
} else if (shim_scalartype == aoti_torch_dtype_bfloat16()) {
return ScalarType::BFloat16;
} else if (shim_scalartype == aoti_torch_dtype_float8_e5m2()) {
return ScalarType::Float8_e5m2;
} else if (shim_scalartype == aoti_torch_dtype_float8_e4m3fn()) {
return ScalarType::Float8_e4m3fn;
} else if (shim_scalartype == aoti_torch_dtype_float8_e5m2fnuz()) {
return ScalarType::Float8_e5m2fnuz;
} else if (shim_scalartype == aoti_torch_dtype_float8_e4m3fnuz()) {
return ScalarType::Float8_e4m3fnuz;
} else if (shim_scalartype == aoti_torch_dtype_uint16()) {
return ScalarType::UInt16;
} else if (shim_scalartype == aoti_torch_dtype_uint32()) {
return ScalarType::UInt32;
} else if (shim_scalartype == aoti_torch_dtype_uint64()) {
return ScalarType::UInt64;
} else {
throw std::runtime_error(
"Not yet supported ScalarType " + std::to_string(shim_scalartype) +
", please file an issue describing your use case.");
}
}
};
// Specialization for StableIValue => std::nullopt_t
template <>
struct ToImpl<std::nullopt_t> {
static std::nullopt_t call(StableIValue val) {
// val should be equivalent to from(nullptr)
return std::nullopt;
}
};
// Specialization for StableIValue => std::optional, see [Handling
// std::optional] as the semantic is the same but in reverse direction as we go
// from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
template <typename T>
struct ToImpl<std::optional<T>> {
static std::optional<T> call(StableIValue val) {
auto sivp = to<StableIValue*>(val);
// sivp is either nullptr or a pointer to a StableIValue
if (sivp == nullptr) {
return {};
}
auto inner_val = to<T>(*sivp);
// free the memory associated with StableIValue* sivp
delete sivp;
return std::make_optional(inner_val);
}
};
// Specialization for StableIValue => torch::stable::Tensor
// The resulting stable::Tensor steals ownership of the input's
// underlying AtenTensorHandle.
template <>
struct ToImpl<torch::stable::Tensor> {
static torch::stable::Tensor call(StableIValue val) {
return torch::stable::Tensor(to<AtenTensorHandle>(val));
}
};
} // namespace detail
// Expose the partially templated class functions through single functions
template <typename T>
StableIValue from(T val) {
return detail::FromImpl<T>::call(val);
}
template <typename T>
StableIValue from(const std::optional<T>& val) {
return detail::FromImpl<std::optional<T>>::call(val);
}
// The below overload is used! See https://godbolt.org/z/859cshxrW
// We are suppressing the warning for versions clang12- and gcc11-
[[maybe_unused]] StableIValue from(const torch::stable::Tensor& val) {
return detail::FromImpl<torch::stable::Tensor>::call(val);
}
template <typename T>
T to(StableIValue val) {
return detail::ToImpl<T>::call(val);
}
// =============================================================================
// end to helpers for converting between StableIValue and T
// =============================================================================
} // namespace

View File

@ -1,166 +1,4 @@
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/shim_utils.h>
#include <climits>
#include <memory>
#include <torch/csrc/stable/accelerator.h>
namespace torch::stable {
using DeviceIndex = torch::stable::accelerator::DeviceIndex;
// The torch::stable::Tensor class is a highlevel C++ wrapper around
// the C shim Tensor APIs. We've modeled this class after TensorBase, as custom
// op kernels only really need to interact with Tensor metadata (think sizes,
// strides, device, dtype). Other functions on Tensor (like empty_like) should
// live like the ATen op that they are and exist outside of this struct.
//
// There are several goals of this class over AtenTensorHandle and
// RAIIAtenTensorHandle:
// 1. torch::stable::Tensor is a nicer UX much closer to torch::Tensor than the
// C APIs with AtenTensorHandle. Under the hood we still call to these C shim
// APIs to preserve stability.
// 2. RAIIAtenTensorHandle boils down to a uniq_ptr that forces the user to pass
// around ownership. This makes it difficult to pass one input into 2
// different functions, e.g., doing something like c = a(t) + b(t) for
// stable::Tensor t. Thus, we use a shared_ptr here.
class Tensor {
private:
std::shared_ptr<AtenTensorOpaque> ath_;
public:
// Construct a stable::Tensor with an uninitialized AtenTensorHandle (ATH)
// Steals ownership from the ATH
Tensor() {
AtenTensorHandle ret;
TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&ret));
ath_ = std::shared_ptr<AtenTensorOpaque>(ret, [](AtenTensorHandle ath) {
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
});
}
// Construct a stable::Tensor from an AtenTensorHandle (ATH)
// Steals ownership from the ATH
explicit Tensor(AtenTensorHandle ath)
: ath_(ath, [](AtenTensorHandle ath) {
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
}) {}
// Copy and move constructors can be default cuz the underlying handle is a
// shared_ptr
Tensor(const Tensor& other) = default;
Tensor(Tensor&& other) noexcept = default;
// Copy and move assignment operators can be default cuz the underlying handle
// is a shared_ptr
Tensor& operator=(const Tensor& other) = default;
Tensor& operator=(Tensor&& other) noexcept = default;
// Destructor can be default: shared ptr has custom deletion logic
~Tensor() = default;
// Returns a borrowed reference to the AtenTensorHandle
AtenTensorHandle get() const {
return ath_.get();
}
// =============================================================================
// C-shimified TensorBase APIs: the below APIs have the same signatures and
// semantics as their counterparts in TensorBase.h.
// =============================================================================
void* data_ptr() const {
void* data_ptr;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr));
return data_ptr;
}
int64_t dim() const {
int64_t dim;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim));
return dim;
}
int64_t numel() const {
int64_t numel;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel));
return numel;
}
// note: this is a subset of the original TensorBase API. It takes no
// arguments whereas the original API takes in a kwarg of memory format.
// Here, we assume the default contiguous memory format.
bool is_contiguous() const {
bool is_contiguous;
TORCH_ERROR_CODE_CHECK(
aoti_torch_is_contiguous(ath_.get(), &is_contiguous));
return is_contiguous;
}
int64_t stride(int64_t dim) const {
int64_t stride;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_stride(ath_.get(), dim, &stride));
return stride;
}
// This is almost the same API as the one in TensorBase.h, except
// we add a check that the returned device_index is within the
// range of int8_t.
int8_t get_device() const {
int32_t device_index;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_index(ath_.get(), &device_index));
STD_TORCH_CHECK(
device_index >= std::numeric_limits<int8_t>::min() &&
device_index <= std::numeric_limits<int8_t>::max(),
"Device index is out of range of return type int8_t, please use get_device_index() instead.");
return static_cast<int8_t>(device_index);
}
// The same as get_device but with two differences:
// 1. it has a more suiting name
// 2. it returns a DeviceIndex, which is int32_t in this world
// that should be more stable than the likely shifting
// DeviceIndex in libtorch (it is int8_t that might become int16_t)
DeviceIndex get_device_index() const {
int32_t device_index;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_index(ath_.get(), &device_index));
return device_index;
}
bool is_cuda() const {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_type(ath_.get(), &device_type));
return device_type == aoti_torch_device_type_cuda();
}
bool is_cpu() const {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_type(ath_.get(), &device_type));
return device_type == aoti_torch_device_type_cpu();
}
int64_t size(int64_t dim) const {
int64_t size;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(ath_.get(), dim, &size));
return size;
}
bool defined() const {
bool defined;
TORCH_ERROR_CODE_CHECK(aoti_torch_is_defined(ath_.get(), &defined));
return defined;
}
// =============================================================================
// END of C-shimified TensorBase APIs
// =============================================================================
};
} // namespace torch::stable
#include <torch/csrc/stable/tensor_inl.h>
#include <torch/csrc/stable/tensor_struct.h>

View File

@ -0,0 +1,24 @@
#pragma once
// This file implements tensor.h. We separated out the Tensor struct so that
// other files can depend on the Tensor struct (like library.h) and the
// implementations of the Tensor methods can depend on APIs in library.h
// without circular dependencies.
#pragma once
#include <torch/csrc/stable/stableivalue_conversions.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/util/shim_utils.h>
namespace torch::stable {
using torch::headeronly::ScalarType;
ScalarType Tensor::scalar_type() const {
int32_t dtype;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(ath_.get(), &dtype));
return to<ScalarType>(from(dtype));
}
} // namespace torch::stable

View File

@ -0,0 +1,171 @@
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/shim_utils.h>
#include <climits>
#include <memory>
#include <torch/csrc/stable/accelerator.h>
namespace torch::stable {
using accelerator::DeviceIndex;
using torch::headeronly::ScalarType;
// The torch::stable::Tensor class is a highlevel C++ wrapper around
// the C shim Tensor APIs. We've modeled this class after TensorBase, as custom
// op kernels only really need to interact with Tensor metadata (think sizes,
// strides, device, dtype). Other functions on Tensor (like empty_like) should
// live like the ATen op that they are and exist outside of this struct.
//
// There are several goals of this class over AtenTensorHandle and
// RAIIAtenTensorHandle:
// 1. torch::stable::Tensor is a nicer UX much closer to torch::Tensor than the
// C APIs with AtenTensorHandle. Under the hood we still call to these C shim
// APIs to preserve stability.
// 2. RAIIAtenTensorHandle boils down to a uniq_ptr that forces the user to pass
// around ownership. This makes it difficult to pass one input into 2
// different functions, e.g., doing something like c = a(t) + b(t) for
// stable::Tensor t. Thus, we use a shared_ptr here.
class Tensor {
private:
std::shared_ptr<AtenTensorOpaque> ath_;
public:
// Construct a stable::Tensor with an uninitialized AtenTensorHandle (ATH)
// Steals ownership from the ATH
Tensor() {
AtenTensorHandle ret;
TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&ret));
ath_ = std::shared_ptr<AtenTensorOpaque>(ret, [](AtenTensorHandle ath) {
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
});
}
// Construct a stable::Tensor from an AtenTensorHandle (ATH)
// Steals ownership from the ATH
explicit Tensor(AtenTensorHandle ath)
: ath_(ath, [](AtenTensorHandle ath) {
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
}) {}
// Copy and move constructors can be default cuz the underlying handle is a
// shared_ptr
Tensor(const Tensor& other) = default;
Tensor(Tensor&& other) noexcept = default;
// Copy and move assignment operators can be default cuz the underlying handle
// is a shared_ptr
Tensor& operator=(const Tensor& other) = default;
Tensor& operator=(Tensor&& other) noexcept = default;
// Destructor can be default: shared ptr has custom deletion logic
~Tensor() = default;
// Returns a borrowed reference to the AtenTensorHandle
AtenTensorHandle get() const {
return ath_.get();
}
// =============================================================================
// C-shimified TensorBase APIs: the below APIs have the same signatures and
// semantics as their counterparts in TensorBase.h.
// =============================================================================
void* data_ptr() const {
void* data_ptr;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr));
return data_ptr;
}
int64_t dim() const {
int64_t dim;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim));
return dim;
}
int64_t numel() const {
int64_t numel;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel));
return numel;
}
// note: this is a subset of the original TensorBase API. It takes no
// arguments whereas the original API takes in a kwarg of memory format.
// Here, we assume the default contiguous memory format.
bool is_contiguous() const {
bool is_contiguous;
TORCH_ERROR_CODE_CHECK(
aoti_torch_is_contiguous(ath_.get(), &is_contiguous));
return is_contiguous;
}
int64_t stride(int64_t dim) const {
int64_t stride;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_stride(ath_.get(), dim, &stride));
return stride;
}
// This is almost the same API as the one in TensorBase.h, except
// we add a check that the returned device_index is within the
// range of int8_t.
int8_t get_device() const {
int32_t device_index;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_index(ath_.get(), &device_index));
STD_TORCH_CHECK(
device_index >= std::numeric_limits<int8_t>::min() &&
device_index <= std::numeric_limits<int8_t>::max(),
"Device index is out of range of return type int8_t, please use get_device_index() instead.");
return static_cast<int8_t>(device_index);
}
// The same as get_device but with two differences:
// 1. it has a more suiting name
// 2. it returns a DeviceIndex, which is int32_t in this world
// that should be more stable than the likely shifting
// DeviceIndex in libtorch (it is int8_t that might become int16_t)
DeviceIndex get_device_index() const {
int32_t device_index;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_index(ath_.get(), &device_index));
return device_index;
}
bool is_cuda() const {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_type(ath_.get(), &device_type));
return device_type == aoti_torch_device_type_cuda();
}
bool is_cpu() const {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_type(ath_.get(), &device_type));
return device_type == aoti_torch_device_type_cpu();
}
int64_t size(int64_t dim) const {
int64_t size;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(ath_.get(), dim, &size));
return size;
}
bool defined() const {
bool defined;
TORCH_ERROR_CODE_CHECK(aoti_torch_is_defined(ath_.get(), &defined));
return defined;
}
// defined in tensor-inl.h to avoid circular dependencies
ScalarType scalar_type() const;
// =============================================================================
// END of C-shimified TensorBase APIs
// =============================================================================
};
} // namespace torch::stable