mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Hi yf225 , here is the C++ frontend API MultiMarginLoss implementation and tests https://github.com/pytorch/pytorch/issues/27198. Could you review it and tell me if it is okay? I am not entirely sure I used `c10::optional` correctly, but `options.weight()` resulted in a compilation error, so I went with `options.weight().value()` instead of `value_or()` to follow the logic in `torch.nn._WeightedLoss.register_buffer` (where one can pass a `None` value). Oh, and are the tests supposed to be skipped or did I do something wrong? I ran `pytest test/test_cpp_api_parity.py -k Loss -v` , and the `L1Loss` test passed but the others were skipped... Thank you for the review in any case! Pull Request resolved: https://github.com/pytorch/pytorch/pull/27424 Differential Revision: D17839963 Pulled By: yf225 fbshipit-source-id: f4b6012590cf22d56d42751c214df80cce717cb8
422 lines
14 KiB
C++
422 lines
14 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/torch.h>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
namespace F = torch::nn::functional;
|
|
|
|
using namespace torch::nn;
|
|
|
|
struct FunctionalTest : torch::test::SeedingFixture {};
|
|
|
|
TEST_F(FunctionalTest, MaxPool1d) {
|
|
auto x = torch::ones({1, 1, 5});
|
|
auto y = F::max_pool1d(x, MaxPool1dOptions(3).stride(2));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, MaxPool2d) {
|
|
auto x = torch::ones({2, 5, 5});
|
|
auto y = F::max_pool2d(x, MaxPool2dOptions(3).stride(2));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, MaxPool3d) {
|
|
auto x = torch::ones({2, 5, 5, 5});
|
|
auto y = F::max_pool3d(x, MaxPool3dOptions(3).stride(2));
|
|
|
|
ASSERT_EQ(y.ndimension(), 4);
|
|
ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2, 2}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, AvgPool1d) {
|
|
auto x = torch::ones({1, 1, 5});
|
|
auto y = F::avg_pool1d(x, AvgPool1dOptions(3).stride(2));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, AvgPool2d) {
|
|
auto x = torch::ones({2, 5, 5});
|
|
auto y = F::avg_pool2d(x, AvgPool2dOptions(3).stride(2));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, AvgPool3d) {
|
|
auto x = torch::ones({2, 5, 5, 5});
|
|
auto y = F::avg_pool3d(x, AvgPool3dOptions(3).stride(2));
|
|
|
|
ASSERT_EQ(y.ndimension(), 4);
|
|
ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2, 2}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, CosineSimilarity) {
|
|
auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat);
|
|
auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat);
|
|
auto output =
|
|
F::cosine_similarity(input1, input2, CosineSimilarityOptions().dim(1));
|
|
auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat);
|
|
ASSERT_TRUE(output.allclose(expected, 1e-04));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, PairwiseDistance) {
|
|
auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat);
|
|
auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat);
|
|
auto output =
|
|
F::pairwise_distance(input1, input2, PairwiseDistanceOptions(1));
|
|
auto expected = torch::tensor({6, 6}, torch::kFloat);
|
|
ASSERT_TRUE(output.allclose(expected));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, PDist) {
|
|
{
|
|
auto input = torch::tensor({{-1.0, -5.0, -1.0}, {2.0, 4.0, 6.0}});
|
|
auto output = F::pdist(input);
|
|
auto expected = torch::tensor({11.7898});
|
|
ASSERT_TRUE(output.allclose(expected));
|
|
}
|
|
{
|
|
auto input = torch::tensor({{1.0, -1.0}, {1.0, 3.0}, {3.0, 3.0}});
|
|
auto output = F::pdist(input, 1.5);
|
|
auto expected = torch::tensor({4.0, 4.8945, 2.0});
|
|
ASSERT_TRUE(output.allclose(expected));
|
|
}
|
|
}
|
|
|
|
TEST_F(FunctionalTest, AdaptiveMaxPool1d) {
|
|
auto x = torch::ones({1, 1, 5});
|
|
auto y = F::adaptive_max_pool1d(x, AdaptiveMaxPool1dOptions(3));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 3})));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 3}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, AdaptiveMaxPool2d) {
|
|
auto x = torch::ones({2, 5, 5});
|
|
auto y = F::adaptive_max_pool2d(x, AdaptiveMaxPool2dOptions(3));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3})));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 3, 3}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, AdaptiveMaxPool3d) {
|
|
auto x = torch::ones({2, 5, 5, 5});
|
|
auto y = F::adaptive_max_pool3d(x, AdaptiveMaxPool3dOptions(3));
|
|
|
|
ASSERT_EQ(y.ndimension(), 4);
|
|
ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3, 3})));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 3, 3, 3}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, AdaptiveAvgPool1d) {
|
|
auto x = torch::ones({1, 1, 5});
|
|
auto y = F::adaptive_avg_pool1d(x, AdaptiveAvgPool1dOptions(3));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 3})));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 3}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, AdaptiveAvgPool2d) {
|
|
auto x = torch::ones({2, 5, 5});
|
|
auto y = F::adaptive_avg_pool2d(x, AdaptiveAvgPool2dOptions(3));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3})));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 3, 3}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, AdaptiveAvgPool3d) {
|
|
auto x = torch::ones({2, 5, 5, 5});
|
|
auto y = F::adaptive_avg_pool3d(x, AdaptiveAvgPool3dOptions(3));
|
|
|
|
ASSERT_EQ(y.ndimension(), 4);
|
|
ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3, 3})));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 3, 3, 3}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, HingeEmbeddingLoss) {
|
|
auto input = torch::tensor({{2, 22, 4}, {20, 10, 0}}, torch::kFloat);
|
|
auto target = torch::tensor({{2, 6, 4}, {1, 10, 0}}, torch::kFloat);
|
|
auto output = F::hinge_embedding_loss(
|
|
input, target, HingeEmbeddingLossOptions().margin(2));
|
|
auto expected = torch::tensor({10}, torch::kFloat);
|
|
|
|
ASSERT_TRUE(output.allclose(expected));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, MultiMarginLoss) {
|
|
auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat);
|
|
auto input = torch::tensor({{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}}, torch::requires_grad());
|
|
auto target = torch::tensor({2, 1, 0}, torch::kLong);
|
|
auto output = F::multi_margin_loss(
|
|
input, target, MultiMarginLossOptions().margin(2).weight(weight));
|
|
auto expected = torch::tensor({0.305556}, torch::kFloat);
|
|
|
|
ASSERT_TRUE(output.allclose(expected, 1e-04));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, CosineEmbeddingLoss) {
|
|
auto input1 = torch::tensor({{2, 3, 4}, {6, 2, 4}});
|
|
auto input2 = torch::tensor({{2, 3, 5}, {9, 12, 0}});
|
|
auto target = torch::tensor({1, -1});
|
|
auto output = F::cosine_embedding_loss(
|
|
input1, input2, target, CosineEmbeddingLossOptions().margin(0.5));
|
|
auto expected = torch::tensor({0.1004}, torch::kFloat);
|
|
|
|
ASSERT_TRUE(output.allclose(expected, 1e-4));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, MaxUnpool1d) {
|
|
auto x = torch::tensor({{{2, 4, 5}}}, torch::requires_grad());
|
|
auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
|
|
auto y = F::max_unpool1d(x, indices, MaxUnpool1dOptions(3));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_TRUE(torch::allclose(
|
|
y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 9}));
|
|
|
|
x = torch::tensor({{{2, 4, 5}}}, torch::requires_grad());
|
|
indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
|
|
y = F::max_unpool1d(
|
|
x, indices, MaxUnpool1dOptions(3), c10::IntArrayRef({1, 1, 9}));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_TRUE(torch::allclose(
|
|
y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 9}));
|
|
|
|
x = torch::tensor({{{2, 4, 5}}}, torch::requires_grad());
|
|
indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
|
|
y = F::max_unpool1d(x, indices, MaxUnpool1dOptions(3).stride(2).padding(1));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_TRUE(
|
|
torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat)));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 5}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, MaxUnpool2d) {
|
|
auto indices = torch::tensor({
|
|
{{{ 6, 8, 9},
|
|
{16, 18, 19},
|
|
{21, 23, 24}}},
|
|
{{{ 6, 8, 9},
|
|
{16, 18, 19},
|
|
{21, 23, 24}}}}, torch::kLong);
|
|
auto x = torch::tensor({
|
|
{{{ 6, 8, 9},
|
|
{16, 18, 19},
|
|
{21, 23, 24}}},
|
|
{{{31, 33, 34},
|
|
{41, 43, 44},
|
|
{46, 48, 49}}}}, torch::requires_grad());
|
|
auto y = F::max_unpool2d(x, indices, MaxUnpool2dOptions(3).stride(2).padding(1));
|
|
|
|
ASSERT_EQ(y.dim(), 4);
|
|
ASSERT_TRUE(torch::allclose(y, torch::tensor(
|
|
{{{{ 0, 0, 0, 0, 0},
|
|
{ 0, 6, 0, 8, 9},
|
|
{ 0, 0, 0, 0, 0},
|
|
{ 0, 16, 0, 18, 19},
|
|
{ 0, 21, 0, 23, 24}}},
|
|
{{{ 0, 0, 0, 0, 0},
|
|
{ 0, 31, 0, 33, 34},
|
|
{ 0, 0, 0, 0, 0},
|
|
{ 0, 41, 0, 43, 44},
|
|
{ 0, 46, 0, 48, 49}}}} , torch::kFloat)));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 1, 5, 5}));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, ELU) {
|
|
const auto size = 3;
|
|
for (const auto inplace : {false, true}) {
|
|
for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) {
|
|
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
|
x.resize_({size, size, size});
|
|
auto y_exp = torch::max(torch::zeros_like(x), x) +
|
|
torch::min(torch::zeros_like(x), alpha * (torch::exp(x) - 1.0));
|
|
auto y = F::elu(x, ELUOptions().alpha(alpha).inplace(inplace));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size}));
|
|
ASSERT_TRUE(torch::allclose(y, y_exp));
|
|
if (inplace) {
|
|
ASSERT_TRUE(torch::allclose(x, y_exp));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST_F(FunctionalTest, SELU) {
|
|
{
|
|
const double scale = 1.0507009873554804934193349852946;
|
|
const double alpha = 1.6732632423543772848170429916717;
|
|
for (const auto inplace : {false, true}) {
|
|
auto input = torch::randn({5, 5});
|
|
auto expected = scale *
|
|
(torch::max(torch::zeros_like(input), input) +
|
|
torch::min(
|
|
torch::zeros_like(input), alpha * (torch::exp(input) - 1)));
|
|
auto output = F::selu(input, inplace);
|
|
|
|
ASSERT_TRUE(output.allclose(expected));
|
|
if (inplace) {
|
|
ASSERT_TRUE(input.allclose(expected));
|
|
}
|
|
}
|
|
}
|
|
{
|
|
auto input = torch::arange(0, 9, torch::kDouble).view({3, 3});
|
|
auto output = F::selu(input);
|
|
auto expected = F::selu(input, false);
|
|
ASSERT_TRUE(output.allclose(expected));
|
|
}
|
|
}
|
|
|
|
TEST_F(FunctionalTest, Hardshrink) {
|
|
const auto size = 3;
|
|
for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) {
|
|
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
|
x.resize_({size, size, size}).set_requires_grad(true);
|
|
auto y = F::hardshrink(x, HardshrinkOptions().lambda(lambda));
|
|
torch::Tensor s = y.sum();
|
|
|
|
s.backward();
|
|
ASSERT_EQ(s.ndimension(), 0);
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size}));
|
|
auto y_exp = (x.abs() > lambda) * x;
|
|
ASSERT_TRUE(torch::allclose(y, y_exp));
|
|
}
|
|
}
|
|
|
|
TEST_F(FunctionalTest, OneHot) {
|
|
{ // Test #1
|
|
auto x = torch::arange(0, 5, torch::kLong);
|
|
auto y = F::one_hot(x % 3);
|
|
auto expected = torch::tensor(
|
|
{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}, {1, 0, 0}, {0, 1, 0}}, torch::kLong);
|
|
|
|
ASSERT_EQ(y.ndimension(), 2);
|
|
ASSERT_TRUE(torch::allclose(y, expected));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({5, 3}));
|
|
}
|
|
|
|
{ // Test #2
|
|
auto x = torch::arange(0, 5, torch::kLong);
|
|
auto y = F::one_hot(x % 3, 5);
|
|
auto expected = torch::tensor(
|
|
{{1, 0, 0, 0, 0},
|
|
{0, 1, 0, 0, 0},
|
|
{0, 0, 1, 0, 0},
|
|
{1, 0, 0, 0, 0},
|
|
{0, 1, 0, 0, 0}},
|
|
torch::kLong);
|
|
|
|
ASSERT_EQ(y.ndimension(), 2);
|
|
ASSERT_TRUE(torch::allclose(y, expected));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({5, 5}));
|
|
}
|
|
|
|
{ // Test #3
|
|
auto x = torch::arange(0, 6, torch::kLong);
|
|
auto y = F::one_hot(x.view(torch::IntArrayRef({3, 2})) % 3);
|
|
auto expected = torch::tensor(
|
|
{{{1, 0, 0}, {0, 1, 0}},
|
|
{{0, 0, 1}, {1, 0, 0}},
|
|
{{0, 1, 0}, {0, 0, 1}}},
|
|
torch::kLong);
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_TRUE(torch::allclose(y, expected));
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 2, 3}));
|
|
}
|
|
}
|
|
|
|
TEST_F(FunctionalTest, Hardtanh) {
|
|
const auto size = 3;
|
|
for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) {
|
|
for (const auto max_val : {0.0, 0.42, 1.0, 4.2}) {
|
|
for (const auto inplace : {false, true}) {
|
|
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
|
x.resize_({size, size, size});
|
|
auto y_exp = (x < min_val) * min_val +
|
|
((x >= min_val) * (x <= max_val)) * x +
|
|
(x > max_val) * max_val;
|
|
auto y = F::hardtanh(x,HardtanhOptions().min_val(min_val)
|
|
.max_val(max_val).inplace(inplace));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size}));
|
|
ASSERT_TRUE(torch::allclose(y, y_exp));
|
|
if (inplace) {
|
|
ASSERT_TRUE(torch::allclose(x, y_exp));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST_F(FunctionalTest, LeakyReLU) {
|
|
const auto size = 3;
|
|
for (const auto negative_slope : {0.0, 0.42, 1.0}) {
|
|
for (const auto inplace : {false, true}) {
|
|
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
|
x.resize_({size, size, size});
|
|
auto y_exp = (x < 0) * x * negative_slope + (x >= 0) * x;
|
|
auto y = F::leaky_relu(x, LeakyReLUOptions()
|
|
.negative_slope(negative_slope).inplace(inplace));
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size}));
|
|
ASSERT_TRUE(torch::allclose(y, y_exp));
|
|
if (inplace) {
|
|
ASSERT_TRUE(torch::allclose(x, y_exp));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST_F(FunctionalTest, LogSigmoid) {
|
|
const auto size = 3;
|
|
LogSigmoid model;
|
|
auto x = torch::linspace(-10.0, 10.0, size * size * size);
|
|
x.resize_({size, size, size});
|
|
auto y = F::logsigmoid(x);
|
|
|
|
ASSERT_EQ(y.ndimension(), 3);
|
|
ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size}));
|
|
auto y_exp = torch::log(torch::ones_like(x)/(torch::ones_like(x) + torch::exp(torch::neg(x))));
|
|
ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
|
|
}
|
|
|
|
TEST_F(FunctionalTest, Softmax) {
|
|
auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
|
|
auto output = F::softmax(input, /*dim=*/1);
|
|
auto sum = torch::sum(torch::exp(input), 1);
|
|
|
|
for (int i = 0; i < 2; i++) {
|
|
auto expected = torch::exp(input[i]) / sum[i];
|
|
ASSERT_TRUE(torch::allclose(output[i], expected));
|
|
}
|
|
}
|