mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
test basic tensor interop
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12249 Differential Revision: D13469356 Pulled By: li-roy fbshipit-source-id: b49748462aa44ac34b8ce79783f2c895a537a232
This commit is contained in:
parent
70f0c4745b
commit
50fbf79451
|
|
@ -15,6 +15,7 @@ list(APPEND ATen_CPU_TEST_SRCS
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dlconvertor_test.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dlconvertor_test.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/native_test.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/native_test.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scalar_tensor_test.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scalar_tensor_test.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/tensor_interop_test.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/test_parallel.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/test_parallel.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
|
||||||
|
|
|
||||||
141
aten/src/ATen/test/tensor_interop_test.cpp
Normal file
141
aten/src/ATen/test/tensor_interop_test.cpp
Normal file
|
|
@ -0,0 +1,141 @@
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
#include "ATen/ATen.h"
|
||||||
|
#include <caffe2/core/init.h>
|
||||||
|
#include <caffe2/core/operator.h>
|
||||||
|
|
||||||
|
TEST(TestTensorInterop, Caffe2ToPytorchSimpleLegacy) {
|
||||||
|
caffe2::Tensor c2_tensor(caffe2::CPU);
|
||||||
|
c2_tensor.Resize(4, 4);
|
||||||
|
auto data = c2_tensor.mutable_data<int64_t>();
|
||||||
|
for (int64_t i = 0; i < 16; i++) {
|
||||||
|
data[i] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: find out why calling data on tensor doesn't work
|
||||||
|
at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
|
||||||
|
at::TensorImpl* impl = at_tensor.unsafeGetTensorImpl();
|
||||||
|
|
||||||
|
auto it = impl->data<int64_t>();
|
||||||
|
for (int64_t i = 0; i < 16; i++) {
|
||||||
|
ASSERT_EQ(it[i], i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TestTensorInterop, Caffe2ToPytorchSimple) {
|
||||||
|
caffe2::Tensor c2_tensor = caffe2::empty({4, 4}, at::kLong);
|
||||||
|
auto data = c2_tensor.mutable_data<int64_t>();
|
||||||
|
for (int64_t i = 0; i < 16; i++) {
|
||||||
|
data[i] = i;
|
||||||
|
}
|
||||||
|
at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
|
||||||
|
at::TensorImpl* impl = at_tensor.unsafeGetTensorImpl();
|
||||||
|
|
||||||
|
auto it = impl->data<int64_t>();
|
||||||
|
for (int64_t i = 0; i < 16; i++) {
|
||||||
|
ASSERT_EQ(it[i], i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TestTensorInterop, Caffe2ToPytorchOp) {
|
||||||
|
caffe2::Tensor c2_tensor(caffe2::CPU);
|
||||||
|
c2_tensor.Resize(3, 3);
|
||||||
|
auto data = c2_tensor.mutable_data<int64_t>();
|
||||||
|
for (int64_t i = 0; i < 9; i++) {
|
||||||
|
data[i] = i;
|
||||||
|
}
|
||||||
|
at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
|
||||||
|
|
||||||
|
ASSERT_EQ(at::sum(at_tensor).item<int64_t>(), 36);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TestTensorInterop, Caffe2ToPytorchUnsupportedDevice) {
|
||||||
|
caffe2::Tensor c2_tensor(caffe2::IDEEP);
|
||||||
|
at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
|
||||||
|
ASSERT_ANY_THROW(at::sum(at_tensor));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TestTensorInterop, PytorchToCaffe2Op) {
|
||||||
|
caffe2::Workspace workspace;
|
||||||
|
caffe2::NetDef net;
|
||||||
|
|
||||||
|
auto at_tensor_a = at::ones({5, 5}, at::dtype(at::kFloat));
|
||||||
|
auto at_tensor_b = at::ones({5, 5}, at::dtype(at::kFloat));
|
||||||
|
auto at_tensor_c = at::ones({5, 5}, at::dtype(at::kFloat));
|
||||||
|
|
||||||
|
auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), at_tensor_a.getIntrusivePtr());
|
||||||
|
auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), at_tensor_b.getIntrusivePtr());
|
||||||
|
|
||||||
|
// Test ShareData as well
|
||||||
|
{
|
||||||
|
auto c2_tensor_c = XBlobGetMutableTensor(workspace.CreateBlob("c"), {0}, at::kCPU);
|
||||||
|
c2_tensor_c.ResizeLike(at_tensor_c.getIntrusivePtr());
|
||||||
|
c2_tensor_c.ShareData(at_tensor_c.getIntrusivePtr());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto op = net.add_op();
|
||||||
|
op->set_type("Sum");
|
||||||
|
op->add_input("a");
|
||||||
|
op->add_input("b");
|
||||||
|
op->add_input("c");
|
||||||
|
op->add_output("d");
|
||||||
|
}
|
||||||
|
|
||||||
|
workspace.RunNetOnce(net);
|
||||||
|
|
||||||
|
auto result = XBlobGetMutableTensor(workspace.CreateBlob("d"), {5, 5}, at::kCPU);
|
||||||
|
|
||||||
|
auto it = result.data<float>();
|
||||||
|
for (int64_t i = 0; i < 25; i++) {
|
||||||
|
ASSERT_EQ(it[i], 3.0);
|
||||||
|
}
|
||||||
|
at::Tensor at_result(result.getIntrusivePtr());
|
||||||
|
ASSERT_EQ(at::sum(at_result).item<float>(), 75);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TestTensorInterop, PytorchToCaffe2SharedStorage) {
|
||||||
|
caffe2::Workspace workspace;
|
||||||
|
caffe2::NetDef net;
|
||||||
|
|
||||||
|
auto at_tensor_a = at::ones({5, 5}, at::dtype(at::kFloat));
|
||||||
|
auto at_tensor_b = at_tensor_a.view({5, 5});
|
||||||
|
|
||||||
|
auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), at_tensor_a.getIntrusivePtr());
|
||||||
|
auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), at_tensor_b.getIntrusivePtr());
|
||||||
|
|
||||||
|
{
|
||||||
|
auto op = net.add_op();
|
||||||
|
op->set_type("Add");
|
||||||
|
op->add_input("a");
|
||||||
|
op->add_input("b");
|
||||||
|
op->add_output("c");
|
||||||
|
}
|
||||||
|
|
||||||
|
workspace.RunNetOnce(net);
|
||||||
|
|
||||||
|
auto result = XBlobGetMutableTensor(workspace.CreateBlob("c"), {5, 5}, at::kCPU);
|
||||||
|
auto it = result.data<float>();
|
||||||
|
for (int64_t i = 0; i < 25; i++) {
|
||||||
|
ASSERT_EQ(it[i], 2.0);
|
||||||
|
}
|
||||||
|
at::Tensor at_result(result.getIntrusivePtr());
|
||||||
|
ASSERT_EQ(at::sum(at_result).item<float>(), 50);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TestTensorInterop, PytorchToCaffe2Strided) {
|
||||||
|
caffe2::Workspace workspace;
|
||||||
|
caffe2::NetDef net;
|
||||||
|
|
||||||
|
auto at_tensor = at::ones({5, 5}, at::dtype(at::kFloat)).t();
|
||||||
|
auto* c2_tensor = BlobSetTensor(workspace.CreateBlob("blob"), at_tensor.getIntrusivePtr());
|
||||||
|
|
||||||
|
{
|
||||||
|
auto op = net.add_op();
|
||||||
|
op->set_type("Sum");
|
||||||
|
op->add_input("blob");
|
||||||
|
op->add_output("out");
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_ANY_THROW(workspace.RunNetOnce(net));
|
||||||
|
}
|
||||||
|
|
@ -15,6 +15,7 @@ VALGRIND=${VALGRIND:=ON}
|
||||||
./dlconvertor_test
|
./dlconvertor_test
|
||||||
./native_test
|
./native_test
|
||||||
./scalar_tensor_test
|
./scalar_tensor_test
|
||||||
|
./tensor_interop_test
|
||||||
./undefined_tensor_test
|
./undefined_tensor_test
|
||||||
if [[ -x ./cudnn_test ]]; then
|
if [[ -x ./cudnn_test ]]; then
|
||||||
./cudnn_test
|
./cudnn_test
|
||||||
|
|
@ -37,6 +38,7 @@ fi
|
||||||
if [ "$VALGRIND" == "ON" ]
|
if [ "$VALGRIND" == "ON" ]
|
||||||
then
|
then
|
||||||
valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./basic "[cpu]"
|
valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./basic "[cpu]"
|
||||||
|
valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./tensor_interop_test
|
||||||
fi
|
fi
|
||||||
|
|
||||||
popd
|
popd
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,12 @@ class CAFFE2_API Tensor final {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Tensor() : impl_() {}
|
Tensor() : impl_() {}
|
||||||
|
Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
|
||||||
|
: impl_(std::move(tensor_impl)) {
|
||||||
|
if (impl_.get() == nullptr) {
|
||||||
|
throw std::runtime_error("TensorBaseImpl with nullptr not supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
operator bool() const {
|
operator bool() const {
|
||||||
return impl_.defined();
|
return impl_.defined();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user