mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
test basic tensor interop
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12249 Differential Revision: D13469356 Pulled By: li-roy fbshipit-source-id: b49748462aa44ac34b8ce79783f2c895a537a232
This commit is contained in:
parent
70f0c4745b
commit
50fbf79451
|
|
@ -15,6 +15,7 @@ list(APPEND ATen_CPU_TEST_SRCS
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/dlconvertor_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/native_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scalar_tensor_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/tensor_interop_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/test_parallel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
|
||||
|
|
|
|||
141
aten/src/ATen/test/tensor_interop_test.cpp
Normal file
141
aten/src/ATen/test/tensor_interop_test.cpp
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ATen/ATen.h"
|
||||
#include <caffe2/core/init.h>
|
||||
#include <caffe2/core/operator.h>
|
||||
|
||||
TEST(TestTensorInterop, Caffe2ToPytorchSimpleLegacy) {
|
||||
caffe2::Tensor c2_tensor(caffe2::CPU);
|
||||
c2_tensor.Resize(4, 4);
|
||||
auto data = c2_tensor.mutable_data<int64_t>();
|
||||
for (int64_t i = 0; i < 16; i++) {
|
||||
data[i] = i;
|
||||
}
|
||||
|
||||
// TODO: find out why calling data on tensor doesn't work
|
||||
at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
|
||||
at::TensorImpl* impl = at_tensor.unsafeGetTensorImpl();
|
||||
|
||||
auto it = impl->data<int64_t>();
|
||||
for (int64_t i = 0; i < 16; i++) {
|
||||
ASSERT_EQ(it[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestTensorInterop, Caffe2ToPytorchSimple) {
|
||||
caffe2::Tensor c2_tensor = caffe2::empty({4, 4}, at::kLong);
|
||||
auto data = c2_tensor.mutable_data<int64_t>();
|
||||
for (int64_t i = 0; i < 16; i++) {
|
||||
data[i] = i;
|
||||
}
|
||||
at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
|
||||
at::TensorImpl* impl = at_tensor.unsafeGetTensorImpl();
|
||||
|
||||
auto it = impl->data<int64_t>();
|
||||
for (int64_t i = 0; i < 16; i++) {
|
||||
ASSERT_EQ(it[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestTensorInterop, Caffe2ToPytorchOp) {
|
||||
caffe2::Tensor c2_tensor(caffe2::CPU);
|
||||
c2_tensor.Resize(3, 3);
|
||||
auto data = c2_tensor.mutable_data<int64_t>();
|
||||
for (int64_t i = 0; i < 9; i++) {
|
||||
data[i] = i;
|
||||
}
|
||||
at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
|
||||
|
||||
ASSERT_EQ(at::sum(at_tensor).item<int64_t>(), 36);
|
||||
}
|
||||
|
||||
TEST(TestTensorInterop, Caffe2ToPytorchUnsupportedDevice) {
|
||||
caffe2::Tensor c2_tensor(caffe2::IDEEP);
|
||||
at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
|
||||
ASSERT_ANY_THROW(at::sum(at_tensor));
|
||||
}
|
||||
|
||||
TEST(TestTensorInterop, PytorchToCaffe2Op) {
|
||||
caffe2::Workspace workspace;
|
||||
caffe2::NetDef net;
|
||||
|
||||
auto at_tensor_a = at::ones({5, 5}, at::dtype(at::kFloat));
|
||||
auto at_tensor_b = at::ones({5, 5}, at::dtype(at::kFloat));
|
||||
auto at_tensor_c = at::ones({5, 5}, at::dtype(at::kFloat));
|
||||
|
||||
auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), at_tensor_a.getIntrusivePtr());
|
||||
auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), at_tensor_b.getIntrusivePtr());
|
||||
|
||||
// Test ShareData as well
|
||||
{
|
||||
auto c2_tensor_c = XBlobGetMutableTensor(workspace.CreateBlob("c"), {0}, at::kCPU);
|
||||
c2_tensor_c.ResizeLike(at_tensor_c.getIntrusivePtr());
|
||||
c2_tensor_c.ShareData(at_tensor_c.getIntrusivePtr());
|
||||
}
|
||||
|
||||
{
|
||||
auto op = net.add_op();
|
||||
op->set_type("Sum");
|
||||
op->add_input("a");
|
||||
op->add_input("b");
|
||||
op->add_input("c");
|
||||
op->add_output("d");
|
||||
}
|
||||
|
||||
workspace.RunNetOnce(net);
|
||||
|
||||
auto result = XBlobGetMutableTensor(workspace.CreateBlob("d"), {5, 5}, at::kCPU);
|
||||
|
||||
auto it = result.data<float>();
|
||||
for (int64_t i = 0; i < 25; i++) {
|
||||
ASSERT_EQ(it[i], 3.0);
|
||||
}
|
||||
at::Tensor at_result(result.getIntrusivePtr());
|
||||
ASSERT_EQ(at::sum(at_result).item<float>(), 75);
|
||||
}
|
||||
|
||||
TEST(TestTensorInterop, PytorchToCaffe2SharedStorage) {
|
||||
caffe2::Workspace workspace;
|
||||
caffe2::NetDef net;
|
||||
|
||||
auto at_tensor_a = at::ones({5, 5}, at::dtype(at::kFloat));
|
||||
auto at_tensor_b = at_tensor_a.view({5, 5});
|
||||
|
||||
auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), at_tensor_a.getIntrusivePtr());
|
||||
auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), at_tensor_b.getIntrusivePtr());
|
||||
|
||||
{
|
||||
auto op = net.add_op();
|
||||
op->set_type("Add");
|
||||
op->add_input("a");
|
||||
op->add_input("b");
|
||||
op->add_output("c");
|
||||
}
|
||||
|
||||
workspace.RunNetOnce(net);
|
||||
|
||||
auto result = XBlobGetMutableTensor(workspace.CreateBlob("c"), {5, 5}, at::kCPU);
|
||||
auto it = result.data<float>();
|
||||
for (int64_t i = 0; i < 25; i++) {
|
||||
ASSERT_EQ(it[i], 2.0);
|
||||
}
|
||||
at::Tensor at_result(result.getIntrusivePtr());
|
||||
ASSERT_EQ(at::sum(at_result).item<float>(), 50);
|
||||
}
|
||||
|
||||
TEST(TestTensorInterop, PytorchToCaffe2Strided) {
|
||||
caffe2::Workspace workspace;
|
||||
caffe2::NetDef net;
|
||||
|
||||
auto at_tensor = at::ones({5, 5}, at::dtype(at::kFloat)).t();
|
||||
auto* c2_tensor = BlobSetTensor(workspace.CreateBlob("blob"), at_tensor.getIntrusivePtr());
|
||||
|
||||
{
|
||||
auto op = net.add_op();
|
||||
op->set_type("Sum");
|
||||
op->add_input("blob");
|
||||
op->add_output("out");
|
||||
}
|
||||
|
||||
ASSERT_ANY_THROW(workspace.RunNetOnce(net));
|
||||
}
|
||||
|
|
@ -15,6 +15,7 @@ VALGRIND=${VALGRIND:=ON}
|
|||
./dlconvertor_test
|
||||
./native_test
|
||||
./scalar_tensor_test
|
||||
./tensor_interop_test
|
||||
./undefined_tensor_test
|
||||
if [[ -x ./cudnn_test ]]; then
|
||||
./cudnn_test
|
||||
|
|
@ -37,6 +38,7 @@ fi
|
|||
if [ "$VALGRIND" == "ON" ]
|
||||
then
|
||||
valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./basic "[cpu]"
|
||||
valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./tensor_interop_test
|
||||
fi
|
||||
|
||||
popd
|
||||
|
|
|
|||
|
|
@ -28,6 +28,12 @@ class CAFFE2_API Tensor final {
|
|||
|
||||
public:
|
||||
Tensor() : impl_() {}
|
||||
Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
|
||||
: impl_(std::move(tensor_impl)) {
|
||||
if (impl_.get() == nullptr) {
|
||||
throw std::runtime_error("TensorBaseImpl with nullptr not supported");
|
||||
}
|
||||
}
|
||||
|
||||
operator bool() const {
|
||||
return impl_.defined();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user