mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
TL;DR: Moving to ScalarType in user extensions and removing deprecated dtypes. This change _modifies_ the from/to behavior between ScalarType and StableValue! Whereas before, user extensions could only in abstract pass around obfuscated dtypes appearing as int32_ts, now, users can confidently use torch::headeronly::ScalarType in their extensions for major scalar types. This PR enables ABI stability by adding a translation layer through the shim, so that even if the ScalarType enum values change in the future, user extensions need not fear. Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t). I then changed the test to test the scalar_type API. This code change required some refactoring because of circular dependencies. ## BC Breaking note This commit is (narrowly) BC-breaking for unpopular dtypes: `quint*`s, `qint*`s, `Bits*`, `dummy_uint*`s, `dummy_int*`s, `Float8_e8m0fnu`, and `Float4_e2m1fn_x2` in the narrow use case where an extension retrieves a Tensor dtype of the above and passes it into `aoti_torch_call_dispatcher`. As of now, I believe there are 0 users of this use case, so the benefits of this change significantly justify BC-breaking this API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160557 Approved by: https://github.com/mikaylagawarecki, https://github.com/malfet
25 lines
732 B
C++
25 lines
732 B
C++
#pragma once
|
|
|
|
// This file implements tensor.h. We separated out the Tensor struct so that
|
|
// other files can depend on the Tensor struct (like library.h) and the
|
|
// implementations of the Tensor methods can depend on APIs in library.h
|
|
// without circular dependencies.
|
|
|
|
#pragma once
|
|
#include <torch/csrc/stable/stableivalue_conversions.h>
|
|
#include <torch/csrc/stable/tensor.h>
|
|
#include <torch/headeronly/core/ScalarType.h>
|
|
#include <torch/headeronly/util/shim_utils.h>
|
|
|
|
namespace torch::stable {
|
|
|
|
using torch::headeronly::ScalarType;
|
|
|
|
ScalarType Tensor::scalar_type() const {
|
|
int32_t dtype;
|
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(ath_.get(), &dtype));
|
|
return to<ScalarType>(from(dtype));
|
|
}
|
|
|
|
} // namespace torch::stable
|