distinguish mutability of TensorImpl::data<T>() (#98719)

There already is a mutable_data<T>() with different semantics, so we
introduce new names:
TensorImpl::(mutable_)?data_dtype_initialized<T>().

Differential Revision: [D44824778](https://our.internmc.facebook.com/intern/diff/D44824778/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98719
Approved by: https://github.com/ezyang
This commit is contained in:
mikey dagitses 2023-04-11 23:08:33 -07:00 committed by PyTorch MergeBot
parent 9c98f2ceb7
commit ee0143bf65
4 changed files with 43 additions and 16 deletions

View File

@ -189,7 +189,7 @@ void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
}
uint64_t input_seed;
auto new_rng_state = new_state.data<uint8_t>();
auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
memcpy(&input_seed, new_rng_state, seed_size);
this->set_current_seed(input_seed);
int64_t philox_offset = 0;

View File

@ -82,7 +82,7 @@ void MPSGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
uint64_t input_seed = default_rng_seed_val;
auto new_rng_state = new_state.data<uint8_t>();
auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
memcpy(&input_seed, new_rng_state + states_size, seed_size);
this->set_current_seed(input_seed);
// state.data must be copied after input_seed to not reset the state in set_current_seed()

View File

@ -1496,17 +1496,46 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* for you; this class is available from 'Tensor'.
*/
template <typename T>
inline T* data() const {
const T* data_dtype_initialized() const {
return data_dtype_initialized_impl<const T>(
[this] { return static_cast<const T*>(storage_.data()); });
}
/**
* Return a mutable typed data pointer to the actual data which this
* tensor refers to. This checks that the requested type (from the
* template parameter) matches the internal type of the tensor.
*
* It is invalid to call data() on a dtype-uninitialized tensor, even if
* the size is 0.
*
* WARNING: If a tensor is not contiguous, you MUST use strides when
* performing index calculations to determine the location of elements in
* the tensor. We recommend using 'TensorAccessor' to handle this computation
* for you; this class is available from 'Tensor'.
*/
template <typename T>
T* mutable_data_dtype_initialized() {
return data_dtype_initialized_impl<T>(
[this] { return static_cast<T*>(storage_.mutable_data()); });
}
private:
// Shared implementation of data_dtype_initialized() and
// mutable_data_dtype_initialized().
template <typename T, typename Func>
T* data_dtype_initialized_impl(const Func& get_data) const {
TORCH_CHECK(
data_type_.Match<T>(),
data_type_.Match<std::remove_const_t<T>>(),
"Tensor type mismatch, caller expects elements to be ",
caffe2::TypeMeta::TypeName<T>(),
caffe2::TypeMeta::TypeName<std::remove_const_t<T>>(),
", while tensor contains ",
data_type_.name(),
". ");
return legacy_mutable_data_ptr_impl<T>();
return data_ptr_impl_impl<T>(get_data);
}
public:
/**
* More efficient helper for Tensor::data_ptr(). Like data<T>(), but
* does not do a type check. Unlike the untemplated data(), does
@ -1514,17 +1543,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
*/
template <typename T>
inline T* mutable_data_ptr_impl() {
return legacy_mutable_data_ptr_impl<T>();
return data_ptr_impl_impl<T>(
[this] { return static_cast<T*>(storage_.mutable_data()); });
}
private:
// The real implementation of mutable_data_ptr_impl, but in a
// non-const method.
//
// TODO: move the implementation into mutable_data_ptr_impl() and
// delete this when data<T>() is no longer const.
template <typename T>
inline T* legacy_mutable_data_ptr_impl() const {
// Shared implementation of mutable_data_ptr_impl() and the future
// mutable_data_ptr_impl().
template <typename T, typename Func>
T* data_ptr_impl_impl(const Func& get_data) const {
TORCH_CHECK(
has_storage(),
"Cannot access data pointer of Tensor that doesn't have storage");
@ -1534,7 +1561,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
"Caffe2 uses a lazy allocation, so you will need to call "
"mutable_data() or raw_mutable_data() to actually allocate memory.");
// Caller does the type check.
return static_cast<T*>(storage_.mutable_data()) + storage_offset_;
return get_data() + storage_offset_;
}
public:

View File

@ -356,7 +356,7 @@ class TORCH_API Tensor final {
template <typename T>
inline T* data() const {
return impl_.get()->data<T>();
return impl_.get()->mutable_data_dtype_initialized<T>();
}
inline void* raw_mutable_data(const TypeMeta meta) const {