mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
distinguish mutability of TensorImpl::data<T>() (#98719)
There already is a mutable_data<T>() with different semantics, so we introduce new names: TensorImpl::(mutable_)?data_dtype_initialized<T>(). Differential Revision: [D44824778](https://our.internmc.facebook.com/intern/diff/D44824778/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/98719 Approved by: https://github.com/ezyang
This commit is contained in:
parent
9c98f2ceb7
commit
ee0143bf65
|
|
@ -189,7 +189,7 @@ void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
|||
}
|
||||
|
||||
uint64_t input_seed;
|
||||
auto new_rng_state = new_state.data<uint8_t>();
|
||||
auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
|
||||
memcpy(&input_seed, new_rng_state, seed_size);
|
||||
this->set_current_seed(input_seed);
|
||||
int64_t philox_offset = 0;
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ void MPSGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
|||
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
|
||||
|
||||
uint64_t input_seed = default_rng_seed_val;
|
||||
auto new_rng_state = new_state.data<uint8_t>();
|
||||
auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
|
||||
memcpy(&input_seed, new_rng_state + states_size, seed_size);
|
||||
this->set_current_seed(input_seed);
|
||||
// state.data must be copied after input_seed to not reset the state in set_current_seed()
|
||||
|
|
|
|||
|
|
@ -1496,17 +1496,46 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
* for you; this class is available from 'Tensor'.
|
||||
*/
|
||||
template <typename T>
|
||||
inline T* data() const {
|
||||
const T* data_dtype_initialized() const {
|
||||
return data_dtype_initialized_impl<const T>(
|
||||
[this] { return static_cast<const T*>(storage_.data()); });
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a mutable typed data pointer to the actual data which this
|
||||
* tensor refers to. This checks that the requested type (from the
|
||||
* template parameter) matches the internal type of the tensor.
|
||||
*
|
||||
* It is invalid to call data() on a dtype-uninitialized tensor, even if
|
||||
* the size is 0.
|
||||
*
|
||||
* WARNING: If a tensor is not contiguous, you MUST use strides when
|
||||
* performing index calculations to determine the location of elements in
|
||||
* the tensor. We recommend using 'TensorAccessor' to handle this computation
|
||||
* for you; this class is available from 'Tensor'.
|
||||
*/
|
||||
template <typename T>
|
||||
T* mutable_data_dtype_initialized() {
|
||||
return data_dtype_initialized_impl<T>(
|
||||
[this] { return static_cast<T*>(storage_.mutable_data()); });
|
||||
}
|
||||
|
||||
private:
|
||||
// Shared implementation of data_dtype_initialized() and
|
||||
// mutable_data_dtype_initialized().
|
||||
template <typename T, typename Func>
|
||||
T* data_dtype_initialized_impl(const Func& get_data) const {
|
||||
TORCH_CHECK(
|
||||
data_type_.Match<T>(),
|
||||
data_type_.Match<std::remove_const_t<T>>(),
|
||||
"Tensor type mismatch, caller expects elements to be ",
|
||||
caffe2::TypeMeta::TypeName<T>(),
|
||||
caffe2::TypeMeta::TypeName<std::remove_const_t<T>>(),
|
||||
", while tensor contains ",
|
||||
data_type_.name(),
|
||||
". ");
|
||||
return legacy_mutable_data_ptr_impl<T>();
|
||||
return data_ptr_impl_impl<T>(get_data);
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* More efficient helper for Tensor::data_ptr(). Like data<T>(), but
|
||||
* does not do a type check. Unlike the untemplated data(), does
|
||||
|
|
@ -1514,17 +1543,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
*/
|
||||
template <typename T>
|
||||
inline T* mutable_data_ptr_impl() {
|
||||
return legacy_mutable_data_ptr_impl<T>();
|
||||
return data_ptr_impl_impl<T>(
|
||||
[this] { return static_cast<T*>(storage_.mutable_data()); });
|
||||
}
|
||||
|
||||
private:
|
||||
// The real implementation of mutable_data_ptr_impl, but in a
|
||||
// non-const method.
|
||||
//
|
||||
// TODO: move the implementation into mutable_data_ptr_impl() and
|
||||
// delete this when data<T>() is no longer const.
|
||||
template <typename T>
|
||||
inline T* legacy_mutable_data_ptr_impl() const {
|
||||
// Shared implementation of mutable_data_ptr_impl() and the future
|
||||
// mutable_data_ptr_impl().
|
||||
template <typename T, typename Func>
|
||||
T* data_ptr_impl_impl(const Func& get_data) const {
|
||||
TORCH_CHECK(
|
||||
has_storage(),
|
||||
"Cannot access data pointer of Tensor that doesn't have storage");
|
||||
|
|
@ -1534,7 +1561,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
"Caffe2 uses a lazy allocation, so you will need to call "
|
||||
"mutable_data() or raw_mutable_data() to actually allocate memory.");
|
||||
// Caller does the type check.
|
||||
return static_cast<T*>(storage_.mutable_data()) + storage_offset_;
|
||||
return get_data() + storage_offset_;
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -356,7 +356,7 @@ class TORCH_API Tensor final {
|
|||
|
||||
template <typename T>
|
||||
inline T* data() const {
|
||||
return impl_.get()->data<T>();
|
||||
return impl_.get()->mutable_data_dtype_initialized<T>();
|
||||
}
|
||||
|
||||
inline void* raw_mutable_data(const TypeMeta meta) const {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user