mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove GraphExecutor's python bindings (#19141)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19141 ghimport-source-id: 796a41f5514d29959af052fcf5391a2834850a80 Reviewed By: jamesr66a Differential Revision: D14888702 Pulled By: zdevito fbshipit-source-id: c280145f08e7bc210434d1c99396a3257b626cf9
This commit is contained in:
parent
ddda563f22
commit
e958ceb5d7
|
|
@ -220,7 +220,7 @@ def get_grad_executor(plan_state, diff_graph_idx=None):
|
|||
pass
|
||||
else:
|
||||
raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
|
||||
grad_executors = list(plan_state.code.grad_executors())
|
||||
grad_executors = list(plan_state.code.grad_executor_states())
|
||||
return grad_executors[diff_graph_idx or 0]
|
||||
|
||||
|
||||
|
|
@ -229,8 +229,8 @@ def backward_graph(script_module, diff_graph_idx=None):
|
|||
raise RuntimeError('Expected ScriptModule')
|
||||
ge_state = script_module.get_debug_state()
|
||||
fwd_plan = get_execution_plan(ge_state)
|
||||
grad_executor = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
|
||||
bwd_plan = get_execution_plan(grad_executor.get_debug_state())
|
||||
grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
|
||||
bwd_plan = get_execution_plan(grad_executor_state)
|
||||
# Running JIT passes requires that we own the graph (with a shared_ptr).
|
||||
# The debug state struct does not own its graph so we make a copy of it.
|
||||
return bwd_plan.graph.copy()
|
||||
|
|
@ -540,11 +540,8 @@ class JitTestCase(TestCase):
|
|||
else:
|
||||
recording_inputs = reference_tensors
|
||||
|
||||
if isinstance(func, torch._C.Graph):
|
||||
ge = torch._C.GraphExecutor(func, optimize)
|
||||
else:
|
||||
ge = torch.jit.trace(func, input_tensors, optimize=optimize, check_tolerance=check_tolerance,
|
||||
_force_outplace=_force_outplace)
|
||||
ge = torch.jit.trace(func, input_tensors, optimize=optimize, check_tolerance=check_tolerance,
|
||||
_force_outplace=_force_outplace)
|
||||
|
||||
if export_import:
|
||||
ge = self.getExportImportCopy(ge)
|
||||
|
|
@ -1440,7 +1437,7 @@ graph(%x : Tensor,
|
|||
return MyInplaceFn.apply(x)
|
||||
|
||||
x = torch.randn(5, 5)
|
||||
ge = torch._C.GraphExecutor(fn, (x,), lambda var: '', _force_outplace=True)
|
||||
ge = torch.jit.trace(fn, (x,), _force_outplace=True, check_trace=False)
|
||||
with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
|
||||
ge(x)
|
||||
|
||||
|
|
@ -1856,7 +1853,7 @@ graph(%x : Tensor,
|
|||
return a * b / (a - b) + b
|
||||
V = Variable
|
||||
a, b = V(torch.rand(1)), V(torch.rand(1))
|
||||
ge = torch._C.GraphExecutor(foo, (a, b), lambda var: '')
|
||||
ge = torch.jit.trace(foo, (a, b))
|
||||
a, b = V(torch.rand(1), requires_grad=True), V(
|
||||
torch.rand(1), requires_grad=True)
|
||||
r, = ge(a, b)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def visualize(graph, name_prefix='', pb_graph=None, executors_it=None):
|
|||
value_map = {}
|
||||
pb_graph = pb_graph or graph_pb2.GraphDef()
|
||||
|
||||
if isinstance(graph, (torch._C.GraphExecutor, torch._C.GraphExecutorState)):
|
||||
if isinstance(graph, torch._C.GraphExecutorState):
|
||||
visualize_graph_executor(graph, name_prefix, pb_graph,
|
||||
partial(visualize, pb_graph=pb_graph))
|
||||
return pb_graph
|
||||
|
|
@ -65,9 +65,6 @@ def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph):
|
|||
The strategy is to embed all different configurations as independent subgraphs,
|
||||
while inlining the original graph as the one that actually produces the values.
|
||||
"""
|
||||
if isinstance(state, torch._C.GraphExecutor):
|
||||
state = state.get_debug_state()
|
||||
|
||||
if state.autograd_fallback_graph is not None:
|
||||
visualize(graph=state.autograd_fallback_graph,
|
||||
name_prefix=name_prefix + 'autograd_fallback/',
|
||||
|
|
|
|||
|
|
@ -239,9 +239,12 @@ void initJITBindings(PyObject* module) {
|
|||
});
|
||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||
py::class_<ArgumentSpec>(m, "ArgumentSpec");
|
||||
py::class_<Code>(m, "Code").def("grad_executors", [](Code& c) {
|
||||
return py::make_iterator(
|
||||
c.grad_executors().begin(), c.grad_executors().end());
|
||||
py::class_<Code>(m, "Code").def("grad_executor_states", [](Code& c) {
|
||||
std::vector<GraphExecutorState> states;
|
||||
for (auto& e : c.grad_executors()) {
|
||||
states.emplace_back(e->getDebugState());
|
||||
}
|
||||
return states;
|
||||
});
|
||||
|
||||
py::class_<ExecutionPlanState>(m, "ExecutionPlanState")
|
||||
|
|
@ -275,50 +278,6 @@ void initJITBindings(PyObject* module) {
|
|||
.def_property_readonly(
|
||||
"fallback", [](GraphExecutorState& s) { return s.fallback; });
|
||||
|
||||
py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
|
||||
.def(
|
||||
py::init([](py::function func,
|
||||
py::tuple inputs,
|
||||
py::function var_name_lookup_fn,
|
||||
bool optimize,
|
||||
bool _force_outplace) {
|
||||
auto graph = tracer::createGraphByTracing(
|
||||
func, toStack(inputs), var_name_lookup_fn, _force_outplace);
|
||||
return GraphExecutor(graph, optimize);
|
||||
}),
|
||||
py::arg("func"),
|
||||
py::arg("inputs"),
|
||||
py::arg("var_name_lookup_fn"),
|
||||
py::arg("optimize") = true,
|
||||
py::arg("_force_outplace") = false)
|
||||
.def(
|
||||
py::init([](std::shared_ptr<Graph> graph, bool optimize) {
|
||||
return GraphExecutor(std::move(graph), optimize);
|
||||
}),
|
||||
py::arg("graph"),
|
||||
py::arg("optimize") = true)
|
||||
.def(
|
||||
"graph_for",
|
||||
[](GraphExecutor& ge, py::args args) {
|
||||
return ge.graphFor(evilDeprecatedBadCreateStackDoNotUse(
|
||||
args, ge.graph()->inputs()));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"graph", [](GraphExecutor& ge) { return ge.graph(); })
|
||||
.def(
|
||||
"get_debug_state",
|
||||
[](GraphExecutor& ge) { return ge.getDebugState(); })
|
||||
.def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
|
||||
const auto& graph = ge.graph();
|
||||
auto stack =
|
||||
evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs());
|
||||
{
|
||||
AutoNoGIL no_gil_guard;
|
||||
ge.run(stack);
|
||||
}
|
||||
return createPyObjectForStack(std::move(stack));
|
||||
});
|
||||
|
||||
py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
|
||||
.def(py::init<std::string>())
|
||||
.def(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user