pytorch/torch/csrc/autograd/function.cpp
Zach DeVito ef4b19f767 Refactor ir.h to distinguish Nodes and Values
This commit adds a Value type similar to the one @ezyang suggested a while
ago for handling multi-return nodes.

Previously if we had a graph like:

  a = op1(b)
  c, d = op2(a)

Then its in-memory format would look like:

  %0 = op1(b)
  %1 = op2(%0)
  %2 = select(%1, 0)
  %2 = select(%1, 1)

Select nodes were used only to handle the multi-output case. In the
single-output case ops referred directly to their uses.

This required special handling for the single- and multi- output cases,
and was confusing when used with ONNX which distinguishes values (the
inputs/outputs of a node) from the nodes themselves (e.g. a Conv).

This commit adds the Node/Value distinction to the IR. In the example
above, `a`, `b`, `c`, and `d` are now Value objects, while `op1` and
`op2` are now Node objects. Inputs/Outputs to the graph are values.

* Nodes now always have multiple outputs, accessible through their `output()`
  method.
* Methods exist for adding/removing outputs from a node.
* Nodes own their output Values, destroying a node destroys its outputs and it
is only valid to destroy a node when no uses of its outputs remain.
* Unlike select, Values do not appear in the nodes list.
* The method `node()` on `Value` retrieves its defining node. Calling it
is always valid. For inputs, its kind is "Param". Like "Return" there is a single Param
node representing all inputs.
* For single-output Nodes, the method `output()` retrieves the single
output Value, asserting that the node is in-fact single output.
* Functions are the same, but some functions like `type()` have moved to
Value.
* `replaceAllUsesWith` is now sanely defined for both Values and Nodes.
In the case of Nodes, it replaces all outputs of the node with the outputs
of the replacement node.
* stage is defined both on Node/Value. This is because Inputs require a stage.
* Apart from changing data types from Node->Value most passes remain the same.
  Things that previously assumed single-output nodes now have to call output()
  to get the node.
* This removes the uses = [...] field in the outputs because it was
getting confusing even before this commit when uses would refer to nodes,
but we print the names of Values. The lint pass validates the use list,
so printing it out seems less necessary.
2017-11-15 11:47:18 -08:00

129 lines
4.2 KiB
C++

#include "Python.h"
#include "function.h"
#include <string>
#include "variable.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/autograd/functions/special.h"
namespace torch { namespace autograd {
template<typename T>
auto makeFlags(const T &inputs) -> FunctionFlags {
int num_inputs = inputs.size();
FunctionFlags f;
f.is_executable = false;
f.is_volatile = false;
f.next_functions.resize(num_inputs);
{
int i = 0;
for (auto it = inputs.begin(); it != inputs.end(); ++it, ++i) {
auto& var = *it;
if (var.defined()) {
f.is_executable |= var.requires_grad();
f.is_volatile |= var.is_volatile();
if (var.grad_fn()) {
f.next_functions[i] = std::make_pair<>(var.grad_fn(), var.output_nr());
} else {
f.next_functions[i] = std::make_pair<>(var.grad_accumulator(), 0);
}
}
}
}
f.is_executable &= !f.is_volatile;
return f;
}
auto Function::flags(const variable_list& inputs) -> FunctionFlags {
return makeFlags(inputs);
}
auto Function::flags(const std::initializer_list<Variable>& inputs) -> FunctionFlags {
return makeFlags(inputs);
}
auto Function::flags(at::TensorList inputs) -> FunctionFlags {
// TODO: Eliminate the intermediate vector allocation
return makeFlags(variable_list(inputs.begin(), inputs.end()));
}
auto Function::name() -> std::string {
return std::string(typeid(*this).name());
}
// This function is analogous to make_trace which operates on PythonOp, but this
// function instead works for C++ implemented autograd Functions, which don't
// actually have any backing Python class. We still need to trace them!
variable_list Function::tracedApply(variable_list inputs) {
using namespace torch::jit;
// Traceable Functions are completely transparent to the JIT.
if (is_traceable()) {
return apply(inputs);
}
auto state = tracer::getTracingState(inputs);
auto state_lock = state->lock();
// Insert a CppOp in the trace.
auto& graph = state->graph;
auto* this_node = graph->createCppOp(getSharedPtr());
for (auto& input: inputs) {
this_node->addInput(tracer::getValueTrace(state, input));
}
graph->appendNode(this_node);
// Finally apply this Function.
state_lock.unlock();
variable_list outputs = apply(inputs);
state_lock.lock();
// Set up output traces.
int num_outputs = outputs.size();
for (int i = 0; i < num_outputs; ++i) {
auto& output = outputs[i];
auto sel = this_node->addOutput();
// TODO: At the moment, C++ does not track shared storage. It
// should. Update this when that happens.
if (output.defined()) {
sel->inferTypeFrom(output.data());
tracer::setValueTrace(state, output, sel);
}
}
if (!passes_state_transparently()) {
auto this_eval = dynamic_cast<Eval*>(this);
// Evals consume handle from a context edge of forward node
if (this_eval)
this_node->addInput(this_eval->forward_ctx_select);
// There's no point in wrapping functions in Eval, if we know they already are
// part of another Eval subgraph. This is both a small optimization, and
// it allows us to not implement saved_variables() in many functions.
bool should_trace_backward = tracing_state->in_eval_subgraph;
if (!should_trace_backward) {
auto saved_vars = saved_variables();
if (!saved_vars)
throw std::runtime_error(std::string("saved_variables() needed but not implemented in ") + name());
variable_list bw_subgraph_inputs(inputs);
for (auto& saved_var : *saved_vars) {
bw_subgraph_inputs.emplace_back(saved_var.unpack(getSharedPtr()));
}
tracer::nontraceableBackwardSubgraph(bw_subgraph_inputs, outputs);
}
bool has_backwards_eval = !should_trace_backward || this_eval;
if (has_backwards_eval)
setUpContextEdge(this_node, inputs, outputs);
}
return outputs;
}
void Function::setUpContextEdge(jit::Node* node,
const variable_list& inputs, const variable_list& outputs) {
auto ctx_select = node->addOutput();
ctx_select->setType(std::make_shared<jit::HandleType>());
auto backward_eval = Eval::getBackwardEval(inputs, outputs);
if (backward_eval)
backward_eval->forward_ctx_select = ctx_select;
}
}} // namespace torch::autograd