pytorch/test/cpp/jit/test_autodiff.cpp
David Berard 99bc978b78 [JIT] Propagate requires_grad to autodiff subgraphs (#71666)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71666

When JIT autodiff is constructing a gradient computation graph, it will only add gradients for tensors that require_grad. Previously, require_grad information was **not** propagated to the subgraph that autodiff used; as a result, autodiff would calculate *all* gradients, even if requires_grad had never been set during profiling runs. In certain cases, this can lead to performance issues. For example, during training, the gradient of the input data is not needed, but is still computed.

This propagates requires_grad to the subgraph passed into autodiff, so that autodiff will not compute unnecessary gradients.

Test: `./bin/test_jit --gtest_filter="AutodiffRemoveUnusedGradientsTest.Linear"`

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D33725304

Pulled By: davidberard98

fbshipit-source-id: ca7ab4c9a6a26f94f93aff2d5a4135e125323ba1
(cherry picked from commit a97fe0556d)
2022-01-28 18:57:36 +00:00

352 lines
12 KiB
C++

#include <gtest/gtest.h>
#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/frontend/tracer.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/graph_fuser.h"
#include "torch/csrc/jit/passes/lower_grad_of.h"
#include "torch/csrc/jit/passes/requires_grad_analysis.h"
#include "torch/csrc/jit/passes/shape_analysis.h"
#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
#include "torch/csrc/jit/runtime/argument_spec.h"
#include "torch/csrc/jit/runtime/autodiff.h"
#include "torch/csrc/jit/runtime/graph_iterator.h"
#include "torch/csrc/jit/runtime/profiling_graph_executor_impl.h"
#include "torch/torch.h"
#include <ATen/ATen.h>
#include "torch/csrc/autograd/engine.h"
#include "torch/csrc/autograd/generated/variable_factories.h"
#include "torch/csrc/autograd/variable.h"
namespace torch {
namespace jit {
using namespace torch::autograd;
using var_meta_type = std::vector<int64_t>;
using var_meta_list = std::vector<var_meta_type>;
using test_fn_type = std::function<variable_list(const variable_list&)>;
struct ADTestSpec {
ADTestSpec(
const char* name,
// NOLINTNEXTLINE(modernize-pass-by-value)
var_meta_list input_meta,
// NOLINTNEXTLINE(modernize-pass-by-value)
test_fn_type test_fn,
float clampMax = -1.0f)
: name(name),
input_meta(input_meta),
test_fn(test_fn),
clampMax(clampMax) {}
variable_list operator()(const variable_list& inputs) const {
return test_fn(inputs);
};
std::vector<Variable> make_vars() const {
std::vector<Variable> out;
for (const auto& m : input_meta) {
if (clampMax > 0.0f) {
out.push_back(torch::randn(m, at::requires_grad(true))
.clamp(-clampMax, clampMax));
continue;
}
out.push_back(torch::randn(m, at::requires_grad(true)));
}
return out;
}
const char* name;
var_meta_list input_meta;
test_fn_type test_fn;
float clampMax;
};
variable_list get_grad_outputs(const variable_list& vars) {
return fmap(vars, [](const Variable& v) -> Variable {
return at::randn(v.sizes(), v.options());
});
}
variable_list grad(
const variable_list& outputs,
const variable_list& inputs,
const variable_list& grad_outputs) {
const auto get_edge = [](const Variable& v) {
return torch::autograd::impl::gradient_edge(v);
};
auto& engine = torch::autograd::Engine::get_default_engine();
return engine.execute(
fmap(outputs, get_edge),
grad_outputs,
true,
false,
false,
fmap(inputs, get_edge));
}
TEST(AutodiffTest, ADFormulas) {
const auto cast = [](const Variable& v) {
return static_cast<at::Tensor>(v);
};
using VL = variable_list;
const var_meta_list binary_pointwise = {{2, 3, 4, 5}, {2, 3, 4, 5}};
const var_meta_list unary_pointwise = {{2, 3, 4, 5}};
const var_meta_list unary_pointwise_2d = {{2, 3}};
const std::vector<ADTestSpec> ad_tests = {
{"add",
binary_pointwise,
[](const VL& v) -> VL { return {v[0] + v[1]}; }},
{"sub",
binary_pointwise,
[](const VL& v) -> VL { return {v[0] - v[1]}; }},
{"mul",
binary_pointwise,
[](const VL& v) -> VL { return {v[0] * v[1]}; }},
{"sigmoid",
unary_pointwise,
[](const VL& v) -> VL { return {v[0].sigmoid()}; }},
// Clamp tanh input tensor values to [-3, 3]
// to set a minimum on gradient absolute values
{"tanh",
unary_pointwise,
[](const VL& v) -> VL { return {v[0].tanh()}; },
3.0f},
{"t", unary_pointwise_2d, [](const VL& v) -> VL { return {v[0].t()}; }},
{"view",
unary_pointwise_2d,
[](const VL& v) -> VL {
return {v[0].view({3, 2})};
}},
{"expand",
{{2, 1}},
[](const VL& v) -> VL {
return {v[0].expand({2, 3})};
}},
{"mm",
{{10, 12}, {12, 15}},
[](const VL& v) -> VL { return {v[0].mm(v[1])}; }},
// TODO: enable once we'll be able to capture lists across
// forward-backward
//{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return
// fmap<Variable>(v[0].chunk(4, 1)); }},
//{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return
// fmap<Variable>(v[0].chunk(3, 2)); }},
//{"split", {{10, 12, 15}}, [](const VL& v) -> VL { return
// fmap<Variable>(v[0].split(4, 1)); }},
//{"split", {{10, 12, 15}}, [](const VL& v) -> VL { return
// fmap<Variable>(v[0].split(3, 2)); }},
};
for (const auto& test : ad_tests) {
// Get reference values form autograd
auto vars_in = test.make_vars();
auto vars_out = test(vars_in);
auto var_grads_in = get_grad_outputs(vars_out);
auto var_grads_out = grad(vars_out, vars_in, var_grads_in);
// Trace and differentiate the op
auto graph = tracer::trace(
fmap<IValue>(vars_in),
[&test](Stack in) -> Stack {
auto ivalue_inps = fmap(in, [](const IValue& v) {
return Variable(v.toTensor());
});
return fmap<IValue>(test(ivalue_inps));
},
[](const Variable& var) { return ""; })
.first->graph;
EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
ConstantPropagation(graph);
auto grad_spec = differentiate(graph);
LowerGradOf(*grad_spec.df);
// Get outputs from the interpreter
auto tensors_in = fmap(vars_in, cast);
auto tensor_grads_in = fmap(var_grads_in, cast);
tensor_list tensors_out, tensor_grads_out;
std::tie(tensors_out, tensor_grads_out) =
runGradient(grad_spec, tensors_in, tensor_grads_in);
// Compare results
auto expected_tensors_out = fmap(vars_out, cast);
auto expected_tensor_grads_out = fmap(var_grads_out, cast);
assertAllClose(tensors_out, expected_tensors_out);
assertAllClose(tensor_grads_out, expected_tensor_grads_out);
}
}
TEST(AutodiffTest, Differentiate) {
// Note: can't use IRParser for this test due to issue #23989
auto graph = std::make_shared<Graph>();
std::vector<int64_t> sizes{2, 3, 4};
std::vector<int64_t> strides{12, 4, 1};
const auto type = TensorType::create(
at::ScalarType::Float,
at::kCPU,
c10::VaryingShape<int64_t>{sizes},
c10::VaryingShape<int64_t>{strides},
true);
// Builds graph a * b * a + b
auto* a = graph->addInput()->setType(type);
auto* b = graph->addInput()->setType(type);
auto* cOne = graph->insertConstant(1);
auto* ab = graph->insertNode(graph->create(aten::mul, /*num_outputs =*/1));
ab->addInput(a);
ab->addInput(b);
auto* aba = graph->insertNode(graph->create(aten::mul, /*num_outputs =*/1));
aba->addInput(ab->output());
aba->addInput(a);
auto* abaplusb =
graph->insertNode(graph->create(aten::add, /*num_outputs =*/1));
abaplusb->addInput(aba->output());
abaplusb->addInput(b);
abaplusb->addInput(cOne);
graph->registerOutput(abaplusb->output());
auto grad_spec = differentiate(graph);
std::vector<size_t> expected_captured_inputs = {0, 1};
std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5, 6, 7};
std::vector<size_t> expected_input_vjps = {0, 1};
std::vector<size_t> expected_output_vjps = {0, 1};
ASSERT_EQ(grad_spec.f_real_outputs, 1);
ASSERT_EQ(grad_spec.df_input_captured_inputs, expected_captured_inputs);
ASSERT_EQ(grad_spec.df_input_captured_outputs, expected_captured_outputs);
ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
testing::FileCheck()
.check_count("aten::mul", 2)
->check("aten::size")
->check("aten::add")
->run(*grad_spec.f);
testing::FileCheck()
.check("prim::GradOf[name=\"aten::add\"]")
->check_count("prim::GradOf[name=\"aten::mul\"]", 2)
->check_count("AutogradAdd", 2)
->run(*grad_spec.df);
}
TEST(AutodiffTest, DifferentiateWithRequiresGrad) {
const auto graph_string = R"IR(
graph(%0 : Tensor,
%1 : Tensor):
%2 : int = prim::Constant[value=1]()
%3 : Tensor = aten::mul(%1, %1)
%4 : Tensor = aten::add(%3, %1, %2)
%5 : Tensor = aten::add(%4, %0, %2)
%6 : Tensor = aten::mul(%5, %0)
%7 : Tensor = aten::add(%6, %1, %2)
return (%4, %7))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
auto a_var = autograd::make_variable(
at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);
auto b_var = autograd::make_variable(
at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false);
ArgumentSpecCreator asc(*g);
asc.specializeTypes(*g, asc.create(true, {a_var, b_var}));
PropagateInputShapes(g);
PropagateRequiresGrad(g);
auto grad_spec = differentiate(g);
std::vector<size_t> expected_input_vjps = {1, 2}; // for e and %4 = (d + a)
std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
ASSERT_EQ(grad_spec.f_real_outputs, 2);
ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
ASSERT_EQ(
grad_spec.df_input_captured_outputs,
std::vector<size_t>({2, 3, 4, 5, 6}));
ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
testing::FileCheck()
.check("aten::mul")
->check_count("aten::add", 2)
->check("aten::mul")
->check("aten::size")
->check("aten::add")
->run(*grad_spec.f);
testing::FileCheck()
.check_count("prim::GradOf[name=\"aten::mul\"]", 1, /*exactly*/ true)
->run(*grad_spec.df);
}
class AutodiffRemoveUnusedGradientsTest : public ::testing::Test {
protected:
void SetUp() override {
prev_exec = getExecutorMode();
getExecutorMode() = true;
prev_profiling = getProfilingMode();
getProfilingMode() = true;
prev_inline_autodiff = getAutodiffSubgraphInlining();
debugSetAutodiffSubgraphInlining(false);
}
void TearDown() override {
getExecutorMode() = prev_exec;
getProfilingMode() = prev_profiling;
debugSetAutodiffSubgraphInlining(prev_inline_autodiff);
}
bool prev_exec;
bool prev_profiling;
bool prev_inline_autodiff;
};
TEST_F(AutodiffRemoveUnusedGradientsTest, Linear) {
auto graph = std::make_shared<Graph>();
const std::string input =
R"IR(
graph(%inp.1 : Tensor,
%weight.1 : Tensor,
%bias.1 : Tensor):
%6 : Tensor = aten::linear(%inp.1, %weight.1, %bias.1)
return (%6))IR";
parseIR(input, graph.get());
auto inp = torch::randn({10, 10}).requires_grad_(false);
auto weight = torch::randn({10, 10}).requires_grad_(true);
auto bias = torch::randn({1, 10}).requires_grad_(true);
auto stack = createStack({inp, weight, bias});
ProfilingGraphExecutorImpl executor(graph, "linear");
// initial run to profile requires_grad information
auto plan = executor.getPlanFor(stack, 20);
InterpreterState is{plan.code};
is.run(stack);
auto optimized_plan = executor.getPlanFor(stack, 20);
DepthFirstGraphNodeIterator it(optimized_plan.graph);
Node* diff_graph_node = nullptr;
while ((diff_graph_node = it.next()) != nullptr) {
if (diff_graph_node->kind() == prim::DifferentiableGraph) {
break;
}
}
ASSERT_NE(nullptr, diff_graph_node);
auto backward_graph = diff_graph_node->g(attr::ReverseSubgraph);
// we expect to compute grad_weight (which requires a matmul) but we don't
// expect to compute grad_input. So, we expect exactly 1 matmul.
// Note: this could change, e.g. if mm is used instead
testing::FileCheck().check_count("matmul", 1, true)->run(*backward_graph);
}
} // namespace jit
} // namespace torch