mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61466 ## Goal Per #55126 the performance of `reshape` is worse than `alias` in cases where they are performing the same operation (i.e. where reshape is returning a view) because `reshape` delegates to `view` and duplicates some of the operations (specifically `infer_size_dv` and `computeStride`). The goal of this pull-request is to reduce or remove the additional overhead that `reshape` has. ### Proposed Implementation Instead of using `view` we implement a private/internal operator (`_reshape_alias`) that `reshape` dispatches to which skips the relevant checks. This is functionally equivalent to `as_strided` however it is a lot simpler because it's specialized to this use-case, and importantly the `backward` implementation is a lot faster. Note that we have to dispatch (`reshape` is a composite operator) because `reshape` can return either a view or a copy of the Tensor depending on the parameters, and this complicates implementing a derivative/backward for `reshape`. ### Why not `as_strided`? Using `as_strided` directly slows down autograd. If we use a custom function equivalent to `_reshape_alias` but with a simpler backward function then `view` has the same performance as `reshape`. If we delegate to `as_strided` it is about 56% slower (and this holds against our custom function). This is also the reason we make an internal operator named `_reshape_alias` instead of exposing a new operator since this should only be used in the `reshape` case and it is effectively a more limited version of `view`, `alias`, and `as_strided`. ## Benchmarks In a micro-benchmark for `backward` running: ```cpp // Setup at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); // Benchmark loop // `reshape(-1)` replaced with a call to view(-1) for view baseline x.pow(4).reshape(-1).mean().backward(); ``` I also benchmarked simple operations without gradients using: ```cpp // Setup at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); // Benchmark loop x.reshape(-1) // replaced with a call to view(-1) for view baseline ``` Baselined to `view`: * Original `reshape`: `+3.3%` (without gradients `+20.8%`) * Using `as_strided`: `+55.1%` (without gradients `+1.0%`) * Using custom `_reshape_view`: `-1.0%` (without gradients `+6.2%`) In absolute terms (note the percentages above were generated comparing between runs/tests rather than to a single baseline): * Original `view`: `53.66 us` (without gradients `582.78 ns`) * Original `reshape`: `55.46 us` (without gradients `704.24 ns`) * Using `as_strided`: `83.24 us` (without gradients `576.49 ns`) * Using custom `_reshape_view`: `53.13 us` (without gradients `536.01 ns`) Note that these benchmarks perform a backwards operation as well. When compared without using gradient computation at all the performance differneces are more pronounced as this takes up more of the time. ### Original performance <details> <summary>Benchmark results</summary> ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f0e4d393160> x.pow(4).view(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.66 us IQR: 2.70 us (52.54 to 55.24) 884 measurements, 100 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f0e2ebd4fa0> x.pow(4).reshape(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 55.46 us IQR: 2.61 us (54.39 to 57.01) 889 measurements, 100 runs per measurement, 1 thread] 2276116 2286256 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f0e5b2e3e20> 2640 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&) 1920 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1040 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long>&&) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) 720 ???:__tls_get_addr 520 ???:at::shouldRunRecordFunction(bool*) 520 ???:__memcpy_avx_unaligned_erms 200 ???:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10:: ... g>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) 100 ???:c10::TensorImpl::strides() const 100 ???:c10::TensorImpl::sizes() const 100 ???:at::(anonymous namespace)::manager() 77 /tmp/benchmark_utils_jit_build__1626465284__8a34e7ff-cd37-4a82-be28-7f19e081e771/timer_cpp_7815557938202456331/timer_src.cpp:main 40 ???:c10::TensorImpl::numel() const -77 /tmp/benchmark_utils_jit_build__1626465284__8a34e7ff-cd37-4a82-be28-7f19e081e771/timer_cpp_8055217880649990171/timer_src.cpp:main -260 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 10140 ``` ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f850dd66c10> x.view(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 582.78 ns IQR: 33.80 ns (573.80 to 607.61) 833 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f850de31e20> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 704.24 ns IQR: 24.42 ns (697.20 to 721.62) 679 measurements, 10000 runs per measurement, 1 thread] 56896 67036 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f84e1930bb0> 2640 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&) 1920 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1040 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long>&&) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) 720 ???:__tls_get_addr 520 ???:at::shouldRunRecordFunction(bool*) 520 ???:__memcpy_avx_unaligned_erms 200 ???:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10:: ... g>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) 100 ???:c10::TensorImpl::strides() const 100 ???:c10::TensorImpl::sizes() const 100 ???:at::(anonymous namespace)::manager() 76 /tmp/benchmark_utils_jit_build__1626466038__15fbbac0-2072-4459-8f8e-08121a905b99/timer_cpp_547407365342278353/timer_src.cpp:main 40 ???:c10::TensorImpl::numel() const -76 /tmp/benchmark_utils_jit_build__1626466038__15fbbac0-2072-4459-8f8e-08121a905b99/timer_cpp_3457873755756181226/timer_src.cpp:main -260 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 10140 ``` </details> ### Using `as_strided` <details> <summary>Benchmark results</summary> ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f8b13bb5b50> x.pow(4).view(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.37 us IQR: 3.15 us (51.73 to 54.88) 936 measurements, 100 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f8af55f8490> x.pow(4).reshape(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 83.24 us IQR: 4.05 us (81.20 to 85.25) 609 measurements, 100 runs per measurement, 1 thread] 2267916 2525061 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f8af55f8e50> 31930 ???:_int_free 15940 ???:malloc 11595 ???:_int_malloc 10100 ???:torch::autograd::generated::details::as_strided_backward(at::Tensor, at::TensorGeometry, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 9360 ???:__tls_get_addr 8280 ???:free 8100 ???:torch::autograd::VariableType::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 4520 ???:c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_() 4080 ???:operator new(unsigned long) ... -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1220 ???:torch::autograd::generated::ViewBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -2560 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) -4860 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) Total: 257145 ``` ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f93176a0160> x.view(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 570.55 ns IQR: 32.69 ns (552.87 to 585.56) 874 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f92f8f29490> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 576.49 ns IQR: 37.95 ns (559.51 to 597.46) 861 measurements, 10000 runs per measurement, 1 thread] 56896 58556 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f932556ca60> 2140 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1940 ???:torch::autograd::VariableType::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1880 ???:torch::ADInplaceOrView::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1720 ???:at::_ops::as_strided::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1400 ???:at::native::as_strided_tensorimpl(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1260 ???:at::_ops::as_strided::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>)'2 1260 ???:at::_ops::as_strided::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) ... -620 ???:at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, c10::ArrayRef<long ... ::ArrayRef<long>)> const&, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) const -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2 -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -1740 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 1660 ``` </details> ### Using custom function (`_reshape_alias`) <details> <summary>Benchmark results</summary> ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f16861d6b50> x.pow(4).view(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.50 us IQR: 2.64 us (52.32 to 54.96) 906 measurements, 100 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f1667b2ed60> x.pow(4).reshape(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.13 us IQR: 3.40 us (51.72 to 55.13) 914 measurements, 100 runs per measurement, 1 thread] 2269736 2273236 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f1693f8dc10> 5060 ???:torch::autograd::VariableType::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 2000 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1780 ???:torch::ADInplaceOrView::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1660 ???:at::_ops::_reshape_alias::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1600 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::ArrayRef<long> >(at::Tensor const&, c10::ArrayRef<long> const&, c10::ArrayRef<long> const&) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)'2 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1220 ???:torch::autograd::generated::AliasToShapeBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) ... -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2 -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1220 ???:torch::autograd::generated::ViewBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) -4860 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) Total: 3500 ``` ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f5287adfb20> x.view(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 505.10 ns IQR: 20.04 ns (500.41 to 520.45) 944 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f526951b430> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 536.01 ns IQR: 17.81 ns (531.34 to 549.16) 916 measurements, 10000 runs per measurement, 1 thread] 56896 60376 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f5295896c10> 2000 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1860 ???:torch::autograd::VariableType::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1780 ???:torch::ADInplaceOrView::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1660 ???:at::_ops::_reshape_alias::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1600 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::ArrayRef<long> >(at::Tensor const&, c10::ArrayRef<long> const&, c10::ArrayRef<long> const&) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)'2 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) ... -620 ???:at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, c10::ArrayRef<long ... ::ArrayRef<long>)> const&, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) const -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2 -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -1740 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 3480 ``` </details> Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D29792126 Pulled By: laurencer fbshipit-source-id: f0519b45b65f868aa3e8651679354558bd761dfd
1145 lines
44 KiB
C++
1145 lines
44 KiB
C++
#include <gtest/gtest.h>
|
|
#include <test/cpp/api/support.h>
|
|
|
|
#include <torch/torch.h>
|
|
|
|
#include <cmath>
|
|
#include <cstddef>
|
|
#include <vector>
|
|
|
|
#include <test/cpp/common/support.h>
|
|
|
|
using namespace torch::test;
|
|
|
|
template <typename T>
|
|
bool exactly_equal(at::Tensor left, T right) {
|
|
return left.item<T>() == right;
|
|
}
|
|
|
|
template <typename T>
|
|
bool almost_equal(at::Tensor left, T right, double tolerance = 1e-4) {
|
|
return std::abs(left.item<T>() - right) < tolerance;
|
|
}
|
|
|
|
#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
|
|
ASSERT_TRUE( \
|
|
tensor.device().type() == at::Device((device_), (index_)).type()); \
|
|
ASSERT_TRUE( \
|
|
tensor.device().index() == at::Device((device_), (index_)).index()); \
|
|
ASSERT_EQ(tensor.dtype(), (type_)); \
|
|
ASSERT_TRUE(tensor.layout() == (layout_))
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, ToDtype) {
|
|
auto tensor = at::empty({3, 4});
|
|
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
|
|
|
|
tensor = tensor.to(at::kInt);
|
|
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
|
|
|
|
tensor = tensor.to(at::kChar);
|
|
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kChar, at::kStrided);
|
|
|
|
tensor = tensor.to(at::kDouble);
|
|
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
|
|
|
|
tensor = tensor.to(at::TensorOptions(at::kInt));
|
|
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
|
|
|
|
tensor = tensor.to(at::TensorOptions(at::kChar));
|
|
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kChar, at::kStrided);
|
|
|
|
tensor = tensor.to(at::TensorOptions(at::kDouble));
|
|
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, ToTensorAndTensorAttributes) {
|
|
auto tensor = at::empty({3, 4});
|
|
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
|
|
|
|
auto other = at::empty({3, 4}, at::kInt);
|
|
tensor = tensor.to(other);
|
|
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
|
|
|
|
other = at::empty({3, 4}, at::kDouble);
|
|
tensor = tensor.to(other.dtype());
|
|
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
|
|
tensor = tensor.to(other.device());
|
|
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
|
|
|
|
other = at::empty({3, 4}, at::kLong);
|
|
tensor = tensor.to(other.device(), other.dtype());
|
|
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kLong, at::kStrided);
|
|
|
|
other = at::empty({3, 4}, at::kInt);
|
|
tensor = tensor.to(other.options());
|
|
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
|
|
}
|
|
|
|
// Not currently supported.
|
|
// TEST(TensorTest, ToLayout) {
|
|
// auto tensor = at::empty({3, 4});
|
|
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
|
|
//
|
|
// tensor = tensor.to(at::kSparse);
|
|
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kSparse);
|
|
//
|
|
// tensor = tensor.to(at::kStrided);
|
|
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
|
|
// }
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, ToOptionsWithRequiresGrad) {
|
|
{
|
|
// Respects requires_grad
|
|
auto tensor = torch::empty({3, 4}, at::requires_grad());
|
|
ASSERT_TRUE(tensor.requires_grad());
|
|
|
|
tensor = tensor.to(at::kDouble);
|
|
ASSERT_TRUE(tensor.requires_grad());
|
|
|
|
// Throws if requires_grad is set in TensorOptions
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
|
ASSERT_THROW(
|
|
tensor.to(at::TensorOptions().requires_grad(true)), c10::Error);
|
|
|
|
// Doesn't throw if requires_grad is not set
|
|
tensor.to(at::TensorOptions());
|
|
tensor.to(at::TensorOptions().requires_grad(false));
|
|
}
|
|
{
|
|
auto tensor = torch::empty({3, 4});
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
|
|
// Respects requires_grad
|
|
tensor = tensor.to(at::kDouble);
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
|
|
// Throws if requires_grad is set in TensorOptions
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
|
ASSERT_THROW(
|
|
tensor.to(at::TensorOptions().requires_grad(true)), c10::Error);
|
|
|
|
// Doesn't throw if requires_grad is not set
|
|
tensor.to(at::TensorOptions());
|
|
tensor.to(at::TensorOptions().requires_grad(false));
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, ToDoesNotCopyWhenOptionsAreAllTheSame) {
|
|
{
|
|
auto tensor = at::empty({3, 4}, at::kFloat);
|
|
auto hopefully_not_copy = tensor.to(at::kFloat);
|
|
ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
|
|
}
|
|
{
|
|
auto tensor = at::empty({3, 4}, at::kFloat);
|
|
auto hopefully_not_copy = tensor.to(tensor.options());
|
|
ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
|
|
}
|
|
{
|
|
auto tensor = at::empty({3, 4}, at::kFloat);
|
|
auto hopefully_not_copy = tensor.to(tensor.dtype());
|
|
ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
|
|
}
|
|
{
|
|
auto tensor = at::empty({3, 4}, at::kFloat);
|
|
auto hopefully_not_copy = tensor.to(tensor.device());
|
|
ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
|
|
}
|
|
{
|
|
auto tensor = at::empty({3, 4}, at::kFloat);
|
|
auto hopefully_not_copy = tensor.to(tensor);
|
|
ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, AtTensorCtorScalar) {
|
|
auto tensor = at::tensor(123);
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.dtype(), at::kInt);
|
|
ASSERT_EQ(tensor[0].item<int32_t>(), 123);
|
|
|
|
tensor = at::tensor(123.456f);
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.dtype(), at::kFloat);
|
|
ASSERT_TRUE(almost_equal(tensor[0], 123.456f));
|
|
|
|
tensor = at::tensor(123.456);
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.dtype(), at::kDouble);
|
|
ASSERT_TRUE(almost_equal(tensor[0], 123.456));
|
|
|
|
tensor = at::tensor(123, at::dtype(at::kFloat)) + 0.5;
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.dtype(), at::kFloat);
|
|
ASSERT_TRUE(almost_equal(tensor[0], 123.5));
|
|
|
|
tensor = at::tensor(c10::complex<float>(1.0, 2.0)) + 0.5;
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.dtype(), at::kComplexFloat);
|
|
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<float>(1.5, 2.0)));
|
|
|
|
tensor = at::tensor(c10::complex<float>(1.0, 2.0), at::dtype(at::kComplexFloat)) + 0.5;
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.dtype(), at::kComplexFloat);
|
|
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<float>(1.5, 2.0)));
|
|
|
|
tensor = at::tensor(c10::complex<double>(1.0, 2.0)) + 0.5;
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
|
|
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5, 2.0)));
|
|
|
|
tensor = at::tensor(c10::complex<float>(1.0, 2.0), at::dtype(at::kComplexDouble)) + 0.5;
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
|
|
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5, 2.0)));
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, AtTensorCtorSingleDim) {
|
|
auto tensor = at::tensor({1, 2, 3});
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), at::kInt);
|
|
ASSERT_TRUE(exactly_equal(tensor[0], 1));
|
|
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
|
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
|
|
|
tensor = at::tensor(std::vector<int>({1, 2, 3}));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), at::kInt);
|
|
ASSERT_TRUE(exactly_equal(tensor[0], 1));
|
|
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
|
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
|
|
|
tensor = at::tensor({1.5, 2.25, 3.125});
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), at::kDouble);
|
|
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
|
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
|
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
|
|
|
tensor = at::tensor({c10::complex<float>(1.5, 0.15), c10::complex<float>(1.5, 0.15), c10::complex<float>(3.125, 0.3125)});
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), at::kComplexFloat);
|
|
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<float>(1.5, 0.15)));
|
|
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<float>(1.5, 0.15)));
|
|
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<float>(3.125, 0.3125)));
|
|
|
|
tensor = at::tensor({c10::complex<double>(1.5, 0.15), c10::complex<double>(1.5, 0.15), c10::complex<double>(3.125, 0.3125)});
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
|
|
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5, 0.15)));
|
|
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(1.5, 0.15)));
|
|
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.125, 0.3125)));
|
|
|
|
tensor = at::tensor({1.1, 2.2, 3.3}, at::dtype(at::kInt));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), at::kInt);
|
|
ASSERT_EQ(tensor.layout(), at::kStrided);
|
|
ASSERT_TRUE(exactly_equal(tensor[0], 1));
|
|
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
|
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
|
|
|
tensor = at::tensor(std::vector<double>({1.5, 2.25, 3.125}));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), at::kDouble);
|
|
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
|
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
|
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
|
|
|
tensor = at::tensor(std::vector<c10::complex<float>>({c10::complex<float>(1.5, 0.15), c10::complex<float>(1.5, 0.15), c10::complex<float>(3.125, 0.3125)}));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), at::kComplexFloat);
|
|
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<float>(1.5, 0.15)));
|
|
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<float>(1.5, 0.15)));
|
|
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<float>(3.125, 0.3125)));
|
|
|
|
tensor = at::tensor(std::vector<c10::complex<double>>({c10::complex<double>(1.5, 0.15), c10::complex<double>(1.5, 0.15), c10::complex<double>(3.125, 0.3125)}));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
|
|
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5, 0.15)));
|
|
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(1.5, 0.15)));
|
|
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.125, 0.3125)));
|
|
|
|
std::vector<int> v = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
|
|
tensor = at::tensor(v);
|
|
ASSERT_EQ(tensor.numel(), v.size());
|
|
ASSERT_EQ(tensor.dtype(), at::kInt);
|
|
for (size_t i = 0; i < v.size(); ++i) {
|
|
ASSERT_TRUE(exactly_equal(tensor[i], v.at(i)));
|
|
}
|
|
|
|
std::vector<double> w = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0};
|
|
tensor = at::tensor(w);
|
|
ASSERT_EQ(tensor.numel(), w.size());
|
|
ASSERT_EQ(tensor.dtype(), at::kDouble);
|
|
for (size_t i = 0; i < w.size(); ++i) {
|
|
ASSERT_TRUE(almost_equal(tensor[i], w.at(i)));
|
|
}
|
|
|
|
std::vector<c10::complex<double>> x = {
|
|
{1.1, -1.1}, {2.2, -2.2}, {3.3, -3.3}, {4.4, -4.4}, {5.5, -5.5},
|
|
{6.6, -6.6}, {7.7, -7.7}, {8.8, -8.8}, {9.9, -9.9}, {10.0, -10.0}
|
|
};
|
|
tensor = at::tensor(x);
|
|
ASSERT_EQ(tensor.numel(), x.size());
|
|
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
|
|
for (size_t i = 0; i < x.size(); ++i) {
|
|
ASSERT_TRUE(almost_equal(tensor[i], x.at(i)));
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, AtTensorCastRealToComplex) {
|
|
auto tensor = at::tensor(std::vector<double>({1.5, 2.5, 3.5}), at::kComplexDouble);
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
|
|
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
|
|
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(2.5)));
|
|
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.5)));
|
|
|
|
tensor = at::tensor({1.5, 2.5, 3.5}, at::kComplexDouble);
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
|
|
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
|
|
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(2.5)));
|
|
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.5)));
|
|
|
|
tensor = at::tensor(1.5, at::kComplexDouble);
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
|
|
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, AtTensorCastComplexToRealErrorChecks) {
|
|
{
|
|
ASSERT_THROWS_WITH(at::tensor(c10::complex<float>(0.1, 0.2), at::kFloat),
|
|
"\"tensor_cpu\" not implemented for 'Float'");
|
|
}
|
|
{
|
|
ASSERT_THROWS_WITH(at::tensor({c10::complex<float>(0.1, 0.2)}, at::kFloat),
|
|
"\"tensor_cpu\" not implemented for 'Float'");
|
|
}
|
|
{
|
|
ASSERT_THROWS_WITH(at::tensor(std::vector<c10::complex<float>>{c10::complex<float>(0.1, 0.2)}, at::kFloat),
|
|
"\"tensor_cpu\" not implemented for 'Float'");
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorScalarIntegralType) {
|
|
auto tensor = torch::tensor(123);
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
|
|
ASSERT_EQ(tensor.dtype(), at::kLong);
|
|
ASSERT_EQ(tensor.item<int64_t>(), 123);
|
|
}
|
|
|
|
void test_TorchTensorCtorScalarFloatingType_expected_dtype(c10::ScalarType default_dtype) {
|
|
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
|
|
|
auto tensor = torch::tensor(123.456f);
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_TRUE(almost_equal(tensor, 123.456f));
|
|
|
|
tensor = torch::tensor(123.456);
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_TRUE(almost_equal(tensor, 123.456));
|
|
|
|
tensor = torch::tensor({123.456});
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_TRUE(almost_equal(tensor[0], 123.456));
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorScalarFloatingType) {
|
|
test_TorchTensorCtorScalarFloatingType_expected_dtype(/*default_dtype=*/torch::kFloat);
|
|
test_TorchTensorCtorScalarFloatingType_expected_dtype(/*default_dtype=*/torch::kDouble);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorScalarBoolType) {
|
|
auto tensor = torch::tensor(true);
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
|
|
ASSERT_EQ(tensor.dtype(), at::kBool);
|
|
ASSERT_TRUE(exactly_equal(tensor, true));
|
|
|
|
tensor = torch::tensor({true});
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1}));
|
|
ASSERT_EQ(tensor.dtype(), at::kBool);
|
|
ASSERT_TRUE(exactly_equal(tensor[0], true));
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorSingleDimIntegralType) {
|
|
auto tensor = torch::tensor({1, 2, 3});
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
|
ASSERT_EQ(tensor.dtype(), at::kLong);
|
|
ASSERT_TRUE(exactly_equal(tensor[0], 1));
|
|
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
|
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
|
|
|
tensor = torch::tensor(at::ArrayRef<int>({1, 2, 3}));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
|
ASSERT_EQ(tensor.dtype(), at::kLong);
|
|
ASSERT_TRUE(exactly_equal(tensor[0], 1));
|
|
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
|
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
|
|
|
tensor = torch::tensor(std::vector<int>({1, 2, 3}));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
|
ASSERT_EQ(tensor.dtype(), at::kLong);
|
|
ASSERT_TRUE(exactly_equal(tensor[0], 1));
|
|
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
|
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
|
|
|
tensor = torch::tensor(at::ArrayRef<int64_t>({1, 2, 3}));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
|
ASSERT_EQ(tensor.dtype(), at::kLong);
|
|
ASSERT_TRUE(exactly_equal(tensor[0], 1));
|
|
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
|
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
|
|
|
tensor = torch::tensor(std::vector<int64_t>({1, 2, 3}));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
|
ASSERT_EQ(tensor.dtype(), at::kLong);
|
|
ASSERT_TRUE(exactly_equal(tensor[0], 1));
|
|
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
|
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
|
}
|
|
|
|
void test_TorchTensorCtorSingleDimFloatingType_expected_dtype(c10::ScalarType default_dtype) {
|
|
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
|
|
|
auto tensor = torch::tensor({1.5, 2.25, 3.125});
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
|
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
|
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
|
|
|
tensor = torch::tensor({1.5f, 2.25f, 3.125f});
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_TRUE(almost_equal(tensor[0], 1.5f));
|
|
ASSERT_TRUE(almost_equal(tensor[1], 2.25f));
|
|
ASSERT_TRUE(almost_equal(tensor[2], 3.125f));
|
|
|
|
tensor = torch::tensor(at::ArrayRef<float>({1.5f, 2.25f, 3.125f}));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
|
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
|
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
|
|
|
tensor = torch::tensor(std::vector<float>({1.5f, 2.25f, 3.125f}));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
|
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
|
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
|
|
|
tensor = torch::tensor(at::ArrayRef<double>({1.5, 2.25, 3.125}));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
|
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
|
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
|
|
|
tensor = torch::tensor(std::vector<double>({1.5, 2.25, 3.125}));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
|
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
|
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorSingleDimFloatingType) {
|
|
test_TorchTensorCtorSingleDimFloatingType_expected_dtype(/*default_dtype=*/torch::kFloat);
|
|
test_TorchTensorCtorSingleDimFloatingType_expected_dtype(/*default_dtype=*/torch::kDouble);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorSingleDimBoolType) {
|
|
auto tensor = torch::tensor({true, false, true});
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
|
ASSERT_EQ(tensor.dtype(), at::kBool);
|
|
ASSERT_TRUE(exactly_equal(tensor[0], true));
|
|
ASSERT_TRUE(exactly_equal(tensor[1], false));
|
|
ASSERT_TRUE(exactly_equal(tensor[2], true));
|
|
|
|
tensor = torch::tensor(at::ArrayRef<bool>({true, false, true}));
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
|
ASSERT_EQ(tensor.dtype(), at::kBool);
|
|
ASSERT_TRUE(exactly_equal(tensor[0], true));
|
|
ASSERT_TRUE(exactly_equal(tensor[1], false));
|
|
ASSERT_TRUE(exactly_equal(tensor[2], true));
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorMultiDimIntegralType) {
|
|
{
|
|
auto tensor = torch::tensor({{1, 2}});
|
|
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
|
|
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{1}, {2}});
|
|
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 1}));
|
|
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{{1, 2}}});
|
|
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 2}));
|
|
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{{1}, {2}}});
|
|
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2, 1}));
|
|
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{1, 2}, {3, 4}});
|
|
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 2}));
|
|
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 5, torch::kLong).view(tensor.sizes())));
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{{{{{{{{{1}}}}}}}}}});
|
|
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
|
|
ASSERT_TRUE(torch::allclose(tensor, torch::full({1}, 1, torch::kLong).view(tensor.sizes())));
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{{{{{{{{{1, 2}}}}}}}}}});
|
|
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 2}));
|
|
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
}
|
|
|
|
void test_TorchTensorCtorMultiDimFloatingType_expected_dtype(c10::ScalarType default_dtype) {
|
|
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
|
{
|
|
auto tensor = torch::tensor({{1.0, 2.0}});
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
|
|
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, default_dtype).view(tensor.sizes())));
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}});
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 3, 1, 1, 1, 1, 3}));
|
|
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 10, default_dtype).view(tensor.sizes())));
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorMultiDimFloatingType) {
|
|
test_TorchTensorCtorMultiDimFloatingType_expected_dtype(/*default_dtype=*/torch::kFloat);
|
|
test_TorchTensorCtorMultiDimFloatingType_expected_dtype(/*default_dtype=*/torch::kDouble);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorMultiDimBoolType) {
|
|
{
|
|
auto tensor = torch::tensor({{true, false}});
|
|
ASSERT_EQ(tensor.dtype(), torch::kBool);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
|
|
auto expected = torch::empty(tensor.sizes(), torch::kBool);
|
|
expected[0][0] = true;
|
|
expected[0][1] = false;
|
|
ASSERT_TRUE(torch::equal(tensor, expected));
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{true}, {false}});
|
|
ASSERT_EQ(tensor.dtype(), torch::kBool);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 1}));
|
|
auto expected = torch::empty(tensor.sizes(), torch::kBool);
|
|
expected[0][0] = true;
|
|
expected[1][0] = false;
|
|
ASSERT_TRUE(torch::equal(tensor, expected));
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorMultiDimWithOptions) {
|
|
{
|
|
auto tensor = torch::tensor({{1, 2}}, torch::dtype(torch::kInt));
|
|
ASSERT_EQ(tensor.dtype(), torch::kInt);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
|
|
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kInt).view(tensor.sizes())));
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{1, 2}, {3, 4}}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
ASSERT_EQ(tensor.dtype(), torch::kFloat);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 2}));
|
|
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 5, torch::kFloat).view(tensor.sizes())));
|
|
ASSERT_TRUE(tensor.requires_grad());
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorMultiDimErrorChecks) {
|
|
{
|
|
ASSERT_THROWS_WITH(torch::tensor({{{2, 3, 4}, {{5, 6}, {7}}}}),
|
|
"Expected all sub-lists to have sizes: 2 (e.g. {5, 6}), but got sub-list {7} with sizes: 1");
|
|
}
|
|
{
|
|
ASSERT_THROWS_WITH(torch::tensor({{{1, 2.0}, {1, 2.0}}}),
|
|
"Expected all elements of the tensor to have the same scalar type: Int, but got element of scalar type: Double");
|
|
}
|
|
{
|
|
ASSERT_THROWS_WITH(torch::tensor({{{true, 2.0, 3}, {true, 2.0, 3}}}),
|
|
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Double");
|
|
}
|
|
{
|
|
ASSERT_THROWS_WITH(torch::tensor({{{true}, {2}}}),
|
|
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int");
|
|
}
|
|
{
|
|
ASSERT_THROWS_WITH(torch::tensor({{{true, 2}}}),
|
|
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int");
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCastRealToComplex) {
|
|
auto tensor = torch::tensor(std::vector<double>({1.5, 2.5, 3.5}), torch::kComplexDouble);
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), torch::kComplexDouble);
|
|
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
|
|
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(2.5)));
|
|
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.5)));
|
|
|
|
tensor = torch::tensor({1.5, 2.5, 3.5}, torch::kComplexDouble);
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor.dtype(), torch::kComplexDouble);
|
|
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
|
|
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(2.5)));
|
|
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.5)));
|
|
|
|
tensor = torch::tensor(1.5, torch::kComplexDouble);
|
|
ASSERT_EQ(tensor.numel(), 1);
|
|
ASSERT_EQ(tensor.dtype(), torch::kComplexDouble);
|
|
ASSERT_TRUE(almost_equal(tensor, c10::complex<double>(1.5)));
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCastComplexToRealErrorChecks) {
|
|
{
|
|
ASSERT_THROWS_WITH(torch::tensor(c10::complex<float>(0.1, 0.2), torch::kFloat),
|
|
"value cannot be converted to type float without overflow");
|
|
}
|
|
{
|
|
ASSERT_THROWS_WITH(torch::tensor({c10::complex<float>(0.1, 0.2), c10::complex<float>(0.3, 0.4)}, torch::kFloat),
|
|
"value cannot be converted to type float without overflow");
|
|
}
|
|
{
|
|
ASSERT_THROWS_WITH(torch::tensor(std::vector<c10::complex<float>>{c10::complex<float>(0.1, 0.2), c10::complex<float>(0.3, 0.4)}, torch::kFloat),
|
|
"can not do torch::tensor(complex, dtype=non-complex) because complex can not be casted to real number without loss of information");
|
|
}
|
|
}
|
|
|
|
void test_TorchTensorCtorMultiDim_CUDA_expected_dtype(c10::ScalarType default_dtype) {
|
|
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
|
|
|
auto tensor = torch::tensor(
|
|
{{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}},
|
|
torch::dtype(default_dtype).device(torch::kCUDA));
|
|
ASSERT_TRUE(tensor.device().is_cuda());
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 3, 1, 1, 1, 1, 3}));
|
|
ASSERT_TRUE(torch::allclose(
|
|
tensor,
|
|
torch::arange(1, 10, default_dtype).view(tensor.sizes()).to(torch::kCUDA)));
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorMultiDim_CUDA) {
|
|
test_TorchTensorCtorMultiDim_CUDA_expected_dtype(/*default_dtype=*/torch::kFloat);
|
|
test_TorchTensorCtorMultiDim_CUDA_expected_dtype(/*default_dtype=*/torch::kDouble);
|
|
}
|
|
|
|
void test_TorchTensorCtorZeroSizedDim_expected_dtype(c10::ScalarType default_dtype) {
|
|
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
|
{
|
|
auto tensor = torch::tensor({});
|
|
ASSERT_EQ(tensor.numel(), 0);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({0}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{}, {}});
|
|
ASSERT_EQ(tensor.numel(), 0);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 0}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{{}, {}}});
|
|
ASSERT_EQ(tensor.numel(), 0);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2, 0}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{{}}});
|
|
ASSERT_EQ(tensor.numel(), 0);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 0}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{{{{{{{}}}}}}}});
|
|
ASSERT_EQ(tensor.numel(), 0);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 0}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{{{{{{{}}}}, {{{{}}}}}}}});
|
|
ASSERT_EQ(tensor.numel(), 0);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 2, 1, 1, 1, 0}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
{
|
|
auto tensor = torch::tensor({{{{{{{{{{}}}}}}}}}});
|
|
ASSERT_EQ(tensor.numel(), 0);
|
|
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 0}));
|
|
ASSERT_EQ(tensor.dtype(), default_dtype);
|
|
ASSERT_FALSE(tensor.requires_grad());
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorZeroSizedDim) {
|
|
test_TorchTensorCtorZeroSizedDim_expected_dtype(/*default_dtype=*/torch::kFloat);
|
|
test_TorchTensorCtorZeroSizedDim_expected_dtype(/*default_dtype=*/torch::kDouble);
|
|
}
|
|
|
|
void test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(c10::ScalarType default_dtype) {
|
|
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
|
|
|
ASSERT_EQ(torch::tensor({1., 2., 3.}).dtype(), default_dtype);
|
|
ASSERT_EQ(torch::tensor({{1., 2., 3.}}).dtype(), default_dtype);
|
|
ASSERT_EQ(torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), default_dtype);
|
|
ASSERT_EQ(torch::tensor({{1., 2., 3.}}, torch::TensorOptions()).dtype(), default_dtype);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorWithoutSpecifyingDtype) {
|
|
ASSERT_EQ(torch::tensor({1, 2, 3}).dtype(), torch::kLong);
|
|
ASSERT_EQ(torch::tensor({{1, 2, 3}}).dtype(), torch::kLong);
|
|
ASSERT_EQ(torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong);
|
|
ASSERT_EQ(torch::tensor({{1, 2, 3}}, torch::TensorOptions()).dtype(), torch::kLong);
|
|
|
|
test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(/*default_dtype=*/torch::kFloat);
|
|
test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(/*default_dtype=*/torch::kDouble);
|
|
}
|
|
|
|
void test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(c10::ScalarType default_dtype) {
|
|
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
|
|
|
ASSERT_EQ(torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong);
|
|
ASSERT_EQ(torch::tensor(at::ArrayRef<int>({1, 2, 3}), torch::TensorOptions()).dtype(), torch::kLong);
|
|
ASSERT_EQ(torch::tensor(std::vector<int>({1, 2, 3}), torch::TensorOptions()).dtype(), torch::kLong);
|
|
|
|
ASSERT_EQ(torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), default_dtype);
|
|
ASSERT_EQ(torch::tensor(at::ArrayRef<double>({1., 2., 3.}), torch::TensorOptions()).dtype(), default_dtype);
|
|
ASSERT_EQ(torch::tensor(std::vector<double>({1., 2., 3.}), torch::TensorOptions()).dtype(), default_dtype);
|
|
|
|
ASSERT_EQ(torch::tensor({1.f, 2.f, 3.f}, torch::TensorOptions()).dtype(), default_dtype);
|
|
ASSERT_EQ(torch::tensor(at::ArrayRef<float>({1.f, 2.f, 3.f}), torch::TensorOptions()).dtype(), default_dtype);
|
|
ASSERT_EQ(torch::tensor(std::vector<float>({1.f, 2.f, 3.f}), torch::TensorOptions()).dtype(), default_dtype);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TorchTensorCtorWithNonDtypeOptions) {
|
|
test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(/*default_dtype=*/torch::kFloat);
|
|
test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(/*default_dtype=*/torch::kDouble);
|
|
}
|
|
|
|
void test_Arange_expected_dtype(c10::ScalarType default_dtype) {
|
|
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
|
|
|
ASSERT_EQ(torch::arange(0., 5).dtype(), default_dtype);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, Arange) {
|
|
{
|
|
auto x = torch::arange(0, 5);
|
|
ASSERT_EQ(x.dtype(), torch::kLong);
|
|
}
|
|
test_Arange_expected_dtype(torch::kFloat);
|
|
test_Arange_expected_dtype(torch::kDouble);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, PrettyPrintTensorDataContainer) {
|
|
{
|
|
ASSERT_EQ(
|
|
c10::str(torch::detail::TensorDataContainer(1.1)),
|
|
"1.1");
|
|
}
|
|
{
|
|
ASSERT_EQ(
|
|
c10::str(torch::detail::TensorDataContainer({1.1, 2.2})),
|
|
"{1.1, 2.2}");
|
|
}
|
|
{
|
|
ASSERT_EQ(
|
|
c10::str(torch::detail::TensorDataContainer({{1, 2}, {3, 4}})),
|
|
"{{1, 2}, {3, 4}}");
|
|
}
|
|
{
|
|
ASSERT_EQ(
|
|
c10::str(torch::detail::TensorDataContainer({{{{{{{{1.1, 2.2, 3.3}}}}}, {{{{{4.4, 5.5, 6.6}}}}}, {{{{{7.7, 8.8, 9.9}}}}}}}})),
|
|
"{{{{{{{{1.1, 2.2, 3.3}}}}}, {{{{{4.4, 5.5, 6.6}}}}}, {{{{{7.7, 8.8, 9.9}}}}}}}}");
|
|
}
|
|
{
|
|
ASSERT_EQ(
|
|
c10::str(torch::detail::TensorDataContainer({{{{{{{{{{1}}}}}}}}}})),
|
|
"{{{{{{{{{{1}}}}}}}}}}");
|
|
}
|
|
{
|
|
ASSERT_EQ(
|
|
c10::str(torch::detail::TensorDataContainer({{{{{{{{{{}}}}}}}}}})),
|
|
"{{{{{{{{{{}}}}}}}}}}");
|
|
}
|
|
{
|
|
ASSERT_EQ(
|
|
c10::str(torch::detail::TensorDataContainer({{{{{{{{{{1, 2}}}}}}}}}})),
|
|
"{{{{{{{{{{1, 2}}}}}}}}}}");
|
|
}
|
|
{
|
|
ASSERT_EQ(
|
|
c10::str(torch::detail::TensorDataContainer(at::ArrayRef<double>({1.1, 2.2}))),
|
|
"{1.1, 2.2}");
|
|
}
|
|
{
|
|
ASSERT_EQ(
|
|
c10::str(torch::detail::TensorDataContainer(std::vector<double>({1.1, 2.2}))),
|
|
"{1.1, 2.2}");
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, TensorDataContainerCallingAccessorOfWrongType) {
|
|
{
|
|
ASSERT_THROWS_WITH(
|
|
torch::detail::TensorDataContainer(1.1).init_list(),
|
|
"Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`");
|
|
ASSERT_THROWS_WITH(
|
|
torch::detail::TensorDataContainer(1.1).tensor(),
|
|
"Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`");
|
|
}
|
|
{
|
|
ASSERT_THROWS_WITH(
|
|
torch::detail::TensorDataContainer({1.1, 2.2}).scalar(),
|
|
"Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`");
|
|
ASSERT_THROWS_WITH(
|
|
torch::detail::TensorDataContainer({1.1, 2.2}).tensor(),
|
|
"Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`");
|
|
}
|
|
{
|
|
ASSERT_THROWS_WITH(
|
|
torch::detail::TensorDataContainer(at::ArrayRef<double>({1.1, 2.2})).scalar(),
|
|
"Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`");
|
|
ASSERT_THROWS_WITH(
|
|
torch::detail::TensorDataContainer(at::ArrayRef<double>({1.1, 2.2})).init_list(),
|
|
"Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`");
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, FromBlob) {
|
|
std::vector<double> v = {1.0, 2.0, 3.0};
|
|
auto tensor = torch::from_blob(
|
|
v.data(), v.size(), torch::dtype(torch::kFloat64).requires_grad(true));
|
|
ASSERT_TRUE(tensor.requires_grad());
|
|
ASSERT_EQ(tensor.dtype(), torch::kFloat64);
|
|
ASSERT_EQ(tensor.numel(), 3);
|
|
ASSERT_EQ(tensor[0].item<double>(), 1);
|
|
ASSERT_EQ(tensor[1].item<double>(), 2);
|
|
ASSERT_EQ(tensor[2].item<double>(), 3);
|
|
// Above syntax did not copy the data, and has nullptr deleter context.
|
|
ASSERT_EQ(tensor.storage().data_ptr().get_context(), nullptr);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, FromBlobUsesDeleter) {
|
|
bool called = false;
|
|
{
|
|
std::vector<int32_t> v = {1, 2, 3};
|
|
auto tensor = torch::from_blob(
|
|
v.data(),
|
|
v.size(),
|
|
/*deleter=*/[&called](void* data) { called = true; },
|
|
torch::kInt32);
|
|
}
|
|
ASSERT_TRUE(called);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, FromBlobWithStrides) {
|
|
// clang-format off
|
|
std::vector<int32_t> v = {
|
|
1, 2, 3,
|
|
4, 5, 6,
|
|
7, 8, 9
|
|
};
|
|
// clang-format on
|
|
auto tensor = torch::from_blob(
|
|
v.data(),
|
|
/*sizes=*/{3, 3},
|
|
/*strides=*/{1, 3},
|
|
torch::kInt32);
|
|
ASSERT_EQ(tensor.dtype(), torch::kInt32);
|
|
ASSERT_EQ(tensor.numel(), 9);
|
|
const std::vector<int64_t> expected_strides = {1, 3};
|
|
ASSERT_EQ(tensor.strides(), expected_strides);
|
|
for (int64_t i = 0; i < tensor.size(0); ++i) {
|
|
for (int64_t j = 0; j < tensor.size(1); ++j) {
|
|
// NOTE: This is column major because the strides are swapped.
|
|
EXPECT_EQ(tensor[i][j].item<int32_t>(), 1 + (j * tensor.size(1)) + i);
|
|
}
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, Item) {
|
|
{
|
|
torch::Tensor tensor = torch::tensor(3.14);
|
|
torch::Scalar scalar = tensor.item();
|
|
ASSERT_NEAR(scalar.to<float>(), 3.14, 1e-5);
|
|
}
|
|
{
|
|
torch::Tensor tensor = torch::tensor(123);
|
|
torch::Scalar scalar = tensor.item();
|
|
ASSERT_EQ(scalar.to<int>(), 123);
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, Item_CUDA) {
|
|
{
|
|
torch::Tensor tensor = torch::tensor(3.14, torch::kCUDA);
|
|
torch::Scalar scalar = tensor.item();
|
|
ASSERT_NEAR(scalar.to<float>(), 3.14, 1e-5);
|
|
}
|
|
{
|
|
torch::Tensor tensor = torch::tensor(123, torch::kCUDA);
|
|
torch::Scalar scalar = tensor.item();
|
|
ASSERT_EQ(scalar.to<int>(), 123);
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, DataPtr) {
|
|
auto tensor = at::empty({3, 4}, at::kFloat);
|
|
auto tensor_not_copy = tensor.to(tensor.options());
|
|
ASSERT_EQ(tensor_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
|
|
ASSERT_EQ(tensor_not_copy.data_ptr(), tensor.data_ptr());
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, Data) {
|
|
const auto tensor = torch::rand({3, 3});
|
|
ASSERT_TRUE(torch::equal(tensor, tensor.data()));
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, BackwardAndGrad) {
|
|
auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
auto y = x * x;
|
|
y.backward();
|
|
ASSERT_EQ(x.grad().item<float>(), 10.0);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, BackwardCreatesOnesGrad) {
|
|
const auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
x.backward();
|
|
ASSERT_TRUE(torch::equal(x.grad(),
|
|
torch::ones_like(x)));
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, BackwardNonScalarOutputs) {
|
|
auto x = torch::randn({5, 5}, torch::requires_grad());
|
|
auto y = x * x;
|
|
ASSERT_THROWS_WITH(y.backward(),
|
|
"grad can be implicitly created only for scalar outputs");
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, IsLeaf) {
|
|
auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
auto y = x * x;
|
|
ASSERT_TRUE(x.is_leaf());
|
|
ASSERT_FALSE(y.is_leaf());
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, OutputNr) {
|
|
auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
auto y = x * x;
|
|
ASSERT_EQ(x.output_nr(), 0);
|
|
ASSERT_EQ(y.output_nr(), 0);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, Version) {
|
|
auto x = torch::ones(3);
|
|
ASSERT_EQ(x._version(), 0);
|
|
x.mul_(2);
|
|
ASSERT_EQ(x._version(), 1);
|
|
x.add_(1);
|
|
ASSERT_EQ(x._version(), 2);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, Detach) {
|
|
auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
auto y = x * x;
|
|
const auto y_detached = y.detach();
|
|
ASSERT_FALSE(y.is_leaf());
|
|
ASSERT_TRUE(y_detached.is_leaf());
|
|
ASSERT_FALSE(y_detached.requires_grad());
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, DetachInplace) {
|
|
auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
auto y = x * x;
|
|
auto y_detached = y.detach_();
|
|
ASSERT_TRUE(y.is_leaf());
|
|
ASSERT_FALSE(y.requires_grad());
|
|
ASSERT_TRUE(y_detached.is_leaf());
|
|
ASSERT_FALSE(y_detached.requires_grad());
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, SetData) {
|
|
auto x = torch::randn({5});
|
|
auto y = torch::randn({5});
|
|
ASSERT_FALSE(torch::equal(x, y));
|
|
ASSERT_NE(x.data_ptr<float>(), y.data_ptr<float>());
|
|
|
|
x.set_data(y);
|
|
ASSERT_TRUE(torch::equal(x, y));
|
|
ASSERT_EQ(x.data_ptr<float>(), y.data_ptr<float>());
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, RequiresGradInplace) {
|
|
auto x = torch::tensor({5.0});
|
|
x.requires_grad_(true);
|
|
ASSERT_TRUE(x.requires_grad());
|
|
|
|
auto y = x * x;
|
|
ASSERT_THROWS_WITH(y.requires_grad_(false),
|
|
"you can only change requires_grad flags of leaf variables.");
|
|
|
|
x.requires_grad_(false);
|
|
ASSERT_FALSE(x.requires_grad());
|
|
|
|
const auto int_tensor = torch::tensor({5}, at::TensorOptions().dtype(torch::kInt));
|
|
ASSERT_THROWS_WITH(int_tensor.requires_grad_(true),
|
|
"Only Tensors of floating point and complex dtype can require gradients");
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, StdDimension) {
|
|
// Test that std(0) doesn't select the std(unbiased=False) overload (gh-40287)
|
|
auto x = torch::randn({4, 3});
|
|
auto std = x.std(0);
|
|
|
|
ASSERT_EQ(x.var(0).numel(), 3);
|
|
ASSERT_EQ(x.std(0).numel(), 3);
|
|
|
|
ASSERT_EQ(x.var(0, /*unbiased=*/true).numel(), 3);
|
|
ASSERT_EQ(x.std(0, /*unbiased=*/true).numel(), 3);
|
|
|
|
ASSERT_EQ(torch::var(x, 0).numel(), 3);
|
|
ASSERT_EQ(std::get<0>(torch::var_mean(x, 0)).numel(), 3);
|
|
ASSERT_EQ(torch::std(x, 0).numel(), 3);
|
|
ASSERT_EQ(std::get<0>(torch::std_mean(x, 0)).numel(), 3);
|
|
|
|
ASSERT_EQ(torch::var(x, 0, /*unbiased=*/true).numel(), 3);
|
|
ASSERT_EQ(std::get<0>(torch::var_mean(x, 0, /*unbiased=*/true)).numel(), 3);
|
|
ASSERT_EQ(torch::std(x, 0, /*unbiased=*/true).numel(), 3);
|
|
ASSERT_EQ(std::get<0>(torch::std_mean(x, 0, /*unbiased=*/true)).numel(), 3);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
TEST(TensorTest, ReshapeAlias) {
|
|
// Tests the behavior of the _reshape_alias private operator so
|
|
// that it matches the behavior of as_strided and view.
|
|
auto x = torch::randn({3, 3});
|
|
ASSERT_TRUE(torch::equal(
|
|
torch::_reshape_alias(x, {2, 2}, {1, 2}),
|
|
torch::as_strided(x, {2, 2}, {1, 2})
|
|
));
|
|
ASSERT_TRUE(torch::equal(
|
|
torch::_reshape_alias(x, {9}, {1}),
|
|
x.view({-1})
|
|
));
|
|
|
|
// Test that the backward works fine.
|
|
auto y = torch::randn({3, 3}, torch::requires_grad(true));
|
|
auto z = torch::clone(y).detach().requires_grad_(true);
|
|
(y * y).view({-1}).mean().backward();
|
|
torch::_reshape_alias((z * z), {9}, {1}).mean().backward();
|
|
ASSERT_TRUE(torch::equal(
|
|
y.grad(),
|
|
z.grad()
|
|
));
|
|
}
|