mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Please review the expects carefully to make sure there are no regressions. I tried to go over them one by one when they changed, but it's sometimes easy to miss finer details. Summary of changes: - Renamed `TensorType` to `CompleteTensorType`. Added a new `TensorType` which records only the scalar type, number of dimensions, and device of a value. The argument behind the rename is to encourage people to use `CompleteTensorType` less, as most passes will only have limited information available. To make transition easier `complete_type->cast<TensorType>()` works, and makes our passes work with both kinds of specialization if they don't need extra the extra detail. - Renamed `ArgumentSpec` to `CompleteArgumentSpec`. Added a new `ArgumentSpec`, which matches argument only at the level of the new `TensorType`. - Shape analysis can process graphs with both `CompleteTensorType` and `TensorType`. - Fuser was a part that heavily relied on full shape information being available. Now, we simply try to fuse the largest possible graphs, and have to do run-time checks to make sure they match the code we generate. If they don't, we fall back to regular interpretation. The shape checks are implementing using an optimized method exploiting algebraic properties of shapes with broadcasting, and the relations of broadcasting with pointwise ops. A full written proof of correctness of the shape checking algorithm is included in a comment in `graph_fuser.cpp`. zdevito ezyang mruberry ngimel csarofeen Pull Request resolved: https://github.com/pytorch/pytorch/pull/10844 Differential Revision: D9498705 Pulled By: apaszke fbshipit-source-id: 0c53c2fcebd871cc2a29c260f8d012276479cc61
267 lines
10 KiB
C++
267 lines
10 KiB
C++
#include "torch/csrc/utils/pybind.h"
|
|
|
|
#include "torch/csrc/jit/python_tracer.h"
|
|
#include "torch/csrc/jit/tracer.h"
|
|
#include "torch/csrc/jit/python_ir.h"
|
|
#include "torch/csrc/jit/python_arg_flatten.h"
|
|
#include "torch/csrc/jit/export.h"
|
|
#include "torch/csrc/jit/argument_spec.h"
|
|
#include "torch/csrc/jit/passes/remove_expands.h"
|
|
#include "torch/csrc/jit/passes/graph_fuser.h"
|
|
#include "torch/csrc/jit/passes/onnx.h"
|
|
#include "torch/csrc/jit/passes/dead_code_elimination.h"
|
|
#include "torch/csrc/jit/passes/erase_number_types.h"
|
|
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
|
|
#include "torch/csrc/jit/passes/peephole.h"
|
|
#include "torch/csrc/jit/passes/canonicalize.h"
|
|
#include "torch/csrc/jit/passes/onnx/peephole.h"
|
|
#include "torch/csrc/jit/passes/onnx/fixup_onnx_loop.h"
|
|
#include "torch/csrc/jit/passes/shape_analysis.h"
|
|
#include "torch/csrc/jit/passes/decompose_addmm.h"
|
|
#include "torch/csrc/jit/passes/constant_propagation.h"
|
|
#include "torch/csrc/jit/passes/loop_unrolling.h"
|
|
#include "torch/csrc/jit/passes/to_batch.h"
|
|
#include "torch/csrc/jit/passes/specialize_undef.h"
|
|
#include "torch/csrc/jit/graph_executor.h"
|
|
#include "torch/csrc/jit/script/init.h"
|
|
#include "torch/csrc/jit/script/python_tree_views.h"
|
|
#include "torch/csrc/jit/batched/BatchTensor.h"
|
|
#include "torch/csrc/jit/pybind_utils.h"
|
|
#include "torch/csrc/jit/function_schema.h"
|
|
#include "torch/csrc/jit/serialization.h"
|
|
#include "torch/csrc/jit/operator.h"
|
|
|
|
#include <pybind11/functional.h>
|
|
|
|
#include <memory>
|
|
#include <sstream>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <tuple>
|
|
#include <utility>
|
|
|
|
namespace torch { namespace jit {
|
|
|
|
namespace {
|
|
|
|
using autograd::variable_list;
|
|
|
|
bool loadPythonClasses() {
|
|
// Leaving this code here, because it will likely be useful at some point
|
|
//PyObject *jit_module = PyImport_ImportModule("torch.jit");
|
|
//THPUtils_assert(jit_module, "class loader couldn't access "
|
|
//"torch.jit module");
|
|
//PyObject *jit_dict = PyModule_GetDict(jit_module);
|
|
|
|
return true;
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
extern std::string runJITCPPTests();
|
|
|
|
void initJITBindings(PyObject *module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
py::class_<python::IODescriptor>(m, "IODescriptor");
|
|
|
|
m.def("_jit_init", loadPythonClasses)
|
|
.def("_jit_pass_onnx", ToONNX)
|
|
.def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
|
|
.def("_jit_pass_fuse", FuseGraph)
|
|
.def("_jit_pass_dce", [](std::shared_ptr<Graph>& g) {
|
|
return EliminateDeadCode(g); // overload resolution
|
|
})
|
|
.def("_jit_pass_cse", [](std::shared_ptr<Graph>& g) {
|
|
return EliminateCommonSubexpression(g); // overload resolution
|
|
})
|
|
.def("_jit_pass_peephole", PeepholeOptimize)
|
|
.def("_jit_pass_canonicalize", [](const std::shared_ptr<Graph>& g) {
|
|
return Canonicalize(g);
|
|
})
|
|
.def("_jit_pass_lint", LintGraph)
|
|
.def("_jit_pass_shape_analysis", [](Graph& graph, py::tuple inputs, bool with_grad) {
|
|
PropagateInputShapes(graph, ArgumentSpec(with_grad, evilDeprecatedBadCreateStackDoNotUse(inputs, graph.inputs())));
|
|
})
|
|
.def("_jit_pass_complete_shape_analysis", [](Graph& graph, py::tuple inputs, bool with_grad) {
|
|
PropagateInputShapes(graph, CompleteArgumentSpec(with_grad, evilDeprecatedBadCreateStackDoNotUse(inputs, graph.inputs())));
|
|
})
|
|
.def("_jit_pass_remove_expands", RemoveExpands)
|
|
.def("_jit_pass_erase_number_types", EraseNumberTypes)
|
|
.def("_jit_pass_loop_unrolling", UnrollLoops)
|
|
.def("_jit_pass_constant_propagation", [](std::shared_ptr<Graph>& g) {
|
|
return ConstantPropagation(g);
|
|
})
|
|
.def("_jit_run_cpp_tests", [] {
|
|
// We have to release the GIL inside this method, because if we happen to
|
|
// initialize the autograd engine in these tests, the newly spawned worker threads will
|
|
// try to initialize their PyThreadState*, and they need the GIL for this.
|
|
AutoNoGIL _no_gil;
|
|
return runJITCPPTests();
|
|
})
|
|
.def("_jit_flatten", [](py::handle& obj) {
|
|
auto res = python::flatten(obj);
|
|
return std::make_pair(res.vars, res.desc);
|
|
})
|
|
.def("_jit_unflatten", [](autograd::variable_list vars, python::IODescriptor& desc) {
|
|
return py::reinterpret_steal<py::object>(python::unflatten(vars, desc));
|
|
})
|
|
.def("_jit_pass_onnx_block", BlockToONNX)
|
|
.def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
|
|
.def("_jit_pass_decompose_addmm", DecomposeAddmm)
|
|
.def("_jit_pass_specialize_undef", specializeUndef)
|
|
.def("_jit_differentiate", [](Graph &g, const std::vector<bool>& requires_grad) {
|
|
// the python binding slightly differs in semantics
|
|
// it makes a copy of the input Graph, and works on that
|
|
// jit::differentiate mutates the input Graph
|
|
auto g_clone = g.copy();
|
|
return differentiate(g_clone, requires_grad);
|
|
});
|
|
|
|
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
|
|
.def("__repr__", [](CompleteArgumentSpec& self) {
|
|
std::ostringstream s;
|
|
s << self;
|
|
return s.str();
|
|
});
|
|
py::class_<ArgumentSpec>(m, "ArgumentSpec");
|
|
py::class_<Code>(m, "Code")
|
|
.def("executors", [](Code& c) {
|
|
return py::make_iterator(c.executors().begin(), c.executors().end());
|
|
});
|
|
|
|
py::class_<ExecutionPlanState>(m, "ExecutionPlanState")
|
|
.def_property_readonly("graph", [](ExecutionPlanState& s) {
|
|
return s.graph;
|
|
})
|
|
.def_property_readonly("code", [](ExecutionPlanState& s) {
|
|
return s.f;
|
|
})
|
|
.def_property_readonly("grad_executor", [](ExecutionPlanState& s) {
|
|
return s.grad_executor.get();
|
|
});
|
|
|
|
py::class_<Gradient>(m, "Gradient")
|
|
.def_property_readonly("f", [](Gradient& m) {
|
|
return m.f;
|
|
})
|
|
.def_property_readonly("df", [](Gradient& m) {
|
|
return m.df;
|
|
})
|
|
.def_property_readonly("f_real_outputs", [](Gradient& m) {
|
|
return m.f_real_outputs;
|
|
})
|
|
.def_property_readonly("df_input_vjps", [](Gradient& m) {
|
|
return m.df_input_vjps;
|
|
})
|
|
.def_property_readonly("df_input_captured_inputs", [](Gradient& m) {
|
|
return m.df_input_captured_inputs;
|
|
})
|
|
.def_property_readonly("df_input_captured_outputs", [](Gradient& m) {
|
|
return m.df_input_captured_outputs;
|
|
})
|
|
.def_property_readonly("df_output_vjps", [](Gradient& m) {
|
|
return m.df_output_vjps;
|
|
});
|
|
|
|
py::class_<GraphExecutorState>(m, "GraphExecutorState")
|
|
.def_property_readonly("graph", [](GraphExecutorState& s) {
|
|
return s.graph;
|
|
})
|
|
.def_property_readonly("execution_plans", [](GraphExecutorState& s) {
|
|
return s.execution_plans;
|
|
})
|
|
.def_property_readonly("autograd_fallback", [](GraphExecutorState& s) {
|
|
return s.autograd_fallback;
|
|
})
|
|
.def_property_readonly("autograd_fallback_graph", [](GraphExecutorState& s) {
|
|
return s.autograd_fallback_graph;
|
|
});
|
|
|
|
py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
|
|
.def(
|
|
py::init([](py::function func,
|
|
py::tuple inputs,
|
|
bool optimize) {
|
|
auto graph = tracer::createGraphByTracing(func, toStack(inputs));
|
|
return GraphExecutor(graph, optimize);
|
|
}),
|
|
py::arg("func"),
|
|
py::arg("inputs"),
|
|
py::arg("optimize") = true)
|
|
.def(
|
|
py::init([](std::shared_ptr<Graph> graph, bool optimize) {
|
|
return GraphExecutor(std::move(graph), optimize);
|
|
}),
|
|
py::arg("graph"),
|
|
py::arg("optimize") = true)
|
|
.def("graph_for", [](GraphExecutor& ge, py::args args) {
|
|
return ge.graphFor(evilDeprecatedBadCreateStackDoNotUse(args, ge.graph()->inputs()));
|
|
})
|
|
.def_property_readonly("graph", [](GraphExecutor& ge) {
|
|
return ge.graph();
|
|
})
|
|
.def("get_debug_state", [](GraphExecutor& ge) {
|
|
return ge.getDebugState();
|
|
})
|
|
.def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
|
|
const auto & graph = ge.graph();
|
|
auto stack = evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs());
|
|
ge.run(stack);
|
|
return createPyObjectForStack(std::move(stack));
|
|
});
|
|
|
|
|
|
py::class_<PyTorchFileWriter>(m, "PyTorchFileWriter")
|
|
.def(py::init<std::string>())
|
|
.def("write_record", &PyTorchFileWriter::writeRecord)
|
|
.def("write_end_of_file", &PyTorchFileWriter::writeEndOfFile);
|
|
|
|
py::class_<PyTorchFileReader>(m, "PyTorchFileReader")
|
|
.def(py::init<std::string>())
|
|
.def("get_record_with_key", [](PyTorchFileReader &self, uint64_t key) {
|
|
std::shared_ptr<void> data;
|
|
size_t size;
|
|
std::tie(data, size) = self.getRecordWithKey(key);
|
|
return py::bytes(reinterpret_cast<const char*>(data.get()), size);
|
|
})
|
|
.def("get_last_record", [](PyTorchFileReader &self){
|
|
std::shared_ptr<void> data;
|
|
size_t size;
|
|
std::tie(data, size) = self.getLastRecord();
|
|
return py::bytes(reinterpret_cast<const char*>(data.get()), size);
|
|
});
|
|
|
|
m.def("_jit_get_operation", [](const std::string& qualified_name) {
|
|
try {
|
|
auto symbol = Symbol::fromQualString(qualified_name);
|
|
auto operations = getAllOperatorsFor(std::move(symbol));
|
|
AT_CHECK(!operations.empty(), "No such operator ", qualified_name);
|
|
AT_CHECK(
|
|
operations.size() == 1,
|
|
"Found ", operations.size(), " overloads for operator ",
|
|
qualified_name, "! Overloads are not supported from Python.");
|
|
std::shared_ptr<Operator> op = operations[0];
|
|
AT_ASSERT(op != nullptr);
|
|
std::ostringstream docstring;
|
|
docstring << "Automatically bound operator '" << qualified_name
|
|
<< "' with schema: " << op->schema();
|
|
return py::cpp_function([op](py::args args, py::kwargs kwargs) {
|
|
return invokeOperatorFromPython(
|
|
*op, std::move(args), std::move(kwargs));
|
|
}, py::name(qualified_name.c_str()), py::doc(docstring.str().c_str()));
|
|
} catch (const at::Error& error) {
|
|
throw std::runtime_error(error.what_without_backtrace());
|
|
}
|
|
}, py::arg("qualified_name"));
|
|
|
|
initPythonIRBindings(module);
|
|
tracer::initPythonTracerBindings(module);
|
|
script::initTreeViewBindings(module);
|
|
script::initJitScriptBindings(module);
|
|
initBatchTensorBindings(module);
|
|
initRegisterBatchOpsBindings(module);
|
|
}
|
|
|
|
}}
|