pytorch/test/cpp/api/nn_utils.cpp
jokerkeny aa30176c68 Add C++ API clip_grad_value_ for nn:utils (#28736)
Summary:
Adds C++ API clip_grad_value_ for torch::nn:utils module.
Also, fix the for indent level error in the original test/test_nn.py.

Issue: https://github.com/pytorch/pytorch/issues/25883

Reviewer: yf225
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28736

Differential Revision: D18263807

Pulled By: yf225

fbshipit-source-id: 29282450bd2099df16925e1d0edd3d933f6eeb9b
2019-10-31 19:11:54 -07:00

143 lines
4.8 KiB
C++

#include <gtest/gtest.h>
#include <torch/torch.h>
#include <test/cpp/api/support.h>
using namespace torch::nn;
using namespace torch::test;
struct NNUtilsTest : torch::test::SeedingFixture {};
TEST_F(NNUtilsTest, ClipGradNorm) {
auto linear_layer = Linear(10, 10);
float max_norm = 2;
auto compute_norm = [linear_layer](float norm_type) -> float {
float total_norm = 0.0;
if (norm_type != std::numeric_limits<float>::infinity()) {
for (const auto& p : linear_layer->parameters()) {
total_norm +=
p.grad().data().abs().pow(norm_type).sum().item().toFloat();
}
return std::pow(total_norm, 1.0 / norm_type);
} else {
for (const auto& p : linear_layer->parameters()) {
auto param_max = p.grad().data().abs().max().item().toFloat();
if (param_max > total_norm) {
total_norm = param_max;
}
}
return total_norm;
}
};
auto compare_scaling =
[linear_layer](const std::vector<torch::Tensor>& grads) -> torch::Tensor {
std::vector<torch::Tensor> p_scale;
for (int i = 0; i < grads.size(); i++) {
auto param = linear_layer->parameters()[i];
auto grad = grads[i];
p_scale.push_back(param.grad().data().div(grad).view(-1));
}
auto scale = torch::cat(p_scale);
return scale; // need to assert std is 0.
};
std::vector<torch::Tensor> grads = {
torch::arange(1.0, 101).view({10, 10}),
torch::ones({10}).div(1000),
};
std::vector<float> norm_types = {
0.5,
1.5,
2.0,
4.0,
std::numeric_limits<float>::infinity(),
};
for (auto norm_type : norm_types) {
for (int i = 0; i < grads.size(); i++) {
linear_layer->parameters()[i].grad() =
grads[i].clone().view_as(linear_layer->parameters()[i].data());
}
auto norm_before = compute_norm(norm_type);
auto layer_params = linear_layer->parameters();
auto norm = utils::clip_grad_norm_(layer_params, max_norm, norm_type);
auto norm_after = compute_norm(norm_type);
ASSERT_FLOAT_EQ(norm, norm_before);
ASSERT_FLOAT_EQ(norm_after, max_norm);
ASSERT_LE(norm_after, max_norm);
auto scaled = compare_scaling(grads);
ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7);
}
// Small gradients should be lefted unchanged
grads = {
torch::rand({10, 10}).div(10000),
torch::ones(10).div(500),
};
for (auto norm_type : norm_types) {
for (int i = 0; i < grads.size(); i++) {
linear_layer->parameters()[i].grad().data().copy_(grads[i]);
}
auto norm_before = compute_norm(norm_type);
auto layer_params = linear_layer->parameters();
auto norm = utils::clip_grad_norm_(layer_params, max_norm, norm_type);
auto norm_after = compute_norm(norm_type);
ASSERT_FLOAT_EQ(norm, norm_before);
ASSERT_FLOAT_EQ(norm_before, norm_after);
ASSERT_LE(norm_after, max_norm);
auto scaled = compare_scaling(grads);
ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7);
ASSERT_EQ(scaled[0].item().toFloat(), 1);
}
// should accept a single tensor as input
auto p1 = torch::randn({10, 10});
auto p2 = torch::randn({10, 10});
auto g = torch::arange(1., 101).view({10, 10});
p1.grad() = g.clone();
p2.grad() = g.clone();
for (const auto norm_type : norm_types) {
utils::clip_grad_norm_(p1, max_norm, norm_type);
std::vector<torch::Tensor> params = {p2};
utils::clip_grad_norm_(params, max_norm, norm_type);
ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad()));
}
}
TEST_F(NNUtilsTest, ClipGradValue) {
auto linear_layer = Linear(10, 10);
float clip_value = 2.5;
torch::Tensor grad_w = torch::arange(-50., 50).view({10, 10}).div_(5);
torch::Tensor grad_b = torch::ones({10}).mul_(2);
std::vector<std::vector<torch::Tensor>> grad_lists = {
{grad_w, grad_b}, {grad_w, torch::Tensor()}};
for (auto grad_list : grad_lists) {
for (int i = 0; i < grad_list.size(); i++) {
auto p = linear_layer->parameters()[i];
auto g = grad_list[i];
p.grad() = g.defined() ? g.clone().view_as(p.data()) : g;
}
auto layer_params = linear_layer->parameters();
utils::clip_grad_value_(layer_params, clip_value);
for (int i = 0; i < layer_params.size(); i++) {
if (layer_params[i].grad().defined()) {
ASSERT_LE(
layer_params[i].grad().data().max().item().toFloat(), clip_value);
ASSERT_GE(
layer_params[i].grad().data().min().item().toFloat(), -clip_value);
}
}
}
// Should accept a single Tensor as input
auto p1 = torch::randn({10, 10});
auto p2 = torch::randn({10, 10});
auto g = torch::arange(-50., 50).view({10, 10}).div_(5);
p1.grad() = g.clone();
p2.grad() = g.clone();
utils::clip_grad_value_(p1, clip_value);
std::vector<torch::Tensor> params = {p2};
utils::clip_grad_value_(params, clip_value);
ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad()));
}