mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add ScalarType -> shim conversion, add stable::Tensor.scalar_type (#160557)
TL;DR: Moving to ScalarType in user extensions and removing deprecated dtypes. This change _modifies_ the from/to behavior between ScalarType and StableValue! Whereas before, user extensions could only in abstract pass around obfuscated dtypes appearing as int32_ts, now, users can confidently use torch::headeronly::ScalarType in their extensions for major scalar types. This PR enables ABI stability by adding a translation layer through the shim, so that even if the ScalarType enum values change in the future, user extensions need not fear. Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t). I then changed the test to test the scalar_type API. This code change required some refactoring because of circular dependencies. ## BC Breaking note This commit is (narrowly) BC-breaking for unpopular dtypes: `quint*`s, `qint*`s, `Bits*`, `dummy_uint*`s, `dummy_int*`s, `Float8_e8m0fnu`, and `Float4_e2m1fn_x2` in the narrow use case where an extension retrieves a Tensor dtype of the above and passes it into `aoti_torch_call_dispatcher`. As of now, I believe there are 0 users of this use case, so the benefits of this change significantly justify BC-breaking this API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160557 Approved by: https://github.com/mikaylagawarecki, https://github.com/malfet
This commit is contained in:
parent
05e8fac4f3
commit
8f766d6839
|
|
@ -9,8 +9,9 @@ This note will eventually contain more details on how to use the APIs in torch/c
|
|||
| type in custom extension | StableIValue representation | type in libtorch | Schema Type |
|
||||
| -------- | ------- | ------- | ------- |
|
||||
| std::optional\<S> | if there is a value, raw bitwise copy into leading bytes of uint64_t of pointer to a new StableIValue representing S. if there is no value, nullptr. | std::optional\<T> | Type? |
|
||||
| RAIIATH | raw bitwise copy of underlying AtenTensorHandle into leading bytes of uint64_t | at::Tensor | Tensor |
|
||||
| int32_t | raw bitwise copy into leading bytes of uint64_t | at::ScalarType | ScalarType |
|
||||
| torch::stable::Tensor | raw bitwise copy of underlying AtenTensorHandle into leading bytes of uint64_t | at::Tensor | Tensor |
|
||||
| RAIIATH (outdated) | raw bitwise copy of underlying AtenTensorHandle into leading bytes of uint64_t | at::Tensor | Tensor |
|
||||
| torch::headeronly::ScalarType | raw bitwise copy of the translated underlying enum into leading bytes of uint64_t | torch::headeronly::ScalarType | ScalarType |
|
||||
| int32_t | raw bitwise copy into leading bytes of uint64_t | at::Layout | Layout |
|
||||
| int32_t | raw bitwise copy into leading bytes of uint64_t | at::MemoryFormat | MemoryFormat |
|
||||
| bool | raw bitwise copy into leading bytes of uint64_t | bool | bool |
|
||||
|
|
|
|||
|
|
@ -139,12 +139,10 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
|
|||
const auto num_args = 6;
|
||||
StableIValue stack[num_args];
|
||||
|
||||
int32_t t_dtype;
|
||||
aoti_torch_get_dtype(t.get(), &t_dtype);
|
||||
auto mf = aoti_torch_memory_format_contiguous_format();
|
||||
|
||||
stack[0] = from(t);
|
||||
stack[1] = from(std::optional(t_dtype)); // dtype
|
||||
stack[1] = from(std::optional(t.scalar_type())); // dtype
|
||||
stack[2] = from(std::nullopt); // layout
|
||||
stack[3] = from(std::optional(device)); // device
|
||||
stack[4] = from(std::optional(false)); // pin_memory
|
||||
|
|
|
|||
|
|
@ -1227,7 +1227,7 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/stableivalue_conversions.h>
|
||||
|
||||
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
|
||||
|
||||
|
|
|
|||
|
|
@ -4,229 +4,16 @@
|
|||
// code for better UX.
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include <optional>
|
||||
// Technically, this file doesn't use anything from stableivalue_conversions.h,
|
||||
// but we need to include it here as the contents of stableivalue_conversions.h
|
||||
// used to live here and so we need to expose them for backwards compatibility.
|
||||
#include <torch/csrc/stable/stableivalue_conversions.h>
|
||||
|
||||
// use anonymous namespace to avoid collisions between differing
|
||||
// versions of this file that may be included by different sources
|
||||
namespace {
|
||||
|
||||
// =============================================================================
|
||||
// helpers for converting between StableIValue and T
|
||||
// =============================================================================
|
||||
|
||||
// forward declare so that from/to() calls in detail work
|
||||
template <typename T>
|
||||
StableIValue from(T val);
|
||||
template <typename T>
|
||||
T to(StableIValue val);
|
||||
|
||||
namespace detail {
|
||||
|
||||
// =============================================================================
|
||||
// FROM CONVERSIONS (T -> StableIValue)
|
||||
// =============================================================================
|
||||
|
||||
// Specialization for general copyable types (catch-all) => StableIValue
|
||||
template <typename T>
|
||||
struct FromImpl {
|
||||
static StableIValue call(T val) {
|
||||
static_assert(
|
||||
sizeof(T) <= sizeof(StableIValue),
|
||||
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
||||
static_assert(std::is_trivially_copyable_v<T>);
|
||||
// Initialization should be cheap enough; let's give people well-specified
|
||||
// reproducible behavior.
|
||||
StableIValue result = 0;
|
||||
// NOTE [ -Wclass-memaccess ]: reinterpret_cast to suppress
|
||||
// overzealous -Wclass-memaccess. (see
|
||||
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=107361) We have a
|
||||
// static_assert above that T is trivially copyable, which should be
|
||||
// enough.
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
std::memcpy(&result, reinterpret_cast<const void*>(&val), sizeof(val));
|
||||
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
// if value has size less than sizeof(StableIValue), then only lowest bytes
|
||||
// have to be updated
|
||||
std::memcpy(
|
||||
reinterpret_cast<unsigned char*>(&result) + sizeof(StableIValue) -
|
||||
sizeof(val),
|
||||
reinterpret_cast<const void*>(&val),
|
||||
sizeof(val));
|
||||
#else
|
||||
#error Unexpected or undefined __BYTE_ORDER__
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for std::nullopt_t => StableIValue
|
||||
template <>
|
||||
struct FromImpl<std::nullopt_t> {
|
||||
static StableIValue call(std::nullopt_t val) {
|
||||
return from(nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for std::optional => StableIValue
|
||||
// [Handling std::optional]
|
||||
// When the schema is represented by an optional type, say int?, then we
|
||||
// expect the custom extension representation to be a std::optional<int>
|
||||
// (critically NOT int!). In order for all parameters to be stably parsed and
|
||||
// handled by our dispatcher, we liaison custom extension parameters through
|
||||
// boxed kernels, meaning that every value will make its way to be an IValue:
|
||||
//
|
||||
// custom extension value --(from)-> StableIValue --(to_ivalue)-> IValue
|
||||
//
|
||||
// When the custom extension value is a literal that can be trivially
|
||||
// casted to StableIValue, e.g., an int, a float, a pointer, this route is
|
||||
// ...trivial. The below specialization is for a case when the custom
|
||||
// extension value would NOT fit within a StableIValue: a std::optional.
|
||||
//
|
||||
// If the std::optional has no value, it is treated as std::nullopt,
|
||||
// whose StableIValue representation is from(nullptr). Otherwise, we:
|
||||
// 1. unwrap the std::optional<T>
|
||||
// 2. recursively convert its value of type T to a StableIValue
|
||||
// 3. allocate heap space for said StableIValue
|
||||
// 4. convert the resulting StableIValue* into a StableIValue
|
||||
//
|
||||
// note that this allocates heap memory! which we expect to be cleaned
|
||||
// up in the to_ivalue() function defined in shim_common.cpp. We
|
||||
// purposefully hide this implementation detail from the user so that
|
||||
// all the user needs to know is:
|
||||
//
|
||||
// The schema requests an optional (T?) so I must call `from` on a
|
||||
// std::optional<T> or a std::nullopt.
|
||||
template <typename T>
|
||||
struct FromImpl<std::optional<T>> {
|
||||
static StableIValue call(const std::optional<T>& val) {
|
||||
if (!val.has_value()) {
|
||||
return from(std::nullopt);
|
||||
}
|
||||
StableIValue* heap_val = new StableIValue(from(val.value()));
|
||||
return from(heap_val);
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for torch::stable::Tensor => StableIValue
|
||||
// Returns a new owning reference of the underlying Tensor.
|
||||
template <>
|
||||
struct FromImpl<torch::stable::Tensor> {
|
||||
static StableIValue call(const torch::stable::Tensor& val) {
|
||||
AtenTensorHandle new_ath;
|
||||
aoti_torch_new_tensor_handle(val.get(), &new_ath);
|
||||
return from(new_ath);
|
||||
}
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
// TO CONVERSIONS (StableIValue -> T)
|
||||
// =============================================================================
|
||||
|
||||
// Specialization for StableIValue => general copyable types (catch-all)
|
||||
template <typename T>
|
||||
struct ToImpl {
|
||||
static T call(StableIValue val) {
|
||||
static_assert(std::is_trivially_copyable_v<T>);
|
||||
// T may not have a default constructor. (For example, it might be
|
||||
// c10::Device.) However, std::memcpy implicitly creates a T at the
|
||||
// destination. So, we can use a union to work around this lack of
|
||||
// default constructor.
|
||||
union Result {
|
||||
Result() {}
|
||||
T t;
|
||||
};
|
||||
Result result;
|
||||
// See NOTE[ -Wclass-memaccess ] above.
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
std::memcpy(reinterpret_cast<void*>(&result.t), &val, sizeof(result));
|
||||
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
static_assert(
|
||||
sizeof(T) <= sizeof(StableIValue),
|
||||
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
||||
// if value has size less than sizeof(StableIValue), then only lowest bytes
|
||||
// have to be updated
|
||||
std::memcpy(
|
||||
reinterpret_cast<void*>(&result.t),
|
||||
reinterpret_cast<unsigned char*>(&val) + sizeof(StableIValue) -
|
||||
sizeof(result),
|
||||
sizeof(result));
|
||||
#else
|
||||
#error Unexpected or undefined __BYTE_ORDER__
|
||||
#endif
|
||||
return result.t;
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for StableIValue => std::nullopt_t
|
||||
template <>
|
||||
struct ToImpl<std::nullopt_t> {
|
||||
static std::nullopt_t call(StableIValue val) {
|
||||
// val should be equivalent to from(nullptr)
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for StableIValue => std::optional, see [Handling
|
||||
// std::optional] as the semantic is the same but in reverse direction as we go
|
||||
// from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
|
||||
template <typename T>
|
||||
struct ToImpl<std::optional<T>> {
|
||||
static std::optional<T> call(StableIValue val) {
|
||||
auto sivp = to<StableIValue*>(val);
|
||||
|
||||
// sivp is either nullptr or a pointer to a StableIValue
|
||||
if (sivp == nullptr) {
|
||||
return {};
|
||||
}
|
||||
auto inner_val = to<T>(*sivp);
|
||||
|
||||
// free the memory associated with StableIValue* sivp
|
||||
delete sivp;
|
||||
|
||||
return std::make_optional(inner_val);
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for StableIValue => torch::stable::Tensor
|
||||
// The resulting stable::Tensor steals ownership of the input's
|
||||
// underlying AtenTensorHandle.
|
||||
template <>
|
||||
struct ToImpl<torch::stable::Tensor> {
|
||||
static torch::stable::Tensor call(StableIValue val) {
|
||||
return torch::stable::Tensor(to<AtenTensorHandle>(val));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Expose the partially templated class functions through single functions
|
||||
template <typename T>
|
||||
StableIValue from(T val) {
|
||||
return detail::FromImpl<T>::call(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
StableIValue from(const std::optional<T>& val) {
|
||||
return detail::FromImpl<std::optional<T>>::call(val);
|
||||
}
|
||||
|
||||
// The below overload is used! See https://godbolt.org/z/859cshxrW
|
||||
// We are suppressing the warning for versions clang12- and gcc11-
|
||||
[[maybe_unused]] StableIValue from(const torch::stable::Tensor& val) {
|
||||
return detail::FromImpl<torch::stable::Tensor>::call(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T to(StableIValue val) {
|
||||
return detail::ToImpl<T>::call(val);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// end to helpers for converting between StableIValue and T
|
||||
// =============================================================================
|
||||
|
||||
class StableLibrary final {
|
||||
private:
|
||||
TorchLibraryHandle lib_;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/stableivalue_conversions.h>
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
|
|
|
|||
345
torch/csrc/stable/stableivalue_conversions.h
Normal file
345
torch/csrc/stable/stableivalue_conversions.h
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/tensor_struct.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include <optional>
|
||||
|
||||
// use anonymous namespace to avoid collisions between differing
|
||||
// versions of this file that may be included by different sources
|
||||
namespace {
|
||||
|
||||
// forward declare so that the from/to() implementations in the detail
|
||||
// namespace of library.h where the real work is done can compile.
|
||||
template <typename T>
|
||||
StableIValue from(T val);
|
||||
template <typename T>
|
||||
T to(StableIValue val);
|
||||
|
||||
// =============================================================================
|
||||
// helpers for converting between StableIValue and T
|
||||
// =============================================================================
|
||||
|
||||
// note that the signatures for from and to are forward declared in
|
||||
// stable/stableivalue_conversions.h but defined below to avoid circular
|
||||
// dependencies where other headers (like tensor-inl.h) will need to/from.
|
||||
|
||||
namespace detail {
|
||||
|
||||
// =============================================================================
|
||||
// FROM CONVERSIONS (T -> StableIValue)
|
||||
// =============================================================================
|
||||
|
||||
// Specialization for general copyable types (catch-all) => StableIValue
|
||||
template <typename T>
|
||||
struct FromImpl {
|
||||
static StableIValue call(T val) {
|
||||
static_assert(
|
||||
sizeof(T) <= sizeof(StableIValue),
|
||||
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
||||
static_assert(std::is_trivially_copyable_v<T>);
|
||||
// Initialization should be cheap enough; let's give people well-specified
|
||||
// reproducible behavior.
|
||||
StableIValue result = 0;
|
||||
// NOTE [ -Wclass-memaccess ]: reinterpret_cast to suppress
|
||||
// overzealous -Wclass-memaccess. (see
|
||||
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=107361) We have a
|
||||
// static_assert above that T is trivially copyable, which should be
|
||||
// enough.
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
std::memcpy(&result, reinterpret_cast<const void*>(&val), sizeof(val));
|
||||
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
// if value has size less than sizeof(StableIValue), then only lowest bytes
|
||||
// have to be updated
|
||||
std::memcpy(
|
||||
reinterpret_cast<unsigned char*>(&result) + sizeof(StableIValue) -
|
||||
sizeof(val),
|
||||
reinterpret_cast<const void*>(&val),
|
||||
sizeof(val));
|
||||
#else
|
||||
#error "Unexpected or undefined __BYTE_ORDER__"
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for torch::headeronly::ScalarType => StableIValue
|
||||
// Note that we call into the shim to translate between the user's
|
||||
// ScalarType and libtorch's ScalarType, which can be different!
|
||||
// Also note that the list below is not comprehensive, as it does not
|
||||
// include types that are no longer really used and should probably be
|
||||
// deprecated (like qint8).
|
||||
using torch::headeronly::ScalarType;
|
||||
template <>
|
||||
struct FromImpl<ScalarType> {
|
||||
static StableIValue call(ScalarType val) {
|
||||
switch (val) {
|
||||
case ScalarType::Byte:
|
||||
return from(aoti_torch_dtype_uint8());
|
||||
case ScalarType::Char:
|
||||
return from(aoti_torch_dtype_int8());
|
||||
case ScalarType::Short:
|
||||
return from(aoti_torch_dtype_int16());
|
||||
case ScalarType::Int:
|
||||
return from(aoti_torch_dtype_int32());
|
||||
case ScalarType::Long:
|
||||
return from(aoti_torch_dtype_int64());
|
||||
case ScalarType::Half:
|
||||
return from(aoti_torch_dtype_float16());
|
||||
case ScalarType::Float:
|
||||
return from(aoti_torch_dtype_float32());
|
||||
case ScalarType::Double:
|
||||
return from(aoti_torch_dtype_float64());
|
||||
case ScalarType::ComplexHalf:
|
||||
return from(aoti_torch_dtype_complex32());
|
||||
case ScalarType::ComplexFloat:
|
||||
return from(aoti_torch_dtype_complex64());
|
||||
case ScalarType::ComplexDouble:
|
||||
return from(aoti_torch_dtype_complex128());
|
||||
case ScalarType::Bool:
|
||||
return from(aoti_torch_dtype_bool());
|
||||
case ScalarType::BFloat16:
|
||||
return from(aoti_torch_dtype_bfloat16());
|
||||
case ScalarType::Float8_e5m2:
|
||||
return from(aoti_torch_dtype_float8_e5m2());
|
||||
case ScalarType::Float8_e4m3fn:
|
||||
return from(aoti_torch_dtype_float8_e4m3fn());
|
||||
case ScalarType::Float8_e5m2fnuz:
|
||||
return from(aoti_torch_dtype_float8_e5m2fnuz());
|
||||
case ScalarType::Float8_e4m3fnuz:
|
||||
return from(aoti_torch_dtype_float8_e4m3fnuz());
|
||||
case ScalarType::UInt16:
|
||||
return from(aoti_torch_dtype_uint16());
|
||||
case ScalarType::UInt32:
|
||||
return from(aoti_torch_dtype_uint32());
|
||||
case ScalarType::UInt64:
|
||||
return from(aoti_torch_dtype_uint64());
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Not yet supported ScalarType, please file an issue describing your use case.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for std::nullopt_t => StableIValue
|
||||
template <>
|
||||
struct FromImpl<std::nullopt_t> {
|
||||
static StableIValue call(std::nullopt_t val) {
|
||||
return from(nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for std::optional => StableIValue
|
||||
// [Handling std::optional]
|
||||
// When the schema is represented by an optional type, say int?, then we
|
||||
// expect the custom extension representation to be a std::optional<int>
|
||||
// (critically NOT int!). In order for all parameters to be stably parsed and
|
||||
// handled by our dispatcher, we liaison custom extension parameters through
|
||||
// boxed kernels, meaning that every value will make its way to be an IValue:
|
||||
//
|
||||
// custom extension value --(from)-> StableIValue --(to_ivalue)-> IValue
|
||||
//
|
||||
// When the custom extension value is a literal that can be trivially
|
||||
// casted to StableIValue, e.g., an int, a float, a pointer, this route is
|
||||
// ...trivial. The below specialization is for a case when the custom
|
||||
// extension value would NOT fit within a StableIValue: a std::optional.
|
||||
//
|
||||
// If the std::optional has no value, it is treated as std::nullopt,
|
||||
// whose StableIValue representation is from(nullptr). Otherwise, we:
|
||||
// 1. unwrap the std::optional<T>
|
||||
// 2. recursively convert its value of type T to a StableIValue
|
||||
// 3. allocate heap space for said StableIValue
|
||||
// 4. convert the resulting StableIValue* into a StableIValue
|
||||
//
|
||||
// note that this allocates heap memory! which we expect to be cleaned
|
||||
// up in the to_ivalue() function defined in shim_common.cpp. We
|
||||
// purposefully hide this implementation detail from the user so that
|
||||
// all the user needs to know is:
|
||||
//
|
||||
// The schema requests an optional (T?) so I must call `from` on a
|
||||
// std::optional<T> or a std::nullopt.
|
||||
template <typename T>
|
||||
struct FromImpl<std::optional<T>> {
|
||||
static StableIValue call(const std::optional<T>& val) {
|
||||
if (!val.has_value()) {
|
||||
return from(std::nullopt);
|
||||
}
|
||||
return from(new StableIValue(from(val.value())));
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for torch::stable::Tensor => StableIValue
|
||||
// Returns a new owning reference of the underlying Tensor.
|
||||
template <>
|
||||
struct FromImpl<torch::stable::Tensor> {
|
||||
static StableIValue call(const torch::stable::Tensor& val) {
|
||||
AtenTensorHandle new_ath;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
|
||||
return from(new_ath);
|
||||
}
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
// TO CONVERSIONS (StableIValue -> T)
|
||||
// =============================================================================
|
||||
|
||||
// Specialization for StableIValue => general copyable types (catch-all)
|
||||
template <typename T>
|
||||
struct ToImpl {
|
||||
static T call(StableIValue val) {
|
||||
static_assert(std::is_trivially_copyable_v<T>);
|
||||
// T may not have a default constructor. (For example, it might be
|
||||
// c10::Device.) However, std::memcpy implicitly creates a T at the
|
||||
// destination. So, we can use a union to work around this lack of
|
||||
// default constructor.
|
||||
union Result {
|
||||
Result() {}
|
||||
T t;
|
||||
};
|
||||
Result result;
|
||||
// See NOTE[ -Wclass-memaccess ] above.
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
std::memcpy(reinterpret_cast<void*>(&result.t), &val, sizeof(result));
|
||||
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
static_assert(
|
||||
sizeof(T) <= sizeof(StableIValue),
|
||||
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
||||
// if value has size less than sizeof(StableIValue), then only lowest bytes
|
||||
// have to be updated
|
||||
std::memcpy(
|
||||
reinterpret_cast<void*>(&result.t),
|
||||
reinterpret_cast<unsigned char*>(&val) + sizeof(StableIValue) -
|
||||
sizeof(result),
|
||||
sizeof(result));
|
||||
#else
|
||||
#error "Unexpected or undefined __BYTE_ORDER__"
|
||||
#endif
|
||||
return result.t;
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for StableIValue => torch::headeronly::ScalarType
|
||||
template <>
|
||||
struct ToImpl<ScalarType> {
|
||||
static ScalarType call(StableIValue val) {
|
||||
int32_t shim_scalartype = to<int32_t>(val);
|
||||
if (shim_scalartype == aoti_torch_dtype_uint8()) {
|
||||
return ScalarType::Byte;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_int8()) {
|
||||
return ScalarType::Char;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_int16()) {
|
||||
return ScalarType::Short;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_int32()) {
|
||||
return ScalarType::Int;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_int64()) {
|
||||
return ScalarType::Long;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_float16()) {
|
||||
return ScalarType::Half;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_float32()) {
|
||||
return ScalarType::Float;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_float64()) {
|
||||
return ScalarType::Double;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_complex32()) {
|
||||
return ScalarType::ComplexHalf;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_complex64()) {
|
||||
return ScalarType::ComplexFloat;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_complex128()) {
|
||||
return ScalarType::ComplexDouble;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_bool()) {
|
||||
return ScalarType::Bool;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_bfloat16()) {
|
||||
return ScalarType::BFloat16;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_float8_e5m2()) {
|
||||
return ScalarType::Float8_e5m2;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_float8_e4m3fn()) {
|
||||
return ScalarType::Float8_e4m3fn;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_float8_e5m2fnuz()) {
|
||||
return ScalarType::Float8_e5m2fnuz;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_float8_e4m3fnuz()) {
|
||||
return ScalarType::Float8_e4m3fnuz;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_uint16()) {
|
||||
return ScalarType::UInt16;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_uint32()) {
|
||||
return ScalarType::UInt32;
|
||||
} else if (shim_scalartype == aoti_torch_dtype_uint64()) {
|
||||
return ScalarType::UInt64;
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Not yet supported ScalarType " + std::to_string(shim_scalartype) +
|
||||
", please file an issue describing your use case.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for StableIValue => std::nullopt_t
|
||||
template <>
|
||||
struct ToImpl<std::nullopt_t> {
|
||||
static std::nullopt_t call(StableIValue val) {
|
||||
// val should be equivalent to from(nullptr)
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for StableIValue => std::optional, see [Handling
|
||||
// std::optional] as the semantic is the same but in reverse direction as we go
|
||||
// from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
|
||||
template <typename T>
|
||||
struct ToImpl<std::optional<T>> {
|
||||
static std::optional<T> call(StableIValue val) {
|
||||
auto sivp = to<StableIValue*>(val);
|
||||
|
||||
// sivp is either nullptr or a pointer to a StableIValue
|
||||
if (sivp == nullptr) {
|
||||
return {};
|
||||
}
|
||||
auto inner_val = to<T>(*sivp);
|
||||
|
||||
// free the memory associated with StableIValue* sivp
|
||||
delete sivp;
|
||||
|
||||
return std::make_optional(inner_val);
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for StableIValue => torch::stable::Tensor
|
||||
// The resulting stable::Tensor steals ownership of the input's
|
||||
// underlying AtenTensorHandle.
|
||||
template <>
|
||||
struct ToImpl<torch::stable::Tensor> {
|
||||
static torch::stable::Tensor call(StableIValue val) {
|
||||
return torch::stable::Tensor(to<AtenTensorHandle>(val));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Expose the partially templated class functions through single functions
|
||||
template <typename T>
|
||||
StableIValue from(T val) {
|
||||
return detail::FromImpl<T>::call(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
StableIValue from(const std::optional<T>& val) {
|
||||
return detail::FromImpl<std::optional<T>>::call(val);
|
||||
}
|
||||
|
||||
// The below overload is used! See https://godbolt.org/z/859cshxrW
|
||||
// We are suppressing the warning for versions clang12- and gcc11-
|
||||
[[maybe_unused]] StableIValue from(const torch::stable::Tensor& val) {
|
||||
return detail::FromImpl<torch::stable::Tensor>::call(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T to(StableIValue val) {
|
||||
return detail::ToImpl<T>::call(val);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// end to helpers for converting between StableIValue and T
|
||||
// =============================================================================
|
||||
|
||||
} // namespace
|
||||
|
|
@ -1,166 +1,4 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
#include <climits>
|
||||
#include <memory>
|
||||
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
|
||||
namespace torch::stable {
|
||||
|
||||
using DeviceIndex = torch::stable::accelerator::DeviceIndex;
|
||||
|
||||
// The torch::stable::Tensor class is a highlevel C++ wrapper around
|
||||
// the C shim Tensor APIs. We've modeled this class after TensorBase, as custom
|
||||
// op kernels only really need to interact with Tensor metadata (think sizes,
|
||||
// strides, device, dtype). Other functions on Tensor (like empty_like) should
|
||||
// live like the ATen op that they are and exist outside of this struct.
|
||||
//
|
||||
// There are several goals of this class over AtenTensorHandle and
|
||||
// RAIIAtenTensorHandle:
|
||||
// 1. torch::stable::Tensor is a nicer UX much closer to torch::Tensor than the
|
||||
// C APIs with AtenTensorHandle. Under the hood we still call to these C shim
|
||||
// APIs to preserve stability.
|
||||
// 2. RAIIAtenTensorHandle boils down to a uniq_ptr that forces the user to pass
|
||||
// around ownership. This makes it difficult to pass one input into 2
|
||||
// different functions, e.g., doing something like c = a(t) + b(t) for
|
||||
// stable::Tensor t. Thus, we use a shared_ptr here.
|
||||
class Tensor {
|
||||
private:
|
||||
std::shared_ptr<AtenTensorOpaque> ath_;
|
||||
|
||||
public:
|
||||
// Construct a stable::Tensor with an uninitialized AtenTensorHandle (ATH)
|
||||
// Steals ownership from the ATH
|
||||
Tensor() {
|
||||
AtenTensorHandle ret;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&ret));
|
||||
ath_ = std::shared_ptr<AtenTensorOpaque>(ret, [](AtenTensorHandle ath) {
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
|
||||
});
|
||||
}
|
||||
|
||||
// Construct a stable::Tensor from an AtenTensorHandle (ATH)
|
||||
// Steals ownership from the ATH
|
||||
explicit Tensor(AtenTensorHandle ath)
|
||||
: ath_(ath, [](AtenTensorHandle ath) {
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
|
||||
}) {}
|
||||
|
||||
// Copy and move constructors can be default cuz the underlying handle is a
|
||||
// shared_ptr
|
||||
Tensor(const Tensor& other) = default;
|
||||
Tensor(Tensor&& other) noexcept = default;
|
||||
|
||||
// Copy and move assignment operators can be default cuz the underlying handle
|
||||
// is a shared_ptr
|
||||
Tensor& operator=(const Tensor& other) = default;
|
||||
Tensor& operator=(Tensor&& other) noexcept = default;
|
||||
|
||||
// Destructor can be default: shared ptr has custom deletion logic
|
||||
~Tensor() = default;
|
||||
|
||||
// Returns a borrowed reference to the AtenTensorHandle
|
||||
AtenTensorHandle get() const {
|
||||
return ath_.get();
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// C-shimified TensorBase APIs: the below APIs have the same signatures and
|
||||
// semantics as their counterparts in TensorBase.h.
|
||||
// =============================================================================
|
||||
|
||||
void* data_ptr() const {
|
||||
void* data_ptr;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr));
|
||||
return data_ptr;
|
||||
}
|
||||
|
||||
int64_t dim() const {
|
||||
int64_t dim;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim));
|
||||
return dim;
|
||||
}
|
||||
|
||||
int64_t numel() const {
|
||||
int64_t numel;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel));
|
||||
return numel;
|
||||
}
|
||||
|
||||
// note: this is a subset of the original TensorBase API. It takes no
|
||||
// arguments whereas the original API takes in a kwarg of memory format.
|
||||
// Here, we assume the default contiguous memory format.
|
||||
bool is_contiguous() const {
|
||||
bool is_contiguous;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_is_contiguous(ath_.get(), &is_contiguous));
|
||||
return is_contiguous;
|
||||
}
|
||||
|
||||
int64_t stride(int64_t dim) const {
|
||||
int64_t stride;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_stride(ath_.get(), dim, &stride));
|
||||
return stride;
|
||||
}
|
||||
|
||||
// This is almost the same API as the one in TensorBase.h, except
|
||||
// we add a check that the returned device_index is within the
|
||||
// range of int8_t.
|
||||
int8_t get_device() const {
|
||||
int32_t device_index;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_get_device_index(ath_.get(), &device_index));
|
||||
STD_TORCH_CHECK(
|
||||
device_index >= std::numeric_limits<int8_t>::min() &&
|
||||
device_index <= std::numeric_limits<int8_t>::max(),
|
||||
"Device index is out of range of return type int8_t, please use get_device_index() instead.");
|
||||
return static_cast<int8_t>(device_index);
|
||||
}
|
||||
|
||||
// The same as get_device but with two differences:
|
||||
// 1. it has a more suiting name
|
||||
// 2. it returns a DeviceIndex, which is int32_t in this world
|
||||
// that should be more stable than the likely shifting
|
||||
// DeviceIndex in libtorch (it is int8_t that might become int16_t)
|
||||
DeviceIndex get_device_index() const {
|
||||
int32_t device_index;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_get_device_index(ath_.get(), &device_index));
|
||||
return device_index;
|
||||
}
|
||||
|
||||
bool is_cuda() const {
|
||||
int32_t device_type;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_get_device_type(ath_.get(), &device_type));
|
||||
return device_type == aoti_torch_device_type_cuda();
|
||||
}
|
||||
|
||||
bool is_cpu() const {
|
||||
int32_t device_type;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_get_device_type(ath_.get(), &device_type));
|
||||
return device_type == aoti_torch_device_type_cpu();
|
||||
}
|
||||
|
||||
int64_t size(int64_t dim) const {
|
||||
int64_t size;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(ath_.get(), dim, &size));
|
||||
return size;
|
||||
}
|
||||
|
||||
bool defined() const {
|
||||
bool defined;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_is_defined(ath_.get(), &defined));
|
||||
return defined;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// END of C-shimified TensorBase APIs
|
||||
// =============================================================================
|
||||
};
|
||||
|
||||
} // namespace torch::stable
|
||||
#include <torch/csrc/stable/tensor_inl.h>
|
||||
#include <torch/csrc/stable/tensor_struct.h>
|
||||
|
|
|
|||
24
torch/csrc/stable/tensor_inl.h
Normal file
24
torch/csrc/stable/tensor_inl.h
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
#pragma once
|
||||
|
||||
// This file implements tensor.h. We separated out the Tensor struct so that
|
||||
// other files can depend on the Tensor struct (like library.h) and the
|
||||
// implementations of the Tensor methods can depend on APIs in library.h
|
||||
// without circular dependencies.
|
||||
|
||||
#pragma once
|
||||
#include <torch/csrc/stable/stableivalue_conversions.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
namespace torch::stable {
|
||||
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
ScalarType Tensor::scalar_type() const {
|
||||
int32_t dtype;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(ath_.get(), &dtype));
|
||||
return to<ScalarType>(from(dtype));
|
||||
}
|
||||
|
||||
} // namespace torch::stable
|
||||
171
torch/csrc/stable/tensor_struct.h
Normal file
171
torch/csrc/stable/tensor_struct.h
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
#include <climits>
|
||||
#include <memory>
|
||||
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
|
||||
namespace torch::stable {
|
||||
|
||||
using accelerator::DeviceIndex;
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
// The torch::stable::Tensor class is a highlevel C++ wrapper around
|
||||
// the C shim Tensor APIs. We've modeled this class after TensorBase, as custom
|
||||
// op kernels only really need to interact with Tensor metadata (think sizes,
|
||||
// strides, device, dtype). Other functions on Tensor (like empty_like) should
|
||||
// live like the ATen op that they are and exist outside of this struct.
|
||||
//
|
||||
// There are several goals of this class over AtenTensorHandle and
|
||||
// RAIIAtenTensorHandle:
|
||||
// 1. torch::stable::Tensor is a nicer UX much closer to torch::Tensor than the
|
||||
// C APIs with AtenTensorHandle. Under the hood we still call to these C shim
|
||||
// APIs to preserve stability.
|
||||
// 2. RAIIAtenTensorHandle boils down to a uniq_ptr that forces the user to pass
|
||||
// around ownership. This makes it difficult to pass one input into 2
|
||||
// different functions, e.g., doing something like c = a(t) + b(t) for
|
||||
// stable::Tensor t. Thus, we use a shared_ptr here.
|
||||
class Tensor {
|
||||
private:
|
||||
std::shared_ptr<AtenTensorOpaque> ath_;
|
||||
|
||||
public:
|
||||
// Construct a stable::Tensor with an uninitialized AtenTensorHandle (ATH)
|
||||
// Steals ownership from the ATH
|
||||
Tensor() {
|
||||
AtenTensorHandle ret;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&ret));
|
||||
ath_ = std::shared_ptr<AtenTensorOpaque>(ret, [](AtenTensorHandle ath) {
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
|
||||
});
|
||||
}
|
||||
|
||||
// Construct a stable::Tensor from an AtenTensorHandle (ATH)
|
||||
// Steals ownership from the ATH
|
||||
explicit Tensor(AtenTensorHandle ath)
|
||||
: ath_(ath, [](AtenTensorHandle ath) {
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath));
|
||||
}) {}
|
||||
|
||||
// Copy and move constructors can be default cuz the underlying handle is a
|
||||
// shared_ptr
|
||||
Tensor(const Tensor& other) = default;
|
||||
Tensor(Tensor&& other) noexcept = default;
|
||||
|
||||
// Copy and move assignment operators can be default cuz the underlying handle
|
||||
// is a shared_ptr
|
||||
Tensor& operator=(const Tensor& other) = default;
|
||||
Tensor& operator=(Tensor&& other) noexcept = default;
|
||||
|
||||
// Destructor can be default: shared ptr has custom deletion logic
|
||||
~Tensor() = default;
|
||||
|
||||
// Returns a borrowed reference to the AtenTensorHandle
|
||||
AtenTensorHandle get() const {
|
||||
return ath_.get();
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// C-shimified TensorBase APIs: the below APIs have the same signatures and
|
||||
// semantics as their counterparts in TensorBase.h.
|
||||
// =============================================================================
|
||||
|
||||
void* data_ptr() const {
|
||||
void* data_ptr;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr));
|
||||
return data_ptr;
|
||||
}
|
||||
|
||||
int64_t dim() const {
|
||||
int64_t dim;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim));
|
||||
return dim;
|
||||
}
|
||||
|
||||
int64_t numel() const {
|
||||
int64_t numel;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel));
|
||||
return numel;
|
||||
}
|
||||
|
||||
// note: this is a subset of the original TensorBase API. It takes no
|
||||
// arguments whereas the original API takes in a kwarg of memory format.
|
||||
// Here, we assume the default contiguous memory format.
|
||||
bool is_contiguous() const {
|
||||
bool is_contiguous;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_is_contiguous(ath_.get(), &is_contiguous));
|
||||
return is_contiguous;
|
||||
}
|
||||
|
||||
int64_t stride(int64_t dim) const {
|
||||
int64_t stride;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_stride(ath_.get(), dim, &stride));
|
||||
return stride;
|
||||
}
|
||||
|
||||
// This is almost the same API as the one in TensorBase.h, except
|
||||
// we add a check that the returned device_index is within the
|
||||
// range of int8_t.
|
||||
int8_t get_device() const {
|
||||
int32_t device_index;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_get_device_index(ath_.get(), &device_index));
|
||||
STD_TORCH_CHECK(
|
||||
device_index >= std::numeric_limits<int8_t>::min() &&
|
||||
device_index <= std::numeric_limits<int8_t>::max(),
|
||||
"Device index is out of range of return type int8_t, please use get_device_index() instead.");
|
||||
return static_cast<int8_t>(device_index);
|
||||
}
|
||||
|
||||
// The same as get_device but with two differences:
|
||||
// 1. it has a more suiting name
|
||||
// 2. it returns a DeviceIndex, which is int32_t in this world
|
||||
// that should be more stable than the likely shifting
|
||||
// DeviceIndex in libtorch (it is int8_t that might become int16_t)
|
||||
DeviceIndex get_device_index() const {
|
||||
int32_t device_index;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_get_device_index(ath_.get(), &device_index));
|
||||
return device_index;
|
||||
}
|
||||
|
||||
bool is_cuda() const {
|
||||
int32_t device_type;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_get_device_type(ath_.get(), &device_type));
|
||||
return device_type == aoti_torch_device_type_cuda();
|
||||
}
|
||||
|
||||
bool is_cpu() const {
|
||||
int32_t device_type;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_get_device_type(ath_.get(), &device_type));
|
||||
return device_type == aoti_torch_device_type_cpu();
|
||||
}
|
||||
|
||||
int64_t size(int64_t dim) const {
|
||||
int64_t size;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(ath_.get(), dim, &size));
|
||||
return size;
|
||||
}
|
||||
|
||||
bool defined() const {
|
||||
bool defined;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_is_defined(ath_.get(), &defined));
|
||||
return defined;
|
||||
}
|
||||
|
||||
// defined in tensor-inl.h to avoid circular dependencies
|
||||
ScalarType scalar_type() const;
|
||||
|
||||
// =============================================================================
|
||||
// END of C-shimified TensorBase APIs
|
||||
// =============================================================================
|
||||
};
|
||||
|
||||
} // namespace torch::stable
|
||||
Loading…
Reference in New Issue
Block a user