mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Partially fixes: https://github.com/pytorch/pytorch/issues/394 Implementation detail: Codegen is modified to generate codes that looks like below: ```C++ static PyObject * THPVariable_svd(PyObject* self_, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ "svd(Tensor input, bool some=True, bool compute_uv=True, *, TensorList[3] out=None)", }, /*traceable=*/true); ParsedArgs<6> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); static PyStructSequence_Field fields0[] = { {"U", ""}, {"S", ""}, {"V", ""}, {nullptr} }; static PyStructSequence_Desc desc0 = { "torch.return_types.svd_out", nullptr, fields0, 3 }; static PyTypeObject type0; static bool namedtuple_type_initialized0 = false; if (!namedtuple_type_initialized0) { PyStructSequence_InitType(&type0, &desc0); namedtuple_type_initialized0 = true; } static PyStructSequence_Field fields1[] = { {"U", ""}, {"S", ""}, {"V", ""}, {nullptr} }; static PyStructSequence_Desc desc1 = { "torch.return_types.svd", nullptr, fields1, 3 }; static PyTypeObject type1; static bool namedtuple_type_initialized1 = false; if (!namedtuple_type_initialized1) { PyStructSequence_InitType(&type1, &desc1); namedtuple_type_initialized1 = true; } if (r.idx == 0) { if (r.isNone(3)) { return wrap(&type1, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2))); } else { auto results = r.tensorlist_n<3>(3); return wrap(&type0, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2), results[0], results[1], results[2])); } } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } ``` Types are defined as static member of `THPVariable_${op_name}` functions, and initialized at the first time the function is called. When parsing function prototypes in `native_functions.yaml`, the parser will set the specified name as `field_name` when see things like `-> (Tensor t1, ...)`. These field names will be the field names of namedtuple. The class of namedtuples will be named `torch.return_types.${op_name}`. In some python 2, `PyStructSequence` is not a subtype of tuple, so we have to create some functions to check if an object is a tuple or namedtuple for compatibility issue. Operators in `native_functions.yaml` are changed such that only `max` and `svd` are generated as namedtuple. Tests are added for these two operators to see if the return value works as expected. Docs for these two ops are also updated to explicitly mention the return value is a namedtuple. More ops will be added in later PRs. There is some issue with Windows build of linker unable to resolve `PyStructSequence_UnnamedField`, and some workaround is added to deal with this case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15429 Differential Revision: D13709678 Pulled By: ezyang fbshipit-source-id: 23a511c9436977098afc49374e9a748b6e30bccf
120 lines
3.5 KiB
C++
120 lines
3.5 KiB
C++
#include <torch/csrc/jit/python_arg_flatten.h>
|
|
#include <torch/csrc/utils/six.h>
|
|
|
|
#include <torch/csrc/autograd/grad_mode.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace python {
|
|
|
|
using namespace torch::autograd;
|
|
using namespace at;
|
|
|
|
// Alphabet used to describe structure of inputs/outputs (D for desc)
|
|
namespace D {
|
|
static constexpr char ListOpen = '[';
|
|
static constexpr char ListClose = ']';
|
|
static constexpr char TupleOpen = '(';
|
|
static constexpr char TupleClose = ')';
|
|
static constexpr char Variable = 'v';
|
|
} // namespace D
|
|
|
|
namespace {
|
|
|
|
template <typename T>
|
|
py::object cast_handle_sequence(std::vector<py::handle> objs) {
|
|
auto num_objs = objs.size();
|
|
T sequence{num_objs};
|
|
for (size_t i = 0; i < num_objs; ++i)
|
|
sequence[i] = py::reinterpret_borrow<py::object>(objs[i]);
|
|
return sequence;
|
|
}
|
|
|
|
void flatten_rec(PyObject* obj, ParsedArgs& args) {
|
|
auto& structure = args.desc.structure;
|
|
if (six::isTuple(obj)) {
|
|
structure.push_back(D::TupleOpen);
|
|
for (auto item : py::reinterpret_borrow<py::tuple>(obj))
|
|
flatten_rec(item.ptr(), args);
|
|
structure.push_back(D::TupleClose);
|
|
} else if (PyList_Check(obj)) {
|
|
structure.push_back(D::ListOpen);
|
|
for (auto item : py::reinterpret_borrow<py::list>(obj))
|
|
flatten_rec(item.ptr(), args);
|
|
structure.push_back(D::ListClose);
|
|
} else if (THPVariable_Check(obj)) {
|
|
auto& var = reinterpret_cast<THPVariable*>(obj)->cdata;
|
|
args.vars.push_back(var);
|
|
args.desc.metadata.emplace_back(var);
|
|
args.desc.structure.push_back(D::Variable);
|
|
} else {
|
|
std::string msg =
|
|
"Only tuples, lists and Variables supported as JIT inputs, but got ";
|
|
msg += THPUtils_typename(obj);
|
|
throw std::runtime_error(msg);
|
|
}
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
ParsedArgs flatten(py::handle obj) {
|
|
ParsedArgs args;
|
|
args.desc.grad_enabled = autograd::GradMode::is_enabled();
|
|
flatten_rec(obj.ptr(), args);
|
|
return args;
|
|
}
|
|
|
|
namespace {
|
|
|
|
template <typename T>
|
|
py::object cast_sequence(std::vector<py::object> objs) {
|
|
auto num_objs = objs.size();
|
|
T sequence{num_objs};
|
|
for (size_t i = 0; i < num_objs; ++i)
|
|
sequence[i] = std::move(objs[i]);
|
|
return sequence;
|
|
}
|
|
|
|
py::object unflatten_rec(
|
|
ArrayRef<Variable>::iterator& var_it,
|
|
ArrayRef<Variable>::iterator& var_it_end,
|
|
std::string::const_iterator& desc_it) {
|
|
char type = *desc_it++;
|
|
if (type == D::TupleOpen) {
|
|
std::vector<py::object> objs;
|
|
while (*desc_it != D::TupleClose)
|
|
objs.push_back(unflatten_rec(var_it, var_it_end, desc_it));
|
|
++desc_it;
|
|
return cast_sequence<py::tuple>(objs);
|
|
} else if (type == D::ListOpen) {
|
|
std::vector<py::object> objs;
|
|
while (*desc_it != D::ListClose)
|
|
objs.push_back(unflatten_rec(var_it, var_it_end, desc_it));
|
|
++desc_it;
|
|
return cast_sequence<py::list>(objs);
|
|
} else {
|
|
if (var_it == var_it_end)
|
|
throw std::runtime_error("Not enough Variables given to unflatten");
|
|
auto var = *var_it++;
|
|
return py::reinterpret_steal<py::object>(THPVariable_Wrap(var));
|
|
}
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
PyObject* unflatten(ArrayRef<Variable> vars, const IODescriptor& desc) {
|
|
// NB: We don't do correctness checking on descriptor.
|
|
// It has to be a correct bytes object produced by unflatten.
|
|
auto vars_it = vars.begin();
|
|
auto vars_it_end = vars.end();
|
|
auto desc_it = desc.structure.begin();
|
|
auto output = unflatten_rec(vars_it, vars_it_end, desc_it);
|
|
if (vars_it != vars_it_end)
|
|
throw std::runtime_error("Too many Variables given to unflatten");
|
|
return output.release().ptr();
|
|
}
|
|
|
|
} // namespace python
|
|
} // namespace jit
|
|
} // namespace torch
|