mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This commit adds a Value type similar to the one @ezyang suggested a while ago for handling multi-return nodes. Previously if we had a graph like: a = op1(b) c, d = op2(a) Then its in-memory format would look like: %0 = op1(b) %1 = op2(%0) %2 = select(%1, 0) %2 = select(%1, 1) Select nodes were used only to handle the multi-output case. In the single-output case ops referred directly to their uses. This required special handling for the single- and multi- output cases, and was confusing when used with ONNX which distinguishes values (the inputs/outputs of a node) from the nodes themselves (e.g. a Conv). This commit adds the Node/Value distinction to the IR. In the example above, `a`, `b`, `c`, and `d` are now Value objects, while `op1` and `op2` are now Node objects. Inputs/Outputs to the graph are values. * Nodes now always have multiple outputs, accessible through their `output()` method. * Methods exist for adding/removing outputs from a node. * Nodes own their output Values, destroying a node destroys its outputs and it is only valid to destroy a node when no uses of its outputs remain. * Unlike select, Values do not appear in the nodes list. * The method `node()` on `Value` retrieves its defining node. Calling it is always valid. For inputs, its kind is "Param". Like "Return" there is a single Param node representing all inputs. * For single-output Nodes, the method `output()` retrieves the single output Value, asserting that the node is in-fact single output. * Functions are the same, but some functions like `type()` have moved to Value. * `replaceAllUsesWith` is now sanely defined for both Values and Nodes. In the case of Nodes, it replaces all outputs of the node with the outputs of the replacement node. * stage is defined both on Node/Value. This is because Inputs require a stage. * Apart from changing data types from Node->Value most passes remain the same. Things that previously assumed single-output nodes now have to call output() to get the node. * This removes the uses = [...] field in the outputs because it was getting confusing even before this commit when uses would refer to nodes, but we print the names of Values. The lint pass validates the use list, so printing it out seems less necessary.
241 lines
6.7 KiB
C++
241 lines
6.7 KiB
C++
#include <Python.h>
|
|
|
|
#include "torch/csrc/jit/ir.h"
|
|
#include "torch/csrc/jit/pybind.h"
|
|
#include "torch/csrc/jit/python_tracer.h"
|
|
#include "torch/csrc/utils/pybind.h"
|
|
|
|
#include <iostream>
|
|
#include <sstream>
|
|
|
|
namespace torch { namespace jit {
|
|
|
|
void initPythonIRBindings(PyObject * module_) {
|
|
auto m = py::handle(module_).cast<py::module>();
|
|
#define GS(name) \
|
|
def(#name,&Graph :: name)
|
|
py::class_<Graph,std::shared_ptr<Graph>>(m,"Graph")
|
|
.def(py::init<>())
|
|
.def("__repr__",[](Graph & g) {
|
|
std::stringstream ss;
|
|
ss << g;
|
|
return ss.str();
|
|
})
|
|
.def("inputs",[](Graph &g) {
|
|
return py::make_iterator(g.inputs().begin(), g.inputs().end());
|
|
})
|
|
.def("outputs",[](Graph &g) {
|
|
return py::make_iterator(g.outputs().begin(), g.outputs().end());
|
|
})
|
|
// TODO: Iterator invalidation might make this hazardous
|
|
.def("nodes",[](Graph &g) {
|
|
return py::make_iterator(g.begin(), g.end());
|
|
})
|
|
.def("addInput",[](Graph &g) { return g.addInput(); })
|
|
.GS(advanceStage)
|
|
.GS(stage)
|
|
.GS(eraseInput)
|
|
.GS(registerOutput)
|
|
.def("create",[](Graph & g, const char * str) {
|
|
return g.create(stringToSymbol(str));
|
|
})
|
|
.def("create",[](Graph & g, const char * str, size_t noutputs) {
|
|
return g.create(stringToSymbol(str), noutputs);
|
|
})
|
|
.def("create",[](Graph & g, const char * str, const std::vector<Value*> & inputs) {
|
|
return g.create(stringToSymbol(str),inputs);
|
|
})
|
|
.def("create",[](Graph & g, const char * str, const std::vector<Value*> & inputs, size_t noutputs) {
|
|
return g.create(stringToSymbol(str),inputs, noutputs);
|
|
})
|
|
.GS(createConstant)
|
|
.GS(createFusionGroup)
|
|
.def("createClone",[](Graph & g, Node * n, py::object fn) {
|
|
return g.createClone(n, [&](Value * e) {
|
|
return fn(e).cast<Value*>();
|
|
});
|
|
})
|
|
.GS(appendNode)
|
|
.GS(prependNode)
|
|
.GS(lint)
|
|
;
|
|
#undef GS
|
|
|
|
#define VS(name) \
|
|
def(#name,&Value :: name)
|
|
py::class_<Value,std::unique_ptr<Value, py::nodelete>>(m,"Value")
|
|
.def("__repr__",[](Value & n) {
|
|
std::stringstream ss;
|
|
ss << n.uniqueName() << " defined in (" << *n.node() << ")";
|
|
return ss.str();
|
|
})
|
|
.VS(type)
|
|
.VS(typeOption)
|
|
.VS(hasType)
|
|
.VS(setType)
|
|
.VS(inferTypeFrom)
|
|
// skip owningGraph because it returns a raw pointer to a otherwise
|
|
// std::shared_ptr stored graph object, and would cause a double free
|
|
.VS(debugName)
|
|
.VS(setDebugName)
|
|
.VS(unique)
|
|
.VS(uniqueName)
|
|
.VS(setStage)
|
|
.VS(stage)
|
|
.VS(offset)
|
|
.VS(uses)
|
|
.VS(isHandle)
|
|
.VS(replaceAllUsesWith)
|
|
.def("node",[](Value &v) { return v.node(); })
|
|
.def("setTypeAs", [](Value * node, Value * other) {
|
|
node->setType(other->typeOption());
|
|
return node;
|
|
})
|
|
.VS(copyMetadata)
|
|
;
|
|
|
|
#undef VS
|
|
|
|
#define NS(name) \
|
|
def(#name,&Node :: name)
|
|
py::class_<Node,std::unique_ptr<Node, py::nodelete>>(m,"Node")
|
|
.def("__repr__",[](Node & n) {
|
|
std::stringstream ss;
|
|
ss << n;
|
|
return ss.str();
|
|
})
|
|
.def("hasMultipleOutputs",[](Node&n) {
|
|
return n.outputs().size() > 1;
|
|
})
|
|
.NS(kind)
|
|
.NS(stage)
|
|
.NS(setStage)
|
|
.def("inputs",[](Node &n) {
|
|
return py::make_iterator(n.inputs().begin(), n.inputs().end());
|
|
})
|
|
.def("outputs",[](Node &n) {
|
|
return py::make_iterator(n.outputs().begin(), n.outputs().end());
|
|
})
|
|
.NS(output)
|
|
.NS(addInput)
|
|
.NS(replaceInput)
|
|
.NS(replaceInputWith)
|
|
.NS(replaceAllUsesWith)
|
|
.NS(insertBefore)
|
|
.NS(insertAfter)
|
|
.NS(moveAfter)
|
|
.NS(moveBefore)
|
|
.NS(removeInput)
|
|
.NS(removeAllInputs)
|
|
.NS(destroy)
|
|
.NS(hasUses)
|
|
.NS(eraseOutput)
|
|
.NS(addOutput)
|
|
|
|
#define AS(name) def(#name,&Attributes<Node> :: name)
|
|
// methods from Attributes
|
|
.AS(copyAttributes)
|
|
.AS(hasAttribute)
|
|
.AS(kindOf)
|
|
.AS(removeAttribute)
|
|
.AS(hasAttributes)
|
|
.AS(attributeNames)
|
|
#undef AS
|
|
#define CREATE_ACCESSOR(Kind,method) \
|
|
def(#method "_",[](Node & n, const char * name, Kind##Attr::ValueType v) { \
|
|
return n . method ## _(stringToSymbol(name), std::move(v)); \
|
|
}) \
|
|
.def(#method, [](Node & n, const char * name) { \
|
|
return n.method(stringToSymbol(name)); \
|
|
})
|
|
.CREATE_ACCESSOR(Float,f)
|
|
.CREATE_ACCESSOR(Floats,fs)
|
|
.CREATE_ACCESSOR(String,s)
|
|
.CREATE_ACCESSOR(Strings,ss)
|
|
.CREATE_ACCESSOR(Int,i)
|
|
.CREATE_ACCESSOR(Ints,is)
|
|
.CREATE_ACCESSOR(Tensor,t)
|
|
.CREATE_ACCESSOR(Tensors,ts)
|
|
.CREATE_ACCESSOR(Graph,g)
|
|
.CREATE_ACCESSOR(Graphs,gs)
|
|
#undef CREATE_ACCESSOR
|
|
.def("z_",[](Node & n, const char * name, at::Tensor v) {
|
|
return n.t_(stringToSymbol(name), std::move(v.view({})));
|
|
})
|
|
.def("z",[](Node & n, const char * name) {
|
|
return n.t(stringToSymbol(name));
|
|
})
|
|
.def("zs_",[](Node & n, const char * name, TensorsAttr::ValueType v) {
|
|
for (size_t i = 0; i < v.size(); ++ i) {
|
|
v[i] = v[i].view({});
|
|
}
|
|
return n.ts_(stringToSymbol(name), std::move(v));
|
|
})
|
|
.def("zs",[](Node & n, const char * name) {
|
|
return n.ts(stringToSymbol(name));
|
|
})
|
|
.def("pyobj",[](Node & n) {
|
|
return py::handle(n.expect<PythonOp>()->pyobj.get()).cast<py::object>();
|
|
})
|
|
.def("cconv",[](Node & n) {
|
|
return n.expect<PythonOp>()->cconv;
|
|
})
|
|
.def("pyname",[](Node & n) {
|
|
return n.expect<PythonOp>()->name();
|
|
})
|
|
.def("scalar_args",[](Node & n) {
|
|
auto op = n.expect<PythonOp>();
|
|
auto scalars = py::list();
|
|
auto append = scalars.attr("append");
|
|
for(auto & arg : op->scalar_args) {
|
|
append(py::handle(arg.get()));
|
|
}
|
|
return scalars;
|
|
})
|
|
;
|
|
|
|
#define TS(name) \
|
|
def(#name,&Node :: name)
|
|
py::class_<Type,std::shared_ptr<Type>>(m,"Type")
|
|
.def("__repr__",[](Type & t) {
|
|
std::stringstream ss;
|
|
ss << t;
|
|
return ss.str();
|
|
})
|
|
.def("kind",[](Type& t_) {
|
|
Type * t = &t_;
|
|
TYPE_IF(t, HandleType)
|
|
return "HandleType";
|
|
TYPE_ELSEIF(TensorType)
|
|
return "TensorType";
|
|
TYPE_END()
|
|
torch::barf("unknown type kind");
|
|
return "";
|
|
})
|
|
.def("sizes",[](Type& t) {
|
|
return t.expect<TensorType>()->sizes();
|
|
})
|
|
.def("strides",[](Type& t) {
|
|
return t.expect<TensorType>()->strides();
|
|
})
|
|
.def("contiguous",[](Type& t) {
|
|
return t.expect<TensorType>()->contiguous();
|
|
})
|
|
.def("scalarType",[](Type& t) {
|
|
return at::toString(t.expect<TensorType>()->scalarType());
|
|
})
|
|
;
|
|
|
|
py::class_<Use>(m,"Use")
|
|
.def_readonly("user",&Use::user)
|
|
.def_readonly("offset",&Use::offset);
|
|
|
|
m.def("_jit_get_graph", [](tracer::TracingState* s) {
|
|
return s->graph;
|
|
});
|
|
m.def("_jit_is_tracing", [](const autograd::Variable& var) {
|
|
return tracer::isTracing(var);
|
|
});
|
|
}
|
|
}}
|