mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: * Fix the necessary pathways so that tuples and lists can be inputs to the script. * prevent linear algebra functions from being run in shape prop because they frequently will error out for nonsense data. * favor schema-driven python input conversion where possible. remaining cases where we directly create Stacks without schema are only for debugging * Make the error messages when calling script/trace functions more pythonic * Simplify FlattenTuples -- now that tuples are supported we can choose to only flatten tuples when needed. This may have to be revisited pending onnx test results, but is necessary for making tuple io work. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10812 Differential Revision: D9477982 Pulled By: zdevito fbshipit-source-id: ed06fc426e6ef6deb404602a26c435a7fc40ea0c
269 lines
10 KiB
C++
269 lines
10 KiB
C++
#include "torch/csrc/utils/pybind.h"
|
|
|
|
#include "torch/csrc/jit/python_tracer.h"
|
|
#include "torch/csrc/jit/tracer.h"
|
|
#include "torch/csrc/jit/python_ir.h"
|
|
#include "torch/csrc/jit/python_arg_flatten.h"
|
|
#include "torch/csrc/jit/export.h"
|
|
#include "torch/csrc/jit/argument_spec.h"
|
|
#include "torch/csrc/jit/passes/remove_expands.h"
|
|
#include "torch/csrc/jit/passes/graph_fuser.h"
|
|
#include "torch/csrc/jit/passes/onnx.h"
|
|
#include "torch/csrc/jit/passes/dead_code_elimination.h"
|
|
#include "torch/csrc/jit/passes/erase_number_types.h"
|
|
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
|
|
#include "torch/csrc/jit/passes/peephole.h"
|
|
#include "torch/csrc/jit/passes/canonicalize.h"
|
|
#include "torch/csrc/jit/passes/onnx/peephole.h"
|
|
#include "torch/csrc/jit/passes/onnx/fixup_onnx_loop.h"
|
|
#include "torch/csrc/jit/passes/shape_analysis.h"
|
|
#include "torch/csrc/jit/passes/decompose_addmm.h"
|
|
#include "torch/csrc/jit/passes/constant_propagation.h"
|
|
#include "torch/csrc/jit/passes/loop_unrolling.h"
|
|
#include "torch/csrc/jit/passes/to_batch.h"
|
|
#include "torch/csrc/jit/passes/lower_tuples.h"
|
|
#include "torch/csrc/jit/passes/specialize_undef.h"
|
|
#include "torch/csrc/jit/graph_executor.h"
|
|
#include "torch/csrc/jit/script/init.h"
|
|
#include "torch/csrc/jit/script/python_tree_views.h"
|
|
#include "torch/csrc/jit/batched/BatchTensor.h"
|
|
#include "torch/csrc/jit/pybind_utils.h"
|
|
#include "torch/csrc/jit/function_schema.h"
|
|
#include "torch/csrc/jit/serialization.h"
|
|
#include "torch/csrc/jit/operator.h"
|
|
|
|
#include <pybind11/functional.h>
|
|
|
|
#include <memory>
|
|
#include <sstream>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <tuple>
|
|
#include <utility>
|
|
|
|
namespace torch { namespace jit {
|
|
|
|
namespace {
|
|
|
|
using autograd::variable_list;
|
|
|
|
bool loadPythonClasses() {
|
|
// Leaving this code here, because it will likely be useful at some point
|
|
//PyObject *jit_module = PyImport_ImportModule("torch.jit");
|
|
//THPUtils_assert(jit_module, "class loader couldn't access "
|
|
//"torch.jit module");
|
|
//PyObject *jit_dict = PyModule_GetDict(jit_module);
|
|
|
|
return true;
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
extern std::string runJITCPPTests();
|
|
|
|
void initJITBindings(PyObject *module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
py::class_<python::IODescriptor>(m, "IODescriptor");
|
|
|
|
m.def("_jit_init", loadPythonClasses)
|
|
.def("_jit_pass_onnx", ToONNX)
|
|
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
|
|
.def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
|
|
.def("_jit_pass_fuse", FuseGraph)
|
|
.def("_jit_pass_dce", [](std::shared_ptr<Graph>& g) {
|
|
return EliminateDeadCode(g); // overload resolution
|
|
})
|
|
.def("_jit_pass_cse", [](std::shared_ptr<Graph>& g) {
|
|
return EliminateCommonSubexpression(g); // overload resolution
|
|
})
|
|
.def("_jit_pass_peephole", PeepholeOptimize)
|
|
.def("_jit_pass_canonicalize", [](const std::shared_ptr<Graph>& g) {
|
|
return Canonicalize(g);
|
|
})
|
|
.def("_jit_pass_lint", LintGraph)
|
|
.def("_jit_pass_shape_analysis", [](Graph& graph, py::tuple inputs, bool with_grad) {
|
|
PropagateInputShapes(graph, ArgumentSpec(with_grad, evilDeprecatedBadCreateStackDoNotUse(inputs, graph.inputs())));
|
|
})
|
|
.def("_jit_pass_complete_shape_analysis", [](Graph& graph, py::tuple inputs, bool with_grad) {
|
|
PropagateInputShapes(graph, CompleteArgumentSpec(with_grad, evilDeprecatedBadCreateStackDoNotUse(inputs, graph.inputs())));
|
|
})
|
|
.def("_jit_pass_remove_expands", RemoveExpands)
|
|
.def("_jit_pass_erase_number_types", EraseNumberTypes)
|
|
.def("_jit_pass_loop_unrolling", UnrollLoops)
|
|
.def("_jit_pass_constant_propagation", [](std::shared_ptr<Graph>& g) {
|
|
return ConstantPropagation(g);
|
|
})
|
|
.def("_jit_run_cpp_tests", [] {
|
|
// We have to release the GIL inside this method, because if we happen to
|
|
// initialize the autograd engine in these tests, the newly spawned worker threads will
|
|
// try to initialize their PyThreadState*, and they need the GIL for this.
|
|
AutoNoGIL _no_gil;
|
|
return runJITCPPTests();
|
|
})
|
|
.def("_jit_flatten", [](py::handle& obj) {
|
|
auto res = python::flatten(obj);
|
|
return std::make_pair(res.vars, res.desc);
|
|
})
|
|
.def("_jit_unflatten", [](autograd::variable_list vars, python::IODescriptor& desc) {
|
|
return py::reinterpret_steal<py::object>(python::unflatten(vars, desc));
|
|
})
|
|
.def("_jit_pass_onnx_block", BlockToONNX)
|
|
.def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
|
|
.def("_jit_pass_decompose_addmm", DecomposeAddmm)
|
|
.def("_jit_pass_specialize_undef", specializeUndef)
|
|
.def("_jit_differentiate", [](Graph &g, const std::vector<bool>& requires_grad) {
|
|
// the python binding slightly differs in semantics
|
|
// it makes a copy of the input Graph, and works on that
|
|
// jit::differentiate mutates the input Graph
|
|
auto g_clone = g.copy();
|
|
return differentiate(g_clone, requires_grad);
|
|
});
|
|
|
|
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
|
|
.def("__repr__", [](CompleteArgumentSpec& self) {
|
|
std::ostringstream s;
|
|
s << self;
|
|
return s.str();
|
|
});
|
|
py::class_<ArgumentSpec>(m, "ArgumentSpec");
|
|
py::class_<Code>(m, "Code")
|
|
.def("executors", [](Code& c) {
|
|
return py::make_iterator(c.executors().begin(), c.executors().end());
|
|
});
|
|
|
|
py::class_<ExecutionPlanState>(m, "ExecutionPlanState")
|
|
.def_property_readonly("graph", [](ExecutionPlanState& s) {
|
|
return s.graph;
|
|
})
|
|
.def_property_readonly("code", [](ExecutionPlanState& s) {
|
|
return s.f;
|
|
})
|
|
.def_property_readonly("grad_executor", [](ExecutionPlanState& s) {
|
|
return s.grad_executor.get();
|
|
});
|
|
|
|
py::class_<Gradient>(m, "Gradient")
|
|
.def_property_readonly("f", [](Gradient& m) {
|
|
return m.f;
|
|
})
|
|
.def_property_readonly("df", [](Gradient& m) {
|
|
return m.df;
|
|
})
|
|
.def_property_readonly("f_real_outputs", [](Gradient& m) {
|
|
return m.f_real_outputs;
|
|
})
|
|
.def_property_readonly("df_input_vjps", [](Gradient& m) {
|
|
return m.df_input_vjps;
|
|
})
|
|
.def_property_readonly("df_input_captured_inputs", [](Gradient& m) {
|
|
return m.df_input_captured_inputs;
|
|
})
|
|
.def_property_readonly("df_input_captured_outputs", [](Gradient& m) {
|
|
return m.df_input_captured_outputs;
|
|
})
|
|
.def_property_readonly("df_output_vjps", [](Gradient& m) {
|
|
return m.df_output_vjps;
|
|
});
|
|
|
|
py::class_<GraphExecutorState>(m, "GraphExecutorState")
|
|
.def_property_readonly("graph", [](GraphExecutorState& s) {
|
|
return s.graph;
|
|
})
|
|
.def_property_readonly("execution_plans", [](GraphExecutorState& s) {
|
|
return s.execution_plans;
|
|
})
|
|
.def_property_readonly("autograd_fallback", [](GraphExecutorState& s) {
|
|
return s.autograd_fallback;
|
|
})
|
|
.def_property_readonly("autograd_fallback_graph", [](GraphExecutorState& s) {
|
|
return s.autograd_fallback_graph;
|
|
});
|
|
|
|
py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
|
|
.def(
|
|
py::init([](py::function func,
|
|
py::tuple inputs,
|
|
bool optimize) {
|
|
auto graph = tracer::createGraphByTracing(func, toStack(inputs));
|
|
return GraphExecutor(graph, optimize);
|
|
}),
|
|
py::arg("func"),
|
|
py::arg("inputs"),
|
|
py::arg("optimize") = true)
|
|
.def(
|
|
py::init([](std::shared_ptr<Graph> graph, bool optimize) {
|
|
return GraphExecutor(std::move(graph), optimize);
|
|
}),
|
|
py::arg("graph"),
|
|
py::arg("optimize") = true)
|
|
.def("graph_for", [](GraphExecutor& ge, py::args args) {
|
|
return ge.graphFor(evilDeprecatedBadCreateStackDoNotUse(args, ge.graph()->inputs()));
|
|
})
|
|
.def_property_readonly("graph", [](GraphExecutor& ge) {
|
|
return ge.graph();
|
|
})
|
|
.def("get_debug_state", [](GraphExecutor& ge) {
|
|
return ge.getDebugState();
|
|
})
|
|
.def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
|
|
const auto & graph = ge.graph();
|
|
auto stack = evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs());
|
|
ge.run(stack);
|
|
return createPyObjectForStack(std::move(stack));
|
|
});
|
|
|
|
|
|
py::class_<PyTorchFileWriter>(m, "PyTorchFileWriter")
|
|
.def(py::init<std::string>())
|
|
.def("write_record", &PyTorchFileWriter::writeRecord)
|
|
.def("write_end_of_file", &PyTorchFileWriter::writeEndOfFile);
|
|
|
|
py::class_<PyTorchFileReader>(m, "PyTorchFileReader")
|
|
.def(py::init<std::string>())
|
|
.def("get_record_with_key", [](PyTorchFileReader &self, uint64_t key) {
|
|
std::shared_ptr<void> data;
|
|
size_t size;
|
|
std::tie(data, size) = self.getRecordWithKey(key);
|
|
return py::bytes(reinterpret_cast<const char*>(data.get()), size);
|
|
})
|
|
.def("get_last_record", [](PyTorchFileReader &self){
|
|
std::shared_ptr<void> data;
|
|
size_t size;
|
|
std::tie(data, size) = self.getLastRecord();
|
|
return py::bytes(reinterpret_cast<const char*>(data.get()), size);
|
|
});
|
|
|
|
m.def("_jit_get_operation", [](const std::string& qualified_name) {
|
|
try {
|
|
auto symbol = Symbol::fromQualString(qualified_name);
|
|
auto operations = getAllOperatorsFor(std::move(symbol));
|
|
AT_CHECK(!operations.empty(), "No such operator ", qualified_name);
|
|
AT_CHECK(
|
|
operations.size() == 1,
|
|
"Found ", operations.size(), " overloads for operator ",
|
|
qualified_name, "! Overloads are not supported from Python.");
|
|
std::shared_ptr<Operator> op = operations[0];
|
|
AT_ASSERT(op != nullptr);
|
|
std::ostringstream docstring;
|
|
docstring << "Automatically bound operator '" << qualified_name
|
|
<< "' with schema: " << op->schema();
|
|
return py::cpp_function([op](py::args args, py::kwargs kwargs) {
|
|
return invokeOperatorFromPython(
|
|
*op, std::move(args), std::move(kwargs));
|
|
}, py::name(qualified_name.c_str()), py::doc(docstring.str().c_str()));
|
|
} catch (const at::Error& error) {
|
|
throw std::runtime_error(error.what_without_backtrace());
|
|
}
|
|
}, py::arg("qualified_name"));
|
|
|
|
initPythonIRBindings(module);
|
|
tracer::initPythonTracerBindings(module);
|
|
script::initTreeViewBindings(module);
|
|
script::initJitScriptBindings(module);
|
|
initBatchTensorBindings(module);
|
|
initRegisterBatchOpsBindings(module);
|
|
}
|
|
|
|
}}
|