pytorch/torch/csrc/jit/interpreter_autograd_function.cpp
Zachary DeVito 38bc732b2d
[jit] Change interpreter/fuser to work on Variables only (#7489)
* this removes the flag controlling whether the interpreter works on variables.
* now the interpreter _always_ works on variables
* constants in the IR are still _always_ non-variables, and an assert was added to ensure this.
* as_tensor was split into as_variable and as_tensor since it is sometimes used
  to construct constants in the IR
* I tried changing the IR to also always use variables but that change was much more
  cross cutting and fragile and I never got it working
2018-05-11 13:33:47 -07:00

188 lines
7.6 KiB
C++

#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/interpreter.h"
#include "torch/csrc/jit/interpreter_autograd_function.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/jit/tracer_state.h"
#include "torch/csrc/jit/variable_flags.h"
#include <ATen/ATen.h>
#include <algorithm>
#include <cstdint>
#include <memory>
#include <tuple>
#include <utility>
#include <vector>
namespace torch { namespace jit {
using namespace torch::jit::tracer;
static at::Tensor zeroTensorWithType(const TensorType & type) {
auto device = (type.device() < 0)? at::kCPU : at::kCUDA;
auto & at_type = at::getType(device, type.scalarType());
// note: this has to be a contiguous tensor of zeros, because the fusion engine
// specialized to what is normally here which might be fully dense
return autograd::make_variable(at::zeros(at_type, type.sizes()));
}
autograd::variable_list InterpreterAutogradFunction::apply(
const autograd::variable_list& inputs) {
// Initial correctness checks.
if (stage_ == stage_details_.size()) {
throw std::runtime_error(std::string("Function compiled only for ") +
std::to_string(stage_details_.size() - 1) + " derivatives. Use nderivs argument " +
"to request more.");
}
if (used_) throw std::runtime_error(autograd::ERR_BACKWARD_TWICE);
used_ |= !keep_graph_;
const auto & details = stage_details_[stage_];
// Validate inputs
std::vector<at::Tensor> stack;
stack.reserve(inputs.size());
TORCH_ASSERT(inputs.size() == num_inputs_);
TORCH_ASSERT(inputs.size() == details.input_flags.size());
for (std::size_t i = 0; i < (std::size_t)inputs.size(); ++i) {
auto actual_flags = VariableFlags::of(inputs[i]);
auto traced_flags = details.input_flags[i];
// check that this trace is general enough to handle the input
// flags of the actual tensor. We can't handle the following two cases
// because we won't have a trace containing computation of either
// the tensor itself (not defined) or the stage for its gradient (requires_grad=False)
if(!traced_flags.defined && actual_flags.defined) {
throw std::runtime_error("JIT interpreter received a defined input, but the"
" trace was compiled with the input being undefined.");
}
if(!traced_flags.requires_grad && actual_flags.requires_grad) {
throw std::runtime_error("JIT inteperpreter recieved an input with "
" requires_grad=True, but was compiled with requires_grad=False");
}
// The remaining cases we can handle. If the gradient was not
// required but the trace will compute it, then we just compute it and
// ignore the result.
// However, if we are passed an undefined tensor, but the trace
// expects a defined tensor, then we have to give it one.
// Undefined tensors are used as stand-ins for zero tensors, so
// we create a zero-filled tensor of the right size
if(!actual_flags.defined) {
// [Temporary workaround for variants] until tracer produces all variants:
// This case appears commonly when you have a function
// x, y = fn(z)
// and only use x then gradient for y
// will be undefined. If you reuse the same trace with and _sometimes_ use y
// then in the cases where you don't use it, the grad_y input in stage 1
// will be undefined. To ensure we can continue, we create a 0 gradient,
// using trace information to figure out what shape it should be
if(traced_flags.defined) {
stack.push_back(zeroTensorWithType(interp_.tensorTypeForInput(i)));
} else {
stack.push_back(at::Tensor());
}
} else {
stack.push_back(inputs[i].detach());
}
}
// Run the interpreter
InterpreterState interp = (keep_graph_) ? interp_.clone() : interp_;
interp.runOneStage(stack);
// Lazily create grad_fn
std::shared_ptr<Function> grad_fn;
auto make_grad_fn = [&]() {
grad_fn = std::make_shared<InterpreterAutogradFunction>(
std::move(interp), stage_details_, stage_ + 1);
// Running this next stage is actually not valid (nderiv is too low)
// but we don't know if the user will ever ask for it so we don't error out here.
// Instead we have to return early because we rely on stage_details_[stage+1] in the
// remaining code
if(stage_ + 1 == stage_details_.size())
return;
// Patch next_edges to include prevous stage next_edges
// This is needed because stage N is really a derivative of
// all stages from 1 to N-1. If a part of stage x graph is
// reused in stage y (y > x), it is inlined by the tracer,
// and so we need to copy next_fns because those Variables
// aren't real inputs to that stage, so that's the only place
// where we can get them.
for (auto copied_idx : stage_details_[stage_ + 1].copied_next_fns) {
grad_fn->add_next_edge(next_edges_[copied_idx]);
}
// Add grad_fns corresponding to inputs
for(size_t i = 0; i < inputs.size(); ++i) {
// If an input isn't used, there's no gradient for it, and next stage
// won't even have its grad in the trace. Don't create an entry for it.
if (!details.used_inputs[i]) continue;
auto & input = inputs[i];
if (!details.input_flags[i].requires_grad) {
continue; // See Note [Null-edge pruning]
} else if (!input.defined() || !input.requires_grad()) {
// See Note [Temporary workaround for variants]
grad_fn->add_next_edge({});
continue;
}
grad_fn->add_next_edge(input.gradient_edge());
}
};
// Wrap the outputs
// TODO: handle views
autograd::variable_list result;
JIT_ASSERT(stack.size() == details.output_flags.size());
auto num_outputs = stack.size();
for (std::size_t i = 0; i < num_outputs; ++i) {
auto & flags = details.output_flags[i];
if (flags.requires_grad) { // See Note [Null-edge pruning]
if (!grad_fn) make_grad_fn();
auto variable = static_cast<const Variable&>(stack[i]);
autograd::create_gradient_edge(variable, grad_fn);
result.push_back(std::move(variable));
} else {
result.push_back(static_cast<const Variable&>(stack[i]));
}
}
return result;
}
InterpreterFunctionFactory::InterpreterFunctionFactory(TracingState *state) {
code_ = jit::Code(state->graph);
stage_details_.resize(state->graph->stage() + 1);
auto graph_inputs = state->graph->inputs();
auto inputs_it = graph_inputs.begin();
for (std::size_t stage = 0; stage < state->graph->stage() + 1; ++stage) {
auto & details = stage_details_[stage];
std::tie(details.input_flags, details.output_flags) = std::move(state->var_flags[stage]);
for (std::size_t i = 0; inputs_it != graph_inputs.end() && (*inputs_it)->stage() == stage; ++i, ++inputs_it) {
details.used_inputs.push_back((*inputs_it)->uses().size() > 0);
}
if (stage >= 1) {
auto & current_outputs = state->output_edges[stage];
auto & prev_outputs = state->output_edges[stage - 1];
for (auto & output : current_outputs) {
// Check if output appears in outputs of previous stage
auto prev_it = std::find(prev_outputs.begin(), prev_outputs.end(), output);
if (prev_it == prev_outputs.end()) continue;
// If yes, find its index and append that to the list of edges that will need
// to be copied in InterpreterAutogradFunction.
details.copied_next_fns.push_back(std::distance(prev_outputs.begin(), prev_it));
}
}
}
}
std::shared_ptr<InterpreterAutogradFunction> InterpreterFunctionFactory::construct() {
return std::make_shared<InterpreterAutogradFunction>(code_, stage_details_);
}
}}