mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Leave undefined tensors / None returned from custom backward functions as undefined/None instead of creating a tensor full of zeros. This change improves performance in some cases. **This is BC-Breaking:** Custom backward functions that return None will now see it potentially being propagated all the way up to AccumulateGrad nodes. Potential impact is that .grad field of leaf tensors as well as the result of autograd.grad may be undefined/None where it used to be a tensor full of zeros. Also, autograd.grad may raise an error, if so, consider using allow_unused=True ([see doc](https://pytorch.org/docs/stable/autograd.html?highlight=autograd%20grad#torch.autograd.grad)) if it applies to your case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/41490 Reviewed By: albanD Differential Revision: D22578241 Pulled By: heitorschueroff fbshipit-source-id: f4966f4cb520069294f8c5c1691eeea799cc0abe
707 lines
21 KiB
C++
707 lines
21 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/torch.h>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
using namespace torch::autograd;
|
|
|
|
#define ASSERT_VARIABLE_EQ(a,b) ASSERT_TRUE(torch::allclose((a),(b)))
|
|
#define EXPECT_VARIABLE_EQ(a,b) EXPECT_TRUE(torch::allclose((a),(b)))
|
|
|
|
std::string graph_desc(std::shared_ptr<Node> node) {
|
|
if (!node) {
|
|
return "None";
|
|
}
|
|
auto result = node->name() + "(";
|
|
auto next_edges = node->next_edges();
|
|
for(auto& edge : next_edges) {
|
|
result += graph_desc(edge.function);
|
|
}
|
|
return result+")";
|
|
}
|
|
|
|
Variable simple_fn(const Variable& x, const Variable& y) {
|
|
return x + 2 * y + x * y;
|
|
}
|
|
|
|
TEST(AutogradAPITests, BackwardSimpleTest) {
|
|
Variable x = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
backward({res.sum()}, {});
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2}));
|
|
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2})*2);
|
|
}
|
|
|
|
TEST(AutogradAPITests, BackwardTest) {
|
|
Variable x = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
backward({res}, {torch::ones({2, 2})}, {}, true);
|
|
|
|
backward({res}, {torch::ones({2, 2})});
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), 2* (y + torch::ones({2, 2})));
|
|
ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2})*2));
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradSimpleTest) {
|
|
// basic grad
|
|
Variable x = torch::randn({2,2}, torch::requires_grad());
|
|
Variable y = torch::randn({2,2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})});
|
|
|
|
ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2}));
|
|
ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2);
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradTest) {
|
|
Variable x = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
res.backward(torch::ones({2, 2}), false, true);
|
|
|
|
Variable x_grad = y + torch::ones({2, 2});
|
|
Variable y_grad = x + torch::ones({2, 2}) * 2;
|
|
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
|
|
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
|
|
|
|
Variable grad_sum = 2 * x.grad() + y.grad();
|
|
auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true);
|
|
|
|
ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2}));
|
|
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
|
|
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradNonLeafTest) {
|
|
Variable x_init = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable x = x_init;
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable grad_output = torch::ones({2, 2});
|
|
|
|
for (int i = 0; i < 5; ++ i) {
|
|
auto res = simple_fn(x, y);
|
|
auto input_grads = grad({res}, {x}, {grad_output}, {}, true);
|
|
|
|
Variable grad_x_expected = y + torch::ones({2, 2});
|
|
ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected);
|
|
ASSERT_FALSE(x.grad().defined());
|
|
ASSERT_FALSE(y.grad().defined());
|
|
x = x + 0.05 * input_grads[0];
|
|
}
|
|
|
|
float val_init = simple_fn(x_init, y).sum().item().toFloat();
|
|
float val_final = simple_fn(x, y).sum().item().toFloat();
|
|
ASSERT_TRUE(val_final > val_init);
|
|
|
|
x.backward(grad_output, false, true);
|
|
ASSERT_TRUE(x_init.grad().defined());
|
|
ASSERT_TRUE(y.grad().defined());
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradUnreachableTest) {
|
|
Variable x = torch::ones({1}, torch::requires_grad());
|
|
Variable y = torch::ones({1}, torch::requires_grad());
|
|
|
|
Variable z = x * 2;
|
|
Variable w = y * 2;
|
|
|
|
auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true);
|
|
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
|
|
ASSERT_FALSE(grad_res[1].defined());
|
|
|
|
// This is slightly different than the case above, because z doesn't even
|
|
// have a grad accumulator allocated.
|
|
z = torch::ones({1}, torch::requires_grad());
|
|
grad_res = grad({x * 2}, {x, z}, {}, {}, false, true);
|
|
|
|
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
|
|
ASSERT_FALSE(grad_res[1].defined());
|
|
|
|
// allow_unused=False, but grads contains None inside, should throw
|
|
ASSERT_THROWS_WITH(grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True");
|
|
}
|
|
|
|
TEST(AutogradAPITests, RetainGrad) {
|
|
auto input = torch::rand({1, 3}, torch::requires_grad());
|
|
auto h1 = input * 3;
|
|
auto out = (h1 * h1).sum();
|
|
|
|
// It should be possible to call retain_grad() multiple times
|
|
h1.retain_grad();
|
|
h1.retain_grad();
|
|
|
|
// Gradient should be accumulated
|
|
out.backward({}, /*keep_graph=*/true);
|
|
ASSERT_VARIABLE_EQ(h1 * 2, h1.grad());
|
|
out.backward({}, /*keep_graph=*/true);
|
|
ASSERT_VARIABLE_EQ(h1 * 4, h1.grad());
|
|
|
|
{
|
|
torch::NoGradGuard no_grad;
|
|
input.grad().zero_();
|
|
}
|
|
// It should be a no-op for leaves
|
|
input.retain_grad();
|
|
input.retain_grad();
|
|
out.backward();
|
|
ASSERT_VARIABLE_EQ(input * 18, input.grad());
|
|
}
|
|
|
|
TEST(CustomAutogradTest, CustomFunction) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) {
|
|
ctx->saved_data["mul"] = mul;
|
|
ctx->save_for_backward({var1, var2});
|
|
return var1 + mul*var2 + var1*var2;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
int mul = ctx->saved_data["mul"].toInt();
|
|
auto saved = ctx->get_saved_variables();
|
|
auto var1 = saved[0];
|
|
auto var2 = saved[1];
|
|
variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1};
|
|
return output;
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5}, torch::requires_grad());
|
|
Variable y = torch::randn({5,5}, torch::requires_grad());
|
|
auto res = MyFunction::apply(x,2,y);
|
|
auto go = torch::ones({}, torch::requires_grad());
|
|
res.sum().backward(go, false, true);
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5,5}));
|
|
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5,5})*2);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, FunctionReturnsInput) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable var1) {
|
|
return var1;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return {grad_output[0]*2};
|
|
}
|
|
};
|
|
|
|
Variable x(torch::ones(1, torch::requires_grad()));
|
|
MyFunction::apply(x).backward(torch::ones(1) , true, true);
|
|
ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, FunctionReturnsUndefined) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable var) {
|
|
return var * 2;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
at::Tensor undefined_tensor;
|
|
return {undefined_tensor};
|
|
}
|
|
};
|
|
|
|
auto x = torch::ones(1, torch::requires_grad());
|
|
|
|
MyFunction::apply(x).backward();
|
|
ASSERT_FALSE(x.grad().defined());
|
|
|
|
MyFunction::apply(x.pow(2)).backward();
|
|
ASSERT_FALSE(x.grad().defined());
|
|
|
|
MyFunction::apply(x).sum().backward();
|
|
ASSERT_FALSE(x.grad().defined());
|
|
|
|
ASSERT_FALSE(torch::autograd::grad(
|
|
{MyFunction::apply(x)}, {x}, {}, false, false, true)[0].defined());
|
|
}
|
|
|
|
TEST(CustomAutogradTest, NoGradCustomFunction) {
|
|
// Custom Function should respect grad mode
|
|
struct MyOp : public Function<MyOp> {
|
|
static Variable forward(AutogradContext *ctx, Variable x) {
|
|
return x+1;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list dy) {
|
|
return dy;
|
|
}
|
|
};
|
|
|
|
auto x = torch::ones({5,5}, torch::requires_grad());
|
|
{
|
|
at::NoGradGuard no_grad;
|
|
auto y = MyOp::apply(x);
|
|
ASSERT_FALSE(y.requires_grad());
|
|
}
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkDirty) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable v) {
|
|
// Change the value inplace
|
|
auto v_data = v.data_ptr<float>();
|
|
v_data[0] = 2;
|
|
ctx->mark_dirty({v});
|
|
return v;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return { (grad_output[0]*2.0) };
|
|
}
|
|
};
|
|
|
|
// Clone here because modifying leafs inplace is not allowed
|
|
auto x = torch::randn({5,5}, torch::requires_grad()).clone();
|
|
auto version_before = x._version();
|
|
auto out = MyFunction::apply(x);
|
|
auto version_after = x._version();
|
|
ASSERT_TRUE(version_after >= (version_before + 1));
|
|
out.sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkNonDifferentiable) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable v) {
|
|
Variable output = v > 0;
|
|
ctx->mark_non_differentiable({output});
|
|
return output;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return { (grad_output[0]*0.0) };
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto mask = MyFunction::apply(x);
|
|
ASSERT_FALSE(mask.requires_grad());
|
|
auto y = x.masked_fill(mask, 0);
|
|
y.sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkNonDifferentiableMixed) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static variable_list forward(AutogradContext *ctx, Variable input) {
|
|
Variable a = input+1;
|
|
Variable b = input+2;
|
|
ctx->mark_non_differentiable({a});
|
|
return {a,b};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
const Variable &grad_a = grad_output[0], &grad_b = grad_output[1];
|
|
EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5,5}));
|
|
EXPECT_VARIABLE_EQ(grad_b, torch::ones({5,5}));
|
|
return {grad_b};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto out = MyFunction::apply(x);
|
|
|
|
ASSERT_FALSE(out[0].requires_grad());
|
|
ASSERT_TRUE(out[1].requires_grad());
|
|
out[1].sum().backward();
|
|
ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5,5}));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkNonDifferentiableNone) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable input) {
|
|
auto output = input.clone();
|
|
ctx->mark_non_differentiable({output});
|
|
return output;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) {
|
|
return {};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto r = MyFunction::apply(x * x);
|
|
(r * x).sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReturnLeafInplace) {
|
|
struct Inplace : public Function<Inplace> {
|
|
static variable_list forward(AutogradContext *ctx, Variable a, Variable b) {
|
|
ctx->mark_dirty({a});
|
|
return {a.add_(b), b+2};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return {grad_output[0], grad_output[0] + grad_output[1]};
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5});
|
|
Variable y = torch::randn({5,5}, torch::requires_grad());
|
|
|
|
auto out = Inplace::apply(x,y);
|
|
auto &q = out[0];
|
|
ASSERT_TRUE(torch::equal(q, x));
|
|
ASSERT_TRUE(q.requires_grad());
|
|
q.sum().backward();
|
|
ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5,5}));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReturnDuplicateInplace) {
|
|
struct DoubleInplace : public Function<DoubleInplace> {
|
|
static variable_list forward(AutogradContext *ctx, Variable x) {
|
|
x.mul_(2);
|
|
ctx->mark_dirty({x});
|
|
return {x,x};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctsx, variable_list grad_outputs) {
|
|
return {grad_outputs[0]*2 + grad_outputs[1]*2};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
|
|
ASSERT_THROWS_WITH(DoubleInplace::apply(x), "leaf Variable that requires grad");
|
|
// TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one output");
|
|
|
|
auto out = DoubleInplace::apply(x.clone());
|
|
ASSERT_TRUE(torch::equal(out[0],out[1]));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReturnDuplicate) {
|
|
struct DoubleDuplicate : public Function<DoubleDuplicate> {
|
|
static variable_list forward(AutogradContext *ctx, Variable x) {
|
|
auto output = x*2;
|
|
return {output, output};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) {
|
|
return {grad_outputs[0]*2 + grad_outputs[1]*2};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto out = DoubleDuplicate::apply(x);
|
|
ASSERT_TRUE(torch::equal(out[0],out[1]));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, SaveEmptyForBackward) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable input) {
|
|
ctx->save_for_backward({Variable(), input, Variable()});
|
|
return input*input;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
auto saved = ctx->get_saved_variables();
|
|
EXPECT_FALSE(saved[0].defined());
|
|
EXPECT_FALSE(saved[2].defined());
|
|
return {saved[1] * 2 * grad_output[0]};
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5}, torch::requires_grad());
|
|
auto y = MyFunction::apply(x);
|
|
y.sum().backward();
|
|
ASSERT_VARIABLE_EQ(x.grad(), 2*x);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, InvalidGradients) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable x) {
|
|
return x*2;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctsx, variable_list grad_outputs) {
|
|
return {torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))};
|
|
}
|
|
};
|
|
|
|
auto input1 = torch::randn({5,5}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
ASSERT_THROWS_WITH(
|
|
MyFunction::apply(input1).sum().backward(), "expected shape");
|
|
auto input2 = torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, NoGradInput) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext*, Variable x) {
|
|
return x;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad_outputs) {
|
|
return grad_outputs;
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5}, torch::requires_grad());
|
|
Variable y;
|
|
{
|
|
at::NoGradGuard no_grad;
|
|
y = MyFunction::apply(x);
|
|
}
|
|
|
|
ASSERT_TRUE(x.requires_grad());
|
|
ASSERT_FALSE(y.grad_fn());
|
|
}
|
|
|
|
TEST(CustomAutogradTest, TooManyGrads) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext*, Variable input) {
|
|
return input;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad_output) {
|
|
grad_output.insert(grad_output.end(), {Variable(), Variable()});
|
|
return grad_output;
|
|
}
|
|
};
|
|
}
|
|
|
|
TEST(CustomAutogradTest, DepNoGrad) {
|
|
struct F1 : public Function<F1> {
|
|
static variable_list forward(AutogradContext *ctx, Variable input) {
|
|
auto out = torch::randn(input.sizes());
|
|
ctx->mark_non_differentiable({out});
|
|
return {input, out};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return {grad_output[0]};
|
|
}
|
|
};
|
|
|
|
struct F2 : public Function<F2> {
|
|
static Variable forward(AutogradContext*, Variable input, Variable ignore) {
|
|
return input;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad_output) {
|
|
return {grad_output[0], Variable()};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn(5, torch::requires_grad());
|
|
auto out = F1::apply(x);
|
|
Variable &a = out[0], &b = out[1];
|
|
b = b+1; // Separate F1 and F2 by another operation
|
|
ASSERT_TRUE(a.requires_grad());
|
|
ASSERT_FALSE(b.requires_grad());
|
|
|
|
auto c = F2::apply(a,b);
|
|
c.backward(torch::ones(c.sizes()), false, false);
|
|
ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes()));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, Reentrant) {
|
|
static Variable y_data = torch::randn({2, 2});
|
|
struct Reenter : public Function<Reenter> {
|
|
static Variable forward(AutogradContext *ctx, Variable input) {
|
|
Variable output;
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
auto x = make_variable(input.tensor_data(), true);
|
|
auto y = make_variable(y_data.tensor_data(), true);
|
|
output = x*y;
|
|
|
|
ctx->saved_data["x"] = x;
|
|
ctx->saved_data["y"] = y;
|
|
ctx->saved_data["output_var"] = output;
|
|
}
|
|
return output.detach();
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
auto out = ctx->saved_data["output_var"].toTensor();
|
|
out.sum().backward();
|
|
}
|
|
return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({2,2}, torch::requires_grad());
|
|
auto out = Reenter::apply(x);
|
|
out.sum().backward();
|
|
ASSERT_VARIABLE_EQ(x.grad(), y_data);
|
|
}
|
|
|
|
|
|
// NOTE: If this fails for apparently unrelated reasons in TSAN be aware of
|
|
// the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950
|
|
TEST(CustomAutogradTest, DeepReentrant) {
|
|
struct DeepReenter : public Function<DeepReenter> {
|
|
static Variable forward(AutogradContext *ctx, Variable x) {
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) -1;
|
|
}
|
|
return ctx->saved_data["x"].toTensor().detach();
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*ctx, variable_list grad_output) {
|
|
if (!ctx->saved_data["x"].toTensor().is_nonzero()) {
|
|
return grad_output;
|
|
}
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
|
|
return grad_output;
|
|
}
|
|
}
|
|
};
|
|
|
|
// This should not stack overflow
|
|
auto v = torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
DeepReenter::apply(v).sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReentrantPriority) {
|
|
static std::vector<int> order;
|
|
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext*, Variable x) {
|
|
return x;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad) {
|
|
order.push_back(0);
|
|
return grad;
|
|
}
|
|
};
|
|
|
|
struct Reenter : public Function<Reenter> {
|
|
static Variable forward(AutogradContext *ctx, Variable x) {
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) -1;
|
|
}
|
|
return ctx->saved_data["x"].toTensor().detach();
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*ctx, variable_list grad_output) {
|
|
order.push_back(1);
|
|
if (!ctx->saved_data["x"].toTensor().is_nonzero()) {
|
|
return grad_output;
|
|
}
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
|
|
return grad_output;
|
|
}
|
|
}
|
|
};
|
|
|
|
auto a = MyFunction::apply(torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true)));
|
|
auto b = Reenter::apply(torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true)));
|
|
auto v = a*b;
|
|
v.backward();
|
|
|
|
|
|
// All the reentrant tasks should be prioritized over the MyFunction backward
|
|
// task.
|
|
ASSERT_EQ(order.size(), 10);
|
|
ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9);
|
|
ASSERT_EQ(order.back(), 0);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, Hooks) {
|
|
Variable x = torch::ones({5,5}, torch::requires_grad());
|
|
Variable y = torch::ones({5,5})*4;
|
|
y.set_requires_grad(true);
|
|
|
|
int counter = 0;
|
|
|
|
std::function<void(int, Variable)> bw_hook([&counter](int inc, Variable grad){
|
|
counter += inc;
|
|
});
|
|
|
|
Variable z = x * x + x * 2 + x * y + y;
|
|
x.register_hook([&bw_hook](Variable grad){
|
|
bw_hook(0, grad);
|
|
});
|
|
auto hook_1 = z.register_hook([&bw_hook](Variable grad){
|
|
bw_hook(1, grad);
|
|
});
|
|
z.backward(torch::ones({5,5}), true, true);
|
|
ASSERT_EQ(counter, 1);
|
|
|
|
auto hook_2 = z.register_hook([&bw_hook](Variable grad){
|
|
bw_hook(2, grad);
|
|
});
|
|
z.backward(torch::ones({5,5}), true, true);
|
|
ASSERT_EQ(counter, 4);
|
|
|
|
z.remove_hook(hook_2);
|
|
z.backward(torch::ones({5,5}), true, true);
|
|
ASSERT_EQ(counter, 5);
|
|
|
|
std::function<Variable(Variable)> bw_hook_modify([](Variable grad){
|
|
return grad.mul(2);
|
|
});
|
|
|
|
z.remove_hook(hook_1);
|
|
z.register_hook(bw_hook_modify);
|
|
y.grad().zero_();
|
|
z.backward(torch::ones({5,5}), true, false);
|
|
ASSERT_VARIABLE_EQ(y.grad(), (x+1)*2);
|
|
|
|
y.register_hook(bw_hook_modify);
|
|
y.grad().zero_();
|
|
z.backward(torch::ones({5,5}), false, false);
|
|
ASSERT_VARIABLE_EQ(y.grad(), (x+1)*4);
|
|
|
|
ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
|
|
}
|
|
|
|
TEST(CustomAutogradTest, HookNone) {
|
|
struct NoneGradientFunction : public Function<NoneGradientFunction> {
|
|
static variable_list forward(AutogradContext *ctx, Variable x, Variable y) {
|
|
return {x,y};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad) {
|
|
return {grad[0], Variable()};
|
|
}
|
|
};
|
|
|
|
bool was_called = false;
|
|
|
|
auto hook = ([&was_called](Variable grad){
|
|
ASSERT_TRUE(grad.defined());
|
|
was_called = true;
|
|
});
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto y = torch::randn({5,5});
|
|
|
|
auto out = NoneGradientFunction::apply(x,y);
|
|
Variable rx = x[0], ry = x[1];
|
|
|
|
rx.register_hook(hook);
|
|
ry.register_hook(hook);
|
|
(rx+ry).sum().backward();
|
|
ASSERT_TRUE(was_called);
|
|
}
|
|
|
|
// TODO add these tests if needed
|
|
// test_once_differentiable
|
|
// test_sparse_backward
|
|
// test_save_output_nr
|
|
// test_free_deep_graph_pyfunction
|
|
// test_naughty_anomaly_access
|
|
// test_naughty_autograd-function_stashing_ctx
|
|
// test_custom_autograd_repeated_grad_grad
|
|
// test_return_leaf
|
|
// test_anomaly_detect_nan
|
|
// test_no_grad_copy
|