mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Addresses #8177 A design doc can be found here: [gist](https://gist.github.com/zou3519/4b7f13f03cc9f3612bd9363e6405fa0a) version or [quip](https://fb.quip.com/azL1AqUckBdo) version General approach: - Add NumberType, FloatType, IntType to represent Python numbers, floats and ints. - Emit these types for python literals - Change aten_schema such that Scalars are NumberType, int64_t and bool are IntType. - Emit aten::type_as, prim::NumToTensor, and prim::TensorToNum nodes for tensor-number math. (see examples below) - Erase NumberType, prim::NumToTensor, and prim::TensorToNum for ONNX export ### Tensor/number math ``` import torch @torch.jit.script def fn(x): return x + 1 ``` ``` graph(%x : Dynamic) { %1 : int = prim::Constant[value={1}]() %2 : Dynamic = prim::NumToTensor(%1) %3 : Dynamic = aten::type_as(%2, %x) %4 : Dynamic = aten::add[alpha={1}](%x, %4) return (%5); } ``` ### Number/Number Math ``` import torch @torch.jit.script def fn(zero): c = 1 + 1 return zero + c ``` ``` graph(%zero : Dynamic) { %1 : int = prim::Constant[value={1}]() %2 : int = prim::Constant[value={1}]() %3 : Dynamic = prim::num_to_tensor(%1) %4 : Dynamic = prim::num_to_tensor(%2) %5 : Dynamic = aten::add[alpha={1}](%3, %4) %c : int = prim::TensorToNum(%6) # this is the result of the addition ... return (%13); } ``` List of squashed commits: * Introduce Python Number types Added: IntType, FloatType, NumberType with IntType <: NumberType FloatType <: NumberType Changed aten_schema so arguments have corresponding types * Emit a NumberType for python literals. Also emit a NumberType for Scalar default values. * Add prim::NumToTensor and prim::TensorToNum * Add DynamicType -> NumberType implicit cast for bc * Better ensureTensor error message * Add ensureTensorOrNumber. Allow passing Number to some functions Like the range() construct and slices * Patch IntList to work. IntList is still a DynamicType in the frontend: a tensor gets built from a List[int]. Also, IntList[1] is a "union between int and IntList" the way it is implemented. If the frontend sees an int being passed for an IntList[1] arg, it converts it to a tensor as well. * Enforce some order on schemas to avoid overload ambiguity add(Tensor, Tensor) should appear earlier than add(Tensor, Scalar). This matches the order in which python_arg_parser parses its arguments. * Disable std_dim and var_dim tests. With the new schema information, std(input, keepdim) and std(input, dim) are ambiguous. This will need to be fixed at a later date. * Add NumberType erasure pass. This is used for ONNX export and to ensure that NumberType information doesn't reach the interpreter * Add support for mixed tensor/number math ops. * Tests for new functionality. Includes: - Tensor/number math - number/number math - EraseNumberTypes pass test * Patch tests Update expect tests for: - decompose_addmm - loop unrolling tests Because python numbers are now NumberType, they cannot be returned by functions anymore. Work around this by using "torch.full", or by adding a tensor([0]) (taken from FIXME_zerol()). Both approaches are used because torch.full is more readable, but it is broken in some cases. * Add erase_number_types to torch/CMakeLists.txt * Move math back to emitSimpleExpr from emitSugaredExpr * Remove some dead lines * Renable some excluded script/trace tests that are fixed. * Move some tests to expected failure * Address some comments (more addressing to come) * Erase relevant aten::type_as nodes in EraseNumberTypes I also changed it so that EraseNumberTypes is only called for ONNX export. It is no longer used to prevent prim::NumToTensor/prim::TensorToNum from reaching shape_analysis or interpreter.cpp. shape_analysis infers the type of the output of these nodes to be the same as their input. intepreter.cpp treats both of these nodes as no-ops. * Add reminder to fix std/var * Call EraseNumberTypes only when exporting a script module * Update expects after rebase
191 lines
6.9 KiB
C++
191 lines
6.9 KiB
C++
#include "torch/csrc/utils/pybind.h"
|
|
|
|
#include "torch/csrc/jit/python_tracer.h"
|
|
#include "torch/csrc/jit/tracer.h"
|
|
#include "torch/csrc/jit/python_ir.h"
|
|
#include "torch/csrc/jit/python_arg_flatten.h"
|
|
#include "torch/csrc/jit/export.h"
|
|
#include "torch/csrc/jit/argument_spec.h"
|
|
#include "torch/csrc/jit/passes/graph_fuser.h"
|
|
#include "torch/csrc/jit/passes/onnx.h"
|
|
#include "torch/csrc/jit/passes/dead_code_elimination.h"
|
|
#include "torch/csrc/jit/passes/erase_number_types.h"
|
|
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
|
|
#include "torch/csrc/jit/passes/peephole.h"
|
|
#include "torch/csrc/jit/passes/canonicalize.h"
|
|
#include "torch/csrc/jit/passes/onnx/peephole.h"
|
|
#include "torch/csrc/jit/passes/onnx/fixup_onnx_loop.h"
|
|
#include "torch/csrc/jit/passes/shape_analysis.h"
|
|
#include "torch/csrc/jit/passes/decompose_addmm.h"
|
|
#include "torch/csrc/jit/passes/loop_unrolling.h"
|
|
#include "torch/csrc/jit/graph_executor.h"
|
|
#include "torch/csrc/jit/script/init.h"
|
|
#include "torch/csrc/jit/script/python_tree_views.h"
|
|
#include "torch/csrc/jit/python_interpreter.h"
|
|
|
|
|
|
namespace torch { namespace jit {
|
|
|
|
namespace {
|
|
|
|
using autograd::variable_list;
|
|
|
|
bool loadPythonClasses() {
|
|
// Leaving this code here, because it will likely be useful at some point
|
|
//PyObject *jit_module = PyImport_ImportModule("torch.jit");
|
|
//THPUtils_assert(jit_module, "class loader couldn't access "
|
|
//"torch.jit module");
|
|
//PyObject *jit_dict = PyModule_GetDict(jit_module);
|
|
|
|
return true;
|
|
}
|
|
|
|
// we cannot use the default py:cast<autograd::Variable> because it currently
|
|
// unwraps the data tensor in the conversion process
|
|
// TODO: replace with bs type
|
|
variable_tensor_list createVariableTensorList(py::tuple tuple, size_t reserve_extra_space = 0) {
|
|
variable_tensor_list result;
|
|
result.reserve(tuple.size() + reserve_extra_space);
|
|
for(auto e : tuple) {
|
|
result.push_back(py::cast<autograd::Variable>(e));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
extern std::string runJITCPPTests();
|
|
|
|
void initJITBindings(PyObject *module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
py::class_<python::IODescriptor>(m, "IODescriptor");
|
|
|
|
m.def("_jit_init", loadPythonClasses)
|
|
.def("_jit_pass_onnx", ToONNX)
|
|
.def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
|
|
.def("_jit_pass_fuse", FuseGraph)
|
|
.def("_jit_pass_dce", [](std::shared_ptr<Graph>& g){
|
|
return EliminateDeadCode(g); // overload resolution
|
|
})
|
|
.def("_jit_pass_cse", EliminateCommonSubexpression)
|
|
.def("_jit_pass_peephole", PeepholeOptimize)
|
|
.def("_jit_pass_canonicalize", [](const std::shared_ptr<Graph>& g) {
|
|
return Canonicalize(g);
|
|
})
|
|
.def("_jit_pass_lint", LintGraph)
|
|
.def("_jit_pass_shape_analysis", [](Graph& graph, py::tuple inputs, bool with_grad) {
|
|
auto tensor_inputs = createVariableTensorList(inputs);
|
|
PropagateInputShapes(graph, ArgumentSpec(with_grad, tensor_inputs));
|
|
})
|
|
.def("_jit_pass_erase_number_types", EraseNumberTypes)
|
|
.def("_jit_pass_loop_unrolling", UnrollLoops)
|
|
.def("_jit_run_cpp_tests", [] {
|
|
// We have to release the GIL inside this method, because if we happen to
|
|
// initialize the autograd engine in these tests, the newly spawned worker threads will
|
|
// try to initialize their PyThreadState*, and they need the GIL for this.
|
|
AutoNoGIL _no_gil;
|
|
return runJITCPPTests();
|
|
})
|
|
.def("_jit_flatten", [](py::handle& obj) {
|
|
auto res = python::flatten(obj);
|
|
return std::make_pair(res.vars, res.desc);
|
|
})
|
|
.def("_jit_unflatten", [](autograd::variable_list vars, python::IODescriptor& desc) {
|
|
return py::reinterpret_steal<py::object>(python::unflatten(vars, desc));
|
|
})
|
|
.def("_jit_pass_onnx_block", BlockToONNX)
|
|
.def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
|
|
.def("_jit_pass_decompose_addmm", DecomposeAddmm);
|
|
|
|
py::class_<ArgumentSpec>(m, "ArgumentSpec")
|
|
.def("__repr__", [](ArgumentSpec& self) {
|
|
std::ostringstream s;
|
|
s << self;
|
|
return s.str();
|
|
});
|
|
py::class_<Code>(m, "Code")
|
|
.def("executors", [](Code& c) {
|
|
return py::make_iterator(c.executors().begin(), c.executors().end());
|
|
});
|
|
|
|
py::class_<ExecutionPlanState>(m, "ExecutionPlanState")
|
|
.def_property_readonly("graph", [](ExecutionPlanState& s) {
|
|
return s.graph;
|
|
})
|
|
.def_property_readonly("code", [](ExecutionPlanState& s) {
|
|
return s.f;
|
|
})
|
|
.def_property_readonly("grad_executor", [](ExecutionPlanState& s) {
|
|
return s.grad_executor.get();
|
|
});
|
|
|
|
py::class_<GraphExecutorState>(m, "GraphExecutorState")
|
|
.def_property_readonly("graph", [](GraphExecutorState& s) {
|
|
return s.graph;
|
|
})
|
|
.def_property_readonly("execution_plans", [](GraphExecutorState& s) {
|
|
return s.execution_plans;
|
|
})
|
|
.def_property_readonly("autograd_fallback", [](GraphExecutorState& s) {
|
|
return s.autograd_fallback;
|
|
})
|
|
.def_property_readonly("autograd_fallback_graph", [](GraphExecutorState& s) {
|
|
return s.autograd_fallback_graph;
|
|
});
|
|
|
|
py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
|
|
.def(
|
|
py::init([](py::function func,
|
|
variable_list inputs,
|
|
bool optimize) {
|
|
size_t num_inputs = inputs.size();
|
|
auto graph = tracer::createGraphByTracing(func, std::move(inputs), num_inputs);
|
|
return GraphExecutor(graph, optimize);
|
|
}),
|
|
py::arg("func"),
|
|
py::arg("inputs"),
|
|
py::arg("optimize") = true)
|
|
.def(
|
|
py::init([](std::shared_ptr<Graph> graph, bool optimize) {
|
|
return GraphExecutor(std::move(graph), optimize);
|
|
}),
|
|
py::arg("graph"),
|
|
py::arg("optimize") = true)
|
|
.def_property_readonly("graph", [](GraphExecutor& ge) {
|
|
return ge.graph();
|
|
})
|
|
.def("graph_for", [](GraphExecutor& ge, py::args args) {
|
|
return ge.graphFor(createVariableTensorList(args));
|
|
})
|
|
.def("get_debug_state", [](GraphExecutor& ge) {
|
|
return ge.getDebugState();
|
|
})
|
|
.def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
|
|
auto inputs = createVariableTensorList(args);
|
|
auto outputs = ge.run(std::move(inputs));
|
|
// if we don't tell pybind these are variables it chokes on the
|
|
// conversion.
|
|
// TODO: fix conversions to be sane and make sure this works.
|
|
if (outputs.size() == 0) {
|
|
return py::none();
|
|
} else if (outputs.size() == 1) {
|
|
return py::cast(autograd::as_variable_ref(outputs[0]));
|
|
} else {
|
|
py::tuple tuple(outputs.size());
|
|
for(size_t i = 0; i < outputs.size(); i++) {
|
|
tuple[i] = py::cast(autograd::as_variable_ref(outputs[i]));
|
|
}
|
|
return tuple;
|
|
}
|
|
});
|
|
|
|
initPythonIRBindings(module);
|
|
tracer::initPythonTracerBindings(module);
|
|
script::initTreeViewBindings(module);
|
|
script::initJitScriptBindings(module);
|
|
registerPythonInterpreterOps();
|
|
}
|
|
|
|
}}
|