pytorch/torch/csrc/jit/python_tracer.cpp
Adam Paszke 8cb1eef7b9 Unify IR operator representation (stop using attributes in the JIT) (#9807)
Summary:
Based on top of #9763 (first 3 commits belong to that PR). The first commits from this PR are "Stop using attributes ..."

I tried to separate the changes into fairly meaningful commits. I can't split them up into smaller PRs, because everything starts working and all tests pass only after the whole sequence, but hopefully this will make reviewing somewhat easier.

Known issues/regressions/future tasks:
- `aten::lerp` and `aten::clamp` are no longer fusable
- `CreateAutodiffSubgraphs` needs a rewrite
  - It is much more strict now, and will miss a lot of opportunities, especially when viewing ops are involved. Our previous approach was "ignore the assumption on shape availability in gradient formulas to determine differentiability, and hope that shape prop will be robust enough to actually deliver them before we differentiate", which obviously doesn't scale well to more complex cases. We should either work on reducing the size dependency of grad formulas (feasible e.g. for `view`/`reshape`, unfeasible for `squeeze`/`unsqueeze`), or make `CreateAutodiffSubgraphs` integrate some kind of "I could integrate this node into an AD subgraph, but will I be able to infer the shape of its input" reasoning (kind of like a limited shape prop, that doesn't infer anything, and only tells if it *could* infer something).
  - It sometimes creates constant-only (or constants + one node) graphs, which is useless
- Broken `aten::add` in auto-batching, because it gained a non-tensor input. I changed the test for pointwise operations to use `aten::mul` instead, but I needed to disable the LSTM cell test. I'm not sure how scalar constants should be implemented in this case, because I don't fully understand our format. cc: ChunliF
- Graph import does some hacks to recover type of constants. This code should be removed once we'll gain the ability to export the IR along with value types.
- There's still a fair amount of dead code that can be removed. I didn't want to make this diff any bigger, and removing it is an easy task.
- Graph fuser could be improved to use signature matching (possibly using `OperatorSet`) instead of basing on node kinds.
- Manual constant propagation for the `ListConstruct` node in `torch/onnx/utils.py` should be replaced with a proper constant propagation pass (or we should ensure that the one we have handles at least this case before we remove this code).

zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9807

Reviewed By: ezyang

Differential Revision: D9004285

Pulled By: apaszke

fbshipit-source-id: fe88026a765f6b687354add034c86402362508b7
2018-07-26 22:11:50 -07:00

156 lines
4.7 KiB
C++

#include "torch/csrc/python_headers.h"
#include "torch/csrc/jit/python_tracer.h"
#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/jit/export.h"
#include "torch/csrc/jit/pybind.h"
#include "torch/csrc/utils/python_strings.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "ATen/Error.h"
#include <sstream>
using namespace torch::autograd;
using namespace torch::jit;
using namespace torch::jit::tracer;
namespace torch { namespace jit { namespace tracer {
// Python interpreter retrieval routine adapted from
// https://stackoverflow.com/a/8706144
std::string getPythonInterpreterStackTrace() {
std::stringstream stack_trace;
AutoGIL gil;
PyThreadState *tstate = PyThreadState_GET();
if (NULL != tstate && NULL != tstate->frame) {
PyFrameObject *frame = tstate->frame;
while (NULL != frame) {
int line = PyCode_Addr2Line(frame->f_code, frame->f_lasti);
std::string filename = THPUtils_unpackString(frame->f_code->co_filename);
std::string funcname = THPUtils_unpackString(frame->f_code->co_name);
stack_trace << filename << "(" << line << "): " << funcname << "\n";
frame = frame->f_back;
}
}
return stack_trace.str();
}
// This is a temporary constructor so that we can write python tests of
// the executor. It does not have most of the functionality of CompiledFunction
// such as being able to hold parameters...
std::shared_ptr<torch::jit::Graph> createGraphByTracing(
py::function func,
tracer::variable_list trace_inputs,
size_t num_func_inputs) {
auto enter_info = tracer::enter(std::move(trace_inputs));
try {
py::tuple py_inputs(num_func_inputs);
for(size_t i = 0; i < num_func_inputs; ++i) {
py_inputs[i] = py::cast(enter_info.second[i]);
}
auto out = func(*py_inputs);
std::vector<autograd::Variable> outputs;
if(PyTuple_Check(out.ptr())) {
outputs = py::cast<std::vector<autograd::Variable>>(out);
} else {
outputs.push_back(py::cast<autograd::Variable>(out));
}
tracer::exit(outputs);
auto graph = enter_info.first->graph;
EliminateDeadCode(graph);
return graph;
} catch (...) {
tracer::abandon();
throw;
}
}
PreTraceInfo preRecordPythonTrace(THPObjectPtr pyobj,
std::string arg_types,
at::ArrayRef<Variable> inputs,
pyobj_list scalar_args) {
THPObjectPtr apply(PyObject_GetAttrString(pyobj.get(), "apply"));
if(!apply) {
throw python_error();
}
PreTraceInfo info;
auto & state = getTracingState();
auto & graph = state->graph;
Node *n = info.n = graph->createPythonOp(
std::move(apply),
arg_types,
std::move(scalar_args));
recordSourceLocation(n);
for (const Variable & input : inputs) {
n->addInput(getValueTrace(input));
}
// NB: Order matters. This must append after inputs but before outputs.
graph->appendNode(n);
return info;
}
void pythonRecordSourceLocation(Node* n) {
auto sl = std::make_shared<StringSourceLocation>(getPythonInterpreterStackTrace());
n->setSourceLocation(sl);
}
void initPythonTracerBindings(PyObject* module_) {
setRecordSourceLocation(pythonRecordSourceLocation);
auto m = py::handle(module_).cast<py::module>();
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
// NB: no constructor; you have to get it from C++ code
.def("__repr__", [](const TracingState& s) {
std::ostringstream ss;
ss << "<TracingState " << (const void*)&s << ">";
return ss.str();
})
.def("__str__", [](const TracingState& s) -> std::string {
std::ostringstream ss;
ss << *s.graph;
return ss.str();
})
.def("push_scope", [](TracingState& s, const std::string& scope_name) {
s.graph->push_scope(scope_name);
})
.def("pop_scope", [](TracingState& s) {
s.graph->pop_scope();
})
.def("set_graph", [](TracingState& s, std::shared_ptr<Graph> g) {
s.graph = g;
})
.def("graph", [](TracingState& s) {
return s.graph;
});
m.def("_tracer_enter", [](variable_list trace_inputs) {
return tracer::enter(std::move(trace_inputs));
});
m.def("_tracer_exit", [](variable_list var_outputs) {
tracer::exit(var_outputs);
});
m.def("_tracer_abandon", []() {
tracer::abandon();
});
m.def("_get_tracing_state", []() {
return getTracingState();
});
m.def("_get_value_trace", [](const Variable& var) {
return getValueTrace(var);
});
m.def("_set_value_trace", [](const Variable& var, Value* value) {
return setValueTrace(var, value);
});
}
}}} // namespace torch::jit::tracing