pytorch/torch/csrc/jit/python_tracer.cpp
anderspapitto 06df037d9a do away with ExportProxy hack in onnx export (#5614)
ExportProxy was a mechanism to reuse the code that supported exporting
autograd Functions to support overriding arbitrary python
functions. However, it had some serious downsides

- only works on some functions (all args must be Variable)
- complicated
- bad error messages in some cases

Instead, just expose enough functionality to python to perform the
necessary logic explicitly.
2018-03-08 22:17:30 -05:00

78 lines
2.6 KiB
C++

#include <Python.h>
#include "torch/csrc/jit/python_tracer.h"
#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/assertions.h"
#include "torch/csrc/jit/export.h"
#include "torch/csrc/jit/pybind.h"
#include "torch/csrc/utils/python_strings.h"
#include <sstream>
using namespace torch::autograd;
using namespace torch::jit;
using namespace torch::jit::tracer;
namespace torch { namespace jit {
#define ASSERT_UNEXPIRED(METHOD_NAME) if (s.is_expired()) throw std::runtime_error("calling " METHOD_NAME " on an expired trace")
void initPythonTracerBindings(PyObject* module_) {
auto m = py::handle(module_).cast<py::module>();
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
// NB: no constructor; you have to get it from C++ code
.def("__repr__", [](const TracingState& s) {
std::ostringstream ss;
ss << "<TracingState " << (const void*)&s << ">";
return ss.str();
})
.def("__str__", [](const TracingState& s) -> std::string {
if (s.is_expired()) return "<expired TracingState>";
std::ostringstream ss;
ss << *s.graph;
return ss.str();
})
.def("push_scope", [](TracingState& s, const std::string& scope_name) {
ASSERT_UNEXPIRED("push_scope");
s.push_scope(scope_name);
})
.def("pop_scope", [](TracingState& s) {
ASSERT_UNEXPIRED("pop_scope");
s.pop_scope();
})
.def("export", [](TracingState& s, const std::vector<at::Tensor>& initializers, int64_t onnx_opset_version) {
ASSERT_UNEXPIRED("export");
return py::bytes(ExportGraph(s.graph, initializers, onnx_opset_version));
})
.def("graph", [](TracingState& s) {
return s.graph;
})
.def_property_readonly("is_expired", [](TracingState& s) {
return s.is_expired();
})
.def_property_readonly("is_complete", [](TracingState& s) {
return s.is_complete();
});
m.def("_tracer_enter", [](std::vector<TraceInput> trace_inputs, std::size_t num_backwards) {
return enter(std::move(trace_inputs), num_backwards + 1);
});
m.def("_tracer_exit", [](variable_list var_outputs) {
tracer::exit(var_outputs);
});
m.def("_get_tracing_state", [](const variable_list& vars) {
return getTracingState(vars);
});
m.def("_get_value_trace", [](std::shared_ptr<TracingState>& state, const Variable& var) {
return getValueTrace(state, var);
});
m.def("_set_value_trace", [](std::shared_ptr<TracingState>& state, const Variable& var, Value* value) {
return setValueTrace(state, var, value);
});
m.def("_is_tracing", [](const variable_list& vars) {
return isTracingVar(vars);
});
}
}} // namespace torch::jit