mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Improve handling of mixed-type tensor operations. This PR affects the arithmetic (add, sub, mul, and div) operators implemented via TensorIterator (so dense but not sparse tensor ops). For these operators, we will now promote to reasonable types where possible, following the rules defined in https://github.com/pytorch/pytorch/issues/9515, and error in cases where the cast would require floating point -> integral or non-boolean to boolean downcasts. The details of the promotion rules are described here: https://github.com/nairbv/pytorch/blob/promote_types_strict/docs/source/tensor_attributes.rst Some specific backwards incompatible examples: * now `int_tensor * float` will result in a float tensor, whereas previously the floating point operand was first cast to an int. Previously `torch.tensor(10) * 1.9` => `tensor(10)` because the 1.9 was downcast to `1`. Now the result will be the more intuitive `tensor(19)` * Now `int_tensor *= float` will error, since the floating point result of this operation can't be cast into the in-place integral type result. See more examples/detail in the original issue (https://github.com/pytorch/pytorch/issues/9515), in the above linked tensor_attributes.rst doc, or in the test_type_promotion.py tests added in this PR: https://github.com/nairbv/pytorch/blob/promote_types_strict/test/test_type_promotion.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/22273 Reviewed By: gchanan Differential Revision: D16582230 Pulled By: nairbv fbshipit-source-id: 4029cca891908cdbf4253e4513c617bba7306cb3
625 lines
18 KiB
C++
625 lines
18 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/autograd.h>
|
|
|
|
#include <torch/utils.h>
|
|
#include <test/cpp/api/support.h>
|
|
|
|
using namespace torch::autograd;
|
|
|
|
#define ASSERT_VARIABLE_EQ(a,b) ASSERT_TRUE(torch::allclose((a),(b)))
|
|
#define EXPECT_VARIABLE_EQ(a,b) EXPECT_TRUE(torch::allclose((a),(b)))
|
|
|
|
std::string graph_desc(std::shared_ptr<Node> node) {
|
|
if (!node) {
|
|
return "None";
|
|
}
|
|
auto result = node->name() + "(";
|
|
auto next_edges = node->next_edges();
|
|
for(auto& edge : next_edges) {
|
|
result += graph_desc(edge.function);
|
|
}
|
|
return result+")";
|
|
}
|
|
|
|
Variable simple_fn(const Variable& x, const Variable& y) {
|
|
return x + 2 * y + x * y;
|
|
}
|
|
|
|
TEST(AutogradAPITests, BackwardSimpleTest) {
|
|
Variable x = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
backward({res.sum()}, {});
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2}));
|
|
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2})*2);
|
|
}
|
|
|
|
TEST(AutogradAPITests, BackwardTest) {
|
|
Variable x = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
backward({res}, {torch::ones({2, 2})}, {}, true);
|
|
|
|
backward({res}, {torch::ones({2, 2})});
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), 2* (y + torch::ones({2, 2})));
|
|
ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2})*2));
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradSimpleTest) {
|
|
// basic grad
|
|
Variable x = torch::randn({2,2}, torch::requires_grad());
|
|
Variable y = torch::randn({2,2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})});
|
|
|
|
ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2}));
|
|
ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2);
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradTest) {
|
|
Variable x = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
auto res = simple_fn(x, y);
|
|
res.backward(torch::ones({2, 2}), false, true);
|
|
|
|
Variable x_grad = y + torch::ones({2, 2});
|
|
Variable y_grad = x + torch::ones({2, 2}) * 2;
|
|
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
|
|
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
|
|
|
|
Variable grad_sum = 2 * x.grad() + y.grad();
|
|
auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true);
|
|
|
|
ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2}));
|
|
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
|
|
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradNonLeafTest) {
|
|
Variable x_init = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable x = x_init;
|
|
Variable y = torch::randn({2, 2}, torch::requires_grad());
|
|
Variable grad_output = torch::ones({2, 2});
|
|
|
|
for (int i = 0; i < 5; ++ i) {
|
|
auto res = simple_fn(x, y);
|
|
auto input_grads = grad({res}, {x}, {grad_output}, {}, true);
|
|
|
|
Variable grad_x_expected = y + torch::ones({2, 2});
|
|
ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected);
|
|
ASSERT_FALSE(x.grad().defined());
|
|
ASSERT_FALSE(y.grad().defined());
|
|
x = x + 0.05 * input_grads[0];
|
|
}
|
|
|
|
float val_init = simple_fn(x_init, y).sum().item().toFloat();
|
|
float val_final = simple_fn(x, y).sum().item().toFloat();
|
|
ASSERT_TRUE(val_final > val_init);
|
|
|
|
x.backward(grad_output, false, true);
|
|
ASSERT_TRUE(x_init.grad().defined());
|
|
ASSERT_TRUE(y.grad().defined());
|
|
}
|
|
|
|
TEST(AutogradAPITests, GradUnreachableTest) {
|
|
Variable x = torch::ones({1}, torch::requires_grad());
|
|
Variable y = torch::ones({1}, torch::requires_grad());
|
|
|
|
Variable z = x * 2;
|
|
Variable w = y * 2;
|
|
|
|
auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true);
|
|
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
|
|
ASSERT_FALSE(grad_res[1].defined());
|
|
|
|
// This is slightly different than the case above, because z doesn't even
|
|
// have a grad accumulator allocated.
|
|
z = torch::ones({1}, torch::requires_grad());
|
|
grad_res = grad({x * 2}, {x, z}, {}, {}, false, true);
|
|
|
|
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
|
|
ASSERT_FALSE(grad_res[1].defined());
|
|
}
|
|
|
|
TEST(CustomAutogradTest, CustomFunction) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) {
|
|
ctx->saved_data["mul"] = mul;
|
|
ctx->save_for_backward({var1, var2});
|
|
return var1 + mul*var2 + var1*var2;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
int mul = ctx->saved_data["mul"].toInt();
|
|
auto saved = ctx->get_saved_variables();
|
|
auto var1 = saved[0];
|
|
auto var2 = saved[1];
|
|
variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1};
|
|
return output;
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5}, torch::requires_grad());
|
|
Variable y = torch::randn({5,5}, torch::requires_grad());
|
|
auto res = MyFunction::apply(x,2,y);
|
|
auto go = torch::ones({}, torch::requires_grad());
|
|
res.sum().backward(go, false, true);
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5,5}));
|
|
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5,5})*2);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, FunctionReturnsInput) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable var1) {
|
|
return var1;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return {grad_output[0]*2};
|
|
}
|
|
};
|
|
|
|
Variable x(torch::ones(1, torch::requires_grad()));
|
|
MyFunction::apply(x).backward(torch::ones(1) , true, true);
|
|
ASSERT_VARIABLE_EQ(x.grad(), torch::full(1,2));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, NoGradCustomFunction) {
|
|
// Custom Function should respect grad mode
|
|
struct MyOp : public Function<MyOp> {
|
|
static Variable forward(AutogradContext *ctx, Variable x) {
|
|
return x+1;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list dy) {
|
|
return dy;
|
|
}
|
|
};
|
|
|
|
auto x = torch::ones({5,5}, torch::requires_grad());
|
|
{
|
|
at::NoGradGuard no_grad;
|
|
auto y = MyOp::apply(x);
|
|
ASSERT_FALSE(y.requires_grad());
|
|
}
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkNonDifferentiable) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable v) {
|
|
Variable output = v > 0;
|
|
ctx->mark_non_differentiable({output});
|
|
return output;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return { (grad_output[0]*0.0) };
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto mask = MyFunction::apply(x);
|
|
ASSERT_FALSE(mask.requires_grad());
|
|
auto y = x.masked_fill(mask, 0);
|
|
y.sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkNonDifferentiableMixed) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static variable_list forward(AutogradContext *ctx, Variable input) {
|
|
Variable a = input+1;
|
|
Variable b = input+2;
|
|
ctx->mark_non_differentiable({a});
|
|
return {a,b};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
const Variable &grad_a = grad_output[0], &grad_b = grad_output[1];
|
|
EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5,5}));
|
|
EXPECT_VARIABLE_EQ(grad_b, torch::ones({5,5}));
|
|
return {grad_b};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto out = MyFunction::apply(x);
|
|
|
|
ASSERT_FALSE(out[0].requires_grad());
|
|
ASSERT_TRUE(out[1].requires_grad());
|
|
out[1].sum().backward();
|
|
ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5,5}));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkNonDifferentiableNone) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable input) {
|
|
auto output = input.clone();
|
|
ctx->mark_non_differentiable({output});
|
|
return output;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) {
|
|
return {};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto r = MyFunction::apply(x * x);
|
|
(r * x).sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReturnLeafInplace) {
|
|
struct Inplace : public Function<Inplace> {
|
|
static variable_list forward(AutogradContext *ctx, Variable a, Variable b) {
|
|
ctx->mark_dirty({a});
|
|
return {a.add_(b), b+2};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return {grad_output[0], grad_output[0] + grad_output[1]};
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5});
|
|
Variable y = torch::randn({5,5}, torch::requires_grad());
|
|
|
|
auto out = Inplace::apply(x,y);
|
|
auto &q = out[0];
|
|
ASSERT_TRUE(torch::equal(q, x));
|
|
ASSERT_TRUE(q.requires_grad());
|
|
q.sum().backward();
|
|
ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5,5}));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReturnDuplicateInplace) {
|
|
struct DoubleInplace : public Function<DoubleInplace> {
|
|
static variable_list forward(AutogradContext *ctx, Variable x) {
|
|
x.mul_(2);
|
|
ctx->mark_dirty({x});
|
|
return {x,x};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctsx, variable_list grad_outputs) {
|
|
return {grad_outputs[0]*2 + grad_outputs[1]*2};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
|
|
ASSERT_THROWS_WITH(DoubleInplace::apply(x), "leaf Variable that requires grad");
|
|
// TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one output");
|
|
|
|
auto out = DoubleInplace::apply(x.clone());
|
|
ASSERT_TRUE(torch::equal(out[0],out[1]));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReturnDuplicate) {
|
|
struct DoubleDuplicate : public Function<DoubleDuplicate> {
|
|
static variable_list forward(AutogradContext *ctx, Variable x) {
|
|
auto output = x*2;
|
|
return {output, output};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) {
|
|
return {grad_outputs[0]*2 + grad_outputs[1]*2};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto out = DoubleDuplicate::apply(x);
|
|
ASSERT_TRUE(torch::equal(out[0],out[1]));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, SaveEmptyForBackward) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable input) {
|
|
ctx->save_for_backward({Variable(), input, Variable()});
|
|
return input*input;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
auto saved = ctx->get_saved_variables();
|
|
EXPECT_FALSE(saved[0].defined());
|
|
EXPECT_FALSE(saved[2].defined());
|
|
return {saved[1] * 2 * grad_output[0]};
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5}, torch::requires_grad());
|
|
auto y = MyFunction::apply(x);
|
|
y.sum().backward();
|
|
ASSERT_VARIABLE_EQ(x.grad(), 2*x);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, InvalidGradients) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext *ctx, Variable x) {
|
|
return x*2;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctsx, variable_list grad_outputs) {
|
|
return {torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))};
|
|
}
|
|
};
|
|
|
|
auto input1 = torch::randn({5,5}, torch::dtype(torch::kFloat).requires_grad(true));
|
|
ASSERT_THROWS_WITH(
|
|
MyFunction::apply(input1).sum().backward(), "expected shape");
|
|
auto input2 = torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, NoGradInput) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext*, Variable x) {
|
|
return x;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad_outputs) {
|
|
return grad_outputs;
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5}, torch::requires_grad());
|
|
Variable y;
|
|
{
|
|
at::NoGradGuard no_grad;
|
|
y = MyFunction::apply(x);
|
|
}
|
|
|
|
ASSERT_TRUE(x.requires_grad());
|
|
ASSERT_FALSE(y.grad_fn());
|
|
}
|
|
|
|
TEST(CustomAutogradTest, TooManyGrads) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext*, Variable input) {
|
|
return input;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad_output) {
|
|
grad_output.insert(grad_output.end(), {Variable(), Variable()});
|
|
return grad_output;
|
|
}
|
|
};
|
|
}
|
|
|
|
TEST(CustomAutogradTest, DepNoGrad) {
|
|
struct F1 : public Function<F1> {
|
|
static variable_list forward(AutogradContext *ctx, Variable input) {
|
|
auto out = torch::randn(input.sizes());
|
|
ctx->mark_non_differentiable({out});
|
|
return {input, out};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return {grad_output[0]};
|
|
}
|
|
};
|
|
|
|
struct F2 : public Function<F2> {
|
|
static Variable forward(AutogradContext*, Variable input, Variable ignore) {
|
|
return input;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad_output) {
|
|
return {grad_output[0], Variable()};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn(5, torch::requires_grad());
|
|
auto out = F1::apply(x);
|
|
Variable &a = out[0], &b = out[1];
|
|
b = b+1; // Separate F1 and F2 by another operation
|
|
ASSERT_TRUE(a.requires_grad());
|
|
ASSERT_FALSE(b.requires_grad());
|
|
|
|
auto c = F2::apply(a,b);
|
|
c.backward(torch::ones(c.sizes()), false, false);
|
|
ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes()));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, Reentrant) {
|
|
static Variable y_data = torch::randn({2, 2});
|
|
struct Reenter : public Function<Reenter> {
|
|
static Variable forward(AutogradContext *ctx, Variable input) {
|
|
Variable output;
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
auto x = make_variable(input.tensor_data(), true);
|
|
auto y = make_variable(y_data.tensor_data(), true);
|
|
output = x*y;
|
|
|
|
ctx->saved_data["x"] = x;
|
|
ctx->saved_data["y"] = y;
|
|
ctx->saved_data["output_var"] = output;
|
|
}
|
|
return output.detach();
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
auto out = ctx->saved_data["output_var"].toTensor();
|
|
out.sum().backward();
|
|
}
|
|
return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]};
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({2,2}, torch::requires_grad());
|
|
auto out = Reenter::apply(x);
|
|
out.sum().backward();
|
|
ASSERT_VARIABLE_EQ(x.grad(), y_data);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, DeepReentrant) {
|
|
struct DeepReenter : public Function<DeepReenter> {
|
|
static Variable forward(AutogradContext *ctx, Variable x) {
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) -1;
|
|
}
|
|
return ctx->saved_data["x"].toTensor().detach();
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*ctx, variable_list grad_output) {
|
|
if (!ctx->saved_data["x"].toTensor().is_nonzero()) {
|
|
return grad_output;
|
|
}
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
|
|
return grad_output;
|
|
}
|
|
}
|
|
};
|
|
|
|
// This should not stack overflow
|
|
auto v = torch::tensor(8193, torch::requires_grad());
|
|
DeepReenter::apply(v).sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReentrantPriority) {
|
|
static std::vector<int> order;
|
|
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(AutogradContext*, Variable x) {
|
|
return x;
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*, variable_list grad) {
|
|
order.push_back(0);
|
|
return grad;
|
|
}
|
|
};
|
|
|
|
struct Reenter : public Function<Reenter> {
|
|
static Variable forward(AutogradContext *ctx, Variable x) {
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) -1;
|
|
}
|
|
return ctx->saved_data["x"].toTensor().detach();
|
|
}
|
|
|
|
static variable_list backward(AutogradContext*ctx, variable_list grad_output) {
|
|
order.push_back(1);
|
|
if (!ctx->saved_data["x"].toTensor().is_nonzero()) {
|
|
return grad_output;
|
|
}
|
|
{
|
|
at::AutoGradMode enable_grad(true);
|
|
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
|
|
return grad_output;
|
|
}
|
|
}
|
|
};
|
|
|
|
auto a = MyFunction::apply(torch::tensor(6, torch::requires_grad()));
|
|
auto b = Reenter::apply(torch::tensor(9, torch::requires_grad()));
|
|
auto v = a*b;
|
|
v.backward();
|
|
|
|
|
|
// All the reentrant tasks should be prioritized over the MyFunction backward
|
|
// task.
|
|
ASSERT_EQ(order.size(), 10);
|
|
ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9);
|
|
ASSERT_EQ(order.back(), 0);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, Hooks) {
|
|
Variable x = torch::ones({5,5}, torch::requires_grad());
|
|
Variable y = torch::ones({5,5})*4;
|
|
y.set_requires_grad(true);
|
|
|
|
int counter = 0;
|
|
|
|
std::function<void(int, Variable)> bw_hook([&counter](int inc, Variable grad){
|
|
counter += inc;
|
|
});
|
|
|
|
Variable z = x * x + x * 2 + x * y + y;
|
|
x.register_hook([&bw_hook](Variable grad){
|
|
bw_hook(0, grad);
|
|
});
|
|
auto hook_1 = z.register_hook([&bw_hook](Variable grad){
|
|
bw_hook(1, grad);
|
|
});
|
|
z.backward(torch::ones({5,5}), true, true);
|
|
ASSERT_EQ(counter, 1);
|
|
|
|
auto hook_2 = z.register_hook([&bw_hook](Variable grad){
|
|
bw_hook(2, grad);
|
|
});
|
|
z.backward(torch::ones({5,5}), true, true);
|
|
ASSERT_EQ(counter, 4);
|
|
|
|
z.remove_hook(hook_2);
|
|
z.backward(torch::ones({5,5}), true, true);
|
|
ASSERT_EQ(counter, 5);
|
|
|
|
std::function<Variable(Variable)> bw_hook_modify([](Variable grad){
|
|
return grad.mul(2);
|
|
});
|
|
|
|
z.remove_hook(hook_1);
|
|
z.register_hook(bw_hook_modify);
|
|
y.grad().zero_();
|
|
z.backward(torch::ones({5,5}), true, false);
|
|
ASSERT_VARIABLE_EQ(y.grad(), (x+1)*2);
|
|
|
|
y.register_hook(bw_hook_modify);
|
|
y.grad().zero_();
|
|
z.backward(torch::ones({5,5}), false, false);
|
|
ASSERT_VARIABLE_EQ(y.grad(), (x+1)*4);
|
|
|
|
ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
|
|
}
|
|
|
|
TEST(CustomAutogradTest, HookNone) {
|
|
struct NoneGradientFunction : public Function<NoneGradientFunction> {
|
|
static variable_list forward(AutogradContext *ctx, Variable x, Variable y) {
|
|
return {x,y};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad) {
|
|
return {grad[0], Variable()};
|
|
}
|
|
};
|
|
|
|
bool was_called = false;
|
|
|
|
auto hook = ([&was_called](Variable grad){
|
|
ASSERT_TRUE(grad.defined());
|
|
was_called = true;
|
|
});
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto y = torch::randn({5,5});
|
|
|
|
auto out = NoneGradientFunction::apply(x,y);
|
|
Variable rx = x[0], ry = x[1];
|
|
|
|
rx.register_hook(hook);
|
|
ry.register_hook(hook);
|
|
(rx+ry).sum().backward();
|
|
ASSERT_TRUE(was_called);
|
|
}
|
|
|
|
// TODO add these tests if needed
|
|
// test_once_differentiable
|
|
// test_sparse_backward
|
|
// test_save_output_nr
|
|
// test_free_deep_graph_pyfunction
|
|
// test_naughty_anomaly_access
|
|
// test_naughty_autograd-function_stashing_ctx
|
|
// test_custom_autograd_repeated_grad_grad
|
|
// test_return_leaf
|
|
// test_anomaly_detect_nan
|
|
// test_no_grad_copy
|