pytorch/torch/csrc/jit/python_ir.cpp
Zach DeVito ef4b19f767 Refactor ir.h to distinguish Nodes and Values
This commit adds a Value type similar to the one @ezyang suggested a while
ago for handling multi-return nodes.

Previously if we had a graph like:

  a = op1(b)
  c, d = op2(a)

Then its in-memory format would look like:

  %0 = op1(b)
  %1 = op2(%0)
  %2 = select(%1, 0)
  %2 = select(%1, 1)

Select nodes were used only to handle the multi-output case. In the
single-output case ops referred directly to their uses.

This required special handling for the single- and multi- output cases,
and was confusing when used with ONNX which distinguishes values (the
inputs/outputs of a node) from the nodes themselves (e.g. a Conv).

This commit adds the Node/Value distinction to the IR. In the example
above, `a`, `b`, `c`, and `d` are now Value objects, while `op1` and
`op2` are now Node objects. Inputs/Outputs to the graph are values.

* Nodes now always have multiple outputs, accessible through their `output()`
  method.
* Methods exist for adding/removing outputs from a node.
* Nodes own their output Values, destroying a node destroys its outputs and it
is only valid to destroy a node when no uses of its outputs remain.
* Unlike select, Values do not appear in the nodes list.
* The method `node()` on `Value` retrieves its defining node. Calling it
is always valid. For inputs, its kind is "Param". Like "Return" there is a single Param
node representing all inputs.
* For single-output Nodes, the method `output()` retrieves the single
output Value, asserting that the node is in-fact single output.
* Functions are the same, but some functions like `type()` have moved to
Value.
* `replaceAllUsesWith` is now sanely defined for both Values and Nodes.
In the case of Nodes, it replaces all outputs of the node with the outputs
of the replacement node.
* stage is defined both on Node/Value. This is because Inputs require a stage.
* Apart from changing data types from Node->Value most passes remain the same.
  Things that previously assumed single-output nodes now have to call output()
  to get the node.
* This removes the uses = [...] field in the outputs because it was
getting confusing even before this commit when uses would refer to nodes,
but we print the names of Values. The lint pass validates the use list,
so printing it out seems less necessary.
2017-11-15 11:47:18 -08:00

241 lines
6.7 KiB
C++

#include <Python.h>
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/pybind.h"
#include "torch/csrc/jit/python_tracer.h"
#include "torch/csrc/utils/pybind.h"
#include <iostream>
#include <sstream>
namespace torch { namespace jit {
void initPythonIRBindings(PyObject * module_) {
auto m = py::handle(module_).cast<py::module>();
#define GS(name) \
def(#name,&Graph :: name)
py::class_<Graph,std::shared_ptr<Graph>>(m,"Graph")
.def(py::init<>())
.def("__repr__",[](Graph & g) {
std::stringstream ss;
ss << g;
return ss.str();
})
.def("inputs",[](Graph &g) {
return py::make_iterator(g.inputs().begin(), g.inputs().end());
})
.def("outputs",[](Graph &g) {
return py::make_iterator(g.outputs().begin(), g.outputs().end());
})
// TODO: Iterator invalidation might make this hazardous
.def("nodes",[](Graph &g) {
return py::make_iterator(g.begin(), g.end());
})
.def("addInput",[](Graph &g) { return g.addInput(); })
.GS(advanceStage)
.GS(stage)
.GS(eraseInput)
.GS(registerOutput)
.def("create",[](Graph & g, const char * str) {
return g.create(stringToSymbol(str));
})
.def("create",[](Graph & g, const char * str, size_t noutputs) {
return g.create(stringToSymbol(str), noutputs);
})
.def("create",[](Graph & g, const char * str, const std::vector<Value*> & inputs) {
return g.create(stringToSymbol(str),inputs);
})
.def("create",[](Graph & g, const char * str, const std::vector<Value*> & inputs, size_t noutputs) {
return g.create(stringToSymbol(str),inputs, noutputs);
})
.GS(createConstant)
.GS(createFusionGroup)
.def("createClone",[](Graph & g, Node * n, py::object fn) {
return g.createClone(n, [&](Value * e) {
return fn(e).cast<Value*>();
});
})
.GS(appendNode)
.GS(prependNode)
.GS(lint)
;
#undef GS
#define VS(name) \
def(#name,&Value :: name)
py::class_<Value,std::unique_ptr<Value, py::nodelete>>(m,"Value")
.def("__repr__",[](Value & n) {
std::stringstream ss;
ss << n.uniqueName() << " defined in (" << *n.node() << ")";
return ss.str();
})
.VS(type)
.VS(typeOption)
.VS(hasType)
.VS(setType)
.VS(inferTypeFrom)
// skip owningGraph because it returns a raw pointer to a otherwise
// std::shared_ptr stored graph object, and would cause a double free
.VS(debugName)
.VS(setDebugName)
.VS(unique)
.VS(uniqueName)
.VS(setStage)
.VS(stage)
.VS(offset)
.VS(uses)
.VS(isHandle)
.VS(replaceAllUsesWith)
.def("node",[](Value &v) { return v.node(); })
.def("setTypeAs", [](Value * node, Value * other) {
node->setType(other->typeOption());
return node;
})
.VS(copyMetadata)
;
#undef VS
#define NS(name) \
def(#name,&Node :: name)
py::class_<Node,std::unique_ptr<Node, py::nodelete>>(m,"Node")
.def("__repr__",[](Node & n) {
std::stringstream ss;
ss << n;
return ss.str();
})
.def("hasMultipleOutputs",[](Node&n) {
return n.outputs().size() > 1;
})
.NS(kind)
.NS(stage)
.NS(setStage)
.def("inputs",[](Node &n) {
return py::make_iterator(n.inputs().begin(), n.inputs().end());
})
.def("outputs",[](Node &n) {
return py::make_iterator(n.outputs().begin(), n.outputs().end());
})
.NS(output)
.NS(addInput)
.NS(replaceInput)
.NS(replaceInputWith)
.NS(replaceAllUsesWith)
.NS(insertBefore)
.NS(insertAfter)
.NS(moveAfter)
.NS(moveBefore)
.NS(removeInput)
.NS(removeAllInputs)
.NS(destroy)
.NS(hasUses)
.NS(eraseOutput)
.NS(addOutput)
#define AS(name) def(#name,&Attributes<Node> :: name)
// methods from Attributes
.AS(copyAttributes)
.AS(hasAttribute)
.AS(kindOf)
.AS(removeAttribute)
.AS(hasAttributes)
.AS(attributeNames)
#undef AS
#define CREATE_ACCESSOR(Kind,method) \
def(#method "_",[](Node & n, const char * name, Kind##Attr::ValueType v) { \
return n . method ## _(stringToSymbol(name), std::move(v)); \
}) \
.def(#method, [](Node & n, const char * name) { \
return n.method(stringToSymbol(name)); \
})
.CREATE_ACCESSOR(Float,f)
.CREATE_ACCESSOR(Floats,fs)
.CREATE_ACCESSOR(String,s)
.CREATE_ACCESSOR(Strings,ss)
.CREATE_ACCESSOR(Int,i)
.CREATE_ACCESSOR(Ints,is)
.CREATE_ACCESSOR(Tensor,t)
.CREATE_ACCESSOR(Tensors,ts)
.CREATE_ACCESSOR(Graph,g)
.CREATE_ACCESSOR(Graphs,gs)
#undef CREATE_ACCESSOR
.def("z_",[](Node & n, const char * name, at::Tensor v) {
return n.t_(stringToSymbol(name), std::move(v.view({})));
})
.def("z",[](Node & n, const char * name) {
return n.t(stringToSymbol(name));
})
.def("zs_",[](Node & n, const char * name, TensorsAttr::ValueType v) {
for (size_t i = 0; i < v.size(); ++ i) {
v[i] = v[i].view({});
}
return n.ts_(stringToSymbol(name), std::move(v));
})
.def("zs",[](Node & n, const char * name) {
return n.ts(stringToSymbol(name));
})
.def("pyobj",[](Node & n) {
return py::handle(n.expect<PythonOp>()->pyobj.get()).cast<py::object>();
})
.def("cconv",[](Node & n) {
return n.expect<PythonOp>()->cconv;
})
.def("pyname",[](Node & n) {
return n.expect<PythonOp>()->name();
})
.def("scalar_args",[](Node & n) {
auto op = n.expect<PythonOp>();
auto scalars = py::list();
auto append = scalars.attr("append");
for(auto & arg : op->scalar_args) {
append(py::handle(arg.get()));
}
return scalars;
})
;
#define TS(name) \
def(#name,&Node :: name)
py::class_<Type,std::shared_ptr<Type>>(m,"Type")
.def("__repr__",[](Type & t) {
std::stringstream ss;
ss << t;
return ss.str();
})
.def("kind",[](Type& t_) {
Type * t = &t_;
TYPE_IF(t, HandleType)
return "HandleType";
TYPE_ELSEIF(TensorType)
return "TensorType";
TYPE_END()
torch::barf("unknown type kind");
return "";
})
.def("sizes",[](Type& t) {
return t.expect<TensorType>()->sizes();
})
.def("strides",[](Type& t) {
return t.expect<TensorType>()->strides();
})
.def("contiguous",[](Type& t) {
return t.expect<TensorType>()->contiguous();
})
.def("scalarType",[](Type& t) {
return at::toString(t.expect<TensorType>()->scalarType());
})
;
py::class_<Use>(m,"Use")
.def_readonly("user",&Use::user)
.def_readonly("offset",&Use::offset);
m.def("_jit_get_graph", [](tracer::TracingState* s) {
return s->graph;
});
m.def("_jit_is_tracing", [](const autograd::Variable& var) {
return tracer::isTracing(var);
});
}
}}