mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
allow sharing tensor of simple types
Summary: If blob type switches between fp32, fp16 - for example - we should share the tensor buffer. This kind of switching can happen with memonger and in-place conversions. Reviewed By: bddppq Differential Revision: D5812333 fbshipit-source-id: 44d54bfe52cbda734db8c7f20d6970e4b51ee1e1
This commit is contained in:
parent
bd17684252
commit
ec2ee181c1
|
|
@ -136,13 +136,32 @@ TEST(TensorNonTypedTest, TensorChangeType) {
|
|||
dims[1] = 3;
|
||||
dims[2] = 5;
|
||||
TensorCPU tensor(dims);
|
||||
EXPECT_TRUE(tensor.mutable_data<int>() != nullptr);
|
||||
|
||||
auto* ptr = tensor.mutable_data<int>();
|
||||
EXPECT_TRUE(ptr != nullptr);
|
||||
EXPECT_TRUE(tensor.data<int>() != nullptr);
|
||||
EXPECT_TRUE(tensor.meta().Match<int>());
|
||||
|
||||
EXPECT_TRUE(tensor.mutable_data<float>() != nullptr);
|
||||
EXPECT_TRUE(tensor.data<float>() != nullptr);
|
||||
// int and float are same size, so should retain the pointer
|
||||
EXPECT_TRUE(tensor.mutable_data<float>() == (float*)ptr);
|
||||
EXPECT_TRUE(tensor.data<float>() == (const float*)ptr);
|
||||
EXPECT_TRUE(tensor.meta().Match<float>());
|
||||
|
||||
// float16 is smaller, so still should share buffer
|
||||
EXPECT_TRUE(tensor.mutable_data<float16>() == (float16*)ptr);
|
||||
EXPECT_TRUE(tensor.data<float16>() == (const float16*)ptr);
|
||||
EXPECT_TRUE(tensor.meta().Match<float16>());
|
||||
|
||||
// share the data with other tensor so that the pointer won't be reused
|
||||
// when we reallocate
|
||||
TensorCPU other_tensor(dims);
|
||||
other_tensor.ShareData(tensor);
|
||||
// but double is bigger, so it should allocate a new one
|
||||
auto* doubleptr = tensor.mutable_data<double>();
|
||||
EXPECT_TRUE(doubleptr != (double*)ptr);
|
||||
EXPECT_TRUE(doubleptr != nullptr);
|
||||
EXPECT_TRUE(tensor.data<double>() != nullptr);
|
||||
EXPECT_TRUE(tensor.meta().Match<double>());
|
||||
}
|
||||
|
||||
template <typename T> class TensorCPUTest : public ::testing::Test {};
|
||||
|
|
|
|||
|
|
@ -492,12 +492,19 @@ class Tensor {
|
|||
if (meta_ == meta && (data_.get() || size_ == 0)) {
|
||||
return data_.get();
|
||||
} else {
|
||||
bool had_special_dtor = meta_.dtor() != nullptr;
|
||||
meta_ = meta;
|
||||
CAFFE_ENFORCE_WITH_CALLER(
|
||||
size_ >= 0,
|
||||
"Tensor is not initialized. You probably need to call Resize() "
|
||||
"before calling mutable_data()");
|
||||
if (size_ == 0) {
|
||||
|
||||
// We can reuse the existing buffer if the current data does not have
|
||||
// a special destructor and the new data doesn't have a special
|
||||
// constructor.
|
||||
if (size_ == 0 ||
|
||||
(meta.ctor() == nullptr && !had_special_dtor &&
|
||||
capacity_ >= size_ * meta_.itemsize())) {
|
||||
return data_.get();
|
||||
}
|
||||
if (meta.ctor()) {
|
||||
|
|
@ -544,16 +551,16 @@ class Tensor {
|
|||
/**
|
||||
* Returns a typed pointer of the underlying storage.
|
||||
*
|
||||
* If the existing data does not match the desired type, it will be deleted
|
||||
* and a new storage will be created.
|
||||
* For fundamental types, we reuse possible existing storage if there
|
||||
* is sufficient capacity.
|
||||
*/
|
||||
template <typename T>
|
||||
inline T* mutable_data() {
|
||||
if ((size_ == 0 || data_.get()) && IsType<T>()) {
|
||||
return static_cast<T*>(data_.get());
|
||||
template <typename T>
|
||||
inline T* mutable_data() {
|
||||
if ((size_ == 0 || data_.get()) && IsType<T>()) {
|
||||
return static_cast<T*>(data_.get());
|
||||
}
|
||||
return static_cast<T*>(raw_mutable_data(TypeMeta::Make<T>()));
|
||||
}
|
||||
return static_cast<T*>(raw_mutable_data(TypeMeta::Make<T>()));
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user