Tensorimpl cleanup try 2 (#72336)

Summary:
This reverts the previous PR and add some comments to make it clear what the intent is.
Also removes some extra static_assert that are not needed (at least for the compilers I tried).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72336

Reviewed By: r-barnes

Differential Revision: D34006722

Pulled By: albanD

fbshipit-source-id: 290fb89a2d2c66a0d1c3651198b31d21216ec230
(cherry picked from commit 76f0aaa765)
This commit is contained in:
Alban Desmaison 2022-02-07 08:37:32 -08:00 committed by PyTorch MergeBot
parent 9d8f0c7842
commit 9f9b9c48e5

View File

@ -2721,6 +2721,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
// We use a templatized class to both contain the logic of checking the sizes
// as well as to provide compile-time information that might be useful in
// figuring out why sizes may have changed.
// All the compile time information is given by the template fields that are
// always printed by the compiler when the static_assert fails.
template <
size_t cplusplus = __cplusplus,
size_t clang_ver_major = C10_CLANG_MAJOR_VERSION,
@ -2731,6 +2733,44 @@ template <
size_t cuda_version_major = C10_CUDA_VERSION_MAJOR,
size_t ptr_size = sizeof(void*)>
class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
// Names of (non-bitfield) fields in TensorImpl; used to provide
// compile-time info about fields whose size changes unexpectedly.
enum class FieldNameEnum {
storage_,
autograd_meta_,
named_tensor_meta_,
version_counter_,
pyobj_interpreter_,
pyobj_,
sizes_and_strides_,
storage_offset_,
numel_,
data_type_,
device_opt_,
key_set_,
TOTAL_SIZE
};
// Provides compile-time equality check that reveals what numbers
// were used and on which quantity
template <size_t Actual, size_t Expected, FieldNameEnum FiledName>
constexpr static bool are_equal() {
static_assert(
Actual == Expected,
"Actual and Expected sizes of a field did not match!");
return true;
}
// Provides compile-time <= check that reveals what numbers
// were used and on which quantity
template <size_t Actual, size_t Expected, FieldNameEnum FiledName>
constexpr static bool is_le() {
static_assert(
Actual <= Expected,
"Actual and Expected sizes of a field did not match!");
return true;
}
public:
// Compile-time check that TensorImpl field sizes are as expected
//
@ -2753,19 +2793,19 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
constexpr size_t tsize = 20 * sizeof(int64_t);
// clang-format off
static_assert(sizeof(storage_) == 4, "Size of storage_ changed!");
static_assert(sizeof(autograd_meta_) == 4, "Size of autograd_meta_ changed!");
static_assert(sizeof(named_tensor_meta_) == 4, "Size of named_tensor_meta_ changed!");
static_assert(sizeof(version_counter_) == 4, "Size of version_counter_ changed!");
static_assert(sizeof(pyobj_interpreter_) == 4, "Size of pyobj_interpreter_ changed!");
static_assert(sizeof(pyobj_) == 4, "Size of pyobj_ changed!");
static_assert(sizeof(sizes_and_strides_) <= 88,"Size of sizes_and_strides_ changed!");
static_assert(sizeof(storage_offset_) == 8, "Size of storage_offset_ changed!");
static_assert(sizeof(numel_) == 8, "Size of numel_ changed!");
static_assert(sizeof(data_type_) == 2, "Size of data_type_ changed!");
static_assert(sizeof(device_opt_) == 3, "Size of device_opt_ changed!");
static_assert(sizeof(key_set_) == 8, "Size of key_set_ changed!");
static_assert(sizeof(TensorImpl) <= tsize, "Total size changed!");
are_equal<sizeof(storage_), 4, FieldNameEnum::storage_>();
are_equal<sizeof(autograd_meta_), 4, FieldNameEnum::autograd_meta_>();
are_equal<sizeof(named_tensor_meta_), 4, FieldNameEnum::named_tensor_meta_>();
are_equal<sizeof(version_counter_), 4, FieldNameEnum::version_counter_>();
are_equal<sizeof(pyobj_interpreter_), 4, FieldNameEnum::pyobj_interpreter_>();
are_equal<sizeof(pyobj_), 4, FieldNameEnum::pyobj_>();
is_le<sizeof(sizes_and_strides_), 88, FieldNameEnum::sizes_and_strides_>();
are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();
are_equal<sizeof(data_type_), 2, FieldNameEnum::data_type_>();
are_equal<sizeof(device_opt_), 3, FieldNameEnum::device_opt_>();
are_equal<sizeof(key_set_), 8, FieldNameEnum::key_set_>();
is_le<sizeof(TensorImpl), tsize, FieldNameEnum::TOTAL_SIZE>();
// clang-format on
return true;
@ -2776,22 +2816,22 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
constexpr size_t tsize = 26 * sizeof(int64_t);
// clang-format off
static_assert(sizeof(storage_) == 8, "Size of storage_ changed!");
are_equal<sizeof(storage_), 8, FieldNameEnum::storage_>();
// On some systems involving NVCC the size of unique_ptr is 16 bytes. We haven't
// figured out how to detect those via macro preprocessors yet, so we use <=
// comparisons for the relevant fields.
static_assert(sizeof(autograd_meta_) <= 16,"Size of autograd_meta_ changed!");
static_assert(sizeof(named_tensor_meta_) <= 16,"Size of named_tensor_meta_ changed!");
static_assert(sizeof(version_counter_) == 8, "Size of version_counter_ changed!");
static_assert(sizeof(pyobj_interpreter_) == 8, "Size of pyobj_interpreter_ changed!");
static_assert(sizeof(pyobj_) == 8, "Size of pyobj_ changed!");
static_assert(sizeof(sizes_and_strides_) == 88,"Size of sizes_and_strides_ changed!");
static_assert(sizeof(storage_offset_) == 8, "Size of storage_offset_ changed!");
static_assert(sizeof(numel_) == 8, "Size of numel_ changed!");
static_assert(sizeof(data_type_) == 2, "Size of data_type_ changed!");
static_assert(sizeof(device_opt_) == 3, "Size of device_opt_ changed!");
static_assert(sizeof(key_set_) == 8, "Size of key_set_ changed!");
static_assert(sizeof(TensorImpl) <= tsize, "Total size changed!");
is_le<sizeof(autograd_meta_), 16, FieldNameEnum::autograd_meta_>();
is_le<sizeof(named_tensor_meta_), 16, FieldNameEnum::named_tensor_meta_>();
are_equal<sizeof(version_counter_), 8, FieldNameEnum::version_counter_>();
are_equal<sizeof(pyobj_interpreter_), 8, FieldNameEnum::pyobj_interpreter_>();
are_equal<sizeof(pyobj_), 8, FieldNameEnum::pyobj_>();
are_equal<sizeof(sizes_and_strides_), 88, FieldNameEnum::sizes_and_strides_>();
are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();
are_equal<sizeof(data_type_), 2, FieldNameEnum::data_type_>();
are_equal<sizeof(device_opt_), 3, FieldNameEnum::device_opt_>();
are_equal<sizeof(key_set_), 8, FieldNameEnum::key_set_>();
is_le<sizeof(TensorImpl), tsize, FieldNameEnum::TOTAL_SIZE>();
// clang-format on
return true;
@ -2804,7 +2844,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
// a static assert to prove there is no run-time behaviour.
// Since the methods we call return either true or fail their
// own static_asserts, we should never see the error messages
// below.
// below. We have to provide it though for c++ <17.
static_assert(
C10_TensorImpl_Size_Check_Dummy_Class<>::check_sizes(),
"You should not see this message.");