diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 37f8cf306d2..d3f0d0134bd 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -15,6 +15,7 @@ list(APPEND ATen_CPU_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/dlconvertor_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/native_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scalar_tensor_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/tensor_interop_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_parallel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp diff --git a/aten/src/ATen/test/tensor_interop_test.cpp b/aten/src/ATen/test/tensor_interop_test.cpp new file mode 100644 index 00000000000..ec3886b3065 --- /dev/null +++ b/aten/src/ATen/test/tensor_interop_test.cpp @@ -0,0 +1,141 @@ +#include "gtest/gtest.h" + +#include "ATen/ATen.h" +#include +#include + +TEST(TestTensorInterop, Caffe2ToPytorchSimpleLegacy) { + caffe2::Tensor c2_tensor(caffe2::CPU); + c2_tensor.Resize(4, 4); + auto data = c2_tensor.mutable_data(); + for (int64_t i = 0; i < 16; i++) { + data[i] = i; + } + + // TODO: find out why calling data on tensor doesn't work + at::Tensor at_tensor(c2_tensor.getIntrusivePtr()); + at::TensorImpl* impl = at_tensor.unsafeGetTensorImpl(); + + auto it = impl->data(); + for (int64_t i = 0; i < 16; i++) { + ASSERT_EQ(it[i], i); + } +} + +TEST(TestTensorInterop, Caffe2ToPytorchSimple) { + caffe2::Tensor c2_tensor = caffe2::empty({4, 4}, at::kLong); + auto data = c2_tensor.mutable_data(); + for (int64_t i = 0; i < 16; i++) { + data[i] = i; + } + at::Tensor at_tensor(c2_tensor.getIntrusivePtr()); + at::TensorImpl* impl = at_tensor.unsafeGetTensorImpl(); + + auto it = impl->data(); + for (int64_t i = 0; i < 16; i++) { + ASSERT_EQ(it[i], i); + } +} + +TEST(TestTensorInterop, Caffe2ToPytorchOp) { + caffe2::Tensor c2_tensor(caffe2::CPU); + c2_tensor.Resize(3, 3); + auto data = c2_tensor.mutable_data(); + for (int64_t i = 0; i < 9; i++) { + data[i] = i; + } + at::Tensor at_tensor(c2_tensor.getIntrusivePtr()); + + ASSERT_EQ(at::sum(at_tensor).item(), 36); +} + +TEST(TestTensorInterop, Caffe2ToPytorchUnsupportedDevice) { + caffe2::Tensor c2_tensor(caffe2::IDEEP); + at::Tensor at_tensor(c2_tensor.getIntrusivePtr()); + ASSERT_ANY_THROW(at::sum(at_tensor)); +} + +TEST(TestTensorInterop, PytorchToCaffe2Op) { + caffe2::Workspace workspace; + caffe2::NetDef net; + + auto at_tensor_a = at::ones({5, 5}, at::dtype(at::kFloat)); + auto at_tensor_b = at::ones({5, 5}, at::dtype(at::kFloat)); + auto at_tensor_c = at::ones({5, 5}, at::dtype(at::kFloat)); + + auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), at_tensor_a.getIntrusivePtr()); + auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), at_tensor_b.getIntrusivePtr()); + + // Test ShareData as well + { + auto c2_tensor_c = XBlobGetMutableTensor(workspace.CreateBlob("c"), {0}, at::kCPU); + c2_tensor_c.ResizeLike(at_tensor_c.getIntrusivePtr()); + c2_tensor_c.ShareData(at_tensor_c.getIntrusivePtr()); + } + + { + auto op = net.add_op(); + op->set_type("Sum"); + op->add_input("a"); + op->add_input("b"); + op->add_input("c"); + op->add_output("d"); + } + + workspace.RunNetOnce(net); + + auto result = XBlobGetMutableTensor(workspace.CreateBlob("d"), {5, 5}, at::kCPU); + + auto it = result.data(); + for (int64_t i = 0; i < 25; i++) { + ASSERT_EQ(it[i], 3.0); + } + at::Tensor at_result(result.getIntrusivePtr()); + ASSERT_EQ(at::sum(at_result).item(), 75); +} + +TEST(TestTensorInterop, PytorchToCaffe2SharedStorage) { + caffe2::Workspace workspace; + caffe2::NetDef net; + + auto at_tensor_a = at::ones({5, 5}, at::dtype(at::kFloat)); + auto at_tensor_b = at_tensor_a.view({5, 5}); + + auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), at_tensor_a.getIntrusivePtr()); + auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), at_tensor_b.getIntrusivePtr()); + + { + auto op = net.add_op(); + op->set_type("Add"); + op->add_input("a"); + op->add_input("b"); + op->add_output("c"); + } + + workspace.RunNetOnce(net); + + auto result = XBlobGetMutableTensor(workspace.CreateBlob("c"), {5, 5}, at::kCPU); + auto it = result.data(); + for (int64_t i = 0; i < 25; i++) { + ASSERT_EQ(it[i], 2.0); + } + at::Tensor at_result(result.getIntrusivePtr()); + ASSERT_EQ(at::sum(at_result).item(), 50); +} + +TEST(TestTensorInterop, PytorchToCaffe2Strided) { + caffe2::Workspace workspace; + caffe2::NetDef net; + + auto at_tensor = at::ones({5, 5}, at::dtype(at::kFloat)).t(); + auto* c2_tensor = BlobSetTensor(workspace.CreateBlob("blob"), at_tensor.getIntrusivePtr()); + + { + auto op = net.add_op(); + op->set_type("Sum"); + op->add_input("blob"); + op->add_output("out"); + } + + ASSERT_ANY_THROW(workspace.RunNetOnce(net)); +} diff --git a/aten/tools/run_tests.sh b/aten/tools/run_tests.sh index 720afe401ed..c2a0d2f47f0 100755 --- a/aten/tools/run_tests.sh +++ b/aten/tools/run_tests.sh @@ -15,6 +15,7 @@ VALGRIND=${VALGRIND:=ON} ./dlconvertor_test ./native_test ./scalar_tensor_test +./tensor_interop_test ./undefined_tensor_test if [[ -x ./cudnn_test ]]; then ./cudnn_test @@ -37,6 +38,7 @@ fi if [ "$VALGRIND" == "ON" ] then valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./basic "[cpu]" + valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./tensor_interop_test fi popd diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index a91e854a471..feacc64ef93 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -28,6 +28,12 @@ class CAFFE2_API Tensor final { public: Tensor() : impl_() {} + Tensor(c10::intrusive_ptr tensor_impl) + : impl_(std::move(tensor_impl)) { + if (impl_.get() == nullptr) { + throw std::runtime_error("TensorBaseImpl with nullptr not supported"); + } + } operator bool() const { return impl_.defined();