mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9718 This patch switches the interpreter to use IValue's primitive numbers rather than tensors for computing on integers and floats. In addition to preparing the interpreter for first-class support of other types, this cleans up the handling of primitive numbers, making it possible to just use the normal operator overloading dispatch to find the right implementation for numbers. As a result of this change, a lot of other functionality needed to be updated since it was the first time we use non-tensors in a lot of places in the code base. Notes: * Fixes code_template.py so that multi-line strings are indented correctly when used on a standalone line * Cast operators (`int(x)`) now are functional. Some tests have addition conversions to integers because we no longer allow implicit tensor -> integer conversions following the same convention as in python * prim::ListConstruct/createList has been added to the interpreter for creating lists and this has replaced aten::stack for integers lists * gen_jit_dispatch.py has been refactored so that non-tensor types use operators on IValues to extract the primitives * IValue gains a .to<T> method that is the equivalent of tensor_as but for IValue instead of at::Tensor * `constant_as<T>` is switched over to using IValues's `.to<T>` method, to make conversion from constant->IValue->C++ type more consistent. This functionality combined with `toIValue(Value*)` replaces the `tensor_as` and `as_tensor` family of functions. * conditional expressions (if, loop) and operators related to them are now computed on integers rather than tensors * IValue gains constructors for constructing from at::Scalar and converting to it. However, IValue itself will always store the scalars as a double or int64. * To align with python 3 syntax, TK_INT, TK_FLOAT, and TK_BOOL have been removed from the parser, and int/float/bool are just treated as special identifiers in the compiler, along with print. These are represented as special sugared values with a `call` method implemented. For int/float/bool this implements casting behavior. * Dropped shared_from_this from Type/Module. They were not needed and they making debugging harder because they internally throw/catch exceptions. * Shape propagation has been updated to support running nodes that include floating point primitive types, this required some refactoring of internal functions. * TensorToNum and NumToTensor have actual implementations as operators now * regster_prim_ops now contains implementations of math operators for float/int primitive types, and for mixed (prim <+> tensor) versions. This removes the need for special handling in compiler.cpp * Primitive math is now entirely handled by letting the compiler choose the right overloads. This removes tons of special casing in the compiler. * incorporates eellison's change to allow casting from return values. Due to the addition of primitive support, the code need slight modifications, so I just pre-merged it here. * stack.h gains generic vararg versions of push/pop that know how to convert to/from C++ types: ``` at::Tensor a; at::Scalar b; pop(stack, a, b); at::Tensor c = a + b; push(stack, c); ``` apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/9584 Reviewed By: apaszke Differential Revision: D8910546 Pulled By: zdevito fbshipit-source-id: 0f3e60d4d22217f196a8f606549430e43b7e7e30
94 lines
2.7 KiB
C++
94 lines
2.7 KiB
C++
#include "torch/csrc/python_headers.h"
|
|
#include "torch/csrc/jit/interpreter.h"
|
|
|
|
#include "torch/csrc/autograd/edge.h"
|
|
#include "torch/csrc/autograd/function.h"
|
|
#include "torch/csrc/autograd/profiler.h"
|
|
#include "torch/csrc/autograd/variable.h"
|
|
#include "torch/csrc/jit/fusion_compiler.h"
|
|
#include "torch/csrc/jit/operator.h"
|
|
#include "torch/csrc/jit/graph_executor.h"
|
|
#include "torch/csrc/jit/ir.h"
|
|
|
|
#include "torch/csrc/variable_tensor_functions.h"
|
|
|
|
#include <typeinfo>
|
|
|
|
#include "torch/csrc/autograd/python_engine.h"
|
|
#include "torch/csrc/autograd/python_variable.h"
|
|
#include "torch/csrc/jit/pybind.h"
|
|
#include "torch/csrc/utils/auto_gil.h"
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace torch { namespace jit {
|
|
|
|
namespace {
|
|
|
|
Operation createPythonOperation(Node* op_) {
|
|
PythonOp* op = static_cast<PythonOp*>(op_);
|
|
py::function func = py::reinterpret_borrow<py::function>(py::handle(op->pyobj.get()));
|
|
size_t num_inputs = 0;
|
|
for(auto arg_type : op->cconv) {
|
|
if(arg_type == 't')
|
|
num_inputs++;
|
|
}
|
|
return [=](Stack & stack) {
|
|
AutoGIL gil;
|
|
py::tuple py_inputs(op->cconv.size());
|
|
size_t i = 0;
|
|
size_t next_scalar = 0;
|
|
size_t next_tensor = 0;
|
|
for (auto arg_type : op->cconv) {
|
|
if (arg_type == 's') {
|
|
py_inputs[i] = py::reinterpret_borrow<py::object>(
|
|
op->scalar_args[next_scalar++].get());
|
|
} else if (arg_type == 't') {
|
|
auto var = std::move(peek(stack, next_tensor, num_inputs)).toTensor();
|
|
py_inputs[i] =
|
|
py::reinterpret_steal<py::object>(THPVariable_Wrap(var));
|
|
next_tensor++;
|
|
}
|
|
i++;
|
|
}
|
|
drop(stack, num_inputs);
|
|
py::object py_outputs(func(*py_inputs));
|
|
|
|
auto num_outputs = op->outputs().size();
|
|
auto addOutput = [&](py::handle entry) {
|
|
if (!THPVariable_Check(entry.ptr())) {
|
|
throw std::runtime_error(
|
|
"Function application returned a non-Variable output");
|
|
}
|
|
THPVariable* var = (THPVariable*)entry.ptr();
|
|
auto cdata = var->cdata;
|
|
stack.push_back(std::move(cdata));
|
|
};
|
|
|
|
if (!PyTuple_Check(py_outputs.ptr())) {
|
|
if (num_outputs != 1) {
|
|
throw std::runtime_error(
|
|
"Function.apply returned the wrong number of outputs.");
|
|
}
|
|
addOutput(py_outputs);
|
|
} else {
|
|
auto output_tuple = py::tuple(py_outputs);
|
|
if (output_tuple.size() != num_outputs) {
|
|
throw std::runtime_error(
|
|
"Function application returned the wrong number of outputs.");
|
|
}
|
|
for (py::handle entry : py::tuple(py_outputs)) {
|
|
addOutput(entry);
|
|
}
|
|
}
|
|
return 0;
|
|
};
|
|
}
|
|
|
|
|
|
RegisterOperators reg({
|
|
Operator(prim::PythonOp, createPythonOperation)
|
|
});
|
|
|
|
}}} // torch::jit::anon
|