pytorch/torch/csrc/jit/init.cpp
Richard Zou 8489c4cc6e
Better support for literals in jit script (#8687)
Addresses #8177

A design doc can be found here: [gist](https://gist.github.com/zou3519/4b7f13f03cc9f3612bd9363e6405fa0a) version or [quip](https://fb.quip.com/azL1AqUckBdo) version

General approach:
- Add NumberType, FloatType, IntType to represent Python numbers, floats and ints.
- Emit these types for python literals
- Change aten_schema such that Scalars are NumberType, int64_t and bool are IntType.
- Emit aten::type_as, prim::NumToTensor, and prim::TensorToNum nodes for tensor-number math. (see examples below)
- Erase NumberType,  prim::NumToTensor, and prim::TensorToNum for ONNX export

### Tensor/number math
```
import torch
@torch.jit.script
def fn(x):
    return x + 1
```
```
graph(%x : Dynamic) {
  %1 : int = prim::Constant[value={1}]()
  %2 : Dynamic = prim::NumToTensor(%1)
  %3 : Dynamic = aten::type_as(%2, %x)
  %4 : Dynamic = aten::add[alpha={1}](%x, %4)
  return (%5);
}
```

### Number/Number Math
```
import torch
@torch.jit.script
def fn(zero):
    c = 1 + 1
    return zero + c
```
```
graph(%zero : Dynamic) {
  %1 : int = prim::Constant[value={1}]()
  %2 : int = prim::Constant[value={1}]()
  %3 : Dynamic = prim::num_to_tensor(%1)
  %4 : Dynamic = prim::num_to_tensor(%2)
  %5 : Dynamic = aten::add[alpha={1}](%3, %4)
  %c : int = prim::TensorToNum(%6)  # this is the result of the addition
  ...
  return (%13);
}
```

List of squashed commits:

* Introduce Python Number types

Added: IntType, FloatType, NumberType with
IntType <: NumberType
FloatType <: NumberType

Changed aten_schema so arguments have corresponding types

* Emit a NumberType for python literals.

Also emit a NumberType for Scalar default values.

* Add prim::NumToTensor and prim::TensorToNum

* Add DynamicType -> NumberType implicit cast for bc

* Better ensureTensor error message

* Add ensureTensorOrNumber. Allow passing Number to some functions

Like the range() construct and slices

* Patch IntList to work.

IntList is still a DynamicType in the frontend: a tensor gets built from
a List[int].

Also, IntList[1] is a "union between int and IntList" the way it is
implemented. If the frontend sees an int being passed for an IntList[1]
arg, it converts it to a tensor as well.

* Enforce some order on schemas to avoid overload ambiguity

add(Tensor, Tensor) should appear earlier than add(Tensor, Scalar). This
matches the order in which python_arg_parser parses its arguments.

* Disable std_dim and var_dim tests.

With the new schema information, std(input, keepdim) and std(input, dim)
are ambiguous. This will need to be fixed at a later date.

* Add NumberType erasure pass.

This is used for ONNX export and to ensure that NumberType information
doesn't reach the interpreter

* Add support for mixed tensor/number math ops.

* Tests for new functionality.

Includes:
- Tensor/number math
- number/number math
- EraseNumberTypes pass test

* Patch tests

Update expect tests for:
- decompose_addmm
- loop unrolling tests

Because python numbers are now NumberType, they cannot be returned by
functions anymore. Work around this by using "torch.full", or by adding
a tensor([0]) (taken from FIXME_zerol()). Both approaches are used
because torch.full is more readable, but it is broken in some cases.

* Add erase_number_types to torch/CMakeLists.txt

* Move math back to emitSimpleExpr from emitSugaredExpr

* Remove some dead lines

* Renable some excluded script/trace tests that are fixed.

* Move some tests to expected failure

* Address some comments (more addressing to come)

* Erase relevant aten::type_as nodes in EraseNumberTypes

I also changed it so that EraseNumberTypes is only called for ONNX
export. It is no longer used to prevent
prim::NumToTensor/prim::TensorToNum from reaching shape_analysis or
interpreter.cpp.

shape_analysis infers the type of the output of these nodes to be the
same as their input.

intepreter.cpp treats both of these nodes as no-ops.

* Add reminder to fix std/var

* Call EraseNumberTypes only when exporting a script module

* Update expects after rebase
2018-06-21 15:43:38 -04:00

191 lines
6.9 KiB
C++

#include "torch/csrc/utils/pybind.h"
#include "torch/csrc/jit/python_tracer.h"
#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/jit/python_ir.h"
#include "torch/csrc/jit/python_arg_flatten.h"
#include "torch/csrc/jit/export.h"
#include "torch/csrc/jit/argument_spec.h"
#include "torch/csrc/jit/passes/graph_fuser.h"
#include "torch/csrc/jit/passes/onnx.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/erase_number_types.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/passes/peephole.h"
#include "torch/csrc/jit/passes/canonicalize.h"
#include "torch/csrc/jit/passes/onnx/peephole.h"
#include "torch/csrc/jit/passes/onnx/fixup_onnx_loop.h"
#include "torch/csrc/jit/passes/shape_analysis.h"
#include "torch/csrc/jit/passes/decompose_addmm.h"
#include "torch/csrc/jit/passes/loop_unrolling.h"
#include "torch/csrc/jit/graph_executor.h"
#include "torch/csrc/jit/script/init.h"
#include "torch/csrc/jit/script/python_tree_views.h"
#include "torch/csrc/jit/python_interpreter.h"
namespace torch { namespace jit {
namespace {
using autograd::variable_list;
bool loadPythonClasses() {
// Leaving this code here, because it will likely be useful at some point
//PyObject *jit_module = PyImport_ImportModule("torch.jit");
//THPUtils_assert(jit_module, "class loader couldn't access "
//"torch.jit module");
//PyObject *jit_dict = PyModule_GetDict(jit_module);
return true;
}
// we cannot use the default py:cast<autograd::Variable> because it currently
// unwraps the data tensor in the conversion process
// TODO: replace with bs type
variable_tensor_list createVariableTensorList(py::tuple tuple, size_t reserve_extra_space = 0) {
variable_tensor_list result;
result.reserve(tuple.size() + reserve_extra_space);
for(auto e : tuple) {
result.push_back(py::cast<autograd::Variable>(e));
}
return result;
}
} // anonymous namespace
extern std::string runJITCPPTests();
void initJITBindings(PyObject *module) {
auto m = py::handle(module).cast<py::module>();
py::class_<python::IODescriptor>(m, "IODescriptor");
m.def("_jit_init", loadPythonClasses)
.def("_jit_pass_onnx", ToONNX)
.def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
.def("_jit_pass_fuse", FuseGraph)
.def("_jit_pass_dce", [](std::shared_ptr<Graph>& g){
return EliminateDeadCode(g); // overload resolution
})
.def("_jit_pass_cse", EliminateCommonSubexpression)
.def("_jit_pass_peephole", PeepholeOptimize)
.def("_jit_pass_canonicalize", [](const std::shared_ptr<Graph>& g) {
return Canonicalize(g);
})
.def("_jit_pass_lint", LintGraph)
.def("_jit_pass_shape_analysis", [](Graph& graph, py::tuple inputs, bool with_grad) {
auto tensor_inputs = createVariableTensorList(inputs);
PropagateInputShapes(graph, ArgumentSpec(with_grad, tensor_inputs));
})
.def("_jit_pass_erase_number_types", EraseNumberTypes)
.def("_jit_pass_loop_unrolling", UnrollLoops)
.def("_jit_run_cpp_tests", [] {
// We have to release the GIL inside this method, because if we happen to
// initialize the autograd engine in these tests, the newly spawned worker threads will
// try to initialize their PyThreadState*, and they need the GIL for this.
AutoNoGIL _no_gil;
return runJITCPPTests();
})
.def("_jit_flatten", [](py::handle& obj) {
auto res = python::flatten(obj);
return std::make_pair(res.vars, res.desc);
})
.def("_jit_unflatten", [](autograd::variable_list vars, python::IODescriptor& desc) {
return py::reinterpret_steal<py::object>(python::unflatten(vars, desc));
})
.def("_jit_pass_onnx_block", BlockToONNX)
.def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
.def("_jit_pass_decompose_addmm", DecomposeAddmm);
py::class_<ArgumentSpec>(m, "ArgumentSpec")
.def("__repr__", [](ArgumentSpec& self) {
std::ostringstream s;
s << self;
return s.str();
});
py::class_<Code>(m, "Code")
.def("executors", [](Code& c) {
return py::make_iterator(c.executors().begin(), c.executors().end());
});
py::class_<ExecutionPlanState>(m, "ExecutionPlanState")
.def_property_readonly("graph", [](ExecutionPlanState& s) {
return s.graph;
})
.def_property_readonly("code", [](ExecutionPlanState& s) {
return s.f;
})
.def_property_readonly("grad_executor", [](ExecutionPlanState& s) {
return s.grad_executor.get();
});
py::class_<GraphExecutorState>(m, "GraphExecutorState")
.def_property_readonly("graph", [](GraphExecutorState& s) {
return s.graph;
})
.def_property_readonly("execution_plans", [](GraphExecutorState& s) {
return s.execution_plans;
})
.def_property_readonly("autograd_fallback", [](GraphExecutorState& s) {
return s.autograd_fallback;
})
.def_property_readonly("autograd_fallback_graph", [](GraphExecutorState& s) {
return s.autograd_fallback_graph;
});
py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
.def(
py::init([](py::function func,
variable_list inputs,
bool optimize) {
size_t num_inputs = inputs.size();
auto graph = tracer::createGraphByTracing(func, std::move(inputs), num_inputs);
return GraphExecutor(graph, optimize);
}),
py::arg("func"),
py::arg("inputs"),
py::arg("optimize") = true)
.def(
py::init([](std::shared_ptr<Graph> graph, bool optimize) {
return GraphExecutor(std::move(graph), optimize);
}),
py::arg("graph"),
py::arg("optimize") = true)
.def_property_readonly("graph", [](GraphExecutor& ge) {
return ge.graph();
})
.def("graph_for", [](GraphExecutor& ge, py::args args) {
return ge.graphFor(createVariableTensorList(args));
})
.def("get_debug_state", [](GraphExecutor& ge) {
return ge.getDebugState();
})
.def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
auto inputs = createVariableTensorList(args);
auto outputs = ge.run(std::move(inputs));
// if we don't tell pybind these are variables it chokes on the
// conversion.
// TODO: fix conversions to be sane and make sure this works.
if (outputs.size() == 0) {
return py::none();
} else if (outputs.size() == 1) {
return py::cast(autograd::as_variable_ref(outputs[0]));
} else {
py::tuple tuple(outputs.size());
for(size_t i = 0; i < outputs.size(); i++) {
tuple[i] = py::cast(autograd::as_variable_ref(outputs[i]));
}
return tuple;
}
});
initPythonIRBindings(module);
tracer::initPythonTracerBindings(module);
script::initTreeViewBindings(module);
script::initJitScriptBindings(module);
registerPythonInterpreterOps();
}
}}