pytorch/test/cpp/api/autograd.cpp
mal ec13f18390 Allow empty Variables to be saved for backwards (#23618)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23618

For example: `save_for_backward({Variable(), x, Variable()})` should be allowed, so that this is consistent with the python API behaviour.

Test Plan: Added a test similar to the python test `test_save_none_for_backward` from test_autograd.py.

Differential Revision: D16589402

fbshipit-source-id: 847544ad8fc10772954d8629ad5a62bfdc1a66c1
2019-07-31 19:51:35 -07:00

151 lines
4.5 KiB
C++

#include <gtest/gtest.h>
#include <torch/autograd.h>
#include <torch/utils.h>
#include <test/cpp/api/support.h>
using namespace torch::autograd;
#define ASSERT_VARIABLE_EQ(a,b) ASSERT_TRUE(torch::allclose((a),(b)))
std::string graph_desc(std::shared_ptr<Node> node) {
if (!node) {
return "None";
}
auto result = node->name() + "(";
auto next_edges = node->next_edges();
for(auto& edge : next_edges) {
result += graph_desc(edge.function);
}
return result+")";
}
TEST(CustomAutogradTest, CustomFunction) {
struct MyFunction : public Function<MyFunction> {
static variable_list forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) {
ctx->saved_data["mul"] = mul;
ctx->save_for_backward({var1, var2});
return {var1 + mul*var2 + var1*var2};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
int mul = ctx->saved_data["mul"].toInt();
auto saved = ctx->get_saved_variables();
auto var1 = saved[0];
auto var2 = saved[1];
variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1};
return output;
}
};
Variable x = torch::randn({5,5}, torch::requires_grad());
Variable y = torch::randn({5,5}, torch::requires_grad());
auto res = MyFunction::apply(x,2,y)[0];
auto go = torch::ones({}, torch::requires_grad());
res.sum().backward(go, false, true);
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5,5}));
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5,5})*2);
}
TEST(CustomAutogradTest, FunctionReturnsInput) {
struct MyFunction : public Function<MyFunction> {
static variable_list forward(AutogradContext *ctx, Variable var1) {
return {var1};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
return {grad_output[0]*2};
}
};
Variable x(torch::ones(1, torch::requires_grad()));
MyFunction::apply(x)[0].backward(torch::ones(1) , true, true);
ASSERT_VARIABLE_EQ(x.grad(), torch::full(1,2));
}
TEST(CustomAutogradTest, NoGradCustomFunction) {
// Custom Function should respect grad mode
struct MyOp : public Function<MyOp> {
static variable_list forward(AutogradContext *ctx, Variable x) {
return {x+1};
}
static variable_list backward(AutogradContext *ctx, variable_list dy) {
return dy;
}
};
auto x = torch::ones({5,5}, torch::requires_grad());
{
at::NoGradGuard no_grad;
auto y = MyOp::apply(x)[0];
ASSERT_FALSE(y.requires_grad());
}
}
TEST(CustomAutogradTest, MarkNonDifferentiable) {
struct MyFunction : public Function<MyFunction> {
static variable_list forward(AutogradContext *ctx, Variable v) {
Variable output = v > 0;
ctx->mark_non_differentiable({output});
return {output};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
return { (grad_output[0]*0.0) };
}
};
auto x = torch::randn({5,5}, torch::requires_grad());
auto mask = MyFunction::apply(x)[0];
ASSERT_FALSE(mask.requires_grad());
auto y = x.masked_fill(mask, 0);
y.sum().backward();
}
TEST(CustomAutogradTest, ReturnLeafInplace) {
struct Inplace : public Function<Inplace> {
static variable_list forward(AutogradContext *ctx, Variable a, Variable b) {
ctx->mark_dirty({a});
return {a.add_(b), b+2};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
return {grad_output[0], grad_output[0] + grad_output[1]};
}
};
Variable x = torch::randn({5,5});
Variable y = torch::randn({5,5}, torch::requires_grad());
auto out = Inplace::apply(x,y);
auto &q = out[0];
ASSERT_TRUE(torch::equal(q, x));
ASSERT_TRUE(q.requires_grad());
q.sum().backward();
ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5,5}));
}
TEST(CustomAutogradTest, SaveEmptyForBackward) {
struct MyFunction : public Function<MyFunction> {
static variable_list forward(AutogradContext *ctx, Variable input) {
ctx->save_for_backward({Variable(), input, Variable()});
return {input*input};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
auto saved = ctx->get_saved_variables();
EXPECT_FALSE(saved[0].defined());
EXPECT_FALSE(saved[2].defined());
return {saved[1] * 2 * grad_output[0]};
}
};
Variable x = torch::randn({5,5}, torch::requires_grad());
auto y = MyFunction::apply(x)[0];
y.sum().backward();
ASSERT_VARIABLE_EQ(x.grad(), 2*x);
}