pytorch/test/cpp/api/misc.cpp
Will Feng 085bd15880 Add TORCH_WARN_ONCE, and use it in Tensor.data<T>() (#25207)
Summary:
This PR adds `TORCH_WARN_ONCE` macro, and use it in `Tensor.data<T>()`.

cc. gchanan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25207

Differential Revision: D17066263

Pulled By: yf225

fbshipit-source-id: 411c6ccc8326fb27ff885fee4638df8b5ba4d449
2019-08-27 21:42:44 -07:00

86 lines
1.9 KiB
C++

#include <gtest/gtest.h>
#include <torch/nn/init.h>
#include <torch/nn/modules/linear.h>
#include <torch/types.h>
#include <torch/utils.h>
#include <test/cpp/api/support.h>
#include <functional>
using namespace torch::test;
void torch_warn_once_A() {
TORCH_WARN_ONCE("warn once");
}
void torch_warn_once_B() {
TORCH_WARN_ONCE("warn something else once");
}
void torch_warn() {
TORCH_WARN("warn multiple times");
}
TEST(UtilsTest, WarnOnce) {
{
std::stringstream buffer;
CerrRedirect cerr_redirect(buffer.rdbuf());
torch_warn_once_A();
torch_warn_once_A();
torch_warn_once_B();
torch_warn_once_B();
ASSERT_EQ(count_substr_occurrences(buffer.str(), "warn once"), 1);
ASSERT_EQ(count_substr_occurrences(buffer.str(), "warn something else once"), 1);
}
{
std::stringstream buffer;
CerrRedirect cerr_redirect(buffer.rdbuf());
torch_warn();
torch_warn();
torch_warn();
ASSERT_EQ(count_substr_occurrences(buffer.str(), "warn multiple times"), 3);
}
}
TEST(NoGradTest, SetsGradModeCorrectly) {
torch::manual_seed(0);
torch::NoGradGuard guard;
torch::nn::Linear model(5, 2);
auto x = torch::randn({10, 5}, torch::requires_grad());
auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
ASSERT_FALSE(model->weight.grad().defined());
}
struct AutogradTest : torch::test::SeedingFixture {
AutogradTest() {
x = torch::randn({3, 3}, torch::requires_grad());
y = torch::randn({3, 3});
z = x * y;
}
torch::Tensor x, y, z;
};
TEST_F(AutogradTest, CanTakeDerivatives) {
z.backward();
ASSERT_TRUE(x.grad().allclose(y));
}
TEST_F(AutogradTest, CanTakeDerivativesOfZeroDimTensors) {
z.sum().backward();
ASSERT_TRUE(x.grad().allclose(y));
}
TEST_F(AutogradTest, CanPassCustomGradientInputs) {
z.sum().backward(torch::ones({}) * 2);
ASSERT_TRUE(x.grad().allclose(y * 2));
}