test basic tensor interop

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12249

Differential Revision: D13469356

Pulled By: li-roy

fbshipit-source-id: b49748462aa44ac34b8ce79783f2c895a537a232
This commit is contained in:
Roy Li 2018-12-27 17:01:19 -08:00 committed by Facebook Github Bot
parent 70f0c4745b
commit 50fbf79451
4 changed files with 150 additions and 0 deletions

View File

@ -15,6 +15,7 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/dlconvertor_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/native_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scalar_tensor_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor_interop_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_parallel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp

View File

@ -0,0 +1,141 @@
#include "gtest/gtest.h"
#include "ATen/ATen.h"
#include <caffe2/core/init.h>
#include <caffe2/core/operator.h>
TEST(TestTensorInterop, Caffe2ToPytorchSimpleLegacy) {
caffe2::Tensor c2_tensor(caffe2::CPU);
c2_tensor.Resize(4, 4);
auto data = c2_tensor.mutable_data<int64_t>();
for (int64_t i = 0; i < 16; i++) {
data[i] = i;
}
// TODO: find out why calling data on tensor doesn't work
at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
at::TensorImpl* impl = at_tensor.unsafeGetTensorImpl();
auto it = impl->data<int64_t>();
for (int64_t i = 0; i < 16; i++) {
ASSERT_EQ(it[i], i);
}
}
TEST(TestTensorInterop, Caffe2ToPytorchSimple) {
caffe2::Tensor c2_tensor = caffe2::empty({4, 4}, at::kLong);
auto data = c2_tensor.mutable_data<int64_t>();
for (int64_t i = 0; i < 16; i++) {
data[i] = i;
}
at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
at::TensorImpl* impl = at_tensor.unsafeGetTensorImpl();
auto it = impl->data<int64_t>();
for (int64_t i = 0; i < 16; i++) {
ASSERT_EQ(it[i], i);
}
}
TEST(TestTensorInterop, Caffe2ToPytorchOp) {
caffe2::Tensor c2_tensor(caffe2::CPU);
c2_tensor.Resize(3, 3);
auto data = c2_tensor.mutable_data<int64_t>();
for (int64_t i = 0; i < 9; i++) {
data[i] = i;
}
at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
ASSERT_EQ(at::sum(at_tensor).item<int64_t>(), 36);
}
TEST(TestTensorInterop, Caffe2ToPytorchUnsupportedDevice) {
caffe2::Tensor c2_tensor(caffe2::IDEEP);
at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
ASSERT_ANY_THROW(at::sum(at_tensor));
}
TEST(TestTensorInterop, PytorchToCaffe2Op) {
caffe2::Workspace workspace;
caffe2::NetDef net;
auto at_tensor_a = at::ones({5, 5}, at::dtype(at::kFloat));
auto at_tensor_b = at::ones({5, 5}, at::dtype(at::kFloat));
auto at_tensor_c = at::ones({5, 5}, at::dtype(at::kFloat));
auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), at_tensor_a.getIntrusivePtr());
auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), at_tensor_b.getIntrusivePtr());
// Test ShareData as well
{
auto c2_tensor_c = XBlobGetMutableTensor(workspace.CreateBlob("c"), {0}, at::kCPU);
c2_tensor_c.ResizeLike(at_tensor_c.getIntrusivePtr());
c2_tensor_c.ShareData(at_tensor_c.getIntrusivePtr());
}
{
auto op = net.add_op();
op->set_type("Sum");
op->add_input("a");
op->add_input("b");
op->add_input("c");
op->add_output("d");
}
workspace.RunNetOnce(net);
auto result = XBlobGetMutableTensor(workspace.CreateBlob("d"), {5, 5}, at::kCPU);
auto it = result.data<float>();
for (int64_t i = 0; i < 25; i++) {
ASSERT_EQ(it[i], 3.0);
}
at::Tensor at_result(result.getIntrusivePtr());
ASSERT_EQ(at::sum(at_result).item<float>(), 75);
}
TEST(TestTensorInterop, PytorchToCaffe2SharedStorage) {
caffe2::Workspace workspace;
caffe2::NetDef net;
auto at_tensor_a = at::ones({5, 5}, at::dtype(at::kFloat));
auto at_tensor_b = at_tensor_a.view({5, 5});
auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), at_tensor_a.getIntrusivePtr());
auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), at_tensor_b.getIntrusivePtr());
{
auto op = net.add_op();
op->set_type("Add");
op->add_input("a");
op->add_input("b");
op->add_output("c");
}
workspace.RunNetOnce(net);
auto result = XBlobGetMutableTensor(workspace.CreateBlob("c"), {5, 5}, at::kCPU);
auto it = result.data<float>();
for (int64_t i = 0; i < 25; i++) {
ASSERT_EQ(it[i], 2.0);
}
at::Tensor at_result(result.getIntrusivePtr());
ASSERT_EQ(at::sum(at_result).item<float>(), 50);
}
TEST(TestTensorInterop, PytorchToCaffe2Strided) {
caffe2::Workspace workspace;
caffe2::NetDef net;
auto at_tensor = at::ones({5, 5}, at::dtype(at::kFloat)).t();
auto* c2_tensor = BlobSetTensor(workspace.CreateBlob("blob"), at_tensor.getIntrusivePtr());
{
auto op = net.add_op();
op->set_type("Sum");
op->add_input("blob");
op->add_output("out");
}
ASSERT_ANY_THROW(workspace.RunNetOnce(net));
}

View File

@ -15,6 +15,7 @@ VALGRIND=${VALGRIND:=ON}
./dlconvertor_test
./native_test
./scalar_tensor_test
./tensor_interop_test
./undefined_tensor_test
if [[ -x ./cudnn_test ]]; then
./cudnn_test
@ -37,6 +38,7 @@ fi
if [ "$VALGRIND" == "ON" ]
then
valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./basic "[cpu]"
valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./tensor_interop_test
fi
popd

View File

@ -28,6 +28,12 @@ class CAFFE2_API Tensor final {
public:
Tensor() : impl_() {}
Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
: impl_(std::move(tensor_impl)) {
if (impl_.get() == nullptr) {
throw std::runtime_error("TensorBaseImpl with nullptr not supported");
}
}
operator bool() const {
return impl_.defined();