mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Move TensorImpl::CopyFrom to caffe2::Tensor (1/2) (#14656)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14656 This diff doesn't move it yet, but prepares it to be moved, i.e. removes all access to class internals. dzhulgakov: Please comment on if you think it still makes sense to land this even though it's not blocking anymore since we're going to move at::CopyBytes anyhow. ezyang: There's some changes in the implementation, especially handling undefined dest tensors. Please review carefully. Reviewed By: ezyang Differential Revision: D13287688 fbshipit-source-id: 17800ca8a79ab1633f23be58d96f99a160d8ed24
This commit is contained in:
parent
dc72a5e02c
commit
070f33f154
|
|
@ -830,26 +830,23 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
src.storage_initialized(),
|
||||
"Cannot copy from an uninitialized Tensor");
|
||||
|
||||
if ((void*)&src == (void*)this) {
|
||||
if (&src == this) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Test if we need to allocate a new storage
|
||||
// Uninitialized storages are guaranteed to be uniquely owned,
|
||||
// so we don't need to swap in this case.
|
||||
if (storage_initialized()) {
|
||||
// If the dtype changed, we need to reallocate storage.
|
||||
if (data_type_ != src.dtype()) {
|
||||
// NB: copy preserves device_type
|
||||
// This storage will get initialized by the mutable_data call below.
|
||||
storage_ = Storage(device_type(), src.dtype());
|
||||
}
|
||||
// If the dtype changed, we need to reallocate storage.
|
||||
if (dtype() != src.dtype()) {
|
||||
// NB: copy preserves device_type
|
||||
// This storage will get initialized by the mutable_data call below.
|
||||
set_storage(at::Storage(device_type(), src.dtype()));
|
||||
}
|
||||
data_type_ = src.dtype();
|
||||
Resize(src.sizes());
|
||||
|
||||
if (numel() > 0) {
|
||||
if (data_type_.copy()) {
|
||||
if (dtype().copy()) {
|
||||
AT_ASSERTM(
|
||||
device_type() == DeviceType::CPU,
|
||||
"In CopyFrom source and dest tensors must both be CPU for "
|
||||
|
|
@ -860,7 +857,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
"In CopyFrom source and dest tensors must both be CPU for "
|
||||
"non-POD copy, but src tensor was ",
|
||||
src.device_type());
|
||||
data_type_.copy()(src.data(), raw_mutable_data(data_type_), numel());
|
||||
dtype().copy()(src.data(), raw_mutable_data(data_type_), numel());
|
||||
} else {
|
||||
// The following copy uses the current (thread local) stream for copying
|
||||
// and also takes the GPU id from the device() field passed in.
|
||||
|
|
@ -873,7 +870,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
// properly.
|
||||
//
|
||||
// note: raw_mutable_data initializes device here
|
||||
void* new_data = raw_mutable_data(data_type_);
|
||||
void* new_data = raw_mutable_data(dtype());
|
||||
CopyBytes(
|
||||
numel() * itemsize(),
|
||||
src.data(),
|
||||
|
|
@ -1233,6 +1230,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
return data_type_ != caffe2::TypeMeta();
|
||||
}
|
||||
|
||||
void set_storage(at::Storage storage) {
|
||||
storage_ = std::move(storage);
|
||||
data_type_ = storage_.dtype();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
// The Caffe2 Resize() method supports being called both as Resize({2,2}) as
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ class CAFFE2_API Tensor final {
|
|||
}
|
||||
|
||||
void CopyFrom(const Tensor& src, bool async = false) const {
|
||||
impl_.get()->CopyFrom(*src.impl_.get(), async);
|
||||
impl_.get()->CopyFrom(*src.impl_, async);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user