Remove GraphExecutor's python bindings (#19141)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19141
ghimport-source-id: 796a41f5514d29959af052fcf5391a2834850a80

Reviewed By: jamesr66a

Differential Revision: D14888702

Pulled By: zdevito

fbshipit-source-id: c280145f08e7bc210434d1c99396a3257b626cf9
This commit is contained in:
Zachary DeVito 2019-04-13 08:28:13 -07:00 committed by Facebook Github Bot
parent ddda563f22
commit e958ceb5d7
3 changed files with 14 additions and 61 deletions

View File

@ -220,7 +220,7 @@ def get_grad_executor(plan_state, diff_graph_idx=None):
pass
else:
raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
grad_executors = list(plan_state.code.grad_executors())
grad_executors = list(plan_state.code.grad_executor_states())
return grad_executors[diff_graph_idx or 0]
@ -229,8 +229,8 @@ def backward_graph(script_module, diff_graph_idx=None):
raise RuntimeError('Expected ScriptModule')
ge_state = script_module.get_debug_state()
fwd_plan = get_execution_plan(ge_state)
grad_executor = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
bwd_plan = get_execution_plan(grad_executor.get_debug_state())
grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
bwd_plan = get_execution_plan(grad_executor_state)
# Running JIT passes requires that we own the graph (with a shared_ptr).
# The debug state struct does not own its graph so we make a copy of it.
return bwd_plan.graph.copy()
@ -540,9 +540,6 @@ class JitTestCase(TestCase):
else:
recording_inputs = reference_tensors
if isinstance(func, torch._C.Graph):
ge = torch._C.GraphExecutor(func, optimize)
else:
ge = torch.jit.trace(func, input_tensors, optimize=optimize, check_tolerance=check_tolerance,
_force_outplace=_force_outplace)
@ -1440,7 +1437,7 @@ graph(%x : Tensor,
return MyInplaceFn.apply(x)
x = torch.randn(5, 5)
ge = torch._C.GraphExecutor(fn, (x,), lambda var: '', _force_outplace=True)
ge = torch.jit.trace(fn, (x,), _force_outplace=True, check_trace=False)
with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
ge(x)
@ -1856,7 +1853,7 @@ graph(%x : Tensor,
return a * b / (a - b) + b
V = Variable
a, b = V(torch.rand(1)), V(torch.rand(1))
ge = torch._C.GraphExecutor(foo, (a, b), lambda var: '')
ge = torch.jit.trace(foo, (a, b))
a, b = V(torch.rand(1), requires_grad=True), V(
torch.rand(1), requires_grad=True)
r, = ge(a, b)

View File

@ -27,7 +27,7 @@ def visualize(graph, name_prefix='', pb_graph=None, executors_it=None):
value_map = {}
pb_graph = pb_graph or graph_pb2.GraphDef()
if isinstance(graph, (torch._C.GraphExecutor, torch._C.GraphExecutorState)):
if isinstance(graph, torch._C.GraphExecutorState):
visualize_graph_executor(graph, name_prefix, pb_graph,
partial(visualize, pb_graph=pb_graph))
return pb_graph
@ -65,9 +65,6 @@ def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph):
The strategy is to embed all different configurations as independent subgraphs,
while inlining the original graph as the one that actually produces the values.
"""
if isinstance(state, torch._C.GraphExecutor):
state = state.get_debug_state()
if state.autograd_fallback_graph is not None:
visualize(graph=state.autograd_fallback_graph,
name_prefix=name_prefix + 'autograd_fallback/',

View File

@ -239,9 +239,12 @@ void initJITBindings(PyObject* module) {
});
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<ArgumentSpec>(m, "ArgumentSpec");
py::class_<Code>(m, "Code").def("grad_executors", [](Code& c) {
return py::make_iterator(
c.grad_executors().begin(), c.grad_executors().end());
py::class_<Code>(m, "Code").def("grad_executor_states", [](Code& c) {
std::vector<GraphExecutorState> states;
for (auto& e : c.grad_executors()) {
states.emplace_back(e->getDebugState());
}
return states;
});
py::class_<ExecutionPlanState>(m, "ExecutionPlanState")
@ -275,50 +278,6 @@ void initJITBindings(PyObject* module) {
.def_property_readonly(
"fallback", [](GraphExecutorState& s) { return s.fallback; });
py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
.def(
py::init([](py::function func,
py::tuple inputs,
py::function var_name_lookup_fn,
bool optimize,
bool _force_outplace) {
auto graph = tracer::createGraphByTracing(
func, toStack(inputs), var_name_lookup_fn, _force_outplace);
return GraphExecutor(graph, optimize);
}),
py::arg("func"),
py::arg("inputs"),
py::arg("var_name_lookup_fn"),
py::arg("optimize") = true,
py::arg("_force_outplace") = false)
.def(
py::init([](std::shared_ptr<Graph> graph, bool optimize) {
return GraphExecutor(std::move(graph), optimize);
}),
py::arg("graph"),
py::arg("optimize") = true)
.def(
"graph_for",
[](GraphExecutor& ge, py::args args) {
return ge.graphFor(evilDeprecatedBadCreateStackDoNotUse(
args, ge.graph()->inputs()));
})
.def_property_readonly(
"graph", [](GraphExecutor& ge) { return ge.graph(); })
.def(
"get_debug_state",
[](GraphExecutor& ge) { return ge.getDebugState(); })
.def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
const auto& graph = ge.graph();
auto stack =
evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs());
{
AutoNoGIL no_gil_guard;
ge.run(stack);
}
return createPyObjectForStack(std::move(stack));
});
py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
.def(py::init<std::string>())
.def(